Source

neglect / neglect / cpplib / include / neglect / matrix.hpp

#ifndef _INC_NEGLECT_MATRIX_HPP_
#define _INC_NEGLECT_MATRIX_HPP_

#include <neglect/boot.hpp>
#include <neglect/vector.hpp>
#include <neglect/math.hpp>

namespace neglect {

    /* matrices are stored column major and columns are represented
       as vectors.  Because of that it's currently not possible to
       create matrices with a height != 2, 3 or 4. */

    template <typename T, size_t M, size_t N>
    class matrix {
    public:
        static const size_t rows = M;
        static const size_t columns = N;
        static const size_t dimensions = M * N;

        /* default constructor, warning: does not initialize the matrix */
        matrix()
        {
        }

        template <class Sequence>
        matrix(Sequence seq)
        {
            T *ptr = this->ptr();
            for (int i = 0; i < dimensions; i++)
                ptr[i] = seq[i];
        }

        static neglect::matrix<T, M, N> zero()
        {
            neglect::matrix<T, M, N> rv;
            rv.set_zero();
            return rv;
        }

        static neglect::matrix<T, M, N> identity()
        {
            neglect::matrix<T, M, N> rv;
            rv.set_identity();
            return rv;
        }

        void set_zero()
        {
            std::fill_n(m_cols, N, vector<T, M>());
        }

        void set_identity()
        {
            BOOST_STATIC_ASSERT(M == N);
            set_zero();
            for (size_t i = 0; i < M; i++)
                m_cols[i][i] = T(1);
        }

        T *ptr()
        {
            return &m_cols[0].x;
        }

        const T *ptr() const
        {
            return &m_cols[0].x;
        }

        T &operator()(size_t row, size_t column)
        {
            return m_cols[column][row];
        }

        T operator()(size_t row, size_t column) const
        {
            return m_cols[column][row];
        }

        neglect::vector<T, M> &operator[](size_t row)
        {
            return m_cols[row];
        }

        const neglect::vector<T, M> &operator[](size_t row) const
        {
            return m_cols[row];
        }

    private:
        neglect::vector<T, M> m_cols[N];
    };

    // common typedefs
    typedef matrix<float, 2, 2> mat2;
    typedef matrix<float, 3, 3> mat3;
    typedef matrix<float, 4, 4> mat4;
    typedef matrix<int, 2, 2> mati2;
    typedef matrix<int, 3, 3> mati3;
    typedef matrix<int, 4, 4> mati4;

    /* checks if two matrices are amost equal */
    template <typename TA, size_t MA, size_t NA,
              typename TB, size_t MB, size_t NB>
    bool operator==(const matrix<TA, MA, NA> &lhs,
                    const matrix<TB, MB, NB> &rhs)
    {
        if (MA != MB || NA != NB)
            return false;
        const TA *ptr_a = lhs.ptr();
        const TB *ptr_b = rhs.ptr();
        for (size_t i = 0; i < lhs.dimensions; i++)
            if (ptr_a[i] != ptr_b[i])
                return false;
        return true;
    }

    template <typename TA, size_t MA, size_t NA,
              typename TB, size_t MB, size_t NB>
    bool operator!=(const matrix<TA, MA, NA> &lhs,
                    const matrix<TB, MB, NB> &rhs)
    {
        return !(lhs == rhs);
    }

    template <typename TA, size_t MA, size_t NA,
              typename TB, size_t MB, size_t NB>
    bool almost_equal(const matrix<TA, MA, NA> &lhs,
                      const matrix<TB, MB, NB> &rhs)
    {
        if (MA != MB || NA != NB)
            return false;
        const TA *ptr_a = lhs.ptr();
        const TB *ptr_b = rhs.ptr();
        for (size_t i = 0; i < lhs.dimensions; i++)
            if (!almost_equal(ptr_a[i], ptr_b[i]))
                return false;
        return true;
    }

    // matrix mathematic functions

    template <typename T, size_t M, size_t N>
    matrix<T, N, M> transpose(const matrix<T, M, N> &mat)
    {
        matrix<T, N, M> rv;
        for (size_t i = 0; i < N; i++)
            for (size_t j = 0; j < M; j++)
                rv[j][i] = mat[i][j];
        return rv;
    }

    // matrix factory functions

    /* multiply two matrices */
    template <typename T, size_t M, size_t N, size_t P>
    matrix<T, M, P> operator*(const matrix<T, M, N> &lhs,
                              const matrix<T, N, P> &rhs)
    {
        matrix<T, M, P> rv;
        for (size_t i = 0; i < M; i++)
            for (size_t j = 0; j < P; j++) {
                T sum = T();
                for (size_t k = 0; k < N; k++)
                    sum += lhs(i, k) * rhs(k, j);
                rv(i, j) = sum;
            }
        return rv;
    }

    /* in-place multiplication is possible if the dimension of
       lhs is the dimension of the resulting matrix. */
    template <typename T, size_t M, size_t N>
    matrix<T, M, N> &operator*=(matrix<T, M, N> &lhs,
                                const matrix<T, N, N> &rhs)
    {
        for (size_t i = 0; i < M; i++)
            for (size_t j = 0; j < N; j++) {
                T sum = T();
                for (size_t k = 0; k < N; k++)
                    sum += lhs(i, k) * rhs(k, j);
                lhs(i, j) = sum;
            }
        return lhs;
    }

    /* multiply a matrix with a vector */
    template <typename T, size_t M, size_t N>
    vector<T, M> operator*(const matrix<T, M, N> &lhs,
                           const vector<T, N> &rhs)
    {
        vector<T, M> rv;
        for (size_t i = 0; i < M; i++)
            for (size_t k = 0; k < N; k++)
                rv[i] = lhs(i, k) * rhs[k];
        return rv;
    }

    /* matrix addition */
    template <typename T, size_t M, size_t N>
    matrix<T, M, N> &operator+=(matrix<T, M, N> &lhs,
                                const matrix<T, M, N> &rhs)
    {
        T *ptr_a = lhs.ptr();
        T *ptr_b = rhs.ptr();
        for (size_t i = 0; i < (M * N); i++)
            ptr_a[i] += ptr_b[i];
        return lhs;
    }

    template <typename T, size_t M, size_t N>
    matrix<T, M, N> operator+(const matrix<T, M, N> &lhs,
                              const matrix<T, M, N> &rhs)
    {
        matrix<T, M, N> rv = lhs;
        rv += rhs;
        return rv;
    }

    /* matrix subtraction */
    template <typename T, size_t M, size_t N>
    matrix<T, M, N> &operator-=(matrix<T, M, N> &lhs,
                                const matrix<T, M, N> &rhs)
    {
        T *ptr_a = lhs.ptr();
        T *ptr_b = rhs.ptr();
        for (size_t i = 0; i < (M * N); i++)
            ptr_a[i] -= ptr_b[i];
        return lhs;
    }

    template <typename T, size_t M, size_t N>
    matrix<T, M, N> operator-(const matrix<T, M, N> &lhs,
                              const matrix<T, M, N> &rhs)
    {
        matrix<T, M, N> rv = lhs;
        rv -= rhs;
        return rv;
    }

    /* create a smaller matrix by excluding a row and column */

    template <typename T, size_t N>
    matrix<T, N - 1, N - 1> exclude(const matrix<T, N, N> &mat,
                                    size_t row, size_t column)
    {
        assert(row < N && column < N);
        matrix<T, N - 1, N - 1> rv;
        for (size_t i = 0; i < N; i++) {
            if (i == row)
                continue;
            for (size_t j = 0; j < N; j++)
                if (j != column)
                    rv[j - (j > column)][i - (i > row)] = mat[j][i];
        }
        return rv;
    }

    template <typename T>
    T exclude(const matrix<T, 2, 2> &mat, size_t row, size_t column)
    {
        assert(row < 2 && column < 2);
        return mat[1 - column][1 - row];
    }

    /* determinant of a square matrix */

    template <typename T, size_t N>
    T determinant(const matrix<T, N, N> &mat)
    {
        T sum = T();
        for (size_t i = 0; i < N; i++)
            sum += mat[0][i] * math::pow(T(-1), static_cast<T>(i))
                * determinant(exclude(mat, i, 0));
        return sum;
    }

    template <typename T>
    T determinant(const matrix<T, 2, 2> &mat)
    {
        return mat(0, 0) * mat(1, 1) - mat(0, 1) * mat(1, 0);
    }

    template <typename T>
    T determinant(const matrix<T, 3, 3> &mat)
    {
        return (
            mat(0, 0) * mat(1, 1) * mat(2, 2) +
            mat(0, 1) * mat(1, 2) * mat(2, 0) +
            mat(0, 2) * mat(1, 0) * mat(2, 1) -
            mat(0, 0) * mat(1, 2) * mat(2, 1) -
            mat(0, 1) * mat(1, 0) * mat(2, 2) -
            mat(0, 2) * mat(1, 1) * mat(2, 0)
        );
    }

    /* return the adjugate matrix */

    template <typename T, size_t N>
    matrix<T, N, N> adjugate(const matrix<T, N, N> &mat)
    {
        matrix<T, N, N> rv;
        for (size_t i = 0; i < N; i++)
            for (size_t j = 0; j < N; j++)
                rv[j][i] = math::pow(static_cast<T>(-1), static_cast<T>(i + j))
                    * determinant(exclude(mat, i, j));
        return rv;
    }

    template <typename T>
    matrix<T, 2, 2> adjugate(const matrix<T, 2, 2> &mat)
    {
        matrix<T, 2, 2> rv;
        rv[0][0] =  mat[1][1];
        rv[0][1] = -mat[0][1];
        rv[1][0] = -mat[1][0];
        rv[1][1] =  mat[0][0];
        return rv;
    }

    /* invert a matrix */

    template <typename T, size_t N>
    matrix<T, N, N> inverse(const matrix<T, N, N> &mat)
    {
        matrix<T, N, N> amat = adjugate(mat);
        T l = T();
        for (size_t i = 0; i < N; i++)
            l += mat[0][i] * amat[i][0];

        for (size_t i = 0; i < N; i++)
            for (size_t j = 0; j < N; j++)
                amat[i][j] /= l;

        return amat;
    }

    /* creates a scale matrix */
    template <typename T>
    matrix<T, 4, 4> scale_matrix(T sx, T sy, T sz)
    {
        matrix<T, 4, 4> m;
        m.set_zero();
        m[0][0] = sx;
        m[1][1] = sy;
        m[2][2] = sz;
        m[3][3] = T(1);
        return m;
    }

    /* creates a translation matrix */
    template <typename T>
    matrix<T, 4, 4> translation_matrix(T tx, T ty, T tz)
    {
        matrix<T, 4, 4> m;
        m.set_identity();
        vector<T, 4> &last = m[3];
        last[0] = tx;
        last[1] = ty;
        last[2] = tz;
        return m;
    }

    /* creates a rotation matrix */
    template <typename T>
    matrix<T, 4, 4> rotation_matrix(T angle, T x, T y, T z)
    {
        T rad_angle = math::deg_to_rad(angle);
        T sin = math::sin(rad_angle), cos = math::cos(rad_angle);
        T length = math::sqrt(x * x + y * y + z * z);
        x /= length;
        y /= length;
        z /= length;
        T sx = sin * x, sy = sin * y, sz = sin * z;
        T ic = T(1) - cos;

        matrix<T, 4, 4> m;
        m[0] = vector<T, 4>(ic * x * x + cos,
                            ic * x * y + sz,
                            ic * x * z - sy,
                            T());
        m[1] = vector<T, 4>(ic * x * y - sz,
                            ic * y * y + cos,
                            ic * y * z + sz,
                            T());
        m[2] = vector<T, 4>(ic * x * z + sy,
                            ic * y * z - sz,
                            ic * z * z + cos,
                            T());
        m[3] = vector<T, 4>(T(), T(), T(), T(1));
        return m;
    }
}

#endif
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.