Guide to Transfer Learning with PyTorch

Using custom images with a pre-trained model like ResNet-18 is common in transfer learning to adapt the model to a specific task or dataset. Here’s why we do it:

1. ResNet-18’s Original Purpose

ResNet-18 is pre-trained on the ImageNet dataset, which contains 1,000 general classes (like “tiger,” “airplane,” “guitar,” etc.). This is useful for recognizing those classes, but it doesn’t help if you need the model to classify images from a different set of categories (e.g., “cats” vs. “dogs” or specific product categories for an e-commerce site).

2. Transfer Learning Benefits

  • Leverages Pre-Trained Knowledge: ResNet-18’s initial layers have already learned to recognize general image features (like edges, textures, and shapes). Instead of training a model from scratch, which requires lots of data and computing power, transfer learning allows us to use this feature recognition foundation and fine-tune the model for a new, specific task.
  • Reduced Training Data Requirements: Since the model already “knows” basic image features, training with a smaller custom dataset is often sufficient to make it accurate for the new task.

3. Custom Images for Specific Needs

By fine-tuning on a set of custom images, we adapt ResNet-18’s pre-trained knowledge to our specific classes. For example, if you want a classifier that distinguishes between “cats” and “dogs,” you’ll need to train the model on those specific images. ResNet-18’s original 1,000 classes won’t cover custom classes unless fine-tuned.

4. Updating the Model’s Final Layer

The final layer of ResNet-18 is specifically designed to output predictions for 1,000 classes in the ImageNet dataset. When using custom classes, we modify this last layer to output predictions for our specific categories (like 2 classes for “cats” and “dogs” or however many classes your dataset requires).

Practical Steps in Transfer Learning:

  1. Load ResNet-18: Start with ResNet-18 pre-trained on ImageNet.
  2. Freeze Early Layers: Optionally, freeze the earlier layers (the ones capturing general features).
  3. Replace Final Layer: Replace the final fully connected layer with one that matches the number of classes in the custom dataset.
  4. Fine-Tune with Custom Images: Train the model on your dataset, adapting it to recognize custom classes.

Using transfer learning with custom images allows us to leverage the strengths of a powerful model like ResNet-18 while making it specific to your unique data.

Transfer learning is a powerful technique that leverages pre-trained models for specific tasks, reducing the need for large datasets and extensive computation. In this guide, we’ll walk through a complete workflow to implement transfer learning using a pre-trained ResNet-18 model in PyTorch.

Prerequisites

Before starting, ensure you have Python installed, along with PyTorch, torchvision, matplotlib, scikit-learn, and seaborn. You can set up a virtual environment to manage dependencies:

# Create a virtual environment
python -m venv venv
# Activate the virtual environment
venv\Scripts\activate  # Windows
# Upgrade pip and install requirements
python.exe -m pip install --upgrade pip
pip install -r requirements.txt --cache-dir "D:/internship/transfer_learning/.cache"
pip freeze > requirements.txt

Step 1: Data Preparation and Transformation

For transfer learning, we’ll use custom data organized into train and test folders, with each category in a separate folder under train and test.

Data Transformations

To prepare images for the model, we’ll resize them to 224×224 pixels (the required input size for ResNet-18) and normalize the color channels to match the pre-trained ResNet model on ImageNet:

from torchvision import transforms

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

This code block is creating a dictionary called data_transforms, which defines two sets of transformations: one for training data ('train') and one for testing data ('test'). Both transformations are identical here, but setting them separately allows you to easily modify them later if necessary.

Explanation of Each Transformation

  • transforms.Compose([...]): Combines multiple transformations into a single operation. Each transformation will be applied sequentially.
  • transforms.Resize((224, 224)): Resizes each image to a fixed size of 224×224 pixels. This size is commonly used as the input for many convolutional neural networks (CNNs), like ResNet, as they are often pre-trained on images of this size.
  • transforms.ToTensor(): Converts the image to a PyTorch tensor format. Additionally, it scales the pixel values from a range of [0, 255] to [0, 1], which is necessary for consistency in neural networks.
  • transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): Normalizes the image tensor to have a mean of [0.485, 0.456, 0.406] and a standard deviation of [0.229, 0.224, 0.225] for each RGB channel. These values are based on the ImageNet dataset statistics, which helps with transfer learning on pre-trained models, as they expect inputs normalized this way.

Loading the Dataset

We use ImageFolder to load images, which automatically assigns labels based on folder names, allowing for easy dataset handling:

from torchvision import datasets
from torch.utils.data import DataLoader

data_dir = 'data'
image_datasets = {
    x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
    for x in ['train', 'test']
}
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
    for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

This code block sets up the dataset and data loaders for training and testing, using the torchvision.datasets.ImageFolder class and torch.utils.data.DataLoader. Here’s a breakdown of each part:

Explanation

  1. data_dir = 'data':
    • Sets the directory containing the image datasets, with subdirectories for training and testing data (e.g., data/train and data/test).
  2. Loading Datasets:
image_datasets = {
    x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
    for x in ['train', 'test']
}

datasets.ImageFolder: This function loads images from directories and applies the specified transformations.

The dataset expects subdirectories named according to the class labels within each main folder (train and test).

The images are loaded with transformations defined earlier (data_transforms), ensuring the images are resized, converted to tensors, and normalized.

image_datasets: A dictionary containing train and test datasets, which can be accessed as image_datasets['train'] and image_datasets['test'].

Creating Data Loaders:

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
    for x in ['train', 'test']
}

DataLoader: Wraps the datasets into iterable objects, allowing easy batch processing during training and testing.

batch_size=32: Loads images in batches of 32, which is useful for batch processing in neural networks.

shuffle=True: Shuffles the data at each epoch to improve model generalization during training. For testing, this doesn’t impact the outcome but provides consistency in data handling.

Dataset Sizes and Class Names:

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

dataset_sizes: A dictionary that stores the number of images in the train and test datasets, allowing for accurate tracking of dataset sizes.

class_names: Retrieves the class labels (directory names) within the train dataset, which is useful for interpreting model predictions later.

Step 2: Setting Up the Model

We’ll use ResNet-18, a popular model for image classification, and modify its final fully connected layer to match the number of classes in our dataset:

import torch
import torch.nn as nn
from torchvision import models

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))

model = models.resnet18(pretrained=True):

  • models.resnet18(pretrained=True): Loads a pre-trained ResNet-18 model, which has already been trained on the ImageNet dataset. Using a pre-trained model allows for leveraging previously learned features, which is particularly useful if the new dataset is similar to ImageNet in terms of general image features (e.g., objects, textures).
  • pretrained=True: This argument indicates that we want to load the weights from a model that has been trained on ImageNet, which significantly speeds up training and improves accuracy in transfer learning tasks.

Extracting the Number of Features:

num_features = model.fc.in_features

The final layer (fc) of ResNet-18 is a fully connected layer that maps the model’s features to the number of classes in the ImageNet dataset (1,000 classes).

model.fc.in_features gives the number of input features for this final layer. This number of features represents the output of the last convolutional layer, which is flattened and fed into this fully connected layer.

Replacing the Final Layer:

model.fc = nn.Linear(num_features, len(class_names))

nn.Linear(num_features, len(class_names)): Replaces the original final fully connected layer with a new one that maps to the number of classes in your dataset, represented by len(class_names).

By doing this, we tailor the model to our dataset, changing the output layer to match the number of classes in the new classification task while retaining all other layers (which are now frozen and serve as feature extractors).

Step 3: Define the Loss Function and Optimizer

For our classifier, we’ll use cross-entropy loss and SGD optimizer with a learning rate of 0.001 and momentum:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

nn.CrossEntropyLoss(): This is a commonly used loss function for multi-class classification tasks. It combines LogSoftmax and NLLLoss in one single class.

Purpose: Cross-entropy loss measures the difference between the predicted probabilities and the actual class labels. Minimizing this loss helps the model improve its predictions.

torch.optim.SGD: Uses Stochastic Gradient Descent (SGD) for optimization. SGD updates the model’s weights based on the gradients of the loss function with respect to each parameter.

model.parameters(): This tells the optimizer which parameters of the model it should update during training. Since we replaced the final layer and potentially unfroze other layers, this will update only the specified trainable layers.

lr=0.001: The learning rate is a scaling factor that determines the step size at each update. A small learning rate like 0.001 helps the model converge gradually.

momentum=0.9: Momentum is used to accelerate SGD by helping it navigate local minima and stabilize updates. It essentially keeps track of the previous gradient direction and scales the current gradient update accordingly, making convergence faster and more stable.

Step 4: Training the Model

The training function iterates through the dataset for a specified number of epochs, computing the loss and adjusting weights accordingly. Here’s a simple training function:

def train_model(model, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        model.train()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders['train']:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')
print("Model saved to models/transfer_learned_model.pth")

4.1 Function Definition:

def train_model(model, criterion, optimizer, num_epochs=5):

This function trains the model for a specified number of epochs (default is 5).

Parameters:

  • model: The neural network model to train.
  • criterion: The loss function used to measure prediction error.
  • optimizer: The optimizer used to adjust model weights.
  • num_epochs: Number of training epochs, or full passes over the dataset.

4.2 Training Loop:

for epoch in range(num_epochs):
    print(f'Epoch {epoch}/{num_epochs - 1}')
    model.train()

Iterates over each epoch, displaying the current epoch.

model.train(): Sets the model to training mode, enabling layers like dropout and batch normalization to function correctly during training.

4.3 Batch Processing:

for inputs, labels in dataloaders['train']:
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

for inputs, labels in dataloaders['train']: Loops over mini-batches in the training data.

optimizer.zero_grad(): Resets gradients to zero before backpropagation.

outputs = model(inputs): Performs a forward pass to get model predictions.

loss = criterion(outputs, labels): Calculates loss by comparing predictions (outputs) with the actual labels (labels).

loss.backward(): Computes the gradients of the loss with respect to model parameters.

optimizer.step(): Updates model parameters based on the computed gradients.

4.4 Tracking Loss and Accuracy:

running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)

running_loss accumulates the total loss across batches for each epoch, which is later divided by the number of samples.

torch.max(outputs, 1): Gets the predicted class with the highest score for each input.

running_corrects tracks the number of correct predictions.

4.5 Epoch Loss and Accuracy:

epoch_loss = running_loss / dataset_sizes['train']
epoch_acc = running_corrects.double() / dataset_sizes['train']

Computes the average loss and accuracy for the epoch based on total samples in dataset_sizes['train'].

4.6 Model Saving:

model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')

torch.save saves the trained model’s state dictionary to a file, allowing reloading for inference or further training.

Step 5: Evaluating the Model

Once trained, it’s crucial to evaluate the model on the test set. This code loads the model, sets it to evaluation mode, and calculates the accuracy and confusion matrix:

from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()

true_labels = []
predictions = []

with torch.no_grad():
    for inputs, labels in dataloaders['test']:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        true_labels.extend(labels.numpy())
        predictions.extend(preds.numpy())

accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")

cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()

5.1 Loading and Setting the Model to Evaluation Mode:

model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()

model.load_state_dict: Loads the saved model weights from the specified path. This restores the model’s learned parameters from training.

model.eval(): Sets the model to evaluation mode. This is important because it disables layers like dropout and batch normalization, ensuring consistent behavior during testing.

5.2 Initializing Lists for True Labels and Predictions:

true_labels = []
predictions = []

These lists store the true and predicted labels for each sample, which will be used for accuracy calculation and confusion matrix generation.

5.3 Inference Loop:

with torch.no_grad():
    for inputs, labels in dataloaders['test']:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        true_labels.extend(labels.numpy())
        predictions.extend(preds.numpy())

torch.no_grad(): Disables gradient calculations to save memory and computation, as gradients are not needed during evaluation.

for inputs, labels in dataloaders['test']: Loops over the test dataset in batches.

outputs = model(inputs): Performs a forward pass to get the model’s predictions.

_, preds = torch.max(outputs, 1): Retrieves the predicted class with the highest score for each input.

true_labels.extend(labels.numpy()) and predictions.extend(preds.numpy()): Appends the true and predicted labels for each batch to the respective lists.

5.4 Calculating Accuracy:

accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")

accuracy_score(true_labels, predictions): Computes the overall accuracy of the model on the test set by comparing true and predicted labels.

5.5 Confusion Matrix and Visualization:

cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()

confusion_matrix(true_labels, predictions): Generates a confusion matrix, which shows the count of correct and incorrect predictions for each class.

sns.heatmap: Visualizes the confusion matrix with annotations to show exact counts. The color map "Blues" is used to differentiate high and low counts, with class names labeled on both axes for easy interpretation.

Step 6: Next Steps

Now that you have a trained model, you can:

  • Optimize the Model: Experiment with different hyperparameters or learning rates.
  • Deploy the Model: Save it in a deployable format (such as TorchScript or ONNX).
  • Expand with More Classes or Data: Increase the robustness of your model by adding more data or fine-tuning with specific requirements.

Complete Code

requirements.txt

torch
torchvision
numpy
matplotlib
scikit-learn
seaborn
fastapi
uvicorn
python-multipart

main.py

# python -m venv venv
# venv\Scripts\activate
# python.exe -m pip install --upgrade pip
# pip install -r requirements.txt --cache-dir "D:/internship/transfer_learning/.cache"
# pip freeze > requirements.txt

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set up transformations for the training and testing datasets
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load datasets
data_dir = 'data'
image_datasets = {
    x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
    for x in ['train', 'test']
}
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
    for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# Load a pre-trained model and modify the final layer
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))

# Set up loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training function
def train_model(model, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        model.train()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders['train']:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

# Train and save the model
model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')
print("Model saved to models/transfer_learned_model.pth")

# Load the model and set it to evaluation mode
model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()

# Evaluate on the test set
true_labels = []
predictions = []

with torch.no_grad():
    for inputs, labels in dataloaders['test']:
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        true_labels.extend(labels.numpy())
        predictions.extend(preds.numpy())

# Calculate accuracy
accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")

# Confusion matrix
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()

Deploying a Transfer Learning Model with FastAPI and PyTorch

predict.py

# uvicorn predict:app --reload
# curl -X POST "http://127.0.0.1:8000/predict/" -H "Content-Type: multipart/form-data" -F "file=@path_to_your_image.jpg"

from fastapi import FastAPI, File, UploadFile
from torchvision import models, transforms, datasets
from PIL import Image
import torch
import torch.nn as nn
import io

app = FastAPI()

# Load the class names from the training dataset
data_dir = 'data/train'  # Path to your training data
dataset = datasets.ImageFolder(data_dir)
class_names = dataset.classes  # Get the class names

# Load the saved model
model = models.resnet18()  # Correctly initialize the model
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # Update with the number of classes in your dataset
model.load_state_dict(torch.load("models/transfer_learned_model.pth"))
model.eval()

# Define the transformation for input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    img = Image.open(io.BytesIO(await file.read()))
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)
        class_name = class_names[predicted.item()]  # Map index to class name
    return {"prediction": class_name}

Explanation of Each Part

  1. Imports and FastAPI Initialization:
    • Imports necessary libraries and initializes the FastAPI app.
  2. Load Class Names:
    • Loads the class names from the dataset directory using ImageFolder, allowing the model to map predictions to readable class names.
  3. Model Loading:
    • Loads the ResNet-18 model and modifies the final layer to match the number of classes in your custom dataset. The trained weights are then loaded from models/transfer_learned_model.pth.
  4. Image Preprocessing:
    • Sets up transformations (resize, normalization) that match the ResNet-18 pre-trained model’s requirements.
  5. Prediction Endpoint:
    • Receives an image file, applies transformations, performs prediction, and returns the class name as a JSON response.

Making Predictions

To test the API, use curl to send an image to the server:

curl -X POST "http://127.0.0.1:8000/predict/" -H "Content-Type: multipart/form-data" -F "file=@path_to_your_image.jpg"

Replace path_to_your_image.jpg with the actual path to an image file.

Expected Response

The server will return a JSON response with the predicted class name:

{
    "prediction": "dog"  // or "cat", depending on the model output
}

Explanation of Model Fine-Tuning and Deployment Process

  1. Fine-Tuning:
    Transfer learning enables us to take a model like ResNet-18, pre-trained on the ImageNet dataset, and fine-tune it on a custom dataset with only a few adjustments to the model’s last layer.
  2. Deployment:
    • FastAPI provides a simple way to wrap our model in an API endpoint for real-time predictions.
    • The API receives image files, processes them, and outputs predictions based on the model’s training.
Subscribe
Notify of
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x