Jernej Kos avatar Jernej Kos committed e753d67

Some fixes in model explanation.

Comments (0)

Files changed (1)

 import matplotlib.pyplot as plt
 import matplotlib.patches as mpatches
 
-def explain_instance(model, data, instance, iterations = 200):
+def explain_instance(model, data, instance, iterations = 300):
   """
   Explains a dataset instance using the specified model. The returned
   explanation is a dictionary containing contribution of each individual
   return dict(explanation = explanation, prediction = prediction, actual = actual,
     model_name = model.__class__.__name__)
 
-def explain_value(model, data, feature, value, iterations = 200):
+def explain_value(model, data, feature, value, iterations = 300):
   """
   Explains a single value of a single feature. The returned explanation
   is a dictionary containing the contribution mean and standard deviation.
     # Use first instance to get attribute format
     instance1 = data[0:1]
     # Replace all attributes with random values
-    for ifeature in data.columns[:-1]:
+    for ifeature in data.columns:
       instance1[ifeature] = select_random_value(data[ifeature])
     # Make another instance and replace the chosen feature with a
     # pre-selected value
     instance2 = instance1.copy()
     instance2[feature] = value
-    # Compute the predicted class
-    prediction = cls_typ(model.predict(np.asarray(instance2))[0])
-    p_index = sorted(orig_data[orig_data.columns[-1]].unique()).index(prediction)
     # Append contribution
-    contribs.append(
-      model.predict_proba(instance2)[0][p_index] - \
-      model.predict_proba(instance1)[0][p_index]
-    )
+    contribs.append(model.predict_proba(instance2)[0][0] - \
+      model.predict_proba(instance1)[0][0])
   
   return dict(mean = np.mean(contribs), std = np.std(contribs))
 
-def explain_discrete_model(model, data, iterations = 200):
+def explain_discrete_model(model, data, resolution = 50, iterations = 300):
   """
   Explains the complete model (all values of all features).
   
   :param model: Scikit-learn model instance with enabled probability
     estimation
   :param data: Dataset used for estimating feature range
+  :param resolution: Resolution for continous attributes
   :param iterations: Number of iterations in Monte Carlo simulation
   """
   explanation = []
   for feature in data.columns[:-1]:
-    values, means, stds = [], [], []
-    for value in sorted(data[feature].unique()):
-      e = explain_value(model, data, feature, value)
-      values.append(value)
+    means, stds = [], []
+    
+    if data[feature].dtype.kind == 'i':
+      values = sorted(data[feature].unique())
+    elif data[feature].dtype.kind == 'f':
+      values = np.linspace(0, 1, resolution)
+    
+    for value in values:
+      e = explain_value(model, data, feature, value, iterations)
       means.append(e['mean'])
       stds.append(e['std'])
     
     bar_ax.text(maxc, 0.5 + idx, "%.2f" % contribution, horizontalalignment = 'right',
       bbox = dict(facecolor='red', alpha = 0.5, pad = 10.0))
   
+    value = round(value, 2) if isinstance(value, float) else value
     bar_ax.text(-maxc + maxc/15., 0.5 + idx, "%s" % value, horizontalalignment = 'left',
       bbox = dict(facecolor='green', alpha = 0.5, pad = 15.0))
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.