Code walk-through for glmnet and tidymodels in R

Introduction

This tutorial was originally made for individuals of my lab who are using glmnet to model neurochemical data collected with fast-scan cyclic voltammetry. Materials and content are adapted from Professor Lucy McGowan and Julia Silge. It shows how to perform elastic net (and/or ridge, lasso) regression on a sample dataset. I made an accompanying video, which can be accessed here, and all materials can be downloaded here.

If you have any questions, just reach out! This was made for members of my lab, but anyone is welcome to use it.

# load packages
library(tidyverse)
library(tidymodels)
library(tune)
library(vip)
library(tictoc)
library(doParallel)
#load music data
music <- read_csv("music.csv")
set.seed(18)
#split the music data into a training and test set. 50% training and 50% test.
music_split <- initial_split(music, prop = 0.5) #split music data into a training and test set. 50% in each.
music_train <- training(music_split) #gather training data
music_test <- testing(music_split) #gather testing data
#create an object train_cv with 10 fold cross validation. This is only used from the training set.
train_cv <- vfold_cv(music_train, v = 10)
#create a recipe to preprocess the data with scaled variables so they can be propertly compared. Predict latitude based on all other variables.
netRec <- recipe(lat ~ ., data = music_train) %>% step_scale(all_predictors())
#just preprocess the data
netPrep <- netRec %>% prep()
#create a model specification -- what you want to do. 
#A linear regression with a penalty (lambda) and mixture (alpha) signaling elastic net. We use the package or "engine" glmnet
#mixture of 1 is lasso, mixture of 0 is ridge regression. Anything between is elastic.
netSpec <- linear_reg(penalty = tune(), mixture = 0.5) %>% set_engine("glmnet")
wf <- workflow() %>%
  add_recipe(netRec) %>%
  add_model(netSpec)
netGrid <- expand_grid(penalty = seq(0,10, by = 0.5))
#run parallel processing
doParallel::registerDoParallel()
tic("parallel") #initiate time
set.seed(18) #reset seed
#tune the elastic net workflow
netTuned <- tune_grid(wf, 
                      resamples = train_cv, #cross-validation model
                      grid = netGrid) #grid of tunable values
## 
## Attaching package: 'rlang'
## The following objects are masked from 'package:purrr':
## 
##     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
##     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
##     splice
## The following object is masked from 'package:magrittr':
## 
##     set_names
## 
## Attaching package: 'vctrs'
## The following object is masked from 'package:dplyr':
## 
##     data_frame
## The following object is masked from 'package:tibble':
## 
##     data_frame
## Loading required package: Matrix
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 4.0-2
toc() #stop clock
## parallel: 4.866 sec elapsed
#preview the tuning curve for penalty
netTuned %>% collect_metrics() %>%
  ggplot(aes(penalty, mean, color = .metric)) +
  geom_line(size = 1.5, show.legend = FALSE) +
  facet_wrap(~.metric, scales = "free", nrow = 2)

#select the penalty that yields the lowest rmse
bestPenalty <- netTuned %>% select_best("rmse")
#finalize workflow using the best penalty
final <- finalize_workflow(wf, bestPenalty)
final %>% 
  #fit on train data
  fit(music_train) %>%
  #pull the fit
  pull_workflow_fit() %>%
  #get variable importance and mutate it with the absolute value of importance.
  #reorder variables by their absolute importance
  vi() %>%
  mutate(Importance = abs(Importance),
         Variable = fct_reorder(Variable, Importance)) %>%
  #plot
  ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
  geom_col()

#fit final model (from the final workflow) on the testing data and collect metrics
last_fit(final, 
         split = music_split) %>% collect_metrics()
## # A tibble: 2 x 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 rmse    standard      16.9   Preprocessor1_Model1
## 2 rsq     standard       0.168 Preprocessor1_Model1
Previous