Extract predicted components from a Random Planted Forest
Source:R/predict_components.R
      predict_components.RdPrediction components are a functional decomposition of the model prediction. The sum of all components equals the overall predicted value for an observation.
Arguments
- object
- A fit object of class - rpf.
- new_data
- Data for new observations to predict. 
- max_interaction
- integeror- NULL: Maximum degree of interactions to consider. Default will use the- max_interactionparameter from the- rpfobject. Must be between- 1(main effects only) and the- max_interactionof the- rpfobject.
- predictors
- characteror- NULL: Vector of one or more column names of predictor variables in- new_datato extract components for. If- NULL, all variables and their interactions are returned.
Value
A list with elements:
- m(- data.table): Components for each main effect and interaction term, representing the functional decomposition of the prediction. All components together with the intercept sum up to the prediction. For multiclass classification, the number of output columns is multiplied by the number of levels in the outcome.
- intercept(- numeric(1)): Expected value of the prediction.
- x(- data.table): Copy of- new_datacontaining predictors selected by- predictors.
- target_levels(- character): For multiclass classification only: Vector of target levels which can be used to disassemble- m, as names include both term and target level.
Details
Extracts all possible components up to max_interaction degrees,
up to the value set when calling rpf(). The intercept is always included.
Optionally predictors can be specified to only include components including the given variables.
If max_interaction is greater than length(predictors), the max_interaction will be lowered accordingly.
Note
Depending on the number of predictors and max_interaction, the number of components will
increase drastically to sum(choose(ncol(new_data), seq_len(max_interaction))).
Examples
# Regression task, only some predictors
train <-  mtcars[1:20, 1:4]
test <-  mtcars[21:32, 1:4]
set.seed(23)
rpfit <- rpf(mpg ~ ., data = train, max_interaction = 3, ntrees = 30)
# Extract all components, including main effects and interaction terms up to `max_interaction`
(components <- predict_components(rpfit, test))
#> $m
#>            cyl       disp         hp    cyl:disp      cyl:hp     disp:hp
#>          <num>      <num>      <num>       <num>       <num>       <num>
#>  1:  1.3171056  2.4939663  1.1179246 -0.01482717 -0.04027828 -0.42825535
#>  2: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  3: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  4: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#>  5: -0.8481770 -1.3837212  0.1665770  0.09270936  0.08321648 -0.02961805
#>  6:  1.3171056  4.6572097  3.3746638  0.93009887  0.30303006  0.04533263
#>  7:  1.3171056  2.4939663  3.3746638 -0.01482717  0.30303006 -0.19752497
#>  8:  1.3171056  4.6572097  1.1318509  0.93009887 -0.35303353 -0.10831095
#>  9: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 10: -0.6441787  0.6437832  0.1665770  0.13550304  0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 12:  1.3171056  2.4939663  0.4186601 -0.01482717 -0.04027828 -0.27031439
#>      cyl:disp:hp
#>            <num>
#>  1: -0.046862620
#>  2:  0.009109767
#>  3:  0.009109767
#>  4:  0.084106744
#>  5:  0.008909305
#>  6:  0.077872127
#>  7:  0.101920822
#>  8:  0.011412415
#>  9:  0.084106744
#> 10: -0.024506959
#> 11:  0.084106744
#> 12: -0.046862620
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          
# sums to prediction
cbind(
  m_sum = rowSums(components$m) + components$intercept,
  prediction = predict(rpfit, test)
)
#>       m_sum    .pred
#> 1  24.67837 24.67837
#> 2  18.72601 18.72601
#> 3  18.72601 18.72601
#> 4  14.69731 14.69731
#> 5  18.36949 18.36949
#> 6  30.98491 30.98491
#> 7  27.65793 27.65793
#> 8  27.86593 27.86593
#> 9  14.69731 14.69731
#> 10 20.39919 20.39919
#> 11 14.69731 14.69731
#> 12 24.13705 24.13705
# Only get components with interactions of a lower degree, ignoring 3-way interactions
predict_components(rpfit, test, max_interaction = 2)
#> $m
#>            cyl       disp         hp    cyl:disp      cyl:hp     disp:hp
#>          <num>      <num>      <num>       <num>       <num>       <num>
#>  1:  1.3171056  2.4939663  1.1179246 -0.01482717 -0.04027828 -0.42825535
#>  2: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  3: -0.8481770 -1.1761913  0.2848618  0.09476257  0.06668192  0.01536748
#>  4: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#>  5: -0.8481770 -1.3837212  0.1665770  0.09270936  0.08321648 -0.02961805
#>  6:  1.3171056  4.6572097  3.3746638  0.93009887  0.30303006  0.04533263
#>  7:  1.3171056  2.4939663  3.3746638 -0.01482717  0.30303006 -0.19752497
#>  8:  1.3171056  4.6572097  1.1318509  0.93009887 -0.35303353 -0.10831095
#>  9: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 10: -0.6441787  0.6437832  0.1665770  0.13550304  0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764  0.09476257  0.06607252 -0.36467941
#> 12:  1.3171056  2.4939663  0.4186601 -0.01482717 -0.04027828 -0.27031439
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> $remainder
#>  [1] -0.046862620  0.009109767  0.009109767  0.084106744  0.008909305
#>  [6]  0.077872127  0.101920822  0.011412415  0.084106744 -0.024506959
#> [11]  0.084106744 -0.046862620
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          
# Only retrieve main effects
(main_effects <- predict_components(rpfit, test, max_interaction = 1))
#> $m
#>            cyl       disp         hp
#>          <num>      <num>      <num>
#>  1:  1.3171056  2.4939663  1.1179246
#>  2: -0.8481770 -1.1761913  0.2848618
#>  3: -0.8481770 -1.1761913  0.2848618
#>  4: -0.8481770 -1.1761913 -3.4381764
#>  5: -0.8481770 -1.3837212  0.1665770
#>  6:  1.3171056  4.6572097  3.3746638
#>  7:  1.3171056  2.4939663  3.3746638
#>  8:  1.3171056  4.6572097  1.1318509
#>  9: -0.8481770 -1.1761913 -3.4381764
#> 10: -0.6441787  0.6437832  0.1665770
#> 11: -0.8481770 -1.1761913 -3.4381764
#> 12:  1.3171056  2.4939663  0.4186601
#> 
#> $intercept
#> [1] 20.2796
#> 
#> $x
#>       cyl  disp    hp
#>     <num> <num> <num>
#>  1:     4 120.1    97
#>  2:     8 318.0   150
#>  3:     8 304.0   150
#>  4:     8 350.0   245
#>  5:     8 400.0   175
#>  6:     4  79.0    66
#>  7:     4 120.3    91
#>  8:     4  95.1   113
#>  9:     8 351.0   264
#> 10:     6 145.0   175
#> 11:     8 301.0   335
#> 12:     4 121.0   109
#> 
#> $remainder
#>  [1] -0.53022343  0.18592174  0.18592174 -0.11973758  0.15521709  1.35633369
#>  [7]  0.19259874  0.48016681 -0.11973758 -0.04658934 -0.11973758 -0.37228247
#> 
#> attr(,"class")
#> [1] "glex"           "rpf_components" "list"          
# The difference is the combined contribution of interaction effects
cbind(
  m_sum = rowSums(main_effects$m) + main_effects$intercept,
  prediction = predict(rpfit, test)
)
#>       m_sum    .pred
#> 1  25.20859 24.67837
#> 2  18.54009 18.72601
#> 3  18.54009 18.72601
#> 4  14.81705 14.69731
#> 5  18.21428 18.36949
#> 6  29.62858 30.98491
#> 7  27.46533 27.65793
#> 8  27.38576 27.86593
#> 9  14.81705 14.69731
#> 10 20.44578 20.39919
#> 11 14.81705 14.69731
#> 12 24.50933 24.13705