Base class for SAGE (Shapley Additive Global Importance) feature importance based on Shapley values with marginalization. This is an abstract class - use MarginalSAGE or ConditionalSAGE.
Details
SAGE uses Shapley values to fairly distribute the total prediction performance among all features. Unlike perturbation-based methods, SAGE marginalizes features by integrating over their distribution. This is approximated by averaging predictions over a reference dataset.
Standard Error Calculation: The standard errors (SE) reported in
$convergence_history reflect the uncertainty in Shapley value estimation
across different random permutations within a single resampling iteration.
These SEs quantify the Monte Carlo sampling error for a fixed trained model
and are only valid for inference about the importance of features for that
specific model. They do not capture broader uncertainty from model variability
across different train/test splits or resampling iterations.
References
Covert I, Lundberg S, Lee S (2020). “Understanding Global Feature Contributions With Additive Importance Measures.” In Advances in Neural Information Processing Systems, volume 33, 17212–17223. https://proceedings.neurips.cc/paper/2020/hash/c7bf0b7c1a86d5eb3be2c722cf2cf746-Abstract.html.
Public fields
n_permutations(
integer(1)) Number of permutations to sample.convergence_history(
data.table) History of SAGE values during computation.converged(
logical(1)) Whether convergence was detected.n_permutations_used(
integer(1)) Actual number of permutations used.
Methods
Method new()
Creates a new instance of the SAGE class.
Usage
SAGE$new(
task,
learner,
measure,
resampling = NULL,
features = NULL,
n_permutations = 10L,
batch_size = 5000L,
n_samples = 100L,
early_stopping = TRUE,
se_threshold = 0.01,
min_permutations = 3L,
check_interval = 1L
)Arguments
task, learner, measure, resampling, featuresPassed to FeatureImportanceMethod.
n_permutations(
integer(1):10L) Number of permutations to sample for SAGE value estimation. The total number of evaluated coalitions is1 (empty) + n_permutations * n_features.batch_size(
integer(1):5000L) Maximum number of observations to process in a single prediction call.n_samples(
integer(1):100L) Number of samples to use for marginalizing out-of-coalition features. For MarginalSAGE, this is the number of marginal data samples ("background data" in other implementations). For ConditionalSAGE, this is the number of conditional samples per test instance retrieved fromsampler.early_stopping(
logical(1):TRUE) Whether to enable early stopping based on convergence detection.se_threshold(
numeric(1):0.01) Convergence threshold for relative standard error. Convergence is detected when the maximum relative SE across all features falls below this threshold. Relative SE is calculated as SE divided by the range of importance values (max - min), making it scale-invariant across different loss metrics. Default of0.01means convergence when relative SE is below 1% of the importance range.min_permutations(
integer(1):3L) Minimum permutations before checking for convergence.check_interval(
integer(1):1L) Check convergence every N permutations.
Method compute()
Compute SAGE values.
Usage
SAGE$compute(
store_backends = TRUE,
batch_size = NULL,
early_stopping = NULL,
se_threshold = NULL,
min_permutations = NULL,
check_interval = NULL
)Arguments
store_backends(
logical(1)) Whether to store data backends.batch_size(
integer(1):5000L) Maximum number of observations to process in a single prediction call.early_stopping(
logical(1):TRUE) Whether to check for convergence and stop early.se_threshold(
numeric(1):0.01) Convergence threshold for relative standard error. SE is normalized by the range of importance values (max - min) to make convergence detection scale-invariant. Default0.01means convergence when relative SE < 1%.min_permutations(
integer(1):3L) Minimum permutations before checking convergence.check_interval(
integer(1):1L) Check convergence every N permutations.