Training ResNet from Scratch: From First Principles to Multi-GPU ImageNet
Build it yourself, train it yourself, watch it fail yourself
A hands-on guide to implementing ResNet from scratch in PyTorch, training on CIFAR-100 and ImageNet, and scaling to multi-GPU with Distributed Data Parallel.
Why This Blog Exists
Training ResNet-50 from scratch on ImageNet is something only about 10,000 people in the world have done. Most ML practitioners download torchvision.models.resnet50(pretrained=True) and call it a day. There's nothing wrong with that — but if you want to actually understand what's happening under the hood, you need to build it yourself, train it yourself, and watch it fail yourself.
This blog documents my journey of doing exactly that, progressing from training a ResNet on CIFAR-100 (targeting 73% top-1 accuracy) to scaling up to full ImageNet-1K on AWS EC2 with multi-GPU training using PyTorch's Distributed Data Parallel (DDP), targeting 75% top-1 accuracy.
If you're the kind of person who wants to know why ResNets use skip connections (and not just that they do), why One Cycle Policy works (and not just how to call the scheduler), and why multi-GPU training requires gradient synchronization (and not just which wrapper to use) — this blog is for you.
01 The Problem That Created ResNet
The Degradation Problem — Why Deeper ≠ Better
Before ResNet, the deep learning community operated under a seemingly reasonable assumption: more layers should mean more capacity, which should mean better performance. VGG proved that going from 8 layers (AlexNet) to 19 layers worked great. So why not 50? 100? 152?
Here's what actually happened: when researchers stacked plain convolutional layers beyond ~20 layers, training accuracy itself got worse. Not just test accuracy — training accuracy. This is crucial to understand because it rules out overfitting as the explanation. If a 56-layer network has worse training loss than a 20-layer network, something is fundamentally broken in the optimization.
The Residual Learning Insight
The ResNet authors (He et al., 2015) had an elegant insight: instead of asking layers to learn the desired mapping H(x) directly, ask them to learn the residual F(x) = H(x) - x. Then the output becomes:
y = F(x) + x
This is the famous skip connection. Why is learning the residual easier? Because if the optimal mapping is close to identity (which it often is in deeper layers), then F(x) just needs to be close to zero. Pushing weights toward zero is much easier for SGD than learning an identity mapping from scratch.
Concrete example: Suppose layer 25 in your network receives a feature map that's already pretty good. In a plain network, layer 25 must learn to output something that's essentially the same as its input — a complex identity function through nonlinear activations. In a ResNet, layer 25 just needs to output approximately zero (the residual), and the skip connection handles the rest. Small residuals mean small gradients, which are numerically stable.
A powerful mental model: Think of ResNet as an ensemble of shallower networks. When you "unravel" the skip connections, a ResNet with 3 residual blocks actually contains paths of length 1, 2, and 3, plus combinations. The network effectively learns at multiple depths simultaneously, and the skip connections ensure gradients can always flow through the shorter paths even if longer paths have vanishing gradients.
02 ResNet Architecture — Every Layer Explained
The ResNet Family
| Variant | Block Type | Block Config (per stage) | Total Params |
|---|---|---|---|
| ResNet-18 | Basic (2 convs) | 2, 2, 2, 2 | ~11M |
| ResNet-34 | Basic (2 convs) | 3, 4, 6, 3 | ~21M |
| ResNet-50 | Bottleneck (3 convs) | 3, 4, 6, 3 | ~25M |
| ResNet-101 | Bottleneck (3 convs) | 3, 4, 23, 3 | ~44M |
| ResNet-152 | Bottleneck (3 convs) | 3, 8, 36, 3 | ~60M |
The Stem: First Contact with the Image
Every ResNet begins with: Input (224×224×3) → Conv 7×7, 64 filters, stride 2 → BN → ReLU → MaxPool 3×3, stride 2
After the stem, you're at 56×56×64. ResNet uses only one MaxPooling in the entire network. All other spatial reduction happens through strided convolutions.
Why 7×7 in the stem? The first layer needs a large receptive field to capture meaningful low-level features (edges, textures) from raw pixels.
Basic Block (ResNet-18/34)
Input (x) ─────────────────────────┐
│ │ (identity shortcut)
├─→ Conv 3×3 → BN → ReLU │
├─→ Conv 3×3 → BN │
│ │
└──── + ◄─────────────────────────┘
│
ReLU
│
Output
Two 3×3 convolutions with a skip connection. The BN-ReLU ordering here is V1 style: Conv → BN → ReLU. In V2 (pre-activation ResNet), it's BN → ReLU → Conv, which gives slightly better gradient flow.
Bottleneck Block (ResNet-50/101/152)
Input (x) ──────────────────────────────────┐
│ │ (1×1 shortcut if dims change)
├─→ Conv 1×1 (reduce) → BN → ReLU │
├─→ Conv 3×3 → BN → ReLU │
├─→ Conv 1×1 (expand) → BN │
│ │
└──── + ◄──────────────────────────────────┘
│
ReLU
│
Output
The bottleneck uses a "squeeze-expand" pattern. Why bottleneck? A single 3×3 conv on 256 channels costs 589,824 parameters. The bottleneck sequence costs 69,632 parameters — roughly 8.5× fewer parameters for a similar receptive field contribution.
Handling Dimension Mismatches
Two types of skip connections:
- Identity shortcut — when dims match, just add.
- Projection shortcut — when dims change, use 1×1 conv with stride 2.
The Four Stages
| Stage | Output Size | Channels | Note |
|---|---|---|---|
| Stage 1 | 56×56 | 64 (Basic) / 256 (Bottleneck) | No spatial reduction |
| Stage 2 | 28×28 | 128 / 512 | First conv has stride 2 |
| Stage 3 | 14×14 | 256 / 1024 | First conv has stride 2 |
| Stage 4 | 7×7 | 512 / 2048 | First conv has stride 2 |
After Stage 4: Global Average Pooling (7×7 → 1×1) → FC layer → 1000 classes (ImageNet) or 100 classes (CIFAR-100).
03 Implementation from Scratch in PyTorch
The Basic Block
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
expansion = 1 # output channels = input channels × expansion
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample # projection shortcut
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity # THE skip connection
out = self.relu(out)
return out
The Bottleneck Block
class Bottleneck(nn.Module):
expansion = 4 # output channels = out_channels × 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
# 1×1 squeeze
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# 3×3 spatial conv (stride applied here for spatial reduction)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1×1 expand
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
The Full ResNet
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super().__init__()
self.in_channels = 64
# Stem
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Four stages
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# Classification head
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# Weight initialization (Kaiming/He init)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, out_channels, num_blocks, stride):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, num_blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Factory functions
def resnet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
def resnet34(num_classes=1000):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
def resnet50(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
def resnet101(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
CIFAR-100 Adaptation
CIFAR-100 images are 32×32 — much smaller than ImageNet's 224×224. Replace the stem:
# Replace the stem for CIFAR-sized inputs (32×32)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# Remove maxpool entirely — 32×32 is already small!
04 The Training Recipe — Making It Actually Work
Optimizer: SGD with Momentum (Not Adam!)
For large-scale image classification, SGD with momentum consistently outperforms Adam in final accuracy, even though Adam converges faster initially.
Why? Adam's adaptive per-parameter learning rates tend to find sharp minima — solutions that have low training loss but generalize poorly. SGD with momentum tends to find flat minima that generalize much better.
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9,
weight_decay=1e-4 # L2 regularization
)
Learning Rate: The One Cycle Policy
Proposed by Leslie Smith. The recipe:
- Find max LR using LR range test
- Step 1 (warmup): Linearly increase LR from max_lr/10 to max_lr over ~45% of training
- Step 2 (decay): Linearly decrease LR from max_lr back to max_lr/10 over ~45%
- Annihilation (final ~10%): Decrease from max_lr/10 to max_lr/1000
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.1,
epochs=100,
steps_per_epoch=len(train_loader),
pct_start=0.3,
anneal_strategy='cos',
div_factor=10,
final_div_factor=100
)
Real results: Using One Cycle Policy, ResNet-56 achieved 91.54% on CIFAR-10 in 9,310 iterations vs ~64,000 without — a ~7× speedup.
Learning Rate Finder
from torch_lr_finder import LRFinder
model = resnet50(num_classes=100)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=10, num_iter=200)
lr_finder.plot()
lr_finder.reset()
Data Augmentation Strategy
PMDA (Poor Man's Data Augmentation):
import torchvision.transforms as T
train_transforms = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(p=0.5),
T.ToTensor(),
T.Normalize(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761])
])
MMDA — CutOut:
class CutOut:
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = torch.ones(h, w)
for _ in range(self.n_holes):
y = torch.randint(0, h, (1,)).item()
x = torch.randint(0, w, (1,)).item()
y1 = max(0, y - self.length // 2)
y2 = min(h, y + self.length // 2)
x1 = max(0, x - self.length // 2)
x2 = min(w, x + self.length // 2)
mask[y1:y2, x1:x2] = 0
mask = mask.unsqueeze(0).expand_as(img)
return img * mask
MixUp:
def mixup_data(x, y, alpha=0.2):
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
ImageNet augmentation:
train_transforms = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transforms = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
The Complete Single-GPU Training Loop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
def train_one_epoch(model, loader, criterion, optimizer, scheduler, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(loader, desc="Training"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step() # Step per ITERATION for OneCycleLR
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / total, 100. * correct / total
@torch.no_grad()
def evaluate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(loader, desc="Evaluating"):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / total, 100. * correct / total
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(num_classes=100).to(device)
train_dataset = datasets.CIFAR100(root='./data', train=True,
transform=train_transforms, download=True)
val_dataset = datasets.CIFAR100(root='./data', train=False,
transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False,
num_workers=4, pin_memory=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.1, epochs=100,
steps_per_epoch=len(train_loader)
)
best_acc = 0.0
for epoch in range(100):
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, scheduler, device
)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1}/100 | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f" → New best: {best_acc:.2f}%")
05 Scaling to Multi-GPU with Distributed Data Parallel (DDP)
Why DDP and Not DataParallel?
| Feature | DataParallel | DDP |
|---|---|---|
| Process model | Single process, multiple threads | Multiple processes |
| GIL bottleneck | Yes (Python GIL) | No (separate processes) |
| Memory usage | GPU 0 gets more | Even across all GPUs |
| Communication | Gather/scatter on GPU 0 | Ring all-reduce (distributed) |
| Scaling | Poor beyond 2–4 GPUs | Near-linear scaling |
Understanding Ring All-Reduce
GPU 0 → GPU 1 → GPU 2 → GPU 3 → GPU 0
Phase 1 (Scatter-Reduce): Each GPU sends one chunk to next GPU. After N-1 steps, each holds fully reduced version of one chunk.
Phase 2 (All-Gather): Each GPU sends its reduced chunk around. After N-1 steps, every GPU has complete gradient.
Each GPU only sends/receives 2(N-1)/N times the data — communication cost grows sub-linearly.
Setting Up DDP
Step 1: Process Group Initialization
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
Step 2: Distributed Data Sampler
train_dataset = datasets.ImageFolder(
'/data/imagenet/train',
transform=train_transforms
)
train_sampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
train_loader = DataLoader(
train_dataset,
batch_size=64,
sampler=train_sampler,
num_workers=4,
pin_memory=True,
drop_last=True
)
train_sampler.set_epoch(epoch) each epoch. Without this, each GPU sees the same data order every epoch, which significantly hurts generalization.
Step 3: Wrap the Model
def create_model(rank, num_classes=1000):
model = resnet50(num_classes=num_classes).to(rank)
model = DDP(model, device_ids=[rank])
return model
Step 4: Full DDP Training Script
import torch.multiprocessing as mp
from datetime import timedelta
def train_ddp(rank, world_size, epochs=90):
setup(rank, world_size)
model = create_model(rank)
train_dataset = datasets.ImageFolder(
'/data/imagenet/train', transform=train_transforms
)
train_sampler = DistributedSampler(train_dataset, world_size, rank)
train_loader = DataLoader(
train_dataset, batch_size=64, sampler=train_sampler,
num_workers=8, pin_memory=True, drop_last=True
)
val_dataset = datasets.ImageFolder(
'/data/imagenet/val', transform=val_transforms
)
val_loader = DataLoader(
val_dataset, batch_size=128, shuffle=False,
num_workers=4, pin_memory=True
)
criterion = nn.CrossEntropyLoss().to(rank)
optimizer = torch.optim.SGD(
model.parameters(), lr=0.1 * (world_size / 4),
momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.1 * (world_size / 4),
epochs=epochs,
steps_per_epoch=len(train_loader)
)
for epoch in range(epochs):
train_sampler.set_epoch(epoch)
model.train()
for images, labels in train_loader:
images = images.to(rank, non_blocking=True)
labels = labels.to(rank, non_blocking=True)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
if rank == 0:
val_loss, val_acc = evaluate(model.module, val_loader, criterion, rank)
print(f"Epoch {epoch+1}/{epochs} | Val Acc: {val_acc:.2f}%")
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, f'checkpoint_epoch_{epoch+1}.pt')
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)
Launching with torchrun
# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py
# Multi-node (2 nodes, 4 GPUs each)
# On node 0:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=<NODE0_IP> --master_port=12355 train.py
# On node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=<NODE0_IP> --master_port=12355 train.py
With torchrun:
def setup_torchrun():
dist.init_process_group("nccl")
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)
return rank
06 AWS EC2 Setup for Multi-GPU Training
Instance Selection
| Instance | GPUs | GPU Memory | Cost (approx) |
|---|---|---|---|
| p3.8xlarge | 4 × V100 | 4 × 16GB | ~$12/hr |
| p3.16xlarge | 8 × V100 | 8 × 16GB | ~$24/hr |
| p4d.24xlarge | 8 × A100 | 8 × 40GB | ~$32/hr |
Quick Setup Checklist
# 1. Launch instance with Deep Learning AMI
# 2. Download ImageNet (~150GB)
# /data/imagenet/
# ├── train/
# │ ├── n01440764/
# │ └── n15075141/
# └── val/
# ├── n01440764/
# └── n15075141/
# 3. Install dependencies
pip install torch torchvision tqdm tensorboard
# 4. Verify GPU availability
python -c "import torch; print(f'GPUs: {torch.cuda.device_count()}')"
# 5. Launch training
torchrun --nproc_per_node=4 train_imagenet.py
Pro Tips
Start small before going big. Train on a small subset on Colab first.
Use spot instances for 60–70% savings. Implement checkpoint saving.
Monitor with TensorBoard:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/resnet50_imagenet')
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
writer.add_scalar('LR', scheduler.get_last_lr()[0], epoch)
nvidia-smi dmon to monitor GPU utilization. If utilization is < 80%, increase num_workers in your DataLoader to keep the GPU fed with data.
07 Checkpoint & Resume
def save_checkpoint(model, optimizer, scheduler, epoch, step, path):
checkpoint = {
'epoch': epoch,
'step': step,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}
torch.save(checkpoint, path)
print(f"Checkpoint saved: {path}")
def load_checkpoint(model, optimizer, scheduler, path, rank):
checkpoint = torch.load(path, map_location=f'cuda:{rank}')
model.module.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
start_step = checkpoint['step']
print(f"Resumed from epoch {start_epoch}, step {start_step}")
return start_epoch, start_step
08 Common Pitfalls and Debugging Guide
Pitfall 1: Forgetting set_epoch()
# WRONG
for epoch in range(num_epochs):
for data in train_loader:
...
# RIGHT
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch) # DON'T FORGET THIS
for data in train_loader:
...
Pitfall 2: Accessing model.module vs model
# Saving
torch.save(model.module.state_dict(), 'model.pth') # ✅
# Loading
model.module.load_state_dict(torch.load('model.pth')) # ✅
Pitfall 3: Batch Norm + Small Per-GPU Batch Size
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
Pitfall 4: LR Not Scaled with Batch Size
Linear scaling rule: new_lr = base_lr × (new_batch_size / base_batch_size). Add warmup of 5 epochs when using large LRs.
Pitfall 5: OOM on ImageNet
Use mixed precision:
scaler = torch.cuda.amp.GradScaler()
for images, labels in train_loader:
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
09 Target Benchmarks
| Task | Target | Typical Epochs |
|---|---|---|
| ResNet on CIFAR-100 | 73% top-1 accuracy | ~100 |
| ResNet-50 on ImageNet-1K | 75% top-1 accuracy | ~90 |
The ImageNet target of 75% top-1 is a well-known benchmark. With modern tricks (MixUp, CutMix, longer training, cosine LR), you can push to 78–80%.
10 Deployment on HuggingFace Spaces
# app.py for Gradio on HuggingFace Spaces
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
model = resnet50(num_classes=1000)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
with open('imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict(image):
img = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
top5_prob, top5_idx = torch.topk(probs, 5)
return {labels[idx]: prob.item() for prob, idx in zip(top5_prob, top5_idx)}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="ResNet-50 ImageNet Classifier (Trained from Scratch)"
)
demo.launch()
Summary: The Complete Recipe
- 1. Understand the architecture — ResNet solves the degradation problem through skip connections.
- 2. Implement from scratch — Build BasicBlock, Bottleneck, and the full ResNet.
- 3. Train with the right recipe — SGD with momentum, One Cycle Policy, appropriate augmentation.
- 4. Scale with DDP — DistributedDataParallel with ring all-reduce. Scale LR with effective batch size.
- 5. Be resilient — Checkpoint frequently, handle spot instance terminations, monitor GPU utilization.
- 6. Deploy — Share on HuggingFace Spaces.
Training ResNet-50 from scratch on full ImageNet is something only about 10,000 people
in the world have done. If you follow this guide to the end, you'll be one of them.