Pytorch入门(7) - Dataset-Dataloader载入自己的图像(分类问题)

在 PyTorch 中,定义一个自定义数据集类(Dataset)需要实现三个核心函数:__init____len____getitem__。下面我们将通过一个具体的例子来详细解释这些函数的作用和实现方式。假设我们有一个 FashionMNIST 数据集,图像存储在一个名为 img_dir 的目录中,而标签则存储在一个单独的 CSV 文件 annotations_file 中。

1. 导入必要的库

首先,我们需要导入一些必要的库:

import os
import pandas as pd
from torchvision.io import read_image
  • os: 用于处理文件路径和目录操作。
  • pandas as pd: 用于读取和处理 CSV 文件(通常存储图像文件名和标签)。
  • read_image: 从 torchvision.io 导入的函数,用于加载图像文件并将其转换为 PyTorch 张量(Tensor)。

2. 定义 CustomImageDataset

接下来,我们定义一个自定义的数据集类 CustomImageDataset,继承自 PyTorch 的 Dataset 类:

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 这是一个自定义的数据集类,继承自 PyTorch 的 Dataset 类。PyTorch 的 Dataset 类是所有数据集的基类,自定义数据集需要实现 __len____getitem__ 方法。

3. __init__ 方法:初始化数据集

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform
  • annotations_file: 是一个 CSV 文件的路径,文件内容通常是图像文件名和对应的标签。
  • img_dir: 是图像文件存储的目录路径。
  • transform: 是一个可选的图像预处理函数(例如,调整大小、归一化等)。
  • target_transform: 是一个可选的标签预处理函数(例如,将标签转换为 one-hot 编码)。
  • self.img_labels = pd.read_csv(annotations_file): 读取 CSV 文件,将其存储为一个 Pandas DataFrame。DataFrame 的每一行通常包含图像文件名和对应的标签。
  • self.img_dir = img_dir: 存储图像目录路径。
  • self.transformself.target_transform: 存储图像和标签的预处理函数。

4. __len__ 方法:返回数据集的大小

def __len__(self):
    return len(self.img_labels)
  • 返回数据集中样本的数量(即 CSV 文件中的行数)。

5. __getitem__ 方法:获取单个样本

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label
  • idx: 是数据集中样本的索引(从 0 到 len(dataset)-1)。
  • img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]):
    根据索引 idx,从 CSV 文件中获取图像文件名,并将其与图像目录路径 self.img_dir 拼接,得到完整的图像文件路径。
  • image = read_image(img_path):
    使用 read_image 函数加载图像文件,并将其转换为 PyTorch 张量(Tensor)。
  • label = self.img_labels.iloc[idx, 1]:
    从 CSV 文件中获取对应图像的标签。
  • if self.transform: image = self.transform(image):
    如果定义了图像预处理函数 self.transform,则对图像进行预处理。
  • if self.target_transform: label = self.target_transform(label):
    如果定义了标签预处理函数 self.target_transform,则对标签进行预处理。
  • return image, label:
    返回处理后的图像和标签。

6. 使用 DataLoader 加载数据

为了在训练模型时方便地加载数据,我们可以使用 PyTorch 的 DataLoaderDataLoader 可以将数据集分成小批次(batch),并支持多线程加载数据和打乱数据顺序。

下面是载入数据并可视化的代码,我们将详细解释每一段的内容。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

6.1 导入 DataLoader

from torch.utils.data import DataLoader
  • DataLoader 是 PyTorch 提供的一个工具,用于将数据集分成小批次(batch),并支持多线程加载数据、打乱数据等功能。它通常用于训练深度学习模型。

6.2 创建 DataLoader 实例

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  • training_datatest_data:
    这是两个数据集对象,通常是 PyTorch 的 Dataset 实例(例如前面定义的 CustomImageDataset)。
  • batch_size=64:
    每个批次包含 64 个样本(图像和标签)。
  • shuffle=True:
    在每个 epoch 开始时打乱数据顺序,确保模型训练时不会受到数据顺序的影响。

6.3 从 DataLoader 中获取一个批次的数据

train_features, train_labels = next(iter(train_dataloader))
  • iter(train_dataloader):
    train_dataloader 转换为一个迭代器。
  • next(...):
    从迭代器中获取下一个批次的数据。这里获取的是第一个批次的数据。
  • train_features:
    是一个包含 64 张图像的张量(Tensor),形状为 (64, C, H, W),其中:
    • 64 是批次大小,
    • C 是图像的通道数(例如,灰度图为 1,RGB 图为 3),
    • HW 分别是图像的高度和宽度。
  • train_labels:
    是一个包含 64 个标签的张量,形状为 (64,)

6.4 打印批次数据的形状

print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
  • train_features.size():
    打印图像张量的形状,例如 torch.Size([64, 1, 28, 28]),表示有 64 张 28x28 的灰度图像。
  • train_labels.size():
    打印标签张量的形状,例如 torch.Size([64]),表示有 64 个标签。

6.5 显示第一张图像和标签

img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
  • train_features[0]:
    获取批次中的第一张图像,形状为 (1, 28, 28)(假设是灰度图)。
  • squeeze():
    去除图像张量中的单维度,将形状从 (1, 28, 28) 变为 (28, 28),以便用 matplotlib 显示。
  • plt.imshow(img, cmap="gray"):
    使用 matplotlib 显示图像,cmap="gray" 表示以灰度图的形式显示。
  • plt.show():
    显示图像窗口。
  • print(f"Label: {label}"):
    打印第一张图像的标签。

7. 完整使用流程

假设你已经定义了一个 CustomImageDataset,并创建了 training_datatest_data 数据集,那么完整的流程如下:

# 创建数据集实例
training_data = CustomImageDataset(annotations_file="train_labels.csv", img_dir="train_images/")
test_data = CustomImageDataset(annotations_file="test_labels.csv", img_dir="test_images/")

# 创建 DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 获取一个批次的数据并显示
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")  # 打印图像张量的形状
print(f"Labels batch shape: {train_labels.size()}")     # 打印标签张量的形状

# 显示第一张图像和标签
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

8. 总结

  • CustomImageDataset 的作用
    这是一个自定义的数据集类,用于从指定目录加载图像数据和对应的标签。它支持对图像和标签进行预处理,并且可以与 PyTorch 的数据加载器(如 DataLoader)配合使用。

  • DataLoader 的作用
    将数据集分成小批次,方便模型训练。它支持多线程加载数据和打乱数据顺序。

  • 完整流程

    1. 定义数据集(如 CustomImageDataset)。
    2. 创建 DataLoader
    3. DataLoader 中获取数据并检查。

通过以上步骤,你可以轻松地加载和预处理自定义数据集,并将其用于训练深度学习模型。