1. Jernej Kos
  2. python-ml-explain

Commits

Jernej Kos  committed 0a05bd0

Improved plotting of instance explanations.

  • Participants
  • Parent commits 9e545c8
  • Branches default

Comments (0)

Files changed (1)

File explainer.py

View file
 import numpy as np
 import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
 
 def explain_instance(model, data, instance, iterations = 200):
   """
   """
   cls_typ = instance.dtypes[-1].type
   features = list(data.columns[:-1])
+  actual = np.asarray(instance[data.columns[-1]])[0]
   instance = instance[features]
   prediction = cls_typ(model.predict(np.asarray(instance))[0])
   p_index = sorted(data[data.columns[-1]].unique()).index(prediction)
       contribution += model.predict_proba(tmp)[0][p_index] - model.predict_proba(tmp2)[0][p_index]
 
     contribution /= iterations
-    explanation.append((feature, contribution))
+    explanation.append((feature, contribution, np.asarray(instance[feature])[0]))
 
-  return dict(explanation = explanation, prediction = prediction)
+  return dict(explanation = explanation, prediction = prediction, actual = actual,
+    model_name = model.__class__.__name__)
 
 def explain_value(model, data, feature, value, iterations = 200):
   """
   """
   plt.clf()
   plt.figure(1, figsize = (6, 4))
-  left_bar = plt.axes((0.1, 0.1, 0.85, 0.7))
+  bar_ax = plt.axes((0.1, 0.1, 0.85, 0.7))
   result['explanation'] = result['explanation'][::-1]
   N = len(result['explanation'])
-  contributions = [contribution for _, contribution in result['explanation']]
+  contributions = [contribution for _, contribution, _ in result['explanation']]
   maxc = max(contributions) + 0.1
 
-  left_bar.barh(np.arange(N) + 0.3, contributions, height = 0.55)
+  bar_ax.barh(np.arange(N) + 0.3, contributions, height = 0.55)
   
-  left_bar.hlines(np.arange(N) + 0.05, -maxc, maxc, linestyles = 'dashed')
-  left_bar.axvline(x = 0.0, color = 'black', linewidth = 2)
+  bar_ax.hlines(np.arange(N) + 0.05, -maxc, maxc, linestyles = 'dashed')
+  bar_ax.axvline(x = 0.0, color = 'black', linewidth = 2)
 
-  left_bar.set_yticks(np.arange(len(result['explanation'])) + 0.5)
-  left_bar.set_yticklabels([feature for feature, _ in result['explanation']])
-  left_bar.set_xlim(left = -maxc, right = maxc)
-  for t in left_bar.get_xticklines(): t.set_marker(None)
-  for t in left_bar.get_yticklines(): t.set_marker(None)
-  left_bar.set_frame_on(False)
+  bar_ax.set_yticks(np.arange(len(result['explanation'])) + 0.5)
+  bar_ax.set_yticklabels([feature for feature, _, _ in result['explanation']])
+  bar_ax.set_xlim(left = -maxc, right = maxc)
+  for t in bar_ax.get_xticklines(): t.set_marker(None)
+  for t in bar_ax.get_yticklines(): t.set_marker(None)
+  bar_ax.set_frame_on(False)
+
+  info_ax = plt.axes((0.1, 0.84, 0.85, 0.1))
+  info_ax.set_frame_on(False)
+  info_ax.set_axis_off()
+
+  p = mpatches.Rectangle((0.0, 0.01), 0.99, 0.99, facecolor = "lightblue",
+    edgecolor = "black", alpha = 0.5)
+  info_ax.add_patch(p)
+  
+  info_ax.text(0.01, 0.6, "Prediction: R = %s" % result['prediction'],
+    fontsize = 13, fontweight = 'bold')
+  info_ax.text(0.01, 0.2, "Actual value: R = %s" % result['actual'],
+    fontsize = 13, fontweight = 'bold')
+  info_ax.text(0.98, 0.6, result['model_name'], fontsize = 13,
+    horizontalalignment = 'right')
+  
+  for idx, (feature, contribution, value) in enumerate(result['explanation']):
+    bar_ax.text(maxc, 0.5 + idx, "%.2f" % contribution, horizontalalignment = 'right',
+      bbox = dict(facecolor='red', alpha = 0.5, pad = 10.0))
+  
+    bar_ax.text(-maxc + maxc/15., 0.5 + idx, "%s" % value, horizontalalignment = 'left',
+      bbox = dict(facecolor='green', alpha = 0.5, pad = 15.0))
 
   plt.savefig(filename)