这个帖子主要内容是讲如何使用Dataset载入FashionMNIST并可视化载入的图像。文章解释了每一行代码的含义,适合纯小白。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
以上代码在快速上手帖子中有介绍:Pytorch入门(1) - 快速上手载入Data
运行完上面的代码后,再运行下面的代码就可以实现简单的可视化。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
这段代码的主要功能是从一个名为 training_data
的数据集中随机选择一些图像,并将它们显示在一个 3x3 的网格中。每个图像下方会显示其对应的标签。以下是代码的详细解释:
1. labels_map
字典
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
labels_map
是一个字典,它将数字标签(0到9)映射到对应的服装类别名称。例如,标签0
对应"T-Shirt"
,标签1
对应"Trouser"
,依此类推。
2. 创建图像显示区域
figure = plt.figure(figsize=(8, 8))
- 使用
matplotlib
库创建一个大小为 8x8 英寸的图像显示区域。
3. 设置网格的行和列
cols, rows = 3, 3
- 设置网格的列数和行数,这里是一个 3x3 的网格,总共显示 9 张图像。
4. 循环显示图像
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
-
for i in range(1, cols * rows + 1):
循环 9 次(因为cols * rows = 9
),每次循环显示一张图像。 -
sample_idx = torch.randint(len(training_data), size=(1,)).item()
从training_data
数据集中随机选择一个索引。torch.randint
生成一个随机整数,范围是从 0 到len(training_data) - 1
。 -
img, label = training_data[sample_idx]
根据随机选择的索引sample_idx
,从training_data
数据集中获取对应的图像img
和标签label
。 -
figure.add_subplot(rows, cols, i)
在figure
中添加一个子图,位置由i
决定。i
从 1 到 9,表示在 3x3 网格中的位置。 -
plt.title(labels_map[label])
设置子图的标题为图像对应的标签名称。labels_map[label]
将数字标签转换为对应的服装类别名称。 -
plt.axis("off")
关闭子图的坐标轴显示,使图像更清晰。 -
plt.imshow(img.squeeze(), cmap="gray")
显示图像。img.squeeze()
用于去除图像中可能存在的单维度(例如,如果图像是 1x28x28,则squeeze()
会将其变为 28x28)。cmap="gray"
表示以灰度图的形式显示图像。
5. 显示图像
plt.show()
- 最后,调用
plt.show()
显示整个图像网格。
总结
这段代码的作用是从 training_data
数据集中随机选择 9 张图像,并将它们显示在一个 3x3 的网格中。每张图像下方会显示其对应的服装类别名称。图像以灰度图的形式显示,并且不显示坐标轴。