Heterogeneous Treatment Effects
The Average Treatment Effect tells you whether a treatment works on average — but averages mask the distribution. A drug that helps young patients and harms elderly ones has a positive ATE if the young outnumber the elderly, but deploying it uniformly is a mistake. A push notification feature with a positive ATE for power users and a negative one for casual users shouldn't be shipped to everyone. Estimating the Conditional Average Treatment Effect — how the effect varies as a function of individual characteristics — is the next level of causal inference and is now central to product analytics, clinical trials, and policy evaluation. The tools for this (meta-learners, causal forests) live at the intersection of causal inference and machine learning, and require careful thinking about what can and cannot be identified from data.
Theory
Imputed counterfactuals + propensity weighting
causal forest (Wager & Athey 2018) extends beyond meta-learners with honest splitting + asymptotic CIs
A single average treatment effect hides the most actionable information: who benefits most, who benefits least, and who might be harmed. Personalized medicine, targeted advertising, and policy design all depend on knowing not just that a treatment works on average, but where in the population it works and why. The diagram above shows how S-, T-, and X-learners differ in how they estimate this heterogeneity — the architectural choice determines which estimation errors dominate.
CATE and why ATE is insufficient
The Conditional Average Treatment Effect (CATE) is:
The CATE must be a function of rather than a scalar because the treatment effect genuinely varies across units — different biological pathways, different price sensitivities, different baseline engagement levels. Estimating it nonparametrically requires recovering two conditional expectations simultaneously ( and ), which means the difficulty scales with the complexity of the outcome surfaces, not just the average effect. This is why meta-learners matter: different learner architectures make different bias-variance tradeoffs depending on how imbalanced the treatment arms are and how smooth the CATE function is.
This is a function of covariates, not a scalar. The ATE is just its expectation: .
CATE informs three decisions beyond what ATE can:
- Targeting: deploy treatment only to units where (positive net lift)
- Personalization: choose the best treatment per individual (argmax over arms)
- Mechanism understanding: which features drive heterogeneity, and why?
Under unconfoundedness (), CATE is nonparametrically identified:
The challenge is estimating two conditional expectations simultaneously with finite data, without overfitting or letting regularization bias the estimated difference.
S-Learner
The S-Learner (single learner) fits one outcome model with treatment as a feature:
CATE estimate:
Problem: regularization can shrink the coefficient on T toward zero, making the model act as if treatment has no effect even when it does. Tree-based S-learners are particularly prone to this — if T has weak marginal importance, it may not appear in any splits, producing a flat everywhere.
T-Learner
The T-Learner (two-model learner) fits separate outcome models per arm:
CATE estimate:
Problem: in regions of covariate space where the treatment group is small, is poorly estimated and must extrapolate. The difference of two noisy estimates can be high-variance, and extrapolation errors compound.
X-Learner
The X-Learner (cross learner) was developed to handle imbalanced treatment (Künzel et al., 2019). It uses imputed counterfactuals to get more stable CATE estimates:
Stage 1: Fit base models (same as T-Learner).
Stage 2: Impute individual treatment effects using cross-arm predictions:
Fit CATE models on these imputed effects:
Stage 3: Combine using the propensity score :
The propensity weighting upweights the CATE estimate from the group with more data. When the control group is much larger (common in A/B tests with 10% treatment), , so — estimated on abundant control data — dominates.
The diagram above shows the architectural differences between all three meta-learners.
Causal Forests
Causal Forests (Wager and Athey, 2018) extend random forests to estimate CATE with valid pointwise confidence intervals. Two key modifications to standard random forests:
Honest estimation: Each tree splits the data into two halves — a splitting sample (used to find splits) and an estimation sample (used to compute leaf means). This prevents overfitting that would invalidate inference.
Causal splitting criterion: Instead of minimizing MSE of Y, splits maximize the heterogeneity of treatment effects across child nodes. For a candidate split :
where , are CATE estimates in left/right child nodes.
Asymptotic normality: The forest estimate is asymptotically normal and centered on under mild regularity conditions:
where is estimated via the infinitesimal jackknife. This gives valid confidence intervals without distributional assumptions — a major advantage over meta-learners, which require bootstrap inference.
Double Robustness via R-Learner / GRF: The Generalized Random Forest framework (Athey et al., 2019) uses a pseudo-outcome based on the Robinson decomposition:
The CATE estimate minimizes:
This is doubly robust: consistent even if one of or is misspecified.
Policy learning and uplift modeling
Given , the optimal treatment policy is:
where c is the cost of treatment (or 0 if treatment is free). This uplift modeling problem asks: who should we target to maximize aggregate effect while respecting budget constraints?
Evaluation: Unlike outcome prediction, CATE cannot be evaluated directly — you never observe both and for the same unit. Two valid evaluation approaches:
-
RATE (Rank-Weighted Average Treatment Effect): Sort units by and measure whether units ranked high actually respond better. Uses inverse propensity weighting to compute unbiased estimates per decile.
-
AUTOC (Area Under the TOC Curve): Measures how much benefit is concentrated in the top-ranked fraction of units.
Qini coefficient (analogous to AUC for uplift) = area between the uplift curve and the random policy baseline.
Walkthrough
Meta-learners with EconML
import numpy as np
from econml.metalearners import SLearner, TLearner, XLearner
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
# Simulate data: Y = treatment * tau(X) + baseline(X) + noise
np.random.seed(42)
n, d = 2000, 5
X = np.random.randn(n, d)
# True CATE: positive for x0 > 0, negative for x0 < 0
tau_true = np.where(X[:, 0] > 0, 0.8, -0.3)
T = np.random.binomial(1, 0.5, n)
Y = tau_true * T + X[:, 0] + np.random.randn(n) * 0.5
X_tr, X_te, T_tr, T_te, Y_tr, Y_te, tau_tr, tau_te = train_test_split(
X, T, Y, tau_true, test_size=0.3, random_state=0
)
gbr = GradientBoostingRegressor(n_estimators=100, max_depth=3)
# S-Learner
sl = SLearner(overall_model=gbr)
sl.fit(Y_tr, T_tr, X=X_tr)
tau_sl = sl.effect(X_te)
# T-Learner
tl = TLearner(models=[gbr, gbr])
tl.fit(Y_tr, T_tr, X=X_tr)
tau_tl = tl.effect(X_te)
# X-Learner (with propensity score)
gbc = GradientBoostingClassifier(n_estimators=100, max_depth=3)
xl = XLearner(models=[gbr, gbr], propensity_model=gbc)
xl.fit(Y_tr, T_tr, X=X_tr)
tau_xl = xl.effect(X_te)
# Causal Forest (doubly robust via R-learner)
cf = CausalForestDML(
n_estimators=200,
min_samples_leaf=10,
max_features='sqrt',
random_state=42,
cv=3,
)
cf.fit(Y_tr, T_tr, X=X_tr, W=None) # W = extra controls
tau_cf = cf.effect(X_te)
# Evaluation: MSE vs true CATE
for name, est in [('S-Learner', tau_sl), ('T-Learner', tau_tl),
('X-Learner', tau_xl), ('Causal Forest', tau_cf)]:
mse = np.mean((est - tau_te) ** 2)
print(f'{name:15s}: CATE MSE = {mse:.4f}')Confidence intervals from Causal Forest
# Pointwise CIs from the infinitesimal jackknife
tau_point, tau_lb, tau_ub = cf.effect_interval(X_te, alpha=0.05)
# Check: which units have significantly positive CATE?
sig_positive = (tau_lb > 0)
print(f'Units with significant positive CATE: {sig_positive.mean():.1%}')
# Policy: target top quartile by estimated CATE
quartile = np.percentile(tau_point, 75)
policy = (tau_point > quartile)
avg_effect_targeted = tau_te[policy].mean()
avg_effect_all = tau_te.mean()
print(f'ATE over all: {avg_effect_all:.3f}')
print(f'ATE over targeted quartile: {avg_effect_targeted:.3f}')RATE evaluation for uplift
from sklearn.model_selection import cross_val_predict
# Inverse propensity weighting for off-policy evaluation
e_hat = cross_val_predict(gbc, X_te, T_te, method='predict_proba', cv=5)[:, 1]
# AIPW pseudo-outcomes for each unit
mu1_hat = cross_val_predict(gbr, X_te[T_te == 1], Y_te[T_te == 1], cv=5) # simplified
# Full AIPW: psi_i = (T_i/e_hat - (1-T_i)/(1-e_hat)) * Y_i + correction terms
# Qini: sort by tau_hat, compute cumulative lift
order = np.argsort(-tau_point) # highest predicted CATE first
treated_frac = np.cumsum(T_te[order]) / T_te.sum()
qini_curve = np.cumsum(Y_te[order] * T_te[order]) / T_te.sum()
# Qini coefficient = AUC of qini_curve - AUC of diagonalAnalysis & Evaluation
Where Your Intuition Breaks
More complex CATE models capture more heterogeneity and therefore produce better targeting policies. CATE estimation is a high-variance problem: the model must estimate two conditional means and take their difference, compounding estimation error. A flexible model that fits training data well can produce a wildly inaccurate CATE surface through overfitting, especially in regions of covariate space with few treated units. Causal Forests provide honest confidence intervals that grow wide in data-sparse regions precisely to signal this. A complex model with a narrow CI on training data and a wide CI on held-out units is the canonical failure mode — the heterogeneity is fitted noise, not signal.
Meta-learner selection guide
| S-Learner | T-Learner | X-Learner | Causal Forest | |
|---|---|---|---|---|
| Sample size needed | Smallest | Moderate | Best with imbalance | Large |
| Treatment balance | Any | Balanced | Imbalanced (works best) | Any |
| Regularization bias | High risk | Moderate | Lower | Low (honest) |
| Valid CIs | Bootstrap only | Bootstrap only | Bootstrap only | Asymptotic (exact) |
| Continuous treatment | Yes | Multi-arm only | No | Yes (GRF) |
| Implementation | Trivial | Simple | Moderate | EconML/GRF |
| Best for | Baseline, large n | Clean RCT | Observational, imbalanced | Best all-round |
CATE vs ATE: when it matters
| Signal | Action |
|---|---|
| CATE variance low, ATE positive | Ship to everyone — heterogeneity doesn't change the decision |
| CATE bimodal (positive + negative) | Targeting — deploy only to positive-effect subgroup |
| CATE correlated with observable features | Personalization — use CATE as a ranking signal |
| CATE uncertainty high everywhere | Run a larger experiment — you don't have enough data |
Common pitfalls
Snooping on CATE: If you compute CATE then report the subgroup with the highest effect, you're doing post-hoc subgroup analysis. The confidence intervals are invalid. Specify the CATE estimation protocol in advance.
Confounding in observational CATE: Meta-learners inherit the unconfoundedness assumption. In observational data, estimates , which is a causal effect only under unconfoundedness. Conditioning on more X can introduce collider bias if you condition on mediators.
Sample splitting for valid inference: Computing and then testing whether high- units have higher realized outcomes on the same data inflates Type I error. Use held-out data or cross-fitting.
RATE vs pseudo-R²: Don't evaluate CATE with — there's no ground truth individual effect. Use RATE, AUTOC, or Qini on held-out experimental data only.
Production-Ready Code
Production CATE estimation follows a train/eval split: fit the causal model on training data, evaluate CATE quality on held-out data using the AUUC curve, then fit a shallow policy tree to produce human-readable targeting rules. Deploying raw CATE scores without an AUUC check is a common failure — the model may rank units correctly on average but fail in the tails where targeting decisions matter most.
# production_hte.py
# CausalForestDML CATE pipeline, AUUC evaluation, and policy tree targeting rules.
# Install: pip install econml scikit-learn
from __future__ import annotations
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
def cate_pipeline(
X: np.ndarray,
T: np.ndarray,
Y: np.ndarray,
feature_names: list[str] | None = None,
test_size: float = 0.3,
random_state: int = 42,
) -> dict:
"""
End-to-end CATE estimation using CausalForestDML.
Splits data into train/eval; returns model + evaluation metrics.
"""
from econml.dml import CausalForestDML
X_tr, X_te, T_tr, T_te, Y_tr, Y_te = train_test_split(
X, T, Y, test_size=test_size, random_state=random_state,
)
model = CausalForestDML(
model_y=GradientBoostingRegressor(n_estimators=100, random_state=random_state),
model_t=GradientBoostingClassifier(n_estimators=100, random_state=random_state),
n_estimators=500,
min_samples_leaf=10,
random_state=random_state,
)
model.fit(Y_tr, T_tr, X=X_tr)
cate_te = model.effect(X_te)
lb, ub = model.effect_interval(X_te, alpha=0.10)
return {
"model": model,
"X_test": X_te,
"T_test": T_te,
"Y_test": Y_te,
"cate_test": cate_te,
"cate_lb": lb,
"cate_ub": ub,
"cate_mean": round(float(cate_te.mean()), 5),
"cate_std": round(float(cate_te.std()), 5),
"cate_p10": round(float(np.percentile(cate_te, 10)), 5),
"cate_p90": round(float(np.percentile(cate_te, 90)), 5),
"pct_positive": round(float((cate_te > 0).mean() * 100), 1),
}
def auuc_score(Y: np.ndarray, T: np.ndarray, cate_hat: np.ndarray) -> float:
"""
Area Under the Uplift Curve.
Sort units by predicted CATE descending; measure cumulative uplift vs. random.
Higher = better CATE ranking. Use on held-out experimental data only.
"""
order = np.argsort(-cate_hat)
Y_o, T_o = Y[order], T[order]
n = len(Y)
uplift_vals = []
n_t_cum = n_c_cum = 0.0
for i in range(n):
n_t_cum += T_o[i] * Y_o[i]
n_c_cum += (1 - T_o[i]) * Y_o[i]
t_seen = T_o[:i+1].sum()
c_seen = (1 - T_o[:i+1]).sum()
r_t = n_t_cum / max(t_seen, 1)
r_c = n_c_cum / max(c_seen, 1)
uplift_vals.append(r_t - r_c)
return round(float(np.trapz(uplift_vals) / n), 6)
def policy_tree_report(
model,
X_test: np.ndarray,
feature_names: list[str] | None = None,
max_depth: int = 2,
) -> dict:
"""
Fit a shallow policy tree on learned CATEs.
Produces human-readable targeting rules for ops/product teams.
"""
from econml.policy import PolicyTree
cate = model.effect(X_test)
tree = PolicyTree(max_depth=max_depth, random_state=42)
tree.fit(X_test, cate)
rules = tree.export_text(feature_names=feature_names)
pct_targeted = float((tree.predict(X_test) > 0).mean() * 100)
return {
"targeting_rules": rules,
"pct_targeted": round(pct_targeted, 1),
"expected_lift_if_targeted": round(float(cate[tree.predict(X_test) > 0].mean()), 5),
}
# ── Example ───────────────────────────────────────────────────────────────────
rng = np.random.default_rng(42)
n = 3_000
X = rng.normal(0, 1, (n, 6))
T = rng.binomial(1, 0.5, n).astype(float)
true_cate = 2.0 * X[:, 0] # effect varies with X[:,0]
Y = true_cate * T + X[:, 1] + rng.normal(0, 1, n)
feature_names = [f"feature_{i}" for i in range(6)]
result = cate_pipeline(X, T, Y, feature_names=feature_names)
print(f"CATE mean: {result['cate_mean']}, pct positive: {result['pct_positive']}%")
auuc = auuc_score(result["Y_test"], result["T_test"], result["cate_test"])
print(f"AUUC: {auuc}") # > 0 means better-than-random targeting
rules = policy_tree_report(result["model"], result["X_test"], feature_names=feature_names)
print(rules["targeting_rules"])
print(f"Target {rules['pct_targeted']}% of population")Enjoying these notes?
Get new lessons delivered to your inbox. No spam.