Commits

Olivier Grisel committed f42c146

early work on input centering / scaling

Comments (0)

Files changed (3)

src/sgd/architecture.py

 from sgd.common import InvalidModuleException
 from sgd.common import PicklableObject
 from sgd.common import SgdException
+from sgd.datachunker import CenteringScaler
 from sgd.learner import SgdLearner
 from sgd.learner import MultiprocessLearner
 
                                    validation_interval=validation_interval,
                                    improvement_tol=improvement_tol)
 
+    def get_scaler(self, chunker):
+        return CenteringScaler(chunker)
+
     def get_current_t(self):
         return self._c_structure.t
 

src/sgd/datachunker.py

 # THE SOFTWARE.
 """Implementation of chunkers that implements the generator protocol"""
 
+import numpy as np
+
 class BaseChunker(object):
     """Base class to be overriden by implementations
 
     def number_of_remaining_chunks(self):
         return self.chunker.number_of_remaining_chunks()
 
+
+class CenteringScaler(object):
+    """Adapter to center and scale to unit variance the underlying chunker"""
+
+    def __init__(self, chunker, sample_max_size=10000):
+        self.chunker = chunker
+
+        # compute the mean
+        first_input, first_label = chunker.next()
+        self.mean_input = first_input.sum(axis=0)
+        self.mean_label = first_label.sum(axis=0)
+        count = first_input.shape[0]
+
+        for input_, label in chunker:
+            count += input_.shape[0]
+            self.mean_input += input_.sum(axis=0)
+            self.mean_label += label.sum(axis=0)
+
+            if count > sample_max_size:
+                break
+
+        self.mean_input /= count
+        self.mean_label /= count
+
+        chunker.reset()
+
+        # compute the standard deviation
+        stddev_input = np.zeros(self.mean_input.shape[0],
+                                dtype=self.mean_input.dtype)
+        stddev_label = np.zeros(self.mean_label.shape[0],
+                                dtype=self.mean_label.dtype)
+        count = 0
+        for input_, label in chunker:
+            count += input_.shape[0]
+            stddev_input += ((input_ - self.mean_input) ** 2).sum(axis=0)
+            stddev_label += ((label - self.mean_label) ** 2).sum(axis=0)
+            if count > sample_max_size:
+                break
+
+        stddev_input /= count
+        stddev_label /= count
+        np.sqrt(stddev_input, stddev_input)
+        np.sqrt(stddev_label, stddev_label)
+
+        # TODO: watch for null components
+        self.scale_input = 1 / stddev_input
+        self.scale_label = 1 / stddev_label
+
+        chunker.reset()
+
+    def next(self):
+        input_, label = self.chunker.next()
+        return ((input_ - self.mean_input) * self.scale_input,
+                (label))# - self.mean_label) * self.scale_label)
+
+    def __iter__(self):
+        return self
+
+    def get_validation_data(self):
+        v_input, v_label = self.chunker.get_validation_data()
+        if v_input is not None:
+            return ((v_input - self.mean_input) * self.scale_input,
+                    (v_label))# - self.mean_label) * self.scale_label)
+        else:
+            return None, None
+

src/sgd/learner.py

         if t0 is not None:
             self.architecture.set_current_t(t0)
 
+        #scaler = self.architecture.get_scaler(chunker)
+        scaler = chunker
+
         last_loss = 1e6
 
-        validation_input, validation_label = chunker.get_validation_data()
+        validation_input, validation_label = scaler.get_validation_data()
         consecutive_improv_miss = 0
 
         for i in xrange(epochs):
-            for j, (input_, label) in enumerate(chunker):
+            for j, (input_, label) in enumerate(scaler):
 
+                # validate
                 if (validation_input is not None
                     and j % validation_interval == 0):