Topic 7 Splines

Learning Goals

  • Explain the advantages of splines over global transformations and other types of piecewise polynomials
  • Explain how splines are constructed by drawing connections to variable transformations and least squares
  • Explain how the number of knots relates to the bias-variance tradeoff


Slides from today are available here.




Splines in caret

To build models with splines in caret, we proceed with the same structure for train() as we use for ordinary linear regression models. (Why can we just use least squares?) Refer to code from Topic 3: Overfitting & CV for a refresher.

To work with splines, we’ll use the splines package. The ns() function in that package creates the transformations needed to create a spline function for a quantitative predictor. This involves a small update to the formula:

# Before: all linear terms
ls_mod <- train(
    y ~ quant_x1 + quant_x2,
    ...
)

# With splines
ls_mod <- train(
    y ~ ns(quant_x1, df = 3) + ns(quant_x2, df = 3),
    ...
)
  • The df argument in ns() stands for degrees of freedom:
    • df = # knots + 1
    • The degrees of freedom are the number of coefficients in the transformation functions that are free to vary (essentially the number of underlying parameters behind the transformations).




Exercises

You can download a template RMarkdown file to start from here.

Before proceeding, install the splines package by entering install.packages("splines") in the Console.

We’ll continue using the College dataset in the ISLR package to explore splines. You can use ?College in the Console to look at the data codebook.

library(caret)
library(ggplot2)
library(dplyr)
library(readr)
library(ISLR)
library(splines)

data(College)

# A little data cleaning
college_clean <- College %>% 
    mutate(school = rownames(College)) %>% 
    filter(Grad.Rate <= 100) # Remove one school with grad rate of 118%
rownames(college_clean) <- NULL # Remove school names as row names

Exercise 1: Evaluating a fully linear model

We will model Grad.Rate as a function of 4 predictors: Private, Terminal, Expend, and S.F.Ratio.

  1. Make scatterplots with 2 different smoothing lines to explore potential nonlinearity. Adding the following to the normal scatterplot code will create a smooth (curved) blue trend line and a red linear trend line.

    geom_smooth(color = "blue", se = FALSE) +
    geom_smooth(method = "lm", color = "red", se = FALSE)
  2. Use caret to fit an ordinary linear regression model (no splines yet) with the following specifications:

    • Use 8-fold CV.
    • Use mean absolute error (MAE) to select a final model.
    • Select the simplest model for which the metric is within one standard error of the best metric.
    set.seed(___)
    ls_mod <- train(
    
    )
  3. Make plots of the residuals vs. the 3 quantitative predictors to evaluate the appropriateness of linear terms.

    ls_mod_data <- college_clean %>%
        mutate(
            pred = predict(ls_mod, newdata = college_clean),
            resid = ___
        )
    
    ggplot(ls_mod_data, ???) +
        ___ +
        ___ +
        geom_hline(yintercept = 0, color = "red")

Exercise 2: Evaluating a spline model

We’ll extend our linear regression model with spline functions of the quantitative predictors (leave Private as is).

  1. What tuning parameter is associated with splines? How do high/low values of this parameter relate to bias and variance?

  2. Update your code from Exercise 1 to model the 3 quantitative predictors with natural splines that have 2 knots (= 3 degrees of freedom).

    set.seed(___)
    spline_mod <- train(
    
    )
  3. Make plots of the residuals vs. the 3 quantitative predictors to evaluate if splines improved the model.

    spline_mod_data <- ___
    
    # Residual plots

Extra! Variable scaling

What is your intuition about whether variable scaling matters for the performance of splines?

Check you intuition by reusing code from Exercise 2, except with preProcess = "scale" inside train(). Call this spline_mod2.

How do the predictions from spline_mod and spline_mod2 compare? You could use a plot to compare or check out the all.equal() function.