Overview

This page is part of a code repository to reproduce the example analyses in the manuscript “Cross validation for model selection: a primer with examples from ecology” by L.A. Yates, Z. Aandahl, S.A.Richards, and B.W.Brook (2022). This tutorial is designed to be used alongside the manuscript.

The scat dataset is sourced from the R package caret. Reid (2015) collected animal feces in coastal California. The data consist of DNA verified species designations as well as fields related to the time and place of the collection and the morphology of the scat itself. In this example, the aim of the analysis is to predict the biological family (felid or canid) for each scat observation, based on eight morphological characteristics, the scat location, and the carbon-to-nitrogen ratio (see Table 1 of the original data publication for variable descriptions).

Load and prepare scat data

Data preparations include: collapsing the response from three categories to two (i.e. binary classification), centering and scaling numerical predictors, and log-transforming two of the variables (based on advice in the original publication).

data(scat,package = "caret")

prep_data <- function(data){
  data %>% as_tibble %>% 
    select(y = Species, Number, Length, Diameter, Taper, TI, Mass, CN, Location, ropey, segmented) %>% 
    rename_with(str_to_lower) %>% # TI: taper index
    mutate(across(c(mass,cn),log)) %>% # following original data publication
    mutate(number = as.double(number)) %>% 
    drop_na() %>% 
    mutate(across(where(is_double), ~ .x %>% scale %>% as.numeric)) %>% # centre and scale predictors
    mutate(y = fct_collapse(y, canid = c("coyote","gray_fox"), felid = c("bobcat")) %>% 
             fct_relevel("canid","felid"),
           location = fct_collapse(location, mid = c("middle"), edge = c("edge","off_edge")) %>% 
             factor(labels = c(0,1))) %>% 
    mutate(across(c(ropey,segmented), factor))
}

scat <- prep_data(scat) # N.B. fail = canid (factor-level 1), success = felid (factor-level 2)
vars <- names(scat %>% select(-y))
scat
## # A tibble: 91 × 11
##    y     number  length diameter   taper      ti   mass     cn location ropey
##    <fct>  <dbl>   <dbl>    <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <fct>    <fct>
##  1 canid -0.524 -0.0149   1.87    0.950   0.0233  0.651  0.450 0        0    
##  2 canid -0.524  1.25     1.79    0.634  -0.144   0.784  1.42  0        0    
##  3 felid -0.524 -0.156    0.112  -0.722  -0.715  -0.169  0.285 1        1    
##  4 canid -0.524 -0.297   -0.0659 -0.182  -0.243  -0.332  1.48  1        1    
##  5 canid  0.858 -0.438    0.595  -0.485  -0.627   1.26   1.20  0        0    
##  6 canid  0.167 -0.156    0.722   0.0679 -0.262   0.501  0.645 0        1    
##  7 felid  1.55  -1.00    -0.676  -1.27   -1.07    0.562 -1.10  0        1    
##  8 felid  2.93  -1.14     0.900  -0.538  -0.715   1.31  -0.975 0        0    
##  9 felid -0.524  0.408   -0.218   0.107   0.0528  0.679 -0.855 0        0    
## 10 felid -1.21   3.09    -0.0913 -0.400  -0.410   0.204  0.112 1        1    
## # … with 81 more rows, and 1 more variable: segmented <fct>
scat %>% select(where(is_double)) %>% cor %>% corrplot::corrplot(method =c("number"))

Part 1A: logistic regression using MCC (discrete selection)

In this part, we perform exhaustive model selection, fitting all 1024 possible combinations of predictors. The models are Bernoulli-distributed generalized linear models (\(\texttt{logit}\) link), fit using maximum likelihood via the base function glm. Model performance is compared using Matthew’s correlation coefficient (MCC). To mitigate against overfitting, the modified one-standard-error (OSE) rule is applied to select a preferred model.

Generate list of model formulas and dimensions

We use the MuMIn package to generate the model formulas:

forms_all <- glm(y~., data = scat, family = binomial) %>% MuMIn::dredge(evaluate = F) %>% map(as.formula)
dim_all <- forms_all %>% map_dbl(~ .x %>% terms %>% attr("variables") %>% {length(.)-2})

For example, the tenth model is

forms_all[[10]]
## y ~ cn + location + 1

Creating a confusion matrix and computing MCC

The predictive performance of a binary (two-class) classifier can be summarised as a \(2\times2\) matrix called a confusion matrix. Labeling one class positive and the other negative, the matrix entries are the counts of the four prediction outcomes: true positives (TP), false positives (FP), false negatives (FN), and true negatives(TN):
\[\left(\begin{array}{cc}\mathrm{TP}&\mathrm{FP}\\\mathrm{FN}&\mathrm{TN}\end{array}\right)\]

MCC is computed from the entries of a confusion matrix, taking values between \(-1\) and \(+1\), where +1 indicates perfect prediction, 0.0 means the prediction is no better than random, and −1 indicates the worst prediction possible: \[\mathrm{MCC} = \frac{TP\cdot TN - FP\cdot FN}{\sqrt{(TP+FP)(TP+FN)(TN+FP)(TN+FN)}}\]

A confusion matrix is easily generated from the predictions of fitted model object, but care must be taken to place the TP in the top left (this does not affect symmetric metrics like MCC). Using the prevalance as a probability threshold equal to determine the response category (logistic regression models class probabilities), a confusion matrix can be computed as follows:

fit <- glm(forms_all[[10]],binomial, scat)
# compute prevalance for the threshold
thresh <- scat %>% pull(y) %>% {.=="felid"} %>% {sum(.)/length(.)} # 0.549
# re-order levels so that TRUE (1) precedes FALSE (0)
pred <- as.numeric(predict(fit,type= "response")>=thresh)%>% factor(levels = c(1,0))
obs <- as.numeric(fit$data$y!=levels(fit$data$y)[1]) %>% factor(levels = c(1,0))
cm <- tibble(pred,obs) %>% table
cm
##     obs
## pred  1  0
##    1 37 14
##    0 13 27

Now we define a function to compute MCC:

# computes MCC from a confusion matrix `cm`
mcc <- function(cm){
  (cm[1,1]*cm[2,2] - cm[1,2]*cm[2,1])/
    (sqrt((cm[1,1]+cm[1,2])*(cm[1,1]+cm[2,1])*(cm[2,2]+cm[1,2])*(cm[2,2]+cm[2,1])))
}

mcc(cm)
## [1] 0.3995122

Using cross validation to estimate predictive MCC

So far we have computed an MCC estimate using within sample data. To obtain an out-of-sample or predictive estimate of MCC we use repeated \(k\)-fold CV. For each repeat, we generate a single confusion matrix and corresponding MCC estimate. We use the rsample package (part of the tidymodels package set) to perform the data splitting: \(10\)-fold CV, repeated 50 times.

set.seed(770); cv_data_1 <- vfold_cv(scat,10,50) %>% mutate(thresh = thresh)
cv_data_1
## #  10-fold cross-validation repeated 50 times 
## # A tibble: 500 × 4
##    splits          id       id2    thresh
##    <list>          <chr>    <chr>   <dbl>
##  1 <split [81/10]> Repeat01 Fold01  0.549
##  2 <split [82/9]>  Repeat01 Fold02  0.549
##  3 <split [82/9]>  Repeat01 Fold03  0.549
##  4 <split [82/9]>  Repeat01 Fold04  0.549
##  5 <split [82/9]>  Repeat01 Fold05  0.549
##  6 <split [82/9]>  Repeat01 Fold06  0.549
##  7 <split [82/9]>  Repeat01 Fold07  0.549
##  8 <split [82/9]>  Repeat01 Fold08  0.549
##  9 <split [82/9]>  Repeat01 Fold09  0.549
## 10 <split [82/9]>  Repeat01 Fold10  0.549
## # … with 490 more rows

The returned rsample object is a tibble and we add a column to store probability threshold (here fixed to \(0.5\), but we will allow it be tuned in a subsequent analysis). The column splits contains \(500\) (\(10\times50\)) splits of the original data. For each split, the test and training data can be extracted using the functions analysis and split, respectively. For example:

split <- cv_data_1$splits[[1]] # take the 1st of 500 splits
analysis(split)
## # A tibble: 81 × 11
##    y     number  length diameter   taper      ti   mass     cn location ropey
##    <fct>  <dbl>   <dbl>    <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <fct>    <fct>
##  1 canid -0.524 -0.0149   1.87    0.950   0.0233  0.651  0.450 0        0    
##  2 canid -0.524  1.25     1.79    0.634  -0.144   0.784  1.42  0        0    
##  3 felid -0.524 -0.156    0.112  -0.722  -0.715  -0.169  0.285 1        1    
##  4 canid -0.524 -0.297   -0.0659 -0.182  -0.243  -0.332  1.48  1        1    
##  5 canid  0.858 -0.438    0.595  -0.485  -0.627   1.26   1.20  0        0    
##  6 canid  0.167 -0.156    0.722   0.0679 -0.262   0.501  0.645 0        1    
##  7 felid  1.55  -1.00    -0.676  -1.27   -1.07    0.562 -1.10  0        1    
##  8 felid  2.93  -1.14     0.900  -0.538  -0.715   1.31  -0.975 0        0    
##  9 felid -1.21   3.09    -0.0913 -0.400  -0.410   0.204  0.112 1        1    
## 10 canid -1.21  -0.438   -1.39   -0.841  -0.459  -0.146  3.03  1        1    
## # … with 71 more rows, and 1 more variable: segmented <fct>
assessment(split)
## # A tibble: 10 × 11
##    y     number length diameter   taper      ti    mass      cn location ropey
##    <fct>  <dbl>  <dbl>    <dbl>   <dbl>   <dbl>   <dbl>   <dbl> <fct>    <fct>
##  1 felid -0.524  0.408   -0.218  0.107   0.0528  0.679  -0.855  0        0    
##  2 canid -1.21   0.267   -1.59  -1.03   -0.617  -1.45   -0.213  1        1    
##  3 felid -1.21  -0.156    0.824 -1.06   -1.06    0.0215  0.0135 1        0    
##  4 canid -1.21   0.690   -0.930  0.476   0.762  -0.237   0.720  0        1    
##  5 felid  0.858  0.690    0.773 -1.02   -1.03    0.884  -0.312  0        0    
##  6 felid  0.167  0.267    0.925  0.641   0.0823  0.875  -0.262  0        1    
##  7 felid -0.524 -1.00    -0.803 -0.452  -0.252  -1.15    0.410  1        0    
##  8 canid  0.167 -1.00    -0.651 -1.64    0.407  -0.887  -0.573  0        1    
##  9 canid -1.21   1.25    -0.346  0.753   0.673  -1.05   -0.363  0        1    
## 10 felid -0.524 -0.579    0.875 -0.0836 -0.400   0.223  -0.739  1        1    
## # … with 1 more variable: segmented <fct>

The following function takes a repeated \(k\)-fold rsample object, a model formula, and a metric (e.g., MCC) and returns a tibble containing a metric estimate for each \(k\)-fold repeat.

fit_confusion_glm <- function(cv_data, formula = NULL, metric){
  get_confusion_matrix <- function(form, split, thresh){
    glm(form, data = analysis(split), family = binomial()) %>% 
      {tibble(pred = predict(.,newdata = assessment(split), type = "response") >= thresh, 
              y = assessment(split)$y != levels(assessment(split)$y)[1])} %>% 
      mutate(across(everything(), ~ .x %>% as.numeric() %>% factor(levels = c(1,0)))) %>% 
      table %>% {c(tp = .[1,1], fp = .[1,2], fn = .[2,1], tn = .[2,2])}
  }
  if(!is.null(formula)) form <- list(formula) else form <- cv_data$form_glm
  with(cv_data, mapply(get_confusion_matrix, form, splits, thresh, SIMPLIFY = T)) %>% t %>% 
    as_tibble %>% 
    bind_cols(rep = cv_data$id, .) %>% 
    group_by(rep) %>% 
    summarise(metric = matrix(c(sum(tp),sum(fn),sum(fp),sum(tn)),2,2) %>% metric) %>% 
    mutate(rep = 1:n())
}

For example, applying the function to one of the models gives:

mcc_f10 <- fit_confusion_glm(cv_data_1, forms_all[[10]],mcc)
mcc_f10
## # A tibble: 50 × 2
##      rep metric
##    <int>  <dbl>
##  1     1  0.332
##  2     2  0.311
##  3     3  0.308
##  4     4  0.353
##  5     5  0.355
##  6     6  0.308
##  7     7  0.376
##  8     8  0.379
##  9     9  0.330
## 10    10  0.332
## # … with 40 more rows
mcc_f10 %>% pull(metric) %>% {c(mean = mean(.), sd = sd(.))}
##       mean         sd 
## 0.34423001 0.02514177

Model selection using the modified OSE rule

We estimate MCC for all 1024 models using the repeated \(k\)-fold data-splitting scheme created above.

mcc_est <- pbmclapply(forms_all, function(form) fit_confusion_glm(cv_data_1,form, mcc)$metric, mc.cores = 40) %>% bind_cols() # (220 seconds using 40 cores)

The following function takes the repeated CV MCC estimates and computes the following summary statistics for each model: the mean MCC; the standard error of the MCC (\(\sigma_m\)), standard error of the MCC difference with respect to the best model (\(\sigma_m^{{\mathrm{diff}}}\)), and the modified standard error (\(\sigma_m^{{\mathrm{adj}}}\)) (see manuscript for definitions):

# computes summary statistics for model score estimates
make_plot_data <- function(metric_data, levels){
  best_model <- metric_data %>% map_dbl(mean) %>% which.max() %>% names()
  tibble(model = factor(names(metric_data), levels = levels), 
         metric = metric_data %>% map_dbl(mean),
         se = metric_data %>% map_dbl(sd),
         se_diff = metric_data %>% map(~ .x - metric_data[[best_model]]) %>% map_dbl(sd),
         se_mod = sqrt(1 -cor(metric_data)[best_model,])*se[best_model])
}

We apply the function to the MCC estimates and filter the results to select only the best performing model at each dimension (i.e., number of model parameters)

metric_summary <- make_plot_data(mcc_est, levels = names(mcc_est)) %>% mutate(dim = dim_all)
## Warning in cor(metric_data): the standard deviation is zero
metric_by_dim <- metric_summary %>% 
  group_by(dim) %>% 
  arrange(-metric) %>% 
  slice(1) %>% 
  ungroup() %>% 
  arrange(dim) %>% 
  filter(!dim %in% c(0,9)) # reduce model set to be plotted

Finally, we apply the modified OSE rule by determining the least complex model whose mean MCC estimate lies within \(\sigma_m^{{\mathrm{adj}}}\) of the best estimate. We plot the results.

best_mod_ose_dim <- metric_by_dim %>% 
  filter(metric + se_mod >= max(metric)) %>% 
  filter(dim == min(dim)) %>% pull(dim)

# main plot for scat example in manuscript
metric_by_dim %>% 
  ggplot(aes(factor(dim, levels = dim))) +
    geom_linerange(aes(ymin = metric - se_mod, ymax = metric + se_mod), col = "black") +
    geom_point(aes(y = metric), size = 2) +
    theme_classic() +
    geom_point(aes(y = metric), shape = 1, size = 6, col = "black", 
               data = ~ .x %>% filter(dim == best_mod_ose_dim)) + # circle the OSE-preferred model
    labs(subtitle = "CV estimates for scat models", x = "Number of parameters",y = "MCC")

The model with two predictors is identified using the modified OSE rule, but we can see that the one-predictor model is very close to being selected. The corresponding model formulas at each dimension are:

forms_all[metric_by_dim$model]
## $`2`
## y ~ cn + 1
## 
## $`34`
## y ~ cn + number + 1
## 
## $`42`
## y ~ cn + location + number + 1
## 
## $`52`
## y ~ cn + diameter + mass + number + 1
## 
## $`56`
## y ~ cn + diameter + length + mass + number + 1
## 
## $`64`
## y ~ cn + diameter + length + location + mass + number + 1
## 
## $`320`
## y ~ cn + diameter + length + location + mass + number + taper + 
##     1
## 
## $`480`
## y ~ cn + diameter + length + location + mass + ropey + segmented + 
##     taper + 1
## 
## $`1024`
## y ~ cn + diameter + length + location + mass + number + ropey + 
##     segmented + taper + ti + 1

Part 1B: logistic regression using MCC (penalised regression)

As an alternative to discrete selection, we take the global logistic model and use penalised regression to shrink or regularise the parameter estimates. We use elastic net regression which comprises a family of penalised regression models, index by a hyperparameter \(\alpha\), \(0\leq\alpha\leq1\), for which the extreme values \(\alpha = 0\) and \(\alpha=1\) correspond to LASSO (least absolute shrinkage and selection operator) and ridge regression, respectively. The LASSO is able shrink some parameter estimates to zero, performing effective variable selection, whereas ridge regression shrinks all parameters towards, but never attaining, zero. Mathematically, the regularised parameters \(\widehat{\boldsymbol\theta} =(\widehat\theta_1,\widehat\theta_2,...,\widehat\theta_p)\) are those minimising the penalised function

\[\begin{equation} \label{eq:elastic_net} f(\mathbf{x};\mathbf{\boldsymbol{\widehat\theta}}) + \lambda\left(\alpha||\boldsymbol{\widehat\theta}||^1 + \frac{(1-\alpha)}{2}||\boldsymbol{\widehat\theta}||^2\right), \end{equation}\]

where \(f\) is the objective function (usually mean squared error or negative log density), and \(||\boldsymbol\theta||^1 = \sum_{j=1}^p |\theta_j|\) and \(||\boldsymbol\theta||^2 = \sum_{j=1}^p \theta_j^2\) are the \(L_1\)- and \(L_2\)-norm of the vector of model parameters, respectively. The regularisation parameter \(\lambda\) determines the strength of the penalty, implementing a trade-off between the size of the model’s parameter estimates (the shrinkage or effective complexity) and the minimised value of the (unconstrained) objective function \(f\).

Load package and perform an initial fit

Elastic net regression is implemented in the R package glmnet. The main CV fitting function accepts a fixed value of \(\alpha\) (alpha), the GLM specification for the global model, the number of folds, and the loss function or measure to be cross validated (type.measure, five options available). A vector of candidate values of the regularisation parameter \(\lambda\) (lambda) can be supplied, otherwise a default set is generated. For example:

library(glmnet)

fit.net <- cv.glmnet(x = scat %>% select(-y) %>% as.data.frame() %>% makeX(), 
                     y = scat %>% pull(y), 
                     family = "binomial", 
                     folds = 10,
                     type.measure = "class", # misclassification loss
                     alpha = 1) # lasso
fit.net %>% plot

fit.net$glmnet.fit %>% plot(xvar = "lambda")




Penalised LASSO using MCC

Unfortunately, it is not possible to use a custom-defined loss function in glmnet, but with a little effort we can still use the package to perform penalised regression for our chosen metric: MCC.

We start by defining a function to compute the predictive confusion matrix across a supplied range of \(\lambda\) values for a given rsample split.

# compute out-sample-sample (i.e., CV) confusion-matrix entries for all lambda values for a given split
get_cm <- function(split, rep, fold, alpha, lambda, thresh){
  glmnet(x = analysis(split) %>% select(-y) %>% as.data.frame() %>% makeX(),
         y = analysis(split) %>% pull(y),
         family = "binomial",
         alpha = alpha,
         lambda = lambda) %>% 
    predict(newx = assessment(split) %>% select(-y) %>% as.data.frame() %>% makeX(),
            type = "response") %>%
    as_tibble %>% 
    rename_with(~str_remove(.x,stringr::fixed("s"))) %>% 
    imap_dfr(~ tibble(y_pred = as.numeric(.x > thresh), 
                 y = assessment(split) %>% pull(y) %>% as.numeric() %>% {.-1},
                 tp = as.numeric(y==1 & y_pred==1),
                 fp = as.numeric(y==0 & y_pred==1),
                 fn = as.numeric(y==1 & y_pred==0),
                 tn = as.numeric(y==0 & y_pred==0),
                 rep = rep, fold = fold, lambda = lambda[as.numeric(.y)+1]))
}

For lasso (alpha = 1), we specify (after a little trial and error) the following sequence of candidate \(\mathrm{log}\,\lambda\) values:

log_lambda_lasso <- seq(-1.6,-4, length.out = 100)

and we compute the confusion matrices over all \(\lambda\) for the full set of splits cv_data_1 used in the previous analysis.

# compute confusion-matrix entries and aggregate across folds within a given repetition
# 2 seconds with 40 cores
cm_lasso <- pbmclapply(1:nrow(cv_data_1), 
                       function(i) get_cm(cv_data_1$splits[[i]], cv_data_1$id[[i]], cv_data_1$id2[[i]], 
                                          alpha = 1, 
                                          lambda = exp(log_lambda_lasso),
                                          thresh = thresh),
                       mc.cores = 40) %>% 
  bind_rows() %>% as_tibble() 

Within each repetition and \(\lambda\) value, we sum the confusion-matrix entries and compute MCC:

mcc_lasso <- cm_lasso %>% 
  group_by(rep, lambda) %>% 
  summarise(metric = matrix(c(sum(tp),sum(fp),sum(fn),sum(tn)),2,2) %>% mcc)
## `summarise()` has grouped output by 'rep'. You can override using the `.groups`
## argument.
mcc_lasso # a row for each repetition for each lambda value
## # A tibble: 5,000 × 3
## # Groups:   rep [50]
##    rep      lambda metric
##    <chr>     <dbl>  <dbl>
##  1 Repeat01 0.0183  0.334
##  2 Repeat01 0.0188  0.334
##  3 Repeat01 0.0193  0.334
##  4 Repeat01 0.0198  0.334
##  5 Repeat01 0.0203  0.334
##  6 Repeat01 0.0208  0.334
##  7 Repeat01 0.0213  0.334
##  8 Repeat01 0.0219  0.334
##  9 Repeat01 0.0224  0.334
## 10 Repeat01 0.0230  0.334
## # … with 4,990 more rows

To aid plot creation we compute the values of lambda that maximise the score:

# lambda value at best mean MCC estimate
lambda_best_lasso <- mcc_lasso %>% 
  group_by(lambda) %>%  
  summarise(metric = mean(metric)) %>% 
  filter(metric == max(metric, na.rm = T)) %>% 
  pull(lambda)

# MCC scores for all reps at best lambda value
mcc_best_lasso <- mcc_lasso %>% 
  filter(lambda == lambda_best_lasso) %>%
  arrange(lambda,rep) %>% pull(metric)

To select \(\lambda\), we compute the (modified) standard errors

metric_plot_data <- mcc_lasso %>% 
  group_by(lambda) %>% 
  arrange(lambda,rep) %>%
  summarise(se_ose = sd(metric),
            se_diff  = sd(metric - mcc_best_lasso),
            se_best = sd(mcc_best_lasso),
            rho_best_m = (se_diff^2 - se_ose^2 - se_best^2)/(-2*se_ose*se_best),
            se_mod = sqrt(1-(rho_best_m))*se_best,
            metric_diff = mean(metric - mcc_best_lasso),
            metric = mean(metric))

metric_mod_ose <- metric_plot_data %>% 
    filter(metric + se_mod > max(metric, na.rm = T)) %>% 
    filter(lambda == max(lambda))

metric_ose <- metric_plot_data %>% 
    filter(se_ose >= abs(metric_diff)) %>% 
    filter(lambda == max(lambda))

which allows us to plot the tuning results and apply the modified OSE rule:

plot_lasso <- 
  metric_plot_data %>% 
  ggplot(aes(x = log(lambda))) +
  geom_vline(aes(xintercept = metric_mod_ose$lambda %>% log), col = "grey70", lty = "dashed") +
  geom_vline(aes(xintercept = lambda_best_lasso %>% log), col = "grey70", lty = "dashed") +  
  geom_linerange(aes(ymax = metric + se_mod, ymin = metric -se_mod, col = "m-OSE"), size = 0.5, 
                 data = metric_mod_ose, position = position_nudge(0.0)) +
  geom_linerange(aes(ymax = metric + se_ose, ymin = metric - se_ose, col = "OSE"), size = 0.5, 
                 data = metric_ose, position = position_nudge(0.0)) +
  geom_point(aes(y = metric), col = "grey60") + 
  geom_point(aes(y = metric), col = "grey10", data = ~ .x %>% filter(metric == max(metric))) + 
  geom_line(aes(y = metric), col = "grey30", lty = "solid") + 
  geom_point(aes(y = metric), shape = 1, size = 6, data = metric_mod_ose, col = "black") +
  geom_point(aes(y = metric), size = 1, data = metric_mod_ose) +
  geom_point(aes(y = metric), size = 1, data = ~ .x %>% filter(lambda == lambda_best_lasso), col = "black") +
  labs(subtitle = NULL, x = NULL, y = "MCC") +
  scale_colour_manual(name = NULL, values = c("black",blues9[8]), 
                      labels = c(modOSE = expression(sigma^"diff"),OSE = expression(sigma^"best"))) +
  theme_classic() +
  theme(panel.grid = element_blank()) 
plot_lasso

Note that the ordinary OSE based on \(\sigma^{\mathrm{best}}\) selects a more heavily regularised model than the modified OSE based on \(\sigma^{\mathrm{mod}}\) (see manuscript for further details).

To see the effect of \(\lambda\) on the parameter estimates we refit the penalised regression models (using all of the data) and plot the corresponding trajectories:

fit_glmnet <- glmnet(x = scat %>% select(-y) %>% as.data.frame() %>% makeX(),
       y = scat$y,
       family = "binomial",
       lambda = exp(log_lambda_lasso),
       alpha = 1)

plot_est <- fit_glmnet$beta %>% as.matrix %>% t %>% as_tibble() %>% 
  mutate(lambda = log_lambda_lasso) %>%  #, metric = metric_plot_data$metric) %>% 
  pivot_longer(!any_of(c("lambda","metric")), values_to = "estimate", names_to = "coefficient") %>% 
  filter(estimate != 0 | coefficient == "cn") %>% 
  ggplot(aes(lambda,estimate)) + 
  geom_vline(aes(xintercept = metric_mod_ose$lambda %>% log), col = "grey70", lty = "dashed") +
  geom_vline(aes(xintercept = lambda_best_lasso %>% log), col = "grey70", lty = "dashed") +
  geom_line(aes(col = coefficient)) +
  geom_point(aes(col = coefficient), size =1, data = ~ .x %>% filter(lambda == metric_mod_ose$lambda %>% log)) +
  geom_hline(aes(yintercept = 0), lty = "longdash") +
  labs(y = "Parameter estimates", x = expression(paste("log(",lambda,")"))) +
  theme_classic() +
  theme(legend.position = "none")

plot_est

The above code is easily re-used to compute ridge estimates (i.e., \(\alpha= 0\)), or indeed any other \(\alpha\) values within the elastic-net family.

Tuning \(\alpha\) in elastic-net regression

Finally, to tune \(\alpha\) as well as \(\lambda\), we can nest the call to get_cm:

  alpha <-seq(0, 1, 0.025) # candidate alpha values

  metric_alpha <- lapply(alpha_vec, function(a){
    pbmclapply(1:nrow(cv_data_1), function(i) get_cm(cv_data_1$splits[[i]], 
                                                     cv_data_1$id[[i]], 
                                                     cv_data_1$id2[[i]], 
                                                     alpha = a, lambda, thresh),
               mc.cores = MAX_CORES) %>% 
      bind_rows %>% 
      group_by(rep, lambda) %>% 
      summarise(cm11 = sum(tp), cm12 = sum(fp), cm21 = sum(fn), cm22 = sum(tn), .groups = "keep") %>% 
      summarise(metric = (cm11*cm22 - cm12*cm21)/(sqrt((cm11+cm12)*(cm11+cm21)*(cm22+cm12)*(cm22+cm21))), .groups = "drop") %>% 
      group_by(lambda) %>%  
      summarise(metric = mean(metric)) %>% 
      filter(metric == max(metric, na.rm = T)) %>% 
      pull(metric) # return best mean metric estimate among all lambda values
    })

The results suggest that LASSO (\(\alpha=1\)) and ridge (\(\alpha=0\)) have comparable performance, whereas the elastic net estimate \(\alpha = 0.175\) is clearly superior:

  tibble(metric = metric_alpha %>% unlist, alpha = alpha) %>% 
    ggplot(aes(alpha,metric)) + geom_line() + theme_classic() +
    geom_point(data = ~ .x %>% filter(metric == max(metric))) +
    labs(subtitle = "Tuning alpha for elastic-net regularisation", y = "MCC")




Part 2: Regularising priors and Bayesian projective inference

Reference model

In a Bayesian setting, lasso-type and ridge-type regression can be implemented by an appropriate choice of prior distribution. The direct analogue of the frequentist lasso is the Laplacian (two-sided exponential) prior (Hastie2015), and Gaussian priors are equivalent to ridge regression. An alternative choice is the regularised horseshoe prior, which provides support for a proper subset of parameters to be far from zero, using prior information to guide both the number and the degree of regularisation of the non-zero parameter estimates.

Here we fit three models, using a different regularising prior for each, and we compare their predictive performance using pareto-smoothed-importance-sampling (PSIS) leave-one-out (LOO) CV—an easily computed approximate CV method for use with fitted Bayesian models. The three models are:

  1. uniform prior (improper),
  2. weakly-informative Gaussian priors (i.e., weak ridge-type regression),
  3. Laplacian priors (i.e., LASSO), and
  4. regularised horseshoe priors.

We fit the models using the package \(\texttt{brms}\): a front-end for the \(\texttt{stan}\) engine which implements cutting-edge Hamiltonian Monte Carlo (HMC) Markov chain methods combined with no u-turn sampling (NUTS) for fast efficient model fitting.

  fit.flat <- brm(y ~ ., family=bernoulli(), 
                   data=scat, chains=4, iter=2000, cores = 4,
                   save_pars = save_pars(all = TRUE))

  fit.ridge <- brm(y ~ ., 
                   family = bernoulli(), 
                   prior = prior(normal(0,10)), 
                   data = scat, 
                   chains = 4, 
                   iter = 2000, 
                   cores = 4)

  fit.lasso <- brm(y ~ ., 
                   family = bernoulli(), 
                   data = scat,
                   prior = set_prior("lasso(1)"),
                   chains = 4, 
                   iter = 2000, 
                   cores = 4)

  fit.hs <- brm(y ~ ., family=bernoulli(), data=scat,
                  prior=prior(horseshoe(scale_global = tau0, scale_slab = 1), class=b),
                  chains=4, iter=2000, cores = 4)

The models took about half a minute to compile, but less than 2 seconds to sample.

Combining the model into a list, we run PSIS-LOO and compare the results:

fits <- list(hs = fit.hs, lasso = fit.lasso, ridge = fit.ridge, flat = fit.flat)
fits %>% map(loo, cores = 10) %>% loo_compare()
## Warning: Found 2 observations with a pareto_k > 0.7 in model '.x[[i]]'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
## Warning: Found 3 observations with a pareto_k > 0.7 in model '.x[[i]]'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
##       elpd_diff se_diff
## hs     0.0       0.0   
## lasso -2.1       0.9   
## ridge -7.9       3.4   
## flat  -8.4       3.5

Although the LOO scores have been estimated for all four models, the pareto_k >0.7 diagnostic tell us that the LOO approximation may have failed for a few data points in the less regularised models. To address this, we refit (reloo = T) the model for each problematic observation; i.e., manually leaving out test points:

future::plan("multisession", workers = 10)
fits.loo <- fits %>% map(loo, reloo = T, future = T)

Using the future package for parallelisation, the models are quickly fit, giving the adjusted summary

fits.loo %>% loo_compare
##       elpd_diff se_diff
## hs     0.0       0.0   
## lasso -2.1       0.9   
## ridge -8.0       3.4   
## flat  -8.5       3.6

The horseshoe model is the best performing of the models and it also the most regularised. The degree of regularisation can be seen by comparing the effective number of parameters (p_loo) between models:

fits.loo %>% loo_compare %>% print(simplify = F)
##       elpd_diff se_diff elpd_loo se_elpd_loo p_loo se_p_loo looic se_looic
## hs      0.0       0.0   -52.0      4.1         3.8   0.5    104.0   8.3   
## lasso  -2.1       0.9   -54.1      4.4         7.0   0.9    108.1   8.7   
## ridge  -8.0       3.4   -60.0      6.8        14.7   2.2    120.0  13.5   
## flat   -8.5       3.6   -60.5      6.9        15.2   2.3    121.0  13.7

It is interesting to ‘see’ the effect of the regularising priors by examining the (marginal) posterior density plots.

library(bayesplot)
# plot posteriors 
plot_post <- function(fit){
  fit %>% as_tibble %>% select(-starts_with("l")) %>% 
    mcmc_areas(area_method = "scaled", prob_outer = 0.98) +
    xlim(c(-2.8,2)) +
    theme_classic()
}

ggpubr::ggarrange(
  fit.flat %>% plot_post + labs(subtitle = "Uninformative (flat)"), 
  fit.ridge %>% plot_post + labs(subtitle = "Weakly informative (ridge)"),
  fit.lasso %>% plot_post + labs(subtitle = "LASSO"),
  fit.hs %>% plot_post + labs(subtitle = "Horseshoe"),
  ncol = 1,
  labels = "AUTO"
) 

Projective Inference

Having specified, fit, and compared the candidate reference models, the predictive-inference step is easily implemented using the \(\texttt{projpred}\) package. Taking the horseshoe variant as the best available reference model, we use cross-validated forward step selection.

library(projpred)
vs <- cv_varsel(fit.hs, cv_method = "LOO", method = "forward")
vs %>% 
  plot(stats = c("elpd"), deltas = T) + 
  theme_classic() +
  theme(strip.text = element_blank(),
        strip.background = element_blank(),
        legend.position = "none") +
  labs(y = expression(Delta*"ELPD"))

solution_terms(vs)
##  [1] "cn"        "number"    "segmented" "taper"     "location"  "mass"     
##  [7] "ropey"     "length"    "ti"        "diameter"

The one-predictor model (cn) is an obvious choice and its selection is what we would have expected after seeing the posterior distribution of the reference model. The last step in the BPI process is to project posterior of the reference model onto the selected submodel (here, using all 4000 draws):

proj <- project(vs, nterms = 1, ndraws = 4000)

To see the effect of the projection, we first fit the same one-parameter submodel to the original data set (as if we had decided a priori that this was the preferred model)

fit.cn <- brm(y ~ cn, family=bernoulli(), data=scat,chains=4, iter=2000, cores = 4)

Finally, we plot the (marginal) posteriors of the cn parameter for the reference (ref), projected (proj) and non-projected (non-proj) models

  tibble(ref = fit.hs %>% as_tibble() %>% pull(b_cn),
       proj = proj %>% as.matrix %>% {.[,"b_cn"]},
       `non-proj` = fit.cn %>% as_tibble() %>% pull(b_cn)) %>% 
  mcmc_areas() +
  labs(x = expression(theta["carbon-nitrogen ratio"])) +
  theme_classic()

The larger effect size of the non-projected model can be interpreted as a selection-induced bias, arising due to a failure to account for selection uncertainty.

Part 3: Nested CV to compare logistic and random forests models

In this section we illustrate nested CV. Comparing parametric models to ‘black-box’ machine-learning models is not commonly done and for our choice of metric, MCC, the implementation requires some custom functions. Rather than show all of the gory details, we source a separate script and outline the basic steps (see GitHub repo for the script).

To get started, we generate the nested data splits using the \(\texttt{rsample}\) package

set.seed(7193) # sample(1e4,1)
cv_data_2 <- nested_cv(scat, outside = vfold_cv(v = 10, repeats = 50), inside = vfold_cv(v = 10))
vars <- names(scat %>% select(-y)); vars

The generated data-splitting scheme is 10-fold CV repeated 50 times (for the outer layer), where each of these 500 training sets contains a second (inner) layer of 10-fold CV splits—there are 5000 inner training sets altogether. The inner splits (training and test) are used to tune the hyperparameters of the model for each separate training set. Each tuned model is then fit to corresponding outer training set after which prediction to the test set and loss calculations are performed in the usual manner.

We load the pre-written functions and perform parameter tuning first. Random forest models require the number of trees (ntree) to be set and the tree depth (mtry): the latter is a hyperparameter that we will tune

source("files_scat/scat_funs.R")

# specify RF hyper-parameters
ntree <- 500 # set a fixed value
mtry_vec <- 1:(ncol(scat)-1)  # candidate tree-depth values

# run the tuning functions
rf_tune_values <- tune_rf(cv_data_2, mtry_vec, ntree, metric = mcc)  # 1:17 seconds for 50 repeats with 40 cores
glm_step_tune_values <- tune_glm_step(cv_data_2, vars, metric = mcc) # 1:40 seconds for 50 repeats with 40 cores

# add tuned hyperparameters values to the rsample object
cv_data_2$thresh_rf <- rf_tune_values$threshold
cv_data_2$thresh <- glm_step_tune_values$threshold
cv_data_2$mtry_rf <- rf_tune_values$mtry
cv_data_2$form_glm <- glm_step_tune_values$form

Let’s have look at the rsample object with tuned values added:

cv_data_2
## # Nested resampling:
## #  outer: 10-fold cross-validation repeated 50 times
## #  inner: 10-fold cross-validation
## # A tibble: 500 × 10
##    splits          id       id2    inner_resamples     thresh_rf thresh_glm_step
##    <list>          <chr>    <chr>  <list>                  <dbl>           <dbl>
##  1 <split [81/10]> Repeat01 Fold01 <vfold_cv [10 × 2]>      0.58            0.19
##  2 <split [82/9]>  Repeat01 Fold02 <vfold_cv [10 × 2]>      0.46            0.67
##  3 <split [82/9]>  Repeat01 Fold03 <vfold_cv [10 × 2]>      0.44            0.34
##  4 <split [82/9]>  Repeat01 Fold04 <vfold_cv [10 × 2]>      0.55            0.55
##  5 <split [82/9]>  Repeat01 Fold05 <vfold_cv [10 × 2]>      0.47            0.15
##  6 <split [82/9]>  Repeat01 Fold06 <vfold_cv [10 × 2]>      0.43            0.57
##  7 <split [82/9]>  Repeat01 Fold07 <vfold_cv [10 × 2]>      0.48            0.59
##  8 <split [82/9]>  Repeat01 Fold08 <vfold_cv [10 × 2]>      0.25            0.27
##  9 <split [82/9]>  Repeat01 Fold09 <vfold_cv [10 × 2]>      0.45            0.57
## 10 <split [82/9]>  Repeat01 Fold10 <vfold_cv [10 × 2]>      0.44            0.33
## # … with 490 more rows, and 4 more variables: thresh_glm_all <dbl>,
## #   mtry_rf <int>, form_glm_step <chr>, form_glm_all <list>

Now fit the models, which is much faster than tuning as we are only using the 500 outer folds. For the random forest models, we fit two alternatives: 1) only the subset most important (best) predictors is kept for each fold (keeping mtry predictors in total) 2) all predictors are kept, but the tuned mtry values are still used for fitting.

  fits_rf_best <- fit_rf(cv_data_2, ntree = ntree, type = "best", metric = mcc) # 3 seconds with 40 cores
  fits_rf_all<- fit_rf(cv_data_2, ntree = ntree, type = "all", metric = mcc) # 2 seconds with 40 cores
  fits_glm_step <- fit_confusion_glm(cv_data_2, metric = mcc)

Finally we collate and prepare the results

mcc_data_2 <- tibble(glm_step = fits_glm_step$metric,
                     rf_best = fits_rf_best$metric,
                     rf_all = fits_rf_all$metric)

mcc_plot_data_2 <- make_plot_data(mcc_data_2, names(tss_data_2))

and plot the model-comparison figure

plot_model_comparisons(mcc_plot_data_2, "se_mod") +
  labs(title = "Model comparison", subtitle = "Stage 2: nested CV with tuned hyper-parameter, score = MCC")

The random forest model is a better predictive model than the glm, with the best subset model the better performer of the two random forest variants. Interestingly, the tuned tree depth across 500 outer folds varied uniformly across the range of possible values (1 - 10), suggesting that other predictors (not just cn) do add value, but the linearity of the glm may have been too inflexible for their inclusion to be useful.

# table of number of variable selected across 500 outer folds
cv_data_2$mtry_rf %>% table # random forest
## .
##  1  2  3  4  5  6  7  8  9 10 
## 23 40 54 49 61 49 66 61 43 54
cv_data_2$form_glm_step %>% str_split(" +") %>% map_dbl(~ .x %in% vars %>% sum) %>% table # glm
## .
##   1   2   3   4   5   6   7   8   9  10 
##  39 110  96  94  71  58  18   9   4   1

Session Information

sessionInfo()
## R version 4.2.1 (2022-06-23)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 20.04.5 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
## 
## locale:
##  [1] LC_CTYPE=en_AU.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_AU.UTF-8        LC_COLLATE=en_AU.UTF-8    
##  [5] LC_MONETARY=en_AU.UTF-8    LC_MESSAGES=en_AU.UTF-8   
##  [7] LC_PAPER=en_AU.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_AU.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] bayesplot_1.9.0    projpred_2.1.2     brms_2.17.0        Rcpp_1.0.9        
##  [5] glmnet_4.1-4       Matrix_1.4-1       knitr_1.39         kableExtra_1.3.4  
##  [9] yardstick_1.0.0    workflowsets_1.0.0 workflows_1.0.0    tune_1.0.0        
## [13] rsample_1.1.0      recipes_1.0.1      parsnip_1.0.0      modeldata_1.0.0   
## [17] infer_1.0.2        dials_1.0.0        scales_1.2.0       broom_1.0.0       
## [21] tidymodels_1.0.0   forcats_0.5.1      stringr_1.4.1      dplyr_1.0.9       
## [25] purrr_0.3.4        readr_2.1.2        tidyr_1.2.0        tibble_3.1.7      
## [29] ggplot2_3.3.6      tidyverse_1.3.2   
## 
## loaded via a namespace (and not attached):
##   [1] utf8_1.2.2           lme4_1.1-30          tidyselect_1.1.2    
##   [4] htmlwidgets_1.5.4    grid_4.2.1           munsell_0.5.0       
##   [7] codetools_0.2-18     DT_0.24              future_1.26.1       
##  [10] miniUI_0.1.1.1       withr_2.5.0          Brobdingnag_1.2-7   
##  [13] colorspace_2.0-3     highr_0.9            rstudioapi_0.14     
##  [16] stats4_4.2.1         ggsignif_0.6.3       listenv_0.8.0       
##  [19] labeling_0.4.2       rstan_2.21.5         DiceDesign_1.9      
##  [22] farver_2.1.1         bridgesampling_1.1-2 coda_0.19-4         
##  [25] parallelly_1.32.1    vctrs_0.4.1          generics_0.1.3      
##  [28] ipred_0.9-13         xfun_0.32            R6_2.5.1            
##  [31] markdown_1.1         gamm4_0.2-6          lhs_1.1.5           
##  [34] cachem_1.0.6         assertthat_0.2.1     promises_1.2.0.1    
##  [37] nnet_7.3-17          googlesheets4_1.0.0  gtable_0.3.0        
##  [40] globals_0.16.0       processx_3.7.0       timeDate_4021.104   
##  [43] rlang_1.0.4          systemfonts_1.0.4    splines_4.2.1       
##  [46] rstatix_0.7.0        gargle_1.2.0         checkmate_2.1.0     
##  [49] inline_0.3.19        yaml_2.3.5           reshape2_1.4.4      
##  [52] abind_1.4-5          modelr_0.1.8         threejs_0.3.3       
##  [55] crosstalk_1.2.0      backports_1.4.1      httpuv_1.6.5        
##  [58] tensorA_0.36.2       tools_4.2.1          lava_1.6.10         
##  [61] ellipsis_0.3.2       jquerylib_0.1.4      posterior_1.2.2     
##  [64] ggridges_0.5.3       plyr_1.8.7           base64enc_0.1-3     
##  [67] ps_1.7.1             prettyunits_1.1.1    ggpubr_0.4.0        
##  [70] rpart_4.1.16         cowplot_1.1.1        zoo_1.8-10          
##  [73] haven_2.5.0          fs_1.5.2             furrr_0.3.1         
##  [76] magrittr_2.0.3       colourpicker_1.1.1   reprex_2.0.1        
##  [79] GPfit_1.0-8          googledrive_2.0.0    mvtnorm_1.1-3       
##  [82] matrixStats_0.62.0   hms_1.1.1            shinyjs_2.1.0       
##  [85] mime_0.12            evaluate_0.16        xtable_1.8-4        
##  [88] shinystan_2.6.0      readxl_1.4.0         gridExtra_2.3       
##  [91] shape_1.4.6          rstantools_2.2.0     MuMIn_1.46.0        
##  [94] compiler_4.2.1       crayon_1.5.1         minqa_1.2.4         
##  [97] StanHeaders_2.21.0-7 htmltools_0.5.3      mgcv_1.8-40         
## [100] later_1.3.0          tzdb_0.3.0           RcppParallel_5.1.5  
## [103] lubridate_1.8.0      DBI_1.1.3            corrplot_0.92       
## [106] dbplyr_2.2.1         MASS_7.3-58.1        boot_1.3-28         
## [109] car_3.0-13           cli_3.3.0            parallel_4.2.1      
## [112] gower_1.0.0          igraph_1.3.3         pkgconfig_2.0.3     
## [115] xml2_1.3.3           foreach_1.5.2        dygraphs_1.1.1.6    
## [118] svglite_2.1.0        bslib_0.4.0          hardhat_1.2.0       
## [121] webshot_0.5.3        prodlim_2019.11.13   rvest_1.0.2         
## [124] distributional_0.3.0 callr_3.7.1          digest_0.6.29       
## [127] rmarkdown_2.14       cellranger_1.1.0     shiny_1.7.1         
## [130] gtools_3.9.3         nloptr_2.0.3         lifecycle_1.0.1     
## [133] nlme_3.1-159         jsonlite_1.8.0       carData_3.0-5       
## [136] viridisLite_0.4.1    fansi_1.0.3          pillar_1.7.0        
## [139] lattice_0.20-45      loo_2.5.1            fastmap_1.1.0       
## [142] httr_1.4.3           pkgbuild_1.3.1       survival_3.4-0      
## [145] glue_1.6.2           xts_0.12.1           diffobj_0.3.5       
## [148] shinythemes_1.2.0    iterators_1.0.14     class_7.3-20        
## [151] stringi_1.7.8        sass_0.4.2           future.apply_1.9.0
Reid, Rachel E. B. 2015. “A Morphometric Modeling Approach to Distinguishing Among Bobcat, Coyote and Gray Fox Scats.” Wildlife Biology 21 (5): 254–62. https://doi.org/10.2981/wlb.00105.