Trucos. Parte 1. Submodel trick

estadística
tidymodels
2025
Published

February 9, 2025

Listening

El otro día atendí a un webminar junto con Aitor en el que se hablaba de algunos truquillos del tidymodels.

No soy muy fan del tuneo de hiperparámetros, pero es innegable que es algo que está ahí.

Nota

Este post está basado en lo leído en Efficient Machine Learnig with R

Submodel trick

Por ejemplo, si estamos haciendo un modelo de árboles con boosting, cada árbol se construye sobre el anterior. Por tanto si tengo un modelo con 200 árboles, puedo usarlo para ver que predicciones daría un modelo con 50 árboles, puesto que para llegar a 200 ha tenido que pasar por 50.

¿Por qué es esto útil? Si tenemos un grid de parámetros que incluya probar mismo modelo de boosting con esta configuración

Show the code
(grid <- expand.grid(
                    n_trees = c(10, 20, 100), 
                    learn_rate = c(0.1, 0.2, 2)
))
#>   n_trees learn_rate
#> 1      10        0.1
#> 2      20        0.1
#> 3     100        0.1
#> 4      10        0.2
#> 5      20        0.2
#> 6     100        0.2
#> 7      10        2.0
#> 8      20        2.0
#> 9     100        2.0

En realidad sólo tendríamos que ajustar 3 modelos, los correspondientes a n_trees = 100 y podríamos usar esos modelos para predecir con cualquier número de árboles del 1 al 100.

Veamos.

Funciones auxiliares para simular dataset

Show the code

bin_roughly <- function(x) {
  n_levels <- sample(1:4, 1)
  cutpoints <- sort(sample(x, n_levels))
  x <- rowSums(vapply(cutpoints, `>`, logical(length(x)),  x))
  factor(x, labels = paste0("level_", 1:(n_levels+1)))
}

simulate_regression <- function(n_rows) {
  modeldata::sim_regression(n_rows) |>
    select(-c(predictor_16:predictor_20)) |>
    mutate(across(contains("_1"), bin_roughly))
}

simulate_classification <- function(n_rows, n_levels) {
  modeldata::sim_classification(n_rows, num_linear = 12) |>
    mutate(across(contains("_1"), bin_roughly))
}

__ tidymodels y bonsai para ajustar modelos de boosting__

Show the code

library(tidymodels) # modelling framework
library(workflows)
library(bonsai)     # models like lightgm
library(future)     # parallel processing

Simulo datos clasificación

Show the code

set.seed(1)
d <- simulate_classification(3e4)
d
#> # A tibble: 30,000 × 18
#>    class   two_factor_1 two_factor_2 non_linear_1 non_linear_2 non_linear_3
#>    <fct>   <fct>               <dbl> <fct>               <dbl>        <dbl>
#>  1 class_1 level_4           -0.439  level_1             0.637        0.462
#>  2 class_1 level_3            0.764  level_1             0.658        0.266
#>  3 class_1 level_3           -1.33   level_1             0.657        0.349
#>  4 class_1 level_1            1.87   level_2             0.971        0.836
#>  5 class_2 level_2            0.109  level_1             0.534        0.984
#>  6 class_1 level_5           -0.0446 level_1             0.959        0.803
#>  7 class_1 level_2            1.13   level_2             0.593        0.711
#>  8 class_1 level_2            1.45   level_2             0.396        0.582
#>  9 class_2 level_2            1.13   level_2             0.438        0.277
#> 10 class_2 level_2           -1.04   level_1             0.624        0.737
#> # ℹ 29,990 more rows
#> # ℹ 12 more variables: linear_01 <dbl>, linear_02 <dbl>, linear_03 <dbl>,
#> #   linear_04 <dbl>, linear_05 <dbl>, linear_06 <dbl>, linear_07 <dbl>,
#> #   linear_08 <dbl>, linear_09 <dbl>, linear_10 <fct>, linear_11 <fct>,
#> #   linear_12 <fct>

# split en train test
d_split <- initial_split(d)
d_train <- training(d_split)
d_test <- testing(d_split)

# folds sobre train
d_folds <- vfold_cv(d_train, v = 5)

Modelo 1

Modelo con trees = 100 y learn_rate = 0.1.

Estimando este modelo, luego podemos hacer predicciones considerando modelos con menos número de árboles sin necesidad de estimarlos por separado.

Show the code

mod1_spec <- 
  boost_tree( trees = 100, learn_rate = 0.1)  |> 
  set_mode("classification")  |> 
  set_engine(engine = "xgboost")

recipe1 <- recipe(
                  class ~ ., 
                  data = d_train)  |>
        step_dummy(all_nominal_predictors())  |> 
        prep()

d_train_bake <- bake(recipe1, d_train)

tictoc::tic()
wf1_fit <- fit(mod1_spec,formula = class ~ .,  d_train_bake)
tictoc::toc()
#> 17.441 sec elapsed

En tidymodels para poder usar el submodel trick directamente teneemos la función multi_predict pero no tiene implementado la interfaz de fórmula por lo que tenemos que usar fit_xy

Show the code

 wf1_fit_to_multipredict  <-  
  mod1_spec |>  
  fit_xy(x = d_train_bake  |>  dplyr::select(-class), y = d_train_bake$class)

Ahora podemos usar la función multi_predict para obtener las prediciones que se obtendrían con un modelo con menos árboles. Es decir, ajustamos un solo modelo con 100 árboles, pero podemos “podar” ese modelo y obtener predicciones usando menos árboles sin tener que reestimar.

Show the code

test_bake <- bake(recipe1, d_test)

pred_test_100_trees <- multi_predict(wf1_fit_to_multipredict,
                                     type = "prob",
                                     new_data = test_bake |>  dplyr::select(-class),
                                     trees = 100 )

pred_test_10_trees <- multi_predict(wf1_fit_to_multipredict,
                                     type = "prob",
                                     new_data = test_bake |>  dplyr::select(-class),
                                     trees = c(10))

head(pred_test_100_trees)
#> # A tibble: 6 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [1 × 3]>
#> 2 <tibble [1 × 3]>
#> 3 <tibble [1 × 3]>
#> 4 <tibble [1 × 3]>
#> 5 <tibble [1 × 3]>
#> 6 <tibble [1 × 3]>
head(pred_test_10_trees)
#> # A tibble: 6 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [1 × 3]>
#> 2 <tibble [1 × 3]>
#> 3 <tibble [1 × 3]>
#> 4 <tibble [1 × 3]>
#> 5 <tibble [1 × 3]>
#> 6 <tibble [1 × 3]>

# también se puede hacer con varios árboles a la vez
pred_test_10_20_trees <- multi_predict(wf1_fit_to_multipredict,
                                     type = "prob",
                                     new_data = test_bake |>  dplyr::select(-class),
                                     trees = c(10, 20))

head(pred_test_10_20_trees)
#> # A tibble: 6 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [2 × 3]>
#> 2 <tibble [2 × 3]>
#> 3 <tibble [2 × 3]>
#> 4 <tibble [2 × 3]>
#> 5 <tibble [2 × 3]>
#> 6 <tibble [2 × 3]>

Podemos ver las predicciones con el modelo completo (100 árboles), con el submodelo (10 árboles) y con los dos submodelos de 10 y 20 árboles.

Show the code

pred_test_100_trees  |>  
  unnest(.pred)  |> 
  slice_head( n = 10)
#> # A tibble: 10 × 3
#>    trees .pred_class_1 .pred_class_2
#>    <dbl>         <dbl>         <dbl>
#>  1   100       0.0647         0.935 
#>  2   100       0.804          0.196 
#>  3   100       0.0584         0.942 
#>  4   100       0.0221         0.978 
#>  5   100       0.590          0.410 
#>  6   100       0.979          0.0210
#>  7   100       0.355          0.645 
#>  8   100       0.00470        0.995 
#>  9   100       0.535          0.465 
#> 10   100       0.951          0.0486
Show the code

pred_test_10_trees  |>  
  unnest(.pred)  |> 
  slice_head( n = 10)
#> # A tibble: 10 × 3
#>    trees .pred_class_1 .pred_class_2
#>    <dbl>         <dbl>         <dbl>
#>  1    10         0.220         0.780
#>  2    10         0.687         0.313
#>  3    10         0.228         0.772
#>  4    10         0.264         0.736
#>  5    10         0.465         0.535
#>  6    10         0.759         0.241
#>  7    10         0.465         0.535
#>  8    10         0.184         0.816
#>  9    10         0.544         0.456
#> 10    10         0.720         0.280
Show the code

pred_test_10_20_trees  |>  
  unnest(.pred)  |> 
  slice_head( n = 10)
#> # A tibble: 10 × 3
#>    trees .pred_class_1 .pred_class_2
#>    <dbl>         <dbl>         <dbl>
#>  1    10         0.220         0.780
#>  2    20         0.111         0.889
#>  3    10         0.687         0.313
#>  4    20         0.726         0.274
#>  5    10         0.228         0.772
#>  6    20         0.124         0.876
#>  7    10         0.264         0.736
#>  8    20         0.135         0.865
#>  9    10         0.465         0.535
#> 10    20         0.490         0.510

Comparamos métrica de roc_auc entre el modelo de 100 árboles y el submodelo de 10

Show the code

pred_test_100_trees  |>  
  unnest(.pred)  |> 
  bind_cols(d_test |> select(class)) |> 
  roc_auc(truth = class, .pred_class_1 )
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.952

pred_test_10_trees  |>  
  unnest(.pred)  |> 
  bind_cols(d_test |> select(class)) |> 
  roc_auc(truth = class, .pred_class_1 )
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.945

¿De qué nos sirve este truco y como funciona en tidymodels?

Cuando hacemos validación cruzada para encontrar los mejores hiperparámetros, este truco sirve para no tener que ejecutar varios modelos. Para que funcione en tidymodels hay que pasarle un grid en formato tibble o data.frame a la función de tune_grid. Veámoslo

Hacemos el “tuneado” dejando que sea tidymodels quien haga el grid, en este caso no se utiliza el submodel trick. Si tenemos 4 combinaciones y 5 folds, tidymodels ejecutará 4 x 5 modelos.

Nota: Cuando haya muchos datos y muchos folds es mejor usar plan(multicore) o plan(multisession) para que tidymodels paralelice y haga fold en un proceso. Pero en ese caso es muy importante poner que la paralelización nativa con xgboost o lightgbm utilice un solo hilo (o como mucho 2), porque si no entra en conflicto la paralelización de future con la nativa de esas librerías basada en OpenMP. Gracias a Jordi Rosell por las pistas.

Show the code

bt <- 
  boost_tree( trees = tune(), learn_rate = tune()) |>
  set_mode("classification")

# engine xgboost
bt
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   trees = tune()
#>   learn_rate = tune()
#> 
#> Computational engine: xgboost

Si dejamos a tidymodels hacer el grid de parámetros usará combinaciones aleatorias de los parámetros. Por ejemplo

Show the code

# Extraer hiperparámetros y definir rango
param_grid <- hardhat::extract_parameter_set_dials(bt) |>
  update(trees = trees(range = c(1, 140)), learn_rate = learn_rate(range= c(-1, 1)))

set.seed(49)
grid_random <- grid_random(param_grid, size = 16)



grid_random
#> # A tibble: 16 × 2
#>    trees learn_rate
#>    <int>      <dbl>
#>  1    19      7.78 
#>  2    89      0.202
#>  3    32      2.23 
#>  4    47      2.65 
#>  5   131      0.158
#>  6    85      1.27 
#>  7   124      0.139
#>  8    97      1.22 
#>  9    84      6.22 
#> 10    66      1.41 
#> 11    29      1.62 
#> 12    46      0.262
#> 13     9      1.32 
#> 14    28      4.31 
#> 15   135      0.268
#> 16   137      0.358

Y sería raro que se tenga mismo valor de learn_rate para distintos valores de trees. Recordemos que en estos modelos el submodel trick funciona solo para trees. Si tuviéramos mismo valor de learn_rate y diferentes valores de trees bastaría con ajustar el modelo con mayor número de árboles.

Veamos cuanta tarda en ajustar estas 16 combinaciones de parámetros, sobre los 5 folds. Es decir 80 modelos.

Show the code

tictoc::tic()
    basic <- 
      tune_grid(
        object = bt,
        preprocessor = class ~ .,
        resamples = d_folds,
        grid = grid_random
      )
tictoc::toc()
#> 679.145 sec elapsed
Show the code

(metricas_basic <- collect_metrics(basic))
#> # A tibble: 48 × 8
#>    trees learn_rate .metric     .estimator   mean     n std_err .config         
#>    <int>      <dbl> <chr>       <chr>       <dbl> <int>   <dbl> <chr>           
#>  1   124      0.139 accuracy    binary     0.883      5 0.00231 Preprocessor1_M…
#>  2   124      0.139 brier_class binary     0.0827     5 0.00156 Preprocessor1_M…
#>  3   124      0.139 roc_auc     binary     0.957      5 0.00167 Preprocessor1_M…
#>  4   131      0.158 accuracy    binary     0.881      5 0.00204 Preprocessor1_M…
#>  5   131      0.158 brier_class binary     0.0838     5 0.00152 Preprocessor1_M…
#>  6   131      0.158 roc_auc     binary     0.956      5 0.00165 Preprocessor1_M…
#>  7    89      0.202 accuracy    binary     0.882      5 0.00259 Preprocessor1_M…
#>  8    89      0.202 brier_class binary     0.0837     5 0.00165 Preprocessor1_M…
#>  9    89      0.202 roc_auc     binary     0.956      5 0.00181 Preprocessor1_M…
#> 10    46      0.262 accuracy    binary     0.883      5 0.00204 Preprocessor1_M…
#> # ℹ 38 more rows

metricas_basic  |>  
  filter(.metric == "roc_auc")  |> 
  arrange(desc(mean))
#> # A tibble: 16 × 8
#>    trees learn_rate .metric .estimator  mean     n std_err .config              
#>    <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#>  1   124      0.139 roc_auc binary     0.957     5 0.00167 Preprocessor1_Model01
#>  2    46      0.262 roc_auc binary     0.956     5 0.00179 Preprocessor1_Model04
#>  3   131      0.158 roc_auc binary     0.956     5 0.00165 Preprocessor1_Model02
#>  4    89      0.202 roc_auc binary     0.956     5 0.00181 Preprocessor1_Model03
#>  5   135      0.268 roc_auc binary     0.954     5 0.00170 Preprocessor1_Model05
#>  6   137      0.358 roc_auc binary     0.951     5 0.00157 Preprocessor1_Model06
#>  7     9      1.32  roc_auc binary     0.944     5 0.00225 Preprocessor1_Model09
#>  8    97      1.22  roc_auc binary     0.939     5 0.00225 Preprocessor1_Model07
#>  9    85      1.27  roc_auc binary     0.937     5 0.00217 Preprocessor1_Model08
#> 10    66      1.41  roc_auc binary     0.932     5 0.00145 Preprocessor1_Model10
#> 11    29      1.62  roc_auc binary     0.918     5 0.00514 Preprocessor1_Model11
#> 12    84      6.22  roc_auc binary     0.717     5 0.0344  Preprocessor1_Model15
#> 13    28      4.31  roc_auc binary     0.695     5 0.0524  Preprocessor1_Model14
#> 14    19      7.78  roc_auc binary     0.665     5 0.0562  Preprocessor1_Model16
#> 15    32      2.23  roc_auc binary     0.639     5 0.0352  Preprocessor1_Model12
#> 16    47      2.65  roc_auc binary     0.595     5 0.0287  Preprocessor1_Model13

Para aprovechar el submodel trick hay que construir un grid dónde a igual combinación de otros parámetros tengamos diferentes valores de trees

Show the code

grid_regular <- grid_regular(param_grid, levels = 4)

grid_regular
#> # A tibble: 16 × 2
#>    trees learn_rate
#>    <int>      <dbl>
#>  1     1      0.1  
#>  2    47      0.1  
#>  3    93      0.1  
#>  4   140      0.1  
#>  5     1      0.464
#>  6    47      0.464
#>  7    93      0.464
#>  8   140      0.464
#>  9     1      2.15 
#> 10    47      2.15 
#> 11    93      2.15 
#> 12   140      2.15 
#> 13     1     10    
#> 14    47     10    
#> 15    93     10    
#> 16   140     10

Ahora en vez de tener que ajustar las 16 combinaciones solo hace falta ajustar las 4 dónde trees = 140 y tidymodels lo tiene en cuenta.

Show the code

tictoc::tic()
    xgb_with_sub_trick <- 
      tune_grid(
        object = bt,
        preprocessor = class ~ .,
        resamples = d_folds,
        grid = grid_regular
      )
tictoc::toc()
#> 259.045 sec elapsed
Show the code

(metricas_xgb_sub_trick <- collect_metrics(xgb_with_sub_trick))
#> # A tibble: 48 × 8
#>    trees learn_rate .metric     .estimator   mean     n  std_err .config        
#>    <int>      <dbl> <chr>       <chr>       <dbl> <int>    <dbl> <chr>          
#>  1     1        0.1 accuracy    binary     0.864      5 0.00356  Preprocessor1_…
#>  2     1        0.1 brier_class binary     0.220      5 0.000194 Preprocessor1_…
#>  3     1        0.1 roc_auc     binary     0.943      5 0.00237  Preprocessor1_…
#>  4    47        0.1 accuracy    binary     0.886      5 0.00118  Preprocessor1_…
#>  5    47        0.1 brier_class binary     0.0812     5 0.00154  Preprocessor1_…
#>  6    47        0.1 roc_auc     binary     0.958      5 0.00187  Preprocessor1_…
#>  7    93        0.1 accuracy    binary     0.885      5 0.00188  Preprocessor1_…
#>  8    93        0.1 brier_class binary     0.0813     5 0.00155  Preprocessor1_…
#>  9    93        0.1 roc_auc     binary     0.958      5 0.00175  Preprocessor1_…
#> 10   140        0.1 accuracy    binary     0.883      5 0.00166  Preprocessor1_…
#> # ℹ 38 more rows

metricas_xgb_sub_trick  |>  
  filter(.metric == "roc_auc")  |> 
  arrange(desc(mean))
#> # A tibble: 16 × 8
#>    trees learn_rate .metric .estimator  mean     n std_err .config              
#>    <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#>  1    47      0.1   roc_auc binary     0.958     5 0.00187 Preprocessor1_Model02
#>  2    93      0.1   roc_auc binary     0.958     5 0.00175 Preprocessor1_Model03
#>  3   140      0.1   roc_auc binary     0.957     5 0.00163 Preprocessor1_Model04
#>  4    47      0.464 roc_auc binary     0.952     5 0.00174 Preprocessor1_Model06
#>  5    93      0.464 roc_auc binary     0.950     5 0.00150 Preprocessor1_Model07
#>  6   140      0.464 roc_auc binary     0.950     5 0.00147 Preprocessor1_Model08
#>  7     1      0.1   roc_auc binary     0.943     5 0.00237 Preprocessor1_Model01
#>  8     1      0.464 roc_auc binary     0.943     5 0.00237 Preprocessor1_Model05
#>  9     1      2.15  roc_auc binary     0.943     5 0.00237 Preprocessor1_Model09
#> 10     1     10     roc_auc binary     0.941     5 0.00222 Preprocessor1_Model13
#> 11    47     10     roc_auc binary     0.622     5 0.0394  Preprocessor1_Model14
#> 12    93     10     roc_auc binary     0.622     5 0.0394  Preprocessor1_Model15
#> 13   140     10     roc_auc binary     0.622     5 0.0394  Preprocessor1_Model16
#> 14    47      2.15  roc_auc binary     0.590     5 0.0510  Preprocessor1_Model10
#> 15    93      2.15  roc_auc binary     0.590     5 0.0510  Preprocessor1_Model11
#> 16   140      2.15  roc_auc binary     0.590     5 0.0510  Preprocessor1_Model12

Y debería haber tardado menos.

Vemos las métricas

Truco adicional

Otra que cosa que se puede hacer, es simplemente cambiando el engine a uno que funcione más rápido. por ejempolo a lightgbm

Show the code

bt_lgb <- bt |> set_engine("lightgbm")


tictoc::tic()
lgb_with_sub_trick <-  
      tune_grid(
        object = bt_lgb,
        preprocessor = class ~ .,
        resamples = d_folds,
        grid = grid_regular
      )
tictoc::toc()   
#> 49.919 sec elapsed
Show the code

(metricas_lgb_with_sub_trick <- collect_metrics(lgb_with_sub_trick))
#> # A tibble: 48 × 8
#>    trees learn_rate .metric     .estimator   mean     n  std_err .config        
#>    <int>      <dbl> <chr>       <chr>       <dbl> <int>    <dbl> <chr>          
#>  1     1        0.1 accuracy    binary     0.578      5 0.0240   Preprocessor1_…
#>  2     1        0.1 brier_class binary     0.217      5 0.000366 Preprocessor1_…
#>  3     1        0.1 roc_auc     binary     0.947      5 0.00232  Preprocessor1_…
#>  4    47        0.1 accuracy    binary     0.885      5 0.00161  Preprocessor1_…
#>  5    47        0.1 brier_class binary     0.0810     5 0.00143  Preprocessor1_…
#>  6    47        0.1 roc_auc     binary     0.958      5 0.00163  Preprocessor1_…
#>  7    93        0.1 accuracy    binary     0.885      5 0.00189  Preprocessor1_…
#>  8    93        0.1 brier_class binary     0.0814     5 0.00152  Preprocessor1_…
#>  9    93        0.1 roc_auc     binary     0.958      5 0.00165  Preprocessor1_…
#> 10   140        0.1 accuracy    binary     0.883      5 0.00200  Preprocessor1_…
#> # ℹ 38 more rows

metricas_lgb_with_sub_trick  |> 
  filter(.metric == "roc_auc")  |> 
  arrange(desc(mean))
#> # A tibble: 16 × 8
#>    trees learn_rate .metric .estimator  mean     n std_err .config              
#>    <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#>  1    47      0.1   roc_auc binary     0.958     5 0.00163 Preprocessor1_Model02
#>  2    93      0.1   roc_auc binary     0.958     5 0.00165 Preprocessor1_Model03
#>  3   140      0.1   roc_auc binary     0.957     5 0.00165 Preprocessor1_Model04
#>  4    47      0.464 roc_auc binary     0.950     5 0.00170 Preprocessor1_Model06
#>  5    93      0.464 roc_auc binary     0.948     5 0.00144 Preprocessor1_Model07
#>  6   140      0.464 roc_auc binary     0.948     5 0.00145 Preprocessor1_Model08
#>  7     1      0.1   roc_auc binary     0.947     5 0.00232 Preprocessor1_Model01
#>  8     1      0.464 roc_auc binary     0.947     5 0.00232 Preprocessor1_Model05
#>  9     1      2.15  roc_auc binary     0.947     5 0.00232 Preprocessor1_Model09
#> 10     1     10     roc_auc binary     0.947     5 0.00232 Preprocessor1_Model13
#> 11    47      2.15  roc_auc binary     0.597     5 0.0282  Preprocessor1_Model10
#> 12    93      2.15  roc_auc binary     0.597     5 0.0282  Preprocessor1_Model11
#> 13   140      2.15  roc_auc binary     0.597     5 0.0282  Preprocessor1_Model12
#> 14    47     10     roc_auc binary     0.538     5 0.0354  Preprocessor1_Model14
#> 15    93     10     roc_auc binary     0.538     5 0.0354  Preprocessor1_Model15
#> 16   140     10     roc_auc binary     0.538     5 0.0354  Preprocessor1_Model16

Si os fijáis en las salidas al final de cada tictoc::toc() se ve que con el truco de los submodelos ganamos velocidad, pero que si además cambiamos de “engine” puede llegar a ser 10 veces más rápido.

En próximas entradas contaré algún truquillo más, como los método de racing, que permite que cuando estemos haciendo el “tuning”, se descarten modelos sin tener qeu esperar a que se ajusten en todos los folds.

Un saludo.