library(ggplot2) # graphics library
library(ISLR) # contains code and data from the textbook
library(knitr) # contains kable() function
library(tree) # For the tree-fitting 'tree' function
library(rpart) # For nicer tree fitting
library(partykit) # For nicer tree plotting
## Loading required package: grid
library(MASS) # For Boston data
options(scipen = 4) # Suppresses scientific notation
You will need the
Carseats
data set from theISLR
library in order to complete this exercise.
Please run all of the code indicated in §8.3.1 of ISLR, even if I don’t explicitly ask you to do so in this document.
View()
command on the Carseats
data to see what the data set looks like.#View(Carseats)
High
variable for the purpose of classification. Our goal will be to classify whether Carseat sales in a store are high or not.High <- with(Carseats, ifelse(Sales <= 8, "No", "Yes"))
Carseats <- data.frame(Carseats, High)
prop.table(table(Carseats$High))
##
## No Yes
## 0.59 0.41
tree
command to fit a decision tree to every other variable in the Carseats
data other than Sales
. Run summary
on your tree object. Run the plot
and text
commands on your tree object.# We don't want to use the Sales variable as an input because
# our outcome variable, High, is derived from Sales
tree.carseats <- tree::tree(High ~ . -Sales, Carseats)
summary(tree.carseats)
##
## Classification tree:
## tree::tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
plot(tree.carseats)
text(tree.carseats,pretty=0)
summary(tree.carseats)$used
## [1] ShelveLoc Price Income CompPrice Population Advertising
## [7] Age US
## 11 Levels: <leaf> CompPrice Income Advertising Population ... US
names(Carseats)[which(!(names(Carseats) %in%summary(tree.carseats)$used))]
## [1] "Sales" "Education" "Urban" "High"
# This tells us: # misclass, # total
summary(tree.carseats)$misclass
## [1] 36 400
misclass.rate <- summary(tree.carseats)$misclass[1] / summary(tree.carseats)$misclass[2]
misclass.rate
## [1] 0.09
plot
. (You do not need to write down an answer)tree.carseats
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
*
at the end mean? How many High = yes
observations are there at this node?18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
*
indicates that this split corresponds to a leaf node. There are 16 observations in this final node. 0.62500 * 16 = 10
of them have High = yes
.tree.carseats
model to just the training data.set.seed(2)
train <- sample(1:nrow(Carseats), 200)
Carseats.test <- Carseats[-train,]
High.test <- High[-train]
tree.carseats <- tree(High~.-Sales,Carseats,subset=train)
tree.pred <- predict(tree.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
## High.test
## tree.pred No Yes
## No 86 27
## Yes 30 57
cv.tree
command to carry out 10-Fold Cross-Validation pruning on tree.subsets
. You’ll need to supply FUN = prune.misclass
as an argument to ensure that the error metric is taken to be the number of misclassifications instead of the deviance.cv.carseats <- cv.tree(tree.carseats, FUN=prune.misclass)
names(cv.carseats)
## [1] "size" "dev" "k" "method"
cv.carseats
## $size
## [1] 19 17 14 13 9 7 3 2 1
##
## $dev
## [1] 53 53 50 50 48 51 67 66 80
##
## $k
## [1] -Inf 0.0000000 0.6666667 1.0000000 1.7500000 2.0000000
## [7] 4.2500000 5.0000000 23.0000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
# Index of tree with minimum error
min.idx <- which.min(cv.carseats$dev)
min.idx
## [1] 5
# Number of leaves in that tree
cv.carseats$size[min.idx]
## [1] 9
# Number of misclassifications (this is a count)
cv.carseats$dev[min.idx]
## [1] 48
# Misclassification rate
cv.carseats$dev[min.idx] / length(train)
## [1] 0.24
par(mfrow = c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type="b")
plot(cv.carseats$k, cv.carseats$dev, type="b")
prune.misclass
command to prune tree.carseats
down to the subtree that has the lowest CV error. Plot this tree, and overlay text. Compare the variables that get used in this tree compared to those that get used in the unpruned tree. Does the pruned tree wind up using fewer variables?par(mfrow = c(1,1))
prune.carseats <- prune.misclass(tree.carseats, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty=0)
# Variables used
summary(prune.carseats)$used
## [1] ShelveLoc Price Advertising Age CompPrice
## 11 Levels: <leaf> CompPrice Income Advertising Population ... US
# Variables that are used in one model but not the other
c(setdiff(summary(prune.carseats)$used, summary(tree.carseats)$used),
setdiff(summary(tree.carseats)$used, summary(prune.carseats)$used))
## [1] "Income" "Population"
tree.pred <- predict(prune.carseats, Carseats.test, type="class")
confusion.pred <- table(tree.pred, High.test)
confusion.pred
## High.test
## tree.pred No Yes
## No 94 24
## Yes 22 60
# Misclassification rate
1 - sum(diag(confusion.pred)) / sum(confusion.pred)
## [1] 0.23
Carseats.train
. Second, unpruned trees can greatly overfit data.rpart
command instead of tree
. This results in a fit that can be converted into a party
object, and plotted in a more aesthetically pleasing way. The code below illustrates how to perform this conversion and how to get a tree out of it. There are some accompanying questions below the code.# Fit a decision tree using rpart
# Note: when you fit a tree using rpart, the fitting routine automatically
# performs 10-fold CV and stores the errors for later use
# (such as for pruning the tree)
carseats.rpart <- rpart(High ~ . -Sales , Carseats, method="class", subset=train)
# Plot the CV error curve for the tree
plotcp(carseats.rpart)
# Identify the value of the complexity parameter that produces
# the lowest CV error
cp.min <- carseats.rpart$cptable[which.min(carseats.rpart$cptable[,"xerror"]),"CP"]
# Prune using the CV error minimizing choice of the complexity parameter cp
carseats.rpart.pruned <- prune(carseats.rpart, cp = cp.min)
# Convert pruned tree to a party object
carseats.party <- as.party(carseats.rpart.pruned)
# Plot
plot(carseats.party)
These questions prompt you to follow along with §8.3.2 in ISL. We’ll once again be working with the
Boston
data set.
set.seed(1)
train <- sample(1:nrow(Boston), nrow(Boston)/2)
tree.boston <- tree(medv ~ . ,Boston,subset=train)
summary(tree.boston)
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "dis"
## Number of terminal nodes: 8
## Residual mean deviance: 12.65 = 3099 / 245
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -14.10000 -2.04200 -0.05357 0.00000 1.96000 12.60000
plot(tree.boston)
text(tree.boston, pretty=0)
lstat (percentage of the population that is lower socioeconomic status) is the first variable that is split on. The split occurs at lstat = 9.715
.
The numbers refer to the average medv
(median home value) in that leaf node.
set.seed(2)
for(i in 1:10) {
train.b <- sample(1:nrow(Boston), nrow(Boston)/2)
tree.boston.b <- tree(medv ~ . , Boston, subset=train.b)
print(as.character(summary(tree.boston.b)$used))
}
## [1] "lstat" "rm" "dis" "nox" "crim"
## [1] "lstat" "rm" "dis" "nox" "crim"
## [1] "rm" "lstat" "crim" "ptratio"
## [1] "rm" "lstat" "age" "crim"
## [1] "lstat" "rm" "dis" "ptratio" "nox" "crim"
## [1] "lstat" "rm" "nox"
## [1] "lstat" "rm" "ptratio" "tax" "nox"
## [1] "lstat" "rm" "dis" "indus" "nox" "crim"
## [1] "lstat" "rm" "dis" "age" "crim"
## [1] "rm" "lstat" "dis"
lstat
or rm
.dev
stands for “deviance”, which is the same as MSE for regression (prediction) problems.# Run CV to find best level at which to prune
cv.boston <- cv.tree(tree.boston)
# Construct a plot (dev = MSE on y-axis)
plot(cv.boston$size,cv.boston$dev,type='b')
# Prune the tree, display pruned tree
prune.boston <- prune.tree(tree.boston,best=5)
plot(prune.boston)
text(prune.boston,pretty=0)
# Get predictions from pruned tree
yhat.tree <- predict(tree.boston, newdata=Boston[-train,])
boston.test <- Boston[-train,"medv"]
# Construct plot of observed values (x-axis) vs predicted values (y-axis)
plot(yhat.tree, boston.test)
# Add a diagonal line
abline(0, 1)
# Calculate test set MSE
mean((yhat.tree-boston.test)^2)
## [1] 25.04559