Pytorch入门(1) - 快速上手载入Data

首先你需要载入必要的库(libraries)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 提供了一些特定领域的库,如 TorchTextTorchVisionTorchAudio,它们都包含数据集。在本教程中,我们将使用 TorchVision 的数据集。

torchvision.datasets 模块包含适用于许多现实世界计算机视觉任务的数据集,例如 CIFARCOCO(完整列表请查看官方文档)。在本教程中,我们使用 FashionMNIST 数据集。

每个 TorchVision 数据集都包含两个参数:

  • transform:用于修改样本(图像)的转换方法。
  • target_transform:用于修改标签的转换方法。
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

这段代码用于下载并加载 FashionMNIST 数据集,用于训练和测试深度学习模型。我们逐行解析其作用:


1. 下载训练数据

training_data = datasets.FashionMNIST(
    root="data",       # 存储数据的路径
    train=True,        # True 表示加载训练集
    download=True,     # 如果数据集不存在,则自动下载
    transform=ToTensor()  # 数据转换为 PyTorch 张量
)

:pushpin: 详细解析

  • datasets.FashionMNIST(...)

    • 这个方法从 PyTorch 的 torchvision.datasets 模块中加载 FashionMNIST 数据集
    • FashionMNIST 是一个类似 MNIST(手写数字)但更复杂的数据集,包含 10 种不同的服饰类别,每张图片是 28×28 像素的灰度图像
  • 参数解释

    • root="data"
      • 指定数据存储路径(如 "data" 文件夹)。
      • 如果数据已经存在,则直接加载;如果不存在,则下载到 "data" 目录。
    • train=True
      • 只加载 训练集(FashionMNIST 训练集包含 60,000 张图片)。
    • download=True
      • 如果数据集 未下载,则从 官方 PyTorch 数据源 下载数据。
    • transform=ToTensor()
      • 将数据转换为 PyTorch 张量格式 (torch.Tensor),以便用于神经网络训练。

2. 下载测试数据

test_data = datasets.FashionMNIST(
    root="data",
    train=False,       # False 表示加载测试集
    download=True,
    transform=ToTensor(),
)

:pushpin: 作用

  • 这部分代码与训练数据的下载过程类似,唯一的区别是:
    • train=False
      • 这里下载的是 测试集(FashionMNIST 测试集包含 10,000 张图片)。

3. 为什么要使用 ToTensor()

transform=ToTensor() 主要做了两件事:

  1. 把 PIL 图片转换成 PyTorch 张量 (torch.Tensor)
    • 这样就可以直接输入到 PyTorch 模型进行训练。
  2. 归一化到 [0,1] 范围
    • 原始图片像素值是 0-255,而 ToTensor() 会将其归一化为 0-1 之间的浮点数
      [
      \text{new_pixel} = \frac{\text{original_pixel}}{255}
      ]
    • 归一化可以加快模型收敛,提高训练稳定性。

4. FashionMNIST 数据集简介

FashionMNIST 是一个用于图像分类的公开数据集,包含 10 个服饰类别

类别索引 服饰类别
0 T 恤 / 上衣
1 裤子
2 套头衫
3 连衣裙
4 外套
5 凉鞋
6 衬衫
7 运动鞋
8 手提包
9 短靴

每张图片是 28×28 的灰度图,类似于 MNIST(手写数字数据集),但比 MNIST 更具挑战性,因为服饰比手写数字更复杂。


5. 代码执行后的文件存储

如果 root="data",那么执行代码后,数据将被下载到如下目录:

data/
│── FashionMNIST/
│   ├── processed/       # 处理后的数据
│   ├── raw/             # 原始数据
│   ├── training.pt      # 训练集(PyTorch 格式)
│   ├── test.pt          # 测试集(PyTorch 格式)
  • processed/training.pt: 预处理后的训练数据
  • processed/test.pt: 预处理后的测试数据
  • raw/: 原始数据集(可能是 .gz 压缩文件)

总结

:pushpin: 这段代码的作用:

  1. 下载 FashionMNIST 训练集(train=True,共 60,000 张图片)。
  2. 下载 FashionMNIST 测试集(train=False,共 10,000 张图片)。
  3. 转换数据 为 PyTorch Tensor,归一化到 [0,1]
  4. 存储数据data/FashionMNIST/ 目录中,以 .pt 格式保存。

这段代码通常用于训练服饰分类神经网络,你可以用 DataLoader 来加载它们进行训练:

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

这样就可以进行 mini-batch 训练 了!:rocket: