Commits

Jernej Kos committed a53cc12

Improved model explanation plot.

  • Participants
  • Parent commits e753d67

Comments (0)

Files changed (1)

File explainer.py

   """
   plt.clf()
   fig = plt.figure(figsize = (8, 12))
-  height = 0.95 / len(result)
+  height = 0.97 / len(result)
   result = result[::-1]
   for idx, (feature, data) in enumerate(result):
     maxc = np.max(data['means']) + 0.1
     subplot = plt.axes((0.1, 0.05 + idx*height, 0.85, height - 0.05))
     subplot.set_title(feature)
+    subplot.plot(data['values'], data['means'], 'bo')
+    subplot.axhline(y = 0.0, color = 'black', linewidth = 1, linestyle = '--')
     subplot.set_ylim(bottom = -maxc, top = maxc)
-    subplot.plot(data['values'], data['means'], 'o')
+    subplot.set_xlim(left = -0.02, right = 1.02)
+    
+    subplot = subplot.twinx()
+    subplot.set_axis_off()
+    subplot.plot(data['values'], data['stds'], 'go')
+    subplot.set_ylim(bottom = -1.0, top = 1.0)
+    subplot.set_xlim(left = -0.02, right = 1.02)
   
   plt.savefig(filename)
   plt.close(fig)