Pytorch入门(5) - 快速上手使用自己训练的模型

这个帖子讲解了 如何保存和加载 PyTorch 模型,并用加载后的模型进行预测。我们分步解析它的作用。


1. 保存模型(Saving Models)

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

:pushpin: 作用

  • model.state_dict()
    • 获取 模型的参数字典state_dict),即 weightsbias 等信息。
  • torch.save(..., "model.pth")
    • 序列化(保存)模型参数,并存储到 "model.pth" 文件中。
  • 这样以后可以重新加载参数,无需重新训练模型。

:open_file_folder: model.pth 文件

这个文件只存储 模型的参数(weights/biases),不包含模型结构。


2. 加载模型(Loading Models)

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))

:pushpin: 作用

  • NeuralNetwork().to(device)
    • 重新创建模型结构,并移动到正确的计算设备(GPU/CPU)。
  • torch.load("model.pth")
    • 加载 model.pth 文件中保存的参数。
  • model.load_state_dict(...)
    • 将加载的参数填充到新建的模型中。
  • 参数 weights_only=True
    • 仅加载权重参数,确保所有参数匹配。

如果 model.pth 文件与 NeuralNetwork() 结构不匹配,或pytorch版本不同,可能会报错。

在 PyTorch 中,weights_onlytorch.load() 的一个无效参数。weights_only 是 PyTorch 1.9.0 新增的参数,通常用于 torch.load() 在加载模型时只加载模型的权重(而不是整个模型的其他信息)。如果你在加载模型时遇到错误,可能是因为你使用的 PyTorch 版本不支持这个参数。

删除 weights_only=True

torch.load("model.pth", weights_only=True) 这行代码修改为:

model.load_state_dict(torch.load("model.pth"))

这行代码将直接加载 model.pth 文件中的权重信息,无需使用 weights_only

输出:

<All keys matched successfully>

表示模型参数成功加载。


3. 用模型进行预测

classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]

:pushpin: 作用

  • classesFashionMNIST 数据集中 10 个类别的名称。
  • 这些类别索引对应模型的输出(0-9)。

4. 运行推理(Inference)

model.eval()  # 设定模型为评估模式
x, y = test_data[0][0], test_data[0][1]  # 取出测试集的第一张图片和标签
with torch.no_grad():  # 关闭梯度计算,加速推理
    x = x.to(device)  # 将数据移到相同的设备(CPU/GPU)
    pred = model(x)  # 进行前向传播,得到预测结果
    predicted, actual = classes[pred[0].argmax(0)], classes[y]  # 获取预测类别和真实类别
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

:pushpin: 详细解析

  1. model.eval()

    • 设定模型为 评估模式(evaluation mode),关闭 dropout、BatchNorm 等影响推理的操作。
  2. x, y = test_data[0][0], test_data[0][1]

    • 从测试集取出第一张图像 (x) 和对应标签 (y)。
    • xTensor 格式的 28x28 灰度图像
    • y 是该图像的 真实类别索引(0-9)。
  3. with torch.no_grad():

    • 禁用梯度计算,减少内存占用,加速推理。
  4. x = x.to(device)

    • 确保数据和模型在同一设备(GPU/CPU)。
  5. pred = model(x)

    • 通过神经网络 前向传播,得到预测结果 pred
  6. predicted, actual = classes[pred[0].argmax(0)], classes[y]

    • pred[0]logits(未归一化的得分)
    • argmax(0) 取最大值的索引(预测类别)。
    • classes[...] 获取该索引对应的类别名称。
    • y 是真实类别的索引,对应 classes[y]
  7. 打印结果

    Predicted: "Ankle boot", Actual: "Ankle boot"
    
    • 预测结果正确 :white_check_mark:

总结

:pushpin: 这段代码的作用

  1. 保存模型 (torch.save()):存储 state_dict 参数到 model.pth
  2. 加载模型 (torch.load() + model.load_state_dict()):重建模型并填充参数。
  3. 设为评估模式 (model.eval()):确保推理时行为正确。
  4. 进行预测
    • 获取一张测试图片。
    • 经过模型推理,得到预测类别。
    • 打印 预测结果 vs 真实类别

:white_check_mark: 这样,你的模型就可以随时加载并进行预测了! :rocket:


问题:

为什么在 argmax 函数中要使用 0 参数?

答案:

argmax(0) 中的 0 表示计算 最大值的维度。这个参数告诉 PyTorch 在多维张量中沿哪个轴(维度)查找最大值。

理解 argmax 和维度

argmax 用来返回最大值的索引。对于一个张量,你可以指定 dim(维度),让 argmax 在该维度上进行查找。例如,给定一个二维张量,你可以选择沿 查找最大值。

  • dim=0:沿着每一列查找最大值,即返回每一列最大值所在的行索引。
  • dim=1:沿着每一行查找最大值,即返回每一行最大值所在的列索引。

示例:二维张量(多个样本的预测结果)

假设我们有一个包含两个样本的 2D 张量,每个样本的输出是 10 个类的分数(即 10 类的 logits)。

import torch

# 示例:二维张量(2 个样本,10 个类)
logits = torch.tensor([[0.2, 2.5, 1.3, 0.1, 0.6, 3.1, 0.3, 2.0, 0.4, 1.0], 
                       [1.1, 0.5, 2.0, 0.3, 0.9, 0.4, 1.2, 2.7, 0.8, 1.5]])

# 沿着每一行(dim=1)查找最大值的索引
predictions_dim1 = torch.argmax(logits, dim=1)
print(predictions_dim1)

输出:

tensor([5, 7])

解释

  • 第一个样本(第一行):最大值是 3.1,对应的索引是 5(类别 “Sandal”)。
  • 第二个样本(第二行):最大值是 2.7,对应的索引是 7(类别 “Sneaker”)。

因此,argmax(dim=1) 返回 [5, 7],表示两个样本的预测类别分别是 57

如果我们改为 dim=0,结果会是:

# 沿着每一列(dim=0)查找最大值的索引
predictions_dim0 = torch.argmax(logits, dim=0)
print(predictions_dim0)

输出:

tensor([1, 0, 1, 0, 0, 0, 0, 1, 1, 1])

解释

  • argmax(dim=0) 会沿着 每一列 查找最大值,并返回每列最大值的行索引。
    • 第 1 列的最大值是 2.5,在第 1 行,所以返回 1
    • 第 2 列的最大值是 2.0,在第 1 行,所以返回 1
    • 以此类推。

为什么在原始代码中使用 argmax(0)

在原始代码中:

predicted = classes[pred[0].argmax(0)]
  • pred[0] 是模型的输出(logits),它是一个一维张量,表示对于单个样本的每个类别的预测分数。
  • 使用 argmax(0) 是为了从这个一维张量中找到最大值的索引(即预测的类别)。
  • 对于一维张量来说,dim=0 就是找到该张量中最大值的索引。

总结

  • argmax(0) 是用于找到 最大值的索引,特别是在处理一维张量时,dim=0 就是直接返回最大值的索引。
  • 对于 二维张量dim=0 会沿着 每一列 查找最大值,而 dim=1 会沿着 每一行 查找最大值。

在分类任务中,我们通常使用 argmax(0) 来找到模型输出中最大分数的索引,从而确定预测的类别。