Deep Learning Feb 2026 40 min read

Training ResNet from Scratch: From First Principles to Multi-GPU ImageNet

Build it yourself, train it yourself, watch it fail yourself

A hands-on guide to implementing ResNet from scratch in PyTorch, training on CIFAR-100 and ImageNet, and scaling to multi-GPU with Distributed Data Parallel.

Why This Blog Exists

Training ResNet-50 from scratch on ImageNet is something only about 10,000 people in the world have done. Most ML practitioners download torchvision.models.resnet50(pretrained=True) and call it a day. There's nothing wrong with that — but if you want to actually understand what's happening under the hood, you need to build it yourself, train it yourself, and watch it fail yourself.

This blog documents my journey of doing exactly that, progressing from training a ResNet on CIFAR-100 (targeting 73% top-1 accuracy) to scaling up to full ImageNet-1K on AWS EC2 with multi-GPU training using PyTorch's Distributed Data Parallel (DDP), targeting 75% top-1 accuracy.

If you're the kind of person who wants to know why ResNets use skip connections (and not just that they do), why One Cycle Policy works (and not just how to call the scheduler), and why multi-GPU training requires gradient synchronization (and not just which wrapper to use) — this blog is for you.

01 The Problem That Created ResNet

The Degradation Problem — Why Deeper ≠ Better

Before ResNet, the deep learning community operated under a seemingly reasonable assumption: more layers should mean more capacity, which should mean better performance. VGG proved that going from 8 layers (AlexNet) to 19 layers worked great. So why not 50? 100? 152?

Here's what actually happened: when researchers stacked plain convolutional layers beyond ~20 layers, training accuracy itself got worse. Not just test accuracy — training accuracy. This is crucial to understand because it rules out overfitting as the explanation. If a 56-layer network has worse training loss than a 20-layer network, something is fundamentally broken in the optimization.

The thought experiment that motivated ResNet: Imagine you have a 20-layer network that works well. Now, take those exact same 20 layers and add 36 more layers on top — but make those 36 layers do nothing (identity mappings). This 56-layer network should, at minimum, perform exactly as well as the 20-layer one. The added layers just pass data through unchanged. But in practice, standard networks couldn't even learn this identity mapping! The vanishing gradient problem meant that by the time gradients reached the early layers through 56 layers of multiplication, they were essentially zero.

The Residual Learning Insight

The ResNet authors (He et al., 2015) had an elegant insight: instead of asking layers to learn the desired mapping H(x) directly, ask them to learn the residual F(x) = H(x) - x. Then the output becomes:

y = F(x) + x

This is the famous skip connection. Why is learning the residual easier? Because if the optimal mapping is close to identity (which it often is in deeper layers), then F(x) just needs to be close to zero. Pushing weights toward zero is much easier for SGD than learning an identity mapping from scratch.

Concrete example: Suppose layer 25 in your network receives a feature map that's already pretty good. In a plain network, layer 25 must learn to output something that's essentially the same as its input — a complex identity function through nonlinear activations. In a ResNet, layer 25 just needs to output approximately zero (the residual), and the skip connection handles the rest. Small residuals mean small gradients, which are numerically stable.

A powerful mental model: Think of ResNet as an ensemble of shallower networks. When you "unravel" the skip connections, a ResNet with 3 residual blocks actually contains paths of length 1, 2, and 3, plus combinations. The network effectively learns at multiple depths simultaneously, and the skip connections ensure gradients can always flow through the shorter paths even if longer paths have vanishing gradients.

02 ResNet Architecture — Every Layer Explained

The ResNet Family

Variant Block Type Block Config (per stage) Total Params
ResNet-18 Basic (2 convs) 2, 2, 2, 2 ~11M
ResNet-34 Basic (2 convs) 3, 4, 6, 3 ~21M
ResNet-50 Bottleneck (3 convs) 3, 4, 6, 3 ~25M
ResNet-101 Bottleneck (3 convs) 3, 4, 23, 3 ~44M
ResNet-152 Bottleneck (3 convs) 3, 8, 36, 3 ~60M

The Stem: First Contact with the Image

Every ResNet begins with: Input (224×224×3) → Conv 7×7, 64 filters, stride 2 → BN → ReLU → MaxPool 3×3, stride 2

After the stem, you're at 56×56×64. ResNet uses only one MaxPooling in the entire network. All other spatial reduction happens through strided convolutions.

Why 7×7 in the stem? The first layer needs a large receptive field to capture meaningful low-level features (edges, textures) from raw pixels.

Basic Block (ResNet-18/34)

Input (x) ─────────────────────────┐
  │                                 │ (identity shortcut)
  ├─→ Conv 3×3 → BN → ReLU         │
  ├─→ Conv 3×3 → BN                 │
  │                                 │
  └──── + ◄─────────────────────────┘
         │
       ReLU
         │
      Output

Two 3×3 convolutions with a skip connection. The BN-ReLU ordering here is V1 style: Conv → BN → ReLU. In V2 (pre-activation ResNet), it's BN → ReLU → Conv, which gives slightly better gradient flow.

Bottleneck Block (ResNet-50/101/152)

Input (x) ──────────────────────────────────┐
  │                                          │ (1×1 shortcut if dims change)
  ├─→ Conv 1×1 (reduce) → BN → ReLU         │
  ├─→ Conv 3×3           → BN → ReLU         │
  ├─→ Conv 1×1 (expand)  → BN                │
  │                                          │
  └──── + ◄──────────────────────────────────┘
         │
       ReLU
         │
      Output

The bottleneck uses a "squeeze-expand" pattern. Why bottleneck? A single 3×3 conv on 256 channels costs 589,824 parameters. The bottleneck sequence costs 69,632 parameters — roughly 8.5× fewer parameters for a similar receptive field contribution.

Handling Dimension Mismatches

Two types of skip connections:

  • Identity shortcut — when dims match, just add.
  • Projection shortcut — when dims change, use 1×1 conv with stride 2.

The Four Stages

Stage Output Size Channels Note
Stage 1 56×56 64 (Basic) / 256 (Bottleneck) No spatial reduction
Stage 2 28×28 128 / 512 First conv has stride 2
Stage 3 14×14 256 / 1024 First conv has stride 2
Stage 4 7×7 512 / 2048 First conv has stride 2

After Stage 4: Global Average Pooling (7×7 → 1×1) → FC layer → 1000 classes (ImageNet) or 100 classes (CIFAR-100).

03 Implementation from Scratch in PyTorch

The Basic Block

Python resnet.py
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    expansion = 1  # output channels = input channels × expansion

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample  # projection shortcut

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity  # THE skip connection
        out = self.relu(out)
        return out
Why bias=False? Because Batch Normalization immediately follows the convolution and includes its own learnable bias (β parameter). Having bias in both conv and BN is redundant.

The Bottleneck Block

Python resnet.py
class Bottleneck(nn.Module):
    expansion = 4  # output channels = out_channels × 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        # 1×1 squeeze
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # 3×3 spatial conv (stride applied here for spatial reduction)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        # 1×1 expand
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

The Full ResNet

Python resnet.py
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_channels = 64

        # Stem
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Four stages
        self.layer1 = self._make_layer(block, 64,  layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # Weight initialization (Kaiming/He init)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


# Factory functions
def resnet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

def resnet34(num_classes=1000):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

def resnet50(num_classes=1000):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)

def resnet101(num_classes=1000):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)

CIFAR-100 Adaptation

CIFAR-100 images are 32×32 — much smaller than ImageNet's 224×224. Replace the stem:

Python resnet_cifar.py
# Replace the stem for CIFAR-sized inputs (32×32)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# Remove maxpool entirely — 32×32 is already small!

04 The Training Recipe — Making It Actually Work

Optimizer: SGD with Momentum (Not Adam!)

For large-scale image classification, SGD with momentum consistently outperforms Adam in final accuracy, even though Adam converges faster initially.

Why? Adam's adaptive per-parameter learning rates tend to find sharp minima — solutions that have low training loss but generalize poorly. SGD with momentum tends to find flat minima that generalize much better.

Python train.py
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4  # L2 regularization
)

Learning Rate: The One Cycle Policy

Proposed by Leslie Smith. The recipe:

  • Find max LR using LR range test
  • Step 1 (warmup): Linearly increase LR from max_lr/10 to max_lr over ~45% of training
  • Step 2 (decay): Linearly decrease LR from max_lr back to max_lr/10 over ~45%
  • Annihilation (final ~10%): Decrease from max_lr/10 to max_lr/1000
Why does high LR in the middle work? A high learning rate acts as a regularizer — it bounces the optimizer out of sharp, narrow minima, forcing it to find flatter regions that generalize better. Cyclic momentum: When LR is high, lower momentum (0.80). When LR is low, higher momentum (0.95).
Python train.py
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    epochs=100,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=10,
    final_div_factor=100
)

Real results: Using One Cycle Policy, ResNet-56 achieved 91.54% on CIFAR-10 in 9,310 iterations vs ~64,000 without — a ~7× speedup.

Learning Rate Finder

Python lr_finder.py
from torch_lr_finder import LRFinder

model = resnet50(num_classes=100)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=10, num_iter=200)
lr_finder.plot()
lr_finder.reset()

Data Augmentation Strategy

PMDA (Poor Man's Data Augmentation):

Python augmentation.py
import torchvision.transforms as T

train_transforms = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(mean=[0.5071, 0.4867, 0.4408],
                std=[0.2675, 0.2565, 0.2761])
])

MMDA — CutOut:

Python augmentation.py
class CutOut:
    def __init__(self, n_holes=1, length=16):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones(h, w)
        for _ in range(self.n_holes):
            y = torch.randint(0, h, (1,)).item()
            x = torch.randint(0, w, (1,)).item()
            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)
            mask[y1:y2, x1:x2] = 0
        mask = mask.unsqueeze(0).expand_as(img)
        return img * mask

MixUp:

Python augmentation.py
def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

ImageNet augmentation:

Python augmentation.py
train_transforms = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

val_transforms = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

The Complete Single-GPU Training Loop

Python train.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

def train_one_epoch(model, loader, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Step per ITERATION for OneCycleLR

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return running_loss / total, 100. * correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(loader, desc="Evaluating"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return running_loss / total, 100. * correct / total


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(num_classes=100).to(device)

train_dataset = datasets.CIFAR100(root='./data', train=True,
                                   transform=train_transforms, download=True)
val_dataset = datasets.CIFAR100(root='./data', train=False,
                                 transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False,
                        num_workers=4, pin_memory=True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
                            momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.1, epochs=100,
    steps_per_epoch=len(train_loader)
)

best_acc = 0.0
for epoch in range(100):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, scheduler, device
    )
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/100 | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"  → New best: {best_acc:.2f}%")

05 Scaling to Multi-GPU with Distributed Data Parallel (DDP)

Why DDP and Not DataParallel?

Feature DataParallel DDP
Process model Single process, multiple threads Multiple processes
GIL bottleneck Yes (Python GIL) No (separate processes)
Memory usage GPU 0 gets more Even across all GPUs
Communication Gather/scatter on GPU 0 Ring all-reduce (distributed)
Scaling Poor beyond 2–4 GPUs Near-linear scaling

Understanding Ring All-Reduce

GPU 0 → GPU 1 → GPU 2 → GPU 3 → GPU 0

Phase 1 (Scatter-Reduce): Each GPU sends one chunk to next GPU. After N-1 steps, each holds fully reduced version of one chunk.

Phase 2 (All-Gather): Each GPU sends its reduced chunk around. After N-1 steps, every GPU has complete gradient.

Each GPU only sends/receives 2(N-1)/N times the data — communication cost grows sub-linearly.

Setting Up DDP

Step 1: Process Group Initialization

Python ddp_setup.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

Step 2: Distributed Data Sampler

Python ddp_setup.py
train_dataset = datasets.ImageFolder(
    '/data/imagenet/train',
    transform=train_transforms
)

train_sampler = DistributedSampler(
    train_dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    sampler=train_sampler,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)
Critical: Call train_sampler.set_epoch(epoch) each epoch. Without this, each GPU sees the same data order every epoch, which significantly hurts generalization.

Step 3: Wrap the Model

Python ddp_setup.py
def create_model(rank, num_classes=1000):
    model = resnet50(num_classes=num_classes).to(rank)
    model = DDP(model, device_ids=[rank])
    return model

Step 4: Full DDP Training Script

Python train_ddp.py
import torch.multiprocessing as mp
from datetime import timedelta

def train_ddp(rank, world_size, epochs=90):
    setup(rank, world_size)

    model = create_model(rank)

    train_dataset = datasets.ImageFolder(
        '/data/imagenet/train', transform=train_transforms
    )
    train_sampler = DistributedSampler(train_dataset, world_size, rank)
    train_loader = DataLoader(
        train_dataset, batch_size=64, sampler=train_sampler,
        num_workers=8, pin_memory=True, drop_last=True
    )

    val_dataset = datasets.ImageFolder(
        '/data/imagenet/val', transform=val_transforms
    )
    val_loader = DataLoader(
        val_dataset, batch_size=128, shuffle=False,
        num_workers=4, pin_memory=True
    )

    criterion = nn.CrossEntropyLoss().to(rank)
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1 * (world_size / 4),
        momentum=0.9, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.1 * (world_size / 4),
        epochs=epochs,
        steps_per_epoch=len(train_loader)
    )

    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)
        model.train()

        for images, labels in train_loader:
            images = images.to(rank, non_blocking=True)
            labels = labels.to(rank, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

        if rank == 0:
            val_loss, val_acc = evaluate(model.module, val_loader, criterion, rank)
            print(f"Epoch {epoch+1}/{epochs} | Val Acc: {val_acc:.2f}%")

            if (epoch + 1) % 10 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, f'checkpoint_epoch_{epoch+1}.pt')

    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)

Launching with torchrun

Bash terminal
# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py

# Multi-node (2 nodes, 4 GPUs each)
# On node 0:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
         --master_addr=<NODE0_IP> --master_port=12355 train.py

# On node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
         --master_addr=<NODE0_IP> --master_port=12355 train.py

With torchrun:

Python train_ddp.py
def setup_torchrun():
    dist.init_process_group("nccl")
    rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(rank)
    return rank

06 AWS EC2 Setup for Multi-GPU Training

Instance Selection

Instance GPUs GPU Memory Cost (approx)
p3.8xlarge 4 × V100 4 × 16GB ~$12/hr
p3.16xlarge 8 × V100 8 × 16GB ~$24/hr
p4d.24xlarge 8 × A100 8 × 40GB ~$32/hr

Quick Setup Checklist

Bash setup.sh
# 1. Launch instance with Deep Learning AMI

# 2. Download ImageNet (~150GB)
#    /data/imagenet/
#    ├── train/
#    │   ├── n01440764/
#    │   └── n15075141/
#    └── val/
#        ├── n01440764/
#        └── n15075141/

# 3. Install dependencies
pip install torch torchvision tqdm tensorboard

# 4. Verify GPU availability
python -c "import torch; print(f'GPUs: {torch.cuda.device_count()}')"

# 5. Launch training
torchrun --nproc_per_node=4 train_imagenet.py

Pro Tips

Start small before going big. Train on a small subset on Colab first.

Use spot instances for 60–70% savings. Implement checkpoint saving.

Monitor with TensorBoard:

Python monitoring.py
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/resnet50_imagenet')
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
writer.add_scalar('LR', scheduler.get_last_lr()[0], epoch)
GPU monitoring: Use nvidia-smi dmon to monitor GPU utilization. If utilization is < 80%, increase num_workers in your DataLoader to keep the GPU fed with data.

07 Checkpoint & Resume

Python checkpoint.py
def save_checkpoint(model, optimizer, scheduler, epoch, step, path):
    checkpoint = {
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.module.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")


def load_checkpoint(model, optimizer, scheduler, path, rank):
    checkpoint = torch.load(path, map_location=f'cuda:{rank}')
    model.module.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    start_step = checkpoint['step']
    print(f"Resumed from epoch {start_epoch}, step {start_step}")
    return start_epoch, start_step

08 Common Pitfalls and Debugging Guide

Pitfall 1: Forgetting set_epoch()

Python pitfalls.py
# WRONG
for epoch in range(num_epochs):
    for data in train_loader:
        ...

# RIGHT
for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)  # DON'T FORGET THIS
    for data in train_loader:
        ...

Pitfall 2: Accessing model.module vs model

Python pitfalls.py
# Saving
torch.save(model.module.state_dict(), 'model.pth')  # ✅

# Loading
model.module.load_state_dict(torch.load('model.pth'))  # ✅

Pitfall 3: Batch Norm + Small Per-GPU Batch Size

Python pitfalls.py
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])

Pitfall 4: LR Not Scaled with Batch Size

Linear scaling rule: new_lr = base_lr × (new_batch_size / base_batch_size). Add warmup of 5 epochs when using large LRs.

Pitfall 5: OOM on ImageNet

Use mixed precision:

Python mixed_precision.py
scaler = torch.cuda.amp.GradScaler()

for images, labels in train_loader:
    optimizer.zero_grad()

    with torch.cuda.amp.autocast():
        outputs = model(images)
        loss = criterion(outputs, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

09 Target Benchmarks

Task Target Typical Epochs
ResNet on CIFAR-100 73% top-1 accuracy ~100
ResNet-50 on ImageNet-1K 75% top-1 accuracy ~90

The ImageNet target of 75% top-1 is a well-known benchmark. With modern tricks (MixUp, CutMix, longer training, cosine LR), you can push to 78–80%.

10 Deployment on HuggingFace Spaces

Python app.py
# app.py for Gradio on HuggingFace Spaces
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image

model = resnet50(num_classes=1000)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()

with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def predict(image):
    img = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img)
        probs = torch.nn.functional.softmax(outputs[0], dim=0)

    top5_prob, top5_idx = torch.topk(probs, 5)
    return {labels[idx]: prob.item() for prob, idx in zip(top5_prob, top5_idx)}

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="ResNet-50 ImageNet Classifier (Trained from Scratch)"
)
demo.launch()

Summary: The Complete Recipe

  • 1. Understand the architecture — ResNet solves the degradation problem through skip connections.
  • 2. Implement from scratch — Build BasicBlock, Bottleneck, and the full ResNet.
  • 3. Train with the right recipe — SGD with momentum, One Cycle Policy, appropriate augmentation.
  • 4. Scale with DDP — DistributedDataParallel with ring all-reduce. Scale LR with effective batch size.
  • 5. Be resilient — Checkpoint frequently, handle spot instance terminations, monitor GPU utilization.
  • 6. Deploy — Share on HuggingFace Spaces.
Training ResNet-50 from scratch on full ImageNet is something only about 10,000 people
in the world have done. If you follow this guide to the end, you'll be one of them.