Skip to contents

Prediction components are a functional decomposition of the model prediction. The sum of all components equals the overall predicted value for an observation.

Usage

predict_components(object, new_data, max_interaction = NULL, predictors = NULL)

Arguments

object

A fit object of class rpf.

new_data

Data for new observations to predict.

max_interaction

integer or NULL: Maximum degree of interactions to consider. Default will use the max_interaction parameter from the rpf object. Must be between 1 (main effects only) and the max_interaction of the rpf object.

predictors

character or NULL: Vector of one or more column names of predictor variables in new_data to extract components for. If NULL, all variables and their interactions are returned.

Value

A list with elements:

  • m (data.table): Components for each main effect and interaction term, representing the functional decomposition of the prediction. All components together with the intercept sum up to the prediction. For multiclass classification, the number of output columns is multiplied by the number of levels in the outcome.

  • intercept (numeric(1)): Expected value of the prediction.

  • x (data.table): Copy of new_data containing predictors selected by predictors.

  • target_levels (character): For multiclass classification only: Vector of target levels which can be used to disassemble m, as names include both term and target level.

Details

Extracts all possible components up to max_interaction degrees, up to the value set when calling rpf(). The intercept is always included. Optionally predictors can be specified to only include components including the given variables. If max_interaction is greater than length(predictors), the max_interaction will be lowered accordingly.

Note

Depending on the number of predictors and max_interaction, the number of components will increase drastically to sum(choose(ncol(new_data), seq_len(max_interaction))).

Examples


# Regression task, only some predictors
train <-  mtcars[1:20, 1:4]
test <-  mtcars[21:32, 1:4]

set.seed(23)
rpfit <- rpf(mpg ~ ., data = train, max_interaction = 3, ntrees = 30)

# Extract all components, including main effects and interaction terms up to `max_interaction`
(components <- predict_components(rpfit, test))
#> $m
#>            cyl       disp         hp    cyl:disp      cyl:hp     disp:hp
#>          <num>      <num>      <num>       <num>       <num>       <num>
#>  1:  1.3171056  2.4939663  1.1179246 -0.01482717 -0.04027828 -0.42825535
#>  2: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  3: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  4: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#>  5: -0.8481770 -1.3837212  0.1665770  0.09270936  0.08321648 -0.02961805
#>  6:  1.3171056  4.6572097  3.3746638  0.93009887  0.30303006  0.04533263
#>  7:  1.3171056  2.4939663  3.3746638 -0.01482717  0.30303006 -0.19752497
#>  8:  1.3171056  4.6572097  1.1318509  0.93009887 -0.35303353 -0.10831095
#>  9: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 10: -0.6441787  0.6437832  0.1665770  0.13550304  0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 12:  1.3171056  2.4939663  0.4186601 -0.01482717 -0.04027828 -0.27031439
#>      cyl:disp:hp
#>            <num>
#>  1: -0.046862620
#>  2:  0.009109767
#>  3:  0.009109767
#>  4:  0.084106744
#>  5:  0.008909305
#>  6:  0.077872127
#>  7:  0.101920822
#>  8:  0.011412415
#>  9:  0.084106744
#> 10: -0.024506959
#> 11:  0.084106744
#> 12: -0.046862620
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          

# sums to prediction
cbind(
  m_sum = rowSums(components$m) + components$intercept,
  prediction = predict(rpfit, test)
)
#>       m_sum    .pred
#> 1  24.67837 24.67837
#> 2  18.72601 18.72601
#> 3  18.72601 18.72601
#> 4  14.69731 14.69731
#> 5  18.36949 18.36949
#> 6  30.98491 30.98491
#> 7  27.65793 27.65793
#> 8  27.86593 27.86593
#> 9  14.69731 14.69731
#> 10 20.39919 20.39919
#> 11 14.69731 14.69731
#> 12 24.13705 24.13705

# Only get components with interactions of a lower degree, ignoring 3-way interactions
predict_components(rpfit, test, max_interaction = 2)
#> $m
#>            cyl       disp         hp    cyl:disp      cyl:hp     disp:hp
#>          <num>      <num>      <num>       <num>       <num>       <num>
#>  1:  1.3171056  2.4939663  1.1179246 -0.01482717 -0.04027828 -0.42825535
#>  2: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  3: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  4: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#>  5: -0.8481770 -1.3837212  0.1665770  0.09270936  0.08321648 -0.02961805
#>  6:  1.3171056  4.6572097  3.3746638  0.93009887  0.30303006  0.04533263
#>  7:  1.3171056  2.4939663  3.3746638 -0.01482717  0.30303006 -0.19752497
#>  8:  1.3171056  4.6572097  1.1318509  0.93009887 -0.35303353 -0.10831095
#>  9: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 10: -0.6441787  0.6437832  0.1665770  0.13550304  0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 12:  1.3171056  2.4939663  0.4186601 -0.01482717 -0.04027828 -0.27031439
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> $remainder
#>  [1] -0.046862620  0.009109767  0.009109767  0.084106744  0.008909305
#>  [6]  0.077872127  0.101920822  0.011412415  0.084106744 -0.024506959
#> [11]  0.084106744 -0.046862620
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          

# Only retrieve main effects
(main_effects <- predict_components(rpfit, test, max_interaction = 1))
#> $m
#>            cyl       disp         hp
#>          <num>      <num>      <num>
#>  1:  1.3171056  2.4939663  1.1179246
#>  2: -0.8481770 -1.1761913  0.2848618
#>  3: -0.8481770 -1.1761913  0.2848618
#>  4: -0.8481770 -1.1761913 -3.4381764
#>  5: -0.8481770 -1.3837212  0.1665770
#>  6:  1.3171056  4.6572097  3.3746638
#>  7:  1.3171056  2.4939663  3.3746638
#>  8:  1.3171056  4.6572097  1.1318509
#>  9: -0.8481770 -1.1761913 -3.4381764
#> 10: -0.6441787  0.6437832  0.1665770
#> 11: -0.8481770 -1.1761913 -3.4381764
#> 12:  1.3171056  2.4939663  0.4186601
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> $remainder
#>  [1] -0.53022343  0.18592174  0.18592174 -0.11973758  0.15521709  1.35633369
#>  [7]  0.19259874  0.48016681 -0.11973758 -0.04658934 -0.11973758 -0.37228247
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          

# The difference is the combined contribution of interaction effects
cbind(
  m_sum = rowSums(main_effects$m) + main_effects$intercept,
  prediction = predict(rpfit, test)
)
#>       m_sum    .pred
#> 1  25.20859 24.67837
#> 2  18.54009 18.72601
#> 3  18.54009 18.72601
#> 4  14.81705 14.69731
#> 5  18.21428 18.36949
#> 6  29.62858 30.98491
#> 7  27.46533 27.65793
#> 8  27.38576 27.86593
#> 9  14.81705 14.69731
#> 10 20.44578 20.39919
#> 11 14.81705 14.69731
#> 12 24.50933 24.13705