Commits

Jernej Kos committed 91c1441

Added first implementation of ml-explainer functions.

Comments (0)

Files changed (4)

+syntax: glob
+*.pyc
+Requires development versions of numpy, scipy and matplotlib.
+from pandas import *
+import numpy as np
+from sklearn import svm
+import explainer
+
+a1 = np.random.choice([0, 1], 1001, replace = True)
+a2 = np.random.choice([0, 1], 1001, replace = True)
+a3 = np.random.choice([0, 1], 1001, replace = True)
+a4 = np.random.choice([0, 1, 2, 3], 1001, replace = True)
+a5 = np.random.choice([0, 1], 1001, replace = True)
+c = ((a1 ^ a2) + a4) | a5
+
+df = DataFrame(dict(a1 = a1, a2 = a2, a3 = a3, a4 = a4, a5 = a5, c = c))
+learn = df[:1000]
+test = df[1000:]
+
+mdl = svm.SVC(probability = True)
+mdl.fit(learn[['a1', 'a2', 'a3', 'a4', 'a5']], learn['c'])
+
+print test
+result = explainer.explain_instance(mdl, learn, test)
+print result
+explainer.plot_instance_explanation(result)
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+def explain_instance(model, data, instance, iterations = 200):
+  """
+  Explains a dataset instance using the specified model. The returned
+  explanation is a dictionary containing contribution of each individual
+  feature and the model's prediction.
+
+  NOTE: It currently only works for discrete features!
+
+  :param model: Scikit-learn model instance with enabled probability
+    estimation
+  :param data: Dataset used for estimating feature range
+  :param instance: Instance to explain
+  :param iterations: Number of iterations in Monte Carlo simulation
+  """
+  cls_typ = instance.dtypes[-1].type
+  features = list(data.columns[:-1])
+  instance = instance[features]
+  prediction = cls_typ(model.predict(instance)[0])
+  p_index = sorted(data[data.columns[-1]].unique()).index(prediction)
+  data = data[features]
+
+  nFeatures = len(data.columns)
+  explanation = []
+  for feature in data.columns:
+    contribution = 0.0
+    for j in xrange(iterations):
+      perm = np.random.choice(data.columns, nFeatures, replace = False)
+      tmp = instance.copy()
+      idx = 0
+      while perm[idx] != feature:
+        tmp[perm[idx]] = select_random_value(data[perm[idx]])
+        idx += 1
+
+      tmp2 = tmp.copy()
+      tmp2[perm[idx]] = select_random_value(data[perm[idx]])
+      contribution += model.predict_proba(tmp)[0][p_index] - model.predict_proba(tmp2)[0][p_index]
+
+    contribution /= iterations
+    explanation.append((feature, contribution))
+
+  return dict(explanation = explanation, prediction = prediction)
+
+def select_random_value(feature):
+  """
+  Selects a random value from a feature's range.
+
+  :param feature: DataFrame describing feature's values
+  """
+  if feature.dtype.kind == 'i':
+    return np.random.choice(feature.unique(), 1)
+ 
+  raise TypeError, "Unsupported feature type!"
+
+def plot_instance_explanation(result, filename = "output.png"):
+  """
+  Plots an instance explanation generated by `explain_instance`.
+
+  :param result: Explanation generated by `explain_instance`
+  :param filename: File where the output visualization should be
+    saved to
+  """
+  plt.figure(1, figsize = (6, 4))
+  left_bar = plt.axes((0.1, 0.1, 0.85, 0.7))
+  result['explanation'] = result['explanation'][::-1]
+  N = len(result['explanation'])
+  contributions = [contribution for _, contribution in result['explanation']]
+  maxc = max(contributions) + 0.1
+
+  left_bar.barh(np.arange(N) + 0.3, contributions, height = 0.55)
+  
+  left_bar.hlines(np.arange(N) + 0.05, -maxc, maxc, linestyles = 'dashed')
+  left_bar.axvline(x = 0.0, color = 'black', linewidth = 2)
+
+  left_bar.set_yticks(np.arange(len(result['explanation'])) + 0.5)
+  left_bar.set_yticklabels([feature for feature, _ in result['explanation']])
+  left_bar.set_xlim(left = -maxc, right = maxc)
+  for t in left_bar.get_xticklines(): t.set_marker(None)
+  for t in left_bar.get_yticklines(): t.set_marker(None)
+  left_bar.set_frame_on(False)
+
+  plt.savefig(filename)
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.