diff --git a/sparse_matrix.cc b/sparse_matrix.cc index 5972be66c09f087edf446591814dac16b6b9fef0..4e6ca0de547ae9813645b303a95d364b7e6d895d 100644 --- a/sparse_matrix.cc +++ b/sparse_matrix.cc @@ -6,7 +6,6 @@ /* CCOPYRIGHT */ - #include <iostream> #include <iomanip> #include <strstream> @@ -222,6 +221,30 @@ namespace MISCMATHS { } } + void multiply(const DiagonalMatrix& lm, const SparseMatrix& rm, SparseMatrix& ret) + { + Tracer_Plus tr("SparseMatrix::multiply"); + + int nrows = lm.Nrows(); + int ncols = rm.Ncols(); + + if(lm.Ncols() != rm.Nrows()) throw Exception("Rows and cols don't match in SparseMatrix::multiply"); + + ret.ReSize(nrows,ncols); + + for(int j = 1; j<=nrows; j++) + { + const SparseMatrix::Row& row = rm.row(j); + for(SparseMatrix::Row::const_iterator it=row.begin();it!=row.end();it++) + { + int c = (*it).first+1; + double val = (*it).second; + ret.insert(j,c,val*lm(j,j)); + } + } + + } + void multiply(const SparseMatrix& lm, const SparseMatrix::Row& rm, ColumnVector& ret) { Tracer_Plus tr("SparseMatrix::multiply3"); diff --git a/sparse_matrix.h b/sparse_matrix.h index c793bbe41f416998feb427159e0d31f4c5c14f5d..491ef3f1cebe5bd0652b67789c30f4c33fc15cfc 100644 --- a/sparse_matrix.h +++ b/sparse_matrix.h @@ -128,6 +128,7 @@ namespace MISCMATHS { }; void multiply(const SparseMatrix& lm, const SparseMatrix& rm, SparseMatrix& ret); + void multiply(const DiagonalMatrix& lm, const SparseMatrix& rm, SparseMatrix& ret); void multiply(const SparseMatrix& lm, const ColumnVector& rm, ColumnVector& ret);