杰瑞科技汇

Dataloader如何高效批量加载数据?

DataLoader 是 PyTorch 中一个非常重要且核心的工具,它就像一个数据加载的“管家”,负责高效、便捷地将你的数据喂给神经网络模型进行训练。

Dataloader如何高效批量加载数据?-图1
(图片来源网络,侵删)

DataLoader 是什么?为什么需要它?

想象一下,你要训练一个模型,你有一个包含 10 万张图片的数据集,如果一次性把这 10 万张图片都加载到内存里,电脑很可能会直接“爆炸”,即使内存够用,在训练时,你也需要一个一个或一小批一小批地取出数据,进行前向传播、计算损失、反向传播和更新参数。

手动实现这个过程会非常繁琐且容易出错。DataLoader 就是为了解决这个问题而生的,它的主要作用是:

  1. 批量处理:将数据集分成一个个小批次,模型每次处理一个批次,而不是单个样本,这能充分利用 GPU 的并行计算能力,大大提高训练效率。
  2. 数据打乱:在每个 epoch(训练轮次)开始时,可以随机打乱数据的顺序,这对于防止模型学习到数据中的特定顺序(比如总是先看到猫再看到狗)导致的过拟合至关重要。
  3. 多进程加载:使用多个子进程来并行加载数据,从而将数据加载和模型计算在 CPU 和 GPU 上同时进行,减少 GPU 的等待时间,提升整体训练速度。
  4. 便捷的数据处理:可以方便地与 Dataset 类配合,Dataset 负责存储数据和索引数据,而 DataLoader 负责按需批量地取出数据。

Dataset 定义了“你的数据在哪里以及如何获取一个样本”,而 DataLoader 定义了“如何从 Dataset 中高效地、批量地、可配置地获取数据”


DataLoader 的核心组件

要使用 DataLoader,你通常需要两个核心组件:

Dataloader如何高效批量加载数据?-图2
(图片来源网络,侵删)

a. Dataset

Dataset 是一个抽象类,你需要继承它并实现两个方法:

  • __len__(): 返回数据集中样本的总数。
  • getitem__(index): 根给定的索引 index,返回一个样本(一个图像张量及其对应的标签)。

示例:创建一个自定义的 Dataset

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
# 假设我们的数据集结构如下:
# dataset/
#   ├── cat/
#   │   ├── cat1.jpg
#   │   └── cat2.jpg
#   └── dog/
#       ├── dog1.jpg
#       └── dog2.jpg
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        """
        初始化函数
        :param annotations_file: 标签文件路径 (这里我们简化,用文件夹名作为标签)
        :param img_dir: 图片所在的根目录
        :param transform: 对图片应用的变换
        :param target_transform: 对标签应用的变换
        """
        self.img_labels = [] # 存储图片路径和标签
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        # 遍历文件夹,创建图片路径和标签的列表
        # 这里我们假设文件夹名就是类别名
        for label in os.listdir(img_dir):
            class_dir = os.path.join(img_dir, label)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.img_labels.append((os.path.join(class_dir, img_name), label))
    def __len__(self):
        """返回数据集的大小"""
        return len(self.img_labels)
    def __getitem__(self, idx):
        """
        根据索引获取一个样本
        :param idx: 索引
        :return: 一个样本 (图片, 标签)
        """
        img_path, label = self.img_labels[idx]
        image = Image.open(img_path).convert("RGB") # 用PIL打开图片
        # 将标签转换为数字 (e.g., 'cat' -> 0, 'dog' -> 1)
        label_map = {'cat': 0, 'dog': 1}
        label = label_map[label]
        # 应用变换
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
# 定义一些变换,比如将图片转换为Tensor
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((64, 64)), # 调整大小
    transforms.ToTensor()       # 转换为Tensor并归一化到[0, 1]
])
# 创建Dataset实例
dataset = CustomImageDataset(annotations_file=None, img_dir='dataset', transform=transform)
print(f"Dataset size: {len(dataset)}")
# 获取第一个样本
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")

b. DataLoader

现在我们有了 Dataset,就可以用它来创建 DataLoader 了。

DataLoader 的常用参数:

Dataloader如何高效批量加载数据?-图3
(图片来源网络,侵删)
  • dataset: (必需) 要加载的 Dataset 对象。
  • batch_size: (默认为1) 每个批次包含的样本数量。
  • shuffle: (默认为False) 是否在每个 epoch 开始时打乱数据顺序。在训练集上通常设为 True,在验证集和测试集上设为 False
  • num_workers: (默认为0) 数据加载时使用多少个子进程,0 表示主进程加载,大于 0 可以实现多进程并行加载,对于大型数据集能显著提升速度
  • collate_fn: (可选) 一个函数,用于将一个批次的数据样本组合成一个批次,当你批次中的样本大小不一(NLP 中的句子长度不同)时,这个函数非常有用,默认情况下,它会自动处理相同形状的样本。
  • drop_last: (默认为False) 如果数据集大小不能被 batch_size 整除,是否丢弃最后一个不完整的批次。

如何使用 DataLoader (完整示例)

我们接着上面的 CustomImageDataset 例子,来看如何使用 DataLoader 进行训练。

# 1. 准备数据集
# (使用上面定义的 CustomImageDataset 和 transform)
dataset = CustomImageDataset(img_dir='dataset', transform=transform)
# 2. 创建 DataLoader
# batch_size=4: 每次取4张图片
# shuffle=True: 训练时打乱数据
# num_workers=2: 使用2个子进程加载数据
# 注意: num_workers > 0 时,最好在 if __name__ == '__main__': 的代码块中运行,避免Windows下多进程问题
train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 3. 在训练循环中使用 DataLoader
# 假设我们有一个简单的模型
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(64*64*3, 64), # 输入是64x64x3的图片
    torch.nn.ReLU(),
    torch.nn.Linear(64, 2)       # 输出是2个类别 (cat, dog)
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.S.Adam(model.parameters(), lr=0.001)
print("\nStarting training...")
# 模拟训练一个epoch
num_epochs = 1
for epoch in range(num_epochs):
    print(f"--- Epoch {epoch+1}/{num_epochs} ---")
    # enumerate(train_loader) 会返回两个值:
    # 1. batch_idx: 当前批次的索引
    # 2. (images, labels): 当前批次的数据和标签
    for batch_idx, (images, labels) in enumerate(train_loader):
        # 1. 前向传播
        # images 的形状是 [batch_size, channels, height, width], e.g., [4, 3, 64, 64]
        # labels 的形状是 [batch_size], e.g., [4]
        outputs = model(images)
        # 2. 计算损失
        loss = loss_fn(outputs, labels)
        # 3. 反向传播和优化
        optimizer.zero_grad() # 清空过往梯度
        loss.backward()       # 计算当前梯度
        optimizer.step()      # 更新权重
        if (batch_idx + 1) % 2 == 0:
            print(f"Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
print("Training finished.")

输出示例:

Dataset size: 4
Sample shape: torch.Size([3, 64, 64]), Label: 0
Starting training...
--- Epoch 1/1 ---
Batch 2/1, Loss: 0.6931
Training finished.

(注意:如果你的数据集很小,len(train_loader) 可能等于 1,所以这里只打印一次)


collate_fn 详解

collate_fnDataLoader 的一个高级但非常有用的功能,它的作用是定义如何将一个列表的样本组合成一个批次

默认情况下,DataLoader 会假设你的所有样本(每个样本的图像张量和标签张量)都具有相同的形状,然后将它们简单地堆叠起来(使用 torch.stack)。

但在某些情况下,默认行为不适用:

  • NLP 任务:句子的长度不同,不能直接堆叠。
  • 目标检测:一张图片可能有不同数量的边界框。

示例:处理变长序列

假设我们有一个文本数据集,每个样本是一个单词的索引列表,长度不一。

from torch.nn.utils.rnn import pad_sequence
# 假设的文本数据集
text_data = [
    torch.tensor([1, 2, 3]),
    torch.tensor([4, 5]),
    torch.tensor([6, 7, 8, 9]),
    torch.tensor([10])
]
# 定义 collate_fn
def collate_fn(batch):
    # batch 是一个列表,列表中的每个元素都是一个样本 (tensor,)
    #  [tensor([1, 2, 3]), tensor([4, 5]), ...]
    # 使用 pad_sequence 将不同长度的张量填充到相同长度
    # padding_value=0 表示用0进行填充
    padded_sequences = pad_sequence(batch, batch_first=True, padding_value=0)
    # 返回填充后的批次
    return padded_sequences
# 创建 DataLoader
loader = DataLoader(text_data, batch_size=2, shuffle=True, collate_fn=collate_fn)
print("\nUsing custom collate_fn for variable-length sequences:")
for batch in loader:
    print(f"Padded batch shape: {batch.shape}")
    print(f"Batch content:\n{batch}\n")

输出示例:

Using custom collate_fn for variable-length sequences:
Padded batch shape: torch.Size([2, 4])
Batch content:
tensor([[ 1,  2,  3,  0],
        [ 6,  7,  8,  9]])
Padded batch shape: torch.Size([2, 3])
Batch content:
tensor([[ 4,  5,  0],
        [10,  0,  0]])

可以看到,collate_fn 成功地将不同长度的序列填充到了批次中最长的长度。


最佳实践与注意事项

  1. num_workers 和 Windows:在 Windows 系统下使用 num_workers > 0 时,如果你的训练脚本直接运行(而不是在 if __name__ == '__main__': 块内),可能会遇到多进程相关的错误。最佳实践是将主训练逻辑放在 if __name__ == '__main__':
  2. num_workers 的选择num_workers 的值不是越大越好,通常设置为 4, 8, 16 等根据你的 CPU 核心数来定,过多的子进程可能会导致 CPU 竞争,反而降低效率,建议从 4 开始尝试。
  3. 数据加载瓶颈DataLoader 的速度跟不上 GPU 的训练速度,GPU 就会空闲等待,这说明你的数据加载可能成为了瓶颈,可以尝试:
    • 增加 num_workers
    • 优化 Dataset 中的 __getitem__ 方法,例如使用更快的图像解码库(如 OpenCV)。
    • 使用 pin_memory=True(见下一点)。
  4. pin_memory=True:当你在使用 GPU 时,可以设置 pin_memory=True,这会让 DataLoader 将数据加载到“锁页内存”(Pinned Memory)中,锁页内存的 CPU 到 GPU 的数据传输速度比非锁页内存更快,它会占用更多内存,但可以显著减少数据传输的等待时间。
概念 作用 关键点
Dataset 定义数据的存储和索引方式。 必须实现 __len__getitem__
DataLoader 定义数据的加载和批处理方式。 负责 batch_size, shuffle, num_workers 等。
collate_fn 自定义如何组合一个批次的数据。 用于处理变长序列、不规则数据等。

掌握 DataLoader 是使用 PyTorch 进行深度学习项目的基础,理解它的工作原理,能够帮助你构建高效、可扩展的训练流程。

分享:
扫描分享到社交APP
上一篇
下一篇