from causal_falsify.algorithms.abstract import AbstractFalsificationAlgorithm
import numpy as np
import pandas as pd
from typing import List
from causal_falsify.utils.cond_indep import (
kcit_rbf,
fisherz,
)
[docs]
class TransportabilityTest(AbstractFalsificationAlgorithm):
def __init__(
self,
cond_indep_test: str = "kcit_rbf",
max_sample_size: int = np.inf,
seed: int | None = None,
) -> None:
"""
Transportability-based test.
Inspired by the benchmarking framework in:
Dahabreh et al., 2024. "Using Trial and Observational Data to Assess Effectiveness:
Trial Emulation, Transportability, Benchmarking, and Joint Analysis"
Performs a joint test for transportability and unconfoundedness across sources.
A rejection indicates that both conditions are likely violated.
Parameters
----------
cond_indep_test : str
Conditional independence test to use. Options are:
- 'kcit_rbf': Kernel-based conditional independence test with RBF kernel.
- 'fisherz': Fisher z-transform test for linear conditional independence.
max_sample_size : int, optional
Maximum number of samples to use during testing. Helps control runtime for
large datasets. Defaults to None (use all samples).
seed : int, optional
Used when subsampling data (necessary if max_sample_size is smaller than total dataset size)
Raises
------
ValueError
If `cond_indep_test` is not one of the supported options.
"""
super().__init__()
if max_sample_size <= 0:
raise ValueError("max_sample_size must be larger than zero")
self.cond_indep_test = cond_indep_test
self.max_sample_size_test = max_sample_size
self.rng = np.random.default_rng(seed)
[docs]
def test(
self,
data: pd.DataFrame,
covariate_vars: List[str],
treatment_var: str,
outcome_var: str,
source_var: str,
) -> float:
"""
Perform falsification test for joint test of unconfoundedness and transportability.
Args:
data (pd.DataFrame): DataFrame containing all required columns.
covariate_vars (List[str]): Covariate column names to condition on.
treatment_var (str): Treatment column name.
outcome_var (str): Outcome column name.
source_var (str): Source/environment indicator column name.
Returns:
float: p-value of the test; low p-value implies unmeasured confounding may be present.
"""
# Validate required columns
required_cols = set(covariate_vars + [treatment_var, outcome_var, source_var])
missing = required_cols.difference(data.columns)
if missing:
raise ValueError(f"Missing columns in data: {missing}")
# Extract arrays for the test
outcome = data[[outcome_var]].values # shape (n_samples, 1)
treatment = data[[treatment_var]].values # shape (n_samples, 1)
source = data[[source_var]].values # shape (n_samples, 1)
covariates = data[covariate_vars].values # shape (n_samples, n_covariates)
# Subsample if necessary
if outcome.shape[0] > self.max_sample_size_test:
outcome, source, covariates, treatment = self.subsample_data(
outcome, source, covariates, treatment
)
# Select conditional independence test function
if self.cond_indep_test == "kcit_rbf":
test_func = kcit_rbf
elif self.cond_indep_test == "fisherz":
test_func = fisherz
else:
raise ValueError(f"Unsupported cond_indep_test: {self.cond_indep_test}")
# Test if outcome is independent of source conditional on covariates and treatment
conditioning_vars = np.hstack([covariates, treatment])
pval = test_func(outcome, source, conditioning_vars)
return pval
[docs]
def subsample_data(
self,
outcome: np.ndarray,
source: np.ndarray,
covariates: np.ndarray,
treatment: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Subsample data to limit the number of samples while preserving the source distribution.
Parameters
----------
outcome : np.ndarray of shape (n_samples, 1)
Outcome variable for each sample.
source : np.ndarray of shape (n_samples, 1)
Source indicator for each sample.
covariates : np.ndarray of shape (n_samples, n_covariates)
Observed covariates for each sample.
treatment : np.ndarray of shape (n_samples, 1)
Treatment assignment for each sample.
Returns
-------
outcome_sub : np.ndarray of shape (n_subsamples, 1)
Subsampled outcomes.
source_sub : np.ndarray of shape (n_subsamples, 1)
Subsampled source indicators.
covariates_sub : np.ndarray of shape (n_subsamples, n_covariates)
Subsampled covariates.
treatment_sub : np.ndarray of shape (n_subsamples, 1)
Subsampled treatment assignments.
Notes
-----
- The method ensures that each source is represented approximately
proportionally to its frequency in the original data.
- If the total number of selected samples exceeds `self.max_sample_size_test`,
a random subset of the selected samples is drawn to enforce the limit.
"""
unique_sources, counts = np.unique(source, return_counts=True)
proportions = counts / counts.sum()
sampled_indices = []
for src_value, proportion in zip(unique_sources, proportions):
src_indices = np.where(source.flatten() == src_value)[0]
n_samples = min(
len(src_indices), int(np.round(proportion * self.max_sample_size_test))
)
sampled_indices.extend(
self.rng.choice(src_indices, n_samples, replace=False)
)
if len(sampled_indices) > self.max_sample_size_test:
sampled_indices = self.rng.choice(
sampled_indices, self.max_sample_size_test, replace=False
)
return (
outcome[sampled_indices, :],
source[sampled_indices, :],
covariates[sampled_indices, :],
treatment[sampled_indices, :],
)