SHAP for Classification Models¶

The following tutorial demonstrates the use of SHAP in the context of a classification model. The example is predicting customer churn for a bank. The data set is available from OpenML

In [1]:
# Data handling 
import pandas as pd
import numpy as np
# EDA
import seaborn as sns
# Preprocessing
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
# Model training and evaluation
import xgboost as xgb
from sklearn.metrics import accuracy_score, precision_score, recall_score
# Model explanation
import shap
import matplotlib.pyplot as plt

Loading the data¶

In [2]:
df = pd.read_csv("dataset.csv", encoding="utf-8")
In [3]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 14 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   RowNumber        10000 non-null  int64  
 1   CustomerId       10000 non-null  int64  
 2   Surname          10000 non-null  object 
 3   CreditScore      10000 non-null  int64  
 4   Geography        10000 non-null  object 
 5   Gender           10000 non-null  object 
 6   Age              10000 non-null  int64  
 7   Tenure           10000 non-null  int64  
 8   Balance          10000 non-null  float64
 9   NumOfProducts    10000 non-null  int64  
 10  HasCrCard        10000 non-null  int64  
 11  IsActiveMember   10000 non-null  int64  
 12  EstimatedSalary  10000 non-null  float64
 13  Exited           10000 non-null  int64  
dtypes: float64(2), int64(9), object(3)
memory usage: 1.1+ MB
In [4]:
# Dropping unnecessary columns
df.drop(columns=["RowNumber", "CustomerId", "Surname"], inplace=True)
In [5]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 11 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   CreditScore      10000 non-null  int64  
 1   Geography        10000 non-null  object 
 2   Gender           10000 non-null  object 
 3   Age              10000 non-null  int64  
 4   Tenure           10000 non-null  int64  
 5   Balance          10000 non-null  float64
 6   NumOfProducts    10000 non-null  int64  
 7   HasCrCard        10000 non-null  int64  
 8   IsActiveMember   10000 non-null  int64  
 9   EstimatedSalary  10000 non-null  float64
 10  Exited           10000 non-null  int64  
dtypes: float64(2), int64(7), object(2)
memory usage: 859.5+ KB

For a description of the variables see OpenML. There are no missing values in the data set.

Exploratory Data Analysis¶

Let's look at how many customer exited the bank in the data set.

In [6]:
df['Exited'].mean()
Out[6]:
0.2037

We will look at the effect of age in the model later on. For this purpose, we are binning the numerical age variable.

In [7]:
df["Age"].describe()
Out[7]:
count    10000.000000
mean        38.921800
std         10.487806
min         18.000000
25%         32.000000
50%         37.000000
75%         44.000000
max         92.000000
Name: Age, dtype: float64
In [8]:
bins = list(range(15, df['Age'].max() + 5, 5))
labels = [f"{b} - {b+4}" for b in bins[:-1]]
df['AgeGroup'] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)

Next, we calculate the percentage of customers who churned per age group. The variable "Exited" contains the information about churn. As it is coded as a dummy variable, we can use the mean for an age group as the proportion of customers who churned.

In [9]:
age_group_exited_ratio = df.groupby('AgeGroup')['Exited'].mean().reset_index()
age_group_exited_ratio['Exited'] = age_group_exited_ratio['Exited'] * 100 
age_group_exited_ratio 
Out[9]:
AgeGroup Exited
0 15 - 19 6.122449
1 20 - 24 9.068627
2 25 - 29 7.094595
3 30 - 34 8.145240
4 35 - 39 13.301560
5 40 - 44 23.670054
6 45 - 49 43.386243
7 50 - 54 56.920078
8 55 - 59 54.775281
9 60 - 64 42.622951
10 65 - 69 21.374046
11 70 - 74 14.432990
12 75 - 79 0.000000
13 80 - 84 9.090909
14 85 - 89 0.000000
15 90 - 94 0.000000

Middle-aged customers are most likely to churn. Let us visualize the relationship between churn and age group using the seaborn package.

In [10]:
sns.barplot(x='AgeGroup', y='Exited', data=age_group_exited_ratio)
plt.xlabel('Age Group')
plt.ylabel('Churn (%)')
plt.title('Churn (%) by Age Group')
plt.xticks(rotation=45)  
plt.show()

Finally, let us look at the distribution of the estimated salary and credit score.

In [11]:
df["EstimatedSalary"].describe()
Out[11]:
count     10000.000000
mean     100090.239881
std       57510.492818
min          11.580000
25%       51002.110000
50%      100193.915000
75%      149388.247500
max      199992.480000
Name: EstimatedSalary, dtype: float64
In [12]:
df["CreditScore"].describe()
Out[12]:
count    10000.000000
mean       650.528800
std         96.653299
min        350.000000
25%        584.000000
50%        652.000000
75%        718.000000
max        850.000000
Name: CreditScore, dtype: float64

Preparing the data for machine learning¶

To prepare for the machine learning part, we split the data into predictor variables (X) and target variable / label (y).

In [13]:
X_raw = df.loc[:, df.columns != "Exited"]
y = df.loc[:, df.columns == "Exited"]

Next, we will use one hot encoding for the categorical variables

In [14]:
# One hot encoding for all object columns 

## Split X by data type
X_raw_numerical = X_raw.select_dtypes(exclude=['object'])
X_raw_categorical = X_raw.select_dtypes(include=['object'])

X_raw_categorical_ohe = pd.get_dummies(X_raw_categorical)

# Combine numerical columns and one hot encoded categorical columns
X = pd.concat([X_raw_numerical, X_raw_categorical_ohe], axis=1)

X.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 14 columns):
 #   Column             Non-Null Count  Dtype   
---  ------             --------------  -----   
 0   CreditScore        10000 non-null  int64   
 1   Age                10000 non-null  int64   
 2   Tenure             10000 non-null  int64   
 3   Balance            10000 non-null  float64 
 4   NumOfProducts      10000 non-null  int64   
 5   HasCrCard          10000 non-null  int64   
 6   IsActiveMember     10000 non-null  int64   
 7   EstimatedSalary    10000 non-null  float64 
 8   AgeGroup           10000 non-null  category
 9   Geography_France   10000 non-null  uint8   
 10  Geography_Germany  10000 non-null  uint8   
 11  Geography_Spain    10000 non-null  uint8   
 12  Gender_Female      10000 non-null  uint8   
 13  Gender_Male        10000 non-null  uint8   
dtypes: category(1), float64(2), int64(6), uint8(5)
memory usage: 684.4 KB
In [15]:
cols_for_training = list(df.columns) # use all columns by default
cols_not_for_training = ["AgeGroup"] # define unwanted columns
cols_for_training = list(set(cols_for_training) - set(cols_not_for_training )) # remove unwanted columns from list
print(cols_for_training)
['Tenure', 'Gender', 'Exited', 'Geography', 'IsActiveMember', 'Balance', 'EstimatedSalary', 'HasCrCard', 'Age', 'CreditScore', 'NumOfProducts']
In [16]:
# Split into training and test data sets  (stratified)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
In [17]:
# Converting to XGB data format
dtrain = xgb.DMatrix(X_train.loc[:, X_train.columns.isin(cols_for_training)], label=y_train)
dtest = xgb.DMatrix(X_test.loc[:, X_train.columns.isin(cols_for_training)], label=y_test)

Setting up XGB¶

In [18]:
# Setting parameters
params = {
    'objective': 'binary:logistic',
    'max_depth': 10,
    'eta': 0.3,
    'nthread': 4
}
In [19]:
num_round = 1000
xgb_classifier = xgb.train(params, dtrain, num_round)

Model evaluation¶

In [20]:
## Model evaluation

## Generate predictions for test data
preds = xgb_classifier.predict(dtest)
preds_binary = (preds > 0.5).astype(int)

## Calculate metrics

positive_ratio = y_test.mean()[0]
print(f"Proportion of positive cases in the test set: {positive_ratio:.2f}")

accuracy = accuracy_score(y_test, preds_binary)
print(f"Accuracy: {accuracy:.2f}")

precision = precision_score(y_test, preds_binary)
recall = recall_score(y_test, preds_binary)

print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
Proportion of positive cases in the test set: 0.20
Accuracy: 0.83
Precision: 0.61
Recall: 0.46

We get some decent results for the model performance.

In [21]:
results = X_test.copy()
results['Actual'] = y_test
results['Score'] = preds
results['Predicted'] = preds_binary
In [22]:
results.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 3000 entries, 6417 to 9704
Data columns (total 17 columns):
 #   Column             Non-Null Count  Dtype   
---  ------             --------------  -----   
 0   CreditScore        3000 non-null   int64   
 1   Age                3000 non-null   int64   
 2   Tenure             3000 non-null   int64   
 3   Balance            3000 non-null   float64 
 4   NumOfProducts      3000 non-null   int64   
 5   HasCrCard          3000 non-null   int64   
 6   IsActiveMember     3000 non-null   int64   
 7   EstimatedSalary    3000 non-null   float64 
 8   AgeGroup           3000 non-null   category
 9   Geography_France   3000 non-null   uint8   
 10  Geography_Germany  3000 non-null   uint8   
 11  Geography_Spain    3000 non-null   uint8   
 12  Gender_Female      3000 non-null   uint8   
 13  Gender_Male        3000 non-null   uint8   
 14  Actual             3000 non-null   int64   
 15  Score              3000 non-null   float32 
 16  Predicted          3000 non-null   int32   
dtypes: category(1), float32(1), float64(2), int32(1), int64(7), uint8(5)
memory usage: 276.1 KB

Exploring Model predictions¶

Earlier we saw the relationship between age group and churn in the actual data. Let us now compare this to the predictions generated by the model.

In [23]:
# Plot Scores and Age
plt.figure(figsize=(10, 6))
results.boxplot(column='Score', by='AgeGroup', grid=False)
plt.xlabel('Age')
plt.ylabel('Score')
plt.title('Scores by Age')
plt.xticks(rotation=45)  
plt.suptitle('')
plt.show()
<Figure size 1000x600 with 0 Axes>

We observe the same pattern. The model's prediction track the pattern in the actual data.

SHAP¶

Global explanations¶

In [24]:
X_shap = X_test.loc[:, X_test.columns.isin(cols_for_training)]

We now initialize SHAP. Since we are using a tree-based model, we can use the TreeExplainer.

In [25]:
explainer = shap.TreeExplainer(xgb_classifier)

We can now calculate SHAP values. We are using the test data set as background data. Background data is providing SHAP with the marginal distribution of the feature values. Usually, the train data set is used. We are using the test data set because we already got the predictions for it.

In [26]:
shap_values = explainer.shap_values(X_shap)

The SHAP summary plot shows the most important features.

In [27]:
shap.summary_plot(shap_values, X_shap, plot_type="bar")

Another version of the summary plot show the most important features (in descending order from top to bottom) and visualizes the relationship between feature falues and impact on the model. Age has a high impact on model predictions. Customers with higher age (red dots) tend to have positive SHAP values - driving the predicted score for churn upwards. In other words, higher age in the model is associated with more churn.

In [28]:
shap.summary_plot(shap_values, X_shap)

Local explanations: explaining individual predictions¶

Let us look at some individual predictions. We isolate the most likely and the least likely customer to churn.

In [29]:
high_pred_index = np.argmax(preds)
high_pred_instance = X_shap.iloc[high_pred_index:high_pred_index+1]
In [30]:
shap_values_high = explainer.shap_values(high_pred_instance)

The force plot shows what the relevant feature values contribute towards the predicted value relative to the base value (average SHAP value). The length of the arrow corresponds to the value of the SHAP value and thus the feature importance. The direction and color visualizes the direction of the impact. In this example, the relatively high age and number of products are the main determinants of the high score on churn.

In [31]:
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values_high, high_pred_instance)
Out[31]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [32]:
low_pred_index = np.argmin(preds)
low_pred_instance = X_shap.iloc[low_pred_index:low_pred_index+1]  
In [33]:
shap_values_low = explainer.shap_values(low_pred_instance)

In the example of the low risk customer, the relatively low age, the estimated salary and the credit score are the main determinants of the model's prediction.

In [34]:
shap.force_plot(explainer.expected_value, shap_values_low, low_pred_instance)
Out[34]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.