Neural-Path/Notes
45 min

Causal Forests & CATE Estimation

The average treatment effect is a useful summary, but it can hide enormous heterogeneity. A drug that shows zero average effect might work well for young patients and harm older ones. A discount that has no average impact on revenue might drive profitable customers to buy more while attracting unprofitable price-sensitive customers. Estimating who the treatment helps — the Conditional Average Treatment Effect (CATE) — is the key to personalized decisions, targeted policies, and understanding mechanism.

Theory

Causal forest: honest splits + adaptive neighborhood

Honest splitQuery xX₁X₂TreatedControlSize = CATENeighborHigh CATEregion (top-right)

Meta-learners. The simplest approaches to CATE estimation treat it as supervised learning. Given training data {Yi,Di,Xi}\{Y_i, D_i, X_i\}:

  • S-learner: Fit a single model μ^(x,d)=E[YX=x,D=d]\hat{\mu}(x, d) = \mathbb{E}[Y \mid X=x, D=d] and compute τ^(x)=μ^(x,1)μ^(x,0)\hat{\tau}(x) = \hat{\mu}(x, 1) - \hat{\mu}(x, 0). Simple, but the treatment indicator DD may be regularized away.

  • T-learner: Fit separate models μ^0(x)\hat{\mu}_0(x) for control and μ^1(x)\hat{\mu}_1(x) for treatment. Compute τ^(x)=μ^1(x)μ^0(x)\hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x). No regularization issue, but extrapolates in regions without overlap.

  • X-learner: Impute treatment effect for each unit: τ~i(1)=Yiμ^0(Xi)\tilde{\tau}_i^{(1)} = Y_i - \hat{\mu}_0(X_i) for treated units, τ~i(0)=μ^1(Xi)Yi\tilde{\tau}_i^{(0)} = \hat{\mu}_1(X_i) - Y_i for controls. Fit regression of imputed effects on XX. Better in settings with unbalanced treatment assignment.

  • R-learner (Robinson decomposition). The R-learner targets the CATE via the partially linear model loss:

τ^=argminτi[(Yim^(Xi))(Die^(Xi))τ(Xi)]2\hat{\tau} = \arg\min_\tau \sum_i \left[(Y_i - \hat{m}(X_i)) - (D_i - \hat{e}(X_i))\tau(X_i)\right]^2

This is doubly robust: if either m^\hat{m} or e^\hat{e} is consistent, τ^\hat{\tau} is consistent.

Generalized Random Forests (GRF). Wager and Athey (2018) extend random forests to directly estimate CATE. The forest defines adaptive weights αi(x)\alpha_i(x) — how much each training unit contributes to the estimate at query point xx — based on how often unit ii and the query point end up in the same leaf. The CATE estimate is:

τ^(x)=argminθi=1nαi(x)[(YiYˉα(x))(DiDˉα(x))θ]2\hat{\tau}(x) = \arg\min_\theta \sum_{i=1}^n \alpha_i(x) \left[(Y_i - \bar{Y}_{\alpha(x)}) - (D_i - \bar{D}_{\alpha(x)}) \theta\right]^2

Honesty. A causal tree split criterion can over-fit: if the same data selects splits and estimates effects, the estimates are biased upward (selection bias from maximizing heterogeneity). Honest splitting divides the training data: half for determining splits, half for estimating effects within leaves. This removes the bias at the cost of some variance.

Why it had to be this way. The honesty requirement follows from the same insight as cross-fitting in Double ML: using the same data to both select a model and estimate its parameters introduces overfitting bias that does not vanish asymptotically. Honesty removes this bias by making the split selection and effect estimation independent.

Walkthrough

Scenario: Personalized discount targeting. We have 10,000 customers with treatment (discount offered) and outcome (purchase within 7 days). We want to estimate each customer's CATE and design a targeting policy.

Step 1: Estimate CATE with EconML CausalForestDML.

python
import numpy as np
from econml.dml import CausalForestDML
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
 
def estimate_cate(
    Y: np.ndarray,    # (n,)
    T: np.ndarray,    # (n,) — 0/1 treatment
    X: np.ndarray,    # (n, p) — covariates for heterogeneity
    W: np.ndarray,    # (n, q) — additional controls (not for heterogeneity)
) -> tuple:
    """Fit CausalForestDML and return CATE estimates and confidence intervals."""
    est = CausalForestDML(
        model_y=GradientBoostingRegressor(n_estimators=100, random_state=0),
        model_t=GradientBoostingClassifier(n_estimators=100, random_state=0),
        n_estimators=200,
        min_samples_leaf=10,
        max_depth=None,
        random_state=42,
        verbose=0,
    )
    est.fit(Y, T, X=X, W=W)
    cate = est.effect(X)
    lb, ub = est.effect_interval(X, alpha=0.05)
    return cate, lb, ub, est

Step 2: Evaluate CATE quality with AUUC.

python
def auuc_score(
    cate_estimates: np.ndarray,
    Y: np.ndarray,
    T: np.ndarray,
) -> float:
    """Area Under Uplift Curve — measures how well CATE ranks targeting benefit.
 
    Higher AUUC = CATE correctly identifies who benefits most.
    Random targeting gives AUUC = 0.5.
    """
    # Sort by predicted CATE descending (highest benefit first)
    order = np.argsort(-cate_estimates)
    Y_sorted = Y[order]
    T_sorted = T[order]
    n = len(Y)
    # Compute cumulative uplift at each threshold
    uplifts = []
    for k in range(1, n + 1):
        sub = slice(0, k)
        n_t = T_sorted[sub].sum()
        n_c = k - n_t
        if n_t == 0 or n_c == 0:
            uplifts.append(0.0)
            continue
        uplift = (Y_sorted[sub][T_sorted[sub] == 1].mean() -
                  Y_sorted[sub][T_sorted[sub] == 0].mean())
        uplifts.append(float(uplift))
    return float(np.trapz(uplifts, dx=1/n))
 
 
def policy_tree_report(
    est,          # fitted CausalForestDML
    X: np.ndarray,
    feature_names: list[str],
    max_depth: int = 2,
) -> str:
    """Fit a shallow PolicyTree for interpretable targeting rules."""
    from econml.policy import PolicyTree
    policy = PolicyTree(max_depth=max_depth)
    cate = est.effect(X)
    policy.fit(X, actions=np.column_stack([np.zeros_like(cate), cate]))
    rules = []
    def traverse(node_id: int, depth: int = 0):
        tree = policy.tree_model_.tree_
        indent = '  ' * depth
        if tree.children_left[node_id] == -1:
            action = 'TREAT' if tree.value[node_id, 0, 1] > tree.value[node_id, 0, 0] else 'CONTROL'
            rules.append(f"{indent}-> {action}")
            return
        feat = feature_names[tree.feature[node_id]] if feature_names else f"X[{tree.feature[node_id]}]"
        thresh = tree.threshold[node_id]
        rules.append(f"{indent}if {feat} <= {thresh:.3f}:")
        traverse(tree.children_left[node_id], depth + 1)
        rules.append(f"{indent}else:")
        traverse(tree.children_right[node_id], depth + 1)
    traverse(0)
    return '\n'.join(rules)

Analysis & Evaluation

Where your intuition breaks. CATE estimates look like point predictions — you get a number for each unit and it feels concrete. But CATE estimates are extremely noisy at the individual level. The confidence interval for a single unit's CATE can be so wide as to be useless for individual-level decisions. What CATE estimates are reliable for: ranking units (who benefits most vs. least) and identifying subgroup averages. The AUUC curve measures ranking quality, not individual-level accuracy — which is why AUUC is the right evaluation metric, not RMSE.

Meta-learnerWorks well whenWeakness
S-learnerSmall effect heterogeneityRegularizes treatment effect to zero
T-learnerBalanced treatment, large nExtrapolates in non-overlap regions
X-learnerUnbalanced treatmentNeeds good outcome models in both arms
R-learnerAny settingComputationally more complex
GRF/CausalForestAny setting, valid CIsRequires econml or grf package

Policy learning. Once you have CATE estimates, target everyone with τ^(x)>c\hat{\tau}(x) > c where cc is the cost of treatment. A PolicyTree gives interpretable rules (if age >35> 35 AND income >50k> 50k, treat). Estimate value of the policy on a held-out test set, not the training set.

💡Intuition

CATE is not ATE for subgroups. The ATE in a subgroup (e.g., women over 40) is the average treatment effect for that subgroup — a valid estimand estimated by subsetting. CATE is a function of covariates estimated by a model — subject to all the usual biases of model-based estimation. The two can agree when the model is well-specified, but CATE estimates from a single model should not be treated as ground truth for specific subgroups.

Production-Ready Code

python
"""
CATE estimation production pipeline.
CausalForestDML, X-learner, policy tree,
AUUC evaluation, and production scoring.
"""
 
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
 
 
@dataclass
class CATEPipelineResult:
    cate_train: np.ndarray
    cate_test: np.ndarray
    ci_lower_test: np.ndarray
    ci_upper_test: np.ndarray
    auuc: float
    model: object   # fitted estimator for scoring
 
 
def cate_pipeline(
    Y: np.ndarray,
    T: np.ndarray,
    X: np.ndarray,
    test_size: float = 0.3,
    random_state: int = 42,
) -> CATEPipelineResult:
    """Train-test split CATE pipeline with evaluation."""
    from econml.dml import CausalForestDML
    X_tr, X_te, Y_tr, Y_te, T_tr, T_te = train_test_split(
        X, Y, T, test_size=test_size, random_state=random_state
    )
    est = CausalForestDML(
        model_y=GradientBoostingRegressor(n_estimators=100, random_state=0),
        model_t=GradientBoostingClassifier(n_estimators=100, random_state=0),
        n_estimators=200,
        min_samples_leaf=10,
        random_state=random_state,
        verbose=0,
    )
    est.fit(Y_tr, T_tr, X=X_tr)
    cate_train = est.effect(X_tr)
    cate_test = est.effect(X_te)
    lb, ub = est.effect_interval(X_te, alpha=0.05)
    auuc = _auuc(cate_test, Y_te, T_te)
    return CATEPipelineResult(
        cate_train=cate_train, cate_test=cate_test,
        ci_lower_test=lb, ci_upper_test=ub,
        auuc=auuc, model=est,
    )
 
 
def _auuc(cate: np.ndarray, Y: np.ndarray, T: np.ndarray) -> float:
    order = np.argsort(-cate)
    Y_s, T_s = Y[order], T[order]
    n = len(Y)
    uplifts = []
    for k in range(1, n + 1):
        n_t = T_s[:k].sum()
        n_c = k - n_t
        if n_t == 0 or n_c == 0:
            uplifts.append(0.0)
            continue
        uplifts.append(float(Y_s[:k][T_s[:k] == 1].mean() - Y_s[:k][T_s[:k] == 0].mean()))
    return float(np.trapz(uplifts, dx=1/n))
 
 
def score_new_users(
    model,
    X_new: np.ndarray,
    treatment_cost: float = 0.0,
) -> pd.DataFrame:
    """Score new users and produce targeting decisions.
 
    Returns DataFrame with cate, ci_lower, ci_upper, treat columns.
    """
    cate = model.effect(X_new)
    lb, ub = model.effect_interval(X_new, alpha=0.05)
    return pd.DataFrame({
        'cate': cate.round(6),
        'ci_lower': lb.round(6),
        'ci_upper': ub.round(6),
        'treat': (cate > treatment_cost).astype(int),
    })

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.