Mahadev Maitri's Logo

Vision Transformer (ViT) for Image Classification

May 25th, 2025

In 2019, during my bachelor's degree, I started with a simple project of recognizing handwritten digits using a neural network. This was a project on Kaggle (Digit Recognizer) and I struggled to understand the basics of neural networks like forward and backward propagation. I took a course at my university which introduced me to the basics of neural networks, and another one by DeepLearning.AI. I built a simple neural network with a single hidden layer using just the numpy package. This was my first experience with implementing a neural network from scratch, and I was able to achieve a decent accuracy of around 95% on the MNIST dataset.

After that, I started exploring different approaches to image classification, particularly convolutional neural networks (CNNs). I learned about CNN architecture, how they work, and how to implement them using Keras. When I submitted my results to Kaggle for benchmarking, I achieved around 98% accuracy on the MNIST dataset.

I currently work as an AI Engineer, where I focus on developing agents and tools based on large language models (LLMs). However, I still have a keen interest in computer vision and image classification. I decided to revisit the handwritten digit recognition project, but this time I wanted to use a more advanced approach: the Vision Transformer (ViT). Let me compare the ViT with my previous CNN-based approach using the Kaggle Digit Recognizer competition as a testbed (since it makes it easier to compare the approaches).

Foundations: Convolutional Neural Networks (CNNs)

Convolutional Neural Networks (CNNs) have been the go-to solution for computer vision tasks for years. Their design, inspired by how the human visual system works, allows them to learn patterns and features in images step by step, making them highly effective for image-related tasks. Some of the real-world applications of CNNs include:

CNN Fundamentals

At its core, CNNs leverage three key principles:

  • Local Connectivity: CNNs exploit the spatial structure of images by enforcing local connectivity patterns. This means that each neuron in a convolutional layer is only connected to a small region of the input image, allowing the network to learn local features such as edges and textures.
  • Parameter Sharing: CNNs use the same set of weights (filters or kernels) across different spatial locations in the input image. This weight sharing significantly reduces the number of parameters in the model, making it more efficient and easier to train.
  • Translation Invariance: By combining convolutional and pooling layers, CNNs achieve translation invariance, meaning that the model can recognize objects in images regardless of their position. This is particularly useful for tasks like object detection and image classification.

CNN Architecture Components

A typical CNN architecture consists of several key components:

  • Convolutional Layers: The core building block of CNNs, where filters (small matrices of weights) slide over the input image to extract features. Each filter learns to detect specific patterns, such as edges, corners, or textures, by performing convolution operations with the input image. Multiple filters are applied simultaneously to capture different types of features, and each filter produces a feature map that highlights where specific patterns are found in the image. Below is an illustration of a convolutional layer with a 3x3 filter:Convolutional Layer Visualization (source: medium.com)
  • Pooling Layers: These layers reduce the spatial dimensions of the feature maps while retaining the most important information. Pooling helps downsample the feature maps, reducing computational complexity and making the model more robust to small translations and distortions in the input image. There are different types of pooling, such as max pooling (which selects the maximum value from a small region) and average pooling (which calculates the average value from a small region of the feature map). Below is an illustration of a max pooling layer with a 2x2 filter:Pooling Layer Visualization (source: medium.com)
  • Fully Connected Layers: After several convolutional and pooling layers, the high-level features are flattened and passed through one or more fully connected layers. These layers are similar to traditional neural networks and are responsible for making the final predictions based on the learned features. The output of the last fully connected layer is typically passed through a softmax activation function to obtain probabilities for each class in the classification task.

My CNN Implementation

Here's a simple implementation of a CNN using Keras for the MNIST digit classification task:

python
1from keras import models, layers
2
3model = models.Sequential()
4model.add(layers.Conv2D(32, (3, 3), input_shape=(28, 28, 1), activation='relu'))
5model.add(layers.MaxPooling2D(pool_size=(2, 2)))
6model.add(layers.Dropout(0.2))
7model.add(layers.Conv2D(128, (3, 3), activation='relu'))
8model.add(layers.Flatten())
9model.add(layers.Dense(128, activation='relu'))
10model.add(layers.Dropout(0.2))
11model.add(layers.Dense(32, activation='relu'))
12model.add(layers.Dropout(0.2))
13model.add(layers.Dense(10, activation='softmax'))
14
15model.summary()
16

Do I need these many layers? Not really. The MNIST dataset is relatively simple, and a smaller architecture would suffice. Back then, I was just experimenting with different architectures to see how they perform depending on the number of layers and neurons.

History of CNNs: AlexNet and Beyond

The era of Deep Learning in computer vision began with the introduction of AlexNet in 2012, which won the ImageNet competition with a top-5 error rate of 15.3%. This eight-layer network, with its five convolutional layers and three fully connected layers, introduced key innovations like the ReLU activation function and dropout regularization, which remain standard practice today. You can visualize these architectures using tools like https://alexlenail.me/NN-SVG/

Following AlexNet's success, architectures like VGG, ResNet, and Inception pushed the boundaries of what was possible, introducing deeper networks, skip connections to combat vanishing gradients, and multi-scale feature extraction.

Enter the Transformer: From NLP to Vision

While CNNs dominated computer vision for years, the introduction of the Transformer architecture in 2017 revolutionized natural language processing (NLP) by enabling models to learn long-range dependencies and context much more effectively. The key innovation of the Transformer is the self-attention mechanism, which allows the model to weigh the importance of different parts of the input sequence when making predictions. This mechanism enables the model to capture global relationships in the data, making it particularly effective for tasks like language translation and text generation.

Core Transformer Components

The Transformer architecture consists of several key components that work together to process input sequences:

  • Self-Attention: This mechanism allows the model to weigh the importance of different words in a sequence, capturing long-range dependencies and contextual relationships. It computes attention scores using Query (Q), Key (K), and Value (V) matrices. Below is a illustration of Scaled Dot-Product Attention (from original paper):Scaled Dot-Product Attention Visualization (source: arxiv.org)
  • Multi-Head Attention: This component extends self-attention by allowing the model to focus on different parts of the input sequence simultaneously. It splits the input into multiple heads, each with its own set of Q, K, and V matrices, and computes attention scores independently. The outputs of all heads are then concatenated and linearly transformed to produce the final output. This allows the model to capture diverse relationships in the data. Below is an illustration of Multi-Head Attention (from original paper):Multi-Head Attention Visualization (source: arxiv.org)

Vision Transformers: Architecture and Implementation

The success of Transformers in NLP (like BERT and GPT) led to their adaptation for computer vision tasks, resulting in the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. This is commonly known as the Vision Transformer (ViT). The difference between CNNs and ViTs lies in how they process images. While CNNs use convolutional layers to extract local features, ViTs treat images as sequences of patches, similar to how Transformers process text.

ViT Architecture Overview

At a high level, the ViT architecture leverages the power of self-attention to capture global relationships in images. The key steps in the ViT architecture are as follows:

  • Patch Embedding Layer: The input image is treated like a sentence. It is divided into fixed-size patches (e.g., 16x16 pixels), which are analogous to words (or tokens). These patches are flattened and linearly projected into an embedding space. A special [CLS] token is appended to the sequence of patch embeddings to represent the entire image.
  • Transformer Encoder: This is the core of the ViT architecture. The sequence of patch embeddings, now combined with position embeddings to retain spatial information, is fed into a series of standard Transformer encoder blocks. Each block applies multi-head self-attention and a feed-forward network, allowing every patch to interact with and learn from every other patch. Below is an illustration of how Embedded Patches are processed through the Transformer Encoder:Transformer Encoder Visualization (source: arxiv.org)
  • Classification Head: After processing the patch embeddings through the Transformer encoder, the output corresponding to the [CLS] token is passed through a classification head, which typically consists of a fully connected layer followed by a softmax activation function. The [CLS] token's output representation is used as the final representation of the image for making classification predictions.

Patch Embedding: Turning Images into Sequences

Let's look into how ViT translates a 2D image into a 1D sequence that a Transformer can understand.

  • Image Partitioning: An image of size H * W is divided into N non-overlapping patches of size P * P. The number of patches is calculated as N = (H * W) / (P * P). For example, if we have a 224x224 image and we partition it into 16x16 patches, we will have N = (224 * 224) / (16 * 16) = 196 patches.
  • Linear Projection: Each patch is flattened into a long vector of size P * P * C (where C is the number of channels, e.g., 3 for RGB images). This vector is then mapped to the model's desired embedding dimension (e.g., 768 for ViT-B/16) using a trainable linear projection layer. This step is very important because it ensures that the patch embeddings are in a suitable space for the subsequent Transformer layers.
  • The [CLS] Token: Token: Inspired by BERT's use of a special token for sentence classification, ViT adds a learnable [CLS] token to the beginning of the patch sequence. The self-attention mechanism ensures that this token aggregates information from all other patches. The final output of this token is then used for the classification task, acting as a holistic representation of the image.
Vision Transformer Architecture (source: arxiv.org)

Position Embeddings

A standard Transformer does not have any notion of the order of the input sequence. In other words, if you shuffle the words in a sentence, the self-attention output would be different but not necessarily wrong. For images, this will be a problem because the spatial arrangement of patches is crucial for understanding the image.

To solve this, ViT introduces position embeddings. These are learnable parameters that are added to the patch embeddings to provide spatial information. By learning these positions during training, the model can understand the relative spatial location of each patch. Without them, the ViT would see the image as a mere "bag of patches," losing all structural information.

ViT uses a simple approach for position embeddings but there are some advanced techniques used to improve the spatial understanding of the images. For example, Swin Transformer introduces a hierarchical structure that allows the model to capture multi-scale features, similar to how CNNs work. This approach uses shifted windows to compute self-attention, which helps the model learn both local and global features effectively.Swin Transformer Architecture (source: arxiv.org)

Self-Attention for Vision

Now that we have the patch embeddings and position embeddings, we can apply the self-attention mechanism to these embeddings. The self-attention mechanism allows the model to weigh the importance of different patches when making predictions. For any given patch, self-attention computes a weighted sum of all other patches (let's call it the context vector). This context vector is then used to update the representation of the given patch, allowing the model to incorporate information from its surroundings.

For example, if a patch contains a part of a digit, the self-attention mechanism can help the model understand how that patch relates to other patches in the image, such as the other parts of the digit or the background. This allows the model to capture global relationships in the image, which is crucial for accurate classification. Multi-Head Attention further enhances this by allowing the model to learn different types of relationships simultaneously in parallel heads.

The main difference between CNNs and ViTs is that CNNs learn local features through convolutional filters, while ViTs learn global relationships through self-attention. This allows ViTs to capture long-range dependencies and contextual information in the image, which is particularly useful for complex tasks like object detection and segmentation.

My ViT Implementation for Digit Recognition

Here's a simple implementation of a Vision Transformer using PyTorch for the MNIST digit classification task:

python
1import torch
2import torch.nn as nn
3from torchvision.models import vit_b_16, ViT_B_16_Weights
4
5class SimpleViT(nn.Module):
6    def __init__(self, num_classes=10, input_size=224, patch_size=16):
7        super(SimpleViT, self).__init__()
8
9        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
10
11        # Modify the classifier head for 10 classes (digits 0-9)
12        self.vit.heads = nn.Sequential(
13            nn.Linear(self.vit.heads[0].in_features, 512),
14            nn.ReLU(),
15            nn.Dropout(0.3),
16            nn.Linear(512, num_classes),
17        )
18
19        # Freeze early layers to prevent overfitting on small dataset
20        for param in self.vit.encoder.layers[:6].parameters():
21            param.requires_grad = False
22
23    def forward(self, x):
24        return self.vit(x)
25

Conclusion

To wrap up, Vision Transformers (ViTs) brings a new paradigm to image classification by leveraging the power of self-attention and Transformer architecture. They treat images as sequences of patches, allowing the model to capture global relationships and contextual information effectively. This is a significant shift from the traditional CNN approach, which relies on local feature extraction through convolutional layers.

Models like Google's Gemini and OpenAI's GPT-4 Vision are not just vision models or language models; they are multimodal models that can process both text and images. They represent the convergence of visual and textual understanding, enabling more comprehensive AI systems that can interpret and generate content across different modalities.

Let's compare them:

  • The Goal: While a CNN or a ViT might classify an image (e.g., "this is a cat"), a multimodal model can understand the context of that image in relation to text (e.g., "this cat is sitting on a couch"). The task shifts from classification to understanding and generating content across modalities.
  • The Architecture: At their core, these multimodal models still use the Transformer architecture, but they are designed to handle both text and image inputs. They incorporate mechanisms to process and integrate information from both modalities, allowing them to learn richer representations.

In summary, Vision Transformers represent a significant advancement in computer vision, enabling models to learn global relationships and contextual information effectively. They are a step towards more comprehensive AI systems that can understand and generate content across different modalities, paving the way for more advanced applications in the future.

References