Pytorch入门 (8)- 可视化你的预测结果

机器学习模型测试代码详解(面向编程小白)

这部分代码是对上面几篇文章的一个补充,旨在得到每张图像的预测和真实标签的对比。

# 测试模型
model.eval()  # 设置为评估模式
with torch.no_grad():
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_images = []
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # 保存一些图像、预测和标签用于可视化
        all_images.append(images.cpu())
        all_preds.append(predicted.cpu())
        all_labels.append(labels.cpu())
        
        # 只保存足够的样本用于可视化
        if len(all_images) >= 1:  # 只需要一个batch就足够了
            break
            
    print(f'测试集准确率: {100 * correct / total:.2f}%')

# 可视化预测结果
def visualize_predictions(images, predictions, labels, num_samples=9):
    # 类别名称
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    # 创建一个3x3的子图
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    
    # 展平axes数组以便于索引
    axes = axes.flatten()
    
    # 确保我们只展示指定数量的样本
    images = torch.cat(all_images, 0)[:num_samples]
    predictions = torch.cat(all_preds, 0)[:num_samples]
    labels = torch.cat(all_labels, 0)[:num_samples]
    
    for i in range(num_samples):
        # 获取图像并转换为numpy数组
        img = images[i].squeeze().numpy()
        
        # 获取预测和真实标签
        pred = classes[predictions[i]]
        true_label = classes[labels[i]]
        
        # 在子图上显示图像
        axes[i].imshow(img, cmap='gray')
        
        # 设置标题为预测和真实标签
        title_color = 'green' if pred == true_label else 'red'
        axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)
        
        # 关闭坐标轴
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# 可视化一些预测结果
visualize_predictions(all_images, all_preds, all_labels)

这段代码主要是在测试一个已经训练好的机器学习模型,并且将预测结果进行可视化展示。我会分步骤解释这段代码的功能。

第一部分:测试模型性能

model.eval()  # 设置为评估模式
  • 这行代码告诉模型:“嘿,现在不是学习的时候,而是考试的时候了!”
  • eval()表示将模型设置为评估模式,这样模型就不会更新它的参数了
with torch.no_grad():
  • 这行代码告诉PyTorch:“接下来的操作不需要计算梯度”
  • 在测试时我们只需要得到预测结果,不需要反向传播来更新模型,这样可以节省内存并加速计算
    correct = 0  # 记录正确预测的样本数
    total = 0    # 记录总样本数
    all_preds = []    # 存储所有预测结果
    all_labels = []   # 存储所有真实标签
    all_images = []   # 存储所有图像
  • 初始化一些变量来记录测试结果
    for images, labels in test_dataloader:
  • 这是一个循环,从测试数据加载器中一批一批地获取数据
  • images是图像数据,labels是对应的真实标签
        images = images.to(device)
        labels = labels.to(device)
  • 将数据送到计算设备上(可能是GPU或CPU)
        outputs = model(images)
  • 用模型对图像进行预测,得到输出
        _, predicted = torch.max(outputs.data, 1)
  • torch.max找出每一行中最大值的索引
  • 这行代码的意思是找出模型预测的最可能的类别
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
  • 累加样本总数
  • 计算有多少预测和真实标签相匹配,并累加到correct变量
        # 保存一些图像、预测和标签用于可视化
        all_images.append(images.cpu())
        all_preds.append(predicted.cpu())
        all_labels.append(labels.cpu())
  • 将当前批次的图像、预测结果和真实标签保存下来,用于后面的可视化
  • .cpu()是将数据从GPU移回CPU,因为可视化需要在CPU上进行
        # 只保存足够的样本用于可视化
        if len(all_images) >= 1:  # 只需要一个batch就足够了
            break
  • 这段代码限制了我们只保存一批数据,因为我们只需要少量样本来展示
    print(f'测试集准确率: {100 * correct / total:.2f}%')
  • 计算并打印模型在测试集上的准确率

第二部分:可视化预测结果

def visualize_predictions(images, predictions, labels, num_samples=9):
  • 定义一个函数来可视化预测结果
  • 参数分别是图像、预测结果、真实标签,以及要展示的样本数量(默认为9个)
    # 类别名称
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
  • 这是一个列表,包含了所有可能的类别名称
  • 这段代码是用于Fashion-MNIST数据集,包含了10类不同的服装或鞋子
    # 创建一个3x3的子图
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    
    # 展平axes数组以便于索引
    axes = axes.flatten()
  • 创建一个3行3列的图表布局,总共可以显示9张图像
  • flatten()是将二维数组转换成一维数组,便于后面的循环操作
    # 确保我们只展示指定数量的样本
    images = torch.cat(all_images, 0)[:num_samples]
    predictions = torch.cat(all_preds, 0)[:num_samples]
    labels = torch.cat(all_labels, 0)[:num_samples]
  • torch.cat是将列表中的所有张量连接起来
  • 这里取前num_samples个样本用于展示
    for i in range(num_samples):
  • 对每一个样本进行处理
        # 获取图像并转换为numpy数组
        img = images[i].squeeze().numpy()
  • 获取第i个图像
  • squeeze()去掉尺寸为1的维度
  • .numpy()将PyTorch张量转换为NumPy数组,因为Matplotlib使用NumPy数组来显示图像
        # 获取预测和真实标签
        pred = classes[predictions[i]]
        true_label = classes[labels[i]]
  • 根据预测和真实标签的索引,获取对应的类别名称
        # 在子图上显示图像
        axes[i].imshow(img, cmap='gray')
  • 在第i个子图上显示图像
  • cmap='gray'表示使用灰度颜色映射,因为Fashion-MNIST的图像是灰度的
        # 设置标题为预测和真实标签
        title_color = 'green' if pred == true_label else 'red'
        axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)
  • 设置子图的标题,包含预测和真实标签
  • 如果预测正确,标题显示为绿色;如果预测错误,标题显示为红色
        # 关闭坐标轴
        axes[i].axis('off')
  • 关闭坐标轴,使图像看起来更整洁
    plt.tight_layout()
    plt.show()
  • tight_layout()调整子图之间的间距,使显示更美观
  • show()显示整个图表
# 可视化一些预测结果
visualize_predictions(all_images, all_preds, all_labels)
  • 调用前面定义的函数,使用收集到的图像、预测和标签来可视化结果

总结:这段代码测试了一个机器学习模型在时尚物品识别任务上的表现,并将部分预测结果以可视化的方式展示出来,方便我们直观地了解模型的性能。