Extract predicted components from a Random Planted Forest
Source:R/predict_components.R
predict_components.Rd
Prediction components are a functional decomposition of the model prediction. The sum of all components equals the overall predicted value for an observation.
Arguments
- object
A fit object of class
rpf
.- new_data
Data for new observations to predict.
- max_interaction
integer
orNULL
: Maximum degree of interactions to consider. Default will use themax_interaction
parameter from therpf
object. Must be between1
(main effects only) and themax_interaction
of therpf
object.- predictors
character
orNULL
: Vector of one or more column names of predictor variables innew_data
to extract components for. IfNULL
, all variables and their interactions are returned.
Value
A list
with elements:
m
(data.table
): Components for each main effect and interaction term, representing the functional decomposition of the prediction. All components together with the intercept sum up to the prediction. For multiclass classification, the number of output columns is multiplied by the number of levels in the outcome.intercept
(numeric(1)
): Expected value of the prediction.x
(data.table
): Copy ofnew_data
containing predictors selected bypredictors
.target_levels
(character
): For multiclass classification only: Vector of target levels which can be used to disassemblem
, as names include both term and target level.
Details
Extracts all possible components up to max_interaction
degrees,
up to the value set when calling rpf()
. The intercept is always included.
Optionally predictors
can be specified to only include components including the given variables.
If max_interaction
is greater than length(predictors)
, the max_interaction
will be lowered accordingly.
Note
Depending on the number of predictors and max_interaction
, the number of components will
increase drastically to sum(choose(ncol(new_data), seq_len(max_interaction)))
.
Examples
# Regression task, only some predictors
train <- mtcars[1:20, 1:4]
test <- mtcars[21:32, 1:4]
set.seed(23)
rpfit <- rpf(mpg ~ ., data = train, max_interaction = 3, ntrees = 30)
# Extract all components, including main effects and interaction terms up to `max_interaction`
(components <- predict_components(rpfit, test))
#> $m
#> cyl disp hp cyl:disp cyl:hp disp:hp
#> <num> <num> <num> <num> <num> <num>
#> 1: 1.3171056 2.4939663 1.1179246 -0.01482717 -0.04027828 -0.42825535
#> 2: -0.8481770 -1.1761913 0.2848618 0.09476257 0.06668192 0.01536748
#> 3: -0.8481770 -1.1761913 0.2848618 0.09476257 0.06668192 0.01536748
#> 4: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 5: -0.8481770 -1.3837212 0.1665770 0.09270936 0.08321648 -0.02961805
#> 6: 1.3171056 4.6572097 3.3746638 0.93009887 0.30303006 0.04533263
#> 7: 1.3171056 2.4939663 3.3746638 -0.01482717 0.30303006 -0.19752497
#> 8: 1.3171056 4.6572097 1.1318509 0.93009887 -0.35303353 -0.10831095
#> 9: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 10: -0.6441787 0.6437832 0.1665770 0.13550304 0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 12: 1.3171056 2.4939663 0.4186601 -0.01482717 -0.04027828 -0.27031439
#> cyl:disp:hp
#> <num>
#> 1: -0.046862620
#> 2: 0.009109767
#> 3: 0.009109767
#> 4: 0.084106744
#> 5: 0.008909305
#> 6: 0.077872127
#> 7: 0.101920822
#> 8: 0.011412415
#> 9: 0.084106744
#> 10: -0.024506959
#> 11: 0.084106744
#> 12: -0.046862620
#>
#> $intercept
#> [1] 20.2796
#>
#> $x
#> cyl disp hp
#> <num> <num> <num>
#> 1: 4 120.1 97
#> 2: 8 318.0 150
#> 3: 8 304.0 150
#> 4: 8 350.0 245
#> 5: 8 400.0 175
#> 6: 4 79.0 66
#> 7: 4 120.3 91
#> 8: 4 95.1 113
#> 9: 8 351.0 264
#> 10: 6 145.0 175
#> 11: 8 301.0 335
#> 12: 4 121.0 109
#>
#> attr(,"class")
#> [1] "glex" "rpf_components" "list"
# sums to prediction
cbind(
m_sum = rowSums(components$m) + components$intercept,
prediction = predict(rpfit, test)
)
#> m_sum .pred
#> 1 24.67837 24.67837
#> 2 18.72601 18.72601
#> 3 18.72601 18.72601
#> 4 14.69731 14.69731
#> 5 18.36949 18.36949
#> 6 30.98491 30.98491
#> 7 27.65793 27.65793
#> 8 27.86593 27.86593
#> 9 14.69731 14.69731
#> 10 20.39919 20.39919
#> 11 14.69731 14.69731
#> 12 24.13705 24.13705
# Only get components with interactions of a lower degree, ignoring 3-way interactions
predict_components(rpfit, test, max_interaction = 2)
#> $m
#> cyl disp hp cyl:disp cyl:hp disp:hp
#> <num> <num> <num> <num> <num> <num>
#> 1: 1.3171056 2.4939663 1.1179246 -0.01482717 -0.04027828 -0.42825535
#> 2: -0.8481770 -1.1761913 0.2848618 0.09476257 0.06668192 0.01536748
#> 3: -0.8481770 -1.1761913 0.2848618 0.09476257 0.06668192 0.01536748
#> 4: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 5: -0.8481770 -1.3837212 0.1665770 0.09270936 0.08321648 -0.02961805
#> 6: 1.3171056 4.6572097 3.3746638 0.93009887 0.30303006 0.04533263
#> 7: 1.3171056 2.4939663 3.3746638 -0.01482717 0.30303006 -0.19752497
#> 8: 1.3171056 4.6572097 1.1318509 0.93009887 -0.35303353 -0.10831095
#> 9: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 10: -0.6441787 0.6437832 0.1665770 0.13550304 0.09153429 -0.24911970
#> 11: -0.8481770 -1.1761913 -3.4381764 0.09476257 0.06607252 -0.36467941
#> 12: 1.3171056 2.4939663 0.4186601 -0.01482717 -0.04027828 -0.27031439
#>
#> $intercept
#> [1] 20.2796
#>
#> $x
#> cyl disp hp
#> <num> <num> <num>
#> 1: 4 120.1 97
#> 2: 8 318.0 150
#> 3: 8 304.0 150
#> 4: 8 350.0 245
#> 5: 8 400.0 175
#> 6: 4 79.0 66
#> 7: 4 120.3 91
#> 8: 4 95.1 113
#> 9: 8 351.0 264
#> 10: 6 145.0 175
#> 11: 8 301.0 335
#> 12: 4 121.0 109
#>
#> $remainder
#> [1] -0.046862620 0.009109767 0.009109767 0.084106744 0.008909305
#> [6] 0.077872127 0.101920822 0.011412415 0.084106744 -0.024506959
#> [11] 0.084106744 -0.046862620
#>
#> attr(,"class")
#> [1] "glex" "rpf_components" "list"
# Only retrieve main effects
(main_effects <- predict_components(rpfit, test, max_interaction = 1))
#> $m
#> cyl disp hp
#> <num> <num> <num>
#> 1: 1.3171056 2.4939663 1.1179246
#> 2: -0.8481770 -1.1761913 0.2848618
#> 3: -0.8481770 -1.1761913 0.2848618
#> 4: -0.8481770 -1.1761913 -3.4381764
#> 5: -0.8481770 -1.3837212 0.1665770
#> 6: 1.3171056 4.6572097 3.3746638
#> 7: 1.3171056 2.4939663 3.3746638
#> 8: 1.3171056 4.6572097 1.1318509
#> 9: -0.8481770 -1.1761913 -3.4381764
#> 10: -0.6441787 0.6437832 0.1665770
#> 11: -0.8481770 -1.1761913 -3.4381764
#> 12: 1.3171056 2.4939663 0.4186601
#>
#> $intercept
#> [1] 20.2796
#>
#> $x
#> cyl disp hp
#> <num> <num> <num>
#> 1: 4 120.1 97
#> 2: 8 318.0 150
#> 3: 8 304.0 150
#> 4: 8 350.0 245
#> 5: 8 400.0 175
#> 6: 4 79.0 66
#> 7: 4 120.3 91
#> 8: 4 95.1 113
#> 9: 8 351.0 264
#> 10: 6 145.0 175
#> 11: 8 301.0 335
#> 12: 4 121.0 109
#>
#> $remainder
#> [1] -0.53022343 0.18592174 0.18592174 -0.11973758 0.15521709 1.35633369
#> [7] 0.19259874 0.48016681 -0.11973758 -0.04658934 -0.11973758 -0.37228247
#>
#> attr(,"class")
#> [1] "glex" "rpf_components" "list"
# The difference is the combined contribution of interaction effects
cbind(
m_sum = rowSums(main_effects$m) + main_effects$intercept,
prediction = predict(rpfit, test)
)
#> m_sum .pred
#> 1 25.20859 24.67837
#> 2 18.54009 18.72601
#> 3 18.54009 18.72601
#> 4 14.81705 14.69731
#> 5 18.21428 18.36949
#> 6 29.62858 30.98491
#> 7 27.46533 27.65793
#> 8 27.38576 27.86593
#> 9 14.81705 14.69731
#> 10 20.44578 20.39919
#> 11 14.81705 14.69731
#> 12 24.50933 24.13705