在 PyTorch 中,定义一个自定义数据集类(Dataset
)需要实现三个核心函数:__init__
、__len__
和 __getitem__
。下面我们将通过一个具体的例子来详细解释这些函数的作用和实现方式。假设我们有一个 FashionMNIST 数据集,图像存储在一个名为 img_dir
的目录中,而标签则存储在一个单独的 CSV 文件 annotations_file
中。
1. 导入必要的库
首先,我们需要导入一些必要的库:
import os
import pandas as pd
from torchvision.io import read_image
os
: 用于处理文件路径和目录操作。pandas as pd
: 用于读取和处理 CSV 文件(通常存储图像文件名和标签)。read_image
: 从torchvision.io
导入的函数,用于加载图像文件并将其转换为 PyTorch 张量(Tensor)。
2. 定义 CustomImageDataset
类
接下来,我们定义一个自定义的数据集类 CustomImageDataset
,继承自 PyTorch 的 Dataset
类:
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- 这是一个自定义的数据集类,继承自 PyTorch 的
Dataset
类。PyTorch 的Dataset
类是所有数据集的基类,自定义数据集需要实现__len__
和__getitem__
方法。
3. __init__
方法:初始化数据集
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
annotations_file
: 是一个 CSV 文件的路径,文件内容通常是图像文件名和对应的标签。img_dir
: 是图像文件存储的目录路径。transform
: 是一个可选的图像预处理函数(例如,调整大小、归一化等)。target_transform
: 是一个可选的标签预处理函数(例如,将标签转换为 one-hot 编码)。self.img_labels = pd.read_csv(annotations_file)
: 读取 CSV 文件,将其存储为一个 Pandas DataFrame。DataFrame 的每一行通常包含图像文件名和对应的标签。self.img_dir = img_dir
: 存储图像目录路径。self.transform
和self.target_transform
: 存储图像和标签的预处理函数。
4. __len__
方法:返回数据集的大小
def __len__(self):
return len(self.img_labels)
- 返回数据集中样本的数量(即 CSV 文件中的行数)。
5. __getitem__
方法:获取单个样本
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
idx
: 是数据集中样本的索引(从 0 到len(dataset)-1
)。img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
:
根据索引idx
,从 CSV 文件中获取图像文件名,并将其与图像目录路径self.img_dir
拼接,得到完整的图像文件路径。image = read_image(img_path)
:
使用read_image
函数加载图像文件,并将其转换为 PyTorch 张量(Tensor)。label = self.img_labels.iloc[idx, 1]
:
从 CSV 文件中获取对应图像的标签。if self.transform: image = self.transform(image)
:
如果定义了图像预处理函数self.transform
,则对图像进行预处理。if self.target_transform: label = self.target_transform(label)
:
如果定义了标签预处理函数self.target_transform
,则对标签进行预处理。return image, label
:
返回处理后的图像和标签。
6. 使用 DataLoader
加载数据
为了在训练模型时方便地加载数据,我们可以使用 PyTorch 的 DataLoader
。DataLoader
可以将数据集分成小批次(batch),并支持多线程加载数据和打乱数据顺序。
下面是载入数据并可视化的代码,我们将详细解释每一段的内容。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
6.1 导入 DataLoader
from torch.utils.data import DataLoader
DataLoader
是 PyTorch 提供的一个工具,用于将数据集分成小批次(batch),并支持多线程加载数据、打乱数据等功能。它通常用于训练深度学习模型。
6.2 创建 DataLoader
实例
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
training_data
和test_data
:
这是两个数据集对象,通常是 PyTorch 的Dataset
实例(例如前面定义的CustomImageDataset
)。batch_size=64
:
每个批次包含 64 个样本(图像和标签)。shuffle=True
:
在每个 epoch 开始时打乱数据顺序,确保模型训练时不会受到数据顺序的影响。
6.3 从 DataLoader
中获取一个批次的数据
train_features, train_labels = next(iter(train_dataloader))
iter(train_dataloader)
:
将train_dataloader
转换为一个迭代器。next(...)
:
从迭代器中获取下一个批次的数据。这里获取的是第一个批次的数据。train_features
:
是一个包含 64 张图像的张量(Tensor),形状为(64, C, H, W)
,其中:64
是批次大小,C
是图像的通道数(例如,灰度图为 1,RGB 图为 3),H
和W
分别是图像的高度和宽度。
train_labels
:
是一个包含 64 个标签的张量,形状为(64,)
。
6.4 打印批次数据的形状
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
train_features.size()
:
打印图像张量的形状,例如torch.Size([64, 1, 28, 28])
,表示有 64 张 28x28 的灰度图像。train_labels.size()
:
打印标签张量的形状,例如torch.Size([64])
,表示有 64 个标签。
6.5 显示第一张图像和标签
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
train_features[0]
:
获取批次中的第一张图像,形状为(1, 28, 28)
(假设是灰度图)。squeeze()
:
去除图像张量中的单维度,将形状从(1, 28, 28)
变为(28, 28)
,以便用matplotlib
显示。plt.imshow(img, cmap="gray")
:
使用matplotlib
显示图像,cmap="gray"
表示以灰度图的形式显示。plt.show()
:
显示图像窗口。print(f"Label: {label}")
:
打印第一张图像的标签。
7. 完整使用流程
假设你已经定义了一个 CustomImageDataset
,并创建了 training_data
和 test_data
数据集,那么完整的流程如下:
# 创建数据集实例
training_data = CustomImageDataset(annotations_file="train_labels.csv", img_dir="train_images/")
test_data = CustomImageDataset(annotations_file="test_labels.csv", img_dir="test_images/")
# 创建 DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# 获取一个批次的数据并显示
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # 打印图像张量的形状
print(f"Labels batch shape: {train_labels.size()}") # 打印标签张量的形状
# 显示第一张图像和标签
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
8. 总结
-
CustomImageDataset
的作用:
这是一个自定义的数据集类,用于从指定目录加载图像数据和对应的标签。它支持对图像和标签进行预处理,并且可以与 PyTorch 的数据加载器(如DataLoader
)配合使用。 -
DataLoader
的作用:
将数据集分成小批次,方便模型训练。它支持多线程加载数据和打乱数据顺序。 -
完整流程:
- 定义数据集(如
CustomImageDataset
)。 - 创建
DataLoader
。 - 从
DataLoader
中获取数据并检查。
- 定义数据集(如
通过以上步骤,你可以轻松地加载和预处理自定义数据集,并将其用于训练深度学习模型。