library(tidyverse)
## ── Attaching packages ──────────────────
## ✔ ggplot2 3.2.1 ✔ purrr 0.3.2
## ✔ tibble 2.1.3 ✔ dplyr 0.8.3
## ✔ tidyr 1.0.0 ✔ stringr 1.4.0
## ✔ readr 1.3.1 ✔ forcats 0.4.0
## ── Conflicts ── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(knitr)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(rpart)
options(scipen = 4)
Unless you’ve already taken a class on data mining or machine learning, a lot of the analytics tasks you’ve undertaken have probably taken the form of “inference” problems. Common inference problems include:
These are all what Mullainathan and Spiess “\(\hat \beta\)” problems (“beta-hat problems”). Essentially, these are all problems that begin with you putting down a model
\[ y = X\beta + \epsilon, \] estimating \(\beta\), and making some conclusions about the world based on those estimates and corresponding statistical significance analyses.
Prediction problems are different. When we’re doing prediction, we aren’t interested in \(\beta\). Instead, we’re interested in being able to accurately predict \(y\) from information \(x\). These are what M&S call “\(\hat y\)” problems (“y-hat problems”). Prediction is a very useful paradigm to know about for a number of reasons. First, prediction problems are ubiquitous, even in policy settings. Second, prediction is much easier than inference. Whereas inference often relies on various assumptions holding, prediction is largely assumption-free.
These notes introduce you to prediction in its most common form—binary classification—and teach you the basics of training and evaluating different classifiers.
Many of the problems you’ll come across in the future will likely be classification problems (and if they’re not, there’s you can typically turn them into classification problems). These are problems where your outcome variable \(y\) is binary or categorical. E.g., \(y\) might be the indicator that an email is spam, a transation is fraud, or that a student graduates from college. There are certainly notable cases where a good prediction of a quantitative outcome is desired. E.g., stock prices, home values, crop yields. But “most” problems do wind up being ones of classification.
Let’s get started. We’ll use a marketing data set where we have observations on whether bank customers who were contacted by the bank’s sales team opened a term deposit (“subscribed”). Let’s start by loading the data.
marketing <- read_delim("http://www.andrew.cmu.edu/user/achoulde/94842/data/bank-full.csv",
delim = ";")
## Parsed with column specification:
## cols(
## age = col_double(),
## job = col_character(),
## marital = col_character(),
## education = col_character(),
## default = col_character(),
## balance = col_double(),
## housing = col_character(),
## loan = col_character(),
## contact = col_character(),
## day = col_double(),
## month = col_character(),
## duration = col_double(),
## campaign = col_double(),
## pdays = col_double(),
## previous = col_double(),
## poutcome = col_character(),
## y = col_character()
## )
What does the data contain?
str(marketing)
## Classes 'spec_tbl_df', 'tbl_df', 'tbl' and 'data.frame': 45211 obs. of 17 variables:
## $ age : num 58 44 33 47 33 35 28 42 58 43 ...
## $ job : chr "management" "technician" "entrepreneur" "blue-collar" ...
## $ marital : chr "married" "single" "married" "married" ...
## $ education: chr "tertiary" "secondary" "secondary" "unknown" ...
## $ default : chr "no" "no" "no" "no" ...
## $ balance : num 2143 29 2 1506 1 ...
## $ housing : chr "yes" "yes" "yes" "yes" ...
## $ loan : chr "no" "no" "yes" "no" ...
## $ contact : chr "unknown" "unknown" "unknown" "unknown" ...
## $ day : num 5 5 5 5 5 5 5 5 5 5 ...
## $ month : chr "may" "may" "may" "may" ...
## $ duration : num 261 151 76 92 198 139 217 380 50 55 ...
## $ campaign : num 1 1 1 1 1 1 1 1 1 1 ...
## $ pdays : num -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 ...
## $ previous : num 0 0 0 0 0 0 0 0 0 0 ...
## $ poutcome : chr "unknown" "unknown" "unknown" "unknown" ...
## $ y : chr "no" "no" "no" "no" ...
## - attr(*, "spec")=
## .. cols(
## .. age = col_double(),
## .. job = col_character(),
## .. marital = col_character(),
## .. education = col_character(),
## .. default = col_character(),
## .. balance = col_double(),
## .. housing = col_character(),
## .. loan = col_character(),
## .. contact = col_character(),
## .. day = col_double(),
## .. month = col_character(),
## .. duration = col_double(),
## .. campaign = col_double(),
## .. pdays = col_double(),
## .. previous = col_double(),
## .. poutcome = col_character(),
## .. y = col_character()
## .. )
marketing <- marketing %>%
mutate(y = as.numeric(y == "yes"))
Our outcome variable here is y
, whether or not a person opens an account. You’ll see above that we transformed the original yes/no y
to an indicator that a person subscribes.
You’re already familiar with linear regression, and you could certainly regress y
on the other variables in the data using linear regiression to construct a model. That approach turns out to be ill-suited to the analysis of binary outcome data. Among other things, when \(y\) is either \(0\) or \(1\), it is odd to fit a line to predict this type of outcome. For one, the linear model generally won’t be constrained to predict values between 0 and 1; it can give negative values or values >1. Linear increases on the probability scale are also unlikely to be a good description of the association between a given input \(x\) and the outcome \(y\).
The standard approach to modeling binary outcome data in regression is to use logistic regression. This is an example of a generalized linear model (glm). GLM’s generalize what you know about linear regression to outcome variable types that aren’t well modeled by the “gaussian” linear model \(y = X\beta + \epsilon\). Obviously binary outcomes \(y\) aren’t generated by this sort of process.
The setup for logistic regression is essentially as follows. The observed outcome \(y\) for an observation with features \(x\) is thought to come from a Bernoulli\((p(x))\) random variable, where the success probability \(p(x)\) is parameterized by
\[ \log\left( \frac{p}{1 - p} \right) = \beta_0 + \beta_1 x_1 + \dots + \beta_p x_p \] Essentially this is saying that instead of modeling \(y\) (or, technically, \(E(y | x)\)) as a linear function of \(x\), we’ll model \(y\) as a Bernoulli realization where the success probility \(p\) is a function of \(x\). The “linear” part of the generalized linear model is what’s being illustrated in the above expression: A transformation of \(p\) is being modeled as a linear function of \(x\). You can compare this with the standard linear regression model, which says that \(y \sim N(\mu(x), \sigma^2)\), where the mean \(\mu = E(Y \mid x)\) is a linear function of \(x\):
\[ \mu = \beta_0 + \beta_1 x_1 + \dots + \beta_p x_p \]
So let’s fit a logistic regression model and look at what we find.
marketing.glm <- glm(y ~ ., data = marketing, family = binomial())
summary(marketing.glm)
##
## Call:
## glm(formula = y ~ ., family = binomial(), data = marketing)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -5.7286 -0.3744 -0.2530 -0.1502 3.4288
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -2.535637780 0.183703164 -13.803 < 2e-16 ***
## age 0.000112719 0.002205165 0.051 0.959233
## jobblue-collar -0.309872593 0.072669201 -4.264 2.01e-05 ***
## jobentrepreneur -0.357103762 0.125564459 -2.844 0.004455 **
## jobhousemaid -0.504001652 0.136469021 -3.693 0.000221 ***
## jobmanagement -0.165278440 0.073292526 -2.255 0.024130 *
## jobretired 0.252362639 0.097217516 2.596 0.009436 **
## jobself-employed -0.298336079 0.111996400 -2.664 0.007726 **
## jobservices -0.223797106 0.084064904 -2.662 0.007763 **
## jobstudent 0.382135715 0.109029897 3.505 0.000457 ***
## jobtechnician -0.176016548 0.068931178 -2.554 0.010664 *
## jobunemployed -0.176713126 0.111642461 -1.583 0.113456
## jobunknown -0.313264379 0.233463307 -1.342 0.179656
## maritalmarried -0.179453495 0.058910580 -3.046 0.002318 **
## maritalsingle 0.092497647 0.067260667 1.375 0.169066
## educationsecondary 0.183528258 0.064792557 2.833 0.004618 **
## educationtertiary 0.378941502 0.075319068 5.031 4.88e-07 ***
## educationunknown 0.250478833 0.103896567 2.411 0.015915 *
## defaultyes -0.016681215 0.162837013 -0.102 0.918407
## balance 0.000012835 0.000005148 2.493 0.012651 *
## housingyes -0.675384337 0.043869060 -15.395 < 2e-16 ***
## loanyes -0.425371663 0.059989904 -7.091 1.33e-12 ***
## contacttelephone -0.163374330 0.075185612 -2.173 0.029784 *
## contactunknown -1.623216856 0.073171806 -22.184 < 2e-16 ***
## day 0.009968922 0.002496619 3.993 6.53e-05 ***
## monthaug -0.693907553 0.078474461 -8.842 < 2e-16 ***
## monthdec 0.691124324 0.176682753 3.912 9.17e-05 ***
## monthfeb -0.147320938 0.089413545 -1.648 0.099427 .
## monthjan -1.261718795 0.121702801 -10.367 < 2e-16 ***
## monthjul -0.830795589 0.077404978 -10.733 < 2e-16 ***
## monthjun 0.453622601 0.093669266 4.843 1.28e-06 ***
## monthmar 1.589890543 0.119853742 13.265 < 2e-16 ***
## monthmay -0.399111424 0.072285121 -5.521 3.36e-08 ***
## monthnov -0.873398521 0.084409802 -10.347 < 2e-16 ***
## monthoct 0.881437433 0.108030525 8.159 3.37e-16 ***
## monthsep 0.874058052 0.119497320 7.314 2.58e-13 ***
## duration 0.004193695 0.000064532 64.986 < 2e-16 ***
## campaign -0.090781782 0.010137033 -8.955 < 2e-16 ***
## pdays -0.000102685 0.000306089 -0.335 0.737268
## previous 0.010152353 0.006502908 1.561 0.118476
## poutcomeother 0.203478400 0.089855382 2.265 0.023543 *
## poutcomesuccess 2.291056017 0.082348964 27.821 < 2e-16 ***
## poutcomeunknown -0.091793506 0.093474710 -0.982 0.326093
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 32631 on 45210 degrees of freedom
## Residual deviance: 21562 on 45168 degrees of freedom
## AIC: 21648
##
## Number of Fisher Scoring iterations: 6
# Initialized seed for random number generation
set.seed(12345)
# Upsample the data to artifically overcome sample imbalance
marketing.more.idx <- sample(which(marketing$y == 1), 15000, replace = TRUE)
marketing.upsample <- rbind(marketing,
marketing[marketing.more.idx, ])
# Trim job strings to 5 characters
# marketing.upsample <- transform(marketing.upsample, job = strtrim(job, 5))
# Randomly select 20% of the data to be held out for model validation
test.indexes <- sample(1:nrow(marketing.upsample),
round(0.2 * nrow(marketing.upsample)))
train.indexes <- setdiff(1:nrow(marketing.upsample), test.indexes)
# Just pull the covariates available to marketers (cols 1:8) and the outcome (col 17)
marketing.train <- marketing.upsample[train.indexes, c(1:8, 17)]
marketing.test <- marketing.upsample[test.indexes, c(1:8, 17)]
When buliding models it is important that we hold out a subset of our data, typically called a “test set” or a “validation set”, or a “holdout set”. In this example we’re holding out a random 20% of our data. The purpose of this test set is to ensure that we get reasonable estimated of the prediction accuracy of our model even if we make mistakes during our “training” process that result in “overfitting”.
When you have a large number of covariates, it’s easy to overfit the data. When you overfit the training data, you get a model that describes the training data really well, but which doesn’t give good predictions on unseen data.
source: http://pingax.com/regularization-implementation-r/
library(glmnet) # Regularized regression
## Loading required package: Matrix
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 3.0-1
library(ranger) # random forests
We’ll start by fitting a logistic regression model to the training data.
marketing.glm <- glm(y ~ ., data = marketing.train, family = binomial())
pred.test.glm <- as.numeric(predict(marketing.glm, newdata = marketing.test, type = "response") > 0.5)
The code above fits a logistic regression model to the training data, and then gets predicted probabilities for the test data. The round
operation here is equivalent to thresholding those probabilities at 0.5 to form predictions of whether the person is a high earner.
# Confusion matrix for logistic regression
conf.glm <- table(marketing.test$y, pred.test.glm)
conf.glm
## pred.test.glm
## 0 1
## 0 7302 724
## 1 3005 1011
# How accurate is our model?
sum(diag(conf.glm)) / sum(conf.glm)
## [1] 0.6903338
That’s way better than 50%! But… is 50% accuracy really the baseline we want? You often hear that something is “better than a coin flip” or “no better than a coin flip”. Is a fair coin flip really the right baseline? Generally, no. Let’s look at what fraction of our test data are actually high earners
mean(marketing.train$y)
## [1] 0.3378314
Hmm… So if we guessed that no one subscribes, our accuracy would already be 0.6621686. That makes our accuracy of 0.6903338 a lot less impressive by comparison.
### Regularized logistic regression, with parameters tuned through cross-validation
# Extract y column
y.marketing <- marketing.train$y
# Get a numeric design matrix x
x.marketing <- model.matrix(~ . - y - 1, data = marketing.train)
x.marketing.test <- model.matrix(~ . - y - 1, data = marketing.test)
# Run cross-validated regularized regression
marketing.cv.glmnet <- cv.glmnet(x.marketing, y.marketing, family = "binomial")
# Have a look at the cv error plot
plot(marketing.cv.glmnet)
Let’s get our predictions for the test data
# Extract predictions from model selected by the 1se rule (simplest model within 1 standard error from the minimum)
pred.test.glmnet <- predict(marketing.cv.glmnet, x.marketing.test, s = "lambda.1se", type = "class")
# Confusion matrix for regularized logistic regression
conf.glmnet <- table(marketing.test$y, pred.test.glmnet)
conf.glmnet
## pred.test.glmnet
## 0 1
## 0 7570 456
## 1 3267 749
How did we do?
sum(diag(conf.glmnet)) / sum(conf.glmnet)
## [1] 0.6908321
Well… that wasn’t any better…
library(partykit)
## Loading required package: grid
## Loading required package: libcoin
## Loading required package: mvtnorm
marketing.tree <- rpart(as.factor(y) ~ ., data = marketing.train,
control = rpart.control(minsplit=50, cp=0.002))
marketing.party <- as.party(marketing.tree)
plot(marketing.party, gp = gpar(fontsize = 10))
pred.test.tree <- as.numeric(predict(marketing.tree, newdata = marketing.test)[,"1"] > 0.5)
# Confusion matrix for tree model
conf.tree <- table(marketing.test$y, pred.test.tree)
conf.tree
## pred.test.tree
## 0 1
## 0 7400 626
## 1 2883 1133
How did we do?
sum(diag(conf.tree)) / sum(conf.tree)
## [1] 0.7086032
That’s a little better.
marketing.rf <- ranger(y ~ ., data = marketing.train, importance = 'impurity')
pred.test.rf <- as.numeric(predict(marketing.rf, data = marketing.test)$predictions > 0.5)
# Confusion matrix for random forest model
conf.rf <- table(marketing.test$y, pred.test.rf)
conf.rf
## pred.test.rf
## 0 1
## 0 7450 576
## 1 2275 1741
How did we do?
sum(diag(conf.rf)) / sum(conf.rf)
## [1] 0.7632453
Way better!
But is overall accuracy really what we care about? How will we use this model in the future? Presumably we’ll be using the model to help guide a new marketing campaign. In that case our task will be to select a subset of new customers who we should contact, instead of contacting everyone. How do we think about our model’s performance in that type of setting?
Here’s a function that calculates a bunch of classification metrics based on a model’s confusion table. We’ll assess it on all of our models.
classSummary <- function(tbl) {
n <- sum(tbl)
prev <- sum(tbl[2,]) / sum(tbl)
acc <- sum(diag(tbl)) / n
prop.pos <- sum(tbl[,2]) / n
ppv <- tbl[2,2] / sum(tbl[,2])
fpr <- tbl[1,2] / sum(tbl[1,])
fnr <- tbl[2,1] / sum(tbl[2,])
spec <- 1 - fpr
sens <- 1 - fnr
lr.pos <- sens / fpr
lr.neg <- fnr / spec
out <- data.frame(value = round(c(n, prev, acc,
prop.pos,
ppv, fpr, fnr, spec, sens,
lr.pos, lr.neg), 3))
rownames(out) <- c("count",
"prevalence",
"accuracy",
"prop.positive",
"PPV",
"FPR",
"FNR",
"Specificity (TNR)",
"Sensitivity (TPR)",
"LR+",
"LR-")
out
}
classSummary(conf.glm)
## value
## count 12042.000
## prevalence 0.333
## accuracy 0.690
## prop.positive 0.144
## PPV 0.583
## FPR 0.090
## FNR 0.748
## Specificity (TNR) 0.910
## Sensitivity (TPR) 0.252
## LR+ 2.791
## LR- 0.822
classSummary(conf.glmnet)
## value
## count 12042.000
## prevalence 0.333
## accuracy 0.691
## prop.positive 0.100
## PPV 0.622
## FPR 0.057
## FNR 0.813
## Specificity (TNR) 0.943
## Sensitivity (TPR) 0.187
## LR+ 3.283
## LR- 0.862
classSummary(conf.tree)
## value
## count 12042.000
## prevalence 0.333
## accuracy 0.709
## prop.positive 0.146
## PPV 0.644
## FPR 0.078
## FNR 0.718
## Specificity (TNR) 0.922
## Sensitivity (TPR) 0.282
## LR+ 3.617
## LR- 0.779
classSummary(conf.rf)
## value
## count 12042.000
## prevalence 0.333
## accuracy 0.763
## prop.positive 0.192
## PPV 0.751
## FPR 0.072
## FNR 0.566
## Specificity (TNR) 0.928
## Sensitivity (TPR) 0.434
## LR+ 6.041
## LR- 0.610
Let’s bind those together to make them easier to compare
tibble(metric = rownames(classSummary(conf.glm)),
logistic = classSummary(conf.glm)$value,
lasso = classSummary(conf.glmnet)$value,
tree = classSummary(conf.tree)$value,
rf = classSummary(conf.rf)$value)
## # A tibble: 11 x 5
## metric logistic lasso tree rf
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 count 12042 12042 12042 12042
## 2 prevalence 0.333 0.333 0.333 0.333
## 3 accuracy 0.69 0.691 0.709 0.763
## 4 prop.positive 0.144 0.1 0.146 0.192
## 5 PPV 0.583 0.622 0.644 0.751
## 6 FPR 0.09 0.057 0.078 0.072
## 7 FNR 0.748 0.813 0.718 0.566
## 8 Specificity (TNR) 0.91 0.943 0.922 0.928
## 9 Sensitivity (TPR) 0.252 0.187 0.282 0.434
## 10 LR+ 2.79 3.28 3.62 6.04
## 11 LR- 0.822 0.862 0.779 0.61
marketing.preds <- tibble(glm = predict(marketing.glm, newdata = marketing.test, type = "response"),
lasso = predict(marketing.cv.glmnet, x.marketing.test, s = "lambda.min", type = "response")[,1],
tree = predict(marketing.tree, newdata = marketing.test)[,"1"],
rf = predict(marketing.rf, data = marketing.test)$predictions,
y = marketing.test$y)
Let’s look at ROC curves and the AUC. ROC curves trace out the TPR on the y axis and the FPR on the x axis as we vary the threshold used for classification.
roc.list <- with(marketing.preds, roc(y ~ glm + lasso + tree + rf))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
plot(roc.list[[1]])
plot(roc.list[[2]], col = "red", add = TRUE)
plot(roc.list[[3]], col = "purple", add = TRUE)
plot(roc.list[[4]], col = "steelblue", add = TRUE)
Let’s calculate the AUCs for these (the areas under the curve). The AUC
with(marketing.preds, auc(y ~ glm))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.665
with(marketing.preds, auc(y ~ lasso))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.665
with(marketing.preds, auc(y ~ rf))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Area under the curve: 0.847
OK… but what variables are important? We don’t have p-values or coefficient estimates, but we do have “importance” measures that tell us how important variables are for predictions.
sort(marketing.rf$variable.importance)
## default loan marital education job housing age
## 21.16536 119.14167 152.43088 175.14662 305.50842 357.57898 791.81364
## balance
## 876.18534
library(edarf)
pd <- partial_dependence(marketing.rf,
data = marketing.test,
vars = c("balance"))
plot_pd(pd)
One of the reasons that the logistic models might not be performing well is that variable like balance appear to have non-linear relationships with the outcome. There looks to be a sharp discontinuity in the relationship between outcome and balance, as modeled by the random forest.