Transferr Learning in Pytorch
A comprehensive introduction to PyTorch for deep learning.
Transfer Learning with PyTorch: Enhancing Your Classification Model
What is Transfer Learning?
Transfer learning allows us to leverage pre-trained models, using the patterns and knowledge they have acquired from large datasets to solve new, but related, problems. This technique is highly valuable in scenarios where acquiring a massive dataset or training a model from scratch is impractical.
Why Use Transfer Learning?
- Leverage Existing Models: Use neural network architectures that are already proven to work on similar problems.
- Efficiency: Achieve better results with less data and computational resources by building on models that have already learned useful patterns.
Setting Up Transfer Learning with PyTorch
Let’s dive into the practical steps to set up a transfer learning workflow using PyTorch, focusing on an image classification task with FoodVision Mini (pizza, steak, sushi).
0. Setup
Ensure you have the necessary libraries and modules. We’ll also download required scripts and data.
# Imports and setup
import torch
import torchvision
from torch import nn
from torchvision import transforms
try:
from torchinfo import summary
except ImportError:
!pip install torchinfo
from torchinfo import summary
# Download going_modular scripts
try:
from going_modular.going_modular import data_setup, engine
except ImportError:
!git clone https://github.com/mrdbourke/pytorch-deep-learning
!mv pytorch-deep-learning/going_modular .
!rm -rf pytorch-deep-learning
from going_modular.going_modular import data_setup, engine
# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
1. Get Data
Download and prepare the dataset.
import os
import zipfile
from pathlib import Path
import requests
# Setup data path
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"
if not image_path.is_dir():
image_path.mkdir(parents=True, exist_ok=True)
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
f.write(request.content)
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
zip_ref.extractall(image_path)
os.remove(data_path / "pizza_steak_sushi.zip")
train_dir = image_path / "train"
test_dir = image_path / "test"
2. Create Datasets and DataLoaders
We need to prepare our data loaders with appropriate transformations.
# Define manual transforms
manual_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create DataLoaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=manual_transforms,
batch_size=32
)
3. Load a Pretrained Model
Choose a pretrained model and customize it for our dataset.
# Load pretrained EfficientNet_B0
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = torchvision.models.efficientnet_b0(weights=weights).to(device)
# Print model summary
summary(model, input_size=(32, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
4. Customize the Model
Freeze the base layers and adjust the output layer for our specific task.
# Freeze base layers
for param in model.features.parameters():
param.requires_grad = False
# Modify the classifier head
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=len(class_names)).to(device)
# Verify model changes
summary(model, input_size=(32, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])
5. Train the Model
Use a suitable optimizer and loss function, then train the model.
# Set up loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
# Train the model
engine.train(model=model, train_dataloader=train_dataloader, test_dataloader=test_dataloader, optimizer=optimizer, loss_fn=loss_fn, epochs=5, device=device)
Conclusion
By following these steps, you can effectively utilize transfer learning to boost your model’s performance with minimal data and effort. This technique not only saves time and resources but also leverages the power of pre-trained models to achieve superior results on custom tasks.
Learn How To Build AI Projects
Now, if you are interested in upskilling in 2024 with AI development, check out this 6 AI advanced projects with Golang where you will learn about building with AI and getting the best knowledge there is currently. Here’s the link.
Last updated 17 Aug 2024, 12:31 +0200 .