`
helpbs
  • 浏览: 1164785 次
文章分类
社区版块
存档分类
最新评论

矩阵类

 
阅读更多

此次的矩阵类可以使用双下标,并且带有越界检查能力

用例:

jks::CMatrix<int> m(3,4);
int i,j;
for (i=0;i<m.getHeight();i++)
{
for (j=0;j<m.getWidth();j++)
{
m[i][j] = i*10+j;
}
}

cerr<<m;

=========================================================================

头文件


#if !defined(__JKS_MATRIX_HPP_)
#define __JKS_MATRIX_HPP_

#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000

#include <cassert>
#include <iostream>
namespace jks
{
//////////////////////////////////////////////////////////////////////////

template<typename T>
class CMatrix
{
class CLine
{
T* _pLine;
CMatrix* _pMat;

public:
CLine(T* pLine,CMatrix* pMat):_pLine(pLine),_pMat(pMat) {};
T& operator[] (long nCol)
{assert(nCol>=0 && nCol<_pMat->getWidth());return _pLine[nCol];}

};

long _sLn,_sCol;
T* _pData;
public:
CMatrix();
CMatrix(T * arrAddress,long arrWidth);//构造一维矩阵(一行,arrWidth列)
CMatrix(T * arrAddress,long arrHeight,long arrWidth);//构造二维矩阵
CMatrix(long Height,long Width);//构造空矩阵
CMatrix(const CMatrix<T> &);//复制构造函数
virtual ~CMatrix(void);//默认析构函数

//属性
long getHeight() const {return _sLn;}
long getWidth() const {return _sCol;}
T* getData(long& arrHeight,long& arrWidth,T* p=NULL) const;
T* getData(int& arrHeight,int& arrWidth,T* p=NULL) const
{long t1 = arrHeight;long t2 = arrWidth;
arrHeight = _sLn;arrWidth = _sCol;
return getData(t1,t2,p);}

operator void* () {return isValid()?this:0;}
bool isValid() const
{
if(this!=NULL && _pData!=NULL) return true;
else return false;
}
bool isInDomain(long nLn,long nCol)
{
if (nLn<_sLn && nLn>=0 && nCol<_sCol && nCol>=0)
return true;
else
return false;
}
int isVector();//如果是0,那就是一个数;1,列向量;-1行向量

//运算
CMatrix operator+(CMatrix<T> &);
CMatrix operator-(CMatrix<T> &);
CMatrix operator*(CMatrix<T> &);
friend CMatrix operator*(double alpha,CMatrix<T> &);//实数与矩阵相乘
CMatrix operator*(double alpha);//矩阵与实数相乘
CMatrix operator/(CMatrix<T> &);//实际是实数相除或矩阵和实数相除
CMatrix operator/(double sub);
CMatrix operator+=(CMatrix<T> &);
CMatrix operator-=(CMatrix<T> &);
CMatrix operator*=(CMatrix<T> &);//矩阵与实数相乘
CMatrix operator*=(double alpha);//矩阵与实数相乘
CMatrix & operator = (CMatrix<T> &);//赋值
CLine operator[](long heightPos);//用于实现用[][]操作矩阵元素
friend CMatrix sqrt(CMatrix<T> m);//开方
friend double abs(CMatrix<T> &);//取绝对值(泛数)
friend double sum(CMatrix<T> &);//求和
friend CMatrix multiply(CMatrix<T> &m1,CMatrix<T> & m2);//按元素相乘
friend T operator+(double dbl,CMatrix<T> &);
friend T operator+(CMatrix<T> &,double dbl);
friend T operator-(double dbl,CMatrix<T> &);
friend T operator-(CMatrix<T> &,double dbl);
const T* c_ptr() const;
friend bool operator == (CMatrix<T> &,CMatrix<T> &);

//输出
friend std::ostream& operator<<(std::ostream &,jks::CMatrix<T> &);

public:
//公有属性
static float m_fPrecision;//控制==的精度
static double m_dPrecision;

};

//////////////////////////////////////////////////////////////////////////
//函数实现
template<typename T>
float CMatrix<T>::m_fPrecision = 0.0000001;
template<typename T>
double CMatrix<T>::m_dPrecision = 1e-20;

//////////////////////////////////////////////////////////////////////////
//构造与析构
template<typename T>
CMatrix<T>::CMatrix(void)//:_sCol(1),_sLn(1)
{
_sCol = 0;
_sLn = 0;
_pData=NULL;
}

template<typename T>
CMatrix<T>::CMatrix(T * arrAddress,long arrWidth)
{
long arrHeight=1;
_pData=new T[arrWidth*arrHeight];
memcpy(_pData,arrAddress,arrWidth*arrHeight*sizeof(T));

_sCol=arrWidth;
_sLn=arrHeight;
}

template<typename T>
CMatrix<T>::CMatrix(T * arrAddress,long arrHeight,long arrWidth)
{
_pData=new T[arrWidth*arrHeight];
memcpy(_pData,arrAddress,arrWidth*arrHeight*sizeof(T));

_sCol=arrWidth;
_sLn=arrHeight;
}

template<typename T>
CMatrix<T>::CMatrix(long height,long width)
{
_sCol=width;
_sLn=height;
_pData=new T[height*width];
}

template<typename T>
CMatrix<T>::CMatrix(const CMatrix<T> & m)//copy constructor
{
_sLn=m._sLn;
_sCol=m._sCol;

_pData=new T[_sLn*_sCol];
memcpy(_pData,m._pData,_sLn*_sCol*sizeof(T));
}

template<typename T>
CMatrix<T>::~CMatrix()
{
if (_pData)
{
delete []_pData;
}
_pData = NULL;

_sLn = 0;
_sCol = 0;

}

//////////////////////////////////////////////////////////////////////////
//运算

template<typename T>
CMatrix<T> CMatrix<T>::operator +(CMatrix &m1)
{
assert(m1._sLn==_sLn && m1._sCol==_sCol);
long tmpHeight=m1._sLn;
long tmpWidth=m1._sCol;
T * t=new T[tmpWidth*tmpHeight];
for(long i=0;i<tmpHeight;i++){
for(long j=0;j<tmpWidth;j++){
*(t+tmpWidth*i+j)=*((T*)m1._pData+tmpWidth*i+j)+*((T*)_pData+tmpWidth*i+j);
}
}
CMatrix<T> m(t,tmpHeight,tmpWidth);
delete [] t;
return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator -(CMatrix &m1)
{
assert(m1._sLn==_sLn && m1._sCol==_sCol);
long tmpHeight=m1._sLn;
long tmpWidth=m1._sCol;
T * t=new T[tmpWidth*tmpHeight];
for(long i=0;i<tmpHeight;i++){
for(long j=0;j<tmpWidth;j++){
*(t+tmpWidth*i+j)=*((T*)_pData+tmpWidth*i+j)-*((T*)m1._pData+tmpWidth*i+j);
}
}
CMatrix<T> m(t,tmpHeight,tmpWidth);
delete [] t;
return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator *(CMatrix &m1)
{
if(!this->isVector() && m1.isVector()){//左为数,右为矩阵
CMatrix<T> m;
m=((T*)_pData)[0]*m1;
return m;
}else if(this->isVector() && !m1.isVector()){//左为矩阵,右为数
CMatrix m;
m=*this*m1[0][0];
return m;
}else if(!this->isVector() && m1.isVector()){//左右都为数
T * t=new T[1];
t[0]=((T*)_pData)[0]*m1[0][0];
CMatrix<T> m(t,1,1);
delete [] t;
return m;
}else if(this->isVector() && m1.isVector() && _sCol==m1._sLn){//左为矩阵,右为矩阵
double sum;
T * t=new T[_sLn*m1._sCol];
for(long i=0;i<_sLn;i++){
for(long j=0;j<m1._sCol;j++){
sum=0;
for(long k=0;k<_sCol;k++){
sum+=(*((T*)_pData+_sCol*i+k))*(m1[k][j]);
}
*(t+m1._sCol*i+j)=sum;
}
}
CMatrix<T> m(t,_sLn,m1._sCol);
delete [] t;
return m;
}else{
assert(0);//未知运算
return *this;
}
}

template<typename T>
CMatrix<T> operator*(double alpha,CMatrix<T> & m1)
{
CMatrix<T> m=m1;
for(long i=0;i<m._sLn;i++){
for(long j=0;j<m._sCol;j++){
m[i][j]=alpha*m1[i][j];
}
}
return m;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator*(double alpha)
{
return alpha*(*this);
}

template<typename T>
CMatrix<T> CMatrix<T>::operator+=(CMatrix<T> & m)
{
return *this+m;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator-=(CMatrix<T> & m)
{
return *this-m;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator *=(double alpha)
{
return *this*alpha;
}

template<typename T>
CMatrix<T> CMatrix<T>::operator *=(CMatrix<T> & m1)
{
return *this*m1;
}


template<typename T>
const T* CMatrix<T>::c_ptr () const
{
return _pData;
}


template<typename T>
CMatrix<T> CMatrix<T>::operator /(CMatrix<T> &m1)
{
assert(m1._sCol==1 && m1._sLn==1);
assert(m1[0][0]!=0);
return *this/m1[0][0];
}

template<typename T>
CMatrix<T> CMatrix<T>::operator /(double sub)
{
assert(sub!=0);
CMatrix<T> m=*this;
for(long i=0;i<_sLn;i++){
for(long j=0;j<_sCol;j++){
m[i][j]=*((T*)_pData+_sCol*i+j)/sub;
}
}
return m;
}

template<typename T>
CMatrix<T> & CMatrix<T>::operator =(CMatrix<T> & m)
{
if(&m==this) return *this;

_sLn=m._sLn;
_sCol=m._sCol;
if(_pData)
{
delete [] _pData;
_pData = NULL;
}
_pData=new T[_sLn*_sCol];
memcpy(_pData,m._pData,_sLn*_sCol*sizeof(T));

return *this;
}

template<typename T>
bool operator == (CMatrix<T>& m1,CMatrix<T>& m2)
{
if (&m1 == &m2)
{
return true;
}

if (m1.getWidth()!=m2.getWidth() || m1.getHeight()!=m2.getHeight())
{
return false;
}

T *p1,*p2;
p1 = m1[0];
p2 = m2[0];
long sum = m1.getHeight()*m1.getWidth();
long i;
if(typeid(T) == typeid(float))
for (i=0;i<sum;i++)
{
if (fabs(*p1++ - *p2++)<CMatrix<T>::m_fPrecision)
{
return false;
}
}
if(typeid(T) == typeid(double))
for (i=0;i<sum;i++)
{
if (fabs(*p1++ - *p2++)<CMatrix<T>::m_dPrecision)
{
return false;
}
}
else
for (i=0;i<sum;i++)
{
if (*p1++ != *p2++)
{
return false;
}
}

return true;
}

template<typename T>
T operator+(double dbl,CMatrix<T> & m)
{
assert(m.getHeight()==1 && m.getWidth()==1);
return dbl+m[0][0];
}

template<typename T>
T operator+(CMatrix<T> & m,double dbl)
{
return dbl+m;
}

template<typename T>
T operator-(double dbl,CMatrix<T> & m)
{
assert(m.getHeight()==1 && m.getWidth()==1);
return dbl-m[0][0];
}

template<typename T>
T operator-(CMatrix<T> & m,double dbl)
{
return -(dbl-m);
}

template<typename T>
CMatrix<T>::CLine CMatrix<T>::operator [](long heightPos)
{
assert(isValid());
assert(heightPos>=0 && heightPos<=_sLn);//报错

CLine rLine(_pData+heightPos*_sCol,this);
return rLine;//取回的是行头指针
}

//////////////////////////////////////////////////////////////////////////
//输出

template<typename T>
std::ostream & operator<<(std::ostream & os,jks::CMatrix<T> & m)
{
os<<"Sum Ln:"<<m._sLn<<" "<<"Sum Col:"<<m._sCol<<std::endl;
long i,j;

if(typeid(T)==typeid(unsigned char))
for (i=0;i<m._sLn;i++)
{
for (j=0;j<m._sCol;j++)
{
os<<(int)m._pData[i*m._sCol+j]<<"/t";
}
os<<std::endl;
}
else
for (i=0;i<m._sLn;i++)
{
for (j=0;j<m._sCol;j++)
{
os<<m._pData[i*m._sCol+j]<<"/t";
}

os<<std::endl;
}

return os;
}

//////////////////////////////////////////////////////////////////////////
template<typename T>
int CMatrix<T>::isVector()
{
//return !(nWidth==1 && nHeight==1);
if (_sCol==1)
if (_sLn==1)
return 0;
else
return 1;
else
return -1;
}

//////////////////////////////////////////////////////////////////////////
}//
#endif // !defined(__JKS_MATRIX_HPP_)

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics