Training Data at Scale
Training a large model is only possible if data reaches the GPU fast enough to keep it busy. The bottleneck is almost never the GPU — it's the data pipeline: reading from disk, decoding images or text, applying augmentations, and delivering batches fast enough to keep compute saturated. Fixing data loading throughput is often higher ROI than buying faster hardware.
How It Works
Data loading strategies — GPU utilization
Background workers preload next batch while GPU trains on current batch. Overlap I/O with compute.
~ Good overlap — some idle time remains between batches.
Toggle through the three strategies above. In the naive case, the GPU trains for a short burst and then idles while the CPU loads the next batch. Prefetching overlaps data loading with training — background workers prepare the next batch while the current one is training. Streaming (WebDataset) removes disk entirely, pulling data directly from object storage at network speed.
GPU training runs in two alternating phases: compute (the GPU processes a batch) and I/O (the CPU loads the next batch). If I/O takes longer than compute, the GPU sits idle. The entire goal of data pipeline optimization is to make I/O time invisible by overlapping it with compute — the GPU should never wait. This is not an optimization problem; it is a scheduling problem, and the solution is always some form of prefetching.
The GPU utilization problem
GPU time costs money. An A100 at $3/hr running at 30% utilization is paying for 70% idle time. At scale across dozens of GPUs and weeks of training, data pipeline inefficiency is a six-figure cost.
DataLoader parallelism
PyTorch DataLoader controls the number of background workers that preload data:
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=256,
num_workers=8, # parallel CPU workers
prefetch_factor=2, # batches to prefetch per worker
pin_memory=True, # page-locked memory for faster GPU transfer
persistent_workers=True, # avoid worker restart overhead
)num_workers=0 (default) runs loading in the main process. Set num_workers to 4–16 depending on CPU core count and I/O bandwidth. pin_memory=True enables faster CPU→GPU transfer via DMA.
WebDataset for object storage streaming
For large-scale training, local disk doesn't scale. WebDataset treats data as a stream of tar shards stored in S3 or GCS:
import webdataset as wds
dataset = (
wds.WebDataset("s3://my-bucket/train-{000000..002999}.tar")
.shuffle(1000) # shuffle within a buffer
.decode("torchrgb") # decode images to tensors
.to_tuple("jpg", "cls") # extract image and label
.batched(256)
)
loader = DataLoader(dataset, num_workers=8)Shards are fetched in parallel. As one shard is consumed, the next is already downloading. No local disk required.
Data format and throughput
| Format | Throughput | Random access | Use case |
|---|---|---|---|
| Raw JPEG/PNG | Low (decode overhead) | Yes | Not suitable for training |
| Parquet | Medium | Column-yes, row-partial | Tabular/NLP data |
| tar shards (WebDataset) | High | No | Large-scale streaming |
| Memory-mapped numpy | Very high | Yes | Pre-tokenized text |
For LLM training, pre-tokenized data stored as memory-mapped numpy arrays (np.memmap) gives near-zero loading overhead — the OS maps the file into address space and pages it in on demand.
Memory-mapped files had to be the format of choice for LLM training because the alternative — loading multi-terabyte tokenized datasets into RAM — is impossible on any practical machine. np.memmap works because modern OS virtual memory systems are better at this than application-level caching: the OS knows which pages are actively being accessed, can prefetch them based on access patterns, and can evict cold pages under memory pressure without any explicit application logic. The training loop just reads array slices; the OS transparently makes that work against a file that may be larger than available RAM.
Design Tradeoffs
Where Your Intuition Breaks
The instinct when training is slow is to get faster GPUs. GPU utilization profiling almost always reveals that GPUs are not the bottleneck — they're already idle 40–60% of the time waiting for data. A faster GPU makes the compute phase shorter, which makes the idle fraction larger, not smaller. The correct first step is always to measure utilization before spending on hardware: if nvidia-smi shows GPU utilization below 80% during training, the data pipeline is the bottleneck and must be fixed before hardware upgrades provide any benefit. This is especially common when migrating to larger models or higher resolutions — the compute cost per batch increases, revealing data loading bottlenecks that were previously hidden.
Shuffling at scale
Training requires shuffled data to avoid gradient biases. Global shuffle is expensive at scale:
- Offline shuffle: pre-shuffle data before training. Eliminates runtime overhead but requires re-shuffling between epochs.
- Buffer shuffle:
wds.shuffle(1000)— shuffle within a sliding window. Fast but not globally random. Adequate for most tasks. - Epoch-level shard randomization: randomize shard order each epoch for approximate global randomness.
For LLM pre-training, data is typically shuffled once before tokenization and packed into fixed-length sequences — tokens from different documents are concatenated with separator tokens.
Augmentation: CPU vs GPU
Heavy augmentation (RandomResizedCrop, ColorJitter) on CPU can be a bottleneck. NVIDIA DALI moves image augmentation to the GPU, offloading the CPU entirely. Effective when augmentation is the bottleneck, not raw I/O.
For text, pre-tokenize offline and store token IDs — avoid re-tokenizing at training time.
Distributed data loading
With multiple GPUs (DDP), each GPU needs its own shard of data. DistributedSampler partitions across ranks:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, batch_size=256, sampler=sampler)Global effective batch size = batch_size × world_size. Gradient accumulation increases effective batch size without increasing memory per GPU.
In Practice
Profiling the data pipeline
Before optimizing, measure. Look for DataLoader.__iter__ taking more time than model.forward():
import torch.profiler as profiler
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA]) as prof:
for i, batch in enumerate(loader):
model(batch)
if i == 5:
break
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))Estimating required throughput
Target throughput = tokens per step × steps per second. For a 7B LLM training on 100B tokens with batch size 1M tokens and 0.5 steps/sec:
- Training time: 100B / (1M × 0.5) = 2.3 days
- If the dataloader caps at 500k tokens/sec: 4.6 days — doubling costs
Pre-compute this before starting a long run.
Checkpointing and resumability
Long training runs fail. Data pipelines must support deterministic resumption:
- Save the dataset shard index and shuffle seed alongside model weights
- On resume, reconstruct the identical data order from that point
- WebDataset supports shard-level resumption by seeking to the correct shard offset
Without deterministic data order, a resumed run might re-train on already-seen data (causing overfitting) or skip data entirely.
Production Patterns
WebDataset loader for S3-backed shards
Store training data as tar shards in S3. Each shard holds ~500MB and contains paired samples (image + label, or text + metadata). The loader streams shards in parallel with no local disk required:
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
SHARD_URL = "s3://my-bucket/imagenet/train-{000000..001281}.tar"
transform = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = (
wds.WebDataset(
SHARD_URL,
resampled=True, # infinite streaming: reshuffle shards each epoch
shardshuffle=True, # randomize shard fetch order
cache_dir="/tmp/wds_cache", # optional: cache recently-used shards locally
nodesplitter=wds.split_by_node, # DDP: each node gets disjoint shards
)
.shuffle(2000) # in-memory buffer shuffle across shard boundaries
.decode("torchrgb") # JPEG → torch.Tensor
.to_tuple("jpg", "cls")
.map_tuple(transform, lambda x: x)
.batched(256, partial=False) # drop incomplete final batch
)
loader = DataLoader(
dataset,
batch_size=None, # batching handled by .batched() above
num_workers=8,
pin_memory=True,
prefetch_factor=3,
)For text data, replace .decode("torchrgb") with a custom decode that reads pre-tokenized numpy arrays directly from the tar entries.
Stratified sampling for imbalanced classes
Class imbalance is common in fraud detection, medical imaging, and click-through prediction. Two approaches: SQL-level downsampling before training, and PyTorch WeightedRandomSampler at load time.
SQL downsampling (offline, during dataset export):
-- Keep all positives; downsample negatives to 10:1 ratio
WITH positives AS (
SELECT * FROM training_events WHERE label = 1
),
negatives AS (
SELECT *
FROM training_events
WHERE label = 0
-- reservoir sample: reproducible with fixed seed
ORDER BY farm_fingerprint(concat(cast(event_id AS STRING), '42'))
LIMIT (SELECT COUNT(*) * 10 FROM positives)
)
SELECT * FROM positives
UNION ALL
SELECT * FROM negativesPyTorch WeightedRandomSampler (online, no disk overhead):
import torch
from torch.utils.data import WeightedRandomSampler
labels = torch.tensor(dataset.labels) # 0 or 1 per sample
class_counts = torch.bincount(labels) # [n_neg, n_pos]
class_weights = 1.0 / class_counts.float() # inverse frequency
sample_weights = class_weights[labels] # one weight per sample
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True, # sample with replacement
)
loader = DataLoader(dataset, batch_size=256, sampler=sampler, num_workers=4)Use the SQL approach when you can afford the storage (cleaner, reproducible). Use WeightedRandomSampler when iterating quickly and the dataset lives in memory.
Streaming data to GPU with prefetching
The default PyTorch pipeline stalls the GPU while CPU decodes the next batch. A prefetch queue keeps the GPU fed:
import threading
import queue
import torch
class CUDAPrefetcher:
"""Async prefetch to a GPU-pinned queue, overlapping transfer with forward pass."""
def __init__(self, loader: DataLoader, device: torch.device, queue_size: int = 3):
self.loader = loader
self.device = device
self.queue: queue.Queue = queue.Queue(maxsize=queue_size)
self._thread = threading.Thread(target=self._fill_queue, daemon=True)
self._thread.start()
def _fill_queue(self) -> None:
for batch in self.loader:
inputs, targets = batch
# Non-blocking copy: returns immediately, transfer happens in background
inputs = inputs.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
self.queue.put((inputs, targets))
self.queue.put(None) # sentinel
def __iter__(self):
while True:
item = self.queue.get()
if item is None:
return
yield item
# Usage
device = torch.device("cuda")
base_loader = DataLoader(dataset, batch_size=256, num_workers=8, pin_memory=True)
prefetcher = CUDAPrefetcher(base_loader, device, queue_size=4)
for inputs, targets in prefetcher:
# inputs and targets already on GPU before this line executes
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()pin_memory=True on the DataLoader is a prerequisite — it allocates CPU tensors in page-locked memory, enabling the non_blocking=True DMA transfer. On multi-GPU runs, instantiate one prefetcher per rank.
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.