Topic 4 Cross-Validation
4.1 Discussion
Overfitting
From ISLR, page 32:
When a given method yields a small training MSE but a large test MSE, we are said to be overfitting the data. This happens because our statistical learning procedure is working too hard to find patterns in the training data, and may be picking up some patterns that are just caused by random chance rather than by true [trends].
When we overfit the training data, the test MSE will be very large because the supposed patterns that the method found in the training data simply don’t exist in the test data. Note that regardless of whether or not overfitting has occurred, we almost always expect the training MSE to be smaller than the test MSE because most statistical learning methods either directly or indirectly seek to minimize the training MSE. Overfitting refers specifically to the case in which a less flexible model would have yielded a smaller test MSE.
“Flexibility” in linear regression refers to the number of coefficients
- More coefficients (more predictors) = more flexible
- Fewer coefficients (fewer predictors) = less flexible
How do we prevent overfitting?
- Want an accuracy metric that allows us to choose which of several models will be most accurate on new data
- Adjusted R-squared is a good idea but restricted to linear regression
- Cross-validation is a much more general technique that allows estimation of the true error rate on new data (the test error)
- Use specific statistical learning methods suited to discourage including weak/useless predictors
- Shrinkage methods (coming soon!)
4.2 Exercises
Goals
- For what purposes would you use cross-validation?
- How is cross-validation useful for preventing overfitting?
You can download a template RMarkdown file to start from here.
You’ll continue working on the body fat dataset from last time. The following variables were recorded.
fatBrozek
: Percent body fat using Brozek’s equation: 457/Density - 414.2fatSiri
: Percent body fat using Siri’s equation: 495/Density - 450density
: Density determined from underwater weighing (gm/cm^3).age
: Age (years)weight
: Weight (lbs)height
: Height (inches)neck
: Neck circumference (cm)chest
: Chest circumference (cm)abdomen
: Abdomen circumference (cm)hip
: Hip circumference (cm)thigh
: Thigh circumference (cm)knee
: Knee circumference (cm)ankle
: Ankle circumference (cm)biceps
: Biceps (extended) circumference (cm)forearm
: Forearm circumference (cm)wrist
: Wrist circumference (cm)
The focus is on predicting body fat percentage using Siri’s equation (fatSiri
) from easily measured variables: age, weight, height, and circumferences.
library(ggplot2)
library(dplyr)
bodyfat_train <- read.csv("https://www.dropbox.com/s/js2gxnazybokbzh/bodyfat_train.csv?dl=1")
# Remove the fatBrozek and density variables
bodyfat_train <- bodyfat_train %>%
select(-fatBrozek, -density)
4 models
Consider the 4 models below:mod1 <- lm(fatSiri ~ age+weight+neck+abdomen+thigh+forearm, data = bodyfat_train) mod2 <- lm(fatSiri ~ age+weight+neck+abdomen+thigh+biceps+forearm, data = bodyfat_train) mod3 <- lm(fatSiri ~ age+weight+neck+chest+abdomen+hip+thigh+biceps+forearm, data = bodyfat_train) mod4 <- lm(fatSiri ~ ., data = bodyfat_train) # The . means all predictors
- Before looking at the summary tables, predict:
- Which model will have the highest R squared?
- Which model will have the lowest training MSE?
- Find/compute the R squared and MSE for the 4 models to check your answers in part a.
- Which model do you think will perform worst on new test data? Why?
- Before looking at the summary tables, predict:
We’ll use the caret
package to perform cross-validation (and to run many different machine learning methods throughout the course). The caret
package is a great resource for the machine learning community because it aggregates machine learning methods written by tons of different authors in tons of different R packages into one single package. The advantage is that instead of learning 10 different styles of code for 10 different machine learning methods, we can use fairly similar code throughout the course.
Install the caret
package by running install.packages("caret")
in the Console.
- Cross-validation with the
caret
packageUse the code below to perform 10-fold cross validation for
mod1
to estimate the test MSE (\(\text{CV}_{(10)}\)).# Load the package library(caret) # Set up what type of cross-validation is desired train_ctrl_cv10 <- trainControl(method = "cv", number = 10) # To ensure the same results each time the code is run set.seed(253) # Fit (train) the model as written in mod1 # Also supply information about the type of CV desired for evaluating the model with the trControl argument # na.action = na.omit prevents errors if the data has missing values mod1_cv10 <- train( fatSiri ~ age+weight+neck+abdomen+thigh+forearm, data = bodyfat_train, method = "lm", trControl = train_ctrl_cv10, na.action = na.omit ) # The $ extracts components of an object # Peek at the "resample" part of mod1_cv10 - what info does it contain? mod1_cv10$resample # Estimate of test MSE # RMSE = square root of MSE mean(mod1_cv10$resample$RMSE^2)
- Adapt the code above to perform:
- 10-fold CV for model 2
- LOOCV for model 1 In doing so, look carefully at the structure of the code. What parts need to be repeated? What parts don’t?
(Hint:nrow(dataset)
gives the number of cases in a dataset.)
- Looking at the evaluation metrics
A completed table of evaluation metrics is below.- Which model performed the best on the training data?
- Which model performed best on the test set?
- Which model would be preferred using \(\text{CV}_{(10)}\) or LOOCV estimates of the test error?
- How is cross-validation helping us avoid overfitting?
Model | \(R^2\) | Training MSE | \(\text{CV}_{(10)}\) | LOOCV | Test set MSE |
---|---|---|---|---|---|
mod1 |
0.8103 | 14.52153 | 17.21062 | 18.16816 | 23.92333 |
mod2 |
0.8146 | 14.18762 | 19.64114 | 19.29848 | 23.90547 |
mod3 |
0.816 | 14.08022 | 21.24115 | 20.28958 | 23.63958 |
mod4 |
0.8162 | 14.06917 | 21.88440 | 21.26073 | 24.65370 |
- Practical issues: choosing \(k\)
- What do you think are the pros/cons of low vs. high \(k\)?
- If possible, it is advisable to choose \(k\) to be a divisor of the sample size. Why do you think that is?
Extra! Writing R functions
If you’re interested in learning about writing R functions, look at the following function that the instructor used to fill out the above evaluation metrics table.bodyfat_test <- read.csv("https://www.dropbox.com/s/7gizws208u0oywq/bodyfat_test.csv?dl=1") evaluate_model <- function(formula) { train_ctrl_cv10 <- trainControl(method = "cv", number = 10) train_ctrl_loocv <- trainControl(method = "cv", number = nrow(bodyfat_train)) mod_cv10 <- train(formula, data = bodyfat_train, method = "lm", trControl = train_ctrl_cv10, na.action = na.omit) mod_loocv <- train(formula, data = bodyfat_train, method = "lm", trControl = train_ctrl_loocv, na.action = na.omit) model_predictions <- predict(mod_cv10, newdata = bodyfat_test) test_mse <- mean((bodyfat_test$fatSiri - model_predictions)^2) c( cv10 = mean(mod_cv10$resample$RMSE^2), loocv = mean(mod_loocv$resample$RMSE^2), test = test_mse ) } set.seed(253) evaluate_model(fatSiri ~ age+weight+neck+abdomen+thigh+forearm) evaluate_model(fatSiri ~ age+weight+neck+abdomen+thigh+biceps+forearm) evaluate_model(fatSiri ~ age+weight+neck+chest+abdomen+hip+thigh+biceps+forearm) evaluate_model(fatSiri ~ .)
- Step through each line and see if you can understand the structure.
- How would you modify the function to work on arbitrary data? How would you have to change the function arguments (within the parentheses on the first line)?
- How would you modify the function to allow the user to choose \(k\)?