Lecture 21: Survival analysis#

UBC 2024-25


import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import (
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import (
import sys
import os

sys.path.append(os.path.join(os.path.abspath(".."), "code"))
from utils import *

plt.rcParams["font.size"] = 12

# does lifelines try to mess with this?
pd.options.display.max_rows = 10

import warnings
DATA_DIR = os.path.join(os.path.abspath(".."), "data/")
import lifelines

Learning objectives#

  • Explain what is right-censored data.

  • Explain the problem with treating right-censored data the same as “regular” data.

  • Determine whether survival analysis is an appropriate tool for a given problem.

  • Apply survival analysis in Python using the lifelines package.

  • Interpret a survival curve, such as the Kaplan-Meier curve.

  • Interpret the coefficients of a fitted Cox proportional hazards model.

  • Make predictions for existing individuals and interpret these predictions.

❓❓ Questions for you#

(iClicker) Exercise 21.1#

When we used Ridge model with hours and day of the week as numerical features, how many coefficients did we get? (Remember that hours were at 3h gaps)

  • (A) 1

  • (B) 2

  • (C) 15

  • (D) 56

  • (E) ???

How about when hours and day of the week were one-hot encoded?

(iClicker) Exercise 21.2#

Select all of the following statements which are TRUE.

  • (A) We need to be careful when splitting the data when working with time series data.

  • (B) Cross-validation in time series can be randomly applied like in other machine learning tasks.

  • (C) In time series forecasting, the future value of a series can only be predicted based on its past values and cannot incorporate other variables.

  • (D) When we used RandomForestRegressor model on the POSIX time feature, it predicted a straight line on the test data because tree-based models are inherently unable to extrapolate (i.e., make predictions outside the range of the training data).


  • Time series analysis is used when there is a temporal aspect in the data.

  • Data splitting: Data should be split based on time to avoid future data leaking into the training set.

  • Essential questions for Exploratory Data Analysis (EDA):

    • What is the frequency of data collection (e.g., hourly, daily)?

    • How many time series are present within the dataset?

    • Are there any gaps or missing values in the data?

  • Feature engineering

    • Derived new features from the date/time column.

    • Appropriately encoded features based on the chosen model.

    • Created lag features to incorporate past values for prediction.

  • Baseline model approach: Employ a simple model, such as using today’s target value to predict tomorrow’s, as a starting point for comparison.

  • Cross-Validation Method for Time Series: In sklearn, use TimeSeriesSplit as the cv parameter in functions like cross_validate or cross_val_score for time-appropriate validation.

  • Strategies for long-term forecasting:

    • Generate forecasts for sequential time steps by assuming the predictions for the previous steps are accurate.

  • Trends

    • A ‘days since’ feature to capture the trend over time

Customer churn#

  • Customer churn, also known as customer attrition, refers to the phenomenon where customers or subscribers stop doing business with a company or service.

  • The bar-chart below is showing the monthly subscriber churn rates for various streaming services.

  • Imagine that you are working for a subscription-based telecom company.

  • They want to predict when a specific customer will churn so that they can come up with retention strategies for different customer segments.

  • We want to model “time to churn” to understand different factors affecting customer churn.

  • Is it possible to use machine learning to predict whether a specific customer will churn?

Let’s work with this dataset Customer Churn Dataset, which is collected at a fixed time.

df = pd.read_csv(DATA_DIR + "WA_Fn-UseC_-Telco-Customer-Churn.csv")
train_df, test_df = train_test_split(df, random_state=123)
customerID gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines InternetService OnlineSecurity ... DeviceProtection TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod MonthlyCharges TotalCharges Churn
6464 4726-DLWQN Male 1 No No 50 Yes Yes DSL Yes ... No No Yes No Month-to-month Yes Bank transfer (automatic) 70.35 3454.6 No
5707 4537-DKTAL Female 0 No No 2 Yes No DSL No ... No No No No Month-to-month No Electronic check 45.55 84.4 No
3442 0468-YRPXN Male 0 No No 29 Yes No Fiber optic No ... Yes Yes Yes Yes Month-to-month Yes Credit card (automatic) 98.80 2807.1 No
3932 1304-NECVQ Female 1 No No 2 Yes Yes Fiber optic No ... Yes No No No Month-to-month Yes Electronic check 78.55 149.55 Yes
6124 7153-CHRBV Female 0 Yes Yes 57 Yes No DSL Yes ... Yes Yes No No One year Yes Mailed check 59.30 3274.35 No

5 rows × 21 columns

  • We are interested in predicting customer churn: the “Churn” column.

  • How will you approach this problem with the approaches we have seen so far?

  • How about treating this as a binary classification problem where we want to predict Churn (yes/no) from these -other columns.

  • Before we look into survival analysis, let’s just treat it as a binary classification model where we want to predict whether a customer churned or not.

(5282, 21)
No     3912
Yes    1370
Name: count, dtype: int64
<class 'pandas.core.frame.DataFrame'>
Index: 5282 entries, 6464 to 3582
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        5282 non-null   object 
 1   gender            5282 non-null   object 
 2   SeniorCitizen     5282 non-null   int64  
 3   Partner           5282 non-null   object 
 4   Dependents        5282 non-null   object 
 5   tenure            5282 non-null   int64  
 6   PhoneService      5282 non-null   object 
 7   MultipleLines     5282 non-null   object 
 8   InternetService   5282 non-null   object 
 9   OnlineSecurity    5282 non-null   object 
 10  OnlineBackup      5282 non-null   object 
 11  DeviceProtection  5282 non-null   object 
 12  TechSupport       5282 non-null   object 
 13  StreamingTV       5282 non-null   object 
 14  StreamingMovies   5282 non-null   object 
 15  Contract          5282 non-null   object 
 16  PaperlessBilling  5282 non-null   object 
 17  PaymentMethod     5282 non-null   object 
 18  MonthlyCharges    5282 non-null   float64
 19  TotalCharges      5282 non-null   object 
 20  Churn             5282 non-null   object 
dtypes: float64(1), int64(2), object(18)
memory usage: 907.8+ KB

Question: Does this mean there is no missing data?

Ok, let’s try our usual approach:

numeric_features = ["tenure", "MonthlyCharges", "TotalCharges"]
drop_features = ["customerID"]
passthrough_features = ["SeniorCitizen"]
target_column = ["Churn"]
# the rest are categorical
categorical_features = list(
    - set(numeric_features)
    - set(passthrough_features)
    - set(drop_features)
    - set(target_column)
preprocessor = make_column_transformer(
    (StandardScaler(), numeric_features),
    (OneHotEncoder(), categorical_features),
    ("passthrough", passthrough_features),
    ("drop", drop_features),
ValueError                                Traceback (most recent call last)
/var/folders/7l/2m7m0lw97rvf654x1cwtdfmr0000gr/T/ipykernel_92470/2961559258.py in ?()
----> 1 preprocessor.fit(train_df);

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py in ?(self, X, y, **params)
    918         """
    919         _raise_for_params(params, self, "fit")
    920         # we use fit_transform to make sure to set sparse_output_ (for which we
    921         # need the transformed data) to have consistent output type in predict
--> 922         self.fit_transform(X, y=y, **params)
    923         return self

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/_set_output.py in ?(self, X, *args, **kwargs)
    311     @wraps(f)
    312     def wrapped(self, X, *args, **kwargs):
--> 313         data_to_wrap = f(self, X, *args, **kwargs)
    314         if isinstance(data_to_wrap, tuple):
    315             # only wrap the first output for cross decomposition
    316             return_tuple = (

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/base.py in ?(estimator, *args, **kwargs)
   1469                 skip_parameter_validation=(
   1470                     prefer_skip_nested_validation or global_skip_validation
   1471                 )
   1472             ):
-> 1473                 return fit_method(estimator, *args, **kwargs)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py in ?(self, X, y, **params)
    972             routed_params = process_routing(self, "fit_transform", **params)
    973         else:
    974             routed_params = self._get_empty_routing()
--> 976         result = self._call_func_on_transformers(
    977             X,
    978             y,
    979             _fit_transform_one,

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/compose/_column_transformer.py in ?(self, X, y, func, column_as_labels, routed_params)
    887         except ValueError as e:
    888             if "Expected 2D array, got 1D array instead" in str(e):
    889                 raise ValueError(_ERR_MSG_1DCOLUMN) from e
    890             else:
--> 891                 raise

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/parallel.py in ?(self, iterable)
     70         iterable_with_config = (
     71             (_with_config(delayed_func, config), args, kwargs)
     72             for delayed_func, args, kwargs in iterable
     73         )
---> 74         return super().__call__(iterable_with_config)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/joblib/parallel.py in ?(self, iterable)
   1914             # If n_jobs==1, run the computation sequentially and return
   1915             # immediately to avoid overheads.
   1916             output = self._get_sequential_output(iterable)
   1917             next(output)
-> 1918             return output if self.return_generator else list(output)
   1920         # Let's create an ID that uniquely identifies the current call. If the
   1921         # call is interrupted early and that the same instance is immediately

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/joblib/parallel.py in ?(self, iterable)
   1857         finally:
   1858             self.print_progress()
   1859             self._running = False
   1860             self._iterating = False
-> 1861             self._original_iterator = None

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/parallel.py in ?(self, *args, **kwargs)
    132                 UserWarning,
    133             )
    134             config = {}
    135         with config_context(**config):
--> 136             return self.function(*args, **kwargs)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/pipeline.py in ?(transformer, X, y, weight, message_clsname, message, params)
   1306     """
   1307     params = params or {}
   1308     with _print_elapsed_time(message_clsname, message):
   1309         if hasattr(transformer, "fit_transform"):
-> 1310             res = transformer.fit_transform(X, y, **params.get("fit_transform", {}))
   1311         else:
   1312             res = transformer.fit(X, y, **params.get("fit", {})).transform(
   1313                 X, **params.get("transform", {})

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/_set_output.py in ?(self, X, *args, **kwargs)
    311     @wraps(f)
    312     def wrapped(self, X, *args, **kwargs):
--> 313         data_to_wrap = f(self, X, *args, **kwargs)
    314         if isinstance(data_to_wrap, tuple):
    315             # only wrap the first output for cross decomposition
    316             return_tuple = (

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/base.py in ?(self, X, y, **fit_params)
   1094                 )
   1096         if y is None:
   1097             # fit method of arity 1 (unsupervised transformation)
-> 1098             return self.fit(X, **fit_params).transform(X)
   1099         else:
   1100             # fit method of arity 2 (supervised transformation)
   1101             return self.fit(X, y, **fit_params).transform(X)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/preprocessing/_data.py in ?(self, X, y, sample_weight)
    874             Fitted scaler.
    875         """
    876         # Reset internal state before fitting
    877         self._reset()
--> 878         return self.partial_fit(X, y, sample_weight)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/base.py in ?(estimator, *args, **kwargs)
   1469                 skip_parameter_validation=(
   1470                     prefer_skip_nested_validation or global_skip_validation
   1471                 )
   1472             ):
-> 1473                 return fit_method(estimator, *args, **kwargs)

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/preprocessing/_data.py in ?(self, X, y, sample_weight)
    910         self : object
    911             Fitted scaler.
    912         """
    913         first_call = not hasattr(self, "n_samples_seen_")
--> 914         X = self._validate_data(
    915             X,
    916             accept_sparse=("csr", "csc"),
    917             dtype=FLOAT_DTYPES,

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/base.py in ?(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)
    629                 out = y
    630             else:
    631                 out = X, y
    632         elif not no_val_X and no_val_y:
--> 633             out = check_array(X, input_name="X", **check_params)
    634         elif no_val_X and not no_val_y:
    635             out = _check_y(y, **check_params)
    636         else:

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/validation.py in ?(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
   1009                         )
   1010                     array = xp.astype(array, dtype, copy=False)
   1011                 else:
   1012                     array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
-> 1013             except ComplexWarning as complex_warning:
   1014                 raise ValueError(
   1015                     "Complex data not supported\n{}\n".format(array)
   1016                 ) from complex_warning

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/sklearn/utils/_array_api.py in ?(array, dtype, order, copy, xp, device)
    747         # Use NumPy API to support order
    748         if copy is True:
    749             array = numpy.array(array, order=order, dtype=dtype)
    750         else:
--> 751             array = numpy.asarray(array, order=order, dtype=dtype)
    753         # At this point array is a NumPy ndarray. We convert it to an array
    754         # container that is consistent with the input's namespace.

~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/generic.py in ?(self, dtype, copy)
   2149     def __array__(
   2150         self, dtype: npt.DTypeLike | None = None, copy: bool_t | None = None
   2151     ) -> np.ndarray:
   2152         values = self._values
-> 2153         arr = np.asarray(values, dtype=dtype)
   2154         if (
   2155             astype_is_view(values.dtype, arr.dtype)
   2156             and using_copy_on_write()

ValueError: could not convert string to float: ' '

Hmmm, one of the numeric features is causing problems?

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7043 entries, 0 to 7042
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        7043 non-null   object 
 1   gender            7043 non-null   object 
 2   SeniorCitizen     7043 non-null   int64  
 3   Partner           7043 non-null   object 
 4   Dependents        7043 non-null   object 
 5   tenure            7043 non-null   int64  
 6   PhoneService      7043 non-null   object 
 7   MultipleLines     7043 non-null   object 
 8   InternetService   7043 non-null   object 
 9   OnlineSecurity    7043 non-null   object 
 10  OnlineBackup      7043 non-null   object 
 11  DeviceProtection  7043 non-null   object 
 12  TechSupport       7043 non-null   object 
 13  StreamingTV       7043 non-null   object 
 14  StreamingMovies   7043 non-null   object 
 15  Contract          7043 non-null   object 
 16  PaperlessBilling  7043 non-null   object 
 17  PaymentMethod     7043 non-null   object 
 18  MonthlyCharges    7043 non-null   float64
 19  TotalCharges      7043 non-null   object 
 20  Churn             7043 non-null   object 
dtypes: float64(1), int64(2), object(18)
memory usage: 1.1+ MB

Oh, looks like TotalCharges is not a numeric type. What if we change the type of this column to float?

train_df["TotalCharges"] = train_df["TotalCharges"].astype(float)
ValueError                                Traceback (most recent call last)
Cell In[11], line 1
----> 1 train_df["TotalCharges"] = train_df["TotalCharges"].astype(float)

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/generic.py:6643, in NDFrame.astype(self, dtype, copy, errors)
   6637     results = [
   6638         ser.astype(dtype, copy=copy, errors=errors) for _, ser in self.items()
   6639     ]
   6641 else:
   6642     # else, only a single dtype is given
-> 6643     new_data = self._mgr.astype(dtype=dtype, copy=copy, errors=errors)
   6644     res = self._constructor_from_mgr(new_data, axes=new_data.axes)
   6645     return res.__finalize__(self, method="astype")

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/internals/managers.py:430, in BaseBlockManager.astype(self, dtype, copy, errors)
    427 elif using_copy_on_write():
    428     copy = False
--> 430 return self.apply(
    431     "astype",
    432     dtype=dtype,
    433     copy=copy,
    434     errors=errors,
    435     using_cow=using_copy_on_write(),
    436 )

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/internals/managers.py:363, in BaseBlockManager.apply(self, f, align_keys, **kwargs)
    361         applied = b.apply(f, **kwargs)
    362     else:
--> 363         applied = getattr(b, f)(**kwargs)
    364     result_blocks = extend_blocks(applied, result_blocks)
    366 out = type(self).from_blocks(result_blocks, self.axes)

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/internals/blocks.py:758, in Block.astype(self, dtype, copy, errors, using_cow, squeeze)
    755         raise ValueError("Can not squeeze with more than one column.")
    756     values = values[0, :]  # type: ignore[call-overload]
--> 758 new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
    760 new_values = maybe_coerce_values(new_values)
    762 refs = None

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/dtypes/astype.py:237, in astype_array_safe(values, dtype, copy, errors)
    234     dtype = dtype.numpy_dtype
    236 try:
--> 237     new_values = astype_array(values, dtype, copy=copy)
    238 except (ValueError, TypeError):
    239     # e.g. _astype_nansafe can fail on object-dtype of strings
    240     #  trying to convert to float
    241     if errors == "ignore":

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/dtypes/astype.py:182, in astype_array(values, dtype, copy)
    179     values = values.astype(dtype, copy=copy)
    181 else:
--> 182     values = _astype_nansafe(values, dtype, copy=copy)
    184 # in pandas we don't store numpy str dtypes, so convert to object
    185 if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str):

File ~/opt/anaconda3/envs/cpsc330-24/lib/python3.12/site-packages/pandas/core/dtypes/astype.py:133, in _astype_nansafe(arr, dtype, copy, skipna)
    129     raise ValueError(msg)
    131 if copy or arr.dtype == object or dtype == object:
    132     # Explicit copy, or required since NumPy can't view from / to object.
--> 133     return arr.astype(dtype, copy=True)
    135 return arr.astype(dtype, copy=copy)

ValueError: could not convert string to float: ' '


for val in train_df["TotalCharges"]:
    except ValueError:

Any ideas?

Well, it turns out we can’t see those problematic values because they are whitespace!

for val in train_df["TotalCharges"]:
    except ValueError:
        print('"%s"' % val)
" "
" "
" "
" "
" "
" "
" "
" "

Let’s replace the whitespaces with NaNs.

train_df = train_df.assign(
    TotalCharges=train_df["TotalCharges"].replace(" ", np.nan).astype(float)
test_df = test_df.assign(
    TotalCharges=test_df["TotalCharges"].replace(" ", np.nan).astype(float)
<class 'pandas.core.frame.DataFrame'>
Index: 5282 entries, 6464 to 3582
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        5282 non-null   object 
 1   gender            5282 non-null   object 
 2   SeniorCitizen     5282 non-null   int64  
 3   Partner           5282 non-null   object 
 4   Dependents        5282 non-null   object 
 5   tenure            5282 non-null   int64  
 6   PhoneService      5282 non-null   object 
 7   MultipleLines     5282 non-null   object 
 8   InternetService   5282 non-null   object 
 9   OnlineSecurity    5282 non-null   object 
 10  OnlineBackup      5282 non-null   object 
 11  DeviceProtection  5282 non-null   object 
 12  TechSupport       5282 non-null   object 
 13  StreamingTV       5282 non-null   object 
 14  StreamingMovies   5282 non-null   object 
 15  Contract          5282 non-null   object 
 16  PaperlessBilling  5282 non-null   object 
 17  PaymentMethod     5282 non-null   object 
 18  MonthlyCharges    5282 non-null   float64
 19  TotalCharges      5274 non-null   float64
 20  Churn             5282 non-null   object 
dtypes: float64(2), int64(2), object(17)
memory usage: 907.8+ KB

But now we are going to have missing values and we need to include imputation for numeric features in our preprocessor.

preprocessor = make_column_transformer(
        make_pipeline(SimpleImputer(strategy="median"), StandardScaler()),
    (OneHotEncoder(handle_unknown="ignore"), categorical_features),
    ("passthrough", passthrough_features),
    ("drop", drop_features),

Now let’s try that again…


It worked! Let’s get the column names of the transformed data from the column transformer.

new_columns = (
    + preprocessor.named_transformers_["onehotencoder"]
    + passthrough_features
X_train_enc = pd.DataFrame(
    preprocessor.transform(train_df), index=train_df.index, columns=new_columns
X_test_enc = pd.DataFrame(
    preprocessor.transform(train_df), index=train_df.index, columns=new_columns
tenure MonthlyCharges TotalCharges Dependents_No Dependents_Yes PhoneService_No PhoneService_Yes TechSupport_No TechSupport_No internet service TechSupport_Yes ... Contract_Month-to-month Contract_One year Contract_Two year StreamingMovies_No StreamingMovies_No internet service StreamingMovies_Yes InternetService_DSL InternetService_Fiber optic InternetService_No SeniorCitizen
6464 0.707712 0.185175 0.513678 1.0 0.0 0.0 1.0 1.0 0.0 0.0 ... 1.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 1.0
5707 -1.248999 -0.641538 -0.979562 1.0 0.0 0.0 1.0 1.0 0.0 0.0 ... 1.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0
3442 -0.148349 1.133562 0.226789 1.0 0.0 0.0 1.0 0.0 0.0 1.0 ... 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0
3932 -1.248999 0.458524 -0.950696 1.0 0.0 0.0 1.0 1.0 0.0 0.0 ... 1.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0
6124 0.993065 -0.183179 0.433814 0.0 1.0 0.0 1.0 0.0 0.0 1.0 ... 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0

5 rows × 45 columns

results = {}
X_train = train_df.drop(columns=["Churn"])
X_test = test_df.drop(columns=["Churn"])

y_train = train_df["Churn"]
y_test = test_df["Churn"]


dc = DummyClassifier()
results["dummy"] = mean_std_cross_val_scores(
    dc, X_train, y_train, return_train_score=True
fit_time score_time test_score train_score
dummy 0.003 (+/- 0.000) 0.002 (+/- 0.000) 0.741 (+/- 0.000) 0.741 (+/- 0.000)

Dummy model scores are pretty good because we have class imbalance.

No     3912
Yes    1370
Name: count, dtype: int64


lr = make_pipeline(preprocessor, LogisticRegression(max_iter=1000))
results["logistic regression"] = mean_std_cross_val_scores(
    lr, X_train, y_train, return_train_score=True
fit_time score_time test_score train_score
dummy 0.003 (+/- 0.000) 0.002 (+/- 0.000) 0.741 (+/- 0.000) 0.741 (+/- 0.000)
logistic regression 0.037 (+/- 0.003) 0.010 (+/- 0.003) 0.804 (+/- 0.014) 0.808 (+/- 0.002)
confusion_matrix(y_train, cross_val_predict(lr, X_train, y_train))
array([[3515,  397],
       [ 636,  734]])
  • Logistic regression beats the dummy model.

  • But it seems like we have many false negatives.


Let’s try random forest model.

rf = make_pipeline(preprocessor, RandomForestClassifier(n_estimators=100))
results["random forest"] = mean_std_cross_val_scores(
    rf, X_train, y_train, return_train_score=True
fit_time score_time test_score train_score
dummy 0.003 (+/- 0.000) 0.002 (+/- 0.000) 0.741 (+/- 0.000) 0.741 (+/- 0.000)
logistic regression 0.037 (+/- 0.003) 0.010 (+/- 0.003) 0.804 (+/- 0.014) 0.808 (+/- 0.002)
random forest 0.412 (+/- 0.047) 0.021 (+/- 0.000) 0.789 (+/- 0.014) 0.998 (+/- 0.000)
confusion_matrix(y_train, cross_val_predict(rf, X_train, y_train))
array([[3537,  375],
       [ 740,  630]])
  • Random forest is not improving the scores.

  • We might decide to do hyperparamter optimization to further improve the score.

  • But after trying out all the usual things should we be happy with the scores?

  • Are we doing anything fundamentally wrong when we treat this problem as a binary classification?

The rest of the class is about what is wrong with what we just did!

Censoring and survival analysis#

Time to event and censoring#

  • When we treat the problem as a binary classification problem, we predict whether a customer would churn or not at a particular point in time, when the data was collected.

  • If a customer has not churned yet, wouldn’t it be more useful to understand when they are likely to churn so that we can offer them promotions etc?

  • Here we are actually interested in the time till the event of churn occurs.

There are many situations where you want to analyze the time until an event occurs. For example,

  • the time until a customer leaves a subscription service (this dataset)

  • the time until a disease kills its host

  • the time until a piece of equipment breaks

  • the time that someone unemployed will take to land a new job

  • the time until you wait for your turn to get a surgery

Although this branch of statistics is usually referred to as Survival Analysis, the event in question does not need to be related to actual “survival”. The important thing is to understand that we are interested in the time until something happens, or whether or not something will happen in a certain time frame.

In our dataset there is a column called “tenure”, which encodes this temporal aspect of the data.

6464 50
5707 2
3442 29
3932 2
6124 57
  • The tenure column is the number of months the customer has stayed with the company.

  • But we only have information about this till the point we collected the data.

❓❓ Questions for you#

But why is this different? Can’t you just use the techniques you learned so far (e.g., regression models) to predict the time (tenure in our case)? Take a minute to think about this. What could be possible scenarios for the duration column?

The answer would be yes if you could observe the actual time in all occurrences, but you usually cannot. Frequently, there will be some kind of censoring which will not allow you to observe the exact time that the event happened for all units/individuals that are being studied.

train_df[["tenure", "Churn"]].head()
tenure Churn
6464 50 No
5707 2 No
3442 29 No
3932 2 Yes
6124 57 No
  • What this means is that we don’t have correct target values to train or test our model.

  • This is a problem!

Let’s consider some approaches to deal with this censoring issue.

Approach 1: Only consider the examples where “Churn”=Yes#

Let’s just consider the cases for which we have the time, to obtain the average subscription length.

train_df_churn = train_df.query(
    "Churn == 'Yes'"
)  # Consider only examples where the customers churned.
test_df_churn = test_df.query(
    "Churn == 'Yes'"
)  # Consider only examples where the customers churned.
customerID gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines InternetService OnlineSecurity ... DeviceProtection TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod MonthlyCharges TotalCharges Churn
3932 1304-NECVQ Female 1 No No 2 Yes Yes Fiber optic No ... Yes No No No Month-to-month Yes Electronic check 78.55 149.55 Yes
301 8098-LLAZX Female 1 No No 4 Yes Yes Fiber optic No ... No No Yes Yes Month-to-month Yes Electronic check 95.45 396.10 Yes
5540 3803-KMQFW Female 0 Yes Yes 1 Yes No No No internet service ... No internet service No internet service No internet service No internet service Month-to-month No Mailed check 20.55 20.55 Yes
4084 2777-PHDEI Female 0 No No 1 Yes No Fiber optic No ... No No Yes No Month-to-month No Electronic check 78.05 78.05 Yes
3272 6772-KSATR Male 0 No No 1 Yes Yes Fiber optic Yes ... No No No No Month-to-month Yes Electronic check 81.70 81.70 Yes

5 rows × 21 columns

(5282, 21)
(1370, 21)
['tenure', 'MonthlyCharges', 'TotalCharges']
preprocessing_notenure = make_column_transformer(
        make_pipeline(SimpleImputer(strategy="median"), StandardScaler()),
        numeric_features[1:],  # Getting rid of the tenure column
    (OneHotEncoder(handle_unknown="ignore"), categorical_features),
    ("passthrough", passthrough_features),
customerID gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines InternetService OnlineSecurity ... DeviceProtection TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod MonthlyCharges TotalCharges Churn
3932 1304-NECVQ Female 1 No No 2 Yes Yes Fiber optic No ... Yes No No No Month-to-month Yes Electronic check 78.55 149.55 Yes
301 8098-LLAZX Female 1 No No 4 Yes Yes Fiber optic No ... No No Yes Yes Month-to-month Yes Electronic check 95.45 396.10 Yes
5540 3803-KMQFW Female 0 Yes Yes 1 Yes No No No internet service ... No internet service No internet service No internet service No internet service Month-to-month No Mailed check 20.55 20.55 Yes
4084 2777-PHDEI Female 0 No No 1 Yes No Fiber optic No ... No No Yes No Month-to-month No Electronic check 78.05 78.05 Yes
3272 6772-KSATR Male 0 No No 1 Yes Yes Fiber optic Yes ... No No No No Month-to-month Yes Electronic check 81.70 81.70 Yes
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
4169 3663-MITLP Female 0 No No 15 Yes No Fiber optic No ... Yes No Yes Yes Month-to-month Yes Electronic check 101.25 1457.25 Yes
4143 4822-YCXMX Male 0 No No 25 Yes Yes Fiber optic No ... No No No Yes Month-to-month Yes Electronic check 84.80 2043.45 Yes
6257 1977-STDKI Female 1 No No 1 Yes No Fiber optic No ... No No No No Month-to-month Yes Electronic check 73.00 73.00 Yes
5857 0378-NHQXU Female 0 Yes Yes 17 Yes Yes Fiber optic No ... Yes No Yes No Month-to-month No Electronic check 88.25 1460.65 Yes
1346 2845-HSJCY Female 0 Yes Yes 14 Yes Yes Fiber optic No ... Yes No No Yes Month-to-month Yes Electronic check 87.25 1258.60 Yes

1370 rows × 21 columns

tenure_lm = make_pipeline(preprocessing_notenure, Ridge())
3932     2
301      4
5540     1
4084     1
3272     1
4169    15
4143    25
6257     1
5857    17
1346    14
Name: tenure, Length: 1370, dtype: int64
tenure_lm.fit(train_df_churn.drop(columns=["tenure"]), train_df_churn["tenure"]);
We can look at the “concordance index” which is more interpretable:

cph.score(train_df_surv, scoring_method="concordance_index")
cph.score(test_df_surv, scoring_method="concordance_index")

From the documentation here:

Another censoring-sensitive measure is the concordance-index, also known as the c-index. This measure evaluates the accuracy of the ranking of predicted time. It is in fact a generalization of AUC, another common loss function, and is interpreted similarly:

  • 0.5 is the expected result from random predictions,

  • 1.0 is perfect concordance and,

  • 0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0)

Here is an excellent introduction & description of the c-index for new users.

null_distribution chi squared
degrees_freedom 43
test_name log-likelihood ratio test
test_statistic p -log2(p)
0 2206.68 <0.005 inf
The ``p_value_threshold`` is set at 0.01. Even under the null hypothesis of no violations, some
covariates will be below the threshold by chance. This is compounded when there are many covariates.
Similarly, when there are lots of observations, even minor deviances from the proportional hazard
assumption will be flagged.

With that in mind, it's best to use a combination of statistical tests and visual tests to determine
the most serious violations. Produce visual plots using ``check_assumptions(..., show_plots=True)``
and looking for non-constant lines. See link [A] below for a full example.
null_distribution chi squared
degrees_of_freedom 1
model <lifelines.CoxPHFitter: fitted with 5282 total...
test_name proportional_hazard_test
test_statistic p -log2(p)
Contract_Month-to-month km 0.07 0.80 0.33
rank 0.00 0.97 0.04
Contract_One year km 14.52 <0.005 12.81
rank 10.11 <0.005 9.41
Contract_Two year km 8.86 <0.005 8.42
rank 7.69 0.01 7.49
Dependents_No km 0.07 0.79 0.34
rank 0.07 0.79 0.34
Dependents_Yes km 0.07 0.79 0.34
rank 0.07 0.79 0.34
DeviceProtection_No km 0.07 0.79 0.34
rank 0.08 0.77 0.37
DeviceProtection_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
DeviceProtection_Yes km 0.70 0.40 1.31
rank 0.76 0.38 1.38
InternetService_DSL km 0.32 0.57 0.81
rank 0.28 0.59 0.75
InternetService_Fiber optic km 1.02 0.31 1.68
rank 0.98 0.32 1.64
InternetService_No km 0.25 0.62 0.69
rank 0.26 0.61 0.72
MonthlyCharges km 1.65 0.20 2.33
rank 1.72 0.19 2.40
MultipleLines_No km 1.57 0.21 2.25
rank 1.87 0.17 2.55
MultipleLines_No phone service km 0.03 0.86 0.21
rank 0.05 0.83 0.27
MultipleLines_Yes km 1.92 0.17 2.59
rank 2.35 0.13 3.00
OnlineBackup_No km 0.29 0.59 0.76
rank 0.24 0.63 0.68
OnlineBackup_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
OnlineBackup_Yes km 1.25 0.26 1.92
rank 1.17 0.28 1.84
OnlineSecurity_No km 0.02 0.88 0.19
rank 0.09 0.77 0.38
OnlineSecurity_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
OnlineSecurity_Yes km 0.56 0.46 1.13
rank 0.85 0.36 1.49
PaperlessBilling_No km 0.03 0.86 0.21
rank 0.02 0.90 0.16
PaperlessBilling_Yes km 0.03 0.86 0.21
rank 0.02 0.90 0.16
Partner_No km 0.27 0.60 0.73
rank 0.37 0.54 0.88
Partner_Yes km 0.27 0.60 0.73
rank 0.37 0.54 0.88
PaymentMethod_Bank transfer (automatic) km 0.44 0.51 0.98
rank 0.51 0.48 1.07
PaymentMethod_Credit card (automatic) km 1.46 0.23 2.14
rank 1.70 0.19 2.38
PaymentMethod_Electronic check km 0.06 0.81 0.30
rank 0.05 0.82 0.29
PaymentMethod_Mailed check km 2.36 0.12 3.01
rank 2.85 0.09 3.45
PhoneService_No km 0.03 0.86 0.21
rank 0.05 0.83 0.27
PhoneService_Yes km 0.03 0.86 0.21
rank 0.05 0.83 0.27
SeniorCitizen km 0.00 0.95 0.08
rank 0.00 0.95 0.08
StreamingMovies_No km 1.10 0.30 1.76
rank 1.25 0.26 1.93
StreamingMovies_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
StreamingMovies_Yes km 2.45 0.12 3.09
rank 2.73 0.10 3.35
StreamingTV_No km 1.09 0.30 1.76
rank 0.90 0.34 1.55
StreamingTV_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
StreamingTV_Yes km 2.48 0.12 3.12
rank 2.23 0.14 2.89
TechSupport_No km 0.49 0.49 1.04
rank 0.50 0.48 1.07
TechSupport_No internet service km 0.25 0.62 0.69
rank 0.26 0.61 0.72
TechSupport_Yes km 1.92 0.17 2.59
rank 2.01 0.16 2.68
gender_Female km 0.22 0.64 0.65
rank 0.08 0.78 0.35
gender_Male km 0.22 0.64 0.65
rank 0.08 0.78 0.35
1. Variable 'Contract_One year' failed the non-proportional test: p-value is 0.0001.

   Advice: with so few unique values (only 2), you can include `strata=['Contract_One year', ...]`
in the call in `.fit`. See documentation in link [E] below.

2. Variable 'Contract_Two year' failed the non-proportional test: p-value is 0.0029.

   Advice: with so few unique values (only 2), you can include `strata=['Contract_Two year', ...]`
in the call in `.fit`. See documentation in link [E] below.

[A]  https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html
[B]  https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html#Bin-variable-and-stratify-on-it
[C]  https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html#Introduce-time-varying-covariates
[D]  https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html#Modify-the-functional-form
[E]  https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html#Stratification

Other approaches / what did we not cover?#

There are many other approaches to modelling in survival analysis:

  • Time-varying proportional hazards.

    • What if some of the features change over time, e.g. plan type, number of lines, etc.

  • Approaches based on deep learning, e.g. the pysurvival package.

  • Random survival forests.

  • And more…

Types of censoring#

There are also various types and sub-types of censoring we didn’t cover:

  • What we did today is called “right censoring”

  • Sub-types within right censoring

    • Did everyone join at the same time?

    • Other reasons the data might be censored at random times, e.g. the person died?

  • Left censoring

  • Interval censoring


  • Censoring and incorrect approaches to handling it

    • Throw away people who haven’t churned

    • Assume everyone churns today

  • Predicting tenure vs. churned

  • Survival analysis encompasses both of these, and deals with censoring

  • And it can make rich and interesting predictions!

  • KM model -> doesn’t look at features

  • CPH model -> like linear regression, does look at the features


Some people working with this same dataset:

  • https://medium.com/@zachary.james.angell/applying-survival-analysis-to-customer-churn-40b5a809b05a

  • https://towardsdatascience.com/churn-prediction-and-prevention-in-python-2d454e5fd9a5 (Cox)

  • https://towardsdatascience.com/survival-analysis-in-python-a-model-for-customer-churn-e737c5242822

  • https://towardsdatascience.com/survival-analysis-intuition-implementation-in-python-504fde4fcf8e

lifelines documentation:

  • https://lifelines.readthedocs.io/en/latest/Survival%20analysis%20with%20lifelines.html

  • https://lifelines.readthedocs.io/en/latest/Survival%20Analysis%20intro.html#introduction-to-survival-analysis