目录

一、为什么需要Dataset和DataLoader?

二、Dataset:数据的“说明书”

2.1 核心作用

2.2 自定义Dataset的步骤

2.3 关键细节

三、DataLoader:数据的“流水线工”

3.1 核心作用

3.2 使用方法

3.3 关键参数说明

四、实战:用Dataset+DataLoader训练模型

4.1 完整流程

4.2 简单训练示例

五、总结


在PyTorch中,数据加载是模型训练的第一步,而DatasetDataLoader则是这一流程的核心组件。它们分工明确,共同解决了“如何高效、灵活地读取数据”的问题。本文将用最通俗的语言,带你快速掌握这两个类的用法与区别。


一、为什么需要Dataset和DataLoader?

在传统机器学习中,我们常手动读取数据(如用pandas读CSV、用PIL读图片),但这种方式在深度学习中效率极低:

  • 数据量庞大时,逐样本读取会严重拖慢训练速度;
  • 无法方便地实现“批量加载”“随机打乱”“多线程加速”等需求;
  • 代码冗余,难以复用(换数据集就要重写读取逻辑)。

PyTorch的DatasetDataLoader正是为解决这些问题设计的:

  • ​Dataset​​:定义数据的“抽象结构”,告诉PyTorch“数据长什么样”(如何获取单个样本)。
  • ​DataLoader​​:基于Dataset,实现“批量加载+迭代”的自动化流程(如自动打乱、分批次、多线程加速)。

二、Dataset:数据的“说明书”

2.1 核心作用

Dataset是PyTorch对数据的​​抽象封装​​,它定义了两个核心方法,告诉PyTorch如何获取数据:

  • __len__():返回数据集的总样本数(必须实现)。
  • __getitem__(index):根据索引index返回单个样本(必须实现)。

2.2 自定义Dataset的步骤

要使用自定义数据集,只需继承torch.utils.data.Dataset,并重写上述两个方法即可。

​示例:图片分类数据集​
假设我们有一个图片数据集,结构如下:

dataset/
    train/
        class1/
            img1.jpg
            img2.jpg
        class2/
            img3.jpg
            img4.jpg
    val/
        ...

我们需要为训练集定义一个Dataset,用于按索引返回“图片+标签”。

import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        初始化函数
        :param root_dir: 数据集根目录(如dataset/train)
        :param transform: 数据预处理函数(如缩放、归一化)
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # 获取所有图片路径和对应的标签(类别名→索引)
        self.classes = sorted(os.listdir(root_dir))  # 类别列表(如['class1', 'class2'])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}  # 类别名→索引映射
        self.images = self._load_images()  # 所有图片路径列表(如[img1.jpg路径, img2.jpg路径,...])

    def _load_images(self):
        """辅助函数:遍历目录,收集所有图片路径"""
        images = []
        for cls in self.classes:
            cls_dir = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(cls_dir):
                if img_name.endswith(('.jpg', '.png')):  # 过滤非图片文件
                    img_path = os.path.join(cls_dir, img_name)
                    images.append((img_path, self.class_to_idx[cls]))  # (图片路径, 标签)
        return images

    def __len__(self):
        """返回数据集总样本数"""
        return len(self.images)

    def __getitem__(self, index):
        """根据索引返回单个样本(图片+标签)"""
        img_path, label = self.images[index]
        
        # 读取图片并转换为Tensor(需PIL支持)
        image = Image.open(img_path).convert('RGB')  # 统一转为RGB格式
        
        # 应用预处理(如归一化、缩放)
        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label, dtype=torch.long)  # 返回(图片Tensor,标签Tensor)

2.3 关键细节

  • transform参数:通常传入torchvision.transforms中的组合变换(如transforms.Compose([transforms.Resize(224), transforms.ToTensor()])),用于数据增强或标准化。
  • __getitem__的返回值:通常是(数据, 标签)的元组,但也可根据任务调整(如目标检测返回(图片, 边框, 标签))。

三、DataLoader:数据的“流水线工”

3.1 核心作用

Dataset解决了“如何获取单个样本”的问题,但训练时需要“批量处理数据”(如每次输入32张图片)。DataLoader的作用就是:

  • Dataset中的样本按批次(batch_size)打包;
  • 支持随机打乱(shuffle=True)训练数据;
  • 多线程加速加载(num_workers>0),避免训练时卡顿;
  • 自动处理数据迭代(for batch in dataloader:即可遍历所有批次)。

3.2 使用方法

通过torch.utils.data.DataLoader类实例化,传入Dataset和关键参数即可。

​示例:加载上面的ImageDataset​

from torchvision import transforms
from torch.utils.data import DataLoader

# 定义数据预处理(缩放+转Tensor+归一化)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 统一图片尺寸
    transforms.ToTensor(),          # 转为Tensor(范围0-1)
    transforms.Normalize(           # 标准化(ImageNet均值/标准差)
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 初始化Dataset
train_dataset = ImageDataset(
    root_dir='dataset/train',
    transform=transform  # 应用预处理
)

# 初始化DataLoader
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=32,       # 每批32个样本
    shuffle=True,        # 训练时随机打乱
    num_workers=4,       # 4个线程加速加载(根据CPU核心数调整)
    pin_memory=True      # 加速GPU数据传输(若用GPU训练,建议开启)
)

3.3 关键参数说明

参数 说明
batch_size 每批次的样本数(常用32、64、128,根据显存调整)。
shuffle 是否打乱数据(训练时设为True,验证/测试时设为False)。
num_workers 加载数据的线程数(0表示主进程加载,>0可加速,但需注意内存占用)。
pin_memory 是否将数据固定在内存(GPU训练时设为True,减少数据拷贝时间)。
drop_last 是否丢弃最后一个不完整的批次(若总样本数不能被batch_size整除时有用)。

四、实战:用Dataset+DataLoader训练模型

4.1 完整流程

  1. 定义Dataset(自定义或使用PyTorch内置的ImageFolderMNIST等);
  2. 定义数据预处理(transform);
  3. DataLoader加载Dataset,得到可迭代的批次数据;
  4. 在训练循环中遍历DataLoader,逐批喂数据给模型。

4.2 简单训练示例

import torch.nn as nn
import torch.optim as optim

# 假设我们有一个简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 111 * 111, num_classes)  # 假设输入224x224,经卷积池化后尺寸计算
        )

    def forward(self, x):
        return self.layers(x)

# 初始化模型、损失函数、优化器
model = SimpleCNN(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # 训练模式
    running_loss = 0.0
    
    # 遍历DataLoader的每个批次
    for batch_idx, (images, labels) in enumerate(train_dataloader):
        # 1. 清空梯度
        optimizer.zero_grad()
        
        # 2. 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 3. 反向传播+优化
        loss.backward()
        optimizer.step()
        
        # 统计损失
        running_loss += loss.item()
        if (batch_idx + 1) % 10 == 0:  # 每10个批次打印一次
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_dataloader)}], Loss: {running_loss/10:.4f}')
            running_loss = 0.0

五、总结

  • ​Dataset​​:数据的“说明书”,定义如何获取单个样本(必须实现__len____getitem__)。
  • ​DataLoader​​:数据的“流水线工”,负责批量加载、打乱、加速,将Dataset转化为可迭代的批次数据。
  • ​配合使用​​:通过Dataset封装数据逻辑,DataLoader处理工程细节,让模型训练更高效、灵活。

下次遇到数据加载问题,记得先想:我的Dataset定义对吗?DataLoader的参数调好了吗? 掌握这两个类,PyTorch数据处理不再难!

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐