Topic 16 Trees (Coding)
Learning Goals
- Implement trees in
tidymodels
- Visualize learned trees and use visualizations to explore predictions made from trees
- Interpret variable importance measures from trees
Trees in tidymodels
To build tree models in tidymodels
, first load the package and set the seed for the random number generator to ensure reproducible results:
library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
tidymodels_prefer()
set.seed(___) # Pick your favorite number to fill in the parentheses
To fit a classification tree, we can adapt the following:
decision_tree() %>%
ct_spec <- set_engine(engine = "rpart") %>%
set_args(
cost_complexity = NULL, # default is 0.01 (used for pruning a tree)
min_n = NULL,
tree_depth = NULL
%>%
) set_mode("classification") # change this for regression tree
recipe(___ ~ ___, data = ______)
data_rec <-
workflow() %>%
data_wf <- add_model(ct_spec) %>%
add_recipe(data_rec)
data_wf %>% # or use tune_grid() to tune any of the parameters above
fit_mod <- fit(data = _____)
Visualizing and interpreting the “best” tree
# Plot the tree (make sure to load the rpart.plot package first)
%>%
fit_mod extract_fit_engine() %>%
rpart.plot()
# Get variable importance metrics
# Sum of the goodness of split measures (impurity reduction) for each split for which it was the primary variable.
%>%
fit_mod extract_fit_engine() %>%
pluck('variable.importance')
Exercises
You will need to install the rpart.plot
package before proceeding.
Bias-variance tradeoff plot
If tune_out
is your output from tune_grid()
, use autoplot
to inspect the plot of test performance vs. tuning parameters:
autoplot(tune_out) + theme_classic()
Inspecting evaluation metrics
Use collect_metrics()
to view evaluation metrics for all tuning parameter combinations. Use filter()
to only show metics for the best values of the tuning parameters.
Interpret these metrics in the context of your data. Is this tree model good? Interpret in light of the no-information rate for your dataset.
%>%
tune_out collect_metrics() %>%
filter(___)
Visualizing your tree
If final_fit
is your output after finalize_workflow()
and fit()
, use the following to plot your tree.
Look at page 3 of the rpart.plot
package vignette to understand what the plot components mean.
Take a look at what variables were used for splitting early on (near the top) and what variables tended to be chosen for splitting overall. Do these make sense contextually?
%>% extract_fit_engine() %>% rpart.plot() final_fit
Variable importance
We can obtain numerical variable importance measures from trees. These measure, roughly, “the total decrease in node impurities from splitting on the variable” (even if the variable isn’t ultimately used in the split).
What are the 3 most important predictors by this measure? Does this agree with you might have expected based on the plot of the best fitted tree? What might greedy behavior have to do with this?
%>%
final_fit extract_fit_engine() %>%
pluck("variable.importance")
Predictions and exploring error
Classification setting
If your outcome is categorical: use the following to add hard and soft predictions to your dataset.
YOUR_DATA %>%
YOUR_DATA <- mutate(
pred_prob = predict(final_fit, new_data = YOUR_DATA, type = "prob"),
pred_class = predict(final_fit, new_data = YOUR_DATA, type = "class")
)
Make some plots exploring the relationship between soft predictions and the predictors. Do the same for hard predictions. This will help you visualize what relationships the tree model is learning.
We can also make plots to explore misclassification rates. Use the code below to create a
misclassified
variable that indicates whether or not a case was misclassified. Explore how this variable relates to your predictors.
mutate(misclassified = YOUR_OUTCOME!=pred_class)
Regression setting
If your outcome is quantitative: use the following to add predictions to your dataset.
YOUR_DATA %>%
YOUR_DATA <- mutate(
pred = predict(final_fit, new_data = YOUR_DATA)
)
- Within
mutate()
, include the calculations needed to calculate residuals and make residual plots to assess any systematic errors in your model.
Backup dataset
If you weren’t able to get trees working for your project, explore the following dataset on urban land cover.
Context: Our goal will be to classify types of urban land cover in small subregions within a high resolution aerial image of a land region. Data from the UCI Machine Learning Repository include the observed type of land cover (determined by human eye) and “spectral, size, shape, and texture information” computed from the image. See this page for the data codebook.
library(dplyr)
library(readr)
library(ggplot2)
library(rpart.plot)
library(tidymodels)
tidymodels_prefer()
# Read in the data
read_csv("https://www.dropbox.com/s/r59esfepjw7qsg0/land_cover_training.csv?dl=1")
land <-
# There are 9 land types, but we'll focus on 3 of them
land %>%
land <- filter(class %in% c("asphalt", "grass", "tree"))
set.seed(123) # don't change this
vfold_cv(land, v = 10)
data_fold <-
decision_tree() %>%
ct_spec_tune <- set_engine(engine = 'rpart') %>%
set_args(cost_complexity = tune(),
min_n = 2,
tree_depth = NULL) %>%
set_mode('classification')
recipe(class ~ ., data = land)
data_rec <-
workflow() %>%
data_wf_tune <- add_model(ct_spec_tune) %>%
add_recipe(data_rec)
grid_regular(cost_complexity(range = c(-5, 1)), levels = 10)
param_grid <-
tune_grid(
tune_res <-
data_wf_tune, resamples = data_fold,
grid = param_grid,
metrics = metric_set(accuracy, roc_auc) #change this for regression trees
)
select_by_one_std_err(tune_res, metric = 'accuracy', desc(cost_complexity))
best_complexity <- finalize_workflow(data_wf_tune, best_complexity)
data_wf_final <-
fit(data_wf_final, data = land) land_final_fit <-