C plus plus:Tutorials:TemplateMatrix

From GDWiki

Jump to: navigation, search

Heavily copies from the templated vector class math::vector<N,T>

#ifndef MATH_MATRIX_HPP
#define MATH_MATRIX_HPP
 
#include <cstddef>
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
#include <cassert>
#include <stdexcept>
 
#ifndef MATH_MATRIX_NO_IO
#include <istream>
#include <ostream>
#endif
 
// by default, make matrix an aggregate
#define MATH_MATRIX_AGGREGATE 1
 
#ifdef BOOST_PREVENT_MACRO_SUBSTITUTION
#ifndef USE_BOOST
#ifndef NO_USE_BOOST
#define USE_BOOST 1
#endif
#endif
#endif
 
#ifdef USE_BOOST
#include <boost/call_traits.hpp>
#endif
 
#ifdef USE_BOOST
#include <boost/static_assert.hpp>
#else
#ifndef BOOST_STATIC_ASSERT
#define BOOST_STATIC_ASSERT(c) typedef int ASSERTION_CHECK[c]
#endif
#endif
 
namespace math {
 
template < std::size_t R = 4, std::size_t C = 4, typename T = float >
class matrix {
    enum { N = R * C };
#ifdef MATH_MATRIX_AGGREGATE
  public:
#endif
    T m_[N];
  public:
    typedef T value_type;
    typedef value_type const const_value_type;
 
    typedef value_type& reference;
    typedef const_value_type& const_reference;
 
    typedef value_type* pointer;
    typedef const_value_type* const_pointer;
 
    typedef std::size_t size_type;
    typedef std::ptrdiff_t difference_type;
 
#ifdef USE_BOOST
    typedef typename call_traits<value_type>::param_type parameter_type;
#else
//    typedef const_reference parameter_type;
    typedef const_value_type parameter_type;
#endif
 
    template <typename U>
    matrix& operator=(U (&a)[N]) {
        std::copy( a, a+N, m_ );
        return *this;
    }  
 
    template <typename U>
    matrix& operator=(matrix<R,C,U> const &m) {
        copy( m.begin(), m.begin()+N, m_ );
        return *this;
    }
 
#ifndef MATH_MATRIX_AGGREGATE
    matrix() {
        std::fill_n( m_, size(), static_cast<value_type>(0) );
    }
 
    template <typename U> /* implicit */ 
    matrix(U (&a)[N]) {
        *this = a;
    }
 
    explicit matrix(parameter_type x, size_type n = N) {
        if ( n > N ) n = N;
        std::fill( m_, m_+n, x );
        std::fill( m_+n, m_+N, static_cast<value_type>(0) );
    }
 
    explicit matrix(const_pointer p, size_type n = N) {
        if ( n > N ) n = N;
        std::copy( p, p+n, m_ );
        std::fill_n( m_+n, N-n, static_cast<value_type>(0) );
    }
 
    template <typename U>
    explicit matrix(matrix<R,C,U> const &m) {
        *this = m;
    }
#endif // MATH_MATRIX_AGGREGATE
 
    const_pointer m() const {
        return m_;
    }
 
    pointer m() { 
        return const_cast<pointer>(const_cast<matrix const &>(*this).m()); 
    }
 
    template <size_type r, size_type c>
    const_reference get() const { 
        BOOST_STATIC_ASSERT(r < R);
        BOOST_STATIC_ASSERT(c < C);
        return m_[c+r*C];
    }
    template <size_type r, size_type c>
    reference get() { 
        return const_cast<reference>(
                const_cast<vector const &>(*this).get<i>()
               );
    }
 
    template <size_type i>
    const_reference get() const { 
        BOOST_STATIC_ASSERT(i < N);
        return m_[i];
    }
    template <size_type i>
    reference get() { 
        return const_cast<reference>(
                const_cast<vector const &>(*this).get<i>()
               );
    }
 
    pointer begin() { return m_; }
    const_pointer begin() const { return m_; }
    pointer end() { return m_+N; }
    const_pointer end() const { return m_+N; }
    static size_type size() { return N; }
    // N==0 is impossible, since it would mean a 0-length array
    static bool empty() { return false; }
 
#ifndef MATH_MATRIX_NO_IMPLICIT_CONVERSION
    operator pointer() { return m(); }
    operator const_pointer() const { return m(); }
#else
    const_reference operator*() const {
        // This typedef means the function will only exist if N==1
        typedef int SIZE_CHECK[N==1];
        return *m();
    }
 
    reference operator*() { 
        return const_cast<reference>(*const_cast<matrix const &>(*this)); 
    }
 
    const_reference operator[](size_type i) const {
        assert( i < N || !"=> index out of range" );
        return m_[i];
    }
 
    reference operator[](size_type i) { 
        return const_cast<reference>(
                const_cast<matrix const &>(*this).operator[](i)
               );
    }
#endif
 
    // operator() is for access as though it were an RxC matrix
    const_reference operator()(size_type i, size_type j) const {
        assert( i < R || !"=> index out of range" );
        assert( j < C || !"=> index out of range" );
        return m_[j+(i*C)];
    }
 
    reference operator()(size_type i, size_type j) {
        return const_cast<reference>(
                    const_cast<matrix const &>(*this).operator()(i, j)
               );
    }
 
    matrix& operator+=(matrix const &mat) {
        std::transform( m(), m()+size(),
                        mat.m(),
                        m(),
                        std::plus<value_type>() );
        return *this;
    }
 
    matrix& operator-=(matrix const &mat) {
        std::transform( m(), m()+size(),
                        mat.m(),
                        m(),
                        std::minus<value_type>() );
        return *this;
    }
 
    matrix& operator*=(parameter_type c) {
        std::transform( m(), m()+size(),
                        m(),
                        std::bind2nd( std::multiplies<value_type>(), c ) );
        return *this;
    }                                          
 
    matrix& operator/=(parameter_type c) {
        std::transform( m(), m()+size(),
                        m(),
                        std::bind2nd( std::divides<value_type>(), c ) );
        return *this;
    }
 
    static const matrix identity() {
        BOOST_STATIC_ASSERT(R==C);
        matrix ret;
        for ( std::size_t i = 0; i < R; ++i ) {
            ret(i,i) = static_cast<value_type>(1);
        }
        return ret;
    }
};
 
template < std::size_t R, std::size_t C, typename T >
bool operator==(matrix<R,C,T> const &lhs, matrix<R,C,T> const &rhs) {
    return std::equal( lhs.m(), lhs.m()+lhs.size(),
                       rhs.m() );
}
template < std::size_t R, std::size_t C, typename T >
bool operator!=(matrix<R,C,T> const &lhs, matrix<R,C,T> const &rhs) {
    return !( lhs == rhs );
}
 
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator+(matrix<R,C,T> lhs, matrix<R,C,T> const &rhs) {
    return lhs += rhs;
}
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator-(matrix<R,C,T> lhs, matrix<R,C,T> const &rhs) {
    return lhs -= rhs;
}
 
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator*(matrix<R,C,T> mat,
                        typename matrix<R,C,T>::parameter_type c) 
{
    return mat*=c;
}
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator*(typename matrix<R,C,T>::parameter_type c,
                        matrix<R,C,T> mat) 
{
    return mat*=c;
}
 
/* matrix requires width of lhs (C1) equal to height of rhs */
template < std::size_t R1, std::size_t C1, std::size_t C2, typename T >
const matrix<R1,C2,T> operator*(matrix<R1,C1,T> const& lhs, 
                                matrix<C1,C2,T> const& rhs) 
{
    matrix<R1,C2,T> ret
#ifdef MATH_MATRIX_AGGREGATE
     = {}
#endif
    ;
 
    for ( std::size_t i = 0; i < R1; ++i ) {
        for ( std::size_t j = 0; j < C2; ++j ) {
            for ( std::size_t k = 0; k < C1; ++k ) {
                ret(i,j) += lhs(i,k) * rhs(k,j);
            }
        }
    }
 
    return ret;
}
 
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator/(matrix<R,C,T> mat,
                        typename matrix<R,C,T>::parameter_type c) 
{
    return mat/=c;
}
 
template < std::size_t R, std::size_t C, typename T >
matrix<R,C,T> operator-(matrix<R,C,T> mat) {
    std::transform( mat.begin(), mat.end(), 
                    mat.begin(), 
                    std::negate<T>() );
    return mat;
}
 
template <std::size_t R, std::size_t C, typename T>
matrix<C,R,T> transpose(matrix<R,C,T> const& mat) {
    matrix<C,R,T> m;
 
    for (std::size_t r = 0; r < C; ++r) {
        for (std::size_t c = 0; c < R; ++c) {
            m(r,c) = mat(c,r);
        }
    }
 
    return m;
}
 
template <std::size_t D, typename T>
matrix<D,D,T> inverse(matrix<D,D,T> const& mat) {
    T d = det(mat);
 
    if (d == 0) throw std::runtime_error("inverse matrix does not exist");
 
    return (static_cast<T>(1.0) / d) * mat;
}
 
template <typename T>
T det(matrix<1,1,T> const& mat) {
    return mat[0];
}
 
template <typename T>
T det(matrix<2,2,T> const& mat) {
    return mat[0]*mat[3] - mat[2]*mat[1];
}
 
template <typename T>
T det(matrix<3,3,T> const& mat) {
    return mat[0]*(mat[4]*mat[8] - mat[7]*mat[5]) -
           mat[3]*(mat[1]*mat[8] - mat[7]*mat[2]) +
           mat[6]*(mat[1]*mat[5] - mat[4]*mat[2]);
}
 
#ifndef MATH_MATRIX_NO_IO
template < std::size_t R, std::size_t C, typename T >
std::ostream& operator<<(std::ostream &sink, matrix<R,C,T> const &mat) {    
    for ( std::size_t i = 0; i < R; ++i ) {
        sink << "\n[ ";
 
        for ( std::size_t j = 0; j < C; ++j ) {
            sink << mat(i,j) << '\t';
        }
 
        sink << ']';
    }
 
    return sink;
}
 
template < std::size_t R, std::size_t C, typename T >
std::istream& operator>>(std::istream &source, matrix<R,C,T> &mat) {
    if ( !source.good() ) return source;
 
    char c = '\0';
    source >> c;
    if ( c != '(' ) {
        source.unget();
        source.clear( std::ios::failbit );
        return source;
    }
 
    for ( std::size_t i = 0; i < mat.size(); ++i ) {
 
        source >> mat[i];
        source.clear( source.rdstate() & ~std::ios::failbit );
        if ( !source.good() ) return source;
 
        c = '\0';
        source >> c;
        source.clear( source.rdstate() & ~std::ios::failbit );
        if ( c != ( (i == mat.size()-1) ? ')' : ',' ) ) {
            source.unget();
            source.clear( std::ios::failbit );
            return source;
        }
        if ( !source.good() ) return source;
 
    }
 
    return source;
}
#endif // MATH_MATRIX_NO_IO
 
} // namespace math
 
#endif
Personal tools
Categories