Thursday, 2 April 2026

DEIT

 Great question. The answer is rooted in a fundamental architectural difference between CNNs and Transformers.


The Core Problem: Transformers Have No Inductive Biases

CNNs come with two hardcoded assumptions baked into their architecture:

  1. Translation equivariance — if a cat moves 10px right, the feature map shifts 10px right. The same filter fires regardless of position.
  2. Locality — each neuron only looks at a small neighborhood. Features are built hierarchically from edges → textures → parts → objects.

These are not learned. They are structural priors that the architecture enforces by design.

A Vision Transformer has neither of these. Every patch attends to every other patch from layer 1. There is no notion of "nearby patches matter more." The model must learn spatial relationships entirely from data.


What This Means for Sample Efficiency

A ResNet-50 trained on 1,000 images can still learn something useful, because the convolution kernel already knows "look locally, slide across the image."

A ViT trained on 1,000 images learns almost nothing, because it has no prior about what to look at. The attention mechanism is completely unconstrained — it needs to discover spatial structure from scratch.

This is why the original ViT paper (Dosovitskiy et al.) required JFT-300M (300 million images) to outperform CNNs. On ImageNet alone (1.2M images) it was worse than ResNets.

DeiT's entire contribution is: "can we make ViT work on ImageNet-scale data without 300M images?" The answer is yes, but only if you aggressively compensate for the missing inductive bias through augmentation.


How Augmentation Compensates for Each Missing Bias

1. RandAugment → teaches translation + viewpoint invariance

Since there's no translation equivariance built in, the model has to see the same object in many positions and orientations to learn that position doesn't change class identity. RandAugment randomly applies rotations, shears, color shifts, posterize, solarize etc. — forcing the model to see each training image in dozens of transformed versions across epochs.

Without this, a ViT would learn "cat at position (14,14) → class cat" rather than "cat anywhere → class cat."

2. Random Erasing → forces global reasoning

CNNs naturally learn local texture features. ViTs can learn global features but won't necessarily — they might also latch onto local patches if those are sufficient for classification.

Random Erasing randomly blacks out rectangular regions of the image. This forces the model to not rely on any single patch or region. If the face of a dog is erased, the model must use the body, the paws, the fur texture across other patches. This directly trains the attention mechanism to actually distribute attention globally — which is the entire theoretical advantage of Transformers.

3. Mixup (α=0.8) → smooths the decision boundary

Mixup creates convex combinations of two images and their labels:

x_mixed = λ·x_i + (1-λ)·x_j
y_mixed = λ·y_i + (1-λ)·y_j

For a CNN this is a regularization trick. For a ViT it does something deeper — the model must produce intermediate confidence for mixed inputs. This prevents the model from being overconfident on any single patch's signal and forces it to build distributed, uncertain representations rather than firing hard on one dominant patch token.

It is essentially regularizing the attention weights themselves indirectly — you can't stake everything on one patch if the image is a 60/40 blend of a cat and a truck.

4. CutMix (α=1.0) → teaches patch-level spatial reasoning

CutMix is more aggressive than Mixup. Instead of blending pixels, it cuts a rectangular region from one image and pastes it into another, with labels mixed proportionally to area:

A dog image with a 40% region replaced by a car patch → label: 60% dog, 40% car

This is almost surgically designed for patch-based models. The model receives an image where some patch tokens literally come from a different class. It must:

  • Identify which patches belong to which object
  • Weigh them proportionally
  • Not get confused by the foreign patches

This directly trains the attention mechanism to be discriminative at the patch token level — exactly what self-attention needs to do well to be useful.

5. Repeated Augmentation → addresses the small dataset problem

The paper also uses repeated augmentation: each image appears twice per epoch with different augmentations, rather than once. This effectively doubles the diversity of the training signal without adding new data. For a model with no locality prior, seeing more views of the same image is directly equivalent to having more training data.


The Unified Reason

Every one of these augmentations is solving the same underlying problem stated differently:

A ViT patch tokenizer destroys spatial structure. Self-attention must reconstruct it. That reconstruction requires enormous amounts of evidence. Augmentation manufactures that evidence artificially.

CNNs don't need this because the convolutional structure preserves spatial relationships for free. The filter literally slides — spatial structure is never destroyed in the first place.

This is also why DeiT's augmentation choices are not arbitrary — they were selected specifically because each one targets a different axis of spatial/semantic invariance that ViTs must learn from scratch:

AugmentationInvariance being taught
RandAugmentViewpoint, color, photometric
Random ErasingPartial occlusion, global context
MixupSoft label calibration, feature interpolation
CutMixPatch-level spatial discrimination
Repeated AugSample diversity, low-data compensation

Remove any one of them and accuracy drops measurably. Remove all of them and you get a model that barely trains — which is exactly what you'd see if you ran your original 3-epoch script.

and waht paper emntions all rest correct? or anythiong

9:58 PM

Yes, the paper mentions several other things beyond augmentation. Let me go through everything:


1. Knowledge Distillation via the Distillation Token — The Core DeiT Contribution

This is what separates DeiT from vanilla ViT entirely. The paper introduces a second learnable token appended alongside the class token.

Input sequence: [CLS token] [DIST token] [patch tokens...]
  • The [CLS] token is trained with standard cross-entropy against ground truth labels
  • The [DIST] token is trained to mimic the teacher's output
  • Both tokens attend to all patch tokens through the full transformer depth
  • At inference, you average both: (cls_out + dist_out) / 2

Your code already does this correctly. The key insight the paper makes is that these two tokens learn complementary representations — the class token learns from labels, the distillation token learns the teacher's "view" of the data. The paper shows experimentally that they are not redundant; combining them consistently beats using either alone.


2. Hard Distillation vs Soft Distillation — Paper Compares Both

The paper actually experiments with both variants and finds hard distillation slightly outperforms soft.

Soft distillation (what most KD literature uses):

python
# Teacher softmax probabilities as targets
T = 3.0  # temperature
soft_teacher = F.softmax(teacher_out / T, dim=1)
loss_distill = kl_loss(
    F.log_softmax(dist_out / T, dim=1),
    soft_teacher
) * (T ** 2)

Hard distillation (what the paper recommends and you use):

python
teacher_label = teacher_out.argmax(dim=1)
loss_distill = ce_loss(dist_out, teacher_label)

The paper's reasoning: hard labels from a strong teacher like RegNet carry enough signal without needing soft probability distributions. The argmax already encodes what the teacher is confident about.


3. Teacher Choice Matters — RegNet, Not ResNet

The paper specifically recommends using RegNetY-16GF as the teacher, not ResNet. The reasoning given is important:

  • RegNet is a ConvNet, so it has the inductive biases ViT lacks
  • The distillation token learns to mimic a model that already understands spatial structure natively
  • This transfers the CNN's inductive bias into the ViT through the distillation signal

The paper explicitly shows that using another ViT as teacher performs worse than using a ConvNet teacher. The teacher's architectural prior matters, not just its accuracy.

For your CIFAR-10 setup, a well-trained ResNet-50 is a reasonable substitute, but worth knowing the paper is opinionated about this.


4. Positional Embeddings — Learned, Not Fixed

The paper uses learned positional embeddings, not sinusoidal ones. This is already standard in most ViT implementations but worth verifying in your deit.py:

python
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, d_model))
# +2 for CLS token and DIST token

The +2 is important. Many ViT implementations only add +1 for the CLS token. If your positional embedding only accounts for +1, the distillation token has no positional encoding, which hurts.


5. Stochastic Depth (DropPath) — Not Regular Dropout

The paper uses stochastic depth (DropPath), not standard dropout on attention weights. This is a significant difference.

Regular dropout randomly zeros individual neurons. Stochastic depth randomly drops entire residual blocks during training:

python
from timm.layers import DropPath

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout, drop_path_rate):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff   = FeedForward(d_model, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.ff(self.norm2(x)))
        return x

The drop path rate is linearly increased across layers — early layers get a low rate, deep layers get a higher rate. For DeiT-Small the paper uses a max rate of 0.1python

# When building layers:
drop_path_rates = [x.item() for x in torch.linspace(0, 0.1, num_layers)]

6. Pre-norm, Not Post-norm

The paper uses pre-norm (LayerNorm before attention/FFN), not post-norm. This is a training stability choice that matters more as depth increases:

python
# ✅ Pre-norm (what paper uses)
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.ff(self.norm2(x)))

# ❌ Post-norm (original "Attention is All You Need" style)
x = self.norm1(x + self.attn(x))
x = self.norm2(x + self.ff(x))

7. No Bias in QKV Projection

A smaller detail but the paper follows the ViT convention of no bias terms in the QKV projection matrices:

python
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)

Complete Checklist Against the Paper

ComponentPaper SpecStatus
Distillation tokenLearnable, attends with CLSCheck your deit.py
Positional embedLearned, size num_patches + 2Check +2 not +1
Distillation typeHard label (teacher argmax)✅ Your code
Loss weighting0.5 CLS + 0.5 DIST✅ Your code
Inference(cls + dist) / 2✅ Your code
OptimizerAdamW, wd=0.05❌ Fix needed
LR scheduleCosine annealing❌ Fix needed
Label smoothingε = 0.1❌ Fix needed
AugmentationRandAugment + Erasing❌ Fix needed
Mixupα = 0.8❌ Fix needed
CutMixα = 1.0❌ Fix needed
Stochastic depthDropPath, max 0.1, linear❌ Likely missing
Pre-normLayerNorm before attn/ffCheck your deit.py
QKV biasNo biasCheck your deit.py
Teacher typeConvNet (RegNet preferred)⚠️ ResNet is acceptable

The three things most likely missing from your deit.py itself are the +2 positional embedding, DropPath, and pre-norm ordering. Worth auditing those directly.

No comments:

Post a Comment

SDE floyd-warshall algorithm

 // User function template for C++ class Solution {   public:     void floydWarshall(vector<vector<int>> &dist) {         //...