00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SOLVETRIANGULAR_H
00026 #define EIGEN_SOLVETRIANGULAR_H
00027
00028 namespace internal {
00029
00030
00031
00032 template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
00033 struct triangular_solve_vector;
00034
00035 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
00036 struct triangular_solve_matrix;
00037
00038
00039 template<typename Lhs, typename Rhs, int Side>
00040 class trsolve_traits
00041 {
00042 private:
00043 enum {
00044 RhsIsVectorAtCompileTime = (Side==OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1
00045 };
00046 public:
00047 enum {
00048 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8)
00049 ? CompleteUnrolling : NoUnrolling,
00050 RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic
00051 };
00052 };
00053
00054 template<typename Lhs, typename Rhs,
00055 int Side,
00056 int Mode,
00057 int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
00058 int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
00059 >
00060 struct triangular_solver_selector;
00061
00062 template<typename Lhs, typename Rhs, int Side, int Mode>
00063 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
00064 {
00065 typedef typename Lhs::Scalar LhsScalar;
00066 typedef typename Rhs::Scalar RhsScalar;
00067 typedef blas_traits<Lhs> LhsProductTraits;
00068 typedef typename LhsProductTraits::ExtractType ActualLhsType;
00069 typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs;
00070 static void run(const Lhs& lhs, Rhs& rhs)
00071 {
00072 ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
00073
00074
00075
00076 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1;
00077 RhsScalar* actualRhs;
00078 if(useRhsDirectly)
00079 {
00080 actualRhs = &rhs.coeffRef(0);
00081 }
00082 else
00083 {
00084 actualRhs = ei_aligned_stack_new(RhsScalar,rhs.size());
00085 MappedRhs(actualRhs,rhs.size()) = rhs;
00086 }
00087
00088 triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
00089 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
00090 ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
00091
00092 if(!useRhsDirectly)
00093 {
00094 rhs = MappedRhs(actualRhs, rhs.size());
00095 ei_aligned_stack_delete(RhsScalar, actualRhs, rhs.size());
00096 }
00097 }
00098 };
00099
00100
00101 template<typename Lhs, typename Rhs, int Side, int Mode>
00102 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
00103 {
00104 typedef typename Rhs::Scalar Scalar;
00105 typedef typename Rhs::Index Index;
00106 typedef blas_traits<Lhs> LhsProductTraits;
00107 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
00108 static void run(const Lhs& lhs, Rhs& rhs)
00109 {
00110 const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
00111 triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
00112 (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
00113 ::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
00114 }
00115 };
00116
00117
00118
00119
00120
00121 template<typename Lhs, typename Rhs, int Mode, int Index, int Size,
00122 bool Stop = Index==Size>
00123 struct triangular_solver_unroller;
00124
00125 template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
00126 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
00127 enum {
00128 IsLower = ((Mode&Lower)==Lower),
00129 I = IsLower ? Index : Size - Index - 1,
00130 S = IsLower ? 0 : I+1
00131 };
00132 static void run(const Lhs& lhs, Rhs& rhs)
00133 {
00134 if (Index>0)
00135 rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose()
00136 .cwiseProduct(rhs.template segment<Index>(S)).sum();
00137
00138 if(!(Mode & UnitDiag))
00139 rhs.coeffRef(I) /= lhs.coeff(I,I);
00140
00141 triangular_solver_unroller<Lhs,Rhs,Mode,Index+1,Size>::run(lhs,rhs);
00142 }
00143 };
00144
00145 template<typename Lhs, typename Rhs, int Mode, int Index, int Size>
00146 struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
00147 static void run(const Lhs&, Rhs&) {}
00148 };
00149
00150 template<typename Lhs, typename Rhs, int Mode>
00151 struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
00152 static void run(const Lhs& lhs, Rhs& rhs)
00153 { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
00154 };
00155
00156 template<typename Lhs, typename Rhs, int Mode>
00157 struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
00158 static void run(const Lhs& lhs, Rhs& rhs)
00159 {
00160 Transpose<Lhs> trLhs(lhs);
00161 Transpose<Rhs> trRhs(rhs);
00162
00163 triangular_solver_unroller<Transpose<Lhs>,Transpose<Rhs>,
00164 ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
00165 0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs);
00166 }
00167 };
00168
00169 }
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184 template<typename MatrixType, unsigned int Mode>
00185 template<int Side, typename OtherDerived>
00186 void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived>& _other) const
00187 {
00188 OtherDerived& other = _other.const_cast_derived();
00189 eigen_assert(cols() == rows());
00190 eigen_assert( (Side==OnTheLeft && cols() == other.rows()) || (Side==OnTheRight && cols() == other.cols()) );
00191 eigen_assert(!(Mode & ZeroDiag));
00192 eigen_assert(Mode & (Upper|Lower));
00193
00194 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime };
00195 typedef typename internal::conditional<copy,
00196 typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
00197 OtherCopy otherCopy(other);
00198
00199 internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type,
00200 Side, Mode>::run(nestedExpression(), otherCopy);
00201
00202 if (copy)
00203 other = otherCopy;
00204 }
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237 template<typename Derived, unsigned int Mode>
00238 template<int Side, typename RhsDerived>
00239 typename internal::plain_matrix_type_column_major<RhsDerived>::type
00240 TriangularView<Derived,Mode>::solve(const MatrixBase<RhsDerived>& rhs) const
00241 {
00242 typename internal::plain_matrix_type_column_major<RhsDerived>::type res(rhs);
00243 solveInPlace<Side>(res);
00244 return res;
00245 }
00246
00247 #endif // EIGEN_SOLVETRIANGULAR_H