首先你需要载入必要的库(libraries)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch 提供了一些特定领域的库,如 TorchText、TorchVision 和 TorchAudio,它们都包含数据集。在本教程中,我们将使用 TorchVision 的数据集。
torchvision.datasets
模块包含适用于许多现实世界计算机视觉任务的数据集,例如 CIFAR 和 COCO(完整列表请查看官方文档)。在本教程中,我们使用 FashionMNIST 数据集。
每个 TorchVision 数据集都包含两个参数:
transform
:用于修改样本(图像)的转换方法。target_transform
:用于修改标签的转换方法。
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
这段代码用于下载并加载 FashionMNIST 数据集,用于训练和测试深度学习模型。我们逐行解析其作用:
1. 下载训练数据
training_data = datasets.FashionMNIST(
root="data", # 存储数据的路径
train=True, # True 表示加载训练集
download=True, # 如果数据集不存在,则自动下载
transform=ToTensor() # 数据转换为 PyTorch 张量
)
详细解析
-
datasets.FashionMNIST(...)
- 这个方法从 PyTorch 的
torchvision.datasets
模块中加载 FashionMNIST 数据集。 FashionMNIST
是一个类似MNIST
(手写数字)但更复杂的数据集,包含 10 种不同的服饰类别,每张图片是 28×28 像素的灰度图像。
- 这个方法从 PyTorch 的
-
参数解释
root="data"
- 指定数据存储路径(如
"data"
文件夹)。 - 如果数据已经存在,则直接加载;如果不存在,则下载到
"data"
目录。
- 指定数据存储路径(如
train=True
- 只加载 训练集(FashionMNIST 训练集包含 60,000 张图片)。
download=True
- 如果数据集 未下载,则从 官方 PyTorch 数据源 下载数据。
transform=ToTensor()
- 将数据转换为 PyTorch 张量格式 (
torch.Tensor
),以便用于神经网络训练。
- 将数据转换为 PyTorch 张量格式 (
2. 下载测试数据
test_data = datasets.FashionMNIST(
root="data",
train=False, # False 表示加载测试集
download=True,
transform=ToTensor(),
)
作用
- 这部分代码与训练数据的下载过程类似,唯一的区别是:
train=False
- 这里下载的是 测试集(FashionMNIST 测试集包含 10,000 张图片)。
3. 为什么要使用 ToTensor()
?
transform=ToTensor()
主要做了两件事:
- 把 PIL 图片转换成 PyTorch 张量 (
torch.Tensor
)- 这样就可以直接输入到 PyTorch 模型进行训练。
- 归一化到
[0,1]
范围- 原始图片像素值是
0-255
,而ToTensor()
会将其归一化为 0-1 之间的浮点数:
[
\text{new_pixel} = \frac{\text{original_pixel}}{255}
] - 归一化可以加快模型收敛,提高训练稳定性。
- 原始图片像素值是
4. FashionMNIST 数据集简介
FashionMNIST 是一个用于图像分类的公开数据集,包含 10 个服饰类别:
类别索引 | 服饰类别 |
---|---|
0 | T 恤 / 上衣 |
1 | 裤子 |
2 | 套头衫 |
3 | 连衣裙 |
4 | 外套 |
5 | 凉鞋 |
6 | 衬衫 |
7 | 运动鞋 |
8 | 手提包 |
9 | 短靴 |
每张图片是 28×28 的灰度图,类似于 MNIST(手写数字数据集),但比 MNIST 更具挑战性,因为服饰比手写数字更复杂。
5. 代码执行后的文件存储
如果 root="data"
,那么执行代码后,数据将被下载到如下目录:
data/
│── FashionMNIST/
│ ├── processed/ # 处理后的数据
│ ├── raw/ # 原始数据
│ ├── training.pt # 训练集(PyTorch 格式)
│ ├── test.pt # 测试集(PyTorch 格式)
processed/training.pt
: 预处理后的训练数据processed/test.pt
: 预处理后的测试数据raw/
: 原始数据集(可能是.gz
压缩文件)
总结
这段代码的作用:
- 下载 FashionMNIST 训练集(
train=True
,共 60,000 张图片)。 - 下载 FashionMNIST 测试集(
train=False
,共 10,000 张图片)。 - 转换数据 为 PyTorch
Tensor
,归一化到[0,1]
。 - 存储数据 在
data/FashionMNIST/
目录中,以.pt
格式保存。
这段代码通常用于训练服饰分类神经网络,你可以用 DataLoader
来加载它们进行训练:
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
这样就可以进行 mini-batch 训练 了!