Aleš Erjavec avatar Aleš Erjavec committed 1085854

Refactor Rank widget. Allow adding extra score functions through entry points.

Comments (0)

Files changed (1)

Orange/OrangeWidgets/Data/OWRank.py

 <contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
 <priority>1102</priority>
 """
+
+from collections import namedtuple
+from functools import partial
+
+import pkg_resources
+
 from OWWidget import *
-    
+
 import OWGUI
 import orange
 
-from functools import partial
+from Orange.regression.earth import ScoreEarthImportance
+from orngSVM import MeasureAttribute_SVMWeights
+from orngEnsemble import MeasureAttribute_randomForests
+
 
 def _toPyObject(variant):
     val = variant.toPyObject()
     """ Return a 2D table with shape filed with ``fill``
     """
     return [[fill for j in range(shape[1])] for i in range(shape[0])]
- 
-from Orange.regression.earth import ScoreEarthImportance
-from orngSVM import MeasureAttribute_SVMWeights
-from orngEnsemble import MeasureAttribute_randomForests
 
-MEASURE_PARAMS = {ScoreEarthImportance: \
-                    [{"name": "t",
-                      "type": int,
-                      "display_name": "Num. models.",
-                      "range": range(1, 21),
-                      "default": 10,
-                      "doc": "Number of models to train for feature scoring."
-                      },
-                     {"name": "terms",
-                      "type": int,
-                      "display_name": "Max. num of terms",
-                      "range": range(3, 200),
-                      "default": 10,
-                      "doc": "Maximum number of terms in the forward pass" 
-                      },
-                     {"name": "degree", 
-                      "type": int,
-                      "display_name": "Max. term degree",
-                      "range": range(1, 4),
-                      "default": 2,
-                      "doc": "Maximum degree of terms included in the model." 
-                     },
-#                     {"name": "score_what",
-#                      "type": int,
-#                      "display_name": "Score what",
-#                      "range": range(0, 3),
-#                      "display_role": ["Num. Subsets", "RSS", "GCV"]
-#                      "default": 2,
-#                      "doc": ""}
-                     ],
-                  orange.MeasureAttribute_relief: \
-                     [{"name": "k",
-                       "type": int,
-                       "display_name": "Neighbours",
-                       "range": range(1, 21),
-                       "default": 10,
-                       "doc": "Number of neighbors to consider."},
-                      {"name":"m",
-                       "type": int,
-                       "display_name": "Examples",
-                       "range": range(20, 101),
-                       "default": 20,
-                       "doc": ""}
-                      ],
-                  MeasureAttribute_randomForests:\
-                     [{"name": "trees",
-                       "type": int,
-                       "display_name": "Num. of trees",
-                       "range": range(20, 101),
-                       "default": 100,
-                       "doc": "Number of trees in the random forest."}
-                      ]
-                  }
 
+MEASURE_PARAMS = {
+    ScoreEarthImportance: [
+        {"name": "t",
+         "type": int,
+         "display_name": "Num. models.",
+         "range": (1, 20),
+         "default": 10,
+         "doc": "Number of models to train for feature scoring."},
+        {"name": "terms",
+         "type": int,
+         "display_name": "Max. num of terms",
+         "range": (3, 200),
+         "default": 10,
+         "doc": "Maximum number of terms in the forward pass"},
+        {"name": "degree",
+         "type": int,
+         "display_name": "Max. term degree",
+         "range": (1, 3),
+         "default": 2,
+         "doc": "Maximum degree of terms included in the model."}
+    ],
+    orange.MeasureAttribute_relief: [
+        {"name": "k",
+         "type": int,
+         "display_name": "Neighbours",
+         "range": (1, 20),
+         "default": 10,
+         "doc": "Number of neighbors to consider."},
+        {"name":"m",
+         "type": int,
+         "display_name": "Examples",
+         "range": (20, 100),
+         "default": 20,
+         "doc": ""}
+        ],
+    MeasureAttribute_randomForests: [
+        {"name": "trees",
+         "type": int,
+         "display_name": "Num. of trees",
+         "range": (20, 100),
+         "default": 100,
+         "doc": "Number of trees in the random forest."}
+        ]
+    }
 
-MEASURES = [("ReliefF", "ReliefF", orange.MeasureAttribute_relief),
-            ("Information Gain", "Inf. gain", orange.MeasureAttribute_info),
-            ("Gain Ratio", "Gain Ratio", orange.MeasureAttribute_gainRatio),
-            ("Gini Gain", "Gini", orange.MeasureAttribute_gini),
-            ("Log Odds Ratio", "log OR", orange.MeasureAttribute_logOddsRatio),
-            ("MSE", "MSE", orange.MeasureAttribute_MSE),
-            ("Earth Importance", "Earth imp.", ScoreEarthImportance),
-            ("Linear SVM Weights", "SVM weight", MeasureAttribute_SVMWeights),
-            ("Random Forests", "RF", MeasureAttribute_randomForests),
-            ]
 
-MEASURES_HANDLES_CONTINUOUS = {"ReliefF": True,
-                               "Earth Importance": True,
-                               "Linear SVM Weights": True,
-                               "Random Forests": True,
-                               }
+_score_meta = namedtuple(
+    "_score_meta",
+    ["name",
+     "shortname",
+     "score",
+     "params",
+     "supports_regression",
+     "supports_classification",
+     "handles_discrete",
+     "handles_continuous"]
+)
 
-MEASURES_SUPPORTS_REGRESSION = {"ReliefF": True,
-                                "MSE": True,
-                                "Earth Importance": True,
-                                "Random Forests": True,
-                                }
 
-MEASURES_SUPPORTS_CLASSIFICATION = {"MSE": False,
-                                    "Random Forests": True,
-                                    }
+class score_meta(_score_meta):
+    # Add sensible defaults to __new__
+    def __new__(cls, name, shortname, score, params=None,
+                supports_regression=True, supports_classification=True,
+                handles_continuous=True, handles_discrete=True):
+        return _score_meta.__new__(
+            cls, name, shortname, score, params,
+            supports_regression, supports_classification,
+            handles_discrete, handles_continuous
+        )
 
-MEASURES_DEFAULT_SELECTED = dict([(mname, True) for mname, _, _ in MEASURES[:6]] + \
-                                 [(mname, False) for mname, _, _ in MEASURES[6:]]) # The Earth imp. and SVM are not selected by default
+
+# Default scores.
+SCORES = [
+    score_meta(
+        "ReliefF", "ReliefF", orange.MeasureAttribute_relief,
+        params=MEASURE_PARAMS[orange.MeasureAttribute_relief],
+        handles_continuous=True,
+        handles_discrete=True),
+    score_meta(
+        "Information Gain", "Inf. gain", orange.MeasureAttribute_info,
+        params=None,
+        supports_regression=False,
+        supports_classification=True,
+        handles_continuous=False,
+        handles_discrete=True),
+    score_meta(
+        "Gain Ratio", "Gain Ratio", orange.MeasureAttribute_gainRatio,
+        params=None,
+        supports_regression=False,
+        handles_continuous=False,
+        handles_discrete=True),
+    score_meta(
+        "Gini Gain", "Gini", orange.MeasureAttribute_gini,
+        params=None,
+        supports_regression=False,
+        supports_classification=True,
+        handles_continuous=False),
+    score_meta(
+        "Log Odds Ratio", "log OR", orange.MeasureAttribute_logOddsRatio,
+        params=None,
+        supports_regression=False,
+        handles_continuous=False),
+    score_meta(
+        "MSE", "MSE", orange.MeasureAttribute_MSE,
+        params=None,
+        supports_classification=False,
+        handles_continuous=False),
+    score_meta(
+        "Linear SVM Weights", "SVM weight", MeasureAttribute_SVMWeights,
+        params=None),
+    score_meta(
+        "Random Forests", "RF", MeasureAttribute_randomForests,
+        params=MEASURE_PARAMS[MeasureAttribute_randomForests]),
+    score_meta(
+        "Earth Importance", "Earth imp.", ScoreEarthImportance,
+        params=MEASURE_PARAMS[ScoreEarthImportance],
+    )
+]
+
+_DEFAULT_SELECTED = set(m.name for m in SCORES[:6])
 
 
 class MethodParameter(object):
         self.default = default
         self.doc = doc
 
-def supports_classification(name):
-    return MEASURES_SUPPORTS_CLASSIFICATION.get(name, True)
-
-def supports_regression(name):
-    return MEASURES_SUPPORTS_REGRESSION.get(name, False)
-
-def handles_continuous(name):
-    return MEASURES_HANDLES_CONTINUOUS.get(name, False)
 
 def measure_parameters(measure):
-    return [MethodParameter(**args) for args in MEASURE_PARAMS.get(measure, [])]
+    return [MethodParameter(**args) for args in (measure.params or [])]
+
 
 def param_attr_name(measure, param):
-    """ Name of the OWRank widget's member where the parameter is stored. 
+    """Name of the OWRank widget's member where the parameter is stored.
     """
     return "param_" + measure.__name__ + "_" + param.name
-        
+
+
+def drop_exceptions(iterable, exceptions=(Exception,)):
+    iterable = iter(iterable)
+    while True:
+        try:
+            yield next(iterable)
+        except StopIteration:
+            raise
+        except BaseException as ex:
+            if not isinstance(ex, exceptions):
+                raise
+
+
+def load_ep_drop_exceptions(entry_point):
+    for ep in pkg_resources.iter_entry_points(entry_point):
+        try:
+            yield ep.load()
+        except Exception:
+            log = logging.getLogger(__name__)
+            log.debug("", exc_info=True)
+
+
+def all_measures():
+    iter_ep = load_ep_drop_exceptions("orange.widgets.feature_score")
+    scores = [m for m in iter_ep if isinstance(m, score_meta)]
+    return SCORES + scores
+
+
 class OWRank(OWWidget):
-    settingsList =  ["nDecimals", "nIntervals", "sortBy", "nSelected", "selectMethod", "autoApply", "showDistributions", "distColorRgb"]
+    settingsList = [
+        "nDecimals", "nIntervals", "sortBy", "nSelected",
+        "selectMethod", "autoApply", "showDistributions",
+        "distColorRgb"
+    ]
 
-    def __init__(self,parent=None, signalManager = None):
+    def __init__(self, parent=None, signalManager=None):
         OWWidget.__init__(self, parent, signalManager, "Rank")
 
         self.inputs = [("Data", ExampleTable, self.setData)]
         self.nSelected = 5
         self.autoApply = True
         self.showDistributions = 1
-        self.distColorRgb = (220,220,220, 255)
+        self.distColorRgb = (220, 220, 220, 255)
         self.distColor = QColor(*self.distColorRgb)
-        self.minmax = {}
-        self.selectedMeasures = dict(MEASURES_DEFAULT_SELECTED)
+
+        self.all_measures = all_measures()
+
+        self.selectedMeasures = dict(
+            [(name, True) for name in _DEFAULT_SELECTED] +
+            [(m.name, False)
+             for m in self.all_measures[len(_DEFAULT_SELECTED):]]
+        )
+
         self.data = None
-        
-#        self.measure_parameters = AttributeDict()
-#        self.measure_parameters = {}
-        
+
         self.methodParamAttrs = []
-        for _, _, m in MEASURES:
-            params = measure_parameters(m) or []
+        for m in self.all_measures:
+            params = measure_parameters(m)
             for p in params:
-                setattr(self, param_attr_name(m, p), p.default)
-                self.methodParamAttrs.append(param_attr_name(m, p))
+                name_mangled = param_attr_name(m.score, p)
+                setattr(self, name_mangled, p.default)
+                self.methodParamAttrs.append(name_mangled)
+
         self.settingsList = self.settingsList + self.methodParamAttrs
-        
-        self.loadSettings() 
 
-        labelWidth = 80
-        
-        self.discMeasures = [name for name, short, _ in MEASURES \
-                             if supports_classification(name)]
-        
-        self.contMeasures = [name for name, short, _ in MEASURES \
-                             if supports_regression(name)]
-        
-        self.discMeasuresShort = [short for name, short, _ in MEASURES \
-                                  if supports_classification(name)]
-        
-        self.contMeasuresShort = [short for name, short, _ in MEASURES \
-                                  if supports_regression(name)]
-        
-        self.discEstimators = [measure for name, _, measure in MEASURES \
-                               if supports_classification(name)]
-        
-        self.contEstimators = [measure for name, _, measure in MEASURES \
-                               if supports_regression(name)]
-        
-        self.discHandlesContinuous = map(handles_continuous, self.discMeasures)
-        self.contHandlesContinuous = map(handles_continuous, self.contMeasures)
+        self.loadSettings()
 
-        # The stacked layout for Classification/Regression measures
-#        self.stackedWidget = OWGUI.widgetBox(self.controlArea, margin=0,
-#                                             addSpace=True)
-        
+        self.discMeasures = [m for m in self.all_measures
+                             if m.supports_classification]
+        self.contMeasures = [m for m in self.all_measures
+                             if m.supports_regression]
+
         self.stackedLayout = QStackedLayout()
         self.stackedLayout.setContentsMargins(0, 0, 0, 0)
         self.stackedWidget = OWGUI.widgetBox(self.controlArea, margin=0,
                                              orientation=self.stackedLayout,
                                              addSpace=True)
-#        self.stackedWidget.layout().addLayout(self.stackedLayout)
+
         # Discrete class scoring
         discreteBox = OWGUI.widgetBox(self.stackedWidget, "Scoring",
                                       addSpace=False,
                                       addToLayout=False)
         self.stackedLayout.addWidget(discreteBox)
-        
+
         # Continuous class scoring
         continuousBox = OWGUI.widgetBox(self.stackedWidget, "Scoring",
                                         addSpace=False,
                                         addToLayout=False)
         self.stackedLayout.addWidget(continuousBox)
-        
-        def measure_control(container, name, measure):
-            """ Construct UI control for measure.
+
+        def measure_control(container, measure):
+            """Construct UI control for `measure` (measure_meta instance).
             """
+            name = measure.name
             params = measure_parameters(measure)
             if params:
-                hbox = OWGUI.widgetBox(container, orientation = "horizontal")
+                hbox = OWGUI.widgetBox(container, orientation="horizontal")
                 OWGUI.checkBox(hbox, self.selectedMeasures, name, name,
-                               callback=partial(self.measuresSelectionChanged, name),
+                               callback=partial(self.measuresSelectionChanged,
+                                                measure),
                                tooltip="Enable " + name)
-                smallWidget = OWGUI.SmallWidgetLabel(hbox, pixmap=1, box=name + " Parameters",
-                                                     tooltip="Show " + name + "Parameters")
+
+                smallWidget = OWGUI.SmallWidgetLabel(
+                    hbox, pixmap=1, box=name + " Parameters",
+                    tooltip="Show " + name + "Parameters")
+
                 for param in params:
-                    OWGUI.spin(smallWidget.widget, self, param_attr_name(measure, param),
+                    OWGUI.spin(smallWidget.widget, self,
+                               param_attr_name(measure.score, param),
                                param.range[0], param.range[-1],
-                               label=param.display_name, 
+                               label=param.display_name,
                                tooltip=param.doc,
-                               callback=partial(self.measureParamChanged, name, param),
+                               callback=partial(
+                                    self.measureParamChanged, measure, param),
                                callbackOnReturn=True)
-                
+
                 OWGUI.button(smallWidget.widget, self, "Load defaults",
-                             callback=partial(self.loadMeasureDefaults, name))
+                             callback=partial(self.loadMeasureDefaults,
+                                              measure))
             else:
                 OWGUI.checkBox(container, self.selectedMeasures, name, name,
-                               callback=partial(self.measuresSelectionChanged, name),
+                               callback=partial(self.measuresSelectionChanged,
+                                                measure),
                                tooltip="Enable " + name)
-        
-        for name, short_name, measure in MEASURES:
-            if supports_classification(name):
-                measure_control(discreteBox, name, measure)
-                    
-            if supports_regression(name):
-                measure_control(continuousBox, name, measure)
-        
-        
+
+        for measure in self.all_measures:
+            if measure.supports_classification:
+                measure_control(discreteBox, measure)
+
+            if measure.supports_regression:
+                measure_control(continuousBox, measure)
+
         OWGUI.comboBox(discreteBox, self, "sortBy", label = "Sort by"+"  ",
                        items = ["No Sorting", "Attribute Name", "Number of Values"] + \
-                               [name for name in self.discMeasures],
+                               [m.name for m in self.discMeasures],
                        orientation=0, valueType = int,
                        callback=self.sortingChanged)
         
         OWGUI.comboBox(continuousBox, self, "sortBy", label = "Sort by"+"  ",
                        items = ["No Sorting", "Attribute Name", "Number of Values"] + \
-                               [name for name in self.contMeasures],
+                               [m.name for m in self.contMeasures],
                        orientation=0, valueType = int,
                        callback=self.sortingChanged)
 
 #        self.discRanksView.horizontalHeader().restoreState(self.discRanksHeaderState)
         
         self.discRanksModel = QStandardItemModel(self)
-        self.discRanksModel.setHorizontalHeaderLabels(["Attribute", "#"] + self.discMeasuresShort)
+        self.discRanksModel.setHorizontalHeaderLabels(
+            ["Attribute", "#"] + [m.shortname for m in self.discMeasures]
+        )
         self.discRanksProxyModel = MySortProxyModel(self)
         self.discRanksProxyModel.setSourceModel(self.discRanksModel)
         self.discRanksView.setModel(self.discRanksProxyModel)
 #        self.contRanksView.horizontalHeader().restoreState(self.contRanksHeaderState)
         
         self.contRanksModel = QStandardItemModel(self)
-        self.contRanksModel.setHorizontalHeaderLabels(["Attribute", "#"] + self.contMeasuresShort)
+        self.contRanksModel.setHorizontalHeaderLabels(
+            ["Attribute", "#"] + [m.shortname for m in self.contMeasures]
+        )
         self.contRanksProxyModel = MySortProxyModel(self)
         self.contRanksProxyModel.setSourceModel(self.contRanksModel)
         self.contRanksView.setModel(self.contRanksProxyModel)
         """
         self.ranksViewStack.setCurrentIndex(index)
         self.stackedLayout.setCurrentIndex(index)
-        
+
         if index == 0:
             self.ranksView = self.discRanksView
             self.ranksModel = self.discRanksModel
             self.ranksProxyModel = self.discRanksProxyModel
             self.measures = self.discMeasures
-            self.handlesContinuous = self.discHandlesContinuous
-            self.estimators = self.discEstimators
         else:
             self.ranksView = self.contRanksView
             self.ranksModel = self.contRanksModel
             self.ranksProxyModel = self.contRanksProxyModel
             self.measures = self.contMeasures
-            self.handlesContinuous = self.contHandlesContinuous
-            self.estimators = self.contEstimators
-            
+
         self.updateVisibleScoreColumns()
             
     def setData(self, data):
         if not self.data:
             return
         
-        estimators = self.estimators
+#         estimators = self.estimators
         measures = self.measures
-        handlesContinous = self.handlesContinuous
-        self.warning(range(max(len(self.discEstimators), len(self.contEstimators))))
-        
+#         handlesContinous = self.handlesContinuous
+        # Invalidate all warnings
+        self.warning(range(max(len(self.discMeasures),
+                               len(self.contMeasures))))
+
         if measuresMask is None:
             # Update all selected measures
-            measuresMask = [self.selectedMeasures.get(m) for m in measures]
-        
-        for measure_index, (est, meas, mask) in enumerate(zip(
-                estimators, measures, measuresMask)):
+            measuresMask = [self.selectedMeasures.get(m.name)
+                            for m in measures]
+
+        for measure_index, (meas, mask) in enumerate(zip(measures, measuresMask)):
             if not mask:
                 continue
-            handles = MEASURES_HANDLES_CONTINUOUS.get(meas, False)
-            params = measure_parameters(est)
-            estimator = est()
+
+            params = measure_parameters(meas)
+            estimator = meas.score()
             if params:
                 for p in params:
                     setattr(estimator, p.name,
-                            getattr(self, param_attr_name(est, p)))
-                    
-            if not handles:
+                            getattr(self, param_attr_name(meas.score, p)))
+
+            if not meas.handles_continuous:
                 data = self.getDiscretizedData()
                 attr_map = data.attrDict
                 data = self.data
             else:
                 attr_map, data = {}, self.data
+
             attr_scores = []
             for i, attr in enumerate(data.domain.attributes):
                 attr = attr_map.get(attr, attr)
                     try:
                         s = estimator(attr, data)
                     except Exception, ex:
-                        self.warning(measure_index, "Error evaluating %r: %r" % (meas, str(ex)))
+                        self.warning(measure_index, "Error evaluating %r: %r" % (meas.name, str(ex)))
                         # TODO: store exception message (for widget info or item tooltip)
-                    if meas == "Log Odds Ratio" and s is not None:
+                    if meas.name == "Log Odds Ratio" and s is not None:
                         if s == -999999:
                             attr = u"-\u221E"
                         elif s == 999999:
             self.discretizedData = self.data.select(orange.Domain(at, self.data.domain.classVar))
             self.discretizedData.setattr("attrDict", attrDict)
         return self.discretizedData
-        
+
     def discretizationChanged(self):
         self.discretizedData = None
-        self.updateScores([not b for b in self.handlesContinuous])
+        self.updateScores([not m.handles_continuous for m in self.measures])
         self.autoSelection()
-        
-    def measureParamChanged(self, name, param=None):
-        index = self.measures.index(name)
-        measure = self.estimators[index]
+
+    def measureParamChanged(self, measure, param=None):
+        index = self.measures.index(measure)
         mask = [i == index for i, _ in enumerate(self.measures)]
         self.updateScores(mask)
     
-    def loadMeasureDefaults(self, name):
-        index = self.measures.index(name)
-        measure = self.estimators[index]
+    def loadMeasureDefaults(self, measure):
+#         index = self.measures.index(measure)
+#         measure = self.estimators[index]
         params = measure_parameters(measure)
         for i, p in enumerate(params):
-            setattr(self, param_attr_name(measure, p), p.default)
-        self.measureParamChanged(name)
+            setattr(self, param_attr_name(measure.score, p), p.default)
+        self.measureParamChanged(measure)
         
     def autoSelection(self):
         selModel = self.ranksView.selectionModel()
             self.ranksView.setSortingEnabled(True)
 
     def setLogORTitle(self):
-        var =self.data.domain.classVar    
+        var = self.data.domain.classVar
         if len(var.values) == 2:
             title = "log OR (for %r)" % var.values[1][:10]
         else:
             title = "log OR"
-        if "Log Odds Ratio" in self.discEstimators:
-            index = self.discMeasures.index("Log Odds Ratio")
-            item = PyStandardItem(title)
-            self.ranksModel.setHorizontalHeaderItem(index + 2, item)
+#         if "Log Odds Ratio" in self.discEstimators:
+#             index = self.discMeasures.index("Log Odds Ratio")
+        index = [m.name for m in self.discMeasures].index("Log Odds Ratio")
 
-    def measuresSelectionChanged(self, name=None):
-        """ Measure selection has changed. Update column visibility.
+        item = PyStandardItem(title)
+        self.ranksModel.setHorizontalHeaderItem(index + 2, item)
+
+    def measuresSelectionChanged(self, measure=None):
+        """Measure selection has changed. Update column visibility.
         """
-        if name is None:
+        if measure is None:
             # Update all scores
             measuresMask = None
         else:
             # Update scores for shown column if they are not yet computed.
-            shown = self.selectedMeasures.get(name, False)
-            index = self.measures.index(name)
+            shown = self.selectedMeasures.get(measure.name, False)
+            index = self.measures.index(measure)
             if all(s is None for s in self.measure_scores[index]) and shown:
-                measuresMask = [n == name for n in self.measures]
+                measuresMask = [m == measure for m in self.measures]
             else:
                 measuresMask = [False] * len(self.measures)
         self.updateScores(measuresMask)
         """ Update the visible columns of the scores view.
         """
         for i, measure in enumerate(self.measures):
-            shown = self.selectedMeasures.get(measure)
+            shown = self.selectedMeasures.get(measure.name)
             self.ranksView.setColumnHidden(i + 2, not shown)
 
     def sortByColumn(self, col):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.