Neural-Path/Notes
30 min

Training Data at Scale

Training a large model is only possible if data reaches the GPU fast enough to keep it busy. The bottleneck is almost never the GPU — it's the data pipeline: reading from disk, decoding images or text, applying augmentations, and delivering batches fast enough to keep compute saturated. Fixing data loading throughput is often higher ROI than buying faster hardware.

How It Works

Data loading strategies — GPU utilization

Background workers preload next batch while GPU trains on current batch. Overlap I/O with compute.

One training step (time →)
prefetch
train
 GPU training CPU data load Network stream
GPU utilization85%

~ Good overlap — some idle time remains between batches.

At scale: A100 at $3/hr with 30% GPU utilization wastes $18k/month in idle compute per GPU. Fixing data loading is often higher ROI than buying faster hardware.

Toggle through the three strategies above. In the naive case, the GPU trains for a short burst and then idles while the CPU loads the next batch. Prefetching overlaps data loading with training — background workers prepare the next batch while the current one is training. Streaming (WebDataset) removes disk entirely, pulling data directly from object storage at network speed.

GPU training runs in two alternating phases: compute (the GPU processes a batch) and I/O (the CPU loads the next batch). If I/O takes longer than compute, the GPU sits idle. The entire goal of data pipeline optimization is to make I/O time invisible by overlapping it with compute — the GPU should never wait. This is not an optimization problem; it is a scheduling problem, and the solution is always some form of prefetching.

The GPU utilization problem

GPU time costs money. An A100 at $3/hr running at 30% utilization is paying for 70% idle time. At scale across dozens of GPUs and weeks of training, data pipeline inefficiency is a six-figure cost.

DataLoader parallelism

PyTorch DataLoader controls the number of background workers that preload data:

python
from torch.utils.data import DataLoader
 
loader = DataLoader(
    dataset,
    batch_size=256,
    num_workers=8,            # parallel CPU workers
    prefetch_factor=2,        # batches to prefetch per worker
    pin_memory=True,          # page-locked memory for faster GPU transfer
    persistent_workers=True,  # avoid worker restart overhead
)

num_workers=0 (default) runs loading in the main process. Set num_workers to 4–16 depending on CPU core count and I/O bandwidth. pin_memory=True enables faster CPU→GPU transfer via DMA.

WebDataset for object storage streaming

For large-scale training, local disk doesn't scale. WebDataset treats data as a stream of tar shards stored in S3 or GCS:

python
import webdataset as wds
 
dataset = (
    wds.WebDataset("s3://my-bucket/train-{000000..002999}.tar")
    .shuffle(1000)          # shuffle within a buffer
    .decode("torchrgb")     # decode images to tensors
    .to_tuple("jpg", "cls") # extract image and label
    .batched(256)
)
 
loader = DataLoader(dataset, num_workers=8)

Shards are fetched in parallel. As one shard is consumed, the next is already downloading. No local disk required.

Data format and throughput

FormatThroughputRandom accessUse case
Raw JPEG/PNGLow (decode overhead)YesNot suitable for training
ParquetMediumColumn-yes, row-partialTabular/NLP data
tar shards (WebDataset)HighNoLarge-scale streaming
Memory-mapped numpyVery highYesPre-tokenized text

For LLM training, pre-tokenized data stored as memory-mapped numpy arrays (np.memmap) gives near-zero loading overhead — the OS maps the file into address space and pages it in on demand.

Memory-mapped files had to be the format of choice for LLM training because the alternative — loading multi-terabyte tokenized datasets into RAM — is impossible on any practical machine. np.memmap works because modern OS virtual memory systems are better at this than application-level caching: the OS knows which pages are actively being accessed, can prefetch them based on access patterns, and can evict cold pages under memory pressure without any explicit application logic. The training loop just reads array slices; the OS transparently makes that work against a file that may be larger than available RAM.

Design Tradeoffs

Where Your Intuition Breaks

The instinct when training is slow is to get faster GPUs. GPU utilization profiling almost always reveals that GPUs are not the bottleneck — they're already idle 40–60% of the time waiting for data. A faster GPU makes the compute phase shorter, which makes the idle fraction larger, not smaller. The correct first step is always to measure utilization before spending on hardware: if nvidia-smi shows GPU utilization below 80% during training, the data pipeline is the bottleneck and must be fixed before hardware upgrades provide any benefit. This is especially common when migrating to larger models or higher resolutions — the compute cost per batch increases, revealing data loading bottlenecks that were previously hidden.

Shuffling at scale

Training requires shuffled data to avoid gradient biases. Global shuffle is expensive at scale:

  • Offline shuffle: pre-shuffle data before training. Eliminates runtime overhead but requires re-shuffling between epochs.
  • Buffer shuffle: wds.shuffle(1000) — shuffle within a sliding window. Fast but not globally random. Adequate for most tasks.
  • Epoch-level shard randomization: randomize shard order each epoch for approximate global randomness.

For LLM pre-training, data is typically shuffled once before tokenization and packed into fixed-length sequences — tokens from different documents are concatenated with separator tokens.

Augmentation: CPU vs GPU

Heavy augmentation (RandomResizedCrop, ColorJitter) on CPU can be a bottleneck. NVIDIA DALI moves image augmentation to the GPU, offloading the CPU entirely. Effective when augmentation is the bottleneck, not raw I/O.

For text, pre-tokenize offline and store token IDs — avoid re-tokenizing at training time.

Distributed data loading

With multiple GPUs (DDP), each GPU needs its own shard of data. DistributedSampler partitions across ranks:

python
from torch.utils.data.distributed import DistributedSampler
 
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader  = DataLoader(dataset, batch_size=256, sampler=sampler)

Global effective batch size = batch_size × world_size. Gradient accumulation increases effective batch size without increasing memory per GPU.

In Practice

Profiling the data pipeline

Before optimizing, measure. Look for DataLoader.__iter__ taking more time than model.forward():

python
import torch.profiler as profiler
 
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA]) as prof:
    for i, batch in enumerate(loader):
        model(batch)
        if i == 5:
            break
 
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Estimating required throughput

Target throughput = tokens per step × steps per second. For a 7B LLM training on 100B tokens with batch size 1M tokens and 0.5 steps/sec:

  • Training time: 100B / (1M × 0.5) = 2.3 days
  • If the dataloader caps at 500k tokens/sec: 4.6 days — doubling costs

Pre-compute this before starting a long run.

Checkpointing and resumability

Long training runs fail. Data pipelines must support deterministic resumption:

  • Save the dataset shard index and shuffle seed alongside model weights
  • On resume, reconstruct the identical data order from that point
  • WebDataset supports shard-level resumption by seeking to the correct shard offset

Without deterministic data order, a resumed run might re-train on already-seen data (causing overfitting) or skip data entirely.

Production Patterns

WebDataset loader for S3-backed shards

Store training data as tar shards in S3. Each shard holds ~500MB and contains paired samples (image + label, or text + metadata). The loader streams shards in parallel with no local disk required:

python
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
 
SHARD_URL = "s3://my-bucket/imagenet/train-{000000..001281}.tar"
 
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
 
dataset = (
    wds.WebDataset(
        SHARD_URL,
        resampled=True,           # infinite streaming: reshuffle shards each epoch
        shardshuffle=True,        # randomize shard fetch order
        cache_dir="/tmp/wds_cache",  # optional: cache recently-used shards locally
        nodesplitter=wds.split_by_node,  # DDP: each node gets disjoint shards
    )
    .shuffle(2000)                     # in-memory buffer shuffle across shard boundaries
    .decode("torchrgb")                # JPEG → torch.Tensor
    .to_tuple("jpg", "cls")
    .map_tuple(transform, lambda x: x)
    .batched(256, partial=False)       # drop incomplete final batch
)
 
loader = DataLoader(
    dataset,
    batch_size=None,       # batching handled by .batched() above
    num_workers=8,
    pin_memory=True,
    prefetch_factor=3,
)

For text data, replace .decode("torchrgb") with a custom decode that reads pre-tokenized numpy arrays directly from the tar entries.

Stratified sampling for imbalanced classes

Class imbalance is common in fraud detection, medical imaging, and click-through prediction. Two approaches: SQL-level downsampling before training, and PyTorch WeightedRandomSampler at load time.

SQL downsampling (offline, during dataset export):

sql
-- Keep all positives; downsample negatives to 10:1 ratio
WITH positives AS (
    SELECT * FROM training_events WHERE label = 1
),
negatives AS (
    SELECT *
    FROM training_events
    WHERE label = 0
    -- reservoir sample: reproducible with fixed seed
    ORDER BY farm_fingerprint(concat(cast(event_id AS STRING), '42'))
    LIMIT (SELECT COUNT(*) * 10 FROM positives)
)
SELECT * FROM positives
UNION ALL
SELECT * FROM negatives

PyTorch WeightedRandomSampler (online, no disk overhead):

python
import torch
from torch.utils.data import WeightedRandomSampler
 
labels = torch.tensor(dataset.labels)          # 0 or 1 per sample
class_counts = torch.bincount(labels)          # [n_neg, n_pos]
class_weights = 1.0 / class_counts.float()    # inverse frequency
sample_weights = class_weights[labels]         # one weight per sample
 
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,                          # sample with replacement
)
 
loader = DataLoader(dataset, batch_size=256, sampler=sampler, num_workers=4)

Use the SQL approach when you can afford the storage (cleaner, reproducible). Use WeightedRandomSampler when iterating quickly and the dataset lives in memory.

Streaming data to GPU with prefetching

The default PyTorch pipeline stalls the GPU while CPU decodes the next batch. A prefetch queue keeps the GPU fed:

python
import threading
import queue
import torch
 
class CUDAPrefetcher:
    """Async prefetch to a GPU-pinned queue, overlapping transfer with forward pass."""
 
    def __init__(self, loader: DataLoader, device: torch.device, queue_size: int = 3):
        self.loader = loader
        self.device = device
        self.queue: queue.Queue = queue.Queue(maxsize=queue_size)
        self._thread = threading.Thread(target=self._fill_queue, daemon=True)
        self._thread.start()
 
    def _fill_queue(self) -> None:
        for batch in self.loader:
            inputs, targets = batch
            # Non-blocking copy: returns immediately, transfer happens in background
            inputs  = inputs.to(self.device, non_blocking=True)
            targets = targets.to(self.device, non_blocking=True)
            self.queue.put((inputs, targets))
        self.queue.put(None)  # sentinel
 
    def __iter__(self):
        while True:
            item = self.queue.get()
            if item is None:
                return
            yield item
 
# Usage
device = torch.device("cuda")
base_loader = DataLoader(dataset, batch_size=256, num_workers=8, pin_memory=True)
prefetcher  = CUDAPrefetcher(base_loader, device, queue_size=4)
 
for inputs, targets in prefetcher:
    # inputs and targets already on GPU before this line executes
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

pin_memory=True on the DataLoader is a prerequisite — it allocates CPU tensors in page-locked memory, enabling the non_blocking=True DMA transfer. On multi-GPU runs, instantiate one prefetcher per rank.

Enjoying these notes?

Get new lessons delivered to your inbox. No spam.