// Author: Ce Liu (c) Dec, 2009; celiu@mit.edu // Modified By: Deepak Pathak (c) 2016; pathak@berkeley.edu #pragma once #include "stdio.h" #include "Vector.hpp" #include "project.hpp" #ifdef _QT #include #endif #include using namespace std; template class Matrix { private: int nRow,nCol; double* pData; static bool IsDispInfo; public: Matrix(void); Matrix(int _nrow,int _ncol,double* data=NULL); Matrix(const Matrix& matrix); ~Matrix(void); void releaseData(); void copyData(const Matrix& matrix); void allocate(const Matrix& matrix); void allocate(int _nrow,int _ncol); void reset(); bool dimMatch(const Matrix& matrix) const; bool dimcheck(const Matrix& matrix) const; void loadData(int _nrow,int _ncol,T* data); static void enableDispInfo(bool dispInfo=false){IsDispInfo=dispInfo;}; // display the matrix void printMatrix(); void identity(int ndim); // function to access the member variables inline int nrow() const{return nRow;}; inline int ncol() const{return nCol;}; inline double* data() {return pData;}; inline const double* data() const {return (const double*)pData;}; inline double operator [](int index) const{return pData[index];}; inline double& operator[](int index) {return pData[index];}; inline double data(int row,int col)const {return pData[row*nCol+col];}; inline double& data(int row,int col) {return pData[row*nCol+col];}; bool matchDimension(int _nrow,int _ncol) const {if(nRow==_nrow && nCol==_ncol) return true; else return false;}; bool matchDimension(const Matrix& matrix) const {return matchDimension(matrix.nrow(),matrix.ncol());}; // functions to check dimensions bool checkDimRight(const Vector& vector) const; bool checkDimRight(const Matrix& matrix) const; bool checkDimLeft(const Vector& vector) const; bool checkDimLeft(const Matrix& matrix) const; // functions for matrix computation void Multiply(Vector& result,const Vector& vect) const; void Multiply(Matrix& result,const Matrix& matrix) const; void transpose(Matrix& result) const; void fromVector(const Vector& vect); double norm2() const; double sum() const { double total = 0; for(int i = 0;i& matrix); Matrix& operator+=(double val); Matrix& operator-=(double val); Matrix& operator*=(double val); Matrix& operator/=(double val); Matrix& operator+=(const Matrix& matrix); Matrix& operator-=(const Matrix& matrix); Matrix& operator*=(const Matrix& matrix); Matrix& operator/=(const Matrix& matrix); friend Vector operator*(const Matrix& matrix,const Vector& vect); friend Matrix operator*(const Matrix& matrix1,const Matrix& matrix2); // solve linear systems void SolveLinearSystem(Vector& result,const Vector& b) const; void ConjugateGradient(Vector& result,const Vector& b) const; #ifdef _QT bool writeMatrix(QFile& file) const; bool readMatrix(QFile& file); #endif #ifdef _MATLAB void readMatrix(const mxArray* prhs); void writeMatrix(mxArray*& prhs) const; #endif }; template bool Matrix::IsDispInfo=false; template Matrix::Matrix(void) { nRow=nCol=0; pData=NULL; } template Matrix::Matrix(int nrow,int ncol,double* data) { nRow=nrow; nCol=ncol; pData=new T[nRow*nCol]; if(data==NULL) memset(pData,0,sizeof(T)*nRow*nCol); else memcpy(pData,data,sizeof(T)*nRow*nCol); } template Matrix::Matrix(const Matrix& matrix) { nRow=nCol=0; pData=NULL; copyData(matrix); } template Matrix::~Matrix(void) { releaseData(); } template void Matrix::releaseData() { if(pData!=NULL) delete pData; pData=NULL; nRow=nCol=0; } template void Matrix::copyData(const Matrix &matrix) { if(!dimMatch(matrix)) allocate(matrix); memcpy(pData,matrix.pData,sizeof(T)*nRow*nCol); } template bool Matrix::dimMatch(const Matrix& matrix) const { if(nCol==matrix.nCol && nRow==matrix.nRow) return true; else return false; } template bool Matrix::dimcheck(const Matrix& matrix) const { if(!dimMatch(matrix)) { cout<<"The dimensions of the matrices don't match!"< void Matrix::reset() { if(pData!=NULL) memset(pData,0,sizeof(T)*nRow*nCol); } template void Matrix::allocate(int nrow,int ncol) { releaseData(); nRow=nrow; nCol=ncol; if(nRow*nCol>0) { pData=new T[nRow*nCol]; memset(pData,0,sizeof(T)*nRow*nCol); } } template void Matrix::allocate(const Matrix& matrix) { allocate(matrix.nRow,matrix.nCol); } template void Matrix::loadData(int _nrow, int _ncol, T *data) { if(!matchDimension(_nrow,_ncol)) allocate(_nrow,_ncol); memcpy(pData,data,sizeof(T)*nRow*nCol); } template void Matrix::printMatrix() { for(int i=0;i void Matrix::identity(int ndim) { allocate(ndim,ndim); reset(); for(int i=0;i bool Matrix::checkDimRight(const Vector& vect) const { if(nCol==vect.dim()) return true; else { cout<<"The matrix and vector don't match in multiplication!"< bool Matrix::checkDimRight(const Matrix &matrix) const { if(nCol==matrix.nrow()) return true; else { cout<<"The matrix and matrix don't match in multiplication!"< bool Matrix::checkDimLeft(const Vector& vect) const { if(nRow==vect.dim()) return true; else { cout<<"The vector and matrix don't match in multiplication!"< bool Matrix::checkDimLeft(const Matrix &matrix) const { if(nRow==matrix.ncol()) return true; else { cout<<"The matrix and matrix don't match in multiplication!"< void Matrix::Multiply(Vector &result, const Vector&vect) const { checkDimRight(vect); if(result.dim()!=nRow) result.allocate(nRow); for(int i=0;i void Matrix::Multiply(Matrix &result, const Matrix &matrix) const { checkDimRight(matrix); if(!result.matchDimension(nRow,matrix.nCol)) result.allocate(nRow,matrix.nCol); for(int i=0;i void Matrix::transpose(Matrix &result) const { if(!result.matchDimension(nCol,nRow)) result.allocate(nCol,nRow); for(int i=0;i void Matrix::fromVector(const Vector&vect) { if(!matchDimension(vect.dim(),1)) allocate(vect.dim(),1); memcpy(pData,vect.data(),sizeof(double)*vect.dim()); } template double Matrix::norm2() const { if(pData==NULL) return 0; double temp=0; for(int i=0;i Matrix& Matrix::operator=(const Matrix& matrix) { copyData(matrix); return *this; } template Matrix& Matrix::operator +=(double val) { for(int i=0;i Matrix& Matrix::operator -=(double val) { for(int i=0;i Matrix& Matrix::operator *=(double val) { for(int i=0;i Matrix& Matrix::operator /=(double val) { for(int i=0;i Matrix& Matrix::operator +=(const Matrix &matrix) { dimcheck(matrix); for(int i=0;i Matrix& Matrix::operator -=(const Matrix &matrix) { dimcheck(matrix); for(int i=0;i Matrix& Matrix::operator *=(const Matrix &matrix) { dimcheck(matrix); for(int i=0;i Matrix& Matrix::operator /=(const Matrix &matrix) { dimcheck(matrix); for(int i=0;i Vector operator*(const Matrix& matrix,const Vector& vect) { Vector result; matrix.Multiply(result,vect); return result; } template Matrix operator*(const Matrix& matrix1,const Matrix& matrix2) { Matrix result; matrix1.Multiply(result,matrix2); return result; } //-------------------------------------------------------------------------------------------------- // function for conjugate gradient method //-------------------------------------------------------------------------------------------------- template void Matrix::ConjugateGradient(Vector &result, const Vector&b) const { if(nCol!=nRow) { cout<<"Error: when solving Ax=b, A is not square!"< r(b),p,q; result.reset(); int nIterations=nRow*5; Vector rou(nIterations); for(int k=0;k void Matrix::SolveLinearSystem(Vector &result, const Vector&b) const { if(nCol==nRow) { ConjugateGradient(result,b); return; } if(nRow AT,ATA; transpose(AT); AT.Multiply(ATA,*this); Vector ATb; AT.Multiply(ATb,b); ATA.ConjugateGradient(result,ATb); } #ifdef _QT template bool Matrix::writeMatrix(QFile &file) const { file.write((char *)&nRow,sizeof(int)); file.write((char *)&nCol,sizeof(int)); if(file.write((char *)pData,sizeof(double)*nRow*nCol)!=sizeof(double)*nRow*nCol) return false; return true; } template bool Matrix::readMatrix(QFile &file) { releaseData(); file.read((char *)&nRow,sizeof(int)); file.read((char *)&nCol,sizeof(int)); if(nRow*nCol>0) { allocate(nRow,nCol); if(file.read((char *)pData,sizeof(double)*nRow*nCol)!=sizeof(double)*nRow*nCol) return false; } return true; } #endif #ifdef _MATLAB template void Matrix::readMatrix(const mxArray* prhs) { if(pData!=NULL) delete pData; int nElements = mxGetNumberOfDimensions(prhs); if(nElements>2) mexErrMsgTxt("A matrix is expected to be loaded!"); const int* dims = mxGetDimensions(prhs); allocate(dims[0],dims[1]); double* data = (double*)mxGetData(prhs); for(int i =0; i void Matrix::writeMatrix(mxArray*& plhs) const { int dims[2]; dims[0]=nRow;dims[1]=nCol; plhs=mxCreateNumericArray(2, dims,mxDOUBLE_CLASS, mxREAL); double* data = (double *)mxGetData(plhs); for(int i =0; i