Commits

astaric committed dcbaa22

Refactor OWParallelGraph.

Comments (0)

Files changed (1)

Orange/OrangeWidgets/Visualize/OWParallelGraph.py

         orngScaleData.setData(self, data, subsetData, **args)
         self.domainContingency = None
 
-
     # update shown data. Set attributes, coloring by className ....
     def updateData(self, attributes, midLabels=None, updateAxisScale=1):
         self.removeDrawingCurves(removeLegendItems=0, removeMarkers=1)  # don't delete legend items
 
         self.visualizedAttributes = attributes
         self.visualizedMidLabels = midLabels
-        for name in self.selectionConditions.keys():        # keep only conditions that are related to the currently visualized attributes
+        self.filter_stale_conditions()
+
+        # set the limits for panning
+        self.xPanningInfo = (1, 0, len(attributes) - 1)
+        self.yPanningInfo = (0, 0, 0)
+
+        self.update_scale(attributes, updateAxisScale)
+
+        length = len(attributes)
+        indices = [self.attributeNameIndex[label] for label in attributes]
+
+        xs = range(length)
+        dataSize = len(self.scaledData[0])
+
+        if self.dataHasDiscreteClass:
+            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
+
+        validData = self.getValidList(indices)
+
+        self.draw_lines(attributes, dataSize, indices, validData)
+
+        if self.showDistributions and self.dataHasDiscreteClass and self.haveData:
+            self.draw_distributions(validData, indices)
+
+        self.draw_axes(attributes)
+
+        if self.showStatistics and self.haveData:
+            self.draw_statistics(indices, length)
+
+        if midLabels:
+            self.draw_midlabels(midLabels)
+
+        if self.enabledLegend == 1 and self.dataHasDiscreteClass:
+            self.draw_legend(attributes)
+        else:
+            self.legend().clear()
+            self.oldLegendKeys = []
+
+        self.replot()
+
+    def filter_stale_conditions(self):
+        for name in self.selectionConditions.keys():
             if name not in self.visualizedAttributes:
                 self.selectionConditions.pop(name)
 
-        # set the limits for panning
-        self.xPanningInfo = (1, 0, len(attributes) - 1)
-        self.yPanningInfo = (
-        0, 0, 0)   # we don't enable panning in y direction so it doesn't matter what values we put in for the limits
-
+    def update_scale(self, attributes, updateAxisScale):
         if updateAxisScale:
             if self.showAttrValues:
                 self.setAxisScale(QwtPlot.yLeft, -0.04, 1.04, 1)
                 M = self.axisScaleDiv(QwtPlot.xBottom).interval().maxValue()
                 if m < 0 or M > len(attributes) - 2:
                     self.setAxisScale(QwtPlot.xBottom, 0, len(attributes) - 1, 1)
-
         self.setAxisScaleDraw(QwtPlot.xBottom,
                               DiscreteAxisScaleDraw([self.getAttributeLabel(attr) for attr in attributes]))
         #self.setAxisScaleDraw(QwtPlot.yLeft, HiddenScaleDraw())
         self.setAxisMaxMajor(QwtPlot.xBottom, len(attributes))
         self.setAxisMaxMinor(QwtPlot.xBottom, 0)
 
-        length = len(attributes)
-        indices = [self.attributeNameIndex[label] for label in attributes]
-
-        xs = range(length)
-        dataSize = len(self.scaledData[0])
-
-        if self.dataHasDiscreteClass:
-            self.discPalette.setNumberOfColors(len(self.dataDomain.classVar.values))
-
-
-        # ############################################
-        # draw the data
-        # ############################################
+    def draw_lines(self, attributes, dataSize, indices, validData):
         subsetIdsToDraw = self.haveSubsetData and dict(
             [(self.rawSubsetData[i].id, 1) for i in self.getValidSubsetIndices(indices)]) or {}
-        validData = self.getValidList(indices)
         mainCurves = {}
         subCurves = {}
         conditions = dict([(name, attributes.index(name)) for name in self.selectionConditions.keys()])
-
         for i in range(dataSize):
             if not validData[i]:
                 continue
                 curve.setCurveAttribute(QwtPlotCurve.Fitted)
             curve.attach(self)
 
-
-
-        # ############################################
-        # do we want to show distributions with discrete attributes
-        if self.showDistributions and self.dataHasDiscreteClass and self.haveData:
-            self.showDistributionValues(validData, indices)
-
-        # ############################################
-        # draw vertical lines that represent attributes
+    def draw_axes(self, attributes):
         for i in range(len(attributes)):
             self.addCurve("", lineWidth=2, style=QwtPlotCurve.Lines, symbol=QwtSymbol.NoSymbol, xData=[i, i],
                           yData=[0, 1])
                         self.addMarker(attrVals[pos], i + 0.01, float(1 + 2 * pos) / float(2 * valsLen),
                                        alignment=Qt.AlignRight | Qt.AlignVCenter, bold=1, brushColor=Qt.white)
 
-        # ##############################################
-        # show lines that represent standard deviation or quartiles
-        # ##############################################
-        if self.showStatistics and self.haveData:
-            data = []
-            for i in range(length):
-                if self.dataDomain[indices[i]].varType != orange.VarTypes.Continuous:
-                    data.append([()])
-                    continue  # only for continuous attributes
-                array = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
-                                       self.scaledData[indices[i]])  # remove missing values
+    def draw_midlabels(self, midLabels):
+        for j in range(len(midLabels)):
+            self.addMarker(midLabels[j], j + 0.5, 1.0, alignment=Qt.AlignCenter | Qt.AlignTop)
 
-                if not self.dataHasClass or self.dataHasContinuousClass:    # no class
+    def draw_statistics(self, indices, length):
+        data = []
+        for i in range(length):
+            if self.dataDomain[indices[i]].varType != orange.VarTypes.Continuous:
+                data.append([()])
+                continue  # only for continuous attributes
+            array = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
+                                   self.scaledData[indices[i]])  # remove missing values
+
+            if not self.dataHasClass or self.dataHasContinuousClass:    # no class
+                if self.showStatistics == MEANS:
+                    m = array.mean()
+                    dev = array.std()
+                    data.append([(m - dev, m, m + dev)])
+                elif self.showStatistics == MEDIAN:
+                    sorted = numpy.sort(array)
+                    if len(sorted) > 0:
+                        data.append([(sorted[int(len(sorted) / 4.0)], sorted[int(len(sorted) / 2.0)],
+                                      sorted[int(len(sorted) * 0.75)])])
+                    else:
+                        data.append([(0, 0, 0)])
+            else:
+                curr = []
+                classValues = getVariableValuesSorted(self.dataDomain.classVar)
+                classValueIndices = getVariableValueIndices(self.dataDomain.classVar)
+                for c in range(len(classValues)):
+                    scaledVal = ((classValueIndices[classValues[c]] * 2) + 1) / float(2 * len(classValueIndices))
+                    nonMissingValues = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
+                                                      self.noJitteringScaledData[
+                                                          self.dataClassIndex])  # remove missing values
+                    arr_c = numpy.compress(numpy.equal(nonMissingValues, scaledVal), array)
+                    if len(arr_c) == 0:
+                        curr.append((0, 0, 0));
+                        continue
                     if self.showStatistics == MEANS:
-                        m = array.mean()
-                        dev = array.std()
-                        data.append([(m - dev, m, m + dev)])
+                        m = arr_c.mean()
+                        dev = arr_c.std()
+                        curr.append((m - dev, m, m + dev))
                     elif self.showStatistics == MEDIAN:
-                        sorted = numpy.sort(array)
-                        if len(sorted) > 0:
-                            data.append([(sorted[int(len(sorted) / 4.0)], sorted[int(len(sorted) / 2.0)],
-                                          sorted[int(len(sorted) * 0.75)])])
-                        else:
-                            data.append([(0, 0, 0)])
-                else:
-                    curr = []
-                    classValues = getVariableValuesSorted(self.dataDomain.classVar)
-                    classValueIndices = getVariableValueIndices(self.dataDomain.classVar)
-                    for c in range(len(classValues)):
-                        scaledVal = ((classValueIndices[classValues[c]] * 2) + 1) / float(2 * len(classValueIndices))
-                        nonMissingValues = numpy.compress(numpy.equal(self.validDataArray[indices[i]], 1),
-                                                          self.noJitteringScaledData[
-                                                              self.dataClassIndex])  # remove missing values
-                        arr_c = numpy.compress(numpy.equal(nonMissingValues, scaledVal), array)
-                        if len(arr_c) == 0:
-                            curr.append((0, 0, 0));
-                            continue
-                        if self.showStatistics == MEANS:
-                            m = arr_c.mean()
-                            dev = arr_c.std()
-                            curr.append((m - dev, m, m + dev))
-                        elif self.showStatistics == MEDIAN:
-                            sorted = numpy.sort(arr_c)
-                            curr.append((sorted[int(len(arr_c) / 4.0)], sorted[int(len(arr_c) / 2.0)],
-                                         sorted[int(len(arr_c) * 0.75)]))
-                    data.append(curr)
+                        sorted = numpy.sort(arr_c)
+                        curr.append((
+                            sorted[int(len(arr_c) / 4.0)], sorted[int(len(arr_c) / 2.0)],
+                            sorted[int(len(arr_c) * 0.75)]))
+                data.append(curr)
 
-            # draw vertical lines
-            for i in range(len(data)):
-                for c in range(len(data[i])):
-                    if data[i][c] == (): continue
-                    x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03
-                    col = QColor(self.discPalette[c])
-                    col.setAlpha(self.alphaValue2)
-                    self.addCurve("", col, col, 3, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x, x, x],
-                                  yData=[data[i][c][0], data[i][c][1], data[i][c][2]], lineWidth=4)
-                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
-                                  yData=[data[i][c][0], data[i][c][0]], lineWidth=4)
-                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
-                                  yData=[data[i][c][1], data[i][c][1]], lineWidth=4)
-                    self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
-                                  yData=[data[i][c][2], data[i][c][2]], lineWidth=4)
-
-            # draw lines with mean/median values
-            classCount = 1
-            if not self.dataHasClass or self.dataHasContinuousClass:
-                classCount = 1 # no class
-            else:
-                classCount = len(self.dataDomain.classVar.values)
-            for c in range(classCount):
-                diff = - 0.03 * (classCount - 1) / 2.0 + c * 0.03
-                ys = []
-                xs = []
-                for i in range(len(data)):
-                    if data[i] != [()]:
-                        ys.append(data[i][c][1]); xs.append(i + diff)
-                    else:
-                        if len(xs) > 1:
-                            col = QColor(self.discPalette[c])
-                            col.setAlpha(self.alphaValue2)
-                            self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys,
-                                          lineWidth=4)
-                        xs = [];
-                        ys = []
+        # draw vertical lines
+        for i in range(len(data)):
+            for c in range(len(data[i])):
+                if data[i][c] == (): continue
+                x = i - 0.03 * (len(data[i]) - 1) / 2.0 + c * 0.03
                 col = QColor(self.discPalette[c])
                 col.setAlpha(self.alphaValue2)
-                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, lineWidth=4)
+                self.addCurve("", col, col, 3, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x, x, x],
+                              yData=[data[i][c][0], data[i][c][1], data[i][c][2]], lineWidth=4)
+                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
+                              yData=[data[i][c][0], data[i][c][0]], lineWidth=4)
+                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
+                              yData=[data[i][c][1], data[i][c][1]], lineWidth=4)
+                self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=[x - 0.03, x + 0.03],
+                              yData=[data[i][c][2], data[i][c][2]], lineWidth=4)
 
+        # draw lines with mean/median values
+        classCount = 1
+        if not self.dataHasClass or self.dataHasContinuousClass:
+            classCount = 1 # no class
+        else:
+            classCount = len(self.dataDomain.classVar.values)
+        for c in range(classCount):
+            diff = - 0.03 * (classCount - 1) / 2.0 + c * 0.03
+            ys = []
+            xs = []
+            for i in range(len(data)):
+                if data[i] != [()]:
+                    ys.append(data[i][c][1]);
+                    xs.append(i + diff)
+                else:
+                    if len(xs) > 1:
+                        col = QColor(self.discPalette[c])
+                        col.setAlpha(self.alphaValue2)
+                        self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys,
+                                      lineWidth=4)
+                    xs = [];
+                    ys = []
+            col = QColor(self.discPalette[c])
+            col.setAlpha(self.alphaValue2)
+            self.addCurve("", col, col, 1, QwtPlotCurve.Lines, QwtSymbol.NoSymbol, xData=xs, yData=ys, lineWidth=4)
 
-        # ##################################################
-        # show labels in the middle of the axis
-        if midLabels:
-            for j in range(len(midLabels)):
-                self.addMarker(midLabels[j], j + 0.5, 1.0, alignment=Qt.AlignCenter | Qt.AlignTop)
+    def draw_legend(self, attributes):
+        if self.dataDomain.classVar.varType == orange.VarTypes.Discrete:
+            legendKeys = []
+            varValues = getVariableValuesSorted(self.dataDomain.classVar)
+            #self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0,0,0), QColor(0,0,0), 0, symbol = QwtSymbol.NoSymbol, enableLegend = 1)
+            for ind in range(len(varValues)):
+                #self.addCurve(varValues[ind], self.discPalette[ind], self.discPalette[ind], 15, symbol = QwtSymbol.Rect, enableLegend = 1)
+                legendKeys.append((varValues[ind], self.discPalette[ind]))
+            if legendKeys != self.oldLegendKeys:
+                self.oldLegendKeys = legendKeys
+                self.legend().clear()
+                self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0, 0, 0), QColor(0, 0, 0), 0,
+                              symbol=QwtSymbol.NoSymbol, enableLegend=1)
+                for (name, color) in legendKeys:
+                    self.addCurve(name, color, color, 15, symbol=QwtSymbol.Rect, enableLegend=1)
+        else:
+            l = len(attributes) - 1
+            xs = [l * 1.15, l * 1.20, l * 1.20, l * 1.15]
+            count = 200;
+            height = 1 / 200.
+            for i in range(count):
+                y = i / float(count)
+                col = self.contPalette[y]
+                curve = PolygonCurve(QPen(col), QBrush(col), xData=xs, yData=[y, y, y + height, y + height])
+                curve.attach(self)
 
-        # show the legend
-        if self.enabledLegend == 1 and self.dataHasDiscreteClass:
-            if self.dataDomain.classVar.varType == orange.VarTypes.Discrete:
-                legendKeys = []
-                varValues = getVariableValuesSorted(self.dataDomain.classVar)
-                #self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0,0,0), QColor(0,0,0), 0, symbol = QwtSymbol.NoSymbol, enableLegend = 1)
-                for ind in range(len(varValues)):
-                    #self.addCurve(varValues[ind], self.discPalette[ind], self.discPalette[ind], 15, symbol = QwtSymbol.Rect, enableLegend = 1)
-                    legendKeys.append((varValues[ind], self.discPalette[ind]))
-                if legendKeys != self.oldLegendKeys:
-                    self.oldLegendKeys = legendKeys
-                    self.legend().clear()
-                    self.addCurve("<b>" + self.dataDomain.classVar.name + ":</b>", QColor(0, 0, 0), QColor(0, 0, 0), 0,
-                                  symbol=QwtSymbol.NoSymbol, enableLegend=1)
-                    for (name, color) in legendKeys:
-                        self.addCurve(name, color, color, 15, symbol=QwtSymbol.Rect, enableLegend=1)
-            else:
-                l = len(attributes) - 1
-                xs = [l * 1.15, l * 1.20, l * 1.20, l * 1.15]
-                count = 200;
-                height = 1 / 200.
-                for i in range(count):
-                    y = i / float(count)
-                    col = self.contPalette[y]
-                    curve = PolygonCurve(QPen(col), QBrush(col), xData=xs, yData=[y, y, y + height, y + height])
-                    curve.attach(self)
-
-                # add markers for min and max value of color attribute
-                [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name]
-                decimals = self.dataDomain.classVar.numberOfDecimals
-                self.addMarker("%%.%df" % (decimals) % (minVal), xs[0] - l * 0.02, 0.04, Qt.AlignLeft)
-                self.addMarker("%%.%df" % (decimals) % (maxVal), xs[0] - l * 0.02, 1.0 - 0.04, Qt.AlignLeft)
-        else:
-            self.legend().clear()
-            self.oldLegendKeys = []
-
-        self.replot()
-
+            # add markers for min and max value of color attribute
+            [minVal, maxVal] = self.attrValues[self.dataDomain.classVar.name]
+            decimals = self.dataDomain.classVar.numberOfDecimals
+            self.addMarker("%%.%df" % (decimals) % (minVal), xs[0] - l * 0.02, 0.04, Qt.AlignLeft)
+            self.addMarker("%%.%df" % (decimals) % (maxVal), xs[0] - l * 0.02, 1.0 - 0.04, Qt.AlignLeft)
 
     # ##########################################
     # SHOW DISTRIBUTION BAR GRAPH
-    def showDistributionValues(self, validData, indices):
+    def draw_distributions(self, validData, indices):
         # create color table
         clsCount = len(self.dataDomain.classVar.values)
         #if clsCount < 1: clsCount = 1.0
                     height = 0.7 / float(clsCount * attrLen)
 
                     yLowBott = yOff + float(clsCount * height) / 2.0 - i * height
-                    curve = PolygonCurve(QPen(newColor), QBrush(newColor),
-                                         xData=[graphAttrIndex, graphAttrIndex + width, graphAttrIndex + width,
-                                                graphAttrIndex],
-                                         yData=[yLowBott, yLowBott, yLowBott - height, yLowBott - height], tooltip=(
-                        self.dataDomain[index].name, variableValueSorted[j], len(self.rawData),
-                        [(clsVal, attrValCont[clsVal]) for clsVal in classValueSorted]))
+                    xData = [graphAttrIndex, graphAttrIndex + width, graphAttrIndex + width, graphAttrIndex]
+                    yData = [yLowBott, yLowBott, yLowBott - height, yLowBott - height]
+                    tooltip = (self.dataDomain[index].name, variableValueSorted[j], len(self.rawData),
+                               [(clsVal, attrValCont[clsVal]) for clsVal in classValueSorted])
+                    curve = PolygonCurve(QPen(newColor), QBrush(newColor), xData, yData, tooltip)
                     curve.attach(self)
 
 
                 if attr.varType == orange.VarTypes.Continuous:
                     condition = self.selectionConditions.get(attr.name, [0, 1])
                     val = self.attrValues[attr.name][0] + condition[pos] * (
-                    self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
+                        self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
                     strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val)
                     QToolTip.showText(ev.globalPos(), strVal)
             else:
                         count = sum([v[1] for v in dist])
                         if count == 0: continue
                         tooltipText = "Attribute: <b>%s</b><br>Value: <b>%s</b><br>Total instances: <b>%i</b> (%.1f%%)<br>Class distribution:<br>" % (
-                        name, value, count, 100.0 * count / float(total))
+                            name, value, count, 100.0 * count / float(total))
                         for (val, n) in dist:
                             tooltipText += "&nbsp; &nbsp; <b>%s</b> : <b>%i</b> (%.1f%%)<br>" % (
-                            val, n, 100.0 * float(n) / float(count))
+                                val, n, 100.0 * float(n) / float(count))
                         QToolTip.showText(ev.globalPos(), tooltipText[:-4])
 
         elif ev.type() == QEvent.MouseMove:
 
             if attr.varType == orange.VarTypes.Continuous:
                 val = self.attrValues[attr.name][0] + oldCondition[pos] * (
-                self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
+                    self.attrValues[attr.name][1] - self.attrValues[attr.name][0])
                 strVal = attr.name + "= %%.%df" % (attr.numberOfDecimals) % (val)
                 QToolTip.showText(e.globalPos(), strVal)
             if self.sendSelectionOnUpdate and self.autoSendSelectionCallback: