Machine learning and statistical learning

SHAP: an example with the Titanic dataset

Author

Ewen Gallic

In this exercise, the aim is to build a statistical model that predicts whether a passenger of the Titanic survived or not. The data come from http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.html. As written on this website:

The titanic and titanic2 data frames describe the survival status of individual passengers on the Titanic. The titanic data frame does not contain information from the crew, but it does contain actual ages of half of the passengers. The principal source for data about Titanic passengers is the Encyclopedia Titanica. The datasets used here were begun by a variety of researchers. One of the original sources is Eaton & Haas (1994) Titanic: Triumph and Tragedy, Patrick Stephens Ltd, which includes a passenger list created by many researchers and edited by Michael A. Findlay.

From this source, let us focus on two samples:

Both samples come from the same source. On the first one, we will train a model to predict the probability of surviving. On the second one, we will test the predictive capacities of the model. Once a model is validated, we will use SHAP to explain the deviations of the predictions for each individual from the average of the predicted probability of suviving.

Variable Description Type
pclass Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd) Double
survived Survival (0 = No; 1 = Yes) Double
name Name of the passenger Character
sex Gender of the passenger Character
age Age of the passenger Double
sibsp Number of siblings / Spouses aboard Double
parch Number of parents / Children aboard Double
ticket Ticket Number Character
fare Passenger fare (in British pounds) Double
cabin Cabin Character
embarked Port of Embarkation (C=Cherbourg, Q=Queenstown, S=Southampton) Character
boat Lifeboat Character
body Body Identification Number Double
home.dest Home / Destination Character

Note: When the variable Age is estimated, the value shows a .5 decimal. Fare price is expressed in pre-1970 British pounds.

First, let us import pandas and numpy:

import pandas as pd
import numpy as np

Let us import the train and test datasets:

titanic_train = pd.read_csv("titanic_train.csv")
titanic_test = pd.read_csv("titanic_test.csv")

A quick glance at the training set:

titanic_train
pclass survived name sex age sibsp parch ticket fare cabin embarked boat body home.dest
0 3 0 Chronopoulos, Mr. Demetrios male 18.0 1 0 2680 14.4542 NaN C NaN NaN Greece
1 3 0 Barry, Miss. Julia female 27.0 0 0 330844 7.8792 NaN Q NaN NaN New York, NY
2 3 0 Bengtsson, Mr. John Viktor male 26.0 0 0 347068 7.7750 NaN S NaN NaN Krakudden, Sweden Moune, IL
3 1 0 Giglio, Mr. Victor male 24.0 0 0 PC 17593 79.2000 B86 C NaN NaN NaN
4 2 0 Pernot, Mr. Rene male NaN 0 0 SC/PARIS 2131 15.0500 NaN C NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1042 2 1 Louch, Mrs. Charles Alexander (Alice Adelaide ... female 42.0 1 0 SC/AH 3085 26.0000 NaN S NaN NaN Weston-Super-Mare, Somerset
1043 3 0 Burke, Mr. Jeremiah male 19.0 0 0 365222 6.7500 NaN Q NaN NaN Co Cork, Ireland Charlestown, MA
1044 2 0 Cotterill, Mr. Henry "Harry" male 21.0 0 0 29107 11.5000 NaN S NaN NaN Penzance, Cornwall / Akron, OH
1045 3 1 Riordan, Miss. Johanna "Hannah" female NaN 0 0 334915 7.7208 NaN Q 13 NaN NaN
1046 3 0 Hampe, Mr. Leon male 20.0 0 0 345769 9.5000 NaN S NaN NaN NaN

1047 rows × 14 columns

print("Nb obs Train: ", titanic_train.shape)
print("Nb obs Test: ", titanic_test.shape)
Nb obs Train:  (1047, 14)
Nb obs Test:  (262, 14)

Pre-processing of the data (1/2)

As we need to transform the type of some columns, let us merge together all the individuals in a single table. We need to make sure to keep the information regarding whether each individual belongs to the training or the testing sample.

titanic_train["sample"] = "train"
titanic_test["sample"] = "test"
titanic = pd.concat([titanic_train, titanic_test], ignore_index=True)
titanic.shape
(1309, 15)

Let us create an ID:

titanic["id"] = titanic.index + 1

Let us change the target variable to a factor.

#titanic["survived"] = pd.Categorical(titanic["survived"])
#titanic["survived"] = titanic["survived"].cat.rename_categories({0:"Died", 1:"Survived"})

Let us also turn pclass, sex and embarked to factors.

#titanic["pclass"] = pd.Categorical(titanic["pclass"])
#titanic["pclass"] = titanic["pclass"].cat.rename_categories({3:"3rd", 2:"2nd", 1:"1st"})

#titanic["sex"] = pd.Categorical(titanic["sex"])

#titanic["embarked"] = pd.Categorical(titanic["embarked"])
#titanic["embarked"] = titanic["embarked"].cat.rename_categories({"C":"Cherbourg", "Q":"Queenstown", "S":"Southampton"})
titanic
pclass survived name sex age sibsp parch ticket fare cabin embarked boat body home.dest sample id
0 3 0 Chronopoulos, Mr. Demetrios male 18.0 1 0 2680 14.4542 NaN C NaN NaN Greece train 1
1 3 0 Barry, Miss. Julia female 27.0 0 0 330844 7.8792 NaN Q NaN NaN New York, NY train 2
2 3 0 Bengtsson, Mr. John Viktor male 26.0 0 0 347068 7.7750 NaN S NaN NaN Krakudden, Sweden Moune, IL train 3
3 1 0 Giglio, Mr. Victor male 24.0 0 0 PC 17593 79.2000 B86 C NaN NaN NaN train 4
4 2 0 Pernot, Mr. Rene male NaN 0 0 SC/PARIS 2131 15.0500 NaN C NaN NaN NaN train 5
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1304 3 0 Windelov, Mr. Einar male 21.0 0 0 SOTON/OQ 3101317 7.2500 NaN S NaN NaN NaN test 1305
1305 3 0 Wittevrongel, Mr. Camille male 36.0 0 0 345771 9.5000 NaN S NaN NaN NaN test 1306
1306 3 1 Yasbeck, Mrs. Antoni (Selini Alexander) female 15.0 1 0 2659 14.4542 NaN C NaN NaN NaN test 1307
1307 3 0 Zabour, Miss. Hileni female 14.5 1 0 2665 14.4542 NaN C NaN 328.0 NaN test 1308
1308 3 0 Zakarian, Mr. Mapriededer male 26.5 0 0 2656 7.2250 NaN C NaN 304.0 NaN test 1309

1309 rows × 16 columns

Descriptive statistics

The variable survived is coded as 0 if the passenger died and 1 if they survived. The percentage of survivors in the the training and the testing sample are 38\% and 36\%, respectively.

import researchpy as rp
rp.summary_cat(titanic["survived"])
Variable Outcome Count Percent
0 survived 0 809 61.8
1 1 500 38.2

Number of observation for the target variable in the train and in the test sets:

rp.crosstab(titanic["survived"], titanic["sample"])
sample
sample test train All
survived
0 167 642 809
1 95 405 500
All 262 1047 1309

The corresponding proportions:

rp.crosstab(titanic["survived"], titanic["sample"], prop = "col")
sample
sample test train All
survived
0 63.74 61.32 61.8
1 36.26 38.68 38.2
All 100.00 100.00 100.0

Let us have a look at the count of observations for each category of the three following variables: pclass, sex, and embarked:

for v in ["pclass", "sex", "embarked"]:
    print("--------------------\n")
    print("Variable: ", v,)
    display(rp.summary_cat(titanic[v]))
--------------------

Variable:  pclass
Variable Outcome Count Percent
0 pclass 3 709 54.16
1 1 323 24.68
2 2 277 21.16
--------------------

Variable:  sex
Variable Outcome Count Percent
0 sex male 843 64.4
1 female 466 35.6
--------------------

Variable:  embarked
Variable Outcome Count Percent
0 embarked S 914 69.93
1 C 270 20.66
2 Q 123 9.41

Let us have a look at the distribution of the class of the passengers. This variable is a proxy for economic status, class 1 representing on average relatively richer individuals.

import matplotlib.pyplot as plt
import seaborn as sns

In jupyter notebook, to have more vectorial images when we create a figure:

%config InlineBackend.figure_format = 'svg'

Then the graph:

plt.figure(figsize = (8,4))
ax = sns.countplot(x="pclass", data=titanic, color="grey")
ax.set_title('Distribution of Passenger Class')
ax.set_xlabel('Passenger Class')
ax.set_ylabel('Count')
Text(0, 0.5, 'Count')

Let us display the survival rate among each of these categorical variables:

plt.figure(figsize = (8,4))
ax = sns.barplot(x='pclass', y='survived', data=titanic, color="grey")
ax.set_title('Survival rate by passenger class')
ax.set_xlabel('Passenger Class')
ax.set_ylabel('Survival Rate')
Text(0, 0.5, 'Survival Rate')

Same for gender:

plt.figure(figsize = (8,4))
ax = sns.countplot(x="sex", data=titanic, color="grey", order = ["female", "male"])
ax.set_title('Distribution of Passenger Gender')
ax.set_xlabel('Passenger gender')
ax.set_ylabel('Count')
Text(0, 0.5, 'Count')

plt.figure(figsize = (8,4))
ax = sns.barplot(x='sex', y='survived', data=titanic, color="grey", order = ["female", "male"])
ax.set_title('Survival rate by passenger Gender')
ax.set_xlabel('Passenger gender')
ax.set_ylabel('Survival Rate')
Text(0, 0.5, 'Survival Rate')

And for port of embarkation:

plt.figure(figsize = (8,4))
ax = sns.countplot(x="embarked", data=titanic, color="grey")
ax.set_title('Distribution of Passenger Embarkation')
ax.set_xlabel('Port of Embarkation')
ax.set_ylabel('Count')
Text(0, 0.5, 'Count')

plt.figure(figsize = (8,4))
ax = sns.barplot(x='embarked', y='survived', data=titanic, color="grey")
ax.set_title('Survival rate by passenger Embarkation')
ax.set_xlabel('Port of Embarkation')
ax.set_ylabel('Survival Rate')
Text(0, 0.5, 'Survival Rate')

Let us have a look at the age:

titanic[["age", "sample"]].apply(lambda x: min(x))
age       0.1667
sample      test
dtype: object
titanic.groupby("sample")["age"].apply(lambda x: {"min": x.min(), "max": x.max()})
sample     
test    min     0.7500
        max    76.0000
train   min     0.1667
        max    80.0000
Name: age, dtype: float64
titanic[["age", "sample"]].groupby("sample").apply(lambda x: {"min": x.min(), "max":x.max()})
sample
test        {'min': [0.75, 'test'], 'max': [76.0, 'test']}
train    {'min': [0.1667, 'train'], 'max': [80.0, 'trai...
dtype: object
plt.figure(figsize = (12,6))
ax = sns.histplot(x='age', data=titanic, color="grey")
ax.set_title('Age distribution')
ax.set_xlabel('Age')
ax.set_ylabel('Count')
Text(0, 0.5, 'Count')

We can construct age classes to visualize the survival rate on those.

conditions = [
    (titanic["age"].lt(15)),
    (titanic["age"].ge(15) & titanic["age"].lt(20)),
    (titanic["age"].ge(20) & titanic["age"].lt(25)),
    (titanic["age"].ge(25) & titanic["age"].lt(35)),
    (titanic["age"].ge(35) & titanic["age"].lt(45)),
    (titanic["age"].ge(45)),
]
choices = ["<15", "[15,19]", "[20,24]", "[25,34]", "[35,44]","> 44"]

titanic["age_class"] = np.select(conditions, choices)
rp.summary_cat(titanic["age_class"])
Variable Outcome Count Percent
0 age_class [25,34] 292 22.31
1 0 263 20.09
2 [20,24] 184 14.06
3 > 44 176 13.45
4 [35,44] 169 12.91
5 [15,19] 116 8.86
6 <15 109 8.33

Number of missing values:

sum(titanic["age_class"] == "0")
263

The survival rate depending on the age can then be plotted:

plt.figure(figsize = (8,4))
ax = sns.barplot(x='age_class', y='survived', data=titanic, color="grey", order = choices)
ax.set_title('Survival rate by passenger Age')
ax.set_xlabel('Age Class')
ax.set_ylabel('Survival Rate')
Text(0, 0.5, 'Survival Rate')

There seems to be a non-linear relationship between age and survival rate. Let us look at how it varies depending on the gender of the passenger:

plt.figure(figsize = (8,4))
ax = sns.barplot(x='age_class', y='survived', data=titanic, order = choices, hue="sex")
ax.set_title('Survival rate by passenger Age')
ax.set_xlabel('Age Class')
ax.set_ylabel('Survival Rate')
Text(0, 0.5, 'Survival Rate')

What about the number of siblings / spouse ? First, we should note that the 3rd quartile of the number of siblings or spouse on board is equal to 1. So, most passengers were not traveling with their siblings or spouse.

plt.figure(figsize = (8,4))
ax = sns.boxplot(y="sibsp", x="survived", data=titanic)
ax.set_title('Number of siblings / spouses depending on the survival status')
ax.set_xlabel('Survived')
ax.set_ylabel('Number of siblings / spouses')
Text(0, 0.5, 'Number of siblings / spouses')

We can also look at the number of parents or children. Again, most of the passengers were not accompanied by any parents or children. The empirical quantile of order 80% is equal to 1.

titanic["parch"].describe().round(2)
count    1309.00
mean        0.39
std         0.87
min         0.00
25%         0.00
50%         0.00
75%         0.00
max         9.00
Name: parch, dtype: float64
np.quantile(titanic["parch"], q=.8)
1.0
plt.figure(figsize = (8,4))
ax = sns.boxplot(y="parch", x="survived", data=titanic)
ax.set_title('Number of parents / children depending on the survival status')
ax.set_xlabel('Survived')
ax.set_ylabel('Number of parents / children')
Text(0, 0.5, 'Number of parents / children')

Train a model

Let us train a model to predict the probability of surviving on the Titanic. Let us consider first a simple decision tree.

df_train = titanic[["survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked", "sample"]]
df_train = df_train[df_train["sample"] == "train"]
df_train = df_train.drop("sample", axis = 1)

df_test = titanic[["survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked", "sample"]]
df_test = df_test[df_test["sample"] != "train"]
df_test = df_test.drop("sample", axis = 1)

Let us remove NaN values (this should be done more properly in a real project).

df_train = df_train.dropna()
df_test = df_test.dropna()

A quick glance at the training set:

df_train
survived pclass sex age sibsp parch fare embarked
0 0 3 male 18.0 1 0 14.4542 C
1 0 3 female 27.0 0 0 7.8792 Q
2 0 3 male 26.0 0 0 7.7750 S
3 0 1 male 24.0 0 0 79.2000 C
5 0 3 female 39.0 0 5 29.1250 Q
... ... ... ... ... ... ... ... ...
1040 1 2 female 12.0 2 1 39.0000 S
1042 1 2 female 42.0 1 0 26.0000 S
1043 0 3 male 19.0 0 0 6.7500 Q
1044 0 2 male 21.0 0 0 11.5000 S
1046 0 3 male 20.0 0 0 9.5000 S

829 rows × 8 columns

And at the test set:

df_test
survived pclass sex age sibsp parch fare embarked
1047 1 1 male 37.0 1 1 52.5542 S
1048 0 1 male 25.0 0 0 26.0000 C
1049 1 1 male 28.0 0 0 26.5500 S
1050 0 1 male 45.0 0 0 35.5000 S
1051 1 1 female 30.0 0 0 164.8667 S
... ... ... ... ... ... ... ... ...
1304 0 3 male 21.0 0 0 7.2500 S
1305 0 3 male 36.0 0 0 9.5000 S
1306 1 3 female 15.0 1 0 14.4542 C
1307 0 3 female 14.5 1 0 14.4542 C
1308 0 3 male 26.5 0 0 7.2250 C

214 rows × 8 columns

Let us create dummy variables for categorical variables:

categ_var = ["pclass", "sex", "embarked"]

On the training set:

df_train = pd.concat(
    [df_train.iloc[:,np.where(np.logical_not(df_train.columns.isin(categ_var)))[0]],
    pd.get_dummies(df_train[categ_var],drop_first=True)],
    axis = 1
)
display(df_train)
survived age sibsp parch fare pclass sex_male embarked_Q embarked_S
0 0 18.0 1 0 14.4542 3 1 0 0
1 0 27.0 0 0 7.8792 3 0 1 0
2 0 26.0 0 0 7.7750 3 1 0 1
3 0 24.0 0 0 79.2000 1 1 0 0
5 0 39.0 0 5 29.1250 3 0 1 0
... ... ... ... ... ... ... ... ... ...
1040 1 12.0 2 1 39.0000 2 0 0 1
1042 1 42.0 1 0 26.0000 2 0 0 1
1043 0 19.0 0 0 6.7500 3 1 1 0
1044 0 21.0 0 0 11.5000 2 1 0 1
1046 0 20.0 0 0 9.5000 3 1 0 1

829 rows × 9 columns

On the test set:

df_test = pd.concat(
    [df_test.iloc[:,np.where(np.logical_not(df_test.columns.isin(categ_var)))[0]],
    pd.get_dummies(df_test[categ_var],drop_first=True)],
    axis = 1
)
display(df_test)
survived age sibsp parch fare pclass sex_male embarked_Q embarked_S
1047 1 37.0 1 1 52.5542 1 1 0 1
1048 0 25.0 0 0 26.0000 1 1 0 0
1049 1 28.0 0 0 26.5500 1 1 0 1
1050 0 45.0 0 0 35.5000 1 1 0 1
1051 1 30.0 0 0 164.8667 1 0 0 1
... ... ... ... ... ... ... ... ... ...
1304 0 21.0 0 0 7.2500 3 1 0 1
1305 0 36.0 0 0 9.5000 3 1 0 1
1306 1 15.0 1 0 14.4542 3 0 0 0
1307 0 14.5 1 0 14.4542 3 0 0 0
1308 0 26.5 0 0 7.2250 3 1 0 0

214 rows × 9 columns

X_train = df_train
X_train = X_train.drop("survived", axis=1)
y_train = df_train["survived"]

X_test = df_test
X_test = X_test.drop("survived", axis=1)
y_test = df_test["survived"]

Now, let us build a decision tree to classify our observations. First, we can import the DecisionTreeClassifier() function from the tree module of the sklearn library.

from sklearn.tree import DecisionTreeClassifier

Then we can train the decision tree:

# Decision Tree classifer
tree_classifier = DecisionTreeClassifier(max_depth=3)

# Let us train it
tree_classifier = tree_classifier.fit(X_train,y_train)

Predictions (out of sample sample):

y_pred = tree_classifier.predict(X_test)

What is the accuracy on the test set? Let us load the metrics module from the sklean library.

from sklearn import metrics

The accuracy on the test set:

print("Accuracy: ",metrics.accuracy_score(y_test, y_pred))
Accuracy:  0.822429906542056

The confusion table:

metrics.confusion_matrix(y_test, y_pred)
array([[119,  15],
       [ 23,  57]])

Now let us visualize the decision tree. We need two libraries: graphviz and pydotplus. These need to be installed.

#pip install graphviz
#pip install pydotplus

Let us import the tree module from the sklearn library.

from sklearn import tree

Then we can use the plot_tree() from the tree module to plot the decision tree:

plt.figure(figsize = (12,6))
tree.plot_tree(tree_classifier, filled = True,
              feature_names = X_train.columns,
              class_names = ["Died", "Survived"]);

Now, let us fit a random forest. Note that the we do not proceed here to fine tuning, which should be done to improve the predictive capacities of the model. Let us import the RandomForestClassifier() function from the ensemble module of the sklearn library.

from sklearn.ensemble import RandomForestClassifier

Then we can use that function to train a random forest on the training set. Let us set the minimum number of observations in terminal leaves to 10, and let us consider 3 variables to draw from to make the splits each time we try to make a split. Fist, let us set the hyperparameters:

random_forest_classifier = RandomForestClassifier(n_estimators = 1000, 
                                                  min_samples_leaf = 10, max_features=3)

The model can then be trained:

random_forest_classifier.fit(X_train, y_train)
RandomForestClassifier(max_features=3, min_samples_leaf=10, n_estimators=1000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

With this kind of model, there is no straightforward method to interpret the results. There is, however, a metric that gives the relative importance of the variables (however, this metric does not disentangle between positive or negative effects).

Assessing the predictive capacities of the model

Predictions on the training sample

predict_tree_train = tree_classifier.predict_proba(X_train)
predict_rf_train = random_forest_classifier.predict_proba(X_train)
  • First column: prediction class 0 (died)
  • Second column: prediction class 1 (survived)
predict_tree_train[range(10)]
array([[0.85824742, 0.14175258],
       [0.46      , 0.54      ],
       [0.85824742, 0.14175258],
       [0.62037037, 0.37962963],
       [0.85      , 0.15      ],
       [0.85824742, 0.14175258],
       [0.85824742, 0.14175258],
       [0.85      , 0.15      ],
       [0.85824742, 0.14175258],
       [0.46      , 0.54      ]])

Predictions on the test sample

predict_tree_test = tree_classifier.predict_proba(X_test)
predict_rf_test =random_forest_classifier.predict_proba(X_test)

Let us set the probability threshold for the classifier, i.e., the value above which the class “Survived” is predicted.

tau = .5
np.where(predict_tree_train[:,1] >= tau, "Survived", "Died")
array(['Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Survived',
       'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
       'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Survived', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived', 'Died',
       'Survived', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Survived', 'Survived', 'Survived', 'Died',
       'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
       'Survived', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived',
       'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
       'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
       'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived',
       'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
       'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Survived',
       'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
       'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived',
       'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
       'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Died', 'Survived', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
       'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Survived',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died',
       'Died', 'Died', 'Died', 'Survived', 'Survived', 'Survived',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
       'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
       'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Survived',
       'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
       'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Survived',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
       'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
       'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
       'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died',
       'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
       'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived',
       'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
       'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
       'Died', 'Died'], dtype='<U8')
y_pred_train_tree = np.where(predict_tree_train[:,1] >= tau, 1, 0)
y_pred_train_rf = np.where(predict_rf_train[:,1] >= tau, 1, 0)

y_pred_test_tree = np.where(predict_tree_test[:,1] >= tau, 1, 0)
y_pred_test_rf = np.where(predict_rf_test[:,1] >= tau, 1, 0)

Let us define a small function to get some goodness-of-fit metrics:

def metrics_fit(y_obs, y_pred):
    res = {
        "Accuracy": metrics.accuracy_score(y_obs, y_pred),
        "Kappa": metrics.cohen_kappa_score(y_obs, y_pred),
        "Sensitivity (True Pos Rate)": metrics.recall_score(y_obs, y_pred),
        "F1": metrics.f1_score(y_obs, y_pred)
    }
    return(res)

This user-defined function can be applied to assess the predictive power of the model on the train set, for the decision tree:

metrics_fit(y_train, y_pred_train_tree)
{'Accuracy': 0.8033775633293124,
 'Kappa': 0.5890510527132113,
 'Sensitivity (True Pos Rate)': 0.7101449275362319,
 'F1': 0.7503828483920367}

And for the random forest:

metrics_fit(y_train, y_pred_train_rf)
{'Accuracy': 0.8299155609167672,
 'Kappa': 0.6390231398986452,
 'Sensitivity (True Pos Rate)': 0.6898550724637681,
 'F1': 0.7714748784440844}

It can also be applied on the test set for the decision tree:

metrics_fit(y_test, y_pred_test_tree)
{'Accuracy': 0.822429906542056,
 'Kappa': 0.6129093678598629,
 'Sensitivity (True Pos Rate)': 0.7125,
 'F1': 0.75}

And for the random forest:

metrics_fit(y_test, y_pred_test_rf)
{'Accuracy': 0.822429906542056,
 'Kappa': 0.6047822706065318,
 'Sensitivity (True Pos Rate)': 0.6625,
 'F1': 0.736111111111111}

Explaining the predicted probabilities with SHAP

Now, let us turn to SHAP to explain how the explanatory variables help explain, for some values of the characteristics, how these specific values affected the deviation from the average predicted probability of surviving.

# pip install shap
import shap

Let us consider a single individual, Edith Eileen, who is characterized as follows (she actually survived and lived more than 100 years!):

example = titanic[titanic["name"].str.contains("Edith Eileen")]
ind_example = example.index[0]
ind_example
539
np.where(X_train.index == ind_example)[0]
array([432])
X_example = X_train[X_train.index.isin([ind_example])]
X_example
age sibsp parch fare pclass sex_male embarked_Q embarked_S
539 15.0 0 2 39.0 2 0 0 1

Let us build our predict function:

def my_predict_f_prob(x):
    pred_prob = random_forest_classifier.predict_proba(x)
    return(pred_prob)

The probability that our individual dies (first column) or survives (second column):

pred_example = my_predict_f_prob(X_example)
pred_example
array([[0.07657917, 0.92342083]])

Then let us build the explainer, with SHAP KernelExplainer. The function takes three arguments:

  • model: the model to be explained (or the prediction function)
  • data: background datsaet
  • link: A generalized linear model link to connect the feature importance values to the model output (identity by default)
explainer = shap.KernelExplainer(model=my_predict_f_prob, data=X_train)
X does not have valid feature names, but RandomForestClassifier was fitted with feature names
Using 829 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.

And let us estimate the Shapley values for our training example:

shap_values_example = explainer.shap_values(X=X_example, nsamples = 100)
X does not have valid feature names, but RandomForestClassifier was fitted with feature names
X does not have valid feature names, but RandomForestClassifier was fitted with feature names
shap_values_example
[array([[-0.02516427, -0.01288884, -0.02583533, -0.01695874, -0.1218961 ,
         -0.30537386,  0.0003784 ,  0.00115635]]),
 array([[ 0.02516427,  0.01288884,  0.02583533,  0.01695874,  0.1218961 ,
          0.30537386, -0.0003784 , -0.00115635]])]

The Shapley values explain how each variable moves, in the model, for the specific individual, the prediction from the average prediction.

shap_values_example[0].sum()
-0.5065823870461259

The average prediction of the probability of surviving in all the training sample:

avg_pred = my_predict_f_prob(X_train)
avg_pred = avg_pred[:,0].mean()
avg_pred
0.5831615566676123

We can see that the sum of these values is equal to the deviation of the predicted value from the average predicted probability to survive in the whole training sample:

pred_example[0][0] - avg_pred
-0.5065823870461261

The package offers some nice plots to explain at least a single example:

shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values_example[0], X_example)
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
shap.force_plot(explainer.expected_value[1], shap_values_example[1], X_example)
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

More options are available, see https://github.com/slundberg/shap

My heart will go on