sd_boostedtrees_flu.py

  1"""
  2
  3"""
  4
  5import matplotlib.pyplot as plt
  6import numpy as np
  7import seaborn as sns
  8from matplotlib.collections import CircleCollection
  9from sklearn.ensemble import AdaBoostClassifier
 10from sklearn.tree import DecisionTreeClassifier
 11
 12from ema_workbench import load_results, ema_logging
 13from ema_workbench.analysis import feature_scoring
 14
 15ema_logging.log_to_stderr(ema_logging.INFO)
 16
 17
 18def plot_factormap(x1, x2, ax, bdt, nominal):
 19    """helper function for plotting a 2d factor map"""
 20    x_min, x_max = x[:, x1].min(), x[:, x1].max()
 21    y_min, y_max = x[:, x2].min(), x[:, x2].max()
 22    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 500), np.linspace(y_min, y_max, 500))
 23
 24    grid = np.ones((xx.ravel().shape[0], x.shape[1])) * nominal
 25    grid[:, x1] = xx.ravel()
 26    grid[:, x2] = yy.ravel()
 27
 28    Z = bdt.predict(grid)
 29    Z = Z.reshape(xx.shape)
 30
 31    ax.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.5)  # @UndefinedVariable
 32
 33    for i in (0, 1):
 34        idx = y == i
 35        ax.scatter(x[idx, x1], x[idx, x2], s=5)
 36    ax.set_xlabel(columns[x1])
 37    ax.set_ylabel(columns[x2])
 38
 39
 40def plot_diag(x1, ax):
 41    x_min, x_max = x[:, x1].min(), x[:, x1].max()
 42    for i in (0, 1):
 43        idx = y == i
 44        ax.hist(x[idx, x1], range=(x_min, x_max), alpha=0.5)
 45
 46
 47# load data
 48experiments, outcomes = load_results("./data/1000 flu cases with policies.tar.gz")
 49
 50# transform to numpy array with proper recoding of cateogorical variables
 51x, columns = feature_scoring._prepare_experiments(experiments)
 52y = outcomes["deceased_population_region 1"][:, -1] > 1000000
 53
 54# establish mean case for factor maps
 55# this is questionable in particular in case of categorical dimensions
 56minima = x.min(axis=0)
 57maxima = x.max(axis=0)
 58nominal = minima + (maxima - minima) / 2
 59
 60# fit the boosted tree
 61bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), algorithm="SAMME", n_estimators=200)
 62bdt.fit(x, y)
 63
 64# determine which dimensions are most important
 65sorted_indices = np.argsort(bdt.feature_importances_)[::-1]
 66
 67# do the actual plotting
 68# this is a quick hack, tying it to seaborn Pairgrid is probably
 69# the more elegant solution, but is tricky with what arguments
 70# can be passed to the plotting function
 71fig, axes = plt.subplots(ncols=5, nrows=5, figsize=(15, 15))
 72
 73for i, row in enumerate(axes):
 74    for j, ax in enumerate(row):
 75        if i > j:
 76            plot_factormap(sorted_indices[j], sorted_indices[i], ax, bdt, nominal)
 77        elif i == j:
 78            plot_diag(sorted_indices[j], ax)
 79        else:
 80            ax.set_xticks([])
 81            ax.set_yticks([])
 82            ax.axis("off")
 83
 84        if j > 0:
 85            ax.set_yticklabels([])
 86            ax.set_ylabel("")
 87        if i < len(axes) - 1:
 88            ax.set_xticklabels([])
 89            ax.set_xlabel("")
 90
 91# add the legend
 92# Draw a full-figure legend outside the grid
 93handles = [
 94    CircleCollection([10], color=sns.color_palette()[0]),
 95    CircleCollection([10], color=sns.color_palette()[1]),
 96]
 97
 98legend = fig.legend(handles, ["False", "True"], scatterpoints=1)
 99
100plt.tight_layout()
101plt.show()