Implements conditional sampling assuming features follow a multivariate Gaussian distribution. Computes conditional distributions analytically using standard formulas for multivariate normal distributions.
Details
For a joint Gaussian distribution \(X \sim N(\mu, \Sigma)\), partitioned as \(X = (X_A, X_B)\), the conditional distribution is:
$$X_B | X_A = x_A \sim N(\mu_{B|A}, \Sigma_{B|A})$$
where: $$\mu_{B|A} = \mu_B + \Sigma_{BA} \Sigma_{AA}^{-1} (x_A - \mu_A)$$ $$\Sigma_{B|A} = \Sigma_{BB} - \Sigma_{BA} \Sigma_{AA}^{-1} \Sigma_{AB}$$
This is equivalent to the regression formulation used by fippy: $$\beta = \Sigma_{BA} \Sigma_{AA}^{-1}$$ $$\mu_{B|A} = \mu_B + \beta (x_A - \mu_A)$$ $$\Sigma_{B|A} = \Sigma_{BB} - \beta \Sigma_{AB}$$
Assumptions:
Features are approximately multivariate normal
Only continuous features are supported
Advantages:
Very fast (closed-form solution)
Deterministic (given seed)
No hyperparameters
Memory efficient
Limitations:
Strong distributional assumption
May produce out-of-range values for bounded features
Cannot handle categorical features
References
Anderson T (2003). An Introduction to Multivariate Statistical Analysis, 3rd edition. Wiley-Interscience, Hoboken, NJ. ISBN 9780471360919.
Super classes
xplainfi::FeatureSampler -> xplainfi::ConditionalSampler -> ConditionalGaussianSampler
Public fields
feature_types(
character()) Feature types supported by the sampler.mu(
numeric()) Mean vector estimated from training data.sigma(
matrix()) Covariance matrix estimated from training data.
Methods
Method new()
Creates a new ConditionalGaussianSampler.
Usage
ConditionalGaussianSampler$new(task, conditioning_set = NULL)Arguments
task(mlr3::Task) Task to sample from. Must have only numeric/integer features.
conditioning_set(
character|NULL) Default conditioning set to use in$sample().
Examples
library(mlr3)
task = tgen("friedman1")$generate(n = 100)
sampler = ConditionalGaussianSampler$new(task)
# Sample x2, x3 conditioned on x1
test_data = task$data(rows = 1:5)
sampled = sampler$sample_newdata(
feature = c("important2", "important3"),
newdata = test_data,
conditioning_set = "important1"
)