Customer churn example¶
In this example, we will be building a model to predict customer churn. Customer churn is a common use case that is mission-critical for most companies. As motivation, we wrote a use case deep-dive on customer churn which goes over much of the methodology that we'll employ here, and we highly recommend it as a companion piece to this documentation. Here, we will walk through building out a churn example using a sample dataset. This is a technical guide that you can use with your free Continual trial to jumpstart a customer churn use case at your organization.
First, login to the Continual UI and
create a new project. In this
example, we'll refer to the project name as
customer_churn_example. Also, ensure that you have installed the Continual CLI and are
logged into Continual via the CLI and set
customer_churn_example as your default project.
Next, we'll need a suitable dataset. For this example, we'll use the KKBox example available at Kaggle. KKBox is an Australian music subscription service, so we'll be able to use their data to try to predict if customers will continue using their service or churn. You may either download the data directly from the Kaggle URL or via the Kaggle CLI:
kaggle competitions download -c kkbox-churn-prediction-challenge
This will download the file as
.7z files. We'll now want to uncompress these files. These files are not small, so it may take several minutes to decompress them all. When finished, you'll have a directory with the following files:
This use case has 3 main data sources:
Transactions: a history of all customer transactions (roughly spanning the years 2015 to 2017).
Members: information on KKBox subscribers: gender, time of registration, etc.
User logs: listening metrics for all users for any day they used the service: total minutes listened, number of unique songs played, etc.
The training dataset contains user IDs with labels, but part of this exercise is learning how to construct the churn label from scratch (so you can do something similar on your own data), so we'll ignore this data for purposes of this exercise. We also notice that each data set has a "v2" associated with it. The gist is that the Kaggle dataset was updated with newer data that we can append to the original data. This data exists in the "v2" CSVs.
Now that we have our files, we'll want to move these into our cloud data warehouse. In this example, we'll be using Snowflake as our data warehouse, but you should be able to perform similar operations upon any supported data warehouses.
Let us know if you try a different path and have issues.
We'll also call out now that the size of the data here is a bit larger than you may find in a lot of tutorials. The uncompressed CSVs comprise of roughly 30GB, and we have over 400M user logs with 6M users and about 20M transactions. This is more than enough data to train a model on, and you're certainly welcome to sample the data into smaller sets before moving this into your data warehouse. We'll proceed by using all of the data, for any who are interested.
Getting this much data into the cloud is a fairly non-trivial task. And, it's actually the most difficult part of this tutorial. Approaches vary here and since we're in a tutorial, there's no wrong answer. If you have access to a data integration tool, like Fivetran, I'd recommend not making your life easy and uploading via their file support function. It's also worth noting that every cloud data warehouse some support for this workflow -- Snowflake, BigQuery, RedShift, Databricks, etc.
Since we are using Snowflake, we'll provide a quick solution using
snowsql that involves loading the files into a staging area and then copying it into tables in snowflake. Before running the following script, make sure you create the table definitions in your target Snowflake schema. We have a quick DDL you can run in Snowflake here. In our example, we'll use
KKBOX.CHURN as the main schema for the data. Then, you can run the staging script in
CREATE STAGE IF NOT EXISTS customer_churn; PUT file:///path/to/file/transactions.csv @customer_churn -- repeat for every local file COPY INTO KKBOX.CHURN.TRANSACTIONS FROM @customer_churn files = ('transactions.csv.gz') file_format = (type = CSV skip_header = 1) --repeat for every staged file --check that your table now has data: SELECT * FROM KKBOX.CHURN.TRANSACTIONS LIMIT 10;
The above doesn't work in the snowflake web UI. You'll need to execute it via snowsql.
Snowflake will compress your staged files, so don't forget to add the
.gz to the end of your file names during the
COPY INTO command.
When finished, don't forget to delete the internal stage.
Now we celebrate! We have our source data in our data warehouse and the hardest part of the tutorial is over. If you're following along with the naming schema in the tutorial you should now have three tables populated:
KKBOX.CHURN.USER_LOGS. Now let's get started with our ML use case.
Prediction & Feature Engineering¶
Our companion blog post covers how to define churn, we recommend giving it a quick read. To recap, for any churn problem we would want to know:
What is the prediction period? I.E. how often will we be making churn predictions? This creates a natural cutoff date where we can make predictions for the rest of the period and aggregate features up to this point.
What is the churn threshold? I.E. how many days past expiration will we allow a customer before we consider it a churn.
What is our time range for each transaction? I.E. what are the start and end dates? When we are making a new prediction, we only consider customers who have an expiring service/contract during the associated window.
These concepts are visualized in the following graphic:
How you pick these will often be decided by your business needs. For our tutorial, we'll select the following:
Our prediction period is: every month.
Our churn threshold is: 30 days.
TRANSACTIONStable thankfully has a
membership_expire_datethat we can use as start and end dates.
Now that we have our prediction defined, we can build our tables. For those that want to fast forward to the results, our GitHub repo has a dbt project. You can use this to quickly build all the required tables for this use case. This should more or less run as is, but you'll need to either create a
continual profile in your
profiles.yaml file or change the profile in the
dbt_project.yaml file. As you begin to explore this use case more you may also be interested in modifying the variables set in
dbt_project.yaml, and if you deviated from our naming convention above, you'll also want to update the table name in the
sources.yaml file. When ready, simply execute:
For everyone else, you'll want to run the SQL script in the GitHub repo in Snowflake. This will create all the required tables for Continual to build a churn model. You should be able to cut and paste the script if you've followed along with our naming schema. If you've made any deviations, update the table names as necessary in the script and then run it in Snowflake.
user_logs_agg table is a rather large table that can take some time to build if you didn't sample your data before uploading it to the data warehouse.
After the dbt job or running the
feature-engineering.sql script, we'll primarily be interested in the following views/tables:
members_all: This has all user demographic data.
transactions_final: This has all transaction data.
user_logs_all: This has all user log data.
user_logs_agg: This has user log data aggregated up to each prediction period.
churn_model_definition: This contains all our churn labels.
The next step will be to register these resources into Continual and begin building a model.
Some may be interested in the operations we're performing here. We'll provide some quick insights into what's going on. If you're not interested in the nitty-gritty, feel free to skip to the next section.
More Details on Prediction & Feature Engineering¶
When inspecting transactions in the source table, you may notice that transactions are not completely sequential or disjoint. At first, you may assume they should be: i.e. a customer joins for
X months, at the end of that experience, they either renew or churn. However, customers can decide to cancel their service at any time, and these are logged as transactions and will have the
is_churn field populated. Further inspection also reveals that it's also possible that a user cancels, and then immediately signs up for another service. This is essentially an "upgrade" transaction, but it gets split into a couple of rows in the database.
An observation we can make is that the cancellations are not actually important to building the prediction definition. The main thing we need for training purposes in the
model_churn_definition table is to know when the next transaction occurs for every transaction and compare that to the expiration date in the row to see if that happens inside or outside the churn threshold. Assuming a cancellation happens one of two things happens:
A subsequent transaction occurs that represents an upgrade. If we remove the cancellation, the original transaction would compare its expiration date to the start date of the upgrade transaction. We would conclude that no churn has happened, so this is great.
No subsequent transaction occurs. If we remove the cancellation, the original transaction would not find the next transaction and we would conclude that churn occurred. This is also great.
So, we can safely remove these and it doesn't mess up our data at all (in fact, the expiration date seems to be inconsistently populated for cancellations, so this also alleviates another issue we'd otherwise have to deal with). This is what is hashed out in the
Another thing you'll notice about transactions if you stare at the data long enough is that: most customers tend to subscribe one month at a time or longer, but some of the data is on the sub-month or week level. We want our prediction period to be a month, so to handle these we group any transactions that occur in the same month together and aggregate the values. If you don't handle this situation, our
time_index will be non-unique for some transactions and this will result in some data blowing up as Continual starts building training data sets (see: joining feature sets for more info)
One additional thing we can do is featurize cancellations by computing whether or not the most recent transaction for a customer, as of the cutoff date, was a cancellation or not. Since this happens strictly before the cutoff date, we are not leaking information from the future into the model (which is very important), but there are scenarios where a customer may cancel before the cutoff date and, if this isn't part of an upgrade, there is no next transaction so when we go to make the prediction we actually know that the last contact with the customer was a cancellation, which seems like very relevant information regarding their likelihood of churning. It could be the case that cancellation occurs and then some customers still renew at a later date. We could also consider this as a separate interaction altogether and filter it out from our normal churn analysis. Depending on your requirement, there are various actions you can take on this information.
One last note! The Kaggle data set has a "v2" update that refreshes the data as of April 1st, 2017. So, for all intents and purposes, "today's date" is stuck at that date. We use that in our model definition to determine how many days past expiration a transaction is. As mentioned before, we have a lot of data here! More than we need to train on, actually. To prevent bringing in all the data during model training, we create a user-defined split in our model to control the size of the training (6 months), validation (1 month), and test data sets (1 month). The test set also ends in February of 2017. Since today's date is April 1st, 2017, not enough time has passed for any transaction ending in March to know if they churned or not.
In the real world, you're likely to hit several problems like this as well, so it's always a good idea to review the results of the Data Profiling to see if you can detect any oddities.
Building Feature Sets & a Model¶
Now we are ready to build our feature sets and model. If you are following along with dbt, you can simply execute the following on top of the dbt project:
For everyone else, the GitHub repo contains a folder with all the Continual
yaml files we need. If you navigate to the
./continual directory you can simply execute the following:
continual push ./featuresets/ ./models/
If you have changed the table locations in any previous steps, you'll need to modify the
query fields in these yaml files with the new location of the tables.
Congratulations, you've built your first churn model! Whichever command you execute will print a link to the Continual Web UI where you can monitor the status of the job. The model will take a few hours to complete, so you can refresh your coffee, respond to some emails, and read some web comics while you wait.
To get a little more into the details, regardless of the approach taken, four feature sets and one model are created in Continual. These correspond to the views/tables created in the previous step, and summarized by this diagram (which you can also check out on your Model "Schema" tab):
The four feature sets are registered in the
kkbox_user entity. The model connects to this entity via the
msno column, which acts as the index for all featuresets in the entity. Continual will then combine all the data into a training set into the model. In a real real world example, these table may be broken into many entities:
sales, etc. For the purposes of this example they are all identified by the
msno, so we've decided to keep them in one entity rather than fabricating more entities.
Evaluating the Model¶
Before we discuss the performance of our model, it's always good to have an idea of a baseline to compare it to. If you query the
churn_model_definition table you'll find that the churn percentage for any month is typically around 8-10%. This represents an imbalanced data set. In our blog, we specify that for these problems we wish to avoid accuracy as the model performance metric because it may lead the system to always predict no churn. As a baseline model, we know that a model that always predicts no churn is correct approximately 90% of the time. This sounds pretty good, but this model never predicts churn, so it's actually useless to the business.
It's always important to remember business context when working through an ML problem. If we are starting this example from scratch, our business may be interested in us correctly predicting any churn. Anything is better than nothing, so let's take our victories as baby steps. More sophisticated businesses may have more demanding requirements. If we're currently capturing 10% of churn, we might be enthused by bumping that number up to 20%, for example. However, it's also important to temper expectations and impose some reality checks on the process. An aggressive business could ask "why can't we catch all churn?" And, indeed we could do that: here's a model that always predicts churn. I predicted all churned customers correctly, but at the expense of putting everyone else in the churn bucket. From an operational standpoint, this is also a useless model.
A more reasonable baseline model may be simply predicting churn at the same frequency within which it manifests itself in the training data set. If we take 10% churn to be our expected value, then this model would:
Correctly Guess "no churn" (i.e. True Negative) 81% of the time (i.e. 90% chance of guessing no churn * 90% chance of it not actually being churn).
Correctly guess "churn" (i.e. True Positive) 1 % of the time (i.e. 10% chance of guessing churn * 10% chance of it actually being churn)
Incorrectly guess "churn" or "no churn" (i.e. False Positive/False Negative) 9% of the time (i.e. 90% or 10% chance of guessing * 10% or 90% chance of it actually being that).
We can visualize this analysis in a confusion matrix.
The business likely is most interested in predicting as much churn as possible. So, a metric we should consider to satisfy the business is recall -- which in this use case measures how often we predict churn out of the entire churn population (i.e. True Positive/(True Positive + False Negative)). For sake of argument, let's say that we're given a goal of hitting a recall of 40%. The thing the business cares about next is how many churn predictions are not correct -- i.e. the number of false positives that our model produces. This is the precision metric (i.e. True Positive/(True Positive + False Positive)). Again, let's assume that they are looking for a number in the area of 40% here as well. In some businesses they'll actually provide real dollar values to each outcome and you can optimize your strategy using these.
When your model has finished training, you can open up the model version from the "Change" page and start to inspect the results. The project in the GitHub repo sets up the model to optimize via the ROC AUC, which is generally a solid starting point for imbalanced problems like churn. On our model version overview page, we see an ensemble method performed the best in this run with an AUC of 0.90 on the validation set.
By navigating over to the model insights page, we can check out the confusion matrix.
We can see that our recall is just under 27% and our precision is at 70% for the test set. This is a pretty encouraging start. This model is doing much better than our baseline model of random guessing, but we are still a little short of our business expectations.
One thing to notice here is that our precision is very good here, and we have a lot of room to move this value around. By default, the system uses a threshold of 0.5 for binary classification problems. What this means is when the model makes a prediction it creates a weighted score for each class (in this case
False). The class with the highest weight (in a binary problem this would be weight over 0.5) is what Continual uses as the prediction, but it's perfectly acceptable to override this value with your own threshold. By moving the threshold down, the model will begin making more churn predictions. Some of these will be correct, some will be incorrect. So, we'll be increasing the number of True Positives and False Positives, but also decreasing the number of False Negatives. Essentially, we're trading precision for higher recall. We can quickly perform this exercise to determine a threshold value that satisfies our business criteria.
For classification problems, Continual provides a score associated with the prediction that you can use to modify the predictions themselves. If you're new to Continual and unsure where Continual is storing your predictions, you can find that information on the Overview page of your batch_prediction. This will include a query like the one below which you can use in your data warehouse to access your predictions.
SELECT * FROM my_feature_store.customer_churn_example.model_user_churn_predictions_history WHERE batch_prediction='projects/customer_churn_example/models/user_churn/batchPredictions/c7sc4kokgv0kd9dhis3g'
We can modify this query slightly to see how changing the threshold affects precision and recall. (Note: you'd typically want to set the treshold based on the values in the validation set.)
with predictions as ( SELECT msno, ts, features:split as split, to_boolean(features:is_churn) as is_churn_actual, case when is_churn_true_prediction_score >= <threshold_value> then True else False end as is_churn_prediction, is_churn_true_prediction_score FROM my_feature_store.customer_churn_example.model_user_churn_predictions_history WHERE batch_prediction='projects/customer_churn_example/models/user_churn/batchPredictions/c7sc4kokgv0kd9dhis3g' and split = 'VALI' ) select count_if(is_churn_actual = False and is_churn_prediction = False ) as tn, count_if(is_churn_actual = False and is_churn_prediction = True ) as fp, count_if(is_churn_actual = True and is_churn_prediction = False ) as fn, count_if(is_churn_actual = True and is_churn_prediction = True ) as tp, tp / (tp + fp) as precision, tp / (tp + fn) as recall from predictions
We know our business wants a recall of at least 40%, so we can simply decrease the threshold value in
when is_churn_true_prediction_score >= <threshold_value> then True until our recall reaches above 50%. With my model, I find with a threshold of 0.425, our recall is right around 40% and our precision is around 70%. Our business specified precision needed to be above 40%, so this is a satisfactory value. A bit more experimentation reveals that the precision stays above 40% until about a threshold of 0.20 (with a recall of almost 70% in this case!). So, this gives us a good threshold range to work with. if the business wants to capture more churn, we can do so, but at the expense of allowing more false positives. On this particular run, we find that a threshold set at .35 on the validation set has a precision & recall over 40% and also translates to a precision and recall over 40% on the test set as well. This is the threshold recommendation that we would make.
With a bit of quick SQL work, we can create the following table for the business to use to analyze predictions with the desired threshold:
create or replace view my_feature_store.customer_churn_example.model_user_churn_predictions_final as ( SELECT msno, ts, case when is_churn_true_prediction_score >= 0.35 then True else False end as is_churn_prediction, FROM my_feature_store.customer_churn_example.model_user_churn_predictions );
Congratulations! This concludes the customer churn tutorial. Live long and prosper 🖖, and get in touch if you have any questions.