#ifndef DIAGONALMATRIX_H #define DIAGONALMATRIX_H #include "MatrixDef.h" namespace ATC_matrix { /** * @class DiagonalMatrix * @brief Class for storing data as a diagonal matrix */ template class DiagonalMatrix : public Matrix { public: explicit DiagonalMatrix(INDEX nRows=0, bool zero=0); DiagonalMatrix(const DiagonalMatrix& c); DiagonalMatrix(const Vector& v); virtual ~DiagonalMatrix(); //* resizes the matrix, ignores nCols, optionally zeros void reset(INDEX rows, INDEX cols=0, bool zero=true); //* resizes the matrix, ignores nCols, optionally copies what fits void resize(INDEX rows, INDEX cols=0, bool copy=false); //* resets based on full copy of vector v void reset(const Vector& v); //* resets based on full copy of a DiagonalMatrix void reset(const DiagonalMatrix& v); //* resets based on a one column DenseMatrix void reset(const DenseMatrix& c); //* resizes the matrix, ignores nCols, optionally copies what fits void copy(const T * ptr, INDEX rows, INDEX cols=0); //* resets based on a "shallow" copy of a vector void shallowreset(const Vector &v); //* resets based on a "shallow" copy of a DiagonalMatrix void shallowreset(const DiagonalMatrix &c); //* resets based on a "shallow" copy of one column DenseMatrix void shallowreset(const DenseMatrix &c); T& operator()(INDEX i, INDEX j); T operator()(INDEX i, INDEX j) const; T& operator[](INDEX i); T operator[](INDEX i) const; INDEX nRows() const; INDEX nCols() const; T* ptr() const; void write_restart(FILE *f) const; // Dump matrix contents to screen (not defined for all datatypes) std::string to_string(int p=myPrecision) const { return _data->to_string(); } using Matrix::matlab; void matlab(std::ostream &o, const std::string &s="D") const; // overloaded operators DiagonalMatrix& operator=(const T s); DiagonalMatrix& operator=(const DiagonalMatrix &C); //DiagonalMatrix& operator=(const Vector &C); INDEX size() const { return _data->size(); } //* computes the inverse of this matrix DiagonalMatrix& inv_this(); //* returns a copy of the inverse of this matrix DiagonalMatrix inv() const; // DiagonalMatrix-matrix multiplication function virtual void MultAB(const Matrix& B, DenseMatrix& C) const { GCK(*this, B, this->nCols()!=B.nRows(), "DiagonalMatrix-Matrix multiplication"); for (INDEX i=0; i &r); DiagonalMatrix& operator=(const Vector &c) {} DiagonalMatrix& operator=(const Matrix &c) {} private: void _delete(); Vector *_data; }; //----------------------------------------------------------------------------- // DiagonalMatrix-DiagonalMatrix multiplication //----------------------------------------------------------------------------- template DiagonalMatrix operator*(const DiagonalMatrix& A, const DiagonalMatrix& B) { SSCK(A, B, "DiagonalMatrix-DiagonalMatrix multiplication"); DiagonalMatrix R(A); for (INDEX i=0; i DenseMatrix operator*(const DiagonalMatrix& A, const Matrix &B) { DenseMatrix C(A.nRows(), B.nCols(), true); A.MultAB(B, C); return C; } //----------------------------------------------------------------------------- // matrix-DiagonalMatrix multiplication //----------------------------------------------------------------------------- template DenseMatrix operator*(const Matrix &B, const DiagonalMatrix& A) { GCK(B, A, B.nCols()!=A.nRows(), "Matrix-DiagonalMatrix multiplication"); DenseMatrix R(B); // makes a copy of r to return for (INDEX j=0; j DenseVector operator*(const DiagonalMatrix& A, const Vector &b) { GCK(A, b, A.nCols()!=b.size(), "DiagonalMatrix-Vector multiplication"); DenseVector r(b); // makes a copy of r to return for (INDEX i=0; i DenseVector operator*(const Vector &b, const DiagonalMatrix& A) { GCK(b, A, b.size()!=A.nRows(), "Matrix-DiagonalMatrix multiplication"); DenseVector r(b); // makes a copy of r to return for (INDEX i=0; i SparseMatrix operator*(const DiagonalMatrix &A, const SparseMatrix& B) { GCK(A, B, A.nCols()!=B.nRows() ,"DiagonalMatrix-SparseMatrix multiplication"); SparseMatrix R(B); CloneVector d(A); R.row_scale(d); return R; } //----------------------------------------------------------------------------- // DiagonalMatrix-scalar multiplication //----------------------------------------------------------------------------- template DiagonalMatrix operator*(DiagonalMatrix &A, const T s) { DiagonalMatrix R(A); R *= s; return R; } //----------------------------------------------------------------------------- // Commute with DiagonalMatrix * double //----------------------------------------------------------------------------- template DiagonalMatrix operator*(const T s, const DiagonalMatrix& A) { DiagonalMatrix R(A); R *= s; return R; } //----------------------------------------------------------------------------- // DiagonalMatrix addition //----------------------------------------------------------------------------- template DiagonalMatrix operator+(const DiagonalMatrix &A, const DiagonalMatrix &B) { DiagonalMatrix R(A); R+=B; return R; } //----------------------------------------------------------------------------- // DiagonalMatrix subtraction //----------------------------------------------------------------------------- template DiagonalMatrix operator-(const DiagonalMatrix &A, const DiagonalMatrix &B) { DiagonalMatrix R(A); return R-=B; } //----------------------------------------------------------------------------- // template member definitions //----------------------------------------------------------------------------- //----------------------------------------------------------------------------- // Default constructor - optionally zeros the matrix //----------------------------------------------------------------------------- template DiagonalMatrix::DiagonalMatrix(INDEX rows, bool zero) : _data(NULL) { reset(rows, zero); } //----------------------------------------------------------------------------- // copy constructor - makes a full copy //----------------------------------------------------------------------------- template DiagonalMatrix::DiagonalMatrix(const DiagonalMatrix& c) : Matrix(), _data(NULL) { reset(c); } //----------------------------------------------------------------------------- // copy constructor from vector //----------------------------------------------------------------------------- template DiagonalMatrix::DiagonalMatrix(const Vector& v) : Matrix(), _data(NULL) { reset(v); } //----------------------------------------------------------------------------- // destructor //----------------------------------------------------------------------------- template DiagonalMatrix::~DiagonalMatrix() { _delete(); } //----------------------------------------------------------------------------- // deletes the data stored by this matrix //----------------------------------------------------------------------------- template void DiagonalMatrix::_delete() { if (_data) delete _data; } //----------------------------------------------------------------------------- // resizes the matrix, ignores nCols, optionally zeros //----------------------------------------------------------------------------- template void DiagonalMatrix::reset(INDEX rows, INDEX cols, bool zero) { _delete(); _data = new DenseVector(rows, zero); } //----------------------------------------------------------------------------- // resizes the matrix, ignores nCols, optionally copies what fits //----------------------------------------------------------------------------- template void DiagonalMatrix::resize(INDEX rows, INDEX cols, bool copy) { _data->resize(rows, copy); } //----------------------------------------------------------------------------- // changes the diagonal of the matrix to a vector v (makes a copy) //----------------------------------------------------------------------------- template void DiagonalMatrix::reset(const Vector& v) { if (&v == _data) return; // check for self-reset _delete(); _data = new DenseVector(v); } //----------------------------------------------------------------------------- // copys from another DiagonalMatrix //----------------------------------------------------------------------------- template void DiagonalMatrix::reset(const DiagonalMatrix& c) { reset(*(c._data)); } //----------------------------------------------------------------------------- // copys from a single column matrix //----------------------------------------------------------------------------- template void DiagonalMatrix::reset(const DenseMatrix& c) { GCHK(c.nCols()!=1,"DiagonalMatrix reset from DenseMatrix"); copy(c.ptr(),c.nRows(),c.nRows()); } //----------------------------------------------------------------------------- // resizes the matrix and copies data //----------------------------------------------------------------------------- template void DiagonalMatrix::copy(const T * ptr, INDEX rows, INDEX cols) { if (_data) _data->reset(rows, false); else _data = new DenseVector(rows, false); _data->copy(ptr,rows,cols); } //----------------------------------------------------------------------------- // shallow reset from another DiagonalMatrix //----------------------------------------------------------------------------- template void DiagonalMatrix::shallowreset(const DiagonalMatrix &c) { _delete(); _data = new CloneVector(*(c._data)); } //----------------------------------------------------------------------------- // shallow reset from Vector //----------------------------------------------------------------------------- template void DiagonalMatrix::shallowreset(const Vector &v) { _delete(); _data = new CloneVector(v); } //----------------------------------------------------------------------------- // shallow reset from a DenseMatrix //----------------------------------------------------------------------------- template void DiagonalMatrix::shallowreset(const DenseMatrix &c) { _delete(); _data = new CloneVector(c,CLONE_COL); } //----------------------------------------------------------------------------- // reference indexing operator - must throw an error if i!=j //----------------------------------------------------------------------------- template T& DiagonalMatrix::operator()(INDEX i, INDEX j) { GCK(*this,*this,i!=j,"DiagonalMatrix: tried to index off diagonal"); return (*this)[i]; } //----------------------------------------------------------------------------- // value indexing operator - returns 0 if i!=j //----------------------------------------------------------------------------- template T DiagonalMatrix::operator()(INDEX i, INDEX j) const { return (i==j) ? (*_data)(i) : (T)0; } //----------------------------------------------------------------------------- // flat reference indexing operator //----------------------------------------------------------------------------- template T& DiagonalMatrix::operator[](INDEX i) { return (*_data)(i); } //----------------------------------------------------------------------------- // flat value indexing operator //----------------------------------------------------------------------------- template T DiagonalMatrix::operator[](INDEX i) const { return (*_data)(i); } //----------------------------------------------------------------------------- // returns the number of rows //----------------------------------------------------------------------------- template INDEX DiagonalMatrix::nRows() const { return _data->size(); } //----------------------------------------------------------------------------- // returns the number of columns (same as nCols()) //----------------------------------------------------------------------------- template INDEX DiagonalMatrix::nCols() const { return _data->size(); } //----------------------------------------------------------------------------- // returns a pointer to the diagonal values, dangerous! //----------------------------------------------------------------------------- template T* DiagonalMatrix::ptr() const { return _data->ptr(); } //----------------------------------------------------------------------------- // writes the diagonal to a binary data restart file //----------------------------------------------------------------------------- template void DiagonalMatrix::write_restart(FILE *f) const { _data->write_restart(f); } //----------------------------------------------------------------------------- // sets the diagonal to a constant //----------------------------------------------------------------------------- template DiagonalMatrix& DiagonalMatrix::operator=(const T v) { this->set_all_elements_to(v); return *this; } //----------------------------------------------------------------------------- // assignment operator with another diagonal matrix //----------------------------------------------------------------------------- template DiagonalMatrix& DiagonalMatrix::operator=(const DiagonalMatrix& C) { reset(C); return *this; } //----------------------------------------------------------------------------- // writes a matlab command to duplicate this sparse matrix //----------------------------------------------------------------------------- template void DiagonalMatrix::matlab(std::ostream &o, const std::string &s) const { _data->matlab(o, s); o << s <<"=diag("< DiagonalMatrix& DiagonalMatrix::inv_this() { for(INDEX i=0; iminabs() / _data->maxabs(); if (min_max > 1e-14) return *this; std::cout << "DiagonalMatrix::inv_this(): Warning: Matrix is badly scaled."; std::cout << " RCOND = "< DiagonalMatrix DiagonalMatrix::inv() const { DiagonalMatrix invA(*this); // Make copy of A to invert for(INDEX i=0; iminabs() / _data->maxabs(); if (min_max > 1e-14) return invA; std::cout << "DiagonalMatrix::inv(): Warning: Matrix is badly scaled."; std::cout << " RCOND = "< inv(const DiagonalMatrix& A) { return A.inv(); } //----------------------------------------------------------------------------- // general diagonalmatrix assigment //----------------------------------------------------------------------------- template void DiagonalMatrix::_set_equal(const Matrix &r) { this->resize(r.nRows(), r.nCols()); const Matrix *pr = &r; const SparseMatrix *ps = dynamic_cast*> (pr); const DiagonalMatrix *pd = dynamic_cast*> (pr); const Vector *pv = dynamic_cast*> (pr); if (ps) this->reset(ps->diag()); else if (pd) this->reset(*pd); else if (pv) this->reset(*pv); else { std::cout <<"Error in general diagonal matrix assignment\n"; exit(1); } } //----------------------------------------------------------------------------- // casts a generic matrix pointer into a DiagonalMatrix pointer - null if fail //----------------------------------------------------------------------------- template const DiagonalMatrix *diag_cast(const Matrix *m) { return dynamic_cast*>(m); } } // end namespace #endif