import pandas as pd
import numpy as npMachine learning and statistical learning
SHAP: an example with the Titanic dataset
In this exercise, the aim is to build a statistical model that predicts whether a passenger of the Titanic survived or not. The data come from http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.html. As written on this website:
The titanic and titanic2 data frames describe the survival status of individual passengers on the Titanic. The titanic data frame does not contain information from the crew, but it does contain actual ages of half of the passengers. The principal source for data about Titanic passengers is the Encyclopedia Titanica. The datasets used here were begun by a variety of researchers. One of the original sources is Eaton & Haas (1994) Titanic: Triumph and Tragedy, Patrick Stephens Ltd, which includes a passenger list created by many researchers and edited by Michael A. Findlay.
From this source, let us focus on two samples:
titanic_train: a dataset containingr nrow(titanic_train)rowstitanic_test: a dataset containingr nrow(titanic_test)rows.
Both samples come from the same source. On the first one, we will train a model to predict the probability of surviving. On the second one, we will test the predictive capacities of the model. Once a model is validated, we will use SHAP to explain the deviations of the predictions for each individual from the average of the predicted probability of suviving.
| Variable | Description | Type |
|---|---|---|
| pclass | Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd) | Double |
| survived | Survival (0 = No; 1 = Yes) | Double |
| name | Name of the passenger | Character |
| sex | Gender of the passenger | Character |
| age | Age of the passenger | Double |
| sibsp | Number of siblings / Spouses aboard | Double |
| parch | Number of parents / Children aboard | Double |
| ticket | Ticket Number | Character |
| fare | Passenger fare (in British pounds) | Double |
| cabin | Cabin | Character |
| embarked | Port of Embarkation (C=Cherbourg, Q=Queenstown, S=Southampton) | Character |
| boat | Lifeboat | Character |
| body | Body Identification Number | Double |
| home.dest | Home / Destination | Character |
Note: When the variable Age is estimated, the value shows a .5 decimal. Fare price is expressed in pre-1970 British pounds.
First, let us import pandas and numpy:
Let us import the train and test datasets:
titanic_train = pd.read_csv("titanic_train.csv")
titanic_test = pd.read_csv("titanic_test.csv")A quick glance at the training set:
titanic_train| pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 3 | 0 | Chronopoulos, Mr. Demetrios | male | 18.0 | 1 | 0 | 2680 | 14.4542 | NaN | C | NaN | NaN | Greece |
| 1 | 3 | 0 | Barry, Miss. Julia | female | 27.0 | 0 | 0 | 330844 | 7.8792 | NaN | Q | NaN | NaN | New York, NY |
| 2 | 3 | 0 | Bengtsson, Mr. John Viktor | male | 26.0 | 0 | 0 | 347068 | 7.7750 | NaN | S | NaN | NaN | Krakudden, Sweden Moune, IL |
| 3 | 1 | 0 | Giglio, Mr. Victor | male | 24.0 | 0 | 0 | PC 17593 | 79.2000 | B86 | C | NaN | NaN | NaN |
| 4 | 2 | 0 | Pernot, Mr. Rene | male | NaN | 0 | 0 | SC/PARIS 2131 | 15.0500 | NaN | C | NaN | NaN | NaN |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1042 | 2 | 1 | Louch, Mrs. Charles Alexander (Alice Adelaide ... | female | 42.0 | 1 | 0 | SC/AH 3085 | 26.0000 | NaN | S | NaN | NaN | Weston-Super-Mare, Somerset |
| 1043 | 3 | 0 | Burke, Mr. Jeremiah | male | 19.0 | 0 | 0 | 365222 | 6.7500 | NaN | Q | NaN | NaN | Co Cork, Ireland Charlestown, MA |
| 1044 | 2 | 0 | Cotterill, Mr. Henry "Harry" | male | 21.0 | 0 | 0 | 29107 | 11.5000 | NaN | S | NaN | NaN | Penzance, Cornwall / Akron, OH |
| 1045 | 3 | 1 | Riordan, Miss. Johanna "Hannah" | female | NaN | 0 | 0 | 334915 | 7.7208 | NaN | Q | 13 | NaN | NaN |
| 1046 | 3 | 0 | Hampe, Mr. Leon | male | 20.0 | 0 | 0 | 345769 | 9.5000 | NaN | S | NaN | NaN | NaN |
1047 rows × 14 columns
print("Nb obs Train: ", titanic_train.shape)
print("Nb obs Test: ", titanic_test.shape)Nb obs Train: (1047, 14)
Nb obs Test: (262, 14)
Pre-processing of the data (1/2)
As we need to transform the type of some columns, let us merge together all the individuals in a single table. We need to make sure to keep the information regarding whether each individual belongs to the training or the testing sample.
titanic_train["sample"] = "train"
titanic_test["sample"] = "test"
titanic = pd.concat([titanic_train, titanic_test], ignore_index=True)
titanic.shape(1309, 15)
Let us create an ID:
titanic["id"] = titanic.index + 1Let us change the target variable to a factor.
#titanic["survived"] = pd.Categorical(titanic["survived"])
#titanic["survived"] = titanic["survived"].cat.rename_categories({0:"Died", 1:"Survived"})Let us also turn pclass, sex and embarked to factors.
#titanic["pclass"] = pd.Categorical(titanic["pclass"])
#titanic["pclass"] = titanic["pclass"].cat.rename_categories({3:"3rd", 2:"2nd", 1:"1st"})
#titanic["sex"] = pd.Categorical(titanic["sex"])
#titanic["embarked"] = pd.Categorical(titanic["embarked"])
#titanic["embarked"] = titanic["embarked"].cat.rename_categories({"C":"Cherbourg", "Q":"Queenstown", "S":"Southampton"})titanic| pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | sample | id | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 3 | 0 | Chronopoulos, Mr. Demetrios | male | 18.0 | 1 | 0 | 2680 | 14.4542 | NaN | C | NaN | NaN | Greece | train | 1 |
| 1 | 3 | 0 | Barry, Miss. Julia | female | 27.0 | 0 | 0 | 330844 | 7.8792 | NaN | Q | NaN | NaN | New York, NY | train | 2 |
| 2 | 3 | 0 | Bengtsson, Mr. John Viktor | male | 26.0 | 0 | 0 | 347068 | 7.7750 | NaN | S | NaN | NaN | Krakudden, Sweden Moune, IL | train | 3 |
| 3 | 1 | 0 | Giglio, Mr. Victor | male | 24.0 | 0 | 0 | PC 17593 | 79.2000 | B86 | C | NaN | NaN | NaN | train | 4 |
| 4 | 2 | 0 | Pernot, Mr. Rene | male | NaN | 0 | 0 | SC/PARIS 2131 | 15.0500 | NaN | C | NaN | NaN | NaN | train | 5 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1304 | 3 | 0 | Windelov, Mr. Einar | male | 21.0 | 0 | 0 | SOTON/OQ 3101317 | 7.2500 | NaN | S | NaN | NaN | NaN | test | 1305 |
| 1305 | 3 | 0 | Wittevrongel, Mr. Camille | male | 36.0 | 0 | 0 | 345771 | 9.5000 | NaN | S | NaN | NaN | NaN | test | 1306 |
| 1306 | 3 | 1 | Yasbeck, Mrs. Antoni (Selini Alexander) | female | 15.0 | 1 | 0 | 2659 | 14.4542 | NaN | C | NaN | NaN | NaN | test | 1307 |
| 1307 | 3 | 0 | Zabour, Miss. Hileni | female | 14.5 | 1 | 0 | 2665 | 14.4542 | NaN | C | NaN | 328.0 | NaN | test | 1308 |
| 1308 | 3 | 0 | Zakarian, Mr. Mapriededer | male | 26.5 | 0 | 0 | 2656 | 7.2250 | NaN | C | NaN | 304.0 | NaN | test | 1309 |
1309 rows × 16 columns
Descriptive statistics
The variable survived is coded as 0 if the passenger died and 1 if they survived. The percentage of survivors in the the training and the testing sample are 38\% and 36\%, respectively.
import researchpy as rprp.summary_cat(titanic["survived"])| Variable | Outcome | Count | Percent | |
|---|---|---|---|---|
| 0 | survived | 0 | 809 | 61.8 |
| 1 | 1 | 500 | 38.2 |
Number of observation for the target variable in the train and in the test sets:
rp.crosstab(titanic["survived"], titanic["sample"])| sample | |||
|---|---|---|---|
| sample | test | train | All |
| survived | |||
| 0 | 167 | 642 | 809 |
| 1 | 95 | 405 | 500 |
| All | 262 | 1047 | 1309 |
The corresponding proportions:
rp.crosstab(titanic["survived"], titanic["sample"], prop = "col")| sample | |||
|---|---|---|---|
| sample | test | train | All |
| survived | |||
| 0 | 63.74 | 61.32 | 61.8 |
| 1 | 36.26 | 38.68 | 38.2 |
| All | 100.00 | 100.00 | 100.0 |
Let us have a look at the count of observations for each category of the three following variables: pclass, sex, and embarked:
for v in ["pclass", "sex", "embarked"]:
print("--------------------\n")
print("Variable: ", v,)
display(rp.summary_cat(titanic[v]))--------------------
Variable: pclass
| Variable | Outcome | Count | Percent | |
|---|---|---|---|---|
| 0 | pclass | 3 | 709 | 54.16 |
| 1 | 1 | 323 | 24.68 | |
| 2 | 2 | 277 | 21.16 |
--------------------
Variable: sex
| Variable | Outcome | Count | Percent | |
|---|---|---|---|---|
| 0 | sex | male | 843 | 64.4 |
| 1 | female | 466 | 35.6 |
--------------------
Variable: embarked
| Variable | Outcome | Count | Percent | |
|---|---|---|---|---|
| 0 | embarked | S | 914 | 69.93 |
| 1 | C | 270 | 20.66 | |
| 2 | Q | 123 | 9.41 |
Let us have a look at the distribution of the class of the passengers. This variable is a proxy for economic status, class 1 representing on average relatively richer individuals.
import matplotlib.pyplot as plt
import seaborn as snsIn jupyter notebook, to have more vectorial images when we create a figure:
%config InlineBackend.figure_format = 'svg'Then the graph:
plt.figure(figsize = (8,4))
ax = sns.countplot(x="pclass", data=titanic, color="grey")
ax.set_title('Distribution of Passenger Class')
ax.set_xlabel('Passenger Class')
ax.set_ylabel('Count')Text(0, 0.5, 'Count')
Let us display the survival rate among each of these categorical variables:
plt.figure(figsize = (8,4))
ax = sns.barplot(x='pclass', y='survived', data=titanic, color="grey")
ax.set_title('Survival rate by passenger class')
ax.set_xlabel('Passenger Class')
ax.set_ylabel('Survival Rate')Text(0, 0.5, 'Survival Rate')
Same for gender:
plt.figure(figsize = (8,4))
ax = sns.countplot(x="sex", data=titanic, color="grey", order = ["female", "male"])
ax.set_title('Distribution of Passenger Gender')
ax.set_xlabel('Passenger gender')
ax.set_ylabel('Count')Text(0, 0.5, 'Count')
plt.figure(figsize = (8,4))
ax = sns.barplot(x='sex', y='survived', data=titanic, color="grey", order = ["female", "male"])
ax.set_title('Survival rate by passenger Gender')
ax.set_xlabel('Passenger gender')
ax.set_ylabel('Survival Rate')Text(0, 0.5, 'Survival Rate')
And for port of embarkation:
plt.figure(figsize = (8,4))
ax = sns.countplot(x="embarked", data=titanic, color="grey")
ax.set_title('Distribution of Passenger Embarkation')
ax.set_xlabel('Port of Embarkation')
ax.set_ylabel('Count')Text(0, 0.5, 'Count')
plt.figure(figsize = (8,4))
ax = sns.barplot(x='embarked', y='survived', data=titanic, color="grey")
ax.set_title('Survival rate by passenger Embarkation')
ax.set_xlabel('Port of Embarkation')
ax.set_ylabel('Survival Rate')Text(0, 0.5, 'Survival Rate')
Let us have a look at the age:
titanic[["age", "sample"]].apply(lambda x: min(x))age 0.1667
sample test
dtype: object
titanic.groupby("sample")["age"].apply(lambda x: {"min": x.min(), "max": x.max()})sample
test min 0.7500
max 76.0000
train min 0.1667
max 80.0000
Name: age, dtype: float64
titanic[["age", "sample"]].groupby("sample").apply(lambda x: {"min": x.min(), "max":x.max()})sample
test {'min': [0.75, 'test'], 'max': [76.0, 'test']}
train {'min': [0.1667, 'train'], 'max': [80.0, 'trai...
dtype: object
plt.figure(figsize = (12,6))
ax = sns.histplot(x='age', data=titanic, color="grey")
ax.set_title('Age distribution')
ax.set_xlabel('Age')
ax.set_ylabel('Count')Text(0, 0.5, 'Count')
We can construct age classes to visualize the survival rate on those.
conditions = [
(titanic["age"].lt(15)),
(titanic["age"].ge(15) & titanic["age"].lt(20)),
(titanic["age"].ge(20) & titanic["age"].lt(25)),
(titanic["age"].ge(25) & titanic["age"].lt(35)),
(titanic["age"].ge(35) & titanic["age"].lt(45)),
(titanic["age"].ge(45)),
]
choices = ["<15", "[15,19]", "[20,24]", "[25,34]", "[35,44]","> 44"]
titanic["age_class"] = np.select(conditions, choices)rp.summary_cat(titanic["age_class"])| Variable | Outcome | Count | Percent | |
|---|---|---|---|---|
| 0 | age_class | [25,34] | 292 | 22.31 |
| 1 | 0 | 263 | 20.09 | |
| 2 | [20,24] | 184 | 14.06 | |
| 3 | > 44 | 176 | 13.45 | |
| 4 | [35,44] | 169 | 12.91 | |
| 5 | [15,19] | 116 | 8.86 | |
| 6 | <15 | 109 | 8.33 |
Number of missing values:
sum(titanic["age_class"] == "0")263
The survival rate depending on the age can then be plotted:
plt.figure(figsize = (8,4))
ax = sns.barplot(x='age_class', y='survived', data=titanic, color="grey", order = choices)
ax.set_title('Survival rate by passenger Age')
ax.set_xlabel('Age Class')
ax.set_ylabel('Survival Rate')Text(0, 0.5, 'Survival Rate')
There seems to be a non-linear relationship between age and survival rate. Let us look at how it varies depending on the gender of the passenger:
plt.figure(figsize = (8,4))
ax = sns.barplot(x='age_class', y='survived', data=titanic, order = choices, hue="sex")
ax.set_title('Survival rate by passenger Age')
ax.set_xlabel('Age Class')
ax.set_ylabel('Survival Rate')Text(0, 0.5, 'Survival Rate')
What about the number of siblings / spouse ? First, we should note that the 3rd quartile of the number of siblings or spouse on board is equal to 1. So, most passengers were not traveling with their siblings or spouse.
plt.figure(figsize = (8,4))
ax = sns.boxplot(y="sibsp", x="survived", data=titanic)
ax.set_title('Number of siblings / spouses depending on the survival status')
ax.set_xlabel('Survived')
ax.set_ylabel('Number of siblings / spouses')Text(0, 0.5, 'Number of siblings / spouses')
We can also look at the number of parents or children. Again, most of the passengers were not accompanied by any parents or children. The empirical quantile of order 80% is equal to 1.
titanic["parch"].describe().round(2)count 1309.00
mean 0.39
std 0.87
min 0.00
25% 0.00
50% 0.00
75% 0.00
max 9.00
Name: parch, dtype: float64
np.quantile(titanic["parch"], q=.8)1.0
plt.figure(figsize = (8,4))
ax = sns.boxplot(y="parch", x="survived", data=titanic)
ax.set_title('Number of parents / children depending on the survival status')
ax.set_xlabel('Survived')
ax.set_ylabel('Number of parents / children')Text(0, 0.5, 'Number of parents / children')
Train a model
Let us train a model to predict the probability of surviving on the Titanic. Let us consider first a simple decision tree.
df_train = titanic[["survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked", "sample"]]
df_train = df_train[df_train["sample"] == "train"]
df_train = df_train.drop("sample", axis = 1)
df_test = titanic[["survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked", "sample"]]
df_test = df_test[df_test["sample"] != "train"]
df_test = df_test.drop("sample", axis = 1)Let us remove NaN values (this should be done more properly in a real project).
df_train = df_train.dropna()
df_test = df_test.dropna()A quick glance at the training set:
df_train| survived | pclass | sex | age | sibsp | parch | fare | embarked | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 3 | male | 18.0 | 1 | 0 | 14.4542 | C |
| 1 | 0 | 3 | female | 27.0 | 0 | 0 | 7.8792 | Q |
| 2 | 0 | 3 | male | 26.0 | 0 | 0 | 7.7750 | S |
| 3 | 0 | 1 | male | 24.0 | 0 | 0 | 79.2000 | C |
| 5 | 0 | 3 | female | 39.0 | 0 | 5 | 29.1250 | Q |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1040 | 1 | 2 | female | 12.0 | 2 | 1 | 39.0000 | S |
| 1042 | 1 | 2 | female | 42.0 | 1 | 0 | 26.0000 | S |
| 1043 | 0 | 3 | male | 19.0 | 0 | 0 | 6.7500 | Q |
| 1044 | 0 | 2 | male | 21.0 | 0 | 0 | 11.5000 | S |
| 1046 | 0 | 3 | male | 20.0 | 0 | 0 | 9.5000 | S |
829 rows × 8 columns
And at the test set:
df_test| survived | pclass | sex | age | sibsp | parch | fare | embarked | |
|---|---|---|---|---|---|---|---|---|
| 1047 | 1 | 1 | male | 37.0 | 1 | 1 | 52.5542 | S |
| 1048 | 0 | 1 | male | 25.0 | 0 | 0 | 26.0000 | C |
| 1049 | 1 | 1 | male | 28.0 | 0 | 0 | 26.5500 | S |
| 1050 | 0 | 1 | male | 45.0 | 0 | 0 | 35.5000 | S |
| 1051 | 1 | 1 | female | 30.0 | 0 | 0 | 164.8667 | S |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1304 | 0 | 3 | male | 21.0 | 0 | 0 | 7.2500 | S |
| 1305 | 0 | 3 | male | 36.0 | 0 | 0 | 9.5000 | S |
| 1306 | 1 | 3 | female | 15.0 | 1 | 0 | 14.4542 | C |
| 1307 | 0 | 3 | female | 14.5 | 1 | 0 | 14.4542 | C |
| 1308 | 0 | 3 | male | 26.5 | 0 | 0 | 7.2250 | C |
214 rows × 8 columns
Let us create dummy variables for categorical variables:
categ_var = ["pclass", "sex", "embarked"]On the training set:
df_train = pd.concat(
[df_train.iloc[:,np.where(np.logical_not(df_train.columns.isin(categ_var)))[0]],
pd.get_dummies(df_train[categ_var],drop_first=True)],
axis = 1
)
display(df_train)| survived | age | sibsp | parch | fare | pclass | sex_male | embarked_Q | embarked_S | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 18.0 | 1 | 0 | 14.4542 | 3 | 1 | 0 | 0 |
| 1 | 0 | 27.0 | 0 | 0 | 7.8792 | 3 | 0 | 1 | 0 |
| 2 | 0 | 26.0 | 0 | 0 | 7.7750 | 3 | 1 | 0 | 1 |
| 3 | 0 | 24.0 | 0 | 0 | 79.2000 | 1 | 1 | 0 | 0 |
| 5 | 0 | 39.0 | 0 | 5 | 29.1250 | 3 | 0 | 1 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1040 | 1 | 12.0 | 2 | 1 | 39.0000 | 2 | 0 | 0 | 1 |
| 1042 | 1 | 42.0 | 1 | 0 | 26.0000 | 2 | 0 | 0 | 1 |
| 1043 | 0 | 19.0 | 0 | 0 | 6.7500 | 3 | 1 | 1 | 0 |
| 1044 | 0 | 21.0 | 0 | 0 | 11.5000 | 2 | 1 | 0 | 1 |
| 1046 | 0 | 20.0 | 0 | 0 | 9.5000 | 3 | 1 | 0 | 1 |
829 rows × 9 columns
On the test set:
df_test = pd.concat(
[df_test.iloc[:,np.where(np.logical_not(df_test.columns.isin(categ_var)))[0]],
pd.get_dummies(df_test[categ_var],drop_first=True)],
axis = 1
)
display(df_test)| survived | age | sibsp | parch | fare | pclass | sex_male | embarked_Q | embarked_S | |
|---|---|---|---|---|---|---|---|---|---|
| 1047 | 1 | 37.0 | 1 | 1 | 52.5542 | 1 | 1 | 0 | 1 |
| 1048 | 0 | 25.0 | 0 | 0 | 26.0000 | 1 | 1 | 0 | 0 |
| 1049 | 1 | 28.0 | 0 | 0 | 26.5500 | 1 | 1 | 0 | 1 |
| 1050 | 0 | 45.0 | 0 | 0 | 35.5000 | 1 | 1 | 0 | 1 |
| 1051 | 1 | 30.0 | 0 | 0 | 164.8667 | 1 | 0 | 0 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1304 | 0 | 21.0 | 0 | 0 | 7.2500 | 3 | 1 | 0 | 1 |
| 1305 | 0 | 36.0 | 0 | 0 | 9.5000 | 3 | 1 | 0 | 1 |
| 1306 | 1 | 15.0 | 1 | 0 | 14.4542 | 3 | 0 | 0 | 0 |
| 1307 | 0 | 14.5 | 1 | 0 | 14.4542 | 3 | 0 | 0 | 0 |
| 1308 | 0 | 26.5 | 0 | 0 | 7.2250 | 3 | 1 | 0 | 0 |
214 rows × 9 columns
X_train = df_train
X_train = X_train.drop("survived", axis=1)
y_train = df_train["survived"]
X_test = df_test
X_test = X_test.drop("survived", axis=1)
y_test = df_test["survived"]Now, let us build a decision tree to classify our observations. First, we can import the DecisionTreeClassifier() function from the tree module of the sklearn library.
from sklearn.tree import DecisionTreeClassifierThen we can train the decision tree:
# Decision Tree classifer
tree_classifier = DecisionTreeClassifier(max_depth=3)
# Let us train it
tree_classifier = tree_classifier.fit(X_train,y_train)Predictions (out of sample sample):
y_pred = tree_classifier.predict(X_test)What is the accuracy on the test set? Let us load the metrics module from the sklean library.
from sklearn import metricsThe accuracy on the test set:
print("Accuracy: ",metrics.accuracy_score(y_test, y_pred))Accuracy: 0.822429906542056
The confusion table:
metrics.confusion_matrix(y_test, y_pred)array([[119, 15],
[ 23, 57]])
Now let us visualize the decision tree. We need two libraries: graphviz and pydotplus. These need to be installed.
#pip install graphviz
#pip install pydotplusLet us import the tree module from the sklearn library.
from sklearn import treeThen we can use the plot_tree() from the tree module to plot the decision tree:
plt.figure(figsize = (12,6))
tree.plot_tree(tree_classifier, filled = True,
feature_names = X_train.columns,
class_names = ["Died", "Survived"]);Now, let us fit a random forest. Note that the we do not proceed here to fine tuning, which should be done to improve the predictive capacities of the model. Let us import the RandomForestClassifier() function from the ensemble module of the sklearn library.
from sklearn.ensemble import RandomForestClassifierThen we can use that function to train a random forest on the training set. Let us set the minimum number of observations in terminal leaves to 10, and let us consider 3 variables to draw from to make the splits each time we try to make a split. Fist, let us set the hyperparameters:
random_forest_classifier = RandomForestClassifier(n_estimators = 1000,
min_samples_leaf = 10, max_features=3)The model can then be trained:
random_forest_classifier.fit(X_train, y_train)RandomForestClassifier(max_features=3, min_samples_leaf=10, n_estimators=1000)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(max_features=3, min_samples_leaf=10, n_estimators=1000)
With this kind of model, there is no straightforward method to interpret the results. There is, however, a metric that gives the relative importance of the variables (however, this metric does not disentangle between positive or negative effects).
Assessing the predictive capacities of the model
Predictions on the training sample
predict_tree_train = tree_classifier.predict_proba(X_train)
predict_rf_train = random_forest_classifier.predict_proba(X_train)- First column: prediction class 0 (died)
- Second column: prediction class 1 (survived)
predict_tree_train[range(10)]array([[0.85824742, 0.14175258],
[0.46 , 0.54 ],
[0.85824742, 0.14175258],
[0.62037037, 0.37962963],
[0.85 , 0.15 ],
[0.85824742, 0.14175258],
[0.85824742, 0.14175258],
[0.85 , 0.15 ],
[0.85824742, 0.14175258],
[0.46 , 0.54 ]])
Predictions on the test sample
predict_tree_test = tree_classifier.predict_proba(X_test)
predict_rf_test =random_forest_classifier.predict_proba(X_test)Let us set the probability threshold for the classifier, i.e., the value above which the class “Survived” is predicted.
tau = .5np.where(predict_tree_train[:,1] >= tau, "Survived", "Died")array(['Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Survived',
'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Survived', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived', 'Died',
'Survived', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died',
'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Survived', 'Survived', 'Survived', 'Died',
'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
'Survived', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived',
'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived',
'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Survived',
'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived',
'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Died', 'Survived', 'Died',
'Died', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Survived', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Survived',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died',
'Died', 'Died', 'Died', 'Survived', 'Survived', 'Survived',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
'Died', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Died', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Survived', 'Survived',
'Survived', 'Died', 'Survived', 'Survived', 'Survived', 'Died',
'Died', 'Survived', 'Survived', 'Died', 'Survived', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Survived',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
'Died', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
'Survived', 'Died', 'Survived', 'Survived', 'Died', 'Died', 'Died',
'Died', 'Died', 'Survived', 'Died', 'Survived', 'Survived',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived',
'Survived', 'Survived', 'Died', 'Survived', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Survived', 'Died', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Survived', 'Died',
'Died', 'Survived', 'Died', 'Survived', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Survived', 'Died', 'Died', 'Survived',
'Died', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived',
'Survived', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Survived', 'Survived', 'Died', 'Died', 'Died', 'Died', 'Died',
'Died', 'Died', 'Died', 'Died', 'Survived', 'Survived', 'Died',
'Died', 'Died'], dtype='<U8')
y_pred_train_tree = np.where(predict_tree_train[:,1] >= tau, 1, 0)
y_pred_train_rf = np.where(predict_rf_train[:,1] >= tau, 1, 0)
y_pred_test_tree = np.where(predict_tree_test[:,1] >= tau, 1, 0)
y_pred_test_rf = np.where(predict_rf_test[:,1] >= tau, 1, 0)Let us define a small function to get some goodness-of-fit metrics:
def metrics_fit(y_obs, y_pred):
res = {
"Accuracy": metrics.accuracy_score(y_obs, y_pred),
"Kappa": metrics.cohen_kappa_score(y_obs, y_pred),
"Sensitivity (True Pos Rate)": metrics.recall_score(y_obs, y_pred),
"F1": metrics.f1_score(y_obs, y_pred)
}
return(res)This user-defined function can be applied to assess the predictive power of the model on the train set, for the decision tree:
metrics_fit(y_train, y_pred_train_tree){'Accuracy': 0.8033775633293124,
'Kappa': 0.5890510527132113,
'Sensitivity (True Pos Rate)': 0.7101449275362319,
'F1': 0.7503828483920367}
And for the random forest:
metrics_fit(y_train, y_pred_train_rf){'Accuracy': 0.8299155609167672,
'Kappa': 0.6390231398986452,
'Sensitivity (True Pos Rate)': 0.6898550724637681,
'F1': 0.7714748784440844}
It can also be applied on the test set for the decision tree:
metrics_fit(y_test, y_pred_test_tree){'Accuracy': 0.822429906542056,
'Kappa': 0.6129093678598629,
'Sensitivity (True Pos Rate)': 0.7125,
'F1': 0.75}
And for the random forest:
metrics_fit(y_test, y_pred_test_rf){'Accuracy': 0.822429906542056,
'Kappa': 0.6047822706065318,
'Sensitivity (True Pos Rate)': 0.6625,
'F1': 0.736111111111111}
Explaining the predicted probabilities with SHAP
Now, let us turn to SHAP to explain how the explanatory variables help explain, for some values of the characteristics, how these specific values affected the deviation from the average predicted probability of surviving.
# pip install shap
import shapLet us consider a single individual, Edith Eileen, who is characterized as follows (she actually survived and lived more than 100 years!):
example = titanic[titanic["name"].str.contains("Edith Eileen")]ind_example = example.index[0]
ind_example539
np.where(X_train.index == ind_example)[0]array([432])
X_example = X_train[X_train.index.isin([ind_example])]
X_example| age | sibsp | parch | fare | pclass | sex_male | embarked_Q | embarked_S | |
|---|---|---|---|---|---|---|---|---|
| 539 | 15.0 | 0 | 2 | 39.0 | 2 | 0 | 0 | 1 |
Let us build our predict function:
def my_predict_f_prob(x):
pred_prob = random_forest_classifier.predict_proba(x)
return(pred_prob)The probability that our individual dies (first column) or survives (second column):
pred_example = my_predict_f_prob(X_example)
pred_examplearray([[0.07657917, 0.92342083]])
Then let us build the explainer, with SHAP KernelExplainer. The function takes three arguments:
model: the model to be explained (or the prediction function)data: background datsaetlink: A generalized linear model link to connect the feature importance values to the model output (identity by default)
explainer = shap.KernelExplainer(model=my_predict_f_prob, data=X_train)X does not have valid feature names, but RandomForestClassifier was fitted with feature names
Using 829 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
And let us estimate the Shapley values for our training example:
shap_values_example = explainer.shap_values(X=X_example, nsamples = 100)X does not have valid feature names, but RandomForestClassifier was fitted with feature names
X does not have valid feature names, but RandomForestClassifier was fitted with feature names
shap_values_example[array([[-0.02516427, -0.01288884, -0.02583533, -0.01695874, -0.1218961 ,
-0.30537386, 0.0003784 , 0.00115635]]),
array([[ 0.02516427, 0.01288884, 0.02583533, 0.01695874, 0.1218961 ,
0.30537386, -0.0003784 , -0.00115635]])]
The Shapley values explain how each variable moves, in the model, for the specific individual, the prediction from the average prediction.
shap_values_example[0].sum()-0.5065823870461259
The average prediction of the probability of surviving in all the training sample:
avg_pred = my_predict_f_prob(X_train)
avg_pred = avg_pred[:,0].mean()
avg_pred0.5831615566676123
We can see that the sum of these values is equal to the deviation of the predicted value from the average predicted probability to survive in the whole training sample:
pred_example[0][0] - avg_pred-0.5065823870461261
The package offers some nice plots to explain at least a single example:
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values_example[0], X_example)Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
shap.force_plot(explainer.expected_value[1], shap_values_example[1], X_example)Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
More options are available, see https://github.com/slundberg/shap