Guide to Transfer Learning with PyTorch
Using custom images with a pre-trained model like ResNet-18 is common in transfer learning to adapt the model to a specific task or dataset. Here’s why we do it:
1. ResNet-18’s Original Purpose
ResNet-18 is pre-trained on the ImageNet dataset, which contains 1,000 general classes (like “tiger,” “airplane,” “guitar,” etc.). This is useful for recognizing those classes, but it doesn’t help if you need the model to classify images from a different set of categories (e.g., “cats” vs. “dogs” or specific product categories for an e-commerce site).
2. Transfer Learning Benefits
- Leverages Pre-Trained Knowledge: ResNet-18’s initial layers have already learned to recognize general image features (like edges, textures, and shapes). Instead of training a model from scratch, which requires lots of data and computing power, transfer learning allows us to use this feature recognition foundation and fine-tune the model for a new, specific task.
- Reduced Training Data Requirements: Since the model already “knows” basic image features, training with a smaller custom dataset is often sufficient to make it accurate for the new task.
3. Custom Images for Specific Needs
By fine-tuning on a set of custom images, we adapt ResNet-18’s pre-trained knowledge to our specific classes. For example, if you want a classifier that distinguishes between “cats” and “dogs,” you’ll need to train the model on those specific images. ResNet-18’s original 1,000 classes won’t cover custom classes unless fine-tuned.
4. Updating the Model’s Final Layer
The final layer of ResNet-18 is specifically designed to output predictions for 1,000 classes in the ImageNet dataset. When using custom classes, we modify this last layer to output predictions for our specific categories (like 2 classes for “cats” and “dogs” or however many classes your dataset requires).
Practical Steps in Transfer Learning:
- Load ResNet-18: Start with ResNet-18 pre-trained on ImageNet.
- Freeze Early Layers: Optionally, freeze the earlier layers (the ones capturing general features).
- Replace Final Layer: Replace the final fully connected layer with one that matches the number of classes in the custom dataset.
- Fine-Tune with Custom Images: Train the model on your dataset, adapting it to recognize custom classes.
Using transfer learning with custom images allows us to leverage the strengths of a powerful model like ResNet-18 while making it specific to your unique data.
Transfer learning is a powerful technique that leverages pre-trained models for specific tasks, reducing the need for large datasets and extensive computation. In this guide, we’ll walk through a complete workflow to implement transfer learning using a pre-trained ResNet-18 model in PyTorch.
Prerequisites
Before starting, ensure you have Python installed, along with PyTorch, torchvision, matplotlib, scikit-learn, and seaborn. You can set up a virtual environment to manage dependencies:
# Create a virtual environment
python -m venv venv
# Activate the virtual environment
venv\Scripts\activate # Windows
# Upgrade pip and install requirements
python.exe -m pip install --upgrade pip
pip install -r requirements.txt --cache-dir "D:/internship/transfer_learning/.cache"
pip freeze > requirements.txt
Step 1: Data Preparation and Transformation
For transfer learning, we’ll use custom data organized into train
and test
folders, with each category in a separate folder under train
and test
.
Data Transformations
To prepare images for the model, we’ll resize them to 224×224 pixels (the required input size for ResNet-18) and normalize the color channels to match the pre-trained ResNet model on ImageNet:
from torchvision import transforms
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
This code block is creating a dictionary called data_transforms
, which defines two sets of transformations: one for training data ('train'
) and one for testing data ('test'
). Both transformations are identical here, but setting them separately allows you to easily modify them later if necessary.
Explanation of Each Transformation
transforms.Compose([...])
: Combines multiple transformations into a single operation. Each transformation will be applied sequentially.transforms.Resize((224, 224))
: Resizes each image to a fixed size of 224×224 pixels. This size is commonly used as the input for many convolutional neural networks (CNNs), like ResNet, as they are often pre-trained on images of this size.transforms.ToTensor()
: Converts the image to a PyTorch tensor format. Additionally, it scales the pixel values from a range of[0, 255]
to[0, 1]
, which is necessary for consistency in neural networks.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
: Normalizes the image tensor to have a mean of[0.485, 0.456, 0.406]
and a standard deviation of[0.229, 0.224, 0.225]
for each RGB channel. These values are based on the ImageNet dataset statistics, which helps with transfer learning on pre-trained models, as they expect inputs normalized this way.
Loading the Dataset
We use ImageFolder
to load images, which automatically assigns labels based on folder names, allowing for easy dataset handling:
from torchvision import datasets
from torch.utils.data import DataLoader
data_dir = 'data'
image_datasets = {
x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
for x in ['train', 'test']
}
dataloaders = {
x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes
This code block sets up the dataset and data loaders for training and testing, using the torchvision.datasets.ImageFolder
class and torch.utils.data.DataLoader
. Here’s a breakdown of each part:
Explanation
data_dir = 'data'
:- Sets the directory containing the image datasets, with subdirectories for training and testing data (e.g.,
data/train
anddata/test
).
- Sets the directory containing the image datasets, with subdirectories for training and testing data (e.g.,
- Loading Datasets:
image_datasets = {
x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
for x in ['train', 'test']
}
datasets.ImageFolder
: This function loads images from directories and applies the specified transformations.
The dataset expects subdirectories named according to the class labels within each main folder (train
and test
).
The images are loaded with transformations defined earlier (data_transforms
), ensuring the images are resized, converted to tensors, and normalized.
image_datasets
: A dictionary containing train
and test
datasets, which can be accessed as image_datasets['train']
and image_datasets['test']
.
Creating Data Loaders:
dataloaders = {
x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
for x in ['train', 'test']
}
DataLoader
: Wraps the datasets into iterable objects, allowing easy batch processing during training and testing.
batch_size=32
: Loads images in batches of 32, which is useful for batch processing in neural networks.
shuffle=True
: Shuffles the data at each epoch to improve model generalization during training. For testing, this doesn’t impact the outcome but provides consistency in data handling.
Dataset Sizes and Class Names:
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes
dataset_sizes
: A dictionary that stores the number of images in the train
and test
datasets, allowing for accurate tracking of dataset sizes.
class_names
: Retrieves the class labels (directory names) within the train
dataset, which is useful for interpreting model predictions later.
Step 2: Setting Up the Model
We’ll use ResNet-18, a popular model for image classification, and modify its final fully connected layer to match the number of classes in our dataset:
import torch
import torch.nn as nn
from torchvision import models
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))
model = models.resnet18(pretrained=True)
:
models.resnet18(pretrained=True)
: Loads a pre-trained ResNet-18 model, which has already been trained on the ImageNet dataset. Using a pre-trained model allows for leveraging previously learned features, which is particularly useful if the new dataset is similar to ImageNet in terms of general image features (e.g., objects, textures).pretrained=True
: This argument indicates that we want to load the weights from a model that has been trained on ImageNet, which significantly speeds up training and improves accuracy in transfer learning tasks.
Extracting the Number of Features:
num_features = model.fc.in_features
The final layer (fc
) of ResNet-18 is a fully connected layer that maps the model’s features to the number of classes in the ImageNet dataset (1,000 classes).
model.fc.in_features
gives the number of input features for this final layer. This number of features represents the output of the last convolutional layer, which is flattened and fed into this fully connected layer.
Replacing the Final Layer:
model.fc = nn.Linear(num_features, len(class_names))
nn.Linear(num_features, len(class_names))
: Replaces the original final fully connected layer with a new one that maps to the number of classes in your dataset, represented by len(class_names)
.
By doing this, we tailor the model to our dataset, changing the output layer to match the number of classes in the new classification task while retaining all other layers (which are now frozen and serve as feature extractors).
Step 3: Define the Loss Function and Optimizer
For our classifier, we’ll use cross-entropy loss and SGD optimizer with a learning rate of 0.001 and momentum:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
nn.CrossEntropyLoss()
: This is a commonly used loss function for multi-class classification tasks. It combines LogSoftmax
and NLLLoss
in one single class.
Purpose: Cross-entropy loss measures the difference between the predicted probabilities and the actual class labels. Minimizing this loss helps the model improve its predictions.
torch.optim.SGD
: Uses Stochastic Gradient Descent (SGD) for optimization. SGD updates the model’s weights based on the gradients of the loss function with respect to each parameter.
model.parameters()
: This tells the optimizer which parameters of the model it should update during training. Since we replaced the final layer and potentially unfroze other layers, this will update only the specified trainable layers.
lr=0.001
: The learning rate is a scaling factor that determines the step size at each update. A small learning rate like 0.001 helps the model converge gradually.
momentum=0.9
: Momentum is used to accelerate SGD by helping it navigate local minima and stabilize updates. It essentially keeps track of the previous gradient direction and scales the current gradient update accordingly, making convergence faster and more stable.
Step 4: Training the Model
The training function iterates through the dataset for a specified number of epochs, computing the loss and adjusting weights accordingly. Here’s a simple training function:
def train_model(model, criterion, optimizer, num_epochs=5):
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
model.train()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders['train']:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes['train']
epoch_acc = running_corrects.double() / dataset_sizes['train']
print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
return model
model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')
print("Model saved to models/transfer_learned_model.pth")
4.1 Function Definition:
def train_model(model, criterion, optimizer, num_epochs=5):
This function trains the model for a specified number of epochs (default is 5).
Parameters:
model
: The neural network model to train.criterion
: The loss function used to measure prediction error.optimizer
: The optimizer used to adjust model weights.num_epochs
: Number of training epochs, or full passes over the dataset.
4.2 Training Loop:
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
model.train()
Iterates over each epoch, displaying the current epoch.
model.train()
: Sets the model to training mode, enabling layers like dropout and batch normalization to function correctly during training.
4.3 Batch Processing:
for inputs, labels in dataloaders['train']:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
for inputs, labels in dataloaders['train']
: Loops over mini-batches in the training data.
optimizer.zero_grad()
: Resets gradients to zero before backpropagation.
outputs = model(inputs)
: Performs a forward pass to get model predictions.
loss = criterion(outputs, labels)
: Calculates loss by comparing predictions (outputs
) with the actual labels (labels
).
loss.backward()
: Computes the gradients of the loss with respect to model parameters.
optimizer.step()
: Updates model parameters based on the computed gradients.
4.4 Tracking Loss and Accuracy:
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
running_loss
accumulates the total loss across batches for each epoch, which is later divided by the number of samples.
torch.max(outputs, 1)
: Gets the predicted class with the highest score for each input.
running_corrects
tracks the number of correct predictions.
4.5 Epoch Loss and Accuracy:
epoch_loss = running_loss / dataset_sizes['train']
epoch_acc = running_corrects.double() / dataset_sizes['train']
Computes the average loss and accuracy for the epoch based on total samples in dataset_sizes['train']
.
4.6 Model Saving:
model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')
torch.save
saves the trained model’s state dictionary to a file, allowing reloading for inference or further training.
Step 5: Evaluating the Model
Once trained, it’s crucial to evaluate the model on the test set. This code loads the model, sets it to evaluation mode, and calculates the accuracy and confusion matrix:
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()
true_labels = []
predictions = []
with torch.no_grad():
for inputs, labels in dataloaders['test']:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
true_labels.extend(labels.numpy())
predictions.extend(preds.numpy())
accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()
5.1 Loading and Setting the Model to Evaluation Mode:
model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()
model.load_state_dict
: Loads the saved model weights from the specified path. This restores the model’s learned parameters from training.
model.eval()
: Sets the model to evaluation mode. This is important because it disables layers like dropout and batch normalization, ensuring consistent behavior during testing.
5.2 Initializing Lists for True Labels and Predictions:
true_labels = []
predictions = []
These lists store the true and predicted labels for each sample, which will be used for accuracy calculation and confusion matrix generation.
5.3 Inference Loop:
with torch.no_grad():
for inputs, labels in dataloaders['test']:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
true_labels.extend(labels.numpy())
predictions.extend(preds.numpy())
torch.no_grad()
: Disables gradient calculations to save memory and computation, as gradients are not needed during evaluation.
for inputs, labels in dataloaders['test']
: Loops over the test dataset in batches.
outputs = model(inputs)
: Performs a forward pass to get the model’s predictions.
_, preds = torch.max(outputs, 1)
: Retrieves the predicted class with the highest score for each input.
true_labels.extend(labels.numpy())
and predictions.extend(preds.numpy())
: Appends the true and predicted labels for each batch to the respective lists.
5.4 Calculating Accuracy:
accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")
accuracy_score(true_labels, predictions)
: Computes the overall accuracy of the model on the test set by comparing true and predicted labels.
5.5 Confusion Matrix and Visualization:
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()
confusion_matrix(true_labels, predictions)
: Generates a confusion matrix, which shows the count of correct and incorrect predictions for each class.
sns.heatmap
: Visualizes the confusion matrix with annotations to show exact counts. The color map "Blues"
is used to differentiate high and low counts, with class names labeled on both axes for easy interpretation.
Step 6: Next Steps
Now that you have a trained model, you can:
- Optimize the Model: Experiment with different hyperparameters or learning rates.
- Deploy the Model: Save it in a deployable format (such as TorchScript or ONNX).
- Expand with More Classes or Data: Increase the robustness of your model by adding more data or fine-tuning with specific requirements.
Complete Code
requirements.txt
torch
torchvision
numpy
matplotlib
scikit-learn
seaborn
fastapi
uvicorn
python-multipart
main.py
# python -m venv venv
# venv\Scripts\activate
# python.exe -m pip install --upgrade pip
# pip install -r requirements.txt --cache-dir "D:/internship/transfer_learning/.cache"
# pip freeze > requirements.txt
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Set up transformations for the training and testing datasets
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Load datasets
data_dir = 'data'
image_datasets = {
x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
for x in ['train', 'test']
}
dataloaders = {
x: DataLoader(image_datasets[x], batch_size=32, shuffle=True)
for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes
# Load a pre-trained model and modify the final layer
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))
# Set up loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Training function
def train_model(model, criterion, optimizer, num_epochs=5):
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
model.train()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders['train']:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes['train']
epoch_acc = running_corrects.double() / dataset_sizes['train']
print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
return model
# Train and save the model
model_trained = train_model(model, criterion, optimizer)
torch.save(model_trained.state_dict(), 'models/transfer_learned_model.pth')
print("Model saved to models/transfer_learned_model.pth")
# Load the model and set it to evaluation mode
model.load_state_dict(torch.load('models/transfer_learned_model.pth'))
model.eval()
# Evaluate on the test set
true_labels = []
predictions = []
with torch.no_grad():
for inputs, labels in dataloaders['test']:
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
true_labels.extend(labels.numpy())
predictions.extend(preds.numpy())
# Calculate accuracy
accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")
# Confusion matrix
cm = confusion_matrix(true_labels, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()
Deploying a Transfer Learning Model with FastAPI and PyTorch
predict.py
# uvicorn predict:app --reload
# curl -X POST "http://127.0.0.1:8000/predict/" -H "Content-Type: multipart/form-data" -F "file=@path_to_your_image.jpg"
from fastapi import FastAPI, File, UploadFile
from torchvision import models, transforms, datasets
from PIL import Image
import torch
import torch.nn as nn
import io
app = FastAPI()
# Load the class names from the training dataset
data_dir = 'data/train' # Path to your training data
dataset = datasets.ImageFolder(data_dir)
class_names = dataset.classes # Get the class names
# Load the saved model
model = models.resnet18() # Correctly initialize the model
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # Update with the number of classes in your dataset
model.load_state_dict(torch.load("models/transfer_learned_model.pth"))
model.eval()
# Define the transformation for input image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
img = Image.open(io.BytesIO(await file.read()))
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
class_name = class_names[predicted.item()] # Map index to class name
return {"prediction": class_name}
Explanation of Each Part
- Imports and FastAPI Initialization:
- Imports necessary libraries and initializes the FastAPI app.
- Load Class Names:
- Loads the class names from the dataset directory using
ImageFolder
, allowing the model to map predictions to readable class names.
- Loads the class names from the dataset directory using
- Model Loading:
- Loads the ResNet-18 model and modifies the final layer to match the number of classes in your custom dataset. The trained weights are then loaded from
models/transfer_learned_model.pth
.
- Loads the ResNet-18 model and modifies the final layer to match the number of classes in your custom dataset. The trained weights are then loaded from
- Image Preprocessing:
- Sets up transformations (resize, normalization) that match the ResNet-18 pre-trained model’s requirements.
- Prediction Endpoint:
- Receives an image file, applies transformations, performs prediction, and returns the class name as a JSON response.
Making Predictions
To test the API, use curl
to send an image to the server:
curl -X POST "http://127.0.0.1:8000/predict/" -H "Content-Type: multipart/form-data" -F "file=@path_to_your_image.jpg"
Replace path_to_your_image.jpg
with the actual path to an image file.
Expected Response
The server will return a JSON response with the predicted class name:
{
"prediction": "dog" // or "cat", depending on the model output
}
Explanation of Model Fine-Tuning and Deployment Process
- Fine-Tuning:
Transfer learning enables us to take a model like ResNet-18, pre-trained on the ImageNet dataset, and fine-tune it on a custom dataset with only a few adjustments to the model’s last layer. - Deployment:
- FastAPI provides a simple way to wrap our model in an API endpoint for real-time predictions.
- The API receives image files, processes them, and outputs predictions based on the model’s training.