Commits

Jure Žbontar committed 6d6fe73

Initial import.

Comments (0)

Files changed (1)

+package main
+
+import (
+	"bytes"
+	"fmt"
+)
+
+type Matrix struct {
+	data         []float64
+	NRows, NCols int
+}
+
+func (a *Matrix) Get(i, j int) float64 {
+	return a.data[i*a.NCols+j]
+}
+
+func (a *Matrix) String() string {
+	var buf bytes.Buffer
+
+	for i := 0; i < a.NRows; i++ {
+		for j := 0; j < a.NCols; j++ {
+			buf.WriteString(fmt.Sprintf("%e ", a.Get(i, j)))
+		}
+		buf.WriteString("\n")
+	}
+	return buf.String()
+}
+
+func AssertRowVector(a *Matrix) {
+	if a.NRows != 1 {
+		panic("Expeced a row vector.")
+	}
+}
+
+func AssertColVector(a *Matrix) {
+	if a.NCols != 1 {
+		panic("Expeced a column vector.")
+	}
+}
+
+func AssertDotSize(a, b *Matrix) {
+	if a.NCols != b.NRows {
+		panic(fmt.Sprintf("Size mismatch. Trying to multiply %dx%d and %dx%d.", 
+			a.NRows, a.NCols, b.NRows, b.NCols))
+	}
+}
+
+func Scal(alpha float64, a *Matrix) *Matrix {
+	b := ZerosLike(a)
+	for i := range b.data {
+		b.data[i] = a.data[i] * alpha
+	}
+	return b
+}
+
+func Dot(a *Matrix, b *Matrix) float64 {
+	AssertRowVector(a)
+	AssertColVector(b)
+	AssertDotSize(a, b)
+
+	acc := 0.0
+	for i := range a.data {
+		acc += a.data[i] * b.data[i]
+	}
+
+	return acc
+}
+
+func Zeros(nrows, ncols int) *Matrix {
+	return &Matrix{make([]float64, nrows*ncols), nrows, ncols}
+}
+
+func ZerosLike(a *Matrix) *Matrix {
+	return Zeros(a.NRows, a.NCols)
+}
+
+func Ones(nrows, ncols int) *Matrix {
+	a := Zeros(nrows, ncols)
+	for i := range a.data {
+		a.data[i] = 1
+	}
+	return a
+}
+
+func main() {
+	a := Scal(2.3, Ones(1, 4))
+	b := Scal(4.1, Ones(4, 1))
+
+	fmt.Println(Dot(a, b))
+}