Building and Training a Simple GAN Using PyTorch
Generative Adversarial Networks (GANs) are a type of deep learning model consisting of two neural networks—a generator and a discriminator—that compete against each other to produce high-quality, realistic data. They were introduced by Ian Goodfellow and his team in 2014 and have since become widely used in applications requiring synthetic data generation, especially in image synthesis.
Here’s how GANs work:
- Generator: The generator network creates fake data that resembles the real data. It takes random noise as input and generates data (e.g., an image) from this noise. The generator’s goal is to create realistic data that can “fool” the discriminator.
- Discriminator: The discriminator is a binary classifier trained to distinguish between real data (from the actual dataset) and fake data (generated by the generator). It outputs a probability indicating whether a given input is real or fake.
- Adversarial Training: Both networks are trained in a loop. The generator tries to improve its ability to produce realistic data, while the discriminator tries to get better at identifying fake data. As the training progresses, the generator gets better at creating realistic samples, and the discriminator becomes better at spotting fakes—until, ideally, the generator produces data that the discriminator cannot reliably distinguish from real data.
Training Process
The training process involves a minimax game, where:
- The generator aims to maximize the probability of the discriminator making an error.
- The discriminator aims to minimize this probability, accurately differentiating real data from fake.
Applications
GANs have found applications across many fields, including:
- Image Generation: GANs are used to create realistic images of people, animals, and objects that do not exist in reality.
- Style Transfer and Super-Resolution: GANs can change the style of an image or improve the resolution of images.
- Text-to-Image Synthesis: Models like DALL-E use GAN-based approaches to generate images based on textual descriptions.
- Video and Audio Synthesis: GANs are also applied to generate video frames or audio data, such as speech synthesis or music composition.
- Anomaly Detection: GANs can detect anomalies by learning to generate normal samples and flagging deviations as anomalies.
Advantages and Challenges
- Advantages: GANs can produce high-quality, diverse, and realistic samples, and they can be adapted to many types of data, including images, audio, and text.
- Challenges: GANs can be challenging to train due to issues like mode collapse (where the generator only produces a few similar outputs), and they require careful tuning of hyperparameters. Moreover, GAN-generated content can be used for malicious purposes, such as deepfakes.
Overall, GANs represent a significant advancement in generative modeling, enabling the creation of high-quality synthetic data across various domains.
Generative Adversarial Networks (GANs) are a powerful type of neural network used in deep learning to generate new data from a given dataset. In this article, we’ll walk through building and training a GAN using PyTorch to generate images similar to those from the MNIST dataset. The MNIST dataset consists of handwritten digits from 0 to 9, making it ideal for experimentation.
Prerequisites
Before we start, make sure you have Python installed, along with the necessary libraries such as PyTorch, torchvision, numpy, and matplotlib. Here’s how to set up your environment.
Create a virtual environment:
python -m venv venv
Activate the virtual environment:
- On Windows:
venv/Scripts/activate
Install dependencies:
python.exe -m pip install --upgrade pip
pip install torch torchvision numpy matplotlib --cache-dir D:/internship/gan/digits/.cache
Step 1: Preparing the Dataset
To train a GAN, we need a dataset. We’ll use MNIST, which will be automatically downloaded through torchvision. We will normalize the images to a range of [-1, 1], as this helps in stabilizing GAN training.
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize images to [-1, 1]
])
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./content/drive/MyDrive/Colab Notebooks/gan', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
transforms.Compose([...])
: This function allows us to combine multiple transformations. Here, we are applying two transformations:
transforms.ToTensor()
: Converts the images from a PIL (Python Imaging Library) format to a PyTorch tensor. It also scales the pixel values from the usual range of 0 to 255 down to a 0.0 to 1.0 range.transforms.Normalize((0.5,), (0.5,))
: This normalizes the pixel values of the image to the range [-1, 1]. Here’s how it works:- The first parameter
(0.5,)
specifies the mean to subtract from each pixel, while the second parameter(0.5,)
is the standard deviation used to scale each pixel. - Normalizing to the range [-1, 1] is common in GANs because it helps stabilize training and makes the data more compatible with certain neural network activations, like
Tanh
.
- The first parameter
datasets.MNIST(...)
: This function loads the MNIST dataset from torchvision’s dataset library.
root='./content/drive/MyDrive/Colab Notebooks/gan'
: Specifies where to save or load the dataset. If it doesn’t already exist in this location, it will be downloaded.train=True
: Indicates that we want the training portion of the MNIST dataset (not the test portion).transform=transform
: Applies the transformations we defined earlier (converting to a tensor and normalizing).download=True
: Downloads the dataset if it’s not already in the specified directory.
DataLoader
: This class wraps the dataset into an iterable, providing convenient ways to batch and shuffle the data for training.
train_dataset
: The dataset to load, here set to the training portion of MNIST.batch_size=64
: Specifies the number of images per batch, allowing the model to process multiple images at once. A batch size of 64 is often a good balance between memory usage and training stability.shuffle=True
: Shuffles the dataset at each epoch to ensure the model doesn’t see the images in the same order each time, which can improve generalization.
Step 2: Define the Generator and Discriminator Models
A GAN has two parts:
- Generator – Creates images similar to those in the training data.
- Discriminator – Evaluates the authenticity of images (real or generated).
import torch.nn as nn
# Generator Model
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img
# Discriminator Model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
Generator Model
The Generator takes a random noise vector as input and transforms it into a 28×28 image that resembles an MNIST digit.
Breakdown of Generator Layers:
latent_dim
: The size of the input noise vector (e.g., 100). This is the “seed” for generating images.- Linear Layers: Each layer successively transforms the noise into higher-dimensional feature spaces. (to create a structured, high-dimensional output that resembles an image)
- ReLU Activations: Introduce non-linearity, helping the network learn complex patterns.
- Linear vs. Non-Linear Transformations: Without activation functions, each layer in a neural network would just apply a linear transformation (i.e., matrix multiplication) to the input. Stacking linear transformations results in a single, equivalent linear transformation. This limits the network’s ability to learn anything complex because all it could represent is a straight line or flat plane of predictions.
- BatchNorm: Stabilizes training by normalizing the input to each layer, reducing internal covariate shift. BatchNorm helps the Generator produce more consistent outputs across training batches, aiding in realistic image generation and reducing the risk of instability that often plagues GAN training. This standardization is applied to mini-batches of data during training, hence the name Batch Normalization.
- Tanh Activation: The output layer uses
Tanh
to scale the output to a range of [-1, 1], matching the normalized range of MNIST images. The Tanh activation function is used in the output layer of the Generator to scale its output to a range of [-1, 1]. This ensures that the generated images have pixel values in a similar range as the real images, helping the Discriminator more effectively distinguish between real and fake images.
Discriminator Model
The Discriminator is a binary classifier that takes an image (real or generated) and outputs a probability indicating whether the image is real or fake.
Breakdown of Discriminator Layers:
- Input Layer: Accepts a 784-dimensional vector (flattened 28×28 image).
- LeakyReLU Activations: Allows a small, non-zero gradient when the input is negative, which helps prevent dying neurons.
- In a standard ReLU activation, any negative input is set to zero, which results in a gradient of zero for that neuron. If a neuron consistently receives negative inputs, it can “die” because it stops learning entirely (its gradient is zero).
- LeakyReLU addresses this by allowing a small gradient (usually around 0.01 times the input) for negative inputs instead of zero. This small slope keeps neurons “alive” even if they receive mostly negative values, ensuring they can still contribute to learning.
- Sigmoid Activation: The output layer uses
Sigmoid
to produce a value between 0 and 1, representing the probability that the input image is real. The Sigmoid function squashes any input to a range of [0, 1], which is well-suited for binary classification tasks like those in GANs, where the Discriminator needs to classify images as either “real” or “fake”. Sigmoid is preferred over ReLU for several reasons, mainly because Sigmoid is designed to handle binary classification tasks, while ReLU is generally suited for hidden layers in neural networks rather than output layers where probabilities are needed.
Step 3: Training the GAN
The training process alternates between updating the Generator and the Discriminator to balance the GAN. Here’s the full training loop.
import torch.optim as optim
# Hyperparameters
latent_dim = 100
num_epochs = 200
learning_rate = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models and optimizer
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
# Move images to device
imgs = imgs.to(device)
# Ground truths for real and fake images
valid = torch.ones(imgs.size(0), 1, device=device)
fake = torch.zeros(imgs.size(0), 1, device=device)
# Train Generator
optimizer_G.zero_grad()
z = torch.randn(imgs.size(0), latent_dim, device=device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
latent_dim
is the length of the random vector input to the Generator.
num_epochs
specifies how many complete passes through the training data.
learning_rate
is the step size used by the optimizer for updating weights.
device
ensures the training can be performed on a GPU if available, for faster computation.
The Generator and Discriminator are initialized and moved to the specified device.Binary Cross-Entropy Loss (BCE) is used to measure the discrepancy between predicted and actual labels, helping the models learn.Adam Optimizer with the given learning rate is used for both the Generator and Discriminator, providing smooth and effective weight updates.
Training Loop: The main training loop runs for num_epochs
epochs, and each epoch involves iterating through all images in the train_loader
batch by batch.
valid
and fake
are ground truth labels used to calculate the BCE loss for real and fake images respectively. Real images are labeled as 1
, while fake images are labeled as 0
.
Train the Generator:
Generate a batch of random noise vectors z
and use the Generator to create images (gen_imgs
).
Pass the generated images through the Discriminator and calculate the loss (g_loss
) assuming they are real (valid
label). The goal is to “fool” the Discriminator into thinking generated images are real.
Backpropagate and Update the Generator’s weights to minimize g_loss
.
Train the Discriminator:
Calculate real_loss
by passing actual images through the Discriminator and comparing the output to valid
labels.
Calculate fake_loss
by passing the generated images (gen_imgs
) through the Discriminator with the label fake
.
Average the losses to get d_loss
, then backpropagate and update the Discriminator’s weights to better distinguish real from fake images.
After each epoch, print the Discriminator (d_loss
) and Generator (g_loss
) losses, providing insight into how well each model is learning over time.
Step 4: Generate and Save Images
Once training is complete, we can generate new images with the trained Generator. Below is a helper function to create a grid of generated images and save them as a single image file.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def save_generated_images(generator, latent_dim, num_images=25, output_path="/content/drive/MyDrive/Colab Notebooks/gan/generated_images.png"):
generator.eval()
with torch.no_grad():
z = torch.randn(num_images, latent_dim, device=device)
gen_imgs = generator(z)
gen_imgs = gen_imgs * 0.5 + 0.5 # Rescale images to [0, 1]
gen_imgs = gen_imgs.cpu().numpy()
# Create a 5x5 grid of generated images
img_grid = np.zeros((5 * 28, 5 * 28))
for i in range(5):
for j in range(5):
img_grid[i*28:(i+1)*28, j*28:(j+1)*28] = gen_imgs[i * 5 + j, 0]
img = Image.fromarray((img_grid * 255).astype(np.uint8))
img.save(output_path)
print(f"Generated images saved to {output_path}")
save_generated_images(generator, latent_dim)
Complete Code
# python -m venv venv
# venv/Scripts/activate
# python.exe -m pip install --upgrade pip
# pip install torch torchvision numpy matplotlib --cache-dir D:/internship/gan/digits/.cache
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# Hyperparameters
latent_dim = 100
num_epochs = 200
batch_size = 64
learning_rate = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Step 1: Prepare Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize images to [-1, 1]
])
train_dataset = datasets.MNIST(root='./content/drive/MyDrive/Colab Notebooks/gan', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Step 2: Define the Generator and Discriminator
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28) # Reshape to image format
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # Flatten image
validity = self.model(img_flat)
return validity
# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Loss function and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
# Step 3: Train the GAN
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
# Move images to device
imgs = imgs.to(device)
# Ground truths
valid = torch.ones(imgs.size(0), 1, device=device)
fake = torch.zeros(imgs.size(0), 1, device=device)
# Train Generator
optimizer_G.zero_grad()
z = torch.randn(imgs.size(0), latent_dim, device=device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
# Step 4: Generate and Save Images
def save_generated_images(generator, latent_dim, num_images=25, output_path="/content/drive/MyDrive/Colab Notebooks/gan/generated_images.png"):
generator.eval()
with torch.no_grad():
z = torch.randn(num_images, latent_dim, device=device)
gen_imgs = generator(z)
gen_imgs = gen_imgs * 0.5 + 0.5 # Rescale images to [0, 1]
gen_imgs = gen_imgs.cpu().numpy()
# Create a 5x5 grid of generated images
img_grid = np.zeros((5 * 28, 5 * 28))
for i in range(5):
for j in range(5):
img_grid[i*28:(i+1)*28, j*28:(j+1)*28] = gen_imgs[i * 5 + j, 0]
# Save the grid as an image
img = Image.fromarray((img_grid * 255).astype(np.uint8))
img.save(output_path)
print(f"Generated images saved to {output_path}")
# Call function to save images
save_generated_images(generator, latent_dim)
Summary
In this article, we explored building and training a simple GAN using PyTorch. The Generator learns to produce realistic images by tricking the Discriminator, which is trained to distinguish between real and generated images. Over time, this adversarial process allows the Generator to produce highly realistic images similar to those in the MNIST dataset.