This page is part of a code repository to reproduce the example analyses in the manuscript “Cross validation for model selection: a primer with examples from ecology” by L.A. Yates, Z. Aandahl, S.A.Richards, and B.W.Brook (2022). This tutorial is designed to be used alongside the manuscript.
The scat
dataset is sourced from the R package
caret
. Reid (2015) collected
animal feces in coastal California. The data consist of DNA verified
species designations as well as fields related to the time and place of
the collection and the morphology of the scat itself. In this example,
the aim of the analysis is to predict the biological family
(felid or canid) for each scat observation, based on
eight morphological characteristics, the scat location, and the
carbon-to-nitrogen ratio (see Table 1 of the original data publication
for variable descriptions).
Data preparations include: collapsing the response from three categories to two (i.e. binary classification), centering and scaling numerical predictors, and log-transforming two of the variables (based on advice in the original publication).
data(scat,package = "caret")
prep_data <- function(data){
data %>% as_tibble %>%
select(y = Species, Number, Length, Diameter, Taper, TI, Mass, CN, Location, ropey, segmented) %>%
rename_with(str_to_lower) %>% # TI: taper index
mutate(across(c(mass,cn),log)) %>% # following original data publication
mutate(number = as.double(number)) %>%
drop_na() %>%
mutate(across(where(is_double), ~ .x %>% scale %>% as.numeric)) %>% # centre and scale predictors
mutate(y = fct_collapse(y, canid = c("coyote","gray_fox"), felid = c("bobcat")) %>%
fct_relevel("canid","felid"),
location = fct_collapse(location, mid = c("middle"), edge = c("edge","off_edge")) %>%
factor(labels = c(0,1))) %>%
mutate(across(c(ropey,segmented), factor))
}
scat <- prep_data(scat) # N.B. fail = canid (factor-level 1), success = felid (factor-level 2)
vars <- names(scat %>% select(-y))
scat
## # A tibble: 91 × 11
## y number length diameter taper ti mass cn location ropey
## <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct> <fct>
## 1 canid -0.524 -0.0149 1.87 0.950 0.0233 0.651 0.450 0 0
## 2 canid -0.524 1.25 1.79 0.634 -0.144 0.784 1.42 0 0
## 3 felid -0.524 -0.156 0.112 -0.722 -0.715 -0.169 0.285 1 1
## 4 canid -0.524 -0.297 -0.0659 -0.182 -0.243 -0.332 1.48 1 1
## 5 canid 0.858 -0.438 0.595 -0.485 -0.627 1.26 1.20 0 0
## 6 canid 0.167 -0.156 0.722 0.0679 -0.262 0.501 0.645 0 1
## 7 felid 1.55 -1.00 -0.676 -1.27 -1.07 0.562 -1.10 0 1
## 8 felid 2.93 -1.14 0.900 -0.538 -0.715 1.31 -0.975 0 0
## 9 felid -0.524 0.408 -0.218 0.107 0.0528 0.679 -0.855 0 0
## 10 felid -1.21 3.09 -0.0913 -0.400 -0.410 0.204 0.112 1 1
## # … with 81 more rows, and 1 more variable: segmented <fct>
scat %>% select(where(is_double)) %>% cor %>% corrplot::corrplot(method =c("number"))
In this part, we perform exhaustive model selection, fitting all 1024
possible combinations of predictors. The models are
Bernoulli-distributed generalized linear models (\(\texttt{logit}\) link), fit using maximum
likelihood via the base function glm
. Model performance is
compared using Matthew’s correlation coefficient (MCC). To mitigate
against overfitting, the modified one-standard-error (OSE) rule is
applied to select a preferred model.
We use the MuMIn
package to generate the model
formulas:
forms_all <- glm(y~., data = scat, family = binomial) %>% MuMIn::dredge(evaluate = F) %>% map(as.formula)
dim_all <- forms_all %>% map_dbl(~ .x %>% terms %>% attr("variables") %>% {length(.)-2})
For example, the tenth model is
forms_all[[10]]
## y ~ cn + location + 1
The predictive performance of a binary (two-class) classifier can be
summarised as a \(2\times2\) matrix
called a confusion matrix. Labeling one class positive and the other
negative, the matrix entries are the counts of the four prediction
outcomes: true positives (TP), false positives (FP), false negatives
(FN), and true negatives(TN):
\[\left(\begin{array}{cc}\mathrm{TP}&\mathrm{FP}\\\mathrm{FN}&\mathrm{TN}\end{array}\right)\]
MCC is computed from the entries of a confusion matrix, taking values between \(-1\) and \(+1\), where +1 indicates perfect prediction, 0.0 means the prediction is no better than random, and −1 indicates the worst prediction possible: \[\mathrm{MCC} = \frac{TP\cdot TN - FP\cdot FN}{\sqrt{(TP+FP)(TP+FN)(TN+FP)(TN+FN)}}\]
A confusion matrix is easily generated from the predictions of fitted model object, but care must be taken to place the TP in the top left (this does not affect symmetric metrics like MCC). Using the prevalance as a probability threshold equal to determine the response category (logistic regression models class probabilities), a confusion matrix can be computed as follows:
fit <- glm(forms_all[[10]],binomial, scat)
# compute prevalance for the threshold
thresh <- scat %>% pull(y) %>% {.=="felid"} %>% {sum(.)/length(.)} # 0.549
# re-order levels so that TRUE (1) precedes FALSE (0)
pred <- as.numeric(predict(fit,type= "response")>=thresh)%>% factor(levels = c(1,0))
obs <- as.numeric(fit$data$y!=levels(fit$data$y)[1]) %>% factor(levels = c(1,0))
cm <- tibble(pred,obs) %>% table
cm
## obs
## pred 1 0
## 1 37 14
## 0 13 27
Now we define a function to compute MCC:
# computes MCC from a confusion matrix `cm`
mcc <- function(cm){
(cm[1,1]*cm[2,2] - cm[1,2]*cm[2,1])/
(sqrt((cm[1,1]+cm[1,2])*(cm[1,1]+cm[2,1])*(cm[2,2]+cm[1,2])*(cm[2,2]+cm[2,1])))
}
mcc(cm)
## [1] 0.3995122
So far we have computed an MCC estimate using within sample data. To
obtain an out-of-sample or predictive estimate of MCC we use repeated
\(k\)-fold CV. For each repeat, we
generate a single confusion matrix and corresponding MCC estimate. We
use the rsample
package (part of the
tidymodels
package set) to perform the data splitting:
\(10\)-fold CV, repeated 50 times.
set.seed(770); cv_data_1 <- vfold_cv(scat,10,50) %>% mutate(thresh = thresh)
cv_data_1
## # 10-fold cross-validation repeated 50 times
## # A tibble: 500 × 4
## splits id id2 thresh
## <list> <chr> <chr> <dbl>
## 1 <split [81/10]> Repeat01 Fold01 0.549
## 2 <split [82/9]> Repeat01 Fold02 0.549
## 3 <split [82/9]> Repeat01 Fold03 0.549
## 4 <split [82/9]> Repeat01 Fold04 0.549
## 5 <split [82/9]> Repeat01 Fold05 0.549
## 6 <split [82/9]> Repeat01 Fold06 0.549
## 7 <split [82/9]> Repeat01 Fold07 0.549
## 8 <split [82/9]> Repeat01 Fold08 0.549
## 9 <split [82/9]> Repeat01 Fold09 0.549
## 10 <split [82/9]> Repeat01 Fold10 0.549
## # … with 490 more rows
The returned rsample object is a tibble and we add a column to store
probability threshold (here fixed to \(0.5\), but we will allow it be tuned in a
subsequent analysis). The column splits
contains \(500\) (\(10\times50\)) splits of the original data.
For each split, the test and training data can be extracted using the
functions analysis
and split
, respectively.
For example:
split <- cv_data_1$splits[[1]] # take the 1st of 500 splits
analysis(split)
## # A tibble: 81 × 11
## y number length diameter taper ti mass cn location ropey
## <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct> <fct>
## 1 canid -0.524 -0.0149 1.87 0.950 0.0233 0.651 0.450 0 0
## 2 canid -0.524 1.25 1.79 0.634 -0.144 0.784 1.42 0 0
## 3 felid -0.524 -0.156 0.112 -0.722 -0.715 -0.169 0.285 1 1
## 4 canid -0.524 -0.297 -0.0659 -0.182 -0.243 -0.332 1.48 1 1
## 5 canid 0.858 -0.438 0.595 -0.485 -0.627 1.26 1.20 0 0
## 6 canid 0.167 -0.156 0.722 0.0679 -0.262 0.501 0.645 0 1
## 7 felid 1.55 -1.00 -0.676 -1.27 -1.07 0.562 -1.10 0 1
## 8 felid 2.93 -1.14 0.900 -0.538 -0.715 1.31 -0.975 0 0
## 9 felid -1.21 3.09 -0.0913 -0.400 -0.410 0.204 0.112 1 1
## 10 canid -1.21 -0.438 -1.39 -0.841 -0.459 -0.146 3.03 1 1
## # … with 71 more rows, and 1 more variable: segmented <fct>
assessment(split)
## # A tibble: 10 × 11
## y number length diameter taper ti mass cn location ropey
## <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct> <fct>
## 1 felid -0.524 0.408 -0.218 0.107 0.0528 0.679 -0.855 0 0
## 2 canid -1.21 0.267 -1.59 -1.03 -0.617 -1.45 -0.213 1 1
## 3 felid -1.21 -0.156 0.824 -1.06 -1.06 0.0215 0.0135 1 0
## 4 canid -1.21 0.690 -0.930 0.476 0.762 -0.237 0.720 0 1
## 5 felid 0.858 0.690 0.773 -1.02 -1.03 0.884 -0.312 0 0
## 6 felid 0.167 0.267 0.925 0.641 0.0823 0.875 -0.262 0 1
## 7 felid -0.524 -1.00 -0.803 -0.452 -0.252 -1.15 0.410 1 0
## 8 canid 0.167 -1.00 -0.651 -1.64 0.407 -0.887 -0.573 0 1
## 9 canid -1.21 1.25 -0.346 0.753 0.673 -1.05 -0.363 0 1
## 10 felid -0.524 -0.579 0.875 -0.0836 -0.400 0.223 -0.739 1 1
## # … with 1 more variable: segmented <fct>
The following function takes a repeated \(k\)-fold rsample object, a model formula, and a metric (e.g., MCC) and returns a tibble containing a metric estimate for each \(k\)-fold repeat.
fit_confusion_glm <- function(cv_data, formula = NULL, metric){
get_confusion_matrix <- function(form, split, thresh){
glm(form, data = analysis(split), family = binomial()) %>%
{tibble(pred = predict(.,newdata = assessment(split), type = "response") >= thresh,
y = assessment(split)$y != levels(assessment(split)$y)[1])} %>%
mutate(across(everything(), ~ .x %>% as.numeric() %>% factor(levels = c(1,0)))) %>%
table %>% {c(tp = .[1,1], fp = .[1,2], fn = .[2,1], tn = .[2,2])}
}
if(!is.null(formula)) form <- list(formula) else form <- cv_data$form_glm
with(cv_data, mapply(get_confusion_matrix, form, splits, thresh, SIMPLIFY = T)) %>% t %>%
as_tibble %>%
bind_cols(rep = cv_data$id, .) %>%
group_by(rep) %>%
summarise(metric = matrix(c(sum(tp),sum(fn),sum(fp),sum(tn)),2,2) %>% metric) %>%
mutate(rep = 1:n())
}
For example, applying the function to one of the models gives:
mcc_f10 <- fit_confusion_glm(cv_data_1, forms_all[[10]],mcc)
mcc_f10
## # A tibble: 50 × 2
## rep metric
## <int> <dbl>
## 1 1 0.332
## 2 2 0.311
## 3 3 0.308
## 4 4 0.353
## 5 5 0.355
## 6 6 0.308
## 7 7 0.376
## 8 8 0.379
## 9 9 0.330
## 10 10 0.332
## # … with 40 more rows
mcc_f10 %>% pull(metric) %>% {c(mean = mean(.), sd = sd(.))}
## mean sd
## 0.34423001 0.02514177
We estimate MCC for all 1024 models using the repeated \(k\)-fold data-splitting scheme created above.
mcc_est <- pbmclapply(forms_all, function(form) fit_confusion_glm(cv_data_1,form, mcc)$metric, mc.cores = 40) %>% bind_cols() # (220 seconds using 40 cores)
The following function takes the repeated CV MCC estimates and computes the following summary statistics for each model: the mean MCC; the standard error of the MCC (\(\sigma_m\)), standard error of the MCC difference with respect to the best model (\(\sigma_m^{{\mathrm{diff}}}\)), and the modified standard error (\(\sigma_m^{{\mathrm{adj}}}\)) (see manuscript for definitions):
# computes summary statistics for model score estimates
make_plot_data <- function(metric_data, levels){
best_model <- metric_data %>% map_dbl(mean) %>% which.max() %>% names()
tibble(model = factor(names(metric_data), levels = levels),
metric = metric_data %>% map_dbl(mean),
se = metric_data %>% map_dbl(sd),
se_diff = metric_data %>% map(~ .x - metric_data[[best_model]]) %>% map_dbl(sd),
se_mod = sqrt(1 -cor(metric_data)[best_model,])*se[best_model])
}
We apply the function to the MCC estimates and filter the results to select only the best performing model at each dimension (i.e., number of model parameters)
metric_summary <- make_plot_data(mcc_est, levels = names(mcc_est)) %>% mutate(dim = dim_all)
## Warning in cor(metric_data): the standard deviation is zero
metric_by_dim <- metric_summary %>%
group_by(dim) %>%
arrange(-metric) %>%
slice(1) %>%
ungroup() %>%
arrange(dim) %>%
filter(!dim %in% c(0,9)) # reduce model set to be plotted
Finally, we apply the modified OSE rule by determining the least complex model whose mean MCC estimate lies within \(\sigma_m^{{\mathrm{adj}}}\) of the best estimate. We plot the results.
best_mod_ose_dim <- metric_by_dim %>%
filter(metric + se_mod >= max(metric)) %>%
filter(dim == min(dim)) %>% pull(dim)
# main plot for scat example in manuscript
metric_by_dim %>%
ggplot(aes(factor(dim, levels = dim))) +
geom_linerange(aes(ymin = metric - se_mod, ymax = metric + se_mod), col = "black") +
geom_point(aes(y = metric), size = 2) +
theme_classic() +
geom_point(aes(y = metric), shape = 1, size = 6, col = "black",
data = ~ .x %>% filter(dim == best_mod_ose_dim)) + # circle the OSE-preferred model
labs(subtitle = "CV estimates for scat models", x = "Number of parameters",y = "MCC")
The model with two predictors is identified using the modified OSE rule, but we can see that the one-predictor model is very close to being selected. The corresponding model formulas at each dimension are:
forms_all[metric_by_dim$model]
## $`2`
## y ~ cn + 1
##
## $`34`
## y ~ cn + number + 1
##
## $`42`
## y ~ cn + location + number + 1
##
## $`52`
## y ~ cn + diameter + mass + number + 1
##
## $`56`
## y ~ cn + diameter + length + mass + number + 1
##
## $`64`
## y ~ cn + diameter + length + location + mass + number + 1
##
## $`320`
## y ~ cn + diameter + length + location + mass + number + taper +
## 1
##
## $`480`
## y ~ cn + diameter + length + location + mass + ropey + segmented +
## taper + 1
##
## $`1024`
## y ~ cn + diameter + length + location + mass + number + ropey +
## segmented + taper + ti + 1
As an alternative to discrete selection, we take the global logistic model and use penalised regression to shrink or regularise the parameter estimates. We use elastic net regression which comprises a family of penalised regression models, index by a hyperparameter \(\alpha\), \(0\leq\alpha\leq1\), for which the extreme values \(\alpha = 0\) and \(\alpha=1\) correspond to LASSO (least absolute shrinkage and selection operator) and ridge regression, respectively. The LASSO is able shrink some parameter estimates to zero, performing effective variable selection, whereas ridge regression shrinks all parameters towards, but never attaining, zero. Mathematically, the regularised parameters \(\widehat{\boldsymbol\theta} =(\widehat\theta_1,\widehat\theta_2,...,\widehat\theta_p)\) are those minimising the penalised function
\[\begin{equation} \label{eq:elastic_net} f(\mathbf{x};\mathbf{\boldsymbol{\widehat\theta}}) + \lambda\left(\alpha||\boldsymbol{\widehat\theta}||^1 + \frac{(1-\alpha)}{2}||\boldsymbol{\widehat\theta}||^2\right), \end{equation}\]
where \(f\) is the objective function (usually mean squared error or negative log density), and \(||\boldsymbol\theta||^1 = \sum_{j=1}^p |\theta_j|\) and \(||\boldsymbol\theta||^2 = \sum_{j=1}^p \theta_j^2\) are the \(L_1\)- and \(L_2\)-norm of the vector of model parameters, respectively. The regularisation parameter \(\lambda\) determines the strength of the penalty, implementing a trade-off between the size of the model’s parameter estimates (the shrinkage or effective complexity) and the minimised value of the (unconstrained) objective function \(f\).
Elastic net regression is implemented in the R package
glmnet
. The main CV fitting function accepts a fixed value
of \(\alpha\) (alpha
), the
GLM specification for the global model, the number of folds, and the
loss function or measure to be cross validated
(type.measure
, five options available). A vector of
candidate values of the regularisation parameter \(\lambda\) (lambda
) can be
supplied, otherwise a default set is generated. For example:
library(glmnet)
fit.net <- cv.glmnet(x = scat %>% select(-y) %>% as.data.frame() %>% makeX(),
y = scat %>% pull(y),
family = "binomial",
folds = 10,
type.measure = "class", # misclassification loss
alpha = 1) # lasso
fit.net %>% plot
fit.net$glmnet.fit %>% plot(xvar = "lambda")
Unfortunately, it is not possible to use a custom-defined loss
function in glmnet
, but with a little effort we can still
use the package to perform penalised regression for our chosen metric:
MCC.
We start by defining a function to compute the predictive confusion
matrix across a supplied range of \(\lambda\) values for a given
rsample
split.
# compute out-sample-sample (i.e., CV) confusion-matrix entries for all lambda values for a given split
get_cm <- function(split, rep, fold, alpha, lambda, thresh){
glmnet(x = analysis(split) %>% select(-y) %>% as.data.frame() %>% makeX(),
y = analysis(split) %>% pull(y),
family = "binomial",
alpha = alpha,
lambda = lambda) %>%
predict(newx = assessment(split) %>% select(-y) %>% as.data.frame() %>% makeX(),
type = "response") %>%
as_tibble %>%
rename_with(~str_remove(.x,stringr::fixed("s"))) %>%
imap_dfr(~ tibble(y_pred = as.numeric(.x > thresh),
y = assessment(split) %>% pull(y) %>% as.numeric() %>% {.-1},
tp = as.numeric(y==1 & y_pred==1),
fp = as.numeric(y==0 & y_pred==1),
fn = as.numeric(y==1 & y_pred==0),
tn = as.numeric(y==0 & y_pred==0),
rep = rep, fold = fold, lambda = lambda[as.numeric(.y)+1]))
}
For lasso (alpha = 1
), we specify (after a little trial
and error) the following sequence of candidate \(\mathrm{log}\,\lambda\) values:
log_lambda_lasso <- seq(-1.6,-4, length.out = 100)
and we compute the confusion matrices over all \(\lambda\) for the full set of splits
cv_data_1
used in the previous analysis.
# compute confusion-matrix entries and aggregate across folds within a given repetition
# 2 seconds with 40 cores
cm_lasso <- pbmclapply(1:nrow(cv_data_1),
function(i) get_cm(cv_data_1$splits[[i]], cv_data_1$id[[i]], cv_data_1$id2[[i]],
alpha = 1,
lambda = exp(log_lambda_lasso),
thresh = thresh),
mc.cores = 40) %>%
bind_rows() %>% as_tibble()
Within each repetition and \(\lambda\) value, we sum the confusion-matrix entries and compute MCC:
mcc_lasso <- cm_lasso %>%
group_by(rep, lambda) %>%
summarise(metric = matrix(c(sum(tp),sum(fp),sum(fn),sum(tn)),2,2) %>% mcc)
## `summarise()` has grouped output by 'rep'. You can override using the `.groups`
## argument.
mcc_lasso # a row for each repetition for each lambda value
## # A tibble: 5,000 × 3
## # Groups: rep [50]
## rep lambda metric
## <chr> <dbl> <dbl>
## 1 Repeat01 0.0183 0.334
## 2 Repeat01 0.0188 0.334
## 3 Repeat01 0.0193 0.334
## 4 Repeat01 0.0198 0.334
## 5 Repeat01 0.0203 0.334
## 6 Repeat01 0.0208 0.334
## 7 Repeat01 0.0213 0.334
## 8 Repeat01 0.0219 0.334
## 9 Repeat01 0.0224 0.334
## 10 Repeat01 0.0230 0.334
## # … with 4,990 more rows
To aid plot creation we compute the values of lambda that maximise the score:
# lambda value at best mean MCC estimate
lambda_best_lasso <- mcc_lasso %>%
group_by(lambda) %>%
summarise(metric = mean(metric)) %>%
filter(metric == max(metric, na.rm = T)) %>%
pull(lambda)
# MCC scores for all reps at best lambda value
mcc_best_lasso <- mcc_lasso %>%
filter(lambda == lambda_best_lasso) %>%
arrange(lambda,rep) %>% pull(metric)
To select \(\lambda\), we compute the (modified) standard errors
metric_plot_data <- mcc_lasso %>%
group_by(lambda) %>%
arrange(lambda,rep) %>%
summarise(se_ose = sd(metric),
se_diff = sd(metric - mcc_best_lasso),
se_best = sd(mcc_best_lasso),
rho_best_m = (se_diff^2 - se_ose^2 - se_best^2)/(-2*se_ose*se_best),
se_mod = sqrt(1-(rho_best_m))*se_best,
metric_diff = mean(metric - mcc_best_lasso),
metric = mean(metric))
metric_mod_ose <- metric_plot_data %>%
filter(metric + se_mod > max(metric, na.rm = T)) %>%
filter(lambda == max(lambda))
metric_ose <- metric_plot_data %>%
filter(se_ose >= abs(metric_diff)) %>%
filter(lambda == max(lambda))
which allows us to plot the tuning results and apply the modified OSE rule:
plot_lasso <-
metric_plot_data %>%
ggplot(aes(x = log(lambda))) +
geom_vline(aes(xintercept = metric_mod_ose$lambda %>% log), col = "grey70", lty = "dashed") +
geom_vline(aes(xintercept = lambda_best_lasso %>% log), col = "grey70", lty = "dashed") +
geom_linerange(aes(ymax = metric + se_mod, ymin = metric -se_mod, col = "m-OSE"), size = 0.5,
data = metric_mod_ose, position = position_nudge(0.0)) +
geom_linerange(aes(ymax = metric + se_ose, ymin = metric - se_ose, col = "OSE"), size = 0.5,
data = metric_ose, position = position_nudge(0.0)) +
geom_point(aes(y = metric), col = "grey60") +
geom_point(aes(y = metric), col = "grey10", data = ~ .x %>% filter(metric == max(metric))) +
geom_line(aes(y = metric), col = "grey30", lty = "solid") +
geom_point(aes(y = metric), shape = 1, size = 6, data = metric_mod_ose, col = "black") +
geom_point(aes(y = metric), size = 1, data = metric_mod_ose) +
geom_point(aes(y = metric), size = 1, data = ~ .x %>% filter(lambda == lambda_best_lasso), col = "black") +
labs(subtitle = NULL, x = NULL, y = "MCC") +
scale_colour_manual(name = NULL, values = c("black",blues9[8]),
labels = c(modOSE = expression(sigma^"diff"),OSE = expression(sigma^"best"))) +
theme_classic() +
theme(panel.grid = element_blank())
plot_lasso
Note that the ordinary OSE based on \(\sigma^{\mathrm{best}}\) selects a more heavily regularised model than the modified OSE based on \(\sigma^{\mathrm{mod}}\) (see manuscript for further details).
To see the effect of \(\lambda\) on the parameter estimates we refit the penalised regression models (using all of the data) and plot the corresponding trajectories:
fit_glmnet <- glmnet(x = scat %>% select(-y) %>% as.data.frame() %>% makeX(),
y = scat$y,
family = "binomial",
lambda = exp(log_lambda_lasso),
alpha = 1)
plot_est <- fit_glmnet$beta %>% as.matrix %>% t %>% as_tibble() %>%
mutate(lambda = log_lambda_lasso) %>% #, metric = metric_plot_data$metric) %>%
pivot_longer(!any_of(c("lambda","metric")), values_to = "estimate", names_to = "coefficient") %>%
filter(estimate != 0 | coefficient == "cn") %>%
ggplot(aes(lambda,estimate)) +
geom_vline(aes(xintercept = metric_mod_ose$lambda %>% log), col = "grey70", lty = "dashed") +
geom_vline(aes(xintercept = lambda_best_lasso %>% log), col = "grey70", lty = "dashed") +
geom_line(aes(col = coefficient)) +
geom_point(aes(col = coefficient), size =1, data = ~ .x %>% filter(lambda == metric_mod_ose$lambda %>% log)) +
geom_hline(aes(yintercept = 0), lty = "longdash") +
labs(y = "Parameter estimates", x = expression(paste("log(",lambda,")"))) +
theme_classic() +
theme(legend.position = "none")
plot_est
The above code is easily re-used to compute ridge estimates (i.e., \(\alpha= 0\)), or indeed any other \(\alpha\) values within the elastic-net family.
Finally, to tune \(\alpha\) as well
as \(\lambda\), we can nest the call to
get_cm
:
alpha <-seq(0, 1, 0.025) # candidate alpha values
metric_alpha <- lapply(alpha_vec, function(a){
pbmclapply(1:nrow(cv_data_1), function(i) get_cm(cv_data_1$splits[[i]],
cv_data_1$id[[i]],
cv_data_1$id2[[i]],
alpha = a, lambda, thresh),
mc.cores = MAX_CORES) %>%
bind_rows %>%
group_by(rep, lambda) %>%
summarise(cm11 = sum(tp), cm12 = sum(fp), cm21 = sum(fn), cm22 = sum(tn), .groups = "keep") %>%
summarise(metric = (cm11*cm22 - cm12*cm21)/(sqrt((cm11+cm12)*(cm11+cm21)*(cm22+cm12)*(cm22+cm21))), .groups = "drop") %>%
group_by(lambda) %>%
summarise(metric = mean(metric)) %>%
filter(metric == max(metric, na.rm = T)) %>%
pull(metric) # return best mean metric estimate among all lambda values
})
The results suggest that LASSO (\(\alpha=1\)) and ridge (\(\alpha=0\)) have comparable performance, whereas the elastic net estimate \(\alpha = 0.175\) is clearly superior:
tibble(metric = metric_alpha %>% unlist, alpha = alpha) %>%
ggplot(aes(alpha,metric)) + geom_line() + theme_classic() +
geom_point(data = ~ .x %>% filter(metric == max(metric))) +
labs(subtitle = "Tuning alpha for elastic-net regularisation", y = "MCC")
In a Bayesian setting, lasso-type and ridge-type regression can be implemented by an appropriate choice of prior distribution. The direct analogue of the frequentist lasso is the Laplacian (two-sided exponential) prior (Hastie2015), and Gaussian priors are equivalent to ridge regression. An alternative choice is the regularised horseshoe prior, which provides support for a proper subset of parameters to be far from zero, using prior information to guide both the number and the degree of regularisation of the non-zero parameter estimates.
Here we fit three models, using a different regularising prior for each, and we compare their predictive performance using pareto-smoothed-importance-sampling (PSIS) leave-one-out (LOO) CV—an easily computed approximate CV method for use with fitted Bayesian models. The three models are:
We fit the models using the package \(\texttt{brms}\): a front-end for the \(\texttt{stan}\) engine which implements cutting-edge Hamiltonian Monte Carlo (HMC) Markov chain methods combined with no u-turn sampling (NUTS) for fast efficient model fitting.
fit.flat <- brm(y ~ ., family=bernoulli(),
data=scat, chains=4, iter=2000, cores = 4,
save_pars = save_pars(all = TRUE))
fit.ridge <- brm(y ~ .,
family = bernoulli(),
prior = prior(normal(0,10)),
data = scat,
chains = 4,
iter = 2000,
cores = 4)
fit.lasso <- brm(y ~ .,
family = bernoulli(),
data = scat,
prior = set_prior("lasso(1)"),
chains = 4,
iter = 2000,
cores = 4)
fit.hs <- brm(y ~ ., family=bernoulli(), data=scat,
prior=prior(horseshoe(scale_global = tau0, scale_slab = 1), class=b),
chains=4, iter=2000, cores = 4)
The models took about half a minute to compile, but less than 2 seconds to sample.
Combining the model into a list, we run PSIS-LOO and compare the results:
fits <- list(hs = fit.hs, lasso = fit.lasso, ridge = fit.ridge, flat = fit.flat)
fits %>% map(loo, cores = 10) %>% loo_compare()
## Warning: Found 2 observations with a pareto_k > 0.7 in model '.x[[i]]'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
## Warning: Found 3 observations with a pareto_k > 0.7 in model '.x[[i]]'. It is
## recommended to set 'moment_match = TRUE' in order to perform moment matching for
## problematic observations.
## elpd_diff se_diff
## hs 0.0 0.0
## lasso -2.1 0.9
## ridge -7.9 3.4
## flat -8.4 3.5
Although the LOO scores have been estimated for all four models, the
pareto_k >0.7
diagnostic tell us that the LOO
approximation may have failed for a few data points in the less
regularised models. To address this, we refit (reloo = T
)
the model for each problematic observation; i.e., manually leaving out
test points:
future::plan("multisession", workers = 10)
fits.loo <- fits %>% map(loo, reloo = T, future = T)
Using the future package for parallelisation, the models are quickly fit, giving the adjusted summary
fits.loo %>% loo_compare
## elpd_diff se_diff
## hs 0.0 0.0
## lasso -2.1 0.9
## ridge -8.0 3.4
## flat -8.5 3.6
The horseshoe model is the best performing of the models and it also
the most regularised. The degree of regularisation can be seen by
comparing the effective number of parameters (p_loo
)
between models:
fits.loo %>% loo_compare %>% print(simplify = F)
## elpd_diff se_diff elpd_loo se_elpd_loo p_loo se_p_loo looic se_looic
## hs 0.0 0.0 -52.0 4.1 3.8 0.5 104.0 8.3
## lasso -2.1 0.9 -54.1 4.4 7.0 0.9 108.1 8.7
## ridge -8.0 3.4 -60.0 6.8 14.7 2.2 120.0 13.5
## flat -8.5 3.6 -60.5 6.9 15.2 2.3 121.0 13.7
It is interesting to ‘see’ the effect of the regularising priors by examining the (marginal) posterior density plots.
library(bayesplot)
# plot posteriors
plot_post <- function(fit){
fit %>% as_tibble %>% select(-starts_with("l")) %>%
mcmc_areas(area_method = "scaled", prob_outer = 0.98) +
xlim(c(-2.8,2)) +
theme_classic()
}
ggpubr::ggarrange(
fit.flat %>% plot_post + labs(subtitle = "Uninformative (flat)"),
fit.ridge %>% plot_post + labs(subtitle = "Weakly informative (ridge)"),
fit.lasso %>% plot_post + labs(subtitle = "LASSO"),
fit.hs %>% plot_post + labs(subtitle = "Horseshoe"),
ncol = 1,
labels = "AUTO"
)
Having specified, fit, and compared the candidate reference models, the predictive-inference step is easily implemented using the \(\texttt{projpred}\) package. Taking the horseshoe variant as the best available reference model, we use cross-validated forward step selection.
library(projpred)
vs <- cv_varsel(fit.hs, cv_method = "LOO", method = "forward")
vs %>%
plot(stats = c("elpd"), deltas = T) +
theme_classic() +
theme(strip.text = element_blank(),
strip.background = element_blank(),
legend.position = "none") +
labs(y = expression(Delta*"ELPD"))
solution_terms(vs)
## [1] "cn" "number" "segmented" "taper" "location" "mass"
## [7] "ropey" "length" "ti" "diameter"
The one-predictor model (cn
) is an obvious choice and
its selection is what we would have expected after seeing the posterior
distribution of the reference model. The last step in the BPI process is
to project posterior of the reference model onto the selected submodel
(here, using all 4000 draws):
proj <- project(vs, nterms = 1, ndraws = 4000)
To see the effect of the projection, we first fit the same one-parameter submodel to the original data set (as if we had decided a priori that this was the preferred model)
fit.cn <- brm(y ~ cn, family=bernoulli(), data=scat,chains=4, iter=2000, cores = 4)
Finally, we plot the (marginal) posteriors of the cn
parameter for the reference (ref), projected (proj) and non-projected
(non-proj) models
tibble(ref = fit.hs %>% as_tibble() %>% pull(b_cn),
proj = proj %>% as.matrix %>% {.[,"b_cn"]},
`non-proj` = fit.cn %>% as_tibble() %>% pull(b_cn)) %>%
mcmc_areas() +
labs(x = expression(theta["carbon-nitrogen ratio"])) +
theme_classic()
The larger effect size of the non-projected model can be interpreted as a selection-induced bias, arising due to a failure to account for selection uncertainty.
In this section we illustrate nested CV. Comparing parametric models to ‘black-box’ machine-learning models is not commonly done and for our choice of metric, MCC, the implementation requires some custom functions. Rather than show all of the gory details, we source a separate script and outline the basic steps (see GitHub repo for the script).
To get started, we generate the nested data splits using the \(\texttt{rsample}\) package
set.seed(7193) # sample(1e4,1)
cv_data_2 <- nested_cv(scat, outside = vfold_cv(v = 10, repeats = 50), inside = vfold_cv(v = 10))
vars <- names(scat %>% select(-y)); vars
The generated data-splitting scheme is 10-fold CV repeated 50 times (for the outer layer), where each of these 500 training sets contains a second (inner) layer of 10-fold CV splits—there are 5000 inner training sets altogether. The inner splits (training and test) are used to tune the hyperparameters of the model for each separate training set. Each tuned model is then fit to corresponding outer training set after which prediction to the test set and loss calculations are performed in the usual manner.
We load the pre-written functions and perform parameter tuning first.
Random forest models require the number of trees (ntree
) to
be set and the tree depth (mtry
): the latter is a
hyperparameter that we will tune
source("files_scat/scat_funs.R")
# specify RF hyper-parameters
ntree <- 500 # set a fixed value
mtry_vec <- 1:(ncol(scat)-1) # candidate tree-depth values
# run the tuning functions
rf_tune_values <- tune_rf(cv_data_2, mtry_vec, ntree, metric = mcc) # 1:17 seconds for 50 repeats with 40 cores
glm_step_tune_values <- tune_glm_step(cv_data_2, vars, metric = mcc) # 1:40 seconds for 50 repeats with 40 cores
# add tuned hyperparameters values to the rsample object
cv_data_2$thresh_rf <- rf_tune_values$threshold
cv_data_2$thresh <- glm_step_tune_values$threshold
cv_data_2$mtry_rf <- rf_tune_values$mtry
cv_data_2$form_glm <- glm_step_tune_values$form
Let’s have look at the rsample
object with tuned values
added:
cv_data_2
## # Nested resampling:
## # outer: 10-fold cross-validation repeated 50 times
## # inner: 10-fold cross-validation
## # A tibble: 500 × 10
## splits id id2 inner_resamples thresh_rf thresh_glm_step
## <list> <chr> <chr> <list> <dbl> <dbl>
## 1 <split [81/10]> Repeat01 Fold01 <vfold_cv [10 × 2]> 0.58 0.19
## 2 <split [82/9]> Repeat01 Fold02 <vfold_cv [10 × 2]> 0.46 0.67
## 3 <split [82/9]> Repeat01 Fold03 <vfold_cv [10 × 2]> 0.44 0.34
## 4 <split [82/9]> Repeat01 Fold04 <vfold_cv [10 × 2]> 0.55 0.55
## 5 <split [82/9]> Repeat01 Fold05 <vfold_cv [10 × 2]> 0.47 0.15
## 6 <split [82/9]> Repeat01 Fold06 <vfold_cv [10 × 2]> 0.43 0.57
## 7 <split [82/9]> Repeat01 Fold07 <vfold_cv [10 × 2]> 0.48 0.59
## 8 <split [82/9]> Repeat01 Fold08 <vfold_cv [10 × 2]> 0.25 0.27
## 9 <split [82/9]> Repeat01 Fold09 <vfold_cv [10 × 2]> 0.45 0.57
## 10 <split [82/9]> Repeat01 Fold10 <vfold_cv [10 × 2]> 0.44 0.33
## # … with 490 more rows, and 4 more variables: thresh_glm_all <dbl>,
## # mtry_rf <int>, form_glm_step <chr>, form_glm_all <list>
Now fit the models, which is much faster than tuning as we are only
using the 500 outer folds. For the random forest models, we fit two
alternatives: 1) only the subset most important (best) predictors is
kept for each fold (keeping mtry
predictors in total) 2)
all predictors are kept, but the tuned mtry
values are
still used for fitting.
fits_rf_best <- fit_rf(cv_data_2, ntree = ntree, type = "best", metric = mcc) # 3 seconds with 40 cores
fits_rf_all<- fit_rf(cv_data_2, ntree = ntree, type = "all", metric = mcc) # 2 seconds with 40 cores
fits_glm_step <- fit_confusion_glm(cv_data_2, metric = mcc)
Finally we collate and prepare the results
mcc_data_2 <- tibble(glm_step = fits_glm_step$metric,
rf_best = fits_rf_best$metric,
rf_all = fits_rf_all$metric)
mcc_plot_data_2 <- make_plot_data(mcc_data_2, names(tss_data_2))
and plot the model-comparison figure
plot_model_comparisons(mcc_plot_data_2, "se_mod") +
labs(title = "Model comparison", subtitle = "Stage 2: nested CV with tuned hyper-parameter, score = MCC")
The random forest model is a better predictive model than the glm,
with the best subset model the better performer of the two random forest
variants. Interestingly, the tuned tree depth across 500 outer folds
varied uniformly across the range of possible values (1 - 10),
suggesting that other predictors (not just cn
) do add
value, but the linearity of the glm may have been too inflexible for
their inclusion to be useful.
# table of number of variable selected across 500 outer folds
cv_data_2$mtry_rf %>% table # random forest
## .
## 1 2 3 4 5 6 7 8 9 10
## 23 40 54 49 61 49 66 61 43 54
cv_data_2$form_glm_step %>% str_split(" +") %>% map_dbl(~ .x %in% vars %>% sum) %>% table # glm
## .
## 1 2 3 4 5 6 7 8 9 10
## 39 110 96 94 71 58 18 9 4 1
sessionInfo()
## R version 4.2.1 (2022-06-23)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 20.04.5 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
##
## locale:
## [1] LC_CTYPE=en_AU.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_AU.UTF-8 LC_COLLATE=en_AU.UTF-8
## [5] LC_MONETARY=en_AU.UTF-8 LC_MESSAGES=en_AU.UTF-8
## [7] LC_PAPER=en_AU.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_AU.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] bayesplot_1.9.0 projpred_2.1.2 brms_2.17.0 Rcpp_1.0.9
## [5] glmnet_4.1-4 Matrix_1.4-1 knitr_1.39 kableExtra_1.3.4
## [9] yardstick_1.0.0 workflowsets_1.0.0 workflows_1.0.0 tune_1.0.0
## [13] rsample_1.1.0 recipes_1.0.1 parsnip_1.0.0 modeldata_1.0.0
## [17] infer_1.0.2 dials_1.0.0 scales_1.2.0 broom_1.0.0
## [21] tidymodels_1.0.0 forcats_0.5.1 stringr_1.4.1 dplyr_1.0.9
## [25] purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 tibble_3.1.7
## [29] ggplot2_3.3.6 tidyverse_1.3.2
##
## loaded via a namespace (and not attached):
## [1] utf8_1.2.2 lme4_1.1-30 tidyselect_1.1.2
## [4] htmlwidgets_1.5.4 grid_4.2.1 munsell_0.5.0
## [7] codetools_0.2-18 DT_0.24 future_1.26.1
## [10] miniUI_0.1.1.1 withr_2.5.0 Brobdingnag_1.2-7
## [13] colorspace_2.0-3 highr_0.9 rstudioapi_0.14
## [16] stats4_4.2.1 ggsignif_0.6.3 listenv_0.8.0
## [19] labeling_0.4.2 rstan_2.21.5 DiceDesign_1.9
## [22] farver_2.1.1 bridgesampling_1.1-2 coda_0.19-4
## [25] parallelly_1.32.1 vctrs_0.4.1 generics_0.1.3
## [28] ipred_0.9-13 xfun_0.32 R6_2.5.1
## [31] markdown_1.1 gamm4_0.2-6 lhs_1.1.5
## [34] cachem_1.0.6 assertthat_0.2.1 promises_1.2.0.1
## [37] nnet_7.3-17 googlesheets4_1.0.0 gtable_0.3.0
## [40] globals_0.16.0 processx_3.7.0 timeDate_4021.104
## [43] rlang_1.0.4 systemfonts_1.0.4 splines_4.2.1
## [46] rstatix_0.7.0 gargle_1.2.0 checkmate_2.1.0
## [49] inline_0.3.19 yaml_2.3.5 reshape2_1.4.4
## [52] abind_1.4-5 modelr_0.1.8 threejs_0.3.3
## [55] crosstalk_1.2.0 backports_1.4.1 httpuv_1.6.5
## [58] tensorA_0.36.2 tools_4.2.1 lava_1.6.10
## [61] ellipsis_0.3.2 jquerylib_0.1.4 posterior_1.2.2
## [64] ggridges_0.5.3 plyr_1.8.7 base64enc_0.1-3
## [67] ps_1.7.1 prettyunits_1.1.1 ggpubr_0.4.0
## [70] rpart_4.1.16 cowplot_1.1.1 zoo_1.8-10
## [73] haven_2.5.0 fs_1.5.2 furrr_0.3.1
## [76] magrittr_2.0.3 colourpicker_1.1.1 reprex_2.0.1
## [79] GPfit_1.0-8 googledrive_2.0.0 mvtnorm_1.1-3
## [82] matrixStats_0.62.0 hms_1.1.1 shinyjs_2.1.0
## [85] mime_0.12 evaluate_0.16 xtable_1.8-4
## [88] shinystan_2.6.0 readxl_1.4.0 gridExtra_2.3
## [91] shape_1.4.6 rstantools_2.2.0 MuMIn_1.46.0
## [94] compiler_4.2.1 crayon_1.5.1 minqa_1.2.4
## [97] StanHeaders_2.21.0-7 htmltools_0.5.3 mgcv_1.8-40
## [100] later_1.3.0 tzdb_0.3.0 RcppParallel_5.1.5
## [103] lubridate_1.8.0 DBI_1.1.3 corrplot_0.92
## [106] dbplyr_2.2.1 MASS_7.3-58.1 boot_1.3-28
## [109] car_3.0-13 cli_3.3.0 parallel_4.2.1
## [112] gower_1.0.0 igraph_1.3.3 pkgconfig_2.0.3
## [115] xml2_1.3.3 foreach_1.5.2 dygraphs_1.1.1.6
## [118] svglite_2.1.0 bslib_0.4.0 hardhat_1.2.0
## [121] webshot_0.5.3 prodlim_2019.11.13 rvest_1.0.2
## [124] distributional_0.3.0 callr_3.7.1 digest_0.6.29
## [127] rmarkdown_2.14 cellranger_1.1.0 shiny_1.7.1
## [130] gtools_3.9.3 nloptr_2.0.3 lifecycle_1.0.1
## [133] nlme_3.1-159 jsonlite_1.8.0 carData_3.0-5
## [136] viridisLite_0.4.1 fansi_1.0.3 pillar_1.7.0
## [139] lattice_0.20-45 loo_2.5.1 fastmap_1.1.0
## [142] httr_1.4.3 pkgbuild_1.3.1 survival_3.4-0
## [145] glue_1.6.2 xts_0.12.1 diffobj_0.3.5
## [148] shinythemes_1.2.0 iterators_1.0.14 class_7.3-20
## [151] stringi_1.7.8 sass_0.4.2 future.apply_1.9.0