Commits

Aleš Erjavec committed 9980267

Added test for EarthLearner on a muti-target problem.

Comments (0)

Files changed (1)

Orange/testing/unit/tests/test_earth.py

+import Orange
 from Orange.misc import testing
 from Orange.misc.testing import datasets_driven, test_on_data, test_on_datasets
 from Orange.regression import earth
-import Orange
+
 try:
     import unittest2 as unittest
 except:
         from Orange.regression.earth import ScoreEarthImportance
         self.measure = ScoreEarthImportance(t=5, score_what="rss")
 
-
+@datasets_driven(datasets=["multitarget-synthetic"])
+class TestEarthMultitarget(unittest.TestCase):
+    @test_on_data
+    def test_multi_target_on_data(self, dataset):
+        self.learner = earth.EarthLearner(degree=2, terms=10)
+        
+        self.predictor = self.multi_target_test(self.learner, dataset)
+        
+        self.assertTrue(bool(self.predictor.multitarget))
+        
+        s = str(self.predictor)
+        self.assertEqual(s, self.predictor.to_string())
+        self.assertNotEqual(s, self.predictor.to_string(3, 6))
+        
+    
+    def multi_target_test(self, learner, data):
+        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data)
+        learn = data.select(indices, 1)
+        test = data.select(indices, 0)
+        
+        predictor = learner(learn)
+        self.assertIsInstance(predictor, Orange.classification.Classifier)
+        self.multi_target_predictor_interface(predictor, learn.domain)
+        
+        from Orange.evaluation import testing as _testing
+        
+        r = _testing.test_on_data([predictor], test)
+        
+        def all_values(vals):
+            for v in vals:
+                self.assertIsInstance(v, Orange.core.Value)
+                
+        def all_dists(dist):
+            for d in dist:
+                self.assertIsInstance(d, Orange.core.Distribution)
+                
+        for ex in test:
+            preds = predictor(ex, Orange.core.GetValue)
+            all_values(preds)
+            
+            dist = predictor(ex, Orange.core.GetProbabilities)
+            all_dists(dist)
+            
+            preds, dist = predictor(ex, Orange.core.GetBoth)
+            all_values(preds)
+            all_dists(dist)
+            
+            for d in dist:
+                if isinstance(d, Orange.core.ContDistribution):
+                    dist_sum = sum(d.values())
+                else:
+                    dist_sum = sum(d)
+                    
+                self.assertGreater(dist_sum, 0.0)
+                self.assertLess(abs(dist_sum - 1.0), 1e-3)
+            
+        return predictor
+    
+    def multi_target_predictor_interface(self, predictor, domain):
+        self.assertTrue(hasattr(predictor, "class_vars"))
+        self.assertIsInstance(predictor.class_vars, (list, Orange.core.VarList))
+        self.assertTrue(all(c1 == c2 for c1, c2 in \
+                            zip(predictor.class_vars, domain.class_vars)))
+        
+    
 #@datasets_driven(datasets=testing.REGRESSION_DATASETS,)
 #class TestScoreRSS(testing.MeasureAttributeTestCase):
 #    def setUp(self):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.