Markus Mottl avatar Markus Mottl committed 13ed9ac

Improved speed of Octave test suite

Comments (0)

Files changed (1)

 function res = kf(x, y, a, b)
   [dim, n1] = size(x);
   n2 = size(y, 2);
-  repmat(sum(y' .* y', 2), 1, n1);
   r2 = repmat(sum(x' .* x', 2), 1, n2) - 2 * x' * y + repmat(sum(y' .* y', 2)', n1, 1);
   res = eval_rbf2(r2, a, b);
   [dim, N] = size(res);
   res = kf(x, y, log_sf2, inv_ell2_e);
 end
 
+function res = kf_diag(x, a, b)
+  r2 = zeros(size(x, 2), 1);
+  res = eval_rbf2(r2, a, b);
+end
+
+function res = k_diag(x)
+  global log_sf2 inv_ell2;
+  res = kf_diag(x, log_sf2, inv_ell2);
+end
+
+function res = k_diag_e(x)
+  global log_sf2 log_sf2_e inv_ell2 inv_ell2_e epsilon;
+  res = kf_diag(x, log_sf2, inv_ell2_e);
+end
+
 
 %%%%%%%%%%%%%%%%%%%% Covariance matrices %%%%%%%%%%%%%%%%%%%%
 
 Km = k(inducing_points, inducing_points);
-
 Km_e = k_e(inducing_points, inducing_points);
 dKm = (Km_e - Km) / epsilon;
 
 Knm_e = k_e(inducing_points, inputs);
 dKnm = (Knm_e - Knm) / epsilon;
 
-Kn = k(inputs, inputs);
-Kn_e = k_e(inputs, inputs);
-dKn = (Kn_e - Kn) / epsilon;
+Kn_diag = k_diag(inputs);
+Kn_e_diag = k_diag_e(inputs);
+dKn_diag = (Kn_e_diag - Kn_diag) / epsilon;
 
 
 %%%%%%%%%%%%%%%%%%%% Main definitions %%%%%%%%%%%%%%%%%%%%
 
 cholKm = chol(Km);
 V = Knm / cholKm;
-Qn = V * V';
 
-lam = diag(diag(Kn - Qn));
-lam_sigma2 = lam + sigma2 * eye(N);
-inv_lam_sigma2 = inv(lam_sigma2);
-inv_lam_sigma = sqrt(inv_lam_sigma2);
+r = Kn_diag - sum(V .^ 2, 2);
+s = r + sigma2;
+is = ones(size(s, 1), 1) ./ s;
+is_2 = sqrt(is);
 
-Knm_ = inv_lam_sigma * Knm;
+inv_lam_sigma = repmat(is_2, 1, size(Knm, 2));
+
+Knm_ = inv_lam_sigma .* Knm;
 
 [Q, R] = qr([Knm_; chol(Km)], 1);
 SF = diag(sign(diag(R)));
 
 B = Km + Knm_' * Knm_;
 
-r = diag(lam);
-s = diag(lam_sigma2);
-is = diag(inv_lam_sigma2);
-
-
 %%%%%%%%%%%%%%%%%%%% Standard %%%%%%%%%%%%%%%%%%%%
 
 %%%%%% Log evidence
     2*sum(log(diag(R))) - 2*sum(log(diag(cholKm))) + sum(log(s)) ...
     + N * log(2*pi))
 
-S = inv_lam_sigma * Q / R';
+S = inv_lam_sigma .* Q / R';
 t = S'*y;
-e = y - Knm*t;
-u = is .* e;
+u = is .* (y - Knm*t);
 l2 = -0.5*(u'*y)
 
 l = l1 + l2
 
 U = V / cholKm';
 
-v1 = is .* (ones(size(Q, 1), 1) - diag(Q * Q'));
-U1 = diag(sqrt(v1)) * U;
+v1 = is .* (ones(size(Q, 1), 1) - sum(Q .^ 2, 2));
+U1 = repmat(sqrt(v1), 1, size(U, 2)) .* U;
 W1 = T - U1'*U1;
-X1 = S - diag(v1)*U;
+X1 = S - repmat(v1, 1, size(U, 2)) .* U;
 
-dl1 = -0.5*(v1' * diag(dKn) - trace(W1'*dKm)) - trace(X1'*dKnm)
+dl1 = -0.5*(v1' * dKn_diag - trace(W1'*dKm)) - trace(X1'*dKnm)
 
 v2 = u .* u;
-U2 = diag(u)*U;
+U2 = repmat(u, 1, size(U, 2)) .* U;
 W2 = t*t' - U2'*U2;
-X2 = u*t' - diag(v2)*U;
+X2 = u*t' - repmat(v2, 1, size(U, 2)) .* U;
 
-dl2 = 0.5*(v2' * diag(dKn) - trace(W2'*dKm)) + trace(X2'*dKnm)
+dl2 = 0.5*(v2' * dKn_diag - trace(W2'*dKm)) + trace(X2'*dKnm)
 
 dl = dl1 + dl2
 
 
 %%%%%% Log evidence derivative
 
-vv1 = is .* (2*ones(size(Q, 1), 1) - is .* r - diag(Q * Q'));
-vU1 = diag(sqrt(vv1)) * U;
+vv1 = is .* (2*ones(size(Q, 1), 1) - is .* r - sum(Q .* 2, 2));
+vU1 = repmat(sqrt(vv1), 1, size(U, 2)) .* U;
 vW1 = T - vU1'*vU1;
-vX1 = S - diag(vv1) * U;
+vX1 = S - repmat(vv1, 1, size(U, 2)) .* U;
 
-vdl1 = -0.5*(vv1' * diag(dKn) - trace(vW1'*dKm)) - trace(vX1'*dKnm)
+vdl1 = -0.5*(vv1' * dKn_diag - trace(vW1'*dKm)) - trace(vX1'*dKnm)
 vdl = vdl1 + dl2
 
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.