Source code for causal_falsify.algorithms.mint

import numpy as np
import pandas as pd
from typing import List, Dict
from functools import partial
from jax import random, vmap, config
import jax.numpy as jnp

config.update("jax_enable_x64", True)

from causal_falsify.algorithms.abstract import AbstractFalsificationAlgorithm
from causal_falsify.utils.mint import (
    create_polynomial_representation,
    bootstrap_model_fitting_jax,
    fit_model_jax,
    permutation_independence_test,
    bootstrapped_permutation_independence_test,
)


[docs] class MINT(AbstractFalsificationAlgorithm): def __init__( self, feature_representation: str = "linear", feature_representation_params: dict = {}, binary_treatment: bool = False, binary_outcome: bool = False, min_samples_per_env: int = 25, independence_test_args: dict = {}, n_bootstraps: int = 1000, ) -> None: """ Mechanism INdependent Test (MINT) algorithm from "Falsification of Unconfoundedness by Testing Independence of Causal Mechanisms" Karlsson and Krijthe, ICML 2025 (https://arxiv.org/abs/2502.06231) Joint test for whether we have independence between causal mechanisms and unconfoundedness across sources. A rejection will falsify both conditions jointly. Parameters ---------- feature_representation : str, optional Feature representation to use ("linear" or "poly"). feature_representation_params : dict, optional Parameters for the feature representation. binary_treatment : bool, optional Whether the treatment is binary 0/1. Default is False, then assuming continuous treatment. binary_outcome : bool, optional Whether the outcome is binary 0/1. Default is False, then assuming continuous outcome. min_samples_per_env : int, optional Minimum number of samples required per environment. independence_test_args : dict, optional Arguments for the independence test. n_bootstraps : int, optional Number of bootstrap iterations. If None, no bootstrapping is used, it is however strongly recommended to use bootstrap. """ self.feature_representation = feature_representation self.feature_representation_params = feature_representation_params self.binary_treatment = binary_treatment self.binary_outcome = binary_outcome self.min_samples_per_env = min_samples_per_env self.independence_test_args = independence_test_args self.n_bootstraps = n_bootstraps # Store last diagnostics for access via get_diagnostics() self._last_model_fit_diagnostics = None
[docs] def test( self, data: pd.DataFrame, covariate_vars: List[str], treatment_var: str, outcome_var: str, source_var: str, ) -> float: """ Perform falsification test for joint test of unconfoundedness and independence of causal mechanisms. Parameters ---------- data : pandas.DataFrame DataFrame containing all data from all environments. covariate_vars : list of str List of covariate column names. treatment_var : str Name of the treatment column. outcome_var : str Name of the outcome column. source_var : str Name of the source/environment column. Returns ------- float p-value from the independence test; low p-value implies unmeasured confounding may be present. """ # Validate required columns required_cols = set(covariate_vars + [treatment_var, outcome_var, source_var]) missing = required_cols.difference(data.columns) if missing: raise ValueError(f"Missing columns in data: {missing}") # Extract arrays for the test outcome = data[[outcome_var]].values # shape (n_samples, 1) treatment = data[[treatment_var]].values # shape (n_samples, 1) source = data[[source_var]].values # shape (n_samples, 1) covariates = data[covariate_vars].values # shape (n_samples, n_covariates) # Validate if binary_treatment/binary_outcome is set to True if self.binary_treatment and not np.isin(treatment, [0, 1]).all(): raise ValueError( "binary_treatment is True but treatment contains values other than 0 and 1" ) if self.binary_outcome and not np.isin(outcome, [0, 1]).all(): raise ValueError( "binary_outcome is True but outcome contains values other than 0 and 1" ) fit_treatment_model_func = partial( fit_model_jax, binary_response=self.binary_treatment ) fit_outcome_model_func = partial( fit_model_jax, binary_response=self.binary_outcome ) n_environments = len(np.unique(source)) coef_outcome_mech, coef_treatment_mech = [], [] resampled_coef_outcome_mech, resampled_coef_treatment_mech = [], [] model_fit_diagnostics = { "source_label": [], "outcome_model_mse": [], "treatment_model_mse": [], } # Prepare global feature mappings phi_outcome = self.get_feature_representation() phi_outcome_treatment = self.get_feature_representation() for source_label in np.unique(source): source_index = (source == source_label).squeeze() covariates_source = jnp.array(covariates[source_index, :]) treatment_source = jnp.array(treatment[source_index, :]) outcome_source = jnp.array(outcome[source_index, :]) if covariates_source.shape[0] < self.min_samples_per_env: n_environments -= 1 continue def add_intercept(term): return jnp.hstack([term, jnp.ones((term.shape[0], 1))]) # Handle empty covariate case: when no covariates, phi returns empty array # Apply feature transformation to covariates (may be empty) tf_covariates_source = phi_outcome(covariates_source) tf_covariates_treatment_source = phi_outcome_treatment( jnp.concatenate([covariates_source, treatment_source], axis=1) ) # Add intercept to transformed features tf_outcome_source = add_intercept(tf_covariates_source) tf_outcome_treatment_source = add_intercept(tf_covariates_treatment_source) params_outcome_mech, outcome_model_mse = fit_outcome_model_func( X=tf_outcome_treatment_source, Y=outcome_source ) params_treatment_mech, treatment_model_mse = fit_treatment_model_func( X=tf_outcome_source, Y=treatment_source, ) coef_outcome_mech.append(params_outcome_mech) coef_treatment_mech.append(params_treatment_mech) model_fit_diagnostics["source_label"].append(source_label) model_fit_diagnostics["outcome_model_mse"].append(outcome_model_mse) model_fit_diagnostics["treatment_model_mse"].append(treatment_model_mse) if self.n_bootstraps: keys = random.split(random.PRNGKey(0), self.n_bootstraps) resampled_params = vmap( bootstrap_model_fitting_jax, in_axes=(None, None, None, None, None, None, 0), )( outcome_source, treatment_source, tf_outcome_source, tf_outcome_treatment_source, fit_outcome_model_func, fit_treatment_model_func, keys, ) resampled_coef_outcome_mech.append(resampled_params[0]) resampled_coef_treatment_mech.append(resampled_params[1]) # Validate that at least one environment has enough samples if len(coef_outcome_mech) == 0 or len(coef_treatment_mech) == 0: raise ValueError( f"No environments have at least {self.min_samples_per_env} samples. " f"Found {len(np.unique(source))} environment(s) but all were skipped. " f"Either reduce min_samples_per_env or provide more data per environment." ) coef_outcome_mech = np.array(jnp.vstack(coef_outcome_mech)) coef_treatment_mech = np.array(jnp.vstack(coef_treatment_mech)) if self.n_bootstraps: resampled_coef_outcome_mech = np.array( jnp.stack(resampled_coef_outcome_mech, axis=1) ) resampled_coef_treatment_mech = np.array( jnp.stack(resampled_coef_treatment_mech, axis=1) ) pval = self.run_bootstrapped_independence_test( coef_outcome_mech, coef_treatment_mech, resampled_coef_outcome_mech, resampled_coef_treatment_mech, ) else: pval = self.run_independence_test(coef_outcome_mech, coef_treatment_mech) # save diagnostics from this run self._last_model_fit_diagnostics = model_fit_diagnostics return pval
[docs] def get_diagnostics(self) -> Dict: """ Returns quality of fit for nuisance models per environment from the most recent test() call. Returns ------- dict Diagnostics for model fit, including source labels and mean squared errors for outcome and treatment models. """ return self._last_model_fit_diagnostics
[docs] def get_feature_representation(self): """ Returns function that outputs feature representation Raises: ------- ValueError: If invalid feature_representation or feature_representation_params are provided Returns: -------- Callable function Feature representation """ if self.feature_representation == "linear": return lambda x: x elif self.feature_representation == "poly": if "degree" not in self.feature_representation_params: print( "Warning: 'degree' not provided in feature_representation_params. Using default value of 3." ) self.feature_representation_params["degree"] = 3 return lambda x: create_polynomial_representation( x, **self.feature_representation_params ) else: raise ValueError( f"Invalid feature representation: {self.feature_representation}" )
[docs] def run_independence_test(self, data_x, data_y): """ Runs independence test between data_x and data_y Returns: ------- float p-value from test """ return permutation_independence_test(data_x=data_x, data_y=data_y)
[docs] def run_bootstrapped_independence_test( self, data_x, data_y, resampled_data_x, resampled_data_y ): """ Runs bootstrapped independence test between data_x and data_y Returns: ------- float p-value from test """ return bootstrapped_permutation_independence_test( data_x=data_x, data_y=data_y, resampled_data_x=resampled_data_x, resampled_data_y=resampled_data_y, )