VLM FoundationsFree

ViT: Vision Transformer (2020, Google Research)

ViT proves that pure Transformers can process images independently of CNNs — a complete architecture walkthrough from Patch Embedding to MLP Head.

ViT(2020, Google Research)

ViT: A Foundation for Multimodal Task Development (2020, Google Research) Vision Transformer

1. ViT Overview

One-sentence summary: The main contribution is applying an existing architecture (Transformers) to the CV domain. The training method and the dataset used for pre-training are key to ViT's excellent results on ImageNet compared to SOTA. It demonstrated that ViT can replace CNNs, validated the emergent capabilities of large-scale data + large models in the image domain, and laid the groundwork for the development of multimodal tasks over the subsequent two years. ViT's design follows the principle of minimal modification, directly applying the original Transformer to image classification tasks. The researchers proposed that it is not always necessary to rely on CNNs; a Transformer alone can achieve excellent performance in classification tasks, especially when trained on large-scale datasets. ViT models pre-trained at scale also demonstrate performance superior to CNNs after being transferred to medium- or small-scale datasets.

![Image](images/ViT(2020, Google Research)_11.png)

1.1 Background and Significance of ViT

  1. The success of Transformers in NLP inspired their application in CV. The enormous success of Transformers in NLP via self-attention mechanisms motivated researchers to explore their application in computer vision.
  2. ViT demonstrated that Transformers can process images independently of CNNs. ViT is a model proposed by Google in 2020 that directly applies the Transformer to image classification without any modifications to the Transformer architecture. It divides images into fixed-size patches as tokens, introduces positional encoding to preserve spatial information, and feeds them into the Transformer. Stacked Transformer layers capture both local and global relationships within the image. ViT discards convolution operations entirely, relying solely on self-attention mechanisms to capture long-range dependencies in images, demonstrating that a pure Transformer can efficiently process image data.
  3. ViT's "scale is all you need" principle. After pre-training on very large-scale datasets (such as Google JFT-300M) and transferring to medium- or small-scale datasets (such as ImageNet, CIFAR-100, VTAB, etc.), ViT's performance can surpass the state-of-the-art CNN models. For example, ViT achieved a classification accuracy of 88.55% on the ImageNet-1K dataset, demonstrating the powerful potential of Transformers in computer vision.

1.2 Comparison of ViT and CNN

| | ViT | CNN | |---|---|---| | Core operation | Self-attention (global association) | Convolution (local filtering) | | Inductive bias | No explicit bias, data-driven | Strong spatial locality, translation equivariance | | Data requirements | Relies on large-scale pre-training | Efficient on small to medium datasets | | Long-range dependency modeling | Global interaction in a single layer | Requires deep stacking or special modules (e.g., Non-local) | | Computational complexity | Quadratic growth (needs block optimization) | Linear growth (with image size) | | Interpretability | Native attention heatmaps | Relies on class activation maps (CAM) | | Typical use cases | Large-scale pre-training, cross-modal tasks | Real-time detection, edge devices |

Key distinctions: Feature modeling mechanism: local perception vs. global association

  • CNN progressively expands the receptive field through local convolutional layers, relying on hierarchical stacking to capture features, but is limited in modeling long-range dependencies.
  • ViT uses self-attention to dynamically establish global associations between image patches in a single layer, computing attention weights between any two positions. It is especially good at capturing semantic connections across regions (e.g., co-recognition of a cat's head and tail) without requiring deep stacking.

Input processing paradigm: pixel array vs. serialized blocks

  • CNN directly processes raw pixel matrices, relying on the spatial invariance of convolutional kernels, treating all pixels equally.
  • ViT divides images into fixed-size patches (e.g., 16×16) and linearly maps them into sequences, introducing learnable positional encoding to retain spatial information, and uses attention weights to dynamically focus on key regions (e.g., enhancing the subject's weight in classification tasks).

Data efficiency and scalability: small-data advantage vs. large-data potential

  • CNN, thanks to its locality prior, converges faster and is less prone to overfitting in small-data scenarios (e.g., ImageNet-1K).
  • ViT lacks spatial bias and requires very large-scale pre-training (e.g., JFT-300M) to realize its potential, but performance improves significantly as data grows (e.g., ViT-H/14 reaches 88.55% accuracy), with stronger transfer learning generalization, indicating more universal representations.

Computational efficiency and deployment: real-time performance vs. interpretability

  • CNN dominates in real-time inference (e.g., mobile) due to linear computational complexity and hardware optimizations.
  • ViT's self-attention leads to quadratic complexity, making high-resolution image processing expensive, but can be optimized through patching or sparse attention.

Structural characteristics and information flow

  • Information retention: ViT's design without downsampling (e.g., ViT-Base maintains a constant sequence length) preserves more spatial details, benefiting dense prediction tasks (e.g., segmentation); CNN progressively compresses spatial dimensions through pooling, which may cause loss of fine-grained information.
  • Residual connections: ViT's skip connections have a greater impact on feature stability and performance; experiments show that removing them causes a much larger accuracy drop than in CNN (e.g., ResNet), indicating that ViT relies more heavily on identity paths to maintain gradient flow.
  • ViT's attention heatmaps provide intuitive explanations, whereas CNN relies on post-processing visualization (e.g., CAM).
  • Representation consistency: Feature maps at different depths in ViT are more similar to each other, indicating that ViT iteratively refines global representations through attention, rather than CNN's hierarchical feature reconstruction.

Key points:

  • Replacement or complement: ViT does not completely replace CNN; it shows advantages in large-data, strong semantic association tasks (e.g., cross-modal understanding), while CNN remains competitive in resource-constrained, small-data scenarios.
  • Trend of fusion: Frontier works (e.g., ConvNeXt, CoAtNet) explore hybrid architectures combining convolution and attention, combining local efficiency with global modeling, signaling that the boundary between the two is gradually blurring.
  • Since Transformers require large amounts of data to achieve high accuracy, the data collection process can lengthen project timelines. In low-data settings, CNNs generally outperform Transformers.
  • Transformer training time is less than CNN; if model training time is limited, Transformers can be chosen by comparing both computational efficiency and accuracy.

2. ViT Model in Detail

The overall ViT architecture can be broken down into 3 modules (5 steps) as shown in the diagram above: Linear Projection of Flattened Patches: Obtain complete linear embeddings

  • Patch Embedding
  • Classification Token
  • Positional Embedding Transformer Encoder: Extract features n * Transformer Encoder Layer(s) MLP Head: Final classification layer One-sentence summary: We divide the image into fixed-size patches, linearly embed each patch, add an additional learnable classification token [cls] (for classification) and a positional embedding to the sequence, feed the resulting vector sequence into a standard Transformer encoder to obtain feature representations. For classification tasks, an MLP is attached at the end to classify using the [cls] token.

2.1 Patch Embedding

Purpose: Obtain a 1D vector that can be fed into the Transformer by splitting the image into patches and performing patch embedding. What method can we use to convert multi-dimensional image data into 1D data similar to NLP?

  • Pixel-by-pixel expansion If each pixel of the image is treated as an independent patch (similar to a single word in NLP), then for a 224×224 image, the number of patches would be 224×224 = 50,176. This approach is computationally prohibitive. For comparison, BERT only needs to process 512 patches but already has 481 billion parameters and requires 20 hours of training on 2048 TPUv4s. Having far more patches than BERT is clearly infeasible.
  • Axis-wise expansion Decompose the image into two independent 1D sequences along the horizontal and vertical axes, and perform self-attention computation on each separately. By decomposing the 2D complexity (H×W) into two 1D computations (H+W), computational cost is significantly reduced, but diagonal global dependencies may be overlooked.
  • Using processed feature maps as Transformer input First extract feature maps using a CNN (e.g., ResNet50), then use them as Transformer input. For example, after processing with ResNet50, an image can be converted to a 14×14 feature map (i.e., 196 patches), which are then fed into the Transformer, reducing computational complexity.
  • Dividing the image into local window blocks Divide the image into local windows (e.g., 7×7) and compute self-attention only within each window, similar to local convolution processing. This greatly reduces computational overhead (complexity reduces to HWK²), and cross-window interaction is achieved through shifted windows or hierarchical structures (e.g., Swin Transformer).

2.1.1 Patch Blocks

Purpose: Split the image into patch blocks. The Transformer expects a 2D matrix input (N, D), where N is the sequence length and D is the dimension of each vector in the sequence (D is commonly 256). To input an image into the Transformer Encoder, we need to convert a 3D image of H×W×CH×W×C into a 2D input (N, D). H×W×CN×(P2×C)H×W×C \Rightarrow N×(P^2 \times C), where N=HW/P2N = HW/P^2 ViT uses the idea of evenly dividing the image into windows/patches. Each small block is called a patch, and each patch is treated as one word in an NLP Transformer. The 3D image of H×W×CH×W×C is flattened into a 2D patch sequence with sequence length N=HW/P2N = HW/P^2. Each 2D patch has dimension P2×CP^2 \times C, where P is the patch size and C is the number of channels.

![Image](images/ViT(2020, Google Research)_121.png)

Example: Data flow for patch blocks Original image size: 224×224×3224×224×3 (H×W×CH×W×C). Each patch size is set to 16 (P=16), so each patch has size 16×16×316×16×3. There are (224/16)×(224/16)=14×14=196(224/16)×(224/16) = 14×14 = 196 patches.

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
# Implemented using the einops library

2.1.2 Patch Embedding

Purpose: Patch to token. The embedding operation stretches each patch block into a 1D vector; each 1D vector can be treated as a token. After the patching step, each patch is flattened and mapped through a linear transformation (i.e., a fully connected layer) to a fixed dimension D, forming the embedding vector (token) for each patch. The dimension of this vector is the Transformer's input dimension. Example: Patch block size: 16×16×316×16×3, flattened to a 1×7681×768 (768=16×16×3768 = 16×16×3) token. With 14×14=19614×14 = 196 patches, we get a matrix XpatchX_{patch} of dimension 196×768196×768.

self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)

The patch embedding method was subsequently improved: a 2D convolution operation merges the patch splitting and linear transformation into a single step.

![Image](images/ViT(2020, Google Research)_136.png)

Example: Data flow for Patch Embedding With embedding dimension embed_dim = 768, a convolutional kernel of size 16×16×316×16×3, and stride 16, performing convolution on a 224×224×3224×224×3 image produces a patch embeddings matrix XpatchX_{patch} of dimension 196×768196×768.

def forward(self, x):
        x = self.proj(x)  # Result shape: (batch_size, embed_dim, num_patches_H, num_patches_W)
        x = x.flatten(2)  # Flatten output to (batch_size, embed_dim, num_patches)
        x = x.transpose(1, 2)  # Transpose to (batch_size, num_patches, embed_dim)
        x = self.norm(x)
        return x


if __name__ == "__main__":
    img_size = 224       # Image size
    patch_size = 16      # Size of each patch
    in_channels = 3      # Number of image channels
    embed_dim = 768      # Patch embedding dimension

    patch_embedding = PatchEmbedding(img_size=img_size, patch_size=patch_size, in_channels=in_channels,
                                     embed_dim=embed_dim)
    batch_size = 2
    x = torch.randn(batch_size, in_channels, img_size, img_size)
    output = patch_embedding(x)

    print("Final output shape:", output.shape)

Why must images be split into patches and then converted to tokens?

  1. Reduce model computation. In the Transformer architecture, assuming an input sequence of length NN, the computational complexity of the self-attention mechanism is O(N2)O(N^2) because each token must compute attention scores with all tokens (including itself). In ViT, the number of patches N=HW/P2N = HW/P^2 depends on the patch size PP. When PP is too small, NN increases, causing computation costs to skyrocket. Therefore, we need to choose an appropriate patch size to balance computational efficiency. Loading...
  2. Raw image data contains a lot of redundant information. Unlike the rich semantic information in natural language, image data contains a large amount of redundant information. Neighboring pixel values are often similar, so there is no need to use an excessively small patch size (e.g., P=1P=1), as this would lead to overly fine computation granularity, increasing computational burden without significantly improving model performance. This property is also one of the reasons why self-supervised learning models based on pixel prediction, such as MAE (Masked Autoencoder), can succeed.

2.2 Classification Token

Purpose: A classification token [cls] is concatenated onto the patch embeddings, allowing the Transformer to access global information across the entire image context for use in downstream tasks (classification). The Classification Token draws on the idea from BERT: prepend a [cls] to each patch embedding, typically placed at the first position of the sequence. [cls] is a "placeholder"-like vector representing the global information of the image. [cls] is a learnable embedding vector; it is fed into the Transformer Encoder together with other patch embeddings. Through continuous training, the model learns that the output of [cls] represents the features of the entire image, and this output is used to determine the final output class or other downstream tasks (more directly: [cls] represents the category of the image corresponding to all other patch embeddings). Understanding approach 2: ViT only uses the Encoder part of the Transformer, not the Decoder. The role of [cls] is somewhat similar to the Query in a Decoder, with the corresponding Keys and Values being the outputs of the other patch embeddings. Understanding approach 3: The other embeddings represent features of different patches, while [cls] aggregates information from all patches, producing a new embedding to represent the entire image.

![Image](images/ViT(2020, Google Research)_148.png)

Example: Data flow for Classification Token The [cls] token is a vector of size 1×7681×768. Concatenating it before the patch embeddings (size 196×768196×768) yields patch embeddings of size 197×768197×768 (Xcls+patchX_{cls+patch}). For classification, the embedding corresponding to [cls] is used for classification.

# dim=768
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# Forward code
# Becomes (b, 196, 1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# Concatenate with the patch blocks
# Append the extra token, resulting in b, 197, 768
x = torch.cat((cls_tokens, x), dim=1)

2.3 Positional Embedding

Purpose: Since the Transformer itself does not consider input order, Positional Embedding embeds positional encoding into image patches to explicitly represent the spatial position of each patch in the original image. ViT introduces a positional embedding (EposE_{pos}) to add spatial position information to the sequence. In ViT, EposE_{pos} is a learnable vector (allowing the model to learn the semantic representation of each position based on data; absolute positional information), whose length is the same as Xcls+patchX_{cls+patch}, so they can be directly added to obtain the final input vector for the Transformer Encoder. Research on ViT found that: positions that are closer together tend to have more similar positional encodings. In addition, a row-column structure emerges; patches in the same row/column have similar positional encodings. Positional embeddings can be generated in two ways:

  • Fixed-algorithm positional embedding
  • Learnable positional embedding For positional embeddings in CV, the following considerations broadly apply:
  • Ignoring positional information
  • Treating CV like NLP, considering only 1D positional information (absolute positional information)
  • Considering CV-specific 2D positional information (absolute positional information)
  • Relative positional encoding: considering both relative and absolute positional information Example: Data flow for Positional Embedding Positional embedding generates a learnable vector of size 197×768197×768, which is directly added to Xcls+patchX_{cls+patch} of size 197×768197×768 to obtain the final patch embeddings (XinputX_{input}).

![Image](images/ViT(2020, Google Research)_162.png)

# num_patches=196, dim=768, +1 because of the additional cls token at the start of decoding
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

Interpretation of XinputX_{input} with dimensions 197×768197×768:

  • First row [cls] token: A vector representing the global information of the entire image, whose values are updated during training, aggregating information from all image patches to better capture image features.
  • Subsequent 196 rows of image patch features: Each row corresponds to the features of one image patch, reflecting local information of that patch within the image. These features are extracted through convolution operations, typically using a sliding window method to obtain patch features.
  • Positional embedding

2.4 Transformer Encoder

Purpose: Extract features from the input vector information. Transformer Encoder input: Input = Patch Embedding + Classification Token + Positional Embedding Xinput=Xcls+patch+EposX_{input} = X_{cls+patch} + E_{pos} The Transformer Encoder is a stack of two blocks, repeated L(/n) times in total (in the ViT-Base architecture, L=12, i.e., 12 Encoder Blocks). These two blocks are:

  • LayerNorm + Multi-Head Attention
  • LayerNorm + MLP

![Image](images/ViT(2020, Google Research)_174.png)

Key points:

  1. The input shape and output shape of the Transformer Encoder remain unchanged.
  2. As can be seen from the implementation code, there is an additional LayerNorm layer after the entire Transformer Encoder.
  3. LayerNorm Although BatchNorm is more commonly used in CV, ViT still uses LN as in NLP. The reason for using LN instead of BN (same as NLP): BN normalizes across all samples for each channel (inter-sample), while LN normalizes across all features for each individual sample (intra-sample). More specifically in ViT, although LN processes image data, before applying LN, the image has already been divided into patches. Each patch represents one "word," so the problem is being solved using semantic logic. Therefore, in ViT, LN is also used following semantic logic.
  4. Multi-Head Attention Purpose: To enable the network to extract more accurate representations by leveraging multiple perspectives simultaneously, capturing richer features; analogous to using multiple kernels in CNNs to extract features separately.
  5. What specifically does Attention learn in ViT? In shallow layers, ViT can only attend to nearby pixels; as the network deepens, ViT gradually learns to seek relevant information from more distant pixels. This process is very similar to progressively expanding the receptive field layer by layer in CNNs. MLP block: Consists of two Linear layers (fully connected layers) + GeLU activation + Dropout layers, using an inverted bottleneck structure. The first fully connected layer expands the number of input nodes by 4x: 197×768197×768 becomes 197×3072197×3072; the second fully connected layer restores the number of nodes: 197×3072197×3072 becomes 197×768197×768. Bottleneck structure (e.g., ResNet): first downsample then upsample
  • Downsample (reduce channels) → convolution operation → upsample (restore channels)
  • Purpose: reduce computation while preserving the network's expressive power. Inverted Bottleneck structure (e.g., MobileNetV2): first upsample then downsample
  • Upsample (increase channels) → nonlinear transformation (e.g., depthwise separable convolution or MLP) → downsample (restore original channels)
  • Purpose: perform complex feature transformations in high-dimensional space to enhance feature expressiveness.

![Image](images/ViT(2020, Google Research)_178.png)

Example: Data flow for Transformer Encoder The first Encoder Block in the Transformer Encoder receives a 197×768197×768 matrix input. All subsequent Encoder Blocks receive a 197×768197×768 output matrix from the previous Encoder Block. Inside one Encoder Block, the input first passes through LayerNorm (dimensions unchanged), then is fed into MultiHead Attention. In MultiHead Attention, a Linear layer converts the input into a 197×2304197×2304 (768×3768×3) qkv matrix. The qkv matrix is then reshaped into three 197×768197×768 matrices representing q, k, and v. Then q, k, and v are further reshaped into 12×197×6412×197×64 matrices representing 12 attention heads. Attention computation then proceeds. The output of MultiHead Attention is added to the input (residual connection) to obtain the final output, followed by subsequent LayerNorm, MLP, etc.

self.act = act_layer()
        # Define the second fully connected layer (compress back to original dimension)
        self.fc2 = nn.Linear(hidden_features, out_features)
        # Dropout layer
        self.drop = nn.Dropout(drop_ratio)

    def forward(self, x):
        # Forward propagation
        x = self.fc1(x)      # Expand dimension
        x = self.act(x)      # Activation function
        x = self.drop(x)     # Dropout
        x = self.fc2(x)      # Compress back to original dimension
        x = self.drop(x)     # Dropout
        return x

class EncoderBlock(nn.Module):
    """
    Transformer Encoder Block
    A single-layer Block consists of LayerNorm, Attention, and MLP.
    In the MLP block, the MLP layer's hidden dimension is 4x the input dimension.
    """
    mlp_ratio = 4  # MLP hidden layer dimension expansion ratio (standard ViT design)

    def __init__(
            self,
            dim,                         # Dimension of input tokens
            num_heads,                   # Number of attention heads
            qkv_bias=False,              # Whether to use bias in qkv
            drop_ratio=0.,               # Dropout ratio after projection
            attention_dropout_ratio=0.,  # Dropout ratio for attention weights
            drop_path_ratio=0.,          # DropPath ratio (for Stochastic Depth)
            norm_layer=nn.LayerNorm,     # Normalization layer type (default LayerNorm)
            act_layer=nn.GELU            # Activation function type (default GELU)
    ):
        super(EncoderBlock, self).__init__()
        # First normalization layer (before attention)
        self.norm1 = norm_layer(dim)

        # Multi-head attention module
        self.attention = MultiheadAttention(
            dim, num_heads,
            qkv_bias=qkv_bias,
            attention_dropout_ratio=attention_dropout_ratio,
            proj_drop=drop_ratio
        )

        # Stochastic Depth module (identity mapping if ratio is 0)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()

        # Second normalization layer (before MLP)
        self.norm2 = norm_layer(dim)

        # Compute MLP hidden layer dimension (4x input dimension)
        mlp_hidden_dim = int(dim * self.mlp_ratio)

        # MLP module
        self.mlp = MLP(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            drop_ratio=drop_ratio,
            act_layer=act_layer
        )

    def forward(self, x):
        # Residual connection 1: Multi-head attention part
        # Normalize, then attend, then DropPath, then add residual
        x = x + self.drop_path(self.attention(self.norm1(x)))

        # Residual connection 2: MLP part
        # Normalize, then MLP, then DropPath, then add residual
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class TransformerEncoder(nn.Module):
    """
    Stack L Transformer Encoder Blocks
    to form the complete Transformer Encoder
    """
    def __init__(

2.5 MLP Head

Purpose: Use nonlinear activation functions for classification prediction. The MLP Head typically includes the following layers:

  • Fully connected layer (Linear): Maps the 768-dimensional feature vector to the number of classes (e.g., 1000 classes), using a linear transformation.
  • Activation function: May use an activation function (e.g., ReLU or GELU) to introduce nonlinearity.
  • Output layer: Another fully connected layer that ultimately outputs classification logits (unnormalized classification scores).

![Image](images/ViT(2020, Google Research)_193.png)

Example: From the Transformer Encoder output of 197×768197×768, extract the 1×7681×768 [cls] for MLP classification. In the original paper, the authors note that when training on ImageNet-21K, the MLP Head consists of Linear + tanh activation + Linear. When transferring to ImageNet-1K or custom data, a single Linear layer also works. ViT main class code:

nn.init.trunc_normal_(m.weight, std=.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                # Convolutional layer: Kaiming normal initialization (for patch embedding)
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                # LayerNorm layer: weights initialized to 1, biases to 0
                nn.init.zeros_(m.bias)
                nn.init.ones_(m.weight)

    # Forward propagation
    def forward(self, x):
        # Input x shape: [B, C, H, W] (e.g., [B,3,224,224])

        # Step 1: Convert image to patch embeddings
        x = self.patch_embed(x)  # Output shape: [B, 196, 768] (assuming hidden_dim=768)

        # Step 2: Expand CLS token to current batch size
        # Original shape: [1,1,768] → expanded: [B,1,768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        # Step 3: Prepend CLS token to patch sequence
        # Shape after concatenation: [B, 196+1, 768] → [B,197,768]
        x = torch.cat((cls_token, x), dim=1)

        # Step 4: Add positional encoding and apply Dropout
        # pos_embed shape automatically broadcasts to [B,197,768]
        x = self.pos_drop(x + self.pos_embed)

        # Step 5: Pass through stacked Transformer encoder
        # Output shape remains: [B,197,768]
        x = self.blocks(x)

        # Step 6: Extract only the CLS token features (index 0)
        # Output shape: [B,768]
        x = x[:, 0]  # Equivalent to x[:, 0, :]

        # Step 7: Pass through classification head to get final logits
        # Output shape: [B, num_classes] (e.g., [B,1000])
        x = self.head(x)

        return x

2.6 ViT Workflow

Official animated workflow:

![Image](images/ViT(2020, Google Research)_203.png)

Consolidated ViT diagram from Sections 2.1 to 2.5:

![Image](images/ViT(2020, Google Research)_205.png)

Detailed flowchart:

![Image](images/ViT(2020, Google Research)_207.png)

Data flow in Transformer Encoder: Aman Arora's blog has a clear data flow diagram, as shown below:

![Image](images/ViT(2020, Google Research)_210.png)

For an animated flow diagram, see: https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.htmlA Visual Guide to Vision Transformers This is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the trans

3. ViT Training and Fine-Tuning

Training method: Typically uses a strategy of large-scale pre-training + transfer learning (fine-tuning) to fully leverage the learning capability of large-scale datasets and adapt to the data requirements of different tasks. Datasets used for pre-training:

  • ILSVRC-2012 ImageNet dataset: 1000 classes
  • ImageNet-21k: 21k classes
  • JFT: 18k high resolution images Transfer pre-training to small datasets:
  • CIFAR-10/100
  • Oxford-IIIT Pets
  • Oxford Flowers-102
  • VTAB The original paper authors designed 3 different sizes of ViT models: | DModel | Layers | Hidden Size | MLP Size | Heads | Params | |---|---|---|---|---|---| | ViT-Base | 12 | 768 | 3072 | 12 | 86M | | ViT-Large | 24 | 1024 | 4096 | 16 | 307M | | ViT-Huge | 32 | 1280 | 5120 | 16 | 632M | Layers: Number of times the Encoder Block is stacked in the Transformer Encoder Hidden Size: The dim (vector length) of each token after passing through the Embedding layer MLP Size: The number of nodes in the first fully connected layer of the MLP Block in the Transformer Encoder (4x Hidden Size) Heads: Number of heads in Multi-Head Attention in the Transformer ViT-L/16 means ViT-Large + 16 patch size

3.1 Pre-training Phase

  1. Objective: Pre-train ViT on very large-scale datasets (e.g., JFT-300M, ImageNet-21K) to learn general visual features.
    • Directly train ViT
      • Use a linear classification head (Prediction Head) for class prediction during pre-training.
      • DD is the output dimension of the Transformer.
      • KK is the number of classes in the target dataset.
    • Use AdamW as optimizer to improve training stability.
    • Apply Learning Rate Decay and Warm-up strategies to improve training effectiveness.
    • Apply Stochastic Depth (random depth dropout) to avoid gradient vanishing.

3.2 Transfer Learning / Fine-Tuning Phase

A pre-trained ViT model is a powerful feature extractor; its output features can be used to accomplish many rich downstream tasks, such as more complex classification tasks, object detection, etc. When performing these tasks, we feed new data to the pre-trained model while keeping the main architecture unchanged as much as possible. For example, freezing the overall parameters of ViT and adding a new model only at the output layer, updating only the new model's parameters during training. This approach fully leverages ViT's pre-trained feature extraction capabilities while effectively adapting to different task requirements.

  1. Objective: Transfer the pre-trained ViT to smaller downstream datasets (e.g., ImageNet-1K, CIFAR-100) to adapt to specific tasks.
  2. Key approaches:
    • Remove the original prediction head and replace it with a feed-forward layer (D×KD×K) adapted to the new dataset.
    • Patch size remains unchanged: if the input image is larger, the patch count N=HW/P2N = HW/P^2 increases.
    • ViT's Positional Encoding is initially designed for fixed-size pre-training inputs.
    • If the input image size differs, positional encoding cannot be directly reused and needs to be adapted to the new patch count through 2D interpolation (bilinear interpolation) to maintain spatial information consistency.

4. Development of ViT

ViT demonstrated the emergent capabilities of large-scale data + large models in the image domain, and also laid the groundwork for the subsequent development of multimodal tasks. Outstanding applications of ViT: Image classification, image caption generation, image segmentation, anomaly detection, action recognition, autonomous driving, etc.

![Image](images/ViT(2020, Google Research)_235.png)

Ready to test yourself?

Put your knowledge to work.

Practice real interview questions on this topic. Get AI feedback on exactly what you missed.