Neural-Path/Notes
45 min

Staggered DiD & Modern Estimators

Most real-world policy rollouts don't happen all at once. A product feature rolls out to different regions in different quarters. A labor law passes in some states in 2018, others in 2020. This staggered adoption is ubiquitous — and it quietly breaks the most common econometric tool for policy evaluation. Two-way fixed effects (TWFE) regression, the workhorse of applied causal inference, produces biased estimates under staggered treatment adoption. Understanding why, and how to fix it, is essential for modern causal analysis.

Theory

Staggered treatment adoption across units

W1W2W3W4W5W6W7W8W9W10W11W12UnitEarly adopter 1Early adopter 2Mid adopter 1Mid adopter 2Late adopter 1Late adopter 2Never treated 1Never treated 2Treated periodPre-treatmentNever treatedAdoption time

The TWFE estimator. The standard TWFE DiD regression is:

Yit=αi+λt+θDit+εitY_{it} = \alpha_i + \lambda_t + \theta D_{it} + \varepsilon_{it}

where αi\alpha_i are unit fixed effects, λt\lambda_t are time fixed effects, and DitD_{it} is a binary treatment indicator. This looks reasonable — but under staggered adoption, θ^\hat{\theta} is a weighted average of 2×22 \times 2 DiD comparisons with potentially negative weights.

Why it had to be this way. Goodman-Bacon (2021) showed that the TWFE estimator decomposes into a weighted sum of all possible clean 2×22 \times 2 DiD comparisons. The "forbidden comparison" is the problem: early adopters (already treated) are used as implicit controls for late adopters. If the treatment effect is heterogeneous over time (treatment effects grow after adoption), using already-treated units as controls produces negatively-weighted comparisons that can flip the sign of the estimate.

The negative weights problem. The TWFE weight on a comparison between cohort gg (adopts at gg) and cohort gg' (adopts at gg') in the time period [g,T][g', T] is proportional to:

ωggngngggT(1ggT)\omega_{gg'} \propto n_g n_{g'} \cdot \frac{g' - g}{T} \cdot \left(1 - \frac{g' - g}{T}\right)

These weights sum to 1, but some can be negative when the estimand requires comparing cohort gg (which is already treated) to cohort gg' (just now adopting) using gg's treated-period outcomes as the "control."

Callaway-Sant'Anna (2021). Define group-time ATTs: ATT(g,t)=E[Yt(1)Yt(0)G=g]ATT(g, t) = \mathbb{E}[Y_t(1) - Y_t(0) \mid G = g] where G=gG = g means the unit adopted treatment at time gg. Estimate each ATT(g,t)ATT(g, t) using only clean comparisons (never-treated or not-yet-treated as controls). Aggregate to a single ATT using user-specified weights.

Sun-Abraham (2021). Estimate a fully saturated event study with cohort-time interactions. The coefficient on the interaction "cohort gg × event time \ell" is the ATT for cohort gg at event time \ell, free of contamination.

Stacked DiD. Restrict each 2×22 \times 2 comparison to clean control groups (never-treated or not-yet-treated units only). Stack the clean datasets and run TWFE within the stacked dataset.

Walkthrough

Scenario: A software company rolls out a new feature to different regions in Q1, Q2, and Q3. We want to estimate the effect on retention. 2 regions never receive the feature.

Step 1: Check for Goodman-Bacon decomposition (optional diagnostic).

python
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
 
def twfe_did(df: pd.DataFrame) -> dict:
    """Estimate TWFE DiD and return coefficient + SE."""
    # df must have columns: unit, time, outcome, treated (0/1)
    model = smf.ols('outcome ~ treated + C(unit) + C(time)', data=df).fit()
    coef = float(model.params['treated'])
    se = float(model.bse['treated'])
    return {'twfe_ate': round(coef, 6), 'se': round(se, 6)}

Step 2: Callaway-Sant'Anna estimator.

python
def callaway_santanna(
    df: pd.DataFrame,
    outcome_col: str = 'outcome',
    unit_col: str = 'unit',
    time_col: str = 'time',
    cohort_col: str = 'cohort',   # adoption time (NaN = never treated)
) -> pd.DataFrame:
    """Estimate group-time ATTs using not-yet-treated as controls.
 
    Returns DataFrame with columns: cohort, time, att, se.
    """
    rows = []
    cohorts = sorted(df[cohort_col].dropna().unique())
    T_all = sorted(df[time_col].unique())
 
    for g in cohorts:
        treated_units = df[df[cohort_col] == g][unit_col].unique()
        for t in T_all:
            if t < g - 1:
                continue  # only estimate from one period before adoption onwards
            control_units = df[
                (df[cohort_col].isna()) |
                (df[cohort_col] > t)
            ][unit_col].unique()
 
            base_t = g - 1
            y_treat_t = df[(df[unit_col].isin(treated_units)) & (df[time_col] == t)][outcome_col].mean()
            y_treat_base = df[(df[unit_col].isin(treated_units)) & (df[time_col] == base_t)][outcome_col].mean()
            y_ctrl_t = df[(df[unit_col].isin(control_units)) & (df[time_col] == t)][outcome_col].mean()
            y_ctrl_base = df[(df[unit_col].isin(control_units)) & (df[time_col] == base_t)][outcome_col].mean()
 
            att = (y_treat_t - y_treat_base) - (y_ctrl_t - y_ctrl_base)
            n_treat = len(treated_units)
            n_ctrl = len(control_units)
            se_approx = np.sqrt(
                df[(df[unit_col].isin(treated_units)) & (df[time_col].isin([t, base_t]))][outcome_col].var() / n_treat +
                df[(df[unit_col].isin(control_units)) & (df[time_col].isin([t, base_t]))][outcome_col].var() / n_ctrl
            )
            rows.append({'cohort': g, 'time': t, 'att': round(att, 6), 'se': round(se_approx, 6)})
 
    return pd.DataFrame(rows)

Step 3: Aggregate ATT.

python
def aggregate_att(att_df: pd.DataFrame) -> dict:
    """Aggregate group-time ATTs to a single ATT (equal-weight across cohorts)."""
    post_att = att_df[att_df['time'] >= att_df['cohort']]
    overall_att = float(post_att['att'].mean())
    overall_se = float(post_att['se'].mean() / np.sqrt(len(post_att)))
    return {
        'overall_att': round(overall_att, 6),
        'se': round(overall_se, 6),
        'n_cohort_times': len(post_att),
    }

Analysis & Evaluation

Where your intuition breaks. A statistically significant TWFE coefficient feels like strong evidence. Under staggered adoption with heterogeneous treatment effects, it can be a convex combination of positive effects with a negative weight correction that makes the coefficient smaller than the true average effect — or even flip its sign. The telltale sign: run a Goodman-Bacon decomposition and check for 2×22 \times 2 comparisons with negative weights and large magnitude. If the negatively-weighted comparisons dominate, TWFE is unreliable.

EstimatorAssumptionHandles het. TEComputational cost
TWFEHomogeneous treatment effectsNoLow
Callaway-Sant'AnnaParallel trends per cohortYesMedium
Sun-AbrahamParallel trends, saturatedYesLow
Stacked DiDParallel trends, clean controlsYesMedium
Borusyak-Jaravel-SpiessParallel trends, imputationYesMedium

Pre-testing parallel trends. Under staggered adoption, the standard pre-trend test (plot event-study coefficients for pre-adoption periods) requires using the correct estimator, not TWFE. Run Sun-Abraham event study coefficients — if pre-period coefficients are near zero, parallel trends is plausible.

⚠️Warning

The TWFE negative weights problem is not hypothetical. Multiple empirical replication studies have found that TWFE gave the wrong sign on the treatment effect in published papers. Always run a Goodman-Bacon decomposition or Callaway-Sant'Anna estimator as a robustness check when treatment adoption is staggered.

Production-Ready Code

python
"""
Staggered DiD production pipeline.
TWFE, Callaway-Sant'Anna ATTs, Sun-Abraham event study,
and stacked DiD with clean comparison sets.
"""
 
from __future__ import annotations
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from scipy.stats import norm
 
 
def check_negative_twfe_weights(
    df: pd.DataFrame,
    unit_col: str = 'unit',
    time_col: str = 'time',
    cohort_col: str = 'cohort',
) -> dict:
    """Diagnose potential negative TWFE weights via Goodman-Bacon decomposition.
 
    Returns fraction of variation explained by "bad" (already-treated) comparisons.
    """
    cohorts = sorted(df[cohort_col].dropna().unique())
    T_all = sorted(df[time_col].unique())
    T = len(T_all)
    bad_share = 0.0
    total_share = 0.0
 
    for i, g in enumerate(cohorts):
        for g_prime in cohorts[i+1:]:
            n_g = df[df[cohort_col] == g][unit_col].nunique()
            n_gp = df[df[cohort_col] == g_prime][unit_col].nunique()
            timing_weight = (g_prime - g) / T * (1 - (g_prime - g) / T)
            size_weight = n_g * n_gp * timing_weight
            total_share += abs(size_weight)
            bad_share += abs(size_weight) * ((T - g_prime) / T)
 
    return {
        'bad_comparison_share': round(bad_share / (total_share + 1e-9), 3),
        'warning': (
            "High TWFE contamination: Callaway-Sant'Anna or Sun-Abraham recommended"
            if bad_share / (total_share + 1e-9) > 0.3 else None
        ),
    }
 
 
def sun_abraham_event_study(
    df: pd.DataFrame,
    outcome_col: str = 'outcome',
    unit_col: str = 'unit',
    time_col: str = 'time',
    cohort_col: str = 'cohort',
    n_pre: int = 4,
    n_post: int = 6,
) -> pd.DataFrame:
    """Saturated event study with cohort x event-time interactions (Sun-Abraham 2021).
 
    Returns DataFrame with columns: event_time, att, se, ci_lower, ci_upper.
    """
    df = df.copy()
    df['event_time'] = df[time_col] - df[cohort_col].fillna(9999)
    df = df[df['event_time'].between(-n_pre, n_post)]
    df = df[df['event_time'] != -1]  # omit base period
 
    df['cohort_str'] = df[cohort_col].fillna('never').astype(str)
    df['interaction'] = 'c' + df['cohort_str'] + '_e' + df['event_time'].astype(int).astype(str)
 
    formula = (f'{outcome_col} ~ C(interaction) + C({unit_col}) + C({time_col})')
    model = smf.ols(formula, data=df[df['event_time'] != 9999 - df[time_col].mean()]).fit()
 
    rows = []
    for et in range(-n_pre, n_post + 1):
        if et == -1:
            rows.append({'event_time': et, 'att': 0.0, 'se': 0.0, 'ci_lower': 0.0, 'ci_upper': 0.0})
            continue
        matching_params = [k for k in model.params.index if f'_e{et}' in k]
        if not matching_params:
            continue
        att = float(model.params[matching_params].mean())
        se = float(model.bse[matching_params].mean())
        z = norm.ppf(0.975)
        rows.append({
            'event_time': et,
            'att': round(att, 6),
            'se': round(se, 6),
            'ci_lower': round(att - z * se, 6),
            'ci_upper': round(att + z * se, 6),
        })
 
    return pd.DataFrame(rows).sort_values('event_time')
 
 
def stacked_did(
    df: pd.DataFrame,
    outcome_col: str = 'outcome',
    unit_col: str = 'unit',
    time_col: str = 'time',
    cohort_col: str = 'cohort',
    n_pre: int = 3,
    n_post: int = 5,
) -> dict:
    """Stacked DiD: restrict each cohort comparison to clean controls only."""
    cohorts = sorted(df[cohort_col].dropna().unique())
    stacked_frames = []
 
    for g in cohorts:
        treated = df[df[cohort_col] == g].copy()
        clean_ctrl = df[
            df[cohort_col].isna() | (df[cohort_col] > g + n_post)
        ].copy()
        window = range(int(g) - n_pre, int(g) + n_post + 1)
        treated = treated[treated[time_col].isin(window)].copy()
        clean_ctrl = clean_ctrl[clean_ctrl[time_col].isin(window)].copy()
        treated['stack_cohort'] = g
        clean_ctrl['stack_cohort'] = g
        treated['post'] = (treated[time_col] >= g).astype(int)
        clean_ctrl['post'] = (clean_ctrl[time_col] >= g).astype(int)
        treated['treated_unit'] = 1
        clean_ctrl['treated_unit'] = 0
        stacked_frames.extend([treated, clean_ctrl])
 
    stacked = pd.concat(stacked_frames, ignore_index=True)
    stacked['did'] = stacked['treated_unit'] * stacked['post']
    model = smf.ols(
        f'{outcome_col} ~ did + C({unit_col}) + C({time_col}) + C(stack_cohort)',
        data=stacked,
    ).fit()
 
    return {
        'att': round(float(model.params['did']), 6),
        'se': round(float(model.bse['did']), 6),
        'pvalue': round(float(model.pvalues['did']), 4),
        'n_stacked_obs': len(stacked),
    }

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.