import warnings
from typing import Dict, List, Any
import numpy as np
import pandas as pd
from scipy.stats import combine_pvalues
from sklearn.exceptions import ConvergenceWarning
from causal_falsify.algorithms.abstract import AbstractFalsificationAlgorithm
from causal_falsify.utils.cond_indep import (
kcit_rbf,
fisherz,
)
warnings.simplefilter("ignore", ConvergenceWarning)
[docs]
class HGIC(AbstractFalsificationAlgorithm):
def __init__(
self,
cond_indep_test: str = "kcit_rbf",
max_tests: int = -1,
min_test_sample_size: int = 25,
method_pval_combination: str = "tippett",
) -> None:
"""
Hierarchical Graphical Independence Constraint (HGIC) algorithm.
Implements the method from:
"Detecting Hidden Confounding in Observational Data Using Multiple Environments"
Karlsson and Krijthe, NeurIPS 2023
(https://arxiv.org/abs/2205.13935)
Jointly tests for independence between causal mechanisms and unconfoundedness across sources.
A rejection falsifies both conditions jointly.
Parameters
----------
cond_indep_test : str, default="kcit_rbf"
Conditional independence test to use. Options: "kcit_rbf" or "fisherz".
max_tests : int, default=-1
Maximum number of pairwise tests to perform. Use -1 for unlimited.
min_test_sample_size : int, default=25
Minimum number of sources per test.
method_pval_combination : str, default="tippett"
Method to combine p-values. Options: "tippett", "fisher", or "stouffer".
"""
super().__init__()
if max_tests == 0:
raise ValueError("max_tests must be non-zero (use -1 for unlimited).")
self.cond_indep_test = cond_indep_test
self.max_tests = max_tests
self.min_test_sample_size = min_test_sample_size
self.method_pval_combination = method_pval_combination
# self.feature_scaling = self.independence_test_args.pop("feature_scaling", False)
[docs]
def test(
self,
data: pd.DataFrame,
covariate_vars: List[str],
treatment_var: str,
outcome_var: str,
source_var: str,
) -> float:
"""
Perform falsification test for joint test of unconfoundedness and independence of causal mechanisms.
Parameters
----------
data : pd.DataFrame
DataFrame containing all required columns.
covariate_vars : List[str]
Covariate column names to condition on.
treatment_var : str
Treatment column name.
outcome_var : str
Outcome column name.
source_var : str
Source/environment indicator column name.
Returns
-------
float
p-value of the falsification test; low p-value implies unmeasured confounding may be present.
"""
# Validate required columns
required_cols = set(covariate_vars + [treatment_var, outcome_var, source_var])
missing = required_cols.difference(data.columns)
if missing:
raise ValueError(f"Missing columns in data: {missing}")
# Partition data by source variable
grouped_data = {
str(source_label): df for source_label, df in data.groupby(source_var)
}
# Sort sources by sample size
sample_sizes = sorted(
[(source_label, len(df)) for source_label, df in grouped_data.items()],
key=lambda x: x[1],
reverse=True,
)
if len(sample_sizes) < 2:
raise ValueError("At least two sources are required for the test.")
max_samples_available = sample_sizes[1][1]
max_tests = self.max_tests if self.max_tests > 0 else max_samples_available
pval_list, samples_used, df_list = [], [], []
n = 0
while n < max_samples_available - 1:
valid_sources = [source for source, size in sample_sizes if size > n + 1]
if len(valid_sources) < self.min_test_sample_size:
n += 2
continue
df_tmp = self._construct_pair_df(
grouped_data,
covariate_vars,
treatment_var,
outcome_var,
valid_sources,
n,
)
df_list.append(df_tmp)
n += 2
if len(df_list) >= max_tests:
break
if not df_list:
raise ValueError("No valid sample pairs available for testing.")
cond_vars = (
[f"{v}_i" for v in covariate_vars]
+ [f"{v}_j" for v in covariate_vars]
+ ["treatment_i"]
)
if self.cond_indep_test == "kcit_rbf":
test_func = kcit_rbf
elif self.cond_indep_test == "fisherz":
test_func = fisherz
else:
raise ValueError(f"Unsupported cond_indep_test: {self.cond_indep_test}")
for df in df_list:
pval = test_func(
df["treatment_j"].values.reshape(-1, 1),
df["outcome_i"].values.reshape(-1, 1),
df[cond_vars].values,
)
if np.isnan(pval):
print("Warning: Computed NaN p-value.")
else:
pval_list.append(pval)
samples_used.append(len(df))
if not pval_list:
raise RuntimeError("No valid p-values were computed.")
weights = (
[np.sqrt(n) for n in samples_used]
if self.method_pval_combination == "stouffer"
else None
)
_, combined_pval = combine_pvalues(
pval_list, method=self.method_pval_combination, weights=weights
)
return combined_pval
def _construct_pair_df(
self,
data: Dict[str, pd.DataFrame],
covariates: List[str],
treatment_var: str,
outcome_var: str,
sources: List[str],
index: int,
) -> pd.DataFrame:
"""
Construct a pairwise sample DataFrame for statistical testing.
Parameters
----------
data : Dict[str, pd.DataFrame]
Dictionary mapping source/environment labels to their respective DataFrames.
covariates : List[str]
List of observed covariate column names.
treatment_var : str
Name of the treatment variable column.
outcome_var : str
Name of the outcome variable column.
sources : List[str]
List of source/environment labels to sample from.
index : int
Starting index for sampling within each source.
Returns
-------
pd.DataFrame
Constructed sample DataFrame containing paired samples for testing.
"""
get = lambda source, col: data[source][col].values
df_data = {
"treatment_i": np.array(
[get(source, treatment_var)[index] for source in sources]
),
"outcome_i": np.array(
[get(source, outcome_var)[index] for source in sources]
),
"treatment_j": np.array(
[get(source, treatment_var)[index + 1] for source in sources]
),
"outcome_j": np.array(
[get(source, outcome_var)[index + 1] for source in sources]
),
}
for var in covariates:
df_data[f"{var}_i"] = np.array(
[get(source, var)[index] for source in sources]
)
df_data[f"{var}_j"] = np.array(
[get(source, var)[index + 1] for source in sources]
)
return pd.DataFrame(data=df_data)