Source code for causal_falsify.utils.mint

import numpy as np
import jax.numpy as jnp
from functools import partial
from jax.scipy.linalg import solve
from jax import grad, jit, random, lax
from sklearn.preprocessing import PolynomialFeatures
import warnings


[docs] def create_polynomial_representation( X, degree, use_sklearn=False, interaction_only=False ) -> np.ndarray: """ Generate a polynomial feature representation of the input data. Parameters ---------- X : array-like of shape (n_samples, n_features) The input data to be transformed into polynomial features. degree : int The degree of the polynomial features to be generated. Must be greater than 1. use_sklearn : bool, optional, default=False If True, use sklearn's PolynomialFeatures for transformation. If False, generate polynomial features manually (only powers of individual features, no cross-terms). interaction_only : bool, optional, default=False If True and `use_sklearn` is True, only interaction features are produced: features that are products of at most `degree` distinct input features (no powers of single features). Has no effect if `use_sklearn` is False. Returns ------- X_poly : ndarray of shape (n_samples, n_output_features) The matrix of polynomial features. Raises ------ ValueError If `degree` is less than or equal to 1. Notes ----- - When `use_sklearn` is False, only powers of individual features are generated (no interaction/cross terms). - When `use_sklearn` is True, both interaction and power terms are generated according to the parameters. """ if degree <= 1: raise ValueError("Degree must be larger than 1.") if interaction_only and not use_sklearn: print("Warning: interaction_only has no effect as use_sklearn = False.") if use_sklearn: return PolynomialFeatures( degree=degree, interaction_only=interaction_only, include_bias=False ).fit_transform(X) else: n_features = X.shape[1] # Handle empty covariate case: return empty array with same number of rows if n_features == 0: return np.empty((X.shape[0], 0)) # Create an empty list to store polynomial features poly_features = [] # Iterate over each feature for feature_idx in range(n_features): # Create polynomial features for the current feature feature = X[:, feature_idx] poly_feature = np.column_stack([feature**d for d in range(1, degree + 1)]) poly_features.append(poly_feature) # Stack the polynomial features horizontally X_poly = np.hstack(poly_features) return X_poly
############################################################### # Test based on computing Frobenius norm of off-diagonal block ###############################################################
[docs] def compute_offdiag_block_frobnorm(data_x, data_y) -> float: """ Compute the Frobenius norm of the off-diagonal block of the covariance matrix between two datasets. Given two datasets with the same number of samples, this function concatenates them, computes the covariance matrix, extracts the off-diagonal block corresponding to the covariances between the two datasets, and returns its Frobenius norm. Parameters ---------- data_x : np.ndarray A 2D array of shape (n_samples, n_features_x) representing the first dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1). data_y : np.ndarray A 2D array of shape (n_samples, n_features_y) representing the second dataset. Can also be 1D with shape (n_samples,), which will be reshaped to (n_samples, 1). Returns ------- float The Frobenius norm of the off-diagonal block of the covariance matrix between `data_x` and `data_y`. Raises ------ AssertionError If the number of samples (first dimension) in `data_x` and `data_y` do not match. ValueError If the input matrices are not valid as determined by `validate_matrix`. Notes ----- The off-diagonal block refers to the submatrix of the covariance matrix that captures the covariances between the features of `data_x` and `data_y`. """ # Ensure data is 2D if data_x.ndim == 1: data_x = data_x.reshape(-1, 1) if data_y.ndim == 1: data_y = data_y.reshape(-1, 1) dim_x, dim_y = data_x.shape[1], data_y.shape[1] assert data_x.shape[0] == data_y.shape[0], "first dimension be the same" coefs = np.hstack([data_x, data_y]) validate_matrix(coefs) covariance_matrix = np.cov(coefs, rowvar=False) offdiag_block = covariance_matrix[:dim_x, dim_x:] assert offdiag_block.shape == (dim_x, dim_y) return np.linalg.norm(offdiag_block, "fro")
[docs] def permutation_independence_test( data_x: np.ndarray, data_y: np.ndarray, n_bootstraps: int = 1000, random_state=None ) -> float: """ Performs a permutation-based independence test between two datasets. This function tests the null hypothesis that `data_x` and `data_y` are independent by comparing the observed off-diagonal block Frobenius norm to the distribution obtained by permuting `data_x`. The p-value is estimated as the proportion of permuted statistics greater than the observed statistic. Parameters ---------- data_x : np.ndarray The first dataset, with samples along the first axis. data_y : np.ndarray The second dataset, with samples along the first axis. n_bootstraps : int, optional Number of permutations to perform (default is 1000). random_state : np.random.RandomState or None, optional Random state for reproducibility. If None, a new RandomState is created. Returns ------- float The estimated p-value for the independence test. Notes ----- Requires the function `compute_offdiag_block_frobnorm` to compute the test statistic. """ if random_state is None: random_state = np.random.RandomState() observed_frob_norm = compute_offdiag_block_frobnorm(data_x, data_y) resampled_frob_norm = np.zeros((n_bootstraps, 1)) for j in range(n_bootstraps): # permute rows in coef_t permuted_data_x = random_state.permutation(data_x) # permutates on first axis resampled_frob_norm[j] = compute_offdiag_block_frobnorm(permuted_data_x, data_y) return np.mean(observed_frob_norm < resampled_frob_norm)
[docs] def bootstrapped_permutation_independence_test( data_x: np.ndarray, data_y: np.ndarray, resampled_data_x: np.ndarray, resampled_data_y: np.ndarray, random_state=None, ) -> float: """ Performs a bootstrapped permutation independence test between two datasets. This function computes the observed off-diagonal block Frobenius norm between `data_x` and `data_y`, then compares it to the distribution of norms obtained by permuting the resampled versions of `data_x` and `data_y`. The returned value is the proportion of times the observed statistic is less than the bootstrapped statistics, which can be interpreted as a p-value for the independence test. Parameters ---------- data_x : np.ndarray The original data array for variable X, of shape (n_samples, n_features_x). data_y : np.ndarray The original data array for variable Y, of shape (n_samples, n_features_y). resampled_data_x : np.ndarray Bootstrapped samples of `data_x`, of shape (n_bootstraps, n_samples, n_features_x). resampled_data_y : np.ndarray Bootstrapped samples of `data_y`, of shape (n_bootstraps, n_samples, n_features_y). random_state : np.random.RandomState or None, optional Random state for reproducibility. If None, a new RandomState is created. Returns ------- float The proportion of bootstrapped statistics greater than the observed statistic, representing the p-value for the independence test. Raises ------ AssertionError If the number of bootstraps in `resampled_data_x` and `resampled_data_y` do not match. Notes ----- This function relies on `compute_offdiag_block_frobnorm` to compute the test statistic. """ if random_state is None: random_state = np.random.RandomState() n_bootstraps = resampled_data_x.shape[0] assert resampled_data_x.shape[:1] == resampled_data_y.shape[:1] observed_frob_norm = compute_offdiag_block_frobnorm(data_x, data_y) resampled_frob_norm = np.zeros((n_bootstraps, 1)) for j in range(n_bootstraps): permuted_resampled_data_x = random_state.permutation( resampled_data_x[j, :, :].squeeze() ) resampled_frob_norm[j] = compute_offdiag_block_frobnorm( permuted_resampled_data_x, resampled_data_y[j, :, :].squeeze() ) return np.mean(observed_frob_norm < resampled_frob_norm)
########################################## # Utils ##########################################
[docs] def validate_matrix(matrix: np.ndarray): """ Validates that the input matrix is a proper 2-dimensional NumPy array without NaN or infinite values. Parameters ---------- matrix : np.ndarray The matrix to validate. Raises ------ AssertionError If the input is not a NumPy array. If the matrix contains NaN values. If the matrix contains infinite values. If the matrix is not 2-dimensional. """ # Assert that the input is a NumPy array assert isinstance(matrix, np.ndarray), "Input must be a NumPy array." # Assert no NaN values assert not jnp.isnan(matrix).any(), f"Matrix contains NaN values: {matrix}" # Assert no infinite values assert not np.isinf(matrix).any(), "Matrix contains infinite values." # Assert proper dimensionality assert matrix.ndim == 2, "Matrix must be 2-dimensional."
############################################################### # Methods for estimating linear models ###############################################################
[docs] def fit_logistic_regression( X: jnp.ndarray, Y: jnp.ndarray, alpha: float = 0 ) -> jnp.ndarray: """ Fit a logistic regression model using JAX and gradient descent (binary cross-entropy). Implementation notes -------------------- - Uses a numerically-stable logits-based binary cross-entropy loss. - Inputs `X` and `Y` are converted to JAX arrays and cast to a float dtype inside the function (`float32` by default). - Training is executed inside a JIT-compiled loop using `jax.lax.fori_loop` for efficiency. The function itself is decorated with `@jit`. - L2 regularization is applied as `alpha * sum(params**2)`. Passing `alpha=0` results in no regularization. Parameters ---------- X : array-like, shape (n_samples, n_features) Design matrix (can include an intercept column). Y : array-like, shape (n_samples,) Binary target values in {0, 1} (the implementation expects 0/1 labels). alpha : float, optional (default=0) Regularization strength for ridge (L2) penalty. May be a traced JAX scalar when called from JITted code; the implementation avoids Python boolean checks on `alpha`. Returns ------- params : jax.numpy.ndarray, shape (n_features,) Fitted logistic regression coefficients (dtype `float32`). Notes ----- - The number of gradient-descent iterations and learning rate are currently hard-coded (1000 iterations, learning rate 0.1). If you need tunable training behaviour, consider adding explicit arguments or switching to an optimizer library such as `optax`. """ # Ensure inputs are JAX arrays with float dtype X = jnp.asarray(X) Y = jnp.asarray(Y).astype(jnp.float32) # Define binary cross-entropy (BCE) loss from logits (numerically stable) # Uses formulation: BCE(logits, y) = max(l,0) - l*y + log(1 + exp(-abs(l))) def logistic_loss(parms, x, y, alp): logits = x @ parms # stable per-example loss per_example = ( jnp.maximum(logits, 0) - logits * y + jnp.log1p(jnp.exp(-jnp.abs(logits))) ) loss = jnp.mean(per_example) # Always add regularization term (alpha can be 0) loss = loss + alp * jnp.sum(parms**2) return loss # Initial guess for parameters (weights) init_params = jnp.zeros(X.shape[1], dtype=jnp.float32) # Compute the gradient of the loss function # Use a traced loop via lax.fori_loop for efficient JIT compilation def body(i, parms): grads = grad(logistic_loss)(parms, X, Y, alpha) return parms - 0.1 * grads num_iters = 1000 params = lax.fori_loop(0, num_iters, body, init_params) return params
[docs] def fit_linear_regression( X: jnp.ndarray, Y: jnp.ndarray, alpha: float = 0.0 ) -> jnp.ndarray: """ Fit a linear regression model using JAX. Parameters ---------- X : array-like of shape (n_samples, n_features) Transformed feature matrix, including intercept term if desired. Y : array-like of shape (n_samples,) or (n_samples, n_targets) Target variable. alpha : float, optional (default=0) Regularization strength for ridge (L2) penalty. Set to 0 for ordinary least squares. Returns ------- params : jax.numpy.ndarray of shape (n_features,) or (n_features, n_targets) Fitted linear regression coefficients. """ I = jnp.eye(X.shape[1]) # Identity matrix for regularization I = I.at[-1, -1].set(0) # Exclude intercept from regularization params = solve(X.T @ X + alpha * I, X.T @ Y) return params
[docs] def cross_val_mse(X: jnp.ndarray, Y: jnp.ndarray, model_fn, num_folds: int) -> float: """ Perform k-fold cross-validation and compute the mean squared error (MSE). Parameters ---------- X : array-like of shape (n_samples, n_features) Transformed feature matrix, including intercept term if desired. Y : array-like of shape (n_samples,) or (n_samples, n_targets) Target variable. model_fn : callable Function to fit the model. Should take (X_train, Y_train) and return fitted parameters. num_folds : int Number of folds for cross-validation. Returns ------- float Mean squared error averaged across all folds. """ n = X.shape[0] fold_size = n // num_folds mse_list = [] for i in range(num_folds): # Split data into training and validation sets val_indices = jnp.arange(i * fold_size, (i + 1) * fold_size) train_indices = jnp.concatenate( [jnp.arange(0, i * fold_size), jnp.arange((i + 1) * fold_size, n)] ) X_train, X_val = X[train_indices], X[val_indices] Y_train, Y_val = Y[train_indices], Y[val_indices] # Fit the model and get parameters using the training set params = model_fn(X_train, Y_train) # Compute MSE on validation set preds = X_val @ params mse = jnp.mean((Y_val - preds) ** 2) mse_list.append(mse) return jnp.mean(jnp.array(mse_list))
[docs] def fit_model_jax( X: jnp.ndarray, Y: jnp.ndarray, binary_response: bool = False, ) -> tuple[jnp.ndarray, float]: """ Fit a nuisance model (linear or logistic) and evaluate its performance via cross-validation. This function selects between `fit_linear_regression` and `fit_logistic_regression` based on the `binary_response` flag and returns the fitted parameters along with a cross-validated mean-squared error diagnostic. Important --------- - `binary_response` is treated as a plain Python boolean here. When calling `fit_model_jax` from JIT-compiled code, prefer passing function objects directly to the caller (see `bootstrap_model_fitting_jax`) to avoid tracing boolean values. Parameters ---------- X : jax.numpy.ndarray Design matrix of shape `(n_samples, n_features)`. Y : jax.numpy.ndarray Target vector of shape `(n_samples,)`. binary_response : bool, optional If True, fit a logistic regression model; otherwise fit a linear model. Returns ------- params_outcome : jax.numpy.ndarray The fitted model parameters (shape `(n_features,)`). model_mse : float Cross-validated mean squared error for diagnostic purposes. """ assert X.shape[0] > X.shape[1], "need more samples than features" # Fit the outcome model using model_fn if binary_response: model_fn = fit_logistic_regression else: model_fn = fit_linear_regression params_outcome = model_fn(X=X, Y=Y) # Perform cross-validation for model diagnostic using the same model_fn try: model_mse = cross_val_mse(X, Y, model_fn, num_folds=3) except Exception as e: warnings.warn(f"Cross-validation failed: {str(e)}") model_mse = np.nan # Ensure params always have shape (n_features,) params_outcome = jnp.asarray(params_outcome).reshape(-1) return params_outcome, model_mse
[docs] @partial(jit, static_argnames=["outcome_model_fn", "treatment_model_fn"]) def bootstrap_model_fitting_jax( Y: jnp.ndarray, T: jnp.ndarray, tf_X: jnp.ndarray, tf_XT: jnp.ndarray, outcome_model_fn, treatment_model_fn, key, ): """ Fit outcome and treatment models on a bootstrap resample of the data. This function performs bootstrap resampling (with replacement) using JAX random primitives and then fits the provided model functions on the resampled data. Parameters ---------- Y : jnp.ndarray Outcome array of shape `(n_samples,)`. T : jnp.ndarray Treatment array of shape `(n_samples,)`. tf_X : jnp.ndarray Transformed covariate matrix for the treatment model of shape `(n_samples, n_features)`. tf_XT : jnp.ndarray Transformed covariate matrix for the outcome model of shape `(n_samples, n_features_outcome)`. outcome_model_fn : callable Callable that fits an outcome model. Signature should be `fn(X, Y)` and return `(params, mse)` where `params` is an array of fitted coefficients. treatment_model_fn : callable Callable that fits a treatment model. Same interface as `outcome_model_fn`. key : jax.random.PRNGKey JAX PRNGKey used for resampling. Returns ------- resampled_params_outcome, resampled_params_treatment : tuple The fitted parameters for outcome and treatment models on the resampled data. Notes ----- - The function expects `outcome_model_fn` and `treatment_model_fn` to be plain Python callables (they can be `functools.partial` wrappers). The function is JIT-compiled here, and the two callable arguments are treated as static via `static_argnums` so they must be passed as Python callables (not JAX tracers / arrays). - This function uses JAX operations for resampling and is JIT-compiled with the model callables static to avoid tracing Python callables. """ # Resample indices using JAX's random module for reproducibility key, subkey = random.split(key) # Split the key to get a new one for resampling min_sample_size_needed_for_estimation = tf_X.shape[1] + 1 assert ( tf_X.shape[0] > min_sample_size_needed_for_estimation ), f"need more samples than {min_sample_size_needed_for_estimation}" resampled_indices = resample_until_enough_unique( subkey, Y.shape[0], min_sample_size_needed_for_estimation ) # Resample the data resampled_Y = Y[resampled_indices] resampled_T = T[resampled_indices] resampled_tf_X = tf_X[resampled_indices] resampled_tf_XT = tf_XT[resampled_indices] # Fit outcome and treatment models on resampled data resampled_params_outcome, _ = outcome_model_fn(resampled_tf_XT, resampled_Y) resampled_params_treatment, _ = treatment_model_fn(resampled_tf_X, resampled_T) return resampled_params_outcome, resampled_params_treatment
[docs] def resample_until_enough_unique(subkey, n_resamples, min_sample_size): """ Repeatedly resample indices (with replacement) until the sample contains at least `min_sample_size` unique indices. Parameters ---------- subkey : jax.random.PRNGKey PRNG key for JAX random operations. n_resamples : int Number of indices to sample in each iteration (sample size). min_sample_size : int Minimum required number of unique indices in the resampled set. Returns ------- resampled_indices : jax.numpy.ndarray Integer array of shape `(n_resamples,)` containing resampled indices. The returned array is guaranteed to contain at least `min_sample_size` distinct values when the function returns. Notes ----- - The function uses `jax.lax.while_loop` internally to remain compatible with JIT tracing. The condition and body are expressed with JAX primitives. - If `min_sample_size > n_resamples` the loop cannot succeed; the caller should ensure `min_sample_size <= n_resamples` to avoid an infinite loop. """ # Initial resampling resampled_indices = random.choice( subkey, n_resamples, shape=(n_resamples,), replace=True ) def count_unique(x): x = jnp.sort(x) return 1 + (x[1:] != x[:-1]).sum() # Define condition function for while loop def condition_fn(state): _, resampled_indices = state # Check if unique indices are below the threshold # Use jnp.asarray() to ensure the result is a JAX array that can be used in lax.while_loop return jnp.asarray(count_unique(resampled_indices) < min_sample_size) # Define body function for while loop def body_fn(state): subkey, _ = state # Resample and update state subkey, new_subkey = random.split(subkey) resampled_indices = random.choice( new_subkey, n_resamples, shape=(n_resamples,), replace=True ) return (subkey, resampled_indices) # Initial state: (key, resampled_indices) state = (subkey, resampled_indices) # Apply while loop until the condition is met _, resampled_indices = lax.while_loop(condition_fn, body_fn, state) return resampled_indices