Commits

Jure Žbontar  committed b395085

Link BLAS.

  • Participants
  • Parent commits 6d6fe73

Comments (0)

Files changed (3)

+package numeric
+
+/*
+#cgo LDFLAGS: -lblas
+extern void dscal_(int *n, double *alpha, double *x, int *incx);
+extern double dnrm2_(int *n, double *x, int *incx);
+extern double dasum_(int *n, double *x, int *incx);
+extern void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
+*/
+import "C"
+
+var one = C.int(1)
+
+func Scal(alpha float64, A *Matrix) *Matrix {
+	C.dscal_(&A.c_size, (*C.double)(&alpha), A.c_data, &one)
+	return A
+}
+
+func Nrm2(A *Matrix) float64 {
+	return float64(C.dnrm2_(&A.c_size, A.c_data, &one))
+}
+
+func Asum(A *Matrix) float64 {
+	return float64(C.dasum_(&A.c_size, A.c_data, &one))
+}
+
+func blasCopy(A, B *Matrix) {
+	AssertDimensions(A, B.NRows, B.NCols)
+
+	C.dcopy_(&A.c_size, A.c_data, &one, B.c_data, &one)
+}
-package main
+package numeric
 
 import (
+	"C"
 	"bytes"
 	"fmt"
+	"unsafe"
 )
 
 type Matrix struct {
-	data         []float64
-	NRows, NCols int
+	data                 []float64
+	NRows, NCols, s0, s1 int
+
+	c_size C.int
+	c_data *C.double
 }
 
-func (a *Matrix) Get(i, j int) float64 {
-	return a.data[i*a.NCols+j]
+func New(data []float64, nrows, ncols int) *Matrix {
+	if nrows*ncols != len(data) {
+		panic("Size mismatch.")
+	}
+
+	return &Matrix{
+		data: data, 
+		NRows: nrows, 
+		NCols: ncols,
+		s0: ncols, 
+		s1: 1,
+		c_size: C.int(nrows * ncols), 
+		c_data: (*C.double)(unsafe.Pointer(&data[0])),
+	}
 }
 
-func (a *Matrix) String() string {
+func (A *Matrix) Copy() *Matrix {
+	B := *A
+	B.data = make([]float64, len(A.data))
+	B.c_data = (*C.double)(unsafe.Pointer(&B.data[0]))
+
+	blasCopy(A, &B)
+	return &B
+}
+
+func (A *Matrix) View() *Matrix {
+	B := *A
+	return &B
+}
+
+func (A *Matrix) Get(i, j int) float64 {
+	return A.data[i*A.s0+j*A.s1]
+}
+
+func (A *Matrix) String() string {
 	var buf bytes.Buffer
 
-	for i := 0; i < a.NRows; i++ {
-		for j := 0; j < a.NCols; j++ {
-			buf.WriteString(fmt.Sprintf("%e ", a.Get(i, j)))
+	for i := 0; i < A.NRows; i++ {
+		for j := 0; j < A.NCols; j++ {
+			buf.WriteString(fmt.Sprintf("%e ", A.Get(i, j)))
 		}
 		buf.WriteString("\n")
 	}
 	return buf.String()
 }
 
-func AssertRowVector(a *Matrix) {
-	if a.NRows != 1 {
-		panic("Expeced a row vector.")
+func (A *Matrix) Transpose() *Matrix {
+	B := A.View()
+	B.NRows, B.NCols = A.NCols, A.NRows
+	B.s0, B.s1 = A.s1, A.s0
+	return B
+}
+
+func AssertDimensions(A *Matrix, nrows, ncols int) {
+	if nrows > 0 && A.NRows != nrows {
+		panic(fmt.Sprintf("Expected %d rows, got %d.", nrows, A.NRows))
+	}
+
+	if ncols > 0 && A.NCols != ncols {
+		panic(fmt.Sprintf("Expected %d columns, got %d.", ncols, A.NCols))
 	}
 }
 
-func AssertColVector(a *Matrix) {
-	if a.NCols != 1 {
-		panic("Expeced a column vector.")
+func Addc(A *Matrix, c float64) *Matrix {
+	for i := range A.data {
+		A.data[i] += c
 	}
+	return A
 }
 
-func AssertDotSize(a, b *Matrix) {
-	if a.NCols != b.NRows {
-		panic(fmt.Sprintf("Size mismatch. Trying to multiply %dx%d and %dx%d.", 
-			a.NRows, a.NCols, b.NRows, b.NCols))
+func Add(A *Matrix, B *Matrix) *Matrix {
+	AssertDimensions(A, B.NRows, B.NCols)
+
+	for i := range A.data {
+		A.data[i] += B.data[i]
 	}
+	return A
 }
 
-func Scal(alpha float64, a *Matrix) *Matrix {
-	b := ZerosLike(a)
-	for i := range b.data {
-		b.data[i] = a.data[i] * alpha
+func Mul(A *Matrix, B *Matrix) *Matrix {
+	AssertDimensions(A, B.NRows, B.NCols)
+
+	for i := range A.data {
+		A.data[i] *= B.data[i]
 	}
-	return b
+	return A
 }
 
-func Dot(a *Matrix, b *Matrix) float64 {
-	AssertRowVector(a)
-	AssertColVector(b)
-	AssertDotSize(a, b)
+func Dot(A *Matrix, B *Matrix) float64 {
+	AssertDimensions(A, 1, B.NRows)
+	AssertDimensions(B, 0, 1)
 
 	acc := 0.0
-	for i := range a.data {
-		acc += a.data[i] * b.data[i]
+	for i := range A.data {
+		acc += A.data[i] * B.data[i]
 	}
 
 	return acc
 }
 
 func Zeros(nrows, ncols int) *Matrix {
-	return &Matrix{make([]float64, nrows*ncols), nrows, ncols}
+	return New(make([]float64, nrows*ncols), nrows, ncols)
 }
 
-func ZerosLike(a *Matrix) *Matrix {
-	return Zeros(a.NRows, a.NCols)
+func ZerosLike(A *Matrix) *Matrix {
+	return Zeros(A.NRows, A.NCols)
 }
 
 func Ones(nrows, ncols int) *Matrix {
-	a := Zeros(nrows, ncols)
-	for i := range a.data {
-		a.data[i] = 1
+	A := Zeros(nrows, ncols)
+	for i := range A.data {
+		A.data[i] = 1
 	}
-	return a
+	return A
 }
-
-func main() {
-	a := Scal(2.3, Ones(1, 4))
-	b := Scal(4.1, Ones(4, 1))
-
-	fmt.Println(Dot(a, b))
-}

File matrix_test.go

+package numeric
+
+import "testing"
+
+func TestMisc(t *testing.T) {
+	A := New([]float64{1, 5, 2, 4, 3}, 1, 5)
+	println(A.Copy().String())
+}