Snippets

Nathan Jacobs Gradient Descent for Optimizing the Huber Loss

Created by Nathan Jacobs
function [v, G] = huberobj(w, X, Y, r)
%HUBERLOSS Huber loss function
%
% The goal is for X*w to be close to Y.
%
%       Evaluates the huber loss function, defined to be
%
%           v = (1/2) * e^2,        when |e| < r
%             = r * e - r^2 / 2.    when |e| >= r
%
%       If e is a vector, then v is the sum of the loss values at
%       all components. The derivative is given by
%
%           g = e,                  when |e| < r
%             = r * sign(e),        when |e| >= r
%

% Created by Dahua Lin, on Jan 15, 2012 (http://code.google.com/p/smitoolbox/)
% Modified by Nathan Jacobs on Nov 2, 2015

Z = X*w;
E = Z - Y;

Ea = abs(E);
Eb = min(Ea, r);

v = 0.5 * Eb .* (2 * Ea - Eb);

% figure(1); clf;
% plot(E, v, '.');
% pause(.01)

if size(v, 1) > 1
    v = sum(v, 1);
end

G = Eb .* sign(E);
G = X'*G;


%% make some fake data

d = 4000;
N = 50000;

w_gt = randn(d,1);
x = randn(N,d);
y = x*w_gt;

%%

% if the residual is less than this, the loss is quadratic
% if the residual is more than this, the loss is linear
quadratic_region = 3; 

localCost = @(w) huberobj(w,x,y,quadratic_region);
opt = optimset('fminunc');
opt.Display = 'iter';
opt.GradObj = 'on';
opt.HessPattern = sparse(eye(d));
w_est = fminunc(localCost, randn(size(w_gt)), opt);

Comments (2)

  1. Nathan Jacobs

    @ted_zhai Here is a simple example of using nonlinear optimization to minimize the Huber loss.

HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.