Commits

Fazlul Shahriar committed d1b361b

add 1 level2 function

Comments (0)

Files changed (3)

 import "C"
 
 const (
-	RowMajor  = C.enum_CBLAS_ORDER(101)
-	ColMajor  = C.enum_CBLAS_ORDER(102)
+	rowMajor  = C.enum_CBLAS_ORDER(101)
+	colMajor  = C.enum_CBLAS_ORDER(102)
 	NoTrans   = C.enum_CBLAS_TRANSPOSE(111)
 	Trans     = C.enum_CBLAS_TRANSPOSE(112)
 	ConjTrans = C.enum_CBLAS_TRANSPOSE(113)
 	Stride int          // step-size for next element in Block
 	Data   []complex128 // the data
 }
+
+type Float32Matrix struct {
+	Rows int
+	Cols int
+	Lda  int // leading dimension
+	Data []float32
+}
+
+type Float64Matrix struct {
+	Rows int
+	Cols int
+	Lda  int // leading dimension
+	Data []float64
+}
+
+type Complex64Matrix struct {
+	Rows int
+	Cols int
+	Lda  int // leading dimension
+	Data []complex64
+}
+
+type Complex128Matrix struct {
+	Rows int
+	Cols int
+	Lda  int // leading dimension
+	Data []complex64
+}
 */
 import "C"
 
+import "unsafe"
+
+// Sgemv updates y: y <- alpha*transA(A)*x + beta*y.
+func Sgemv(transA uint32, alpha float32, A *Float32Matrix, x *Float32Vector,
+beta float32, y *Float32Vector) {
+	C.cblas_sgemv(rowMajor, transA, C.int(A.Rows), C.int(A.Cols), C.float(alpha),
+		(*C.float)(unsafe.Pointer(&A.Data[0])), C.int(A.Lda),
+		(*C.float)(unsafe.Pointer(&x.Data[0])), C.int(x.Stride), C.float(beta),
+		(*C.float)(unsafe.Pointer(&y.Data[0])), C.int(y.Stride))
+}
+
 /*
-void cblas_sgemv(const enum CBLAS_ORDER Order,
-                 const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
-                 const float alpha, const float *A, const int lda,
-                 const float *X, const int incX, const float beta,
-                 float *Y, const int incY);
 void cblas_sgbmv(const enum CBLAS_ORDER Order,
                  const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
                  const int KL, const int KU, const float alpha,
+// Copyright 2011 Fazlul Shahriar. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cblas
+
+import "testing"
+
+func F32Matrix(x []float32, cols int) *Float32Matrix {
+	return &Float32Matrix{Rows: len(x) / cols, Cols: cols, Lda: cols, Data: x}
+}
+
+func F32VectorEq(x, y *Float32Vector) bool {
+	if x.Size != y.Size {
+		return false
+	}
+	for i := 0; i < x.Size; i++ {
+		if x.Data[i*x.Stride] != y.Data[i*y.Stride] {
+			return false
+		}
+	}
+	return true
+}
+
+func TestSgemv(t *testing.T) {
+	x := F32Vector([]float32{1, 2, 3})
+	y := F32Vector([]float32{-5, 7, 4})
+	A := F32Matrix([]float32{
+		1, 0, 0,
+		0, 1, 0,
+		0, 0, 1,
+	},
+		3)
+	Sgemv(NoTrans, 1, A, x, 1, y)
+	y1 := F32Vector([]float32{-4, 9, 7})
+	if !F32VectorEq(y, y1) {
+		t.Errorf("Invalid update of y %v; expected %v\n", y, y1)
+	}
+}