机器学习模型测试代码详解(面向编程小白)
这部分代码是对上面几篇文章的一个补充,旨在得到每张图像的预测和真实标签的对比。
# 测试模型
model.eval() # 设置为评估模式
with torch.no_grad():
correct = 0
total = 0
all_preds = []
all_labels = []
all_images = []
for images, labels in test_dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 保存一些图像、预测和标签用于可视化
all_images.append(images.cpu())
all_preds.append(predicted.cpu())
all_labels.append(labels.cpu())
# 只保存足够的样本用于可视化
if len(all_images) >= 1: # 只需要一个batch就足够了
break
print(f'测试集准确率: {100 * correct / total:.2f}%')
# 可视化预测结果
def visualize_predictions(images, predictions, labels, num_samples=9):
# 类别名称
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 创建一个3x3的子图
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
# 展平axes数组以便于索引
axes = axes.flatten()
# 确保我们只展示指定数量的样本
images = torch.cat(all_images, 0)[:num_samples]
predictions = torch.cat(all_preds, 0)[:num_samples]
labels = torch.cat(all_labels, 0)[:num_samples]
for i in range(num_samples):
# 获取图像并转换为numpy数组
img = images[i].squeeze().numpy()
# 获取预测和真实标签
pred = classes[predictions[i]]
true_label = classes[labels[i]]
# 在子图上显示图像
axes[i].imshow(img, cmap='gray')
# 设置标题为预测和真实标签
title_color = 'green' if pred == true_label else 'red'
axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)
# 关闭坐标轴
axes[i].axis('off')
plt.tight_layout()
plt.show()
# 可视化一些预测结果
visualize_predictions(all_images, all_preds, all_labels)
这段代码主要是在测试一个已经训练好的机器学习模型,并且将预测结果进行可视化展示。我会分步骤解释这段代码的功能。
第一部分:测试模型性能
model.eval() # 设置为评估模式
- 这行代码告诉模型:“嘿,现在不是学习的时候,而是考试的时候了!”
eval()
表示将模型设置为评估模式,这样模型就不会更新它的参数了
with torch.no_grad():
- 这行代码告诉PyTorch:“接下来的操作不需要计算梯度”
- 在测试时我们只需要得到预测结果,不需要反向传播来更新模型,这样可以节省内存并加速计算
correct = 0 # 记录正确预测的样本数
total = 0 # 记录总样本数
all_preds = [] # 存储所有预测结果
all_labels = [] # 存储所有真实标签
all_images = [] # 存储所有图像
- 初始化一些变量来记录测试结果
for images, labels in test_dataloader:
- 这是一个循环,从测试数据加载器中一批一批地获取数据
images
是图像数据,labels
是对应的真实标签
images = images.to(device)
labels = labels.to(device)
- 将数据送到计算设备上(可能是GPU或CPU)
outputs = model(images)
- 用模型对图像进行预测,得到输出
_, predicted = torch.max(outputs.data, 1)
torch.max
找出每一行中最大值的索引- 这行代码的意思是找出模型预测的最可能的类别
total += labels.size(0)
correct += (predicted == labels).sum().item()
- 累加样本总数
- 计算有多少预测和真实标签相匹配,并累加到
correct
变量
# 保存一些图像、预测和标签用于可视化
all_images.append(images.cpu())
all_preds.append(predicted.cpu())
all_labels.append(labels.cpu())
- 将当前批次的图像、预测结果和真实标签保存下来,用于后面的可视化
.cpu()
是将数据从GPU移回CPU,因为可视化需要在CPU上进行
# 只保存足够的样本用于可视化
if len(all_images) >= 1: # 只需要一个batch就足够了
break
- 这段代码限制了我们只保存一批数据,因为我们只需要少量样本来展示
print(f'测试集准确率: {100 * correct / total:.2f}%')
- 计算并打印模型在测试集上的准确率
第二部分:可视化预测结果
def visualize_predictions(images, predictions, labels, num_samples=9):
- 定义一个函数来可视化预测结果
- 参数分别是图像、预测结果、真实标签,以及要展示的样本数量(默认为9个)
# 类别名称
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
- 这是一个列表,包含了所有可能的类别名称
- 这段代码是用于Fashion-MNIST数据集,包含了10类不同的服装或鞋子
# 创建一个3x3的子图
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
# 展平axes数组以便于索引
axes = axes.flatten()
- 创建一个3行3列的图表布局,总共可以显示9张图像
flatten()
是将二维数组转换成一维数组,便于后面的循环操作
# 确保我们只展示指定数量的样本
images = torch.cat(all_images, 0)[:num_samples]
predictions = torch.cat(all_preds, 0)[:num_samples]
labels = torch.cat(all_labels, 0)[:num_samples]
torch.cat
是将列表中的所有张量连接起来- 这里取前
num_samples
个样本用于展示
for i in range(num_samples):
- 对每一个样本进行处理
# 获取图像并转换为numpy数组
img = images[i].squeeze().numpy()
- 获取第i个图像
squeeze()
去掉尺寸为1的维度.numpy()
将PyTorch张量转换为NumPy数组,因为Matplotlib使用NumPy数组来显示图像
# 获取预测和真实标签
pred = classes[predictions[i]]
true_label = classes[labels[i]]
- 根据预测和真实标签的索引,获取对应的类别名称
# 在子图上显示图像
axes[i].imshow(img, cmap='gray')
- 在第i个子图上显示图像
cmap='gray'
表示使用灰度颜色映射,因为Fashion-MNIST的图像是灰度的
# 设置标题为预测和真实标签
title_color = 'green' if pred == true_label else 'red'
axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)
- 设置子图的标题,包含预测和真实标签
- 如果预测正确,标题显示为绿色;如果预测错误,标题显示为红色
# 关闭坐标轴
axes[i].axis('off')
- 关闭坐标轴,使图像看起来更整洁
plt.tight_layout()
plt.show()
tight_layout()
调整子图之间的间距,使显示更美观show()
显示整个图表
# 可视化一些预测结果
visualize_predictions(all_images, all_preds, all_labels)
- 调用前面定义的函数,使用收集到的图像、预测和标签来可视化结果
总结:这段代码测试了一个机器学习模型在时尚物品识别任务上的表现,并将部分预测结果以可视化的方式展示出来,方便我们直观地了解模型的性能。