Commits

Aleš Erjavec  committed 4e79625

Added test for 'forward_pass'.

  • Participants
  • Parent commits 234f628

Comments (0)

Files changed (2)

File orangekit/earth/tests/__init__.py

Empty file added.

File orangekit/earth/tests/test_core.py

+import unittest
+
+import numpy
+
+from .. import core
+
+
+class TestEarth(unittest.TestCase):
+    def test_forward(self):
+        x = numpy.linspace(0, 1, num=100, endpoint=True)
+        y = numpy.hstack((numpy.zeros(50), (x[50:] - 0.5)))
+
+        n, best_set, bx, dirs, cuts = \
+            core.forward_pass(x.reshape(-1, 1), y, degree=1, terms=3,
+                               penalty=0, new_var_penalty=0)
+
+        self.assertEqual(n, 3)
+        self.assertTrue((best_set == [True, True, True]).all())
+
+        self.assertTrue(numpy.allclose(cuts.ravel(),
+                                       [0, 0.5, 0.5], atol=0.05))
+
+        self.assertTrue((dirs.ravel() == [0, 1, -1]).all())
+
+        # raise y by 1.0 and repeat
+        y += 1.0
+
+        n, best_set, bx, dirs, cuts = \
+            core.forward_pass(x.reshape(-1, 1), y, degree=1, terms=3,
+                               penalty=0, new_var_penalty=0)
+
+        self.assertEqual(n, 3)
+        self.assertTrue((best_set == [True, True, True]).all())
+
+        self.assertTrue(numpy.allclose(cuts.ravel(),
+                                       [0, 0.5, 0.5], atol=0.05))
+
+        self.assertTrue((dirs.ravel() == [0, 1, -1]).all())