Skip to contents
library(xplainfi)
library(mlr3learners)
#> Loading required package: mlr3

# Data manip and visualization
library(data.table)
library(ggplot2)

There are multiple (work in progress) inference possible with the underlying implementation, but the API around them is still being worked out.

Setup

We use a simple linear DGP for demonstration purposes where

  • \(X_1\) and \(X_2\) are strongly correlated (r = 0.7)
  • \(X_1\) and \(X_3\) has an effect on Y
  • \(X_2\) and \(X_4\) don’t have an effect
task = sim_dgp_correlated(n = 2000, r = 0.7)
learner = lrn("regr.ranger", num.trees = 500)
measure = msr("regr.mse")

DAG for correlated features DGP

Variance-correction

When we calculate PFI using an appropriate resampling, such as subsampling with 15 repeats, we can use the approach recommended by Molnar et al. (2023) based on the proposed correction by Nadeay & Bengio (2003).

By default, any importance measures’ $importance() method will not output any variances or confidence intervals, it will merely compute averages over resampling iterations and repeats within resamplings (iter_repeat here).

pfi = PFI$new(
    task = task,
    learner = learner,
    resampling = rsmp("subsampling", repeats = 15),
    measure = measure,
    n_repeats = 10 # for stability within resampling iters
)

pfi$compute()
pfi$importance()
#> Key: <feature>
#>    feature   importance
#>     <char>        <num>
#> 1:      x1  6.451498301
#> 2:      x2  0.094961175
#> 3:      x3  1.801051193
#> 4:      x4 -0.001016452

If we want unadjusted confidence intervals based on a t-distribution, we can ask for them, but note these are too narrow / optimistic and hence invalid for inference:

pfi_ci_raw = pfi$importance(ci_method = "raw")
pfi_ci_raw
#> Key: <feature>
#>    feature   importance           se   conf_lower    conf_upper
#>     <char>        <num>        <num>        <num>         <num>
#> 1:      x1  6.451498301 0.0777539664  6.284732629  6.6182639734
#> 2:      x2  0.094961175 0.0041677716  0.086022194  0.1039001566
#> 3:      x3  1.801051193 0.0157425383  1.767286806  1.8348155792
#> 4:      x4 -0.001016452 0.0003145202 -0.001691031 -0.0003418737

Analogously we can retrieve the Nadeau & Bengio-adjusted standard errors and derived confidence intervals which were demonstrated to have better (but still imperfect) coverage:

pfi_ci_corrected = pfi$importance(ci_method = "nadeau_bengio")
pfi_ci_corrected
#> Key: <feature>
#>    feature   importance           se   conf_lower   conf_upper
#>     <char>        <num>        <num>        <num>        <num>
#> 1:      x1  6.451498301 0.2267648327  5.965136107 6.9378604959
#> 2:      x2  0.094961175 0.0121550589  0.068891167 0.1210311840
#> 3:      x3  1.801051193 0.0459121794  1.702579361 1.8995230240
#> 4:      x4 -0.001016452 0.0009172796 -0.002983821 0.0009509165

Empirical quantiles

Both "raw" and "nadeau_bengio" methods assume normally distributed importance scores and use parametric confidence intervals based on the t-distribution. As a alternative, we can use empirical quantiles to construct confidence-like intervals without any coverage guarantees.

pfi_ci_quantile = pfi$importance(ci_method = "quantile")
pfi_ci_quantile
#> Key: <feature>
#>    feature   importance   conf_lower   conf_upper
#>     <char>        <num>        <num>        <num>
#> 1:      x1  6.451498301  5.953713874 6.9244524536
#> 2:      x2  0.094961175  0.074956969 0.1216748325
#> 3:      x3  1.801051193  1.695615492 1.9012580196
#> 4:      x4 -0.001016452 -0.003636112 0.0005228665

To highlight the differences between parametric and empirical approaches, we visualize all methods:

pfi_cis = rbindlist(
    list(
        pfi_ci_raw[, type := "raw"],
        pfi_ci_corrected[, type := "nadeau_bengio"],
        pfi_ci_quantile[, type := "quantile"]
    ),
    fill = TRUE
)

ggplot(pfi_cis, aes(y = feature, color = type)) +
    geom_errorbar(
        aes(xmin = conf_lower, xmax = conf_upper),
        position = position_dodge(width = 0.6),
        width = .5
    ) +
    geom_point(aes(x = importance), position = position_dodge(width = 0.6)) +
    scale_color_brewer(palette = "Set2") +
    labs(
        title = "Parametric & non-parametric CI methods",
        subtitle = "RF with 15 subsampling iterations",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "bottom")

The results highlight just how optimistic the unadjusted, raw confidence intervals are.

Conditional predictive impact (CPI)

CPI is implemented by the cpi package, and provides conditional variable importance using knockoffs. It works with mlr3 and its output on our data looks like this:

library(cpi)

resampling = rsmp("cv", folds = 5)
resampling$instantiate(task)
cpi_res = cpi(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    test = "t"
)
setDT(cpi_res)
setnames(cpi_res, "Variable", "feature")
cpi_res[, method := "CPI"]

cpi_res
#>    feature          CPI          SE   test   statistic     estimate
#>     <char>        <num>       <num> <char>       <num>        <num>
#> 1:      x1 4.554770e+00 0.142063012      t 32.06161590 4.554770e+00
#> 2:      x2 2.479253e-03 0.003012508      t  0.82298643 2.479253e-03
#> 3:      x3 1.848159e+00 0.058189043      t 31.76129690 1.848159e+00
#> 4:      x4 4.283886e-05 0.000696889      t  0.06147158 4.283886e-05
#>          p.value        ci.lo method
#>            <num>        <num> <char>
#> 1: 1.209450e-182  4.320988517    CPI
#> 2:  2.053069e-01 -0.002478179    CPI
#> 3: 6.961317e-180  1.752402643    CPI
#> 4:  4.754949e-01 -0.001103973    CPI

CPI with knockoffs

Since xplainfi also includes knockoffs via the KnockoffSampler and the KnockoffGaussianSampler, the latter implementing the second order Gaussian knockoffs also used by default in cpi, we can recreate its results using CFI with the corresponding sampler.

CFI with a knockoff sampler supports CPI inference directly via ci_method = "cpi":

knockoff_gaussian = KnockoffGaussianSampler$new(task)

cfi = CFI$new(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    sampler = knockoff_gaussian
)

cfi$compute()

# CPI uses observation-wise losses with one-sided t-test by default
cfi_cpi_res = cfi$importance(ci_method = "cpi")
cfi_cpi_res
#> Key: <feature>
#>    feature    importance           se  statistic       p.value   conf_lower
#>     <char>         <num>        <num>      <num>         <num>        <num>
#> 1:      x1  4.6427559476 0.1478836872 31.3946456 1.587955e-176  4.399396149
#> 2:      x2  0.0026423506 0.0033612948  0.7861109  2.159478e-01 -0.002889051
#> 3:      x3  1.7620712330 0.0582790374 30.2350779 5.393726e-166  1.666166302
#> 4:      x4 -0.0002589976 0.0007047195 -0.3675188  6.433645e-01 -0.001418696
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf

# Rename columns to match cpi package output for comparison
setnames(cfi_cpi_res, c("importance", "conf_lower"), c("CPI", "ci.lo"))
cfi_cpi_res[, method := "CFI+Knockoffs"]

The results should be very similar to those computed by cpi(), so let’s compare them:

rbindlist(list(cpi_res, cfi_cpi_res), fill = TRUE) |>
    ggplot(aes(y = feature, x = CPI, color = method)) +
    geom_point(position = position_dodge(width = 0.3)) +
    geom_errorbar(
        aes(xmin = CPI, xmax = ci.lo),
        position = position_dodge(width = 0.3),
        width = 0.5
    ) +
    scale_color_brewer(palette = "Dark2") +
    labs(
        title = "CPI and CFI with Knockoff sampler",
        subtitle = "RF with 5-fold CV",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "top")

A noteable caveat of the knockoff approach is that they are not readily available for mixed data (with categorical features).

CPI with ARF

An alternative is available using ARF as conditional sampler rather than knockoffs (CITE cARFi), which we can perform analogously:

arf_sampler = ConditionalARFSampler$new(
    task = task,
    finite_bounds = "local",
    min_node_size = 20,
    epsilon = 1e-15
)

cfi_arf = CFI$new(
    task = task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    sampler = arf_sampler
)

cfi_arf$compute()

# CPI uses observation-wise losses with one-sided t-test
cfi_arf_res = cfi_arf$importance(ci_method = "cpi")
cfi_arf_res
#> Key: <feature>
#>    feature   importance           se statistic       p.value    conf_lower
#>     <char>        <num>        <num>     <num>         <num>         <num>
#> 1:      x1  4.030582334 0.1301510177 30.968504 1.222410e-172  3.8164037048
#> 2:      x2  0.004954503 0.0032955819  1.503377  6.644987e-02 -0.0004687601
#> 3:      x3  1.709224166 0.0535822827 31.899055 3.781474e-181  1.6210482908
#> 4:      x4 -0.001110782 0.0007197487 -1.543292  9.385409e-01 -0.0022952124
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf

# Rename columns to match cpi package output for comparison
setnames(cfi_arf_res, c("importance", "conf_lower"), c("CPI", "ci.lo"))
cfi_arf_res[, method := "CFI+ARF"]

We can now compare all three methods:

rbindlist(list(cpi_res, cfi_cpi_res, cfi_arf_res), fill = TRUE) |>
    ggplot(aes(y = feature, x = CPI, color = method)) +
    geom_point(position = position_dodge(width = 0.3)) +
    geom_errorbar(
        aes(xmin = CPI, xmax = ci.lo),
        position = position_dodge(width = 0.3),
        width = 0.5
    ) +
    scale_color_brewer(palette = "Dark2") +
    labs(
        title = "CPI and CFI with Knockoffs and ARF",
        subtitle = "RF with 5-fold CV",
        color = NULL
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "top")

As expected, the ARF-based approach differs more from both knockoff-based approaches, but they are all roughly in agreement.

Statistical tests with CPI

CPI can also perform additional tests besides the default t-test, specifically the Wilcoxon-, Fisher-, or binomial test:

(cpi_res_wilcoxon = cfi_arf$importance(ci_method = "cpi", test = "wilcoxon"))
#> Key: <feature>
#>    feature   importance           se statistic       p.value    conf_lower
#>     <char>        <num>        <num>     <num>         <num>         <num>
#> 1:      x1  4.030582334 0.1301510177   1967386 3.716827e-309  2.5659825422
#> 2:      x2  0.004954503 0.0032955819    992983  5.390444e-01 -0.0008211084
#> 3:      x3  1.709224166 0.0535822827   1929081 1.697037e-288  1.1465162882
#> 4:      x4 -0.001110782 0.0007197487    964259  8.802801e-01 -0.0005923333
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf
# Fisher test with same default for B as in cpi()
(cpi_res_fisher = cfi_arf$importance(ci_method = "cpi", test = "fisher", B = 1999))
#> Key: <feature>
#>    feature   importance           se p.value    conf_lower conf_upper
#>     <char>        <num>        <num>   <num>         <num>      <num>
#> 1:      x1  4.030582334 0.1301510177  0.0005  3.7836382934        Inf
#> 2:      x2  0.004954503 0.0032955819  0.0750 -0.0006203853        Inf
#> 3:      x3  1.709224166 0.0535822827  0.0005  1.5992602629        Inf
#> 4:      x4 -0.001110782 0.0007197487  0.9320 -0.0023060964        Inf
(cpi_res_binom = cfi_arf$importance(ci_method = "cpi", test = "binomial"))
#> Key: <feature>
#>    feature   importance           se statistic       p.value conf_lower
#>     <char>        <num>        <num>     <num>         <num>      <num>
#> 1:      x1  4.030582334 0.1301510177      1861  0.000000e+00  0.9204192
#> 2:      x2  0.004954503 0.0032955819       991  6.645228e-01  0.4768744
#> 3:      x3  1.709224166 0.0535822827      1781 3.333550e-304  0.8783373
#> 4:      x4 -0.001110782 0.0007197487       962  9.574574e-01  0.4624027
#>    conf_upper
#>         <num>
#> 1:        Inf
#> 2:        Inf
#> 3:        Inf
#> 4:        Inf

rbindlist(
    list(
        cfi_arf$importance(ci_method = "cpi")[, test := "t"],
        cpi_res_wilcoxon[, test := "Wilcoxon"],
        cpi_res_fisher[, test := "Fisher"],
        cpi_res_binom[, test := "Binomial"]
    ),
    fill = TRUE
) |>
    ggplot(aes(y = feature, x = importance, color = test)) +
    geom_point(position = position_dodge(width = 0.3)) +
    geom_errorbar(
        aes(xmin = importance, xmax = conf_lower),
        position = position_dodge(width = 0.3),
        width = 0.5
    ) +
    scale_color_brewer(palette = "Dark2") +
    labs(
        title = "CPI test with CFI/ARF",
        subtitle = "RF with 5-fold CV",
        color = "Test"
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "top")

Given the width of the resulting confidence intervals, the Fisher- or t-test are generally recommended.

LOCO (WIP)

(CITATION) proposed inference for LOCO using the median absolute differences of the baseline- and post-refit loss differences

\[ \theta_j = \mathrm{med}\left( |Y - \hat{f}_{n_1}^{-j}(X)| - |Y - \hat{f}_{n_1}(X)| \big| D_1 \right) \]

If we apply LOCO as implemented in xplainfi using the median absolute error (MAE) as our measure including the median as the aggregation function, we unfortunately get something else, though:

measure_mae = msr("regr.mae")
measure_mae$aggregator = median

loco = LOCO$new(
    task = task,
    learner = learner,
    resampling = rsmp("holdout"),
    measure = measure_mae
)

loco$compute()
loco$importance()
#> Key: <feature>
#>    feature importance
#>     <char>      <num>
#> 1:      x1 0.98148070
#> 2:      x2 0.05038684
#> 3:      x3 0.62258117
#> 4:      x4 0.05334160

This is not exactly what the authors propose, because $score() calculates the aggregation function (median) for each resampling iteration first, and takes the difference afterwards, i.e.

\[ \theta_j = \mathrm{med}\left(|Y - \hat{f}_{n_1}^{-j}(X)|\right) - \mathrm{med}\left(|Y - \hat{f}_{n_1}(X)| \big| D_1 \right) \]

In the default case where the arithemtic mean is used, it does not matter whether we calculate the difference of the means or the mean of the differences, but using the median it does.

We can, however, reconstruct it by using the observation-wise losses (in this case, the absolute error):

loco_obsloss = loco$obs_loss()
head(loco_obsloss)
#>    feature iter_rsmp iter_repeat row_ids loss_baseline loss_post obs_importance
#>     <char>     <int>       <num>   <int>         <num>     <num>          <num>
#> 1:      x1         1           1       4    0.11147320 1.0227631     0.91128990
#> 2:      x1         1           1       7    0.30561825 0.8431147     0.53749646
#> 3:      x1         1           1      15    0.23994174 2.5361659     2.29622412
#> 4:      x1         1           1      20    0.21063594 0.1680135    -0.04262239
#> 5:      x1         1           1      22    0.07570713 2.9775583     2.90185118
#> 6:      x1         1           1      23    0.02199827 2.5518444     2.52984614

obs_importance here refers to the difference loss_post - loss_baseline, so

  • loss_baseline $ = |Y - _{n_1}(X)|$
  • loss_post $ = |Y - _{n_1}^{-j}(X)|$
  • obs_importance = loss_post - loss_baseline

Which means by taking the median for each feature \(j\) within each resampling iteration, we can construct \(\theta_j(D_1)\) as proposed, for each set \(D_k\) where \(k\) is the resampling iteration:

loco_thetas = loco_obsloss[, list(theta = median(obs_importance)), by = c("feature")]
loco_thetas
#>    feature      theta
#>     <char>      <num>
#> 1:      x1 0.83911185
#> 2:      x2 0.02684860
#> 3:      x3 0.51935341
#> 4:      x4 0.03344017

The authors then propose to construct distribution-free confidence intervals, e.g. using a sign- or Wilcoxon test We can for example use [wilcoxon.test] to compute confidence intervals around the estimated pseudo-median:

loco_wilcox_ci = loco_obsloss[,
    {
        tt <- wilcox.test(
            obs_importance,
            conf.int = TRUE,
            conf.level = 0.95
        )
        .(
            statistic = tt$statistic,
            estimate = tt$estimate, # the pseudomedian importance
            p.value = tt$p.value,
            conf_lower = tt$conf.int[1],
            conf_upper = tt$conf.int[2]
        )
    },
    by = feature
]

loco_wilcox_ci
#>    feature statistic   estimate      p.value conf_lower conf_upper
#>     <char>     <num>      <num>        <num>      <num>      <num>
#> 1:      x1    216917 0.91943944 1.012455e-99 0.85057682 0.98969443
#> 2:      x2    144722 0.03419386 2.149646e-11 0.02424473 0.04449679
#> 3:      x3    213532 0.59231575 1.508448e-93 0.54444761 0.64071906
#> 4:      x4    162767 0.04423316 5.712254e-25 0.03565144 0.05305982

Note: The above approach needs checking with the literature to ensure it’s actually corresponding to what was proposed and the results are valid.

The main point of this section is to illustrate that the availability of the intermediate parts (i.e. obs losses) and flexibility regarding the used measure allows for flexibility in terms of inference.