杰瑞科技汇

Python torchvision如何高效使用?

Of course! Here is a comprehensive guide to torchvision, a powerful and essential library in the PyTorch ecosystem.


What is torchvision?

torchvision is a library for Computer Vision that is part of the PyTorch project. Think of it as the "Swiss Army knife" for vision tasks in PyTorch. It provides highly optimized versions of key datasets, model architectures, and common image transformations.

Its main purpose is to make it incredibly easy to load data, preprocess images, and use state-of-the-art pre-trained models.

The library is typically imported as torchvision and has four main submodules:

  1. torchvision.datasets: A collection of popular vision datasets (e.g., MNIST, CIFAR-10, ImageNet, COCO).
  2. torchvision.models: A collection of pre-trained model architectures (e.g., ResNet, VGG, AlexNet, MobileNet).
  3. torchvision.transforms: Common image transformations (e.g., resizing, cropping, flipping, converting to tensors).
  4. torchvision.utils: Useful utility functions (e.g., making a grid of images for visualization).

torchvision.datasets: Loading Data Easily

This submodule is a huge time-saver. Instead of writing custom data loaders from scratch, you can load standard datasets with just a few lines of code.

Example: Loading the CIFAR-10 Dataset

CIFAR-10 is a dataset of 60,000 32x32 color images in 10 classes.

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
# Define a series of transformations to apply to the images
# 1. ToTensor() converts a PIL Image or numpy array to a PyTorch tensor
# 2. Normalize() normalizes the tensor with a mean and standard deviation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Mean and std for each channel
])
# Download and load the training data
train_set = datasets.CIFAR10(root='./data', train=True,
                              download=True, transform=transform)
# Download and load the test data
test_set = datasets.CIFAR10(root='./data', train=False,
                             download=True, transform=transform)
# Create data loaders to iterate over the data
# DataLoader wraps an iterable over a dataset and supports automatic batching
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)
# Example of iterating through the data
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}") # Should be [batch_size, channels, height, width]
    print(f"Labels shape: {labels.shape}")
    break # Just print the first batch
# The 10 classes in CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Key datasets parameters:

  • root: The directory where the dataset will be stored.
  • train: True for the training set, False for the test set.
  • download: True to download the data if it's not already present.
  • transform: A function/sequence of functions to apply to the data.

torchvision.models: Using Pre-trained Models

This is one of the most powerful features. You can load models that have already been trained on massive datasets like ImageNet. This is useful for:

  • Transfer Learning: Using a pre-trained model as a feature extractor or fine-tuning it for your own specific task.
  • Feature Extraction: Getting high-quality features from images without training a model from scratch.

Example: Loading a Pre-trained ResNet-18

import torchvision.models as models
# Load a pre-trained ResNet-18 model
# pretrained=True or weights=models.ResNet18_Weights.DEFAULT (newer syntax)
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# The model is now loaded with weights trained on ImageNet
print(model)
# If you want to use it for a different number of classes (e.g., 10 for CIFAR-10)
# You need to replace the final layer
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10) # Replace the final layer
# Now the model is ready to be fine-tuned on your 10-class dataset

Commonly available models:

  • resnet18, resnet34, resnet50, ...
  • alexnet
  • vgg16, vgg19
  • densenet121
  • inception_v3
  • mobilenet_v2, mobilenet_v3
  • shufflenet_v2
  • googlenet
  • efficientnet_b0, efficientnet_b1, ...
  • vit_b_16 (Vision Transformer)

torchvision.transforms: Preprocessing Images

Before feeding images to a neural network, you almost always need to preprocess them. The transforms module provides a wide range of common operations.

The Compose Transform

You often want to chain multiple transformations together. transforms.Compose does exactly that.

import torchvision.transforms as transforms
# Define a more complex transformation pipeline
# 1. Resize the image to 256x256
# 2. Crop the central 224x224 pixels (common for models like ResNet)
# 3. Randomly flip the image horizontally (for data augmentation)
# 4. Convert the image to a PyTorch Tensor
# 5. Normalize the tensor using ImageNet mean and std
transform_pipeline = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# You would then pass this pipeline to your dataset
# train_set = datasets.ImageFolder(root='path/to/data', transform=transform_pipeline)

Common Transforms:

  • Geometric: Resize, Scale, CenterCrop, RandomCrop, RandomResizedCrop, Pad, RandomHorizontalFlip, RandomVerticalFlip, Rotation.
  • Color/Intensity: ColorJitter (change brightness, contrast, saturation, hue), Grayscale, RandomGrayscale.
  • Conversion: ToTensor, PILToTensor, Lambda.
  • Normalization: Normalize.

torchvision.utils: Helper Functions

This submodule contains useful utilities, with make_grid being the most common for visualization.

Example: Visualizing a Batch of Images

import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
# Assume 'images' is a batch of tensors from your DataLoader
# images shape: [batch_size, C, H, W]
# Create a grid of images
# nrow specifies the number of columns in the grid
grid = vutils.make_grid(images, nrow=8, normalize=True)
# Convert the grid to a numpy array for plotting
# We need to move the channel dimension to the last (H, W, C)
np_grid = grid.numpy().transpose((1, 2, 0))
# Plot the grid
plt.figure(figsize=(12, 12))
plt.imshow(np_grid)
plt.axis('off') # Hide axes"A Batch of Images from CIFAR-10")
plt.show()

Putting It All Together: A Simple Training Loop

Here is a complete, minimal example of how these pieces fit together to train a model on CIFAR-10.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# 1. Define Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 2. Load Datasets
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
# 3. Load a Pre-trained Model and Modify It
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10) # CIFAR-10 has 10 classes
model = model.to(device)
# 4. Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 5. Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
print("Finished Training")
# 6. Evaluation (Example)
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Summary: Why is torchvision so important?

  • Saves Time: No need to write boilerplate code for data loading or model definition.
  • Standardization: Provides a consistent and standard way to handle common vision tasks.
  • Best Practices: The datasets, models, and transforms are built and maintained by the PyTorch team, ensuring they are efficient and follow best practices.
  • Foundation for Research: It provides the building blocks necessary to quickly prototype and implement new computer vision research ideas.
分享:
扫描分享到社交APP
上一篇
下一篇