Markus Mottl avatar Markus Mottl committed 77ee5e2

Added Mat.prod_diag and Mat.prod_trace

Comments (0)

Files changed (5)

+2009-04-28:  Added new functions:
+
+               * Mat.prod_diag
+               * Mat.prod_trace
+
 2009-04-25:  Major API changes:
 
              Removed function:
 name="lacaml"
-version="5.0.1"
+version="5.0.2"
 description="LACAML - BLAS/LAPACK-interface for OCaml"
 
 requires="lacaml.core"
   ignore (get_dim_vec loc alphas_str ofs 1 alphas n_str (Some n));
   direct_scal_cols ~m ~n ~ofs ~alphas ~ar ~ac ~a
 
-external direct_axpy_mat :
+external direct_mat_axpy :
   m : int ->
   n : int ->
   alpha : num_type ->
   yr : int ->
   yc : int ->
   y : mat ->
-  unit = "lacaml_NPRECaxpy_mat_stub_bc" "lacaml_NPRECaxpy_mat_stub"
+  unit = "lacaml_NPRECmat_axpy_stub_bc" "lacaml_NPRECmat_axpy_stub"
 
 let axpy ?m ?n ?(alpha = one) ?(xr = 1) ?(xc = 1) ~x ?(yr = 1) ?(yc = 1) y =
   let loc = "Lacaml.Impl.NPREC.Mat.axpy" in
   let m = get_dim1_mat loc x_str x xr m_str m in
   let n = get_dim2_mat loc x_str x xc n_str n in
   check_dim_mat loc y_str yr yc y m n;
-  direct_axpy_mat ~m ~n ~alpha ~xr ~xc ~x ~yr ~yc ~y
+  direct_mat_axpy ~m ~n ~alpha ~xr ~xc ~x ~yr ~yc ~y
+
+let vec_create n = Array1.create prec fortran_layout n
+
+external direct_prod_diag :
+  transa : char ->
+  n : int ->
+  k : int ->
+  ar : int ->
+  ac : int ->
+  a : mat ->
+  br : int ->
+  bc : int ->
+  b : mat ->
+  ofsy : int ->
+  y : vec ->
+  alpha : num_type ->
+  beta : num_type ->
+  unit = "lacaml_NPRECprod_diag_stub_bc" "lacaml_NPRECprod_diag_stub"
+
+let prod_diag ?n ?k ?(beta = zero) ?(ofsy = 1) ?y
+  ?(transa = `N) ?(alpha = one) ?(ar = 1) ?(ac = 1) a ?(br = 1) ?(bc = 1) b =
+  let loc = "Lacaml.Impl.NPREC.Mat.prod_diag" in
+  let n = get_rows_mat_tr loc a_str a ar ac transa n_str n in
+  let k = get_cols_mat_tr loc a_str a ar ac transa k_str k in
+  check_dim_mat loc b_str br bc b k n;
+  let transa = get_trans_char transa in
+  let y = get_vec loc y_str y ofsy 1 n vec_create in
+  direct_prod_diag ~transa ~n ~k ~ar ~ac ~a ~br ~bc ~b ~ofsy ~y ~alpha ~beta;
+  y
+
+external direct_prod_trace :
+  transa : char ->
+  n : int ->
+  k : int ->
+  ar : int ->
+  ac : int ->
+  a : mat ->
+  br : int ->
+  bc : int ->
+  b : mat ->
+  num_type = "lacaml_NPRECprod_trace_stub_bc" "lacaml_NPRECprod_trace_stub"
+
+let prod_trace ?n ?k ?(transa = `N) ?(ar = 1) ?(ac = 1) a
+  ?(br = 1) ?(bc = 1) b =
+  let loc = "Lacaml.Impl.NPREC.Mat.prod_trace" in
+  let n = get_rows_mat_tr loc a_str a ar ac transa n_str n in
+  let k = get_cols_mat_tr loc a_str a ar ac transa k_str k in
+  check_dim_mat loc b_str br bc b k n;
+  let transa = get_trans_char transa in
+  direct_prod_trace ~transa ~n ~k ~ar ~ac ~a ~br ~bc ~b
+
+
+(* Iterators over matrices *)
 
 let map f ?m ?n ?(br = 1) ?(bc = 1) ?b ?(ar = 1) ?(ac = 1) a =
   let loc = "Lacaml.Impl.NPREC.Mat.map" in
 *)
 
 
-(** {6 Arithmetic operations} *)
+(** {6 Arithmetic and other matrix operations} *)
 
 val scal :
   ?m : int -> ?n : int -> num_type -> ?ar : int -> ?ac : int -> mat -> unit
 (** [axpy ?m ?n ?alpha ?xr ?xc ~x ?yr ?yc y] BLAS [axpy] function for
     matrices. *)
 
+val prod_diag :
+  ?n : int ->
+  ?k : int ->
+  ?beta : num_type ->
+  ?ofsy : int ->
+  ?y : vec ->
+  ?transa : trans3 ->
+  ?alpha : num_type ->
+  ?ar : int ->
+  ?ac : int ->
+  mat ->
+  ?br : int ->
+  ?bc : int ->
+  mat ->
+  vec
+(** [prod_diag ?n ?k ?beta ?ofsy ?y ?transa ?alpha ?ar ?ac a ?br ?bc b]
+    computes the diagonal of the product of the (sub-)matrices [a]
+    and [b] (taking into account potential transposing), multiplying
+    it with [alpha] and adding [beta] times [y], storing the result in
+    [y] starting at the specified offset.  [n] elements of the diagonal
+    will be computed, and [k] elements of the matrices will be part of
+    the dot product associated with each diagonal element.
+
+    @param n default = number of rows of [a] (or tr [a]) and
+                       number of columns of [b]
+    @param k default = number of columns of [a] (or tr [a]) and
+                       number of rows of [b]
+    @param beta default = [0]
+    @param ofsy default = [1]
+    @param y default = fresh vector of size [n + ofsy - 1]
+    @param transa default = [`N]
+    @param alpha default = [1]
+    @param ar default = [1]
+    @param ac default = [1]
+    @param br default = [1]
+    @param bc default = [1]
+*)
+
+val prod_trace :
+  ?n : int ->
+  ?k : int ->
+  ?transa : trans3 ->
+  ?ar : int ->
+  ?ac : int ->
+  mat ->
+  ?br : int ->
+  ?bc : int ->
+  mat ->
+  num_type
+(** [prod_trace ?n ?k ?transa ?ar ?ac a ?br ?bc b] computes the trace
+    of the product of the (sub-)matrices [a] and [b] (taking into account
+    potential transposing) [n] elements of the diagonal will be computed,
+    and [k] elements of the matrices will be part of the dot product
+    associated with each diagonal element.
+
+    @param n default = number of rows of [a] (or tr [a]) and
+                       number of columns of [b]
+    @param k default = number of columns of [a] (or tr [a]) and
+                       number of rows of [b]
+    @param transa default = [`N]
+    @param ar default = [1]
+    @param ac default = [1]
+    @param br default = [1]
+    @param bc default = [1]
+*)
+
 
 (** {6 Iterators over matrices} *)
 
 
 #include <string.h>
 #include "lacaml_macros.h"
+#include "utils_c.h"
 #include "f2c.h"
 
-static integer ONE = 1;
+static integer integer_one = 1;
+static char transb_N = 'N';
+static NUMBER number_one = NUMBER_ONE;
+
+
+/* scal */
+
+extern void FUN(scal)(
+  integer *N,
+  NUMBER *ALPHA,
+  NUMBER *X, integer *INCX);
 
 CAMLprim value LFUN(scal_mat_stub)(
   value vM, value vN,
   caml_enter_blocking_section();
     if (rows_A == M) {
       integer MN = M * N;
-      FUN(scal)(&MN, pALPHA, A_data, &ONE);
+      FUN(scal)(&MN, pALPHA, A_data, &integer_one);
     } else {
       NUMBER *A_src = A_data + rows_A * (N - 1);
       while (A_src >= A_data) {
-        FUN(scal)(&M, pALPHA, A_src, &ONE);
+        FUN(scal)(&M, pALPHA, A_src, &integer_one);
         A_src -= rows_A;
       }
     }
     argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]);
 }
 
+
+/* scal_cols */
+
 CAMLprim value LFUN(scal_cols_stub)(
   value vM, value vN,
   value vOFSALPHAs,
   VEC_PARAMS(ALPHAs);
   MAT_PARAMS(A);
 
+  NUMBER *A_src = A_data + rows_A * (N - 1);
+  NUMBER *ALPHAs_src = ALPHAs_data + (N - 1);
+
   caml_enter_blocking_section();
-    NUMBER *A_src = A_data + rows_A * (N - 1);
-    NUMBER *ALPHAs_src = ALPHAs_data + (N - 1);
     while (A_src >= A_data) {
-      FUN(scal)(&M, ALPHAs_src, A_src, &ONE);
+      FUN(scal)(&M, ALPHAs_src, A_src, &integer_one);
       A_src -= rows_A;
       ALPHAs_src--;
     }
     argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]);
 }
 
-CAMLprim value LFUN(axpy_mat_stub)(
+
+/* mat_axpy */
+
+extern void FUN(axpy)(
+  integer *N,
+  NUMBER *ALPHA,
+  NUMBER *X, integer *INCX,
+  NUMBER *Y, integer *INCY);
+
+CAMLprim value LFUN(mat_axpy_stub)(
   value vM, value vN,
   value vALPHA,
   value vXR, value vXC, value vX,
   caml_enter_blocking_section();
     if (rows_X == M && rows_Y == M) {
       integer MN = M * N;
-      FUN(axpy)(&MN, pALPHA, X_data, &ONE, Y_data, &ONE);
+      FUN(axpy)(&MN, pALPHA, X_data, &integer_one, Y_data, &integer_one);
     } else {
       NUMBER *X_src = X_data + rows_X * (N - 1);
       NUMBER *Y_dst = Y_data + rows_Y * (N - 1);
       while (X_src >= X_data) {
-        FUN(axpy)(&M, pALPHA, X_src, &ONE, Y_dst, &ONE);
+        FUN(axpy)(&M, pALPHA, X_src, &integer_one, Y_dst, &integer_one);
         X_src -= rows_X;
         Y_dst -= rows_Y;
       }
   CAMLreturn(Val_unit);
 }
 
-CAMLprim value LFUN(axpy_mat_stub_bc)(value *argv, int argn)
+CAMLprim value LFUN(mat_axpy_stub_bc)(value *argv, int argn)
 {
-  return LFUN(axpy_mat_stub)(
+  return LFUN(mat_axpy_stub)(
     argv[0], argv[1], argv[2], argv[3], argv[4],
     argv[5], argv[6], argv[7], argv[8]);
 }
+
+
+/* prod_diag */
+
+extern void FUN(gemm)(
+  char *TRANSA, char *TRANSB,
+  integer *M, integer *N, integer *K,
+  NUMBER *ALPHA,
+  NUMBER *A, integer *LDA,
+  NUMBER *B, integer *LDB,
+  NUMBER *BETA,
+  NUMBER *C, integer *LDC);
+
+CAMLprim value LFUN(prod_diag_stub)(
+  value vTRANSA,
+  value vN, value vK,
+  value vAR, value vAC, value vA,
+  value vBR, value vBC, value vB,
+  value vOFSY,
+  value vY,
+  value vALPHA,
+  value vBETA
+  )
+{
+  CAMLparam3(vA, vB, vY);
+
+  integer GET_INT(N), GET_INT(K);
+  char GET_INT(TRANSA);
+
+  CREATE_NUMBERP(ALPHA);
+  CREATE_NUMBERP(BETA);
+
+  MAT_PARAMS(A);
+  MAT_PARAMS(B);
+  VEC_PARAMS(Y);
+
+  int incr_A = (TRANSA == 'N') ? 1 : rows_A;
+
+  INIT_NUMBER(ALPHA);
+  INIT_NUMBER(BETA);
+
+  caml_enter_blocking_section();  /* Allow other threads */
+  while (N--) {
+    /* TODO: quite inefficient for small K (> factor 2 for ten elements).
+       Optimize by essentially reimplementing gemm, possibly using "dot"
+       at each step, but hoisting all initializations and checks out of
+       the loop. */
+    FUN(gemm)(
+      &TRANSA, &transb_N,
+      &integer_one, &integer_one, &K,
+      pALPHA,
+      A_data, &rows_A,
+      B_data, &rows_B,
+      pBETA,
+      Y_data, &integer_one);
+    A_data += incr_A;
+    B_data += rows_B;
+    Y_data++;
+  }
+  caml_leave_blocking_section();  /* Disallow other threads */
+
+  CAMLreturn(Val_unit);
+}
+
+CAMLprim value LFUN(prod_diag_stub_bc)(value *argv, int argn)
+{
+  return LFUN(prod_diag_stub)(
+    argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6],
+    argv[7], argv[8], argv[9], argv[10], argv[11], argv[12]);
+}
+
+/* prod_trace */
+
+CAMLprim value LFUN(prod_trace_stub)(
+  value vTRANSA,
+  value vN, value vK,
+  value vAR, value vAC, value vA,
+  value vBR, value vBC, value vB)
+{
+  CAMLparam2(vA, vB);
+
+  integer GET_INT(N), GET_INT(K);
+  char GET_INT(TRANSA);
+
+  MAT_PARAMS(A);
+  MAT_PARAMS(B);
+
+  int incr_A = (TRANSA == 'N') ? 1 : rows_A;
+
+  NUMBER res = NUMBER_ZERO;
+
+  caml_enter_blocking_section();  /* Allow other threads */
+  while (N--) {
+    /* TODO: quite inefficient for small K (> factor 2 for ten elements).
+       Optimize by essentially reimplementing gemm, possibly using "dot"
+       at each step, but hoisting all initializations and checks out of
+       the loop. */
+    FUN(gemm)(
+      &TRANSA, &transb_N,
+      &integer_one, &integer_one, &K,
+      &number_one,
+      A_data, &rows_A,
+      B_data, &rows_B,
+      &number_one,
+      &res, &integer_one);
+    A_data += incr_A;
+    B_data += rows_B;
+  }
+  caml_leave_blocking_section();  /* Disallow other threads */
+
+  CAMLreturn(COPY_NUMBER(res));
+}
+
+CAMLprim value LFUN(prod_trace_stub_bc)(value *argv, int argn)
+{
+  return LFUN(prod_trace_stub)(
+    argv[0], argv[1], argv[2], argv[3], argv[4],
+    argv[5], argv[6], argv[7], argv[8]);
+}
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.