Commits

Aleš Erjavec  committed 3995884

Using PyQwt5 for PCA scree plot.

  • Participants
  • Parent commits 86b1eb3

Comments (0)

Files changed (1)

File Orange/OrangeWidgets/Unsupervised/OWPCA.py

 <priority>3050</priority>
 
 """
-import Orange
-import Orange.utils.addons
+import sys
 
-from OWWidget import *
-import OWGUI
+import numpy as np
+
+from PyQt4.Qwt5 import QwtPlot, QwtPlotCurve, QwtSymbol
+from PyQt4.QtCore import pyqtSignal as Signal, pyqtSlot as Slot
 
 import Orange
 import Orange.projection.linear as plinear
 
-import numpy as np
-import sys
+from OWWidget import *
+from OWGraph import OWGraph
 
-from plot.owplot import OWPlot
-from plot.owcurve import OWCurve
-from plot import owaxis
+import OWGUI
 
 
-class ScreePlot(OWPlot):
-    def __init__(self, parent=None, name="Scree Plot"):
-        OWPlot.__init__(self, parent, name=name)
-        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
-                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
-        self.cutoff_curve.setVisible(False)
-        self.cutoff_curve.set_style(OWCurve.Lines)
-        self.add_custom_curve(self.cutoff_curve)
+def plot_curve(title=None, pen=None, brush=None, style=QwtPlotCurve.Lines,
+               symbol=QwtSymbol.Ellipse, legend=True, antialias=True,
+               auto_scale=True, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
+    curve = QwtPlotCurve(title or "")
+    return configure_curve(curve, pen=pen, brush=brush, style=style,
+                           symbol=symbol, legend=legend, antialias=antialias,
+                           auto_scale=auto_scale, xaxis=xaxis, yaxis=yaxis)
 
-    def is_cutoff_enabled(self):
-        return self.cutoff_curve and self.cutoff_curve.isVisible()
 
-    def set_cutoff_curve_enabled(self, state):
-        self.cutoff_curve.setVisible(state)
+def configure_curve(curve, title=None, pen=None, brush=None,
+          style=QwtPlotCurve.Lines, symbol=QwtSymbol.Ellipse,
+          legend=True, antialias=True, auto_scale=True,
+          xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
+    if title is not None:
+        curve.setTitle(title)
+    if pen is not None:
+        curve.setPen(pen)
 
-    def set_cutoff_value(self, value):
-        xmin, xmax = self.x_scale()
-        x = min(max(value, xmin), xmax)
-        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
+    if brush is not None:
+        curve.setBrush(brush)
+
+    if not isinstance(symbol, QwtSymbol):
+        symbol_ = QwtSymbol()
+        symbol_.setStyle(symbol)
+        symbol = symbol_
+
+    curve.setStyle(style)
+    curve.setSymbol(QwtSymbol(symbol))
+    curve.setRenderHint(QwtPlotCurve.RenderAntialiased, antialias)
+    curve.setItemAttribute(QwtPlotCurve.Legend, legend)
+    curve.setItemAttribute(QwtPlotCurve.AutoScale, auto_scale)
+    curve.setAxis(xaxis, yaxis)
+    return curve
+
+
+class PlotTool(QObject):
+    """
+    A base class for Plot tools that operate on QwtPlot's canvas
+    widget by installing itself as its event filter.
+
+    """
+    cursor = Qt.ArrowCursor
+
+    def __init__(self, parent=None, graph=None):
+        QObject.__init__(self, parent)
+        self.__graph = None
+        self.__oldCursor = None
+        self.setGraph(graph)
+
+    def setGraph(self, graph):
+        """
+        Install this tool to operate on ``graph``.
+        """
+        if self.__graph is graph:
+            return
+
+        if self.__graph is not None:
+            self.uninstall(self.__graph)
+
+        self.__graph = graph
+
+        if graph is not None:
+            self.install(graph)
+
+    def graph(self):
+        return self.__graph
+
+    def install(self, graph):
+        canvas = graph.canvas()
+        canvas.setMouseTracking(True)
+        canvas.installEventFilter(self)
+        canvas.destroyed.connect(self.__on_destroyed)
+        self.__oldCursor = canvas.cursor()
+        canvas.setCursor(self.cursor)
+
+    def uninstall(self, graph):
+        canvas = graph.canvas()
+        canvas.removeEventFilter(self)
+        canvas.setCursor(self.__oldCursor)
+        canvas.destroyed.disconnect(self.__on_destroyed)
+        self.__oldCursor = None
+
+    def eventFilter(self, obj, event):
+        if obj is self.__graph.canvas():
+            return self.canvasEvent(event)
+        return False
+
+    def canvasEvent(self, event):
+        """
+        Main handler for a canvas events.
+        """
+        if event.type() == QEvent.MouseButtonPress:
+            return self.mousePressEvent(event)
+        elif event.type() == QEvent.MouseButtonRelease:
+            return self.mouseReleaseEvent(event)
+        elif event.type() == QEvent.MouseButtonDblClick:
+            return self.mouseDoubleClickEvent(event)
+        elif event.type() == QEvent.MouseMove:
+            return self.mouseMoveEvent(event)
+        elif event.type() == QEvent.Leave:
+            return self.leaveEvent(event)
+        elif event.type() == QEvent.Enter:
+            return self.enterEvent(event)
+        return False
+
+    # These are actually event filters (note the return values)
+    def mousePressEvent(self, event):
+        return False
+
+    def mouseMoveEvent(self, event):
+        return False
+
+    def mouseReleaseEvent(self, event):
+        return False
+
+    def mouseDoubleClickEvent(self, event):
+        return False
+
+    def enterEvent(self, event):
+        return False
+
+    def leaveEvent(self, event):
+        return False
+
+    def keyPressEvent(self, event):
+        return False
+
+    def transform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
+        """
+        Transform a QPointF from plot coordinates to canvas local coordinates.
+        """
+        x = self.__graph.transform(xaxis, point.x())
+        y = self.__graph.transform(yaxis, point.y())
+        return QPoint(x, y)
+
+    def invTransform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
+        """
+        Transform a QPoint from canvas local coordinates to plot coordinates.
+        """
+        x = self.__graph.invTransform(xaxis, point.x())
+        y = self.__graph.invTransform(yaxis, point.y())
+        return QPointF(x, y)
+
+    @Slot()
+    def __on_destroyed(self, obj):
+        obj.removeEventFilter(self)
+
+
+class CutoffControler(PlotTool):
+
+    class CutoffCurve(QwtPlotCurve):
+        pass
+
+    cutoffChanged = Signal(float)
+    cutoffMoved = Signal(float)
+    cutoffPressed = Signal()
+    cutoffReleased = Signal()
+
+    NoState, Drag = 0, 1
+
+    def __init__(self, parent=None, graph=None):
+        self.__curve = None
+        self.__range = (0, 1)
+        self.__cutoff = 0
+        super(CutoffControler, self).__init__(parent, graph)
+        self._state = self.NoState
+
+    def install(self, graph):
+        super(CutoffControler, self).install(graph)
+        assert self.__curve is None
+        self.__curve = CutoffControler.CutoffCurve("")
+        configure_curve(self.__curve, symbol=QwtSymbol.NoSymbol, legend=False)
+        self.__curve.setData([self.__cutoff, self.__cutoff], [0.0, 1.0])
+        self.__curve.attach(graph)
+
+    def uninstall(self, graph):
+        super(CutoffControler, self).uninstall(graph)
+        self.__curve.detach()
+        self.__curve = None
+
+    def _toRange(self, value):
+        minval, maxval = self.__range
+        return max(min(value, maxval), minval)
 
     def mousePressEvent(self, event):
-        if self.isLegendEvent(event, QGraphicsView.mousePressEvent):
-            return
-
-        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
-            pos = self.mapToScene(event.pos())
-            x, _ = self.map_from_graph(pos)
-            xmin, xmax = self.x_scale()
-            if x >= xmin - 0.1 and x <= xmax + 0.1:
-                x = min(max(x, xmin), xmax)
-                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
-                self.emit_cutoff_moved(x)
-        return QGraphicsView.mousePressEvent(self, event)
+        if event.button() == Qt.LeftButton:
+            cut = self.invTransform(event.pos()).x()
+            self.setCutoff(cut)
+            self.cutoffPressed.emit()
+            self._state = self.Drag
+        return True
 
     def mouseMoveEvent(self, event):
-        if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent):
-            return
+        if self._state == self.Drag:
+            cut = self._toRange(self.invTransform(event.pos()).x())
+            self.setCutoff(cut)
+            self.cutoffMoved.emit(cut)
+        else:
+            cx = self.transform(QPointF(self.cutoff(), 0)).x()
+            if abs(cx - event.pos().x()) < 2:
+                self.graph().canvas().setCursor(Qt.SizeHorCursor)
+            else:
+                self.graph().canvas().setCursor(self.cursor)
+        return True
 
-        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
-            pos = self.mapToScene(event.pos())
-            x, _ = self.map_from_graph(pos)
-            xmin, xmax = self.x_scale()
-            if x >= xmin - 0.5 and x <= xmax + 0.5:
-                x = min(max(x, xmin), xmax)
-                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
-                self.emit_cutoff_moved(x)
-        elif self.is_cutoff_enabled() and \
-                self.is_pos_over_cutoff_line(event.pos()):
-            self.setCursor(Qt.SizeHorCursor)
-        else:
-            self.setCursor(Qt.ArrowCursor)
+    def mouseReleaseEvent(self, event):
+        if event.button() == Qt.LeftButton and self._state == self.Drag:
+            cut = self._toRange(self.invTransform(event.pos()).x())
+            self.setCutoff(cut)
+            self.cutoffReleased.emit()
+            self._state = self.NoState
+        return True
 
-        return QGraphicsView.mouseMoveEvent(self, event)
+    def setCutoff(self, cutoff):
+        minval, maxval = self.__range
+        cutoff = max(min(cutoff, maxval), minval)
+        if self.__cutoff != cutoff:
+            self.__cutoff = cutoff
+            if self.__curve is not None:
+                self.__curve.setData([cutoff, cutoff], [0.0, 1.0])
+            self.cutoffChanged.emit(cutoff)
+            if self.graph() is not None:
+                self.graph().replot()
 
-    def mouseReleaseEvene(self, event):
-        return QGraphicsView.mouseReleaseEvent(self, event)
+    def cutoff(self):
+        return self.__cutoff
 
-    def x_scale(self):
-        ax = self.axes[owaxis.xBottom]
-        if ax.labels:
-            return 0, len(ax.labels) - 1
-        elif ax.scale:
-            return ax.scale[0], ax.scale[1]
-        else:
-            raise ValueError
+    def setRange(self, minval, maxval):
+        maxval = max(minval, maxval)
+        if self.__range != (minval, maxval):
+            self.__range = (minval, maxval)
+            self.setCutoff(max(min(self.cutoff(), maxval), minval))
 
-    def emit_cutoff_moved(self, x):
-        self.emit(SIGNAL("cutoff_moved(double)"), x)
 
-    def set_axis_labels(self, *args):
-        OWPlot.set_axis_labels(self, *args)
-        self.map_transform = self.transform_for_axes()
+class Graph(OWGraph):
+    def __init__(self, *args, **kwargs):
+        super(Graph, self).__init__(*args, **kwargs)
+        self.gridCurve.attach(self)
 
-    def is_pos_over_cutoff_line(self, pos):
-        x1 = self.inv_transform(owaxis.xBottom, pos.x() - 1.5)
-        x2 = self.inv_transform(owaxis.xBottom, pos.x() + 1.5)
-        y = self.inv_transform(owaxis.yLeft, pos.y())
-        if y < 0.0 or y > 1.0:
-            return False
-        curve_data = self.cutoff_curve.data()
-        if not curve_data:
-            return False
-        cutoff = curve_data[0][0]
-        return x1 < cutoff and cutoff < x2
+    # bypass the OWGraph event handlers
+    def mousePressEvent(self, event):
+        QwtPlot.mousePressEvent(self, event)
 
+    def mouseMoveEvent(self, event):
+        QwtPlot.mouseMoveEvent(self, event)
 
-class CutoffCurve(OWCurve):
-    def __init__(self, *args, **kwargs):
-        OWCurve.__init__(self, *args, **kwargs)
-        self.setAcceptHoverEvents(True)
-        self.setCursor(Qt.SizeHorCursor)
+    def mouseReleaseEvent(self, event):
+        QwtPlot.mouseReleaseEvent(self, event)
 
 
 class OWPCA(OWWidget):
                          callback=self.update_components)
         OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
 
-        self.scree_plot = ScreePlot(self)
-#        self.scree_plot.set_main_title("Scree Plot")
-#        self.scree_plot.set_show_main_title(True)
-        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
-        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
-        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
-        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
+        self.plot = Graph()
+        canvas = self.plot.canvas()
+        canvas.setFrameStyle(QFrame.StyledPanel)
+        self.mainArea.layout().addWidget(self.plot)
+        self.plot.setAxisTitle(QwtPlot.yLeft, "Proportion of Variance")
+        self.plot.setAxisTitle(QwtPlot.xBottom, "Principal Components")
+        self.plot.setAxisScale(QwtPlot.yLeft, 0.0, 1.0)
+        self.plot.enableGridXB(True)
+        self.plot.enableGridYL(True)
+        self.plot.setGridColor(Qt.lightGray)
 
-        self.variance_curve = self.scree_plot.add_curve(
-                        "Variance",
-                        Qt.red, Qt.red, 2,
-                        xData=[],
-                        yData=[],
-                        style=OWCurve.Lines,
-                        enableLegend=True,
-                        lineWidth=2,
-                        autoScale=1,
-                        x_axis_key=owaxis.xBottom,
-                        y_axis_key=owaxis.yLeft,
-                        )
+        self.variance_curve = plot_curve(
+            "Variance",
+            pen=QPen(Qt.red, 2),
+            symbol=QwtSymbol.NoSymbol,
+            xaxis=QwtPlot.xBottom,
+            yaxis=QwtPlot.yLeft
+        )
+        self.cumulative_variance_curve = plot_curve(
+            "Cumulative Variance",
+            pen=QPen(Qt.darkYellow, 2),
+            symbol=QwtSymbol.NoSymbol,
+            xaxis=QwtPlot.xBottom,
+            yaxis=QwtPlot.yLeft
+        )
 
-        self.cumulative_variance_curve = self.scree_plot.add_curve(
-                        "Cumulative Variance",
-                        Qt.darkYellow, Qt.darkYellow, 2,
-                        xData=[],
-                        yData=[],
-                        style=OWCurve.Lines,
-                        enableLegend=True,
-                        lineWidth=2,
-                        autoScale=1,
-                        x_axis_key=owaxis.xBottom,
-                        y_axis_key=owaxis.yLeft,
-                        )
+        self.variance_curve.attach(self.plot)
+        self.cumulative_variance_curve.attach(self.plot)
 
-        self.mainArea.layout().addWidget(self.scree_plot)
-        self.connect(self.scree_plot,
-                     SIGNAL("cutoff_moved(double)"),
-                     self.on_cutoff_moved
-                     )
+        self.selection_tool = CutoffControler(parent=self.plot.canvas())
+        self.selection_tool.cutoffMoved.connect(self.on_cutoff_moved)
 
-        self.connect(self.graphButton,
-                     SIGNAL("clicked()"),
-                     self.scree_plot.save_to_file)
-
+        self.graphButton.clicked.connect(self.saveToFile)
         self.components = None
         self.variances = None
         self.variances_sum = None
         self.resize(800, 400)
 
     def clear(self):
-        """Clear widget state
+        """
+        Clear (reset) the widget state.
         """
         self.data = None
-        self.scree_plot.set_cutoff_curve_enabled(False)
+        self.selection_tool.setGraph(None)
         self.clear_cached()
         self.variance_curve.setVisible(False)
         self.cumulative_variance_curve.setVisible(False)
 
         """
         pca = self.construct_pca_all_comp()
-        self.projector_full = projector = pca(self.data)
+        self.projector_full = pca(self.data)
 
         self.variances = self.projector_full.variances
         self.variances /= np.sum(self.variances)
         components = self.projector_full.projection
         input_domain = self.projector_full.input_domain
         variances = self.projector_full.variances
-        variance_sum = self.projector_full.variance_sum
 
         # Get selected components (based on max_components and
         # variance_coverd)
 
     def update_scree_plot(self):
         x_space = np.arange(0, len(self.variances))
-        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
-        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
-        self.scree_plot.set_axis_labels(owaxis.xBottom,
-                                        ["PC" + str(i + 1) for i in x_space])
+        self.plot.enableAxis(QwtPlot.xBottom, True)
+        self.plot.enableAxis(QwtPlot.yLeft, True)
+        if len(x_space) <= 5:
+            self.plot.setXlabels(["PC" + str(i + 1) for i in x_space])
+        else:
+            # Restore continuous plot scale
+            # TODO: disable minor ticks
+            self.plot.setXlabels(None)
 
-        self.variance_curve.set_data(x_space, self.variances)
-        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
+        self.variance_curve.setData(x_space, self.variances)
+        self.cumulative_variance_curve.setData(x_space, self.variances_cumsum)
         self.variance_curve.setVisible(True)
         self.cumulative_variance_curve.setVisible(True)
 
-        self.scree_plot.set_cutoff_curve_enabled(True)
-        self.scree_plot.replot()
+        self.selection_tool.setRange(0, len(self.variances) - 1)
+        self.selection_tool.setGraph(self.plot)
+        self.plot.replot()
 
     def on_cutoff_moved(self, value):
         """Cutoff curve was moved by the user.
 
         variance = self.variances_cumsum[max_components - 1] * 100.0
         if variance < self.variance_covered:
-            cutoff = float(max_components - 1)
+            cutoff = max_components - 1
         else:
             cutoff = np.searchsorted(self.variances_cumsum,
                                      self.variance_covered / 100.0)
-        self.scree_plot.set_cutoff_value(cutoff + 0.5)
+
+        self.selection_tool.setCutoff(float(cutoff + 0.5))
 
     def number_of_selected_components(self):
         """How many components are selected.
             self.reportRaw(summary)
 
             self.reportSection("Scree Plot")
-            self.reportImage(self.scree_plot.save_to_file_direct)
+            self.reportImage(self.plot.saveToFileDirect)
+
+    def saveToFile(self):
+        self.plot.saveToFile()
 
 
 def append_metas(dest, source):
     data = Orange.data.Table("iris")
     w.set_data(data)
     w.show()
+    w.set_data(Orange.data.Table("brown-selected"))
     app.exec_()