这个帖子讲解了 如何保存和加载 PyTorch 模型,并用加载后的模型进行预测。我们分步解析它的作用。
1. 保存模型(Saving Models)
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
作用
model.state_dict():- 获取 模型的参数字典(
state_dict),即weights和bias等信息。
- 获取 模型的参数字典(
torch.save(..., "model.pth"):- 序列化(保存)模型参数,并存储到
"model.pth"文件中。
- 序列化(保存)模型参数,并存储到
- 这样以后可以重新加载参数,无需重新训练模型。
model.pth 文件
这个文件只存储 模型的参数(weights/biases),不包含模型结构。
2. 加载模型(Loading Models)
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))
作用
NeuralNetwork().to(device)- 重新创建模型结构,并移动到正确的计算设备(GPU/CPU)。
torch.load("model.pth")- 加载
model.pth文件中保存的参数。
- 加载
model.load_state_dict(...)- 将加载的参数填充到新建的模型中。
- 参数
weights_only=True- 仅加载权重参数,确保所有参数匹配。
如果 model.pth 文件与 NeuralNetwork() 结构不匹配,或pytorch版本不同,可能会报错。
在 PyTorch 中,weights_only 是 torch.load() 的一个无效参数。weights_only 是 PyTorch 1.9.0 新增的参数,通常用于 torch.load() 在加载模型时只加载模型的权重(而不是整个模型的其他信息)。如果你在加载模型时遇到错误,可能是因为你使用的 PyTorch 版本不支持这个参数。
删除 weights_only=True
将 torch.load("model.pth", weights_only=True) 这行代码修改为:
model.load_state_dict(torch.load("model.pth"))
这行代码将直接加载 model.pth 文件中的权重信息,无需使用 weights_only。
输出:
<All keys matched successfully>
表示模型参数成功加载。
3. 用模型进行预测
classes = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
作用
classes是 FashionMNIST 数据集中 10 个类别的名称。- 这些类别索引对应模型的输出(0-9)。
4. 运行推理(Inference)
model.eval() # 设定模型为评估模式
x, y = test_data[0][0], test_data[0][1] # 取出测试集的第一张图片和标签
with torch.no_grad(): # 关闭梯度计算,加速推理
x = x.to(device) # 将数据移到相同的设备(CPU/GPU)
pred = model(x) # 进行前向传播,得到预测结果
predicted, actual = classes[pred[0].argmax(0)], classes[y] # 获取预测类别和真实类别
print(f'Predicted: "{predicted}", Actual: "{actual}"')
详细解析
-
model.eval()- 设定模型为 评估模式(evaluation mode),关闭 dropout、BatchNorm 等影响推理的操作。
-
x, y = test_data[0][0], test_data[0][1]- 从测试集取出第一张图像 (
x) 和对应标签 (y)。 x是Tensor格式的 28x28 灰度图像。y是该图像的 真实类别索引(0-9)。
- 从测试集取出第一张图像 (
-
with torch.no_grad():- 禁用梯度计算,减少内存占用,加速推理。
-
x = x.to(device)- 确保数据和模型在同一设备(GPU/CPU)。
-
pred = model(x)- 通过神经网络 前向传播,得到预测结果
pred。
- 通过神经网络 前向传播,得到预测结果
-
predicted, actual = classes[pred[0].argmax(0)], classes[y]pred[0]是 logits(未归一化的得分)。argmax(0)取最大值的索引(预测类别)。classes[...]获取该索引对应的类别名称。y是真实类别的索引,对应classes[y]。
-
打印结果
Predicted: "Ankle boot", Actual: "Ankle boot"- 预测结果正确

- 预测结果正确
总结
这段代码的作用
- 保存模型 (
torch.save()):存储state_dict参数到model.pth。 - 加载模型 (
torch.load() + model.load_state_dict()):重建模型并填充参数。 - 设为评估模式 (
model.eval()):确保推理时行为正确。 - 进行预测:
- 获取一张测试图片。
- 经过模型推理,得到预测类别。
- 打印 预测结果 vs 真实类别。
这样,你的模型就可以随时加载并进行预测了! ![]()
问题:
为什么在 argmax 函数中要使用 0 参数?
答案:
argmax(0) 中的 0 表示计算 最大值的维度。这个参数告诉 PyTorch 在多维张量中沿哪个轴(维度)查找最大值。
理解 argmax 和维度:
argmax 用来返回最大值的索引。对于一个张量,你可以指定 dim(维度),让 argmax 在该维度上进行查找。例如,给定一个二维张量,你可以选择沿 行 或 列 查找最大值。
dim=0:沿着每一列查找最大值,即返回每一列最大值所在的行索引。dim=1:沿着每一行查找最大值,即返回每一行最大值所在的列索引。
示例:二维张量(多个样本的预测结果)
假设我们有一个包含两个样本的 2D 张量,每个样本的输出是 10 个类的分数(即 10 类的 logits)。
import torch
# 示例:二维张量(2 个样本,10 个类)
logits = torch.tensor([[0.2, 2.5, 1.3, 0.1, 0.6, 3.1, 0.3, 2.0, 0.4, 1.0],
[1.1, 0.5, 2.0, 0.3, 0.9, 0.4, 1.2, 2.7, 0.8, 1.5]])
# 沿着每一行(dim=1)查找最大值的索引
predictions_dim1 = torch.argmax(logits, dim=1)
print(predictions_dim1)
输出:
tensor([5, 7])
解释:
- 第一个样本(第一行):最大值是
3.1,对应的索引是5(类别 “Sandal”)。 - 第二个样本(第二行):最大值是
2.7,对应的索引是7(类别 “Sneaker”)。
因此,argmax(dim=1) 返回 [5, 7],表示两个样本的预测类别分别是 5 和 7。
如果我们改为 dim=0,结果会是:
# 沿着每一列(dim=0)查找最大值的索引
predictions_dim0 = torch.argmax(logits, dim=0)
print(predictions_dim0)
输出:
tensor([1, 0, 1, 0, 0, 0, 0, 1, 1, 1])
解释:
argmax(dim=0)会沿着 每一列 查找最大值,并返回每列最大值的行索引。- 第 1 列的最大值是
2.5,在第 1 行,所以返回1。 - 第 2 列的最大值是
2.0,在第 1 行,所以返回1。 - 以此类推。
- 第 1 列的最大值是
为什么在原始代码中使用 argmax(0)?
在原始代码中:
predicted = classes[pred[0].argmax(0)]
pred[0]是模型的输出(logits),它是一个一维张量,表示对于单个样本的每个类别的预测分数。- 使用
argmax(0)是为了从这个一维张量中找到最大值的索引(即预测的类别)。 - 对于一维张量来说,
dim=0就是找到该张量中最大值的索引。
总结:
argmax(0)是用于找到 最大值的索引,特别是在处理一维张量时,dim=0就是直接返回最大值的索引。- 对于 二维张量,
dim=0会沿着 每一列 查找最大值,而dim=1会沿着 每一行 查找最大值。
在分类任务中,我们通常使用 argmax(0) 来找到模型输出中最大分数的索引,从而确定预测的类别。