causal_falsify.utils.mint module

causal_falsify.utils.mint.bootstrap_model_fitting_jax(Y, T, tf_X, tf_XT, outcome_model_fn, treatment_model_fn, key)[source]

Fit outcome and treatment models on a bootstrap resample of the data.

This function performs bootstrap resampling (with replacement) using JAX random primitives and then fits the provided model functions on the resampled data.

Parameters:
  • Y (jnp.ndarray) – Outcome array of shape (n_samples,).

  • T (jnp.ndarray) – Treatment array of shape (n_samples,).

  • tf_X (jnp.ndarray) – Transformed covariate matrix for the treatment model of shape (n_samples, n_features).

  • tf_XT (jnp.ndarray) – Transformed covariate matrix for the outcome model of shape (n_samples, n_features_outcome).

  • outcome_model_fn (callable) – Callable that fits an outcome model. Signature should be fn(X, Y) and return (params, mse) where params is an array of fitted coefficients.

  • treatment_model_fn (callable) – Callable that fits a treatment model. Same interface as outcome_model_fn.

  • key (jax.random.PRNGKey) – JAX PRNGKey used for resampling.

Returns:

resampled_params_outcome, resampled_params_treatment – The fitted parameters for outcome and treatment models on the resampled data.

Notes

  • The function expects outcome_model_fn and treatment_model_fn to be plain

    Python callables (they can be functools.partial wrappers). The function is JIT-compiled here, and the two callable arguments are treated as static via static_argnums so they must be passed as Python callables (not JAX tracers / arrays).

  • This function uses JAX operations for resampling and is JIT-compiled with

    the model callables static to avoid tracing Python callables.

Return type:

tuple

causal_falsify.utils.mint.bootstrapped_permutation_independence_test(data_x, data_y, resampled_data_x, resampled_data_y, random_state=None)[source]

Performs a bootstrapped permutation independence test between two datasets.

This function computes the observed off-diagonal block Frobenius norm between data_x and data_y, then compares it to the distribution of norms obtained by permuting the resampled versions of data_x and data_y. The returned value is the proportion of times the observed statistic is less than the bootstrapped statistics, which can be interpreted as a p-value for the independence test.

Parameters:
  • data_x (np.ndarray) – The original data array for variable X, of shape (n_samples, n_features_x).

  • data_y (np.ndarray) – The original data array for variable Y, of shape (n_samples, n_features_y).

  • resampled_data_x (np.ndarray) – Bootstrapped samples of data_x, of shape (n_bootstraps, n_samples, n_features_x).

  • resampled_data_y (np.ndarray) – Bootstrapped samples of data_y, of shape (n_bootstraps, n_samples, n_features_y).

  • random_state (np.random.RandomState or None, optional) – Random state for reproducibility. If None, a new RandomState is created.

Returns:

The proportion of bootstrapped statistics greater than the observed statistic, representing the p-value for the independence test.

Return type:

float

Raises:

AssertionError – If the number of bootstraps in resampled_data_x and resampled_data_y do not match.

Notes

This function relies on compute_offdiag_block_frobnorm to compute the test statistic.

causal_falsify.utils.mint.compute_offdiag_block_frobnorm(data_x, data_y)[source]

Compute the Frobenius norm of the off-diagonal block of the covariance matrix between two datasets.

Given two datasets with the same number of samples, this function concatenates them, computes the covariance matrix, extracts the off-diagonal block corresponding to the covariances between the two datasets, and returns its Frobenius norm.

Parameters:
  • data_x (np.ndarray) – A 2D array of shape (n_samples, n_features_x) representing the first dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1).

  • data_y (np.ndarray) – A 2D array of shape (n_samples, n_features_y) representing the second dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1).

Returns:

The Frobenius norm of the off-diagonal block of the covariance matrix between data_x and data_y.

Return type:

float

Raises:
  • AssertionError – If the number of samples (first dimension) in data_x and data_y do not match.

  • ValueError – If the input matrices are not valid as determined by validate_matrix.

Notes

The off-diagonal block refers to the submatrix of the covariance matrix that captures the covariances between the features of data_x and data_y.

causal_falsify.utils.mint.create_polynomial_representation(X, degree, use_sklearn=False, interaction_only=False)[source]

Generate a polynomial feature representation of the input data.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – The input data to be transformed into polynomial features.

  • degree (int) – The degree of the polynomial features to be generated. Must be greater than 1.

  • use_sklearn (bool, optional, default=False) – If True, use sklearn’s PolynomialFeatures for transformation. If False, generate polynomial features manually (only powers of individual features, no cross-terms).

  • interaction_only (bool, optional, default=False) – If True and use_sklearn is True, only interaction features are produced: features that are products of at most degree distinct input features (no powers of single features). Has no effect if use_sklearn is False.

Returns:

X_poly – The matrix of polynomial features.

Return type:

ndarray of shape (n_samples, n_output_features)

Raises:

ValueError – If degree is less than or equal to 1.

Notes

  • When use_sklearn is False, only powers of individual features are generated (no interaction/cross terms).

  • When use_sklearn is True, both interaction and power terms are generated according to the parameters.

causal_falsify.utils.mint.cross_val_mse(X, Y, model_fn, num_folds)[source]

Perform k-fold cross-validation and compute the mean squared error (MSE).

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Transformed feature matrix, including intercept term if desired.

  • Y (array-like of shape (n_samples,) or (n_samples, n_targets)) – Target variable.

  • model_fn (callable) – Function to fit the model. Should take (X_train, Y_train) and return fitted parameters.

  • num_folds (int) – Number of folds for cross-validation.

Returns:

Mean squared error averaged across all folds.

Return type:

float

causal_falsify.utils.mint.fit_linear_regression(X, Y, alpha=0.0)[source]

Fit a linear regression model using JAX.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Transformed feature matrix, including intercept term if desired.

  • Y (array-like of shape (n_samples,) or (n_samples, n_targets)) – Target variable.

  • alpha (float, optional (default=0)) – Regularization strength for ridge (L2) penalty. Set to 0 for ordinary least squares.

Returns:

params – Fitted linear regression coefficients.

Return type:

jax.numpy.ndarray of shape (n_features,) or (n_features, n_targets)

causal_falsify.utils.mint.fit_logistic_regression(X, Y, alpha=0)[source]

Fit a logistic regression model using JAX and gradient descent (binary cross-entropy).

Implementation notes

  • Uses a numerically-stable logits-based binary cross-entropy loss.

  • Inputs X and Y are converted to JAX arrays and cast to a float dtype inside the function (float32 by default).

  • Training is executed inside a JIT-compiled loop using jax.lax.fori_loop for efficiency. The function itself is decorated with @jit.

  • L2 regularization is applied as alpha * sum(params**2). Passing alpha=0 results in no regularization.

type X:

Array

param X:

Design matrix (can include an intercept column).

type X:

array-like, shape (n_samples, n_features)

type Y:

Array

param Y:

Binary target values in {0, 1} (the implementation expects 0/1 labels).

type Y:

array-like, shape (n_samples,)

type alpha:

float

param alpha:

Regularization strength for ridge (L2) penalty. May be a traced JAX scalar when called from JITted code; the implementation avoids Python boolean checks on alpha.

type alpha:

float, optional (default=0)

returns:

params – Fitted logistic regression coefficients (dtype float32).

rtype:

jax.numpy.ndarray, shape (n_features,)

Notes

  • The number of gradient-descent iterations and learning rate are currently hard-coded (1000 iterations, learning rate 0.1). If you need tunable training behaviour, consider adding explicit arguments or switching to an optimizer library such as optax.

causal_falsify.utils.mint.fit_model_jax(X, Y, binary_response=False)[source]

Fit a nuisance model (linear or logistic) and evaluate its performance via cross-validation.

This function selects between fit_linear_regression and fit_logistic_regression based on the binary_response flag and returns the fitted parameters along with a cross-validated mean-squared error diagnostic.

Important

  • binary_response is treated as a plain Python boolean here. When calling fit_model_jax from JIT-compiled code, prefer passing function objects directly to the caller (see bootstrap_model_fitting_jax) to avoid tracing boolean values.

Parameters:
  • X (jax.numpy.ndarray) – Design matrix of shape (n_samples, n_features).

  • Y (jax.numpy.ndarray) – Target vector of shape (n_samples,).

  • binary_response (bool, optional) – If True, fit a logistic regression model; otherwise fit a linear model.

Return type:

tuple[Array, float]

Returns:

  • params_outcome (jax.numpy.ndarray) – The fitted model parameters (shape (n_features,)).

  • model_mse (float) – Cross-validated mean squared error for diagnostic purposes.

causal_falsify.utils.mint.permutation_independence_test(data_x, data_y, n_bootstraps=1000, random_state=None)[source]

Performs a permutation-based independence test between two datasets.

This function tests the null hypothesis that data_x and data_y are independent by comparing the observed off-diagonal block Frobenius norm to the distribution obtained by permuting data_x. The p-value is estimated as the proportion of permuted statistics greater than the observed statistic.

Parameters:
  • data_x (np.ndarray) – The first dataset, with samples along the first axis.

  • data_y (np.ndarray) – The second dataset, with samples along the first axis.

  • n_bootstraps (int, optional) – Number of permutations to perform (default is 1000).

  • random_state (np.random.RandomState or None, optional) – Random state for reproducibility. If None, a new RandomState is created.

Returns:

The estimated p-value for the independence test.

Return type:

float

Notes

Requires the function compute_offdiag_block_frobnorm to compute the test statistic.

causal_falsify.utils.mint.resample_until_enough_unique(subkey, n_resamples, min_sample_size)[source]

Repeatedly resample indices (with replacement) until the sample contains at least min_sample_size unique indices.

Parameters:
  • subkey (jax.random.PRNGKey) – PRNG key for JAX random operations.

  • n_resamples (int) – Number of indices to sample in each iteration (sample size).

  • min_sample_size (int) – Minimum required number of unique indices in the resampled set.

Returns:

resampled_indices – Integer array of shape (n_resamples,) containing resampled indices. The returned array is guaranteed to contain at least min_sample_size distinct values when the function returns.

Return type:

jax.numpy.ndarray

Notes

  • The function uses jax.lax.while_loop internally to remain compatible with JIT tracing. The condition and body are expressed with JAX primitives.

  • If min_sample_size > n_resamples the loop cannot succeed; the caller should ensure min_sample_size <= n_resamples to avoid an infinite loop.

causal_falsify.utils.mint.validate_matrix(matrix)[source]

Validates that the input matrix is a proper 2-dimensional NumPy array without NaN or infinite values.

Parameters:

matrix (np.ndarray) – The matrix to validate.

Raises:

AssertionError – If the input is not a NumPy array. If the matrix contains NaN values. If the matrix contains infinite values. If the matrix is not 2-dimensional.