PyTorch for image classification - Part 1

PyTorch for image classification - Part 1

Loading the data and Model Selection

1. Introduction

This tutorial shows how to classify images using a pretraining Residual Neural Network (ResNet). Image classification is a supervised learning problem with the objective of training a model that learns the relationship between input features and corresponding labels. The output of a classification model is a discrete label or category, indicating the class to which the input belongs.

We will demonstrate the following concepts:

  • Efficiently loading a dataset off the disk.

  • Pulling a pre-trainer model and fine-tuning it.

  • Mitigate overfitting by data augmentation and dropout.

And we will follow a standard machine-learning workflow:

  • Examine and understand the data

  • Build an input pipeline

  • Build the model

  • Train the model

  • Test the model

  • Improve the model and repeat the process

In addition, we will also demonstrate how to save models locally.


  • Creating a custom dataset

  • Build a DNN model in PyTorch

  • Training the MLP model classifier in PyTorch

  • Plotting of loss and accuracy curve on training a test data

  • Evaluation of model performance - Confusion matrix

2- Setup

Import TensorFlow and other necessary libraries:

import numpy as np
import pandas as pd 
import random

from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm_notebook as tqdm

import os
import time
from copy import deepcopy
from PIL import Image
from PIL import ImageOps
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.models import resnet50, resnet101
import torch
from import Dataset, DataLoader
from import SubsetRandomSampler
import torch.nn.functional as F

from torchsummary import summary

import seaborn as sn

from utils import flatten_list, read_files_in_path

It's important to specify where the tensors and models are stored and processed, whether on a CPU or a GPU. The following code will ensure to work on the default GPU if available.

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

3- Train and Test data

The dataset is divided into training and validation sets, offering a diverse collection of images of your choice. We assume that the training set, residing at BASE_PATH_TRAIN, serves as the foundation for our model to learn and generalize. Meanwhile, the validation set at BASE_PATH_VALID provides a litmus test for the model's accuracy and effectiveness.

BASE_PATH_TRAIN = "Images/Train"
BASE_PATH_VALID = "Images/Valid"

Let's look into the dataset and extract the number of classes within our training set. I need to exclude the generic "Others" category.

classes = os.listdir(BASE_PATH_TRAIN)
classes = [x for x in classes if "Others" not in x]
print(f'Number of classes: {len(classes)}: {classes}')

Next, we will index the full path of images and their corresponding category and also introduce an element of randomness by shuffling the rows. Shuffling ensures that our model encounters a diverse mix of images during each epoch of training, enhancing its ability to generalize and recognize patterns effectively.

training_data = read_files_in_path(path = BASE_PATH_TRAIN)
training_data = training_data.sample(frac=1).reset_index(drop=True)

validation_data = read_files_in_path(path = BASE_PATH_VALID)

4- Decoding the Dataset: Label Encoding

Next, we introduce an important step — label encoding. This process not only facilitates the training of the model but also streamlines the interpretation of results.

lb = LabelEncoder()
training_data['encoded_labels'] = lb.fit_transform(training_data['labels'])
validation_data['encoded_labels'] = lb.transform(validation_data['labels'])

# Construct the Mapping Dictionary
mapping = dict(zip(lb.classes_, lb.transform(lb.classes_)))
reversed_mapping = {value: key for key, value in mapping.items()}

The above code not only encodes the categorical labels but also constructs a dictionary that maps classes to their corresponding encoded values.

5- Crafting a Data Pipeline and built-in data transformation

For this project, we prefer to write our very own data loader from scratch mainly because it will give us full control over the transformations we like to apply to the images. The pipeline loads the images off disk, applies the specified transformations, and makes the images ready for the model.

class CreatePipeline(Dataset):
    def __init__(self, data, transform):
        super(CreatePipeline, self).__init__() = data.values
        self.transform = transform

    def __len__(self):
        return len(

    def __getitem__(self, x):
        image, _, label =[x]
        im = np.asarray('RGB'))
        if self.transform is not None:
            im = self.transform(im)

        return im, label

train_pipeline = CreatePipeline(training_data, image_transforms['train'])
train_loader = DataLoader(train_pipeline, batch_size=BATCH_SIZE)

validation_pipeline = CreatePipeline(validation_data, image_transforms['valid'])
validation_loader = DataLoader(validation_pipeline, batch_size=BATCH_SIZE)

where image_tranforms defines a dictionary of image transformations for both training and validation datasets. All images pass through the transformation phase (random rotations and flips, color jittering, etc) to ensure the ResNet model is exposed to a rich and varied dataset. These random transformations ensure a more robust model and prevent overfitting. This also helps to expose the model to more aspects of the data and generalize better.

The list of transformations is below:

image_transforms = {'train':   transforms.Compose([transforms.ToPILImage(),
                               transforms.Resize((SIZE, SIZE)),
                               transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
                               transforms.ColorJitter(brightness=.1, hue=.1),
                               transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),

                    'valid':   transforms.Compose([transforms.ToPILImage(),
                               transforms.ColorJitter(brightness=.1, hue=.1),
                               transforms.Resize((SIZE, SIZE)),
                               transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

Applying distinct transformations to training and validation datasets ensures a balance between data augmentation (which helps the model generalize better during training) and ensuring a fair evaluation of the model's performance on unseen data during validation.

6- Model Selection

In the realm of deep learning, the choice of neural network architecture is a critical decision that can significantly impact the model's performance. ResNet, short for Residual Network, stands out as a powerful and influential architecture due to its ability to address challenges associated with training very deep networks.

Here's why ResNet is a compelling choice for our use case, image classification:

  1. The Vanishing Gradient Problem: One common challenge in training deep neural networks is the vanishing gradient problem, where the gradients diminish as they backpropagate through numerous layers. This can hinder the training of deep networks. ResNet introduces skip connections or residual connections (shortcut paths), allowing information to bypass certain layers. This architecture mitigates the vanishing gradient problem and makes it possible to train extremely deep networks. The skip connections also allow the model to learn identity mappings, making it easier to optimize and converge during training.

  2. Performance on Image Classification: ResNet architectures have demonstrated remarkable performance on image classification tasks, particularly in competitions like ImageNet. The ability to capture intricate features and patterns in images contributes to their success.

  3. Transfer Learning Capability: Pre-trained ResNet models on large datasets (e.g., ImageNet) can be used as a starting point for a variety of computer vision tasks. Leveraging transfer learning allows models to benefit from features learned on diverse datasets.

  4. Adaptability to Different Scales: The skip connections in ResNet allow the model to adapt to features at different scales, capturing both low-level and high-level information in images.

  5. Consistent State-of-the-Art Performance: ResNet variants have consistently achieved state-of-the-art performance in various computer vision tasks, making them a reliable choice for image classification.

7- Model Hyperparameters

With our training and validation Dataset was meticulously prepared

# Hyperparameters for Training
LR = 0.1

SIZE = 224
NUM_CLASSES = len(lb.classes_)



The above hyperparameters set the batch size to 16 images, allowing the model to incrementally learn from diverse samples. We set the Epochs equal to 24, with each epoch representing a complete pass through the entire training dataset. The Learning Rate (LR) of 0.1 determines the size of the steps our model takes. Each image is resized to a dimension of 224x224 pixels (SIZE) and the number of classes NUM_CLASSES is set. The RANDOM_SEED at 42, ensures the reproducibility of our training process, providing a consistent backdrop for our model's learning journey. We also set the MODEL_STORE_NAME.

Part 2 of this series will be published soon!!