causal_falsify.utils.mint module
- causal_falsify.utils.mint.bootstrap_model_fitting_jax(Y, T, tf_X, tf_XT, outcome_model_fn, treatment_model_fn, key)[source]
Fit outcome and treatment models on a bootstrap resample of the data.
This function performs bootstrap resampling (with replacement) using JAX random primitives and then fits the provided model functions on the resampled data.
- Parameters:
Y (jnp.ndarray) – Outcome array of shape (n_samples,).
T (jnp.ndarray) – Treatment array of shape (n_samples,).
tf_X (jnp.ndarray) – Transformed covariate matrix for the treatment model of shape (n_samples, n_features).
tf_XT (jnp.ndarray) – Transformed covariate matrix for the outcome model of shape (n_samples, n_features_outcome).
outcome_model_fn (callable) – Callable that fits an outcome model. Signature should be fn(X, Y) and return (params, mse) where params is an array of fitted coefficients.
treatment_model_fn (callable) – Callable that fits a treatment model. Same interface as outcome_model_fn.
key (jax.random.PRNGKey) – JAX PRNGKey used for resampling.
- Returns:
resampled_params_outcome, resampled_params_treatment – The fitted parameters for outcome and treatment models on the resampled data.
Notes
- The function expects outcome_model_fn and treatment_model_fn to be plain
Python callables (they can be functools.partial wrappers). The function is JIT-compiled here, and the two callable arguments are treated as static via static_argnums so they must be passed as Python callables (not JAX tracers / arrays).
- This function uses JAX operations for resampling and is JIT-compiled with
the model callables static to avoid tracing Python callables.
- Return type:
tuple
- causal_falsify.utils.mint.bootstrapped_permutation_independence_test(data_x, data_y, resampled_data_x, resampled_data_y, random_state=None)[source]
Performs a bootstrapped permutation independence test between two datasets.
This function computes the observed off-diagonal block Frobenius norm between data_x and data_y, then compares it to the distribution of norms obtained by permuting the resampled versions of data_x and data_y. The returned value is the proportion of times the observed statistic is less than the bootstrapped statistics, which can be interpreted as a p-value for the independence test.
- Parameters:
data_x (np.ndarray) – The original data array for variable X, of shape (n_samples, n_features_x).
data_y (np.ndarray) – The original data array for variable Y, of shape (n_samples, n_features_y).
resampled_data_x (np.ndarray) – Bootstrapped samples of data_x, of shape (n_bootstraps, n_samples, n_features_x).
resampled_data_y (np.ndarray) – Bootstrapped samples of data_y, of shape (n_bootstraps, n_samples, n_features_y).
random_state (np.random.RandomState or None, optional) – Random state for reproducibility. If None, a new RandomState is created.
- Returns:
The proportion of bootstrapped statistics greater than the observed statistic, representing the p-value for the independence test.
- Return type:
float
- Raises:
AssertionError – If the number of bootstraps in resampled_data_x and resampled_data_y do not match.
Notes
This function relies on compute_offdiag_block_frobnorm to compute the test statistic.
- causal_falsify.utils.mint.compute_offdiag_block_frobnorm(data_x, data_y)[source]
Compute the Frobenius norm of the off-diagonal block of the covariance matrix between two datasets.
Given two datasets with the same number of samples, this function concatenates them, computes the covariance matrix, extracts the off-diagonal block corresponding to the covariances between the two datasets, and returns its Frobenius norm.
- Parameters:
data_x (np.ndarray) – A 2D array of shape (n_samples, n_features_x) representing the first dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1).
data_y (np.ndarray) – A 2D array of shape (n_samples, n_features_y) representing the second dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1).
- Returns:
The Frobenius norm of the off-diagonal block of the covariance matrix between data_x and data_y.
- Return type:
float
- Raises:
AssertionError – If the number of samples (first dimension) in data_x and data_y do not match.
ValueError – If the input matrices are not valid as determined by validate_matrix.
Notes
The off-diagonal block refers to the submatrix of the covariance matrix that captures the covariances between the features of data_x and data_y.
- causal_falsify.utils.mint.create_polynomial_representation(X, degree, use_sklearn=False, interaction_only=False)[source]
Generate a polynomial feature representation of the input data.
- Parameters:
X (array-like of shape (n_samples, n_features)) – The input data to be transformed into polynomial features.
degree (int) – The degree of the polynomial features to be generated. Must be greater than 1.
use_sklearn (bool, optional, default=False) – If True, use sklearn’s PolynomialFeatures for transformation. If False, generate polynomial features manually (only powers of individual features, no cross-terms).
interaction_only (bool, optional, default=False) – If True and use_sklearn is True, only interaction features are produced: features that are products of at most degree distinct input features (no powers of single features). Has no effect if use_sklearn is False.
- Returns:
X_poly – The matrix of polynomial features.
- Return type:
ndarray of shape (n_samples, n_output_features)
- Raises:
ValueError – If degree is less than or equal to 1.
Notes
When use_sklearn is False, only powers of individual features are generated (no interaction/cross terms).
When use_sklearn is True, both interaction and power terms are generated according to the parameters.
- causal_falsify.utils.mint.cross_val_mse(X, Y, model_fn, num_folds)[source]
Perform k-fold cross-validation and compute the mean squared error (MSE).
- Parameters:
X (array-like of shape (n_samples, n_features)) – Transformed feature matrix, including intercept term if desired.
Y (array-like of shape (n_samples,) or (n_samples, n_targets)) – Target variable.
model_fn (callable) – Function to fit the model. Should take (X_train, Y_train) and return fitted parameters.
num_folds (int) – Number of folds for cross-validation.
- Returns:
Mean squared error averaged across all folds.
- Return type:
float
- causal_falsify.utils.mint.fit_linear_regression(X, Y, alpha=0.0)[source]
Fit a linear regression model using JAX.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Transformed feature matrix, including intercept term if desired.
Y (array-like of shape (n_samples,) or (n_samples, n_targets)) – Target variable.
alpha (float, optional (default=0)) – Regularization strength for ridge (L2) penalty. Set to 0 for ordinary least squares.
- Returns:
params – Fitted linear regression coefficients.
- Return type:
jax.numpy.ndarray of shape (n_features,) or (n_features, n_targets)
- causal_falsify.utils.mint.fit_logistic_regression(X, Y, alpha=0)[source]
Fit a logistic regression model using JAX and gradient descent (binary cross-entropy).
Implementation notes
Uses a numerically-stable logits-based binary cross-entropy loss.
Inputs X and Y are converted to JAX arrays and cast to a float dtype inside the function (float32 by default).
Training is executed inside a JIT-compiled loop using jax.lax.fori_loop for efficiency. The function itself is decorated with @jit.
L2 regularization is applied as alpha * sum(params**2). Passing alpha=0 results in no regularization.
- type X:
Array- param X:
Design matrix (can include an intercept column).
- type X:
array-like, shape (n_samples, n_features)
- type Y:
Array- param Y:
Binary target values in {0, 1} (the implementation expects 0/1 labels).
- type Y:
array-like, shape (n_samples,)
- type alpha:
float- param alpha:
Regularization strength for ridge (L2) penalty. May be a traced JAX scalar when called from JITted code; the implementation avoids Python boolean checks on alpha.
- type alpha:
float, optional (default=0)
- returns:
params – Fitted logistic regression coefficients (dtype float32).
- rtype:
jax.numpy.ndarray, shape (n_features,)
Notes
The number of gradient-descent iterations and learning rate are currently hard-coded (1000 iterations, learning rate 0.1). If you need tunable training behaviour, consider adding explicit arguments or switching to an optimizer library such as optax.
- causal_falsify.utils.mint.fit_model_jax(X, Y, binary_response=False)[source]
Fit a nuisance model (linear or logistic) and evaluate its performance via cross-validation.
This function selects between fit_linear_regression and fit_logistic_regression based on the binary_response flag and returns the fitted parameters along with a cross-validated mean-squared error diagnostic.
Important
binary_response is treated as a plain Python boolean here. When calling fit_model_jax from JIT-compiled code, prefer passing function objects directly to the caller (see bootstrap_model_fitting_jax) to avoid tracing boolean values.
- Parameters:
X (jax.numpy.ndarray) – Design matrix of shape (n_samples, n_features).
Y (jax.numpy.ndarray) – Target vector of shape (n_samples,).
binary_response (bool, optional) – If True, fit a logistic regression model; otherwise fit a linear model.
- Return type:
tuple[Array,float]- Returns:
params_outcome (jax.numpy.ndarray) – The fitted model parameters (shape (n_features,)).
model_mse (float) – Cross-validated mean squared error for diagnostic purposes.
- causal_falsify.utils.mint.permutation_independence_test(data_x, data_y, n_bootstraps=1000, random_state=None)[source]
Performs a permutation-based independence test between two datasets.
This function tests the null hypothesis that data_x and data_y are independent by comparing the observed off-diagonal block Frobenius norm to the distribution obtained by permuting data_x. The p-value is estimated as the proportion of permuted statistics greater than the observed statistic.
- Parameters:
data_x (np.ndarray) – The first dataset, with samples along the first axis.
data_y (np.ndarray) – The second dataset, with samples along the first axis.
n_bootstraps (int, optional) – Number of permutations to perform (default is 1000).
random_state (np.random.RandomState or None, optional) – Random state for reproducibility. If None, a new RandomState is created.
- Returns:
The estimated p-value for the independence test.
- Return type:
float
Notes
Requires the function compute_offdiag_block_frobnorm to compute the test statistic.
- causal_falsify.utils.mint.resample_until_enough_unique(subkey, n_resamples, min_sample_size)[source]
Repeatedly resample indices (with replacement) until the sample contains at least min_sample_size unique indices.
- Parameters:
subkey (jax.random.PRNGKey) – PRNG key for JAX random operations.
n_resamples (int) – Number of indices to sample in each iteration (sample size).
min_sample_size (int) – Minimum required number of unique indices in the resampled set.
- Returns:
resampled_indices – Integer array of shape (n_resamples,) containing resampled indices. The returned array is guaranteed to contain at least min_sample_size distinct values when the function returns.
- Return type:
jax.numpy.ndarray
Notes
The function uses jax.lax.while_loop internally to remain compatible with JIT tracing. The condition and body are expressed with JAX primitives.
If min_sample_size > n_resamples the loop cannot succeed; the caller should ensure min_sample_size <= n_resamples to avoid an infinite loop.
- causal_falsify.utils.mint.validate_matrix(matrix)[source]
Validates that the input matrix is a proper 2-dimensional NumPy array without NaN or infinite values.
- Parameters:
matrix (np.ndarray) – The matrix to validate.
- Raises:
AssertionError – If the input is not a NumPy array. If the matrix contains NaN values. If the matrix contains infinite values. If the matrix is not 2-dimensional.