Neural-Path/Notes
40 min

Heterogeneous Treatment Effects

The Average Treatment Effect tells you whether a treatment works on average — but averages mask the distribution. A drug that helps young patients and harms elderly ones has a positive ATE if the young outnumber the elderly, but deploying it uniformly is a mistake. A push notification feature with a positive ATE for power users and a negative one for casual users shouldn't be shipped to everyone. Estimating the Conditional Average Treatment Effect — how the effect varies as a function of individual characteristics — is the next level of causal inference and is now central to product analytics, clinical trials, and policy evaluation. The tools for this (meta-learners, causal forests) live at the intersection of causal inference and machine learning, and require careful thinking about what can and cannot be identified from data.

Theory

CATE meta-learner architectures

Imputed counterfactuals + propensity weighting

Stage 1
Fit two base outcome models (same as T-learner)
· μ̂₁ on treated, μ̂₀ on control
· Estimate cross-arm residuals:
· D̃ᵢ¹ = Yᵢ − μ̂₀(Xᵢ) for treated
· D̃ᵢ⁰ = μ̂₁(Xᵢ) − Yᵢ for control
Stage 2
Fit CATE models on imputed residuals
· τ̂₁ ← regress D̃¹ on X (treated)
· τ̂₀ ← regress D̃⁰ on X (control)
· Each model uses imputed individual effects
Combine
Weight by propensity score g(x)
· τ̂(x) = g(x)·τ̂₁(x) + (1−g(x))·τ̂₀(x)
· g(x) = P(T=1 | X=x)
· Upweights estimate from more-data arm
CATE estimator
τ̂(x) = g(x)·τ̂₁(x) + (1−g(x))·τ̂₀(x)
⚠ pitfall
Requires propensity score estimation; sensitive to propensity model misspecification in tails.
✓ use when
Highly imbalanced treatment (e.g., 5% treated), observational data, want best CATE accuracy.

causal forest (Wager & Athey 2018) extends beyond meta-learners with honest splitting + asymptotic CIs

A single average treatment effect hides the most actionable information: who benefits most, who benefits least, and who might be harmed. Personalized medicine, targeted advertising, and policy design all depend on knowing not just that a treatment works on average, but where in the population it works and why. The diagram above shows how S-, T-, and X-learners differ in how they estimate this heterogeneity — the architectural choice determines which estimation errors dominate.

CATE and why ATE is insufficient

The Conditional Average Treatment Effect (CATE) is: τ(x)=E[Yi(1)Yi(0)Xi=x]\tau(x) = \mathbb{E}[Y_i(1) - Y_i(0) \mid X_i = x]

The CATE must be a function of xx rather than a scalar because the treatment effect genuinely varies across units — different biological pathways, different price sensitivities, different baseline engagement levels. Estimating it nonparametrically requires recovering two conditional expectations simultaneously (E[YX,T=1]E[Y|X,T=1] and E[YX,T=0]E[Y|X,T=0]), which means the difficulty scales with the complexity of the outcome surfaces, not just the average effect. This is why meta-learners matter: different learner architectures make different bias-variance tradeoffs depending on how imbalanced the treatment arms are and how smooth the CATE function is.

This is a function of covariates, not a scalar. The ATE is just its expectation: ATE=E[τ(Xi)]\text{ATE} = \mathbb{E}[\tau(X_i)].

CATE informs three decisions beyond what ATE can:

  1. Targeting: deploy treatment only to units where τ(x)>c\tau(x) > c (positive net lift)
  2. Personalization: choose the best treatment per individual (argmax over arms)
  3. Mechanism understanding: which features drive heterogeneity, and why?

Under unconfoundedness (Yi(t) ⁣ ⁣ ⁣TiXiY_i(t) \perp\!\!\!\perp T_i \mid X_i), CATE is nonparametrically identified: τ(x)=E[YX=x,T=1]E[YX=x,T=0]\tau(x) = \mathbb{E}[Y \mid X=x, T=1] - \mathbb{E}[Y \mid X=x, T=0]

The challenge is estimating two conditional expectations simultaneously with finite data, without overfitting or letting regularization bias the estimated difference.

S-Learner

The S-Learner (single learner) fits one outcome model with treatment as a feature: μ^(x,t)=E^[YX=x,T=t]\hat{\mu}(x, t) = \hat{\mathbb{E}}[Y \mid X=x, T=t]

CATE estimate: τ^(x)=μ^(x,1)μ^(x,0)\hat{\tau}(x) = \hat{\mu}(x, 1) - \hat{\mu}(x, 0)

Problem: regularization can shrink the coefficient on T toward zero, making the model act as if treatment has no effect even when it does. Tree-based S-learners are particularly prone to this — if T has weak marginal importance, it may not appear in any splits, producing a flat τ^(x)0\hat{\tau}(x) \approx 0 everywhere.

T-Learner

The T-Learner (two-model learner) fits separate outcome models per arm: μ^1(x)=E^[YX=x,T=1],μ^0(x)=E^[YX=x,T=0]\hat{\mu}_1(x) = \hat{\mathbb{E}}[Y \mid X=x, T=1], \quad \hat{\mu}_0(x) = \hat{\mathbb{E}}[Y \mid X=x, T=0]

CATE estimate: τ^(x)=μ^1(x)μ^0(x)\hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x)

Problem: in regions of covariate space where the treatment group is small, μ^1\hat{\mu}_1 is poorly estimated and must extrapolate. The difference of two noisy estimates can be high-variance, and extrapolation errors compound.

X-Learner

The X-Learner (cross learner) was developed to handle imbalanced treatment (Künzel et al., 2019). It uses imputed counterfactuals to get more stable CATE estimates:

Stage 1: Fit base models μ^0,μ^1\hat{\mu}_0, \hat{\mu}_1 (same as T-Learner).

Stage 2: Impute individual treatment effects using cross-arm predictions: D~i1=Yiμ^0(Xi)(for treated units Ti=1)\tilde{D}_i^1 = Y_i - \hat{\mu}_0(X_i) \quad \text{(for treated units } T_i=1\text{)} D~i0=μ^1(Xi)Yi(for control units Ti=0)\tilde{D}_i^0 = \hat{\mu}_1(X_i) - Y_i \quad \text{(for control units } T_i=0\text{)}

Fit CATE models on these imputed effects: τ^1(x)=E^[D~1X=x],τ^0(x)=E^[D~0X=x]\hat{\tau}_1(x) = \hat{\mathbb{E}}[\tilde{D}^1 \mid X=x], \quad \hat{\tau}_0(x) = \hat{\mathbb{E}}[\tilde{D}^0 \mid X=x]

Stage 3: Combine using the propensity score g(x)=P(T=1X=x)g(x) = P(T=1 \mid X=x): τ^(x)=g(x)τ^1(x)+(1g(x))τ^0(x)\hat{\tau}(x) = g(x)\,\hat{\tau}_1(x) + (1 - g(x))\,\hat{\tau}_0(x)

The propensity weighting upweights the CATE estimate from the group with more data. When the control group is much larger (common in A/B tests with 10% treatment), g(x)0.1g(x) \approx 0.1, so τ^0\hat{\tau}_0 — estimated on abundant control data — dominates.

The diagram above shows the architectural differences between all three meta-learners.

Causal Forests

Causal Forests (Wager and Athey, 2018) extend random forests to estimate CATE with valid pointwise confidence intervals. Two key modifications to standard random forests:

Honest estimation: Each tree splits the data into two halves — a splitting sample (used to find splits) and an estimation sample (used to compute leaf means). This prevents overfitting that would invalidate inference.

Causal splitting criterion: Instead of minimizing MSE of Y, splits maximize the heterogeneity of treatment effects across child nodes. For a candidate split (j,c)(j, c): Δ(j,c)=nLnRn(τ^Lτ^R)2penalty\Delta(j,c) = \frac{n_L n_R}{n}\left(\hat{\tau}_L - \hat{\tau}_R\right)^2 - \text{penalty}

where τ^L\hat{\tau}_L, τ^R\hat{\tau}_R are CATE estimates in left/right child nodes.

Asymptotic normality: The forest estimate τ^(x)\hat{\tau}(x) is asymptotically normal and centered on τ(x)\tau(x) under mild regularity conditions: τ^(x)τ(x)σ^(x)dN(0,1)\frac{\hat{\tau}(x) - \tau(x)}{\hat{\sigma}(x)} \xrightarrow{d} \mathcal{N}(0,1)

where σ^2(x)\hat{\sigma}^2(x) is estimated via the infinitesimal jackknife. This gives valid confidence intervals without distributional assumptions — a major advantage over meta-learners, which require bootstrap inference.

Double Robustness via R-Learner / GRF: The Generalized Random Forest framework (Athey et al., 2019) uses a pseudo-outcome based on the Robinson decomposition: Y~i=Yim(Xi),T~i=Tie(Xi)\tilde{Y}_i = Y_i - m(X_i), \quad \tilde{T}_i = T_i - e(X_i)

The CATE estimate minimizes: τ^=argminτi[(Y~iτ(Xi)T~i)2]\hat{\tau} = \arg\min_\tau \sum_i \left[(\tilde{Y}_i - \tau(X_i)\,\tilde{T}_i)^2\right]

This is doubly robust: consistent even if one of m(x)m(x) or e(x)e(x) is misspecified.

Policy learning and uplift modeling

Given τ^(x)\hat{\tau}(x), the optimal treatment policy is: π(x)=1[τ^(x)>c]\pi^*(x) = \mathbf{1}[\hat{\tau}(x) > c]

where c is the cost of treatment (or 0 if treatment is free). This uplift modeling problem asks: who should we target to maximize aggregate effect while respecting budget constraints?

Evaluation: Unlike outcome prediction, CATE cannot be evaluated directly — you never observe both Yi(1)Y_i(1) and Yi(0)Y_i(0) for the same unit. Two valid evaluation approaches:

  1. RATE (Rank-Weighted Average Treatment Effect): Sort units by τ^(x)\hat{\tau}(x) and measure whether units ranked high actually respond better. Uses inverse propensity weighting to compute unbiased estimates per decile.

  2. AUTOC (Area Under the TOC Curve): Measures how much benefit is concentrated in the top-ranked fraction of units.

Qini coefficient (analogous to AUC for uplift) = area between the uplift curve and the random policy baseline.

Walkthrough

Meta-learners with EconML

python
import numpy as np
from econml.metalearners import SLearner, TLearner, XLearner
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
 
# Simulate data: Y = treatment * tau(X) + baseline(X) + noise
np.random.seed(42)
n, d = 2000, 5
X = np.random.randn(n, d)
# True CATE: positive for x0 > 0, negative for x0 < 0
tau_true = np.where(X[:, 0] > 0, 0.8, -0.3)
T = np.random.binomial(1, 0.5, n)
Y = tau_true * T + X[:, 0] + np.random.randn(n) * 0.5
 
X_tr, X_te, T_tr, T_te, Y_tr, Y_te, tau_tr, tau_te = train_test_split(
    X, T, Y, tau_true, test_size=0.3, random_state=0
)
 
gbr = GradientBoostingRegressor(n_estimators=100, max_depth=3)
 
# S-Learner
sl = SLearner(overall_model=gbr)
sl.fit(Y_tr, T_tr, X=X_tr)
tau_sl = sl.effect(X_te)
 
# T-Learner
tl = TLearner(models=[gbr, gbr])
tl.fit(Y_tr, T_tr, X=X_tr)
tau_tl = tl.effect(X_te)
 
# X-Learner (with propensity score)
gbc = GradientBoostingClassifier(n_estimators=100, max_depth=3)
xl = XLearner(models=[gbr, gbr], propensity_model=gbc)
xl.fit(Y_tr, T_tr, X=X_tr)
tau_xl = xl.effect(X_te)
 
# Causal Forest (doubly robust via R-learner)
cf = CausalForestDML(
    n_estimators=200,
    min_samples_leaf=10,
    max_features='sqrt',
    random_state=42,
    cv=3,
)
cf.fit(Y_tr, T_tr, X=X_tr, W=None)  # W = extra controls
tau_cf = cf.effect(X_te)
 
# Evaluation: MSE vs true CATE
for name, est in [('S-Learner', tau_sl), ('T-Learner', tau_tl),
                  ('X-Learner', tau_xl), ('Causal Forest', tau_cf)]:
    mse = np.mean((est - tau_te) ** 2)
    print(f'{name:15s}: CATE MSE = {mse:.4f}')

Confidence intervals from Causal Forest

python
# Pointwise CIs from the infinitesimal jackknife
tau_point, tau_lb, tau_ub = cf.effect_interval(X_te, alpha=0.05)
 
# Check: which units have significantly positive CATE?
sig_positive = (tau_lb > 0)
print(f'Units with significant positive CATE: {sig_positive.mean():.1%}')
 
# Policy: target top quartile by estimated CATE
quartile = np.percentile(tau_point, 75)
policy = (tau_point > quartile)
avg_effect_targeted = tau_te[policy].mean()
avg_effect_all = tau_te.mean()
print(f'ATE over all: {avg_effect_all:.3f}')
print(f'ATE over targeted quartile: {avg_effect_targeted:.3f}')

RATE evaluation for uplift

python
from sklearn.model_selection import cross_val_predict
 
# Inverse propensity weighting for off-policy evaluation
e_hat = cross_val_predict(gbc, X_te, T_te, method='predict_proba', cv=5)[:, 1]
 
# AIPW pseudo-outcomes for each unit
mu1_hat = cross_val_predict(gbr, X_te[T_te == 1], Y_te[T_te == 1], cv=5)  # simplified
# Full AIPW: psi_i = (T_i/e_hat - (1-T_i)/(1-e_hat)) * Y_i + correction terms
 
# Qini: sort by tau_hat, compute cumulative lift
order = np.argsort(-tau_point)          # highest predicted CATE first
treated_frac = np.cumsum(T_te[order]) / T_te.sum()
qini_curve = np.cumsum(Y_te[order] * T_te[order]) / T_te.sum()
# Qini coefficient = AUC of qini_curve - AUC of diagonal

Analysis & Evaluation

Where Your Intuition Breaks

More complex CATE models capture more heterogeneity and therefore produce better targeting policies. CATE estimation is a high-variance problem: the model must estimate two conditional means and take their difference, compounding estimation error. A flexible model that fits training data well can produce a wildly inaccurate CATE surface through overfitting, especially in regions of covariate space with few treated units. Causal Forests provide honest confidence intervals that grow wide in data-sparse regions precisely to signal this. A complex model with a narrow CI on training data and a wide CI on held-out units is the canonical failure mode — the heterogeneity is fitted noise, not signal.

Meta-learner selection guide

S-LearnerT-LearnerX-LearnerCausal Forest
Sample size neededSmallestModerateBest with imbalanceLarge
Treatment balanceAnyBalancedImbalanced (works best)Any
Regularization biasHigh riskModerateLowerLow (honest)
Valid CIsBootstrap onlyBootstrap onlyBootstrap onlyAsymptotic (exact)
Continuous treatmentYesMulti-arm onlyNoYes (GRF)
ImplementationTrivialSimpleModerateEconML/GRF
Best forBaseline, large nClean RCTObservational, imbalancedBest all-round

CATE vs ATE: when it matters

SignalAction
CATE variance low, ATE positiveShip to everyone — heterogeneity doesn't change the decision
CATE bimodal (positive + negative)Targeting — deploy only to positive-effect subgroup
CATE correlated with observable featuresPersonalization — use CATE as a ranking signal
CATE uncertainty high everywhereRun a larger experiment — you don't have enough data

Common pitfalls

Snooping on CATE: If you compute CATE then report the subgroup with the highest effect, you're doing post-hoc subgroup analysis. The confidence intervals are invalid. Specify the CATE estimation protocol in advance.

Confounding in observational CATE: Meta-learners inherit the unconfoundedness assumption. In observational data, τ^(x)\hat{\tau}(x) estimates E[YT=1,X=x]E[YT=0,X=x]E[Y|T=1,X=x] - E[Y|T=0,X=x], which is a causal effect only under unconfoundedness. Conditioning on more X can introduce collider bias if you condition on mediators.

Sample splitting for valid inference: Computing τ^\hat{\tau} and then testing whether high-τ^\hat{\tau} units have higher realized outcomes on the same data inflates Type I error. Use held-out data or cross-fitting.

RATE vs pseudo-R²: Don't evaluate CATE with R2(τ^(x),realized outcomes)R^2(\hat{\tau}(x), \text{realized outcomes}) — there's no ground truth individual effect. Use RATE, AUTOC, or Qini on held-out experimental data only.

Production-Ready Code

Production CATE estimation follows a train/eval split: fit the causal model on training data, evaluate CATE quality on held-out data using the AUUC curve, then fit a shallow policy tree to produce human-readable targeting rules. Deploying raw CATE scores without an AUUC check is a common failure — the model may rank units correctly on average but fail in the tails where targeting decisions matter most.

python
# production_hte.py
# CausalForestDML CATE pipeline, AUUC evaluation, and policy tree targeting rules.
# Install: pip install econml scikit-learn
 
from __future__ import annotations
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
 
 
def cate_pipeline(
    X: np.ndarray,
    T: np.ndarray,
    Y: np.ndarray,
    feature_names: list[str] | None = None,
    test_size: float = 0.3,
    random_state: int = 42,
) -> dict:
    """
    End-to-end CATE estimation using CausalForestDML.
    Splits data into train/eval; returns model + evaluation metrics.
    """
    from econml.dml import CausalForestDML
 
    X_tr, X_te, T_tr, T_te, Y_tr, Y_te = train_test_split(
        X, T, Y, test_size=test_size, random_state=random_state,
    )
    model = CausalForestDML(
        model_y=GradientBoostingRegressor(n_estimators=100, random_state=random_state),
        model_t=GradientBoostingClassifier(n_estimators=100, random_state=random_state),
        n_estimators=500,
        min_samples_leaf=10,
        random_state=random_state,
    )
    model.fit(Y_tr, T_tr, X=X_tr)
    cate_te = model.effect(X_te)
    lb, ub  = model.effect_interval(X_te, alpha=0.10)
 
    return {
        "model":        model,
        "X_test":       X_te,
        "T_test":       T_te,
        "Y_test":       Y_te,
        "cate_test":    cate_te,
        "cate_lb":      lb,
        "cate_ub":      ub,
        "cate_mean":    round(float(cate_te.mean()), 5),
        "cate_std":     round(float(cate_te.std()), 5),
        "cate_p10":     round(float(np.percentile(cate_te, 10)), 5),
        "cate_p90":     round(float(np.percentile(cate_te, 90)), 5),
        "pct_positive": round(float((cate_te > 0).mean() * 100), 1),
    }
 
 
def auuc_score(Y: np.ndarray, T: np.ndarray, cate_hat: np.ndarray) -> float:
    """
    Area Under the Uplift Curve.
    Sort units by predicted CATE descending; measure cumulative uplift vs. random.
    Higher = better CATE ranking. Use on held-out experimental data only.
    """
    order = np.argsort(-cate_hat)
    Y_o, T_o = Y[order], T[order]
    n = len(Y)
    uplift_vals = []
    n_t_cum = n_c_cum = 0.0
    for i in range(n):
        n_t_cum += T_o[i] * Y_o[i]
        n_c_cum += (1 - T_o[i]) * Y_o[i]
        t_seen = T_o[:i+1].sum()
        c_seen = (1 - T_o[:i+1]).sum()
        r_t = n_t_cum / max(t_seen, 1)
        r_c = n_c_cum / max(c_seen, 1)
        uplift_vals.append(r_t - r_c)
    return round(float(np.trapz(uplift_vals) / n), 6)
 
 
def policy_tree_report(
    model,
    X_test: np.ndarray,
    feature_names: list[str] | None = None,
    max_depth: int = 2,
) -> dict:
    """
    Fit a shallow policy tree on learned CATEs.
    Produces human-readable targeting rules for ops/product teams.
    """
    from econml.policy import PolicyTree
    cate = model.effect(X_test)
    tree = PolicyTree(max_depth=max_depth, random_state=42)
    tree.fit(X_test, cate)
    rules = tree.export_text(feature_names=feature_names)
    pct_targeted = float((tree.predict(X_test) > 0).mean() * 100)
    return {
        "targeting_rules":  rules,
        "pct_targeted":     round(pct_targeted, 1),
        "expected_lift_if_targeted": round(float(cate[tree.predict(X_test) > 0].mean()), 5),
    }
 
 
# ── Example ───────────────────────────────────────────────────────────────────
rng = np.random.default_rng(42)
n = 3_000
X = rng.normal(0, 1, (n, 6))
T = rng.binomial(1, 0.5, n).astype(float)
true_cate = 2.0 * X[:, 0]                              # effect varies with X[:,0]
Y = true_cate * T + X[:, 1] + rng.normal(0, 1, n)
feature_names = [f"feature_{i}" for i in range(6)]
 
result = cate_pipeline(X, T, Y, feature_names=feature_names)
print(f"CATE mean: {result['cate_mean']}, pct positive: {result['pct_positive']}%")
 
auuc = auuc_score(result["Y_test"], result["T_test"], result["cate_test"])
print(f"AUUC: {auuc}")  # > 0 means better-than-random targeting
 
rules = policy_tree_report(result["model"], result["X_test"], feature_names=feature_names)
print(rules["targeting_rules"])
print(f"Target {rules['pct_targeted']}% of population")

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.