Source

go-gsl / gsl / matrix.go

Full commit
// Copyright (C) 2010  The Go-GSL Authors.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package gsl

/*
#include <gsl/gsl_matrix.h>
*/
import "C"

import (
	"strconv"
	"unsafe"
)

type Matrix struct {
	Size1 int
	Size2 int
	Tda   int
	Data  int
	Block []float64
}

type MatrixView Matrix

func NewMatrix(n1, n2 int) *Matrix {
	return &Matrix{
		Size1: n1,
		Size2: n2,
		Tda:   n2,
		Data:  0,
		Block: make([]float64, n1*n2),
	}
}

func (m *Matrix) String() string {
	s := ""
	for i := 0; i < m.Size1; i++ {
		s += "["
		for j := 0; j < m.Size2; j++ {
			s += strconv.FormatFloat(m.Get(i, j), 'g', 6, 64)
			if j != m.Size2-1 {
				s += ", "
			}
		}
		if i == m.Size1-1 {
			s += "]"
		} else {
			s += "],\n "
		}
	}
	return "[" + s + "]"
}

func (m *Matrix) toGSL() *C.gsl_matrix {
	return &C.gsl_matrix{
		size1: C.size_t(m.Size1),
		size2: C.size_t(m.Size2),
		tda:   C.size_t(m.Tda),
		data:  (*C.double)(unsafe.Pointer(&m.Block[m.Data])),
		block: &C.gsl_block{
			size: C.size_t(len(m.Block)),
			data: (*C.double)(unsafe.Pointer(&m.Block)),
		},
		owner: 0,
	}
}

// Accessing matrix elements

func (m *Matrix) Get(i, j int) float64 {
	return m.Block[m.Data+i*m.Tda+j]
}

func (m *Matrix) Set(i, j int, x float64) {
	m.Block[m.Data+i*m.Tda+j] = x
}

// Initializing matrix elements

func (m *Matrix) SetAll(x float64) {
	C.gsl_matrix_set_all(m.toGSL(), C.double(x))
}

func (m *Matrix) SetZero() {
	C.gsl_matrix_set_zero(m.toGSL())
}

func (m *Matrix) SetIdentity() {
	C.gsl_matrix_set_identity(m.toGSL())
}

// Matrix views

func (m *Matrix) Submatrix(k1, k2, n1, n2 int) *MatrixView {
	return &MatrixView{
		Size1: n1,
		Size2: n2,
		Tda:   m.Tda,
		Data:  k1*m.Tda + k2,
		Block: m.Block[m.Data:],
	}
}

// Copying rows and columns

func (m *Matrix) GetRow(i int) (*Vector, error) {
	v := NewVector(m.Size2)
	n := int(C.gsl_matrix_get_row(v.toGSL(), m.toGSL(), C.size_t(i)))
	if n != 0 {
		return nil, NewError(n)
	}
	return v, nil
}

func (m *Matrix) GetCol(j int) (*Vector, error) {
	v := NewVector(m.Size1)
	n := int(C.gsl_matrix_get_row(v.toGSL(), m.toGSL(), C.size_t(j)))
	if n != 0 {
		return nil, NewError(n)
	}
	return v, nil
}

func (m *Matrix) SetRow(i int, v *Vector) error {
	return errnoToError(C.gsl_matrix_set_row(m.toGSL(), C.size_t(i), v.toGSL()))
}

func (m *Matrix) SetCol(j int, v *Vector) error {
	return errnoToError(C.gsl_matrix_set_col(m.toGSL(), C.size_t(j), v.toGSL()))
}

// Matrix operations

func (m *Matrix) Add(n *Matrix) error {
	return errnoToError(C.gsl_matrix_add(m.toGSL(), n.toGSL()))
}

func (m *Matrix) Sub(n *Matrix) error {
	return errnoToError(C.gsl_matrix_sub(m.toGSL(), n.toGSL()))
}

func (m *Matrix) MulElements(n *Matrix) error {
	return errnoToError(C.gsl_matrix_mul_elements(m.toGSL(), n.toGSL()))
}

func (m *Matrix) DivElements(n *Matrix) error {
	return errnoToError(C.gsl_matrix_div_elements(m.toGSL(), n.toGSL()))
}

func (m *Matrix) Scale(x float64) error {
	return errnoToError(C.gsl_matrix_scale(m.toGSL(), C.double(x)))
}

func (m *Matrix) AddConstant(x float64) error {
	return errnoToError(C.gsl_matrix_add_constant(m.toGSL(), C.double(x)))
}

// Finding maximum and minimum elements of matrices

func (m *Matrix) Max() float64 {
	return float64(C.gsl_matrix_max(m.toGSL()))
}

func (m *Matrix) Min() float64 {
	return float64(C.gsl_matrix_min(m.toGSL()))
}

func (m *Matrix) Minmax() (min, max float64) {
	C.gsl_matrix_minmax(m.toGSL(),
		(*C.double)(unsafe.Pointer(&min)),
		(*C.double)(unsafe.Pointer(&max)))
	return
}

func (m *Matrix) MaxIndex() (imax, jmax int) {
	var i, j C.size_t
	C.gsl_matrix_max_index(m.toGSL(), &i, &j)
	imax = int(i)
	jmax = int(j)
	return
}

func (m *Matrix) MinIndex() (imin, jmin int) {
	var i, j C.size_t
	C.gsl_matrix_min_index(m.toGSL(), &i, &j)
	imin = int(i)
	jmin = int(j)
	return
}

func (m *Matrix) MinmaxIndex() (imin, jmin, imax, jmax int) {
	var is, js, ib, jb C.size_t
	C.gsl_matrix_minmax_index(m.toGSL(), &is, &js, &ib, &jb)
	imin, jmin = int(is), int(js)
	imax, jmax = int(ib), int(jb)
	return
}

// Matrix properties

func (m *Matrix) Isnull() bool {
	return C.gsl_matrix_isnull(m.toGSL()) != 0
}

func (m *Matrix) Ispos() bool {
	return C.gsl_matrix_ispos(m.toGSL()) != 0
}

func (m *Matrix) Isneg() bool {
	return C.gsl_matrix_isneg(m.toGSL()) != 0
}

func (m *Matrix) Isnonneg() bool {
	return C.gsl_matrix_isnonneg(m.toGSL()) != 0
}