Commits

Jernej Kos committed a93061e

Implemented basic plotting of model explanations.

Comments (0)

Files changed (1)

 
 def select_random_value(feature):
   """
-  Selects a random value from a feature's range.
+  Selects a random value from a feature's range. It assumes that
+  integer-typed features are discrete and float-typed features
+  are numeric.
 
   :param feature: DataFrame describing feature's values
   """
   if feature.dtype.kind == 'i':
     return np.random.choice(feature.unique(), 1)
+  elif feature.dtype.kind == 'f':
+    return np.random.rand() * (feature.max() - feature.min()) + feature.min()
  
   raise TypeError, "Unsupported feature type!"
 
     saved to
   """
   plt.clf()
-  plt.figure(1, figsize = (6, 4))
+  fig = plt.figure(figsize = (8, 6))
   bar_ax = plt.axes((0.1, 0.1, 0.85, 0.7))
   result['explanation'] = result['explanation'][::-1]
   N = len(result['explanation'])
       bbox = dict(facecolor='green', alpha = 0.5, pad = 15.0))
 
   plt.savefig(filename)
+  plt.close(fig)
 
 def plot_model_explanation(result, filename = "output.png"):
-  # TODO
-  pass
+  """
+  Plots an instance explanation generated by `explain_discrete_model`.
 
+  :param result: Explanation generated by `explain_discrete_model`
+  :param filename: File where the output visualization should be
+    saved to
+  """
+  plt.clf()
+  fig = plt.figure(figsize = (8, 12))
+  height = 0.95 / len(result)
+  result = result[::-1]
+  for idx, (feature, data) in enumerate(result):
+    maxc = np.max(data['means']) + 0.1
+    subplot = plt.axes((0.1, 0.05 + idx*height, 0.85, height - 0.05))
+    subplot.set_title(feature)
+    subplot.set_ylim(bottom = -maxc, top = maxc)
+    subplot.plot(data['values'], data['means'], 'o')
+  
+  plt.savefig(filename)
+  plt.close(fig)
+