Pytorch入门(6)- Dataset-Dataloader载入图像并可视化

这个帖子主要内容是讲如何使用Dataset载入FashionMNIST并可视化载入的图像。文章解释了每一行代码的含义,适合纯小白。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

以上代码在快速上手帖子中有介绍:Pytorch入门(1) - 快速上手载入Data

运行完上面的代码后,再运行下面的代码就可以实现简单的可视化。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

这段代码的主要功能是从一个名为 training_data 的数据集中随机选择一些图像,并将它们显示在一个 3x3 的网格中。每个图像下方会显示其对应的标签。以下是代码的详细解释:

1. labels_map 字典

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
  • labels_map 是一个字典,它将数字标签(0到9)映射到对应的服装类别名称。例如,标签 0 对应 "T-Shirt",标签 1 对应 "Trouser",依此类推。

2. 创建图像显示区域

figure = plt.figure(figsize=(8, 8))
  • 使用 matplotlib 库创建一个大小为 8x8 英寸的图像显示区域。

3. 设置网格的行和列

cols, rows = 3, 3
  • 设置网格的列数和行数,这里是一个 3x3 的网格,总共显示 9 张图像。

4. 循环显示图像

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
  • for i in range(1, cols * rows + 1):
    循环 9 次(因为 cols * rows = 9),每次循环显示一张图像。

  • sample_idx = torch.randint(len(training_data), size=(1,)).item()
    training_data 数据集中随机选择一个索引。torch.randint 生成一个随机整数,范围是从 0 到 len(training_data) - 1

  • img, label = training_data[sample_idx]
    根据随机选择的索引 sample_idx,从 training_data 数据集中获取对应的图像 img 和标签 label

  • figure.add_subplot(rows, cols, i)
    figure 中添加一个子图,位置由 i 决定。i 从 1 到 9,表示在 3x3 网格中的位置。

  • plt.title(labels_map[label])
    设置子图的标题为图像对应的标签名称。labels_map[label] 将数字标签转换为对应的服装类别名称。

  • plt.axis("off")
    关闭子图的坐标轴显示,使图像更清晰。

  • plt.imshow(img.squeeze(), cmap="gray")
    显示图像。img.squeeze() 用于去除图像中可能存在的单维度(例如,如果图像是 1x28x28,则 squeeze() 会将其变为 28x28)。cmap="gray" 表示以灰度图的形式显示图像。

5. 显示图像

plt.show()
  • 最后,调用 plt.show() 显示整个图像网格。

总结

这段代码的作用是从 training_data 数据集中随机选择 9 张图像,并将它们显示在一个 3x3 的网格中。每张图像下方会显示其对应的服装类别名称。图像以灰度图的形式显示,并且不显示坐标轴。