Commits

Christoph Dann committed e8640d2

HashedTabular representation for PST policy evaluation

  • Participants
  • Parent commits 08bb128
  • Branches poleval

Comments (0)

Files changed (3)

File Representations/Tabular.py

 
 class QTabular(Tabular, QFunRepresentation):
     pass
+
+
+
+class HashedTabular(Representation):
+    """
+    WARNING: EXPERIMENTAL!!
+
+
+    Tabular Representation with Hashing Trick
+    """
+
+    BIG_INT = 2147483647
+
+    def __init__(self, domain, logger, memory,
+                 safety="super",
+                 seed=0):
+        """
+        TODO
+        """
+        self.features_num = memory
+        super(HashedTabular, self).__init__(domain, logger)
+        # now only hashing stuff
+        self.seed=seed
+        self.safety = safety
+        if safety == "super":
+            size = self.domain.state_space_dims
+            self.check_data = -np.ones((self.features_num, size), dtype=np.long)
+        elif safety == "lazy":
+            size = 1
+        else:
+            self.check_data = -np.ones((self.features_num), dtype=np.long)
+        self.counts = np.zeros(self.features_num, dtype=np.long)
+        self.collisions = 0
+        self.R = np.random.RandomState(seed).randint(self.BIG_INT / 4  ,size=self.features_num).astype(np.int)
+
+        if safety == "none":
+            try:
+                import hashing as h
+                f = lambda self, A: h.physical_addr(A, self.R, self.check_data,
+                                                   self.counts)[0]
+                self._physical_addr = type(HashedTabular._physical_addr)(f, self, HashedTabular)
+                print "Use cython extension for hashing trick"
+            except Exception, e:
+                print e
+                print "Cython extension for hashing trick not available"
+
+    def phi_nonTerminal(self, s):
+
+        phi = np.zeros((self.features_num))
+        sn = np.array(s, dtype="int")
+        j = self._physical_addr(sn)
+        phi[j] = 1
+        return phi
+
+
+    def _hash(self, A, increment=449, max=None):
+        """
+        hashing without collision detection
+        """
+        # TODO implement in cython if speed needs to be improved
+        max = self.features_num if max == None else max
+        return int(self.R[np.mod(A + np.arange(len(A))*increment, self.features_num)].sum()) % max
+
+    def _physical_addr(self, A):
+        """
+        Map a virtual vector address A to a physical address i.e. the actual
+        feature number.
+        This is the actual hashing trick
+        """
+        h1 = self._hash(A)
+        if self.safety == "super":
+            # use full value to detect collisions
+            check_val = A
+        else:
+            # use second hash
+            check_val = self._hash(A, increment = 457, max = self.BIG_INT)
+
+        if self.counts[h1] == 0:
+            # first time, set up data
+            self.check_data[h1] = check_val
+            self.counts[h1] += 1
+            return h1
+        elif np.all(check_val == self.check_data[h1]):
+            # clear hit, everything's fine
+            self.counts[h1] += 1
+            return h1
+        elif self.safety == "none":
+            self.collisions += 1
+            return h1
+        else:
+            h2 = 1 + 2 * self._hash(A, max = self.BIG_INT / 4)
+            for i in xrange(self.features_num):
+                h1 = (h1 + h2) % self.features_num
+                if self.counts[h1] == 0 or np.all(self.check_data[h1] == check_val):
+                    self.check_data[h1] = check_val
+                    self.counts[h1] += 1
+                    return h1
+            self.collisions += 1
+            #self.logger.log("Tile memory too small")
+            return h1
+
+    def featureType(self):
+        return bool
+
+class QHashedTabular(HashedTabular, QFunRepresentation):
+    pass
+
+

File Representations/__init__.py

 
 
-from Tabular import Tabular, QTabular
+from Tabular import Tabular, QTabular, HashedTabular, QHashedTabular
 from IncrementalTabular import IncrementalTabular, QIncrementalTabular
 from IndependentDiscretization import IndependentDiscretization, QIndependentDiscretization
 from IndependentDiscretization import IndependentDiscretizationCompact, QIndependentDiscretizationCompact

File examples/uav/poleval/tabular.py

+from Tools import Logger
+from ValueEstimators.TDLearning import TDLearning
+from Representations import HashedTabular
+from Domains import PST
+from Tools import __rlpy_location__
+from Experiments.PolicyEvaluationExperiment import PolicyEvaluationExperiment
+from Policies import StoredPolicy
+import numpy as np
+from hyperopt import hp
+
+param_space = {'lambda_': hp.uniform("lambda_", 0., 1.),
+               'boyan_N0': hp.loguniform("boyan_N0", np.log(1e1), np.log(1e5)),
+               'initial_alpha': hp.loguniform("initial_alpha", np.log(1e-2), np.log(1))}
+
+
+def make_experiment(id=1, path="./Results/Temp/{domain}/poleval/ifdd/",
+                    lambda_=0.701309,
+                    boyan_N0=1375.098,
+                    initial_alpha=0.016329):
+    logger = Logger()
+    max_steps = 500000
+    sparsify = 1
+    domain = PST(NUM_UAV=4, motionNoise=0, logger=logger)
+    pol = StoredPolicy(filename="__rlpy_location__/Policies/PST_4UAV_mediocre_policy_nocache.pck")
+    representation = HashedTabular(domain, logger, memory=20000, safety="super")
+    estimator = TDLearning(representation=representation, lambda_=lambda_,
+                           boyan_N0=boyan_N0, initial_alpha=initial_alpha, alpha_decay_mode="boyan")
+    experiment = PolicyEvaluationExperiment(estimator, domain, pol, max_steps=max_steps, num_checks=20,
+                                            path=path, log_interval=10, id=id)
+    experiment.num_eval_points_per_dim=20
+    experiment.num_traj_V = 300
+    experiment.num_traj_stationary = 300
+    return experiment
+
+if __name__ == '__main__':
+    from Tools.run import run_profiled
+    #run_profiled(make_experiment)
+    experiment = make_experiment(1)
+    experiment.run()
+    #experiment.plot()
+    #experiment.save()