Commits

Christoph Dann committed 58a03b9

confidence interval plot

Comments (0)

Files changed (1)

                                        fancybox=True, shadow=True, ncol=2)
         return fig
 
+    def plot_avg_confint95(self, x, y, pad_x=False, pad_y=False, xbars=False, ybars=True,
+                     colors=None, markers=None, xerror_every=1,
+                     legend=True, **kwargs):
+        """
+        plots quantity y over x (means and standard error of the mean).
+        The quantities are specified by their id strings,
+        i.e. "return" or "learning steps"
+
+        pad_x, pad_y: if not enough observations are present for some results,
+                should they be filled with the value of the last available obervation?
+        xbars, ybars: show standard error of the mean for the respective quantity
+        colors: dictionary which maps experiment keys to colors
+        markers: dictionary which maps experiment keys to markers
+        xerror_exery: show horizontal error bars only every .. observation
+        legend: show legend below plot
+
+        return the figure handle of the created plot
+        """
+        style = {"linewidth": 2, "alpha": .7, "linestyle": "-", "markersize": 7,
+                 }
+        if colors is None:
+            colors = dict([(l, default_colors[i % len(default_colors)]) for i, l in enumerate(self.data.keys())])
+        if markers is None:
+            markers = dict([(l, default_markers[i % len(default_markers)]) for i, l in enumerate(self.data.keys())])
+        style.update(kwargs)
+        min_ = np.inf
+        max_ = - np.inf
+        fig = plt.figure()
+        for label, results in self.data.items():
+            style["color"] = colors[label]
+            style["marker"] = markers[label]
+            y_mean, y_std, y_num = avg_quantity(results, y, pad_y)
+            y_sem = y_std / np.sqrt(y_num)
+            x_mean, x_std, x_num = avg_quantity(results, x, pad_x)
+            x_sem = x_std / np.sqrt(x_num)
 
+            if xbars:
+                plt.errorbar(x_mean, y_mean, xerr=x_sem, label=label,
+                             ecolor="k", errorevery=xerror_every, **style)
+            else:
+                plt.plot(x_mean, y_mean, label=label, **style)
+
+            if ybars:
+                plt.fill_between(x_mean, y_mean - y_sem*1.96, y_mean + y_sem*1.96,
+                                 alpha=.3, color=style["color"])
+                max_ = max(np.max(y_mean + y_sem), max_)
+                min_ = min(np.min(y_mean - y_sem), min_)
+            else:
+                max_ = max(y_mean.max(), max_)
+                min_ = min(y_mean.min(), min_)
+
+        # adjust visible space
+        y_lim = [min_-.1*abs(max_-min_),max_+.1*abs(max_-min_)]
+        if min_ != max_:
+            plt.ylim(y_lim)
+
+        # axis labels
+        xlabel = default_labels[x] if x in default_labels else x
+        ylabel = default_labels[y] if y in default_labels else y
+        plt.xlabel(xlabel, fontsize=16)
+        plt.ylabel(ylabel, fontsize=16)
+
+        if legend:
+            box = plt.gca().get_position()
+            plt.gca().set_position([box.x0, box.y0 + box.height * 0.2,
+                                    box.width, box.height * 0.8])
+            legend_handle = plt.legend(loc='upper center',
+                                       bbox_to_anchor=(0.5, -0.15),
+                                       fancybox=True, shadow=True, ncol=2)
+        return fig
 def save_figure(figure, filename):
     figure.savefig(filename, transparent=True, pad_inches=.1, bbox_inches='tight')