Great question. The answer is rooted in a fundamental architectural difference between CNNs and Transformers.
The Core Problem: Transformers Have No Inductive Biases
CNNs come with two hardcoded assumptions baked into their architecture:
- Translation equivariance — if a cat moves 10px right, the feature map shifts 10px right. The same filter fires regardless of position.
- Locality — each neuron only looks at a small neighborhood. Features are built hierarchically from edges → textures → parts → objects.
These are not learned. They are structural priors that the architecture enforces by design.
A Vision Transformer has neither of these. Every patch attends to every other patch from layer 1. There is no notion of "nearby patches matter more." The model must learn spatial relationships entirely from data.
What This Means for Sample Efficiency
A ResNet-50 trained on 1,000 images can still learn something useful, because the convolution kernel already knows "look locally, slide across the image."
A ViT trained on 1,000 images learns almost nothing, because it has no prior about what to look at. The attention mechanism is completely unconstrained — it needs to discover spatial structure from scratch.
This is why the original ViT paper (Dosovitskiy et al.) required JFT-300M (300 million images) to outperform CNNs. On ImageNet alone (1.2M images) it was worse than ResNets.
DeiT's entire contribution is: "can we make ViT work on ImageNet-scale data without 300M images?" The answer is yes, but only if you aggressively compensate for the missing inductive bias through augmentation.
How Augmentation Compensates for Each Missing Bias
1. RandAugment → teaches translation + viewpoint invariance
Since there's no translation equivariance built in, the model has to see the same object in many positions and orientations to learn that position doesn't change class identity. RandAugment randomly applies rotations, shears, color shifts, posterize, solarize etc. — forcing the model to see each training image in dozens of transformed versions across epochs.
Without this, a ViT would learn "cat at position (14,14) → class cat" rather than "cat anywhere → class cat."
2. Random Erasing → forces global reasoning
CNNs naturally learn local texture features. ViTs can learn global features but won't necessarily — they might also latch onto local patches if those are sufficient for classification.
Random Erasing randomly blacks out rectangular regions of the image. This forces the model to not rely on any single patch or region. If the face of a dog is erased, the model must use the body, the paws, the fur texture across other patches. This directly trains the attention mechanism to actually distribute attention globally — which is the entire theoretical advantage of Transformers.
3. Mixup (α=0.8) → smooths the decision boundary
Mixup creates convex combinations of two images and their labels:
x_mixed = λ·x_i + (1-λ)·x_j
y_mixed = λ·y_i + (1-λ)·y_jFor a CNN this is a regularization trick. For a ViT it does something deeper — the model must produce intermediate confidence for mixed inputs. This prevents the model from being overconfident on any single patch's signal and forces it to build distributed, uncertain representations rather than firing hard on one dominant patch token.
It is essentially regularizing the attention weights themselves indirectly — you can't stake everything on one patch if the image is a 60/40 blend of a cat and a truck.
4. CutMix (α=1.0) → teaches patch-level spatial reasoning
CutMix is more aggressive than Mixup. Instead of blending pixels, it cuts a rectangular region from one image and pastes it into another, with labels mixed proportionally to area:
A dog image with a 40% region replaced by a car patch → label: 60% dog, 40% carThis is almost surgically designed for patch-based models. The model receives an image where some patch tokens literally come from a different class. It must:
- Identify which patches belong to which object
- Weigh them proportionally
- Not get confused by the foreign patches
This directly trains the attention mechanism to be discriminative at the patch token level — exactly what self-attention needs to do well to be useful.
5. Repeated Augmentation → addresses the small dataset problem
The paper also uses repeated augmentation: each image appears twice per epoch with different augmentations, rather than once. This effectively doubles the diversity of the training signal without adding new data. For a model with no locality prior, seeing more views of the same image is directly equivalent to having more training data.
The Unified Reason
Every one of these augmentations is solving the same underlying problem stated differently:
A ViT patch tokenizer destroys spatial structure. Self-attention must reconstruct it. That reconstruction requires enormous amounts of evidence. Augmentation manufactures that evidence artificially.
CNNs don't need this because the convolutional structure preserves spatial relationships for free. The filter literally slides — spatial structure is never destroyed in the first place.
This is also why DeiT's augmentation choices are not arbitrary — they were selected specifically because each one targets a different axis of spatial/semantic invariance that ViTs must learn from scratch:
| Augmentation | Invariance being taught |
|---|---|
| RandAugment | Viewpoint, color, photometric |
| Random Erasing | Partial occlusion, global context |
| Mixup | Soft label calibration, feature interpolation |
| CutMix | Patch-level spatial discrimination |
| Repeated Aug | Sample diversity, low-data compensation |
Remove any one of them and accuracy drops measurably. Remove all of them and you get a model that barely trains — which is exactly what you'd see if you ran your original 3-epoch script.
Yes, the paper mentions several other things beyond augmentation. Let me go through everything:
1. Knowledge Distillation via the Distillation Token — The Core DeiT Contribution
This is what separates DeiT from vanilla ViT entirely. The paper introduces a second learnable token appended alongside the class token.
Input sequence: [CLS token] [DIST token] [patch tokens...]- The
[CLS]token is trained with standard cross-entropy against ground truth labels - The
[DIST]token is trained to mimic the teacher's output - Both tokens attend to all patch tokens through the full transformer depth
- At inference, you average both:
(cls_out + dist_out) / 2
Your code already does this correctly. The key insight the paper makes is that these two tokens learn complementary representations — the class token learns from labels, the distillation token learns the teacher's "view" of the data. The paper shows experimentally that they are not redundant; combining them consistently beats using either alone.
2. Hard Distillation vs Soft Distillation — Paper Compares Both
The paper actually experiments with both variants and finds hard distillation slightly outperforms soft.
Soft distillation (what most KD literature uses):
# Teacher softmax probabilities as targets
T = 3.0 # temperature
soft_teacher = F.softmax(teacher_out / T, dim=1)
loss_distill = kl_loss(
F.log_softmax(dist_out / T, dim=1),
soft_teacher
) * (T ** 2)Hard distillation (what the paper recommends and you use):
teacher_label = teacher_out.argmax(dim=1)
loss_distill = ce_loss(dist_out, teacher_label)The paper's reasoning: hard labels from a strong teacher like RegNet carry enough signal without needing soft probability distributions. The argmax already encodes what the teacher is confident about.
3. Teacher Choice Matters — RegNet, Not ResNet
The paper specifically recommends using RegNetY-16GF as the teacher, not ResNet. The reasoning given is important:
- RegNet is a ConvNet, so it has the inductive biases ViT lacks
- The distillation token learns to mimic a model that already understands spatial structure natively
- This transfers the CNN's inductive bias into the ViT through the distillation signal
The paper explicitly shows that using another ViT as teacher performs worse than using a ConvNet teacher. The teacher's architectural prior matters, not just its accuracy.
For your CIFAR-10 setup, a well-trained ResNet-50 is a reasonable substitute, but worth knowing the paper is opinionated about this.
4. Positional Embeddings — Learned, Not Fixed
The paper uses learned positional embeddings, not sinusoidal ones. This is already standard in most ViT implementations but worth verifying in your deit.py:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, d_model))
# +2 for CLS token and DIST tokenThe +2 is important. Many ViT implementations only add +1 for the CLS token. If your positional embedding only accounts for +1, the distillation token has no positional encoding, which hurts.
5. Stochastic Depth (DropPath) — Not Regular Dropout
The paper uses stochastic depth (DropPath), not standard dropout on attention weights. This is a significant difference.
Regular dropout randomly zeros individual neurons. Stochastic depth randomly drops entire residual blocks during training:
from timm.layers import DropPath
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, dropout, drop_path_rate):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ff = FeedForward(d_model, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.ff(self.norm2(x)))
return xThe drop path rate is linearly increased across layers — early layers get a low rate, deep layers get a higher rate. For DeiT-Small the paper uses a max rate of 0.1python
# When building layers:
drop_path_rates = [x.item() for x in torch.linspace(0, 0.1, num_layers)]6. Pre-norm, Not Post-norm
The paper uses pre-norm (LayerNorm before attention/FFN), not post-norm. This is a training stability choice that matters more as depth increases:
# ✅ Pre-norm (what paper uses)
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.ff(self.norm2(x)))
# ❌ Post-norm (original "Attention is All You Need" style)
x = self.norm1(x + self.attn(x))
x = self.norm2(x + self.ff(x))7. No Bias in QKV Projection
A smaller detail but the paper follows the ViT convention of no bias terms in the QKV projection matrices:
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)Complete Checklist Against the Paper
| Component | Paper Spec | Status |
|---|---|---|
| Distillation token | Learnable, attends with CLS | Check your deit.py |
| Positional embed | Learned, size num_patches + 2 | Check +2 not +1 |
| Distillation type | Hard label (teacher argmax) | ✅ Your code |
| Loss weighting | 0.5 CLS + 0.5 DIST | ✅ Your code |
| Inference | (cls + dist) / 2 | ✅ Your code |
| Optimizer | AdamW, wd=0.05 | ❌ Fix needed |
| LR schedule | Cosine annealing | ❌ Fix needed |
| Label smoothing | ε = 0.1 | ❌ Fix needed |
| Augmentation | RandAugment + Erasing | ❌ Fix needed |
| Mixup | α = 0.8 | ❌ Fix needed |
| CutMix | α = 1.0 | ❌ Fix needed |
| Stochastic depth | DropPath, max 0.1, linear | ❌ Likely missing |
| Pre-norm | LayerNorm before attn/ff | Check your deit.py |
| QKV bias | No bias | Check your deit.py |
| Teacher type | ConvNet (RegNet preferred) | ⚠️ ResNet is acceptable |
The three things most likely missing from your deit.py itself are the +2 positional embedding, DropPath, and pre-norm ordering. Worth auditing those directly.
No comments:
Post a Comment