这个帖子讲解了 如何保存和加载 PyTorch 模型,并用加载后的模型进行预测。我们分步解析它的作用。
1. 保存模型(Saving Models)
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
作用
model.state_dict()
:- 获取 模型的参数字典(
state_dict
),即weights
和bias
等信息。
- 获取 模型的参数字典(
torch.save(..., "model.pth")
:- 序列化(保存)模型参数,并存储到
"model.pth"
文件中。
- 序列化(保存)模型参数,并存储到
- 这样以后可以重新加载参数,无需重新训练模型。
model.pth
文件
这个文件只存储 模型的参数(weights/biases),不包含模型结构。
2. 加载模型(Loading Models)
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth", weights_only=True))
作用
NeuralNetwork().to(device)
- 重新创建模型结构,并移动到正确的计算设备(GPU/CPU)。
torch.load("model.pth")
- 加载
model.pth
文件中保存的参数。
- 加载
model.load_state_dict(...)
- 将加载的参数填充到新建的模型中。
- 参数
weights_only=True
- 仅加载权重参数,确保所有参数匹配。
如果 model.pth
文件与 NeuralNetwork()
结构不匹配,或pytorch版本不同,可能会报错。
在 PyTorch 中,weights_only
是 torch.load()
的一个无效参数。weights_only
是 PyTorch 1.9.0 新增的参数,通常用于 torch.load()
在加载模型时只加载模型的权重(而不是整个模型的其他信息)。如果你在加载模型时遇到错误,可能是因为你使用的 PyTorch 版本不支持这个参数。
删除 weights_only=True
将 torch.load("model.pth", weights_only=True)
这行代码修改为:
model.load_state_dict(torch.load("model.pth"))
这行代码将直接加载 model.pth
文件中的权重信息,无需使用 weights_only
。
输出:
<All keys matched successfully>
表示模型参数成功加载。
3. 用模型进行预测
classes = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
作用
classes
是 FashionMNIST 数据集中 10 个类别的名称。- 这些类别索引对应模型的输出(0-9)。
4. 运行推理(Inference)
model.eval() # 设定模型为评估模式
x, y = test_data[0][0], test_data[0][1] # 取出测试集的第一张图片和标签
with torch.no_grad(): # 关闭梯度计算,加速推理
x = x.to(device) # 将数据移到相同的设备(CPU/GPU)
pred = model(x) # 进行前向传播,得到预测结果
predicted, actual = classes[pred[0].argmax(0)], classes[y] # 获取预测类别和真实类别
print(f'Predicted: "{predicted}", Actual: "{actual}"')
详细解析
-
model.eval()
- 设定模型为 评估模式(evaluation mode),关闭 dropout、BatchNorm 等影响推理的操作。
-
x, y = test_data[0][0], test_data[0][1]
- 从测试集取出第一张图像 (
x
) 和对应标签 (y
)。 x
是Tensor
格式的 28x28 灰度图像。y
是该图像的 真实类别索引(0-9)。
- 从测试集取出第一张图像 (
-
with torch.no_grad():
- 禁用梯度计算,减少内存占用,加速推理。
-
x = x.to(device)
- 确保数据和模型在同一设备(GPU/CPU)。
-
pred = model(x)
- 通过神经网络 前向传播,得到预测结果
pred
。
- 通过神经网络 前向传播,得到预测结果
-
predicted, actual = classes[pred[0].argmax(0)], classes[y]
pred[0]
是 logits(未归一化的得分)。argmax(0)
取最大值的索引(预测类别)。classes[...]
获取该索引对应的类别名称。y
是真实类别的索引,对应classes[y]
。
-
打印结果
Predicted: "Ankle boot", Actual: "Ankle boot"
- 预测结果正确
- 预测结果正确
总结
这段代码的作用
- 保存模型 (
torch.save()
):存储state_dict
参数到model.pth
。 - 加载模型 (
torch.load() + model.load_state_dict()
):重建模型并填充参数。 - 设为评估模式 (
model.eval()
):确保推理时行为正确。 - 进行预测:
- 获取一张测试图片。
- 经过模型推理,得到预测类别。
- 打印 预测结果 vs 真实类别。
这样,你的模型就可以随时加载并进行预测了!
问题:
为什么在 argmax
函数中要使用 0
参数?
答案:
argmax(0)
中的 0
表示计算 最大值的维度。这个参数告诉 PyTorch 在多维张量中沿哪个轴(维度)查找最大值。
理解 argmax
和维度:
argmax
用来返回最大值的索引。对于一个张量,你可以指定 dim
(维度),让 argmax
在该维度上进行查找。例如,给定一个二维张量,你可以选择沿 行 或 列 查找最大值。
dim=0
:沿着每一列查找最大值,即返回每一列最大值所在的行索引。dim=1
:沿着每一行查找最大值,即返回每一行最大值所在的列索引。
示例:二维张量(多个样本的预测结果)
假设我们有一个包含两个样本的 2D 张量,每个样本的输出是 10 个类的分数(即 10 类的 logits)。
import torch
# 示例:二维张量(2 个样本,10 个类)
logits = torch.tensor([[0.2, 2.5, 1.3, 0.1, 0.6, 3.1, 0.3, 2.0, 0.4, 1.0],
[1.1, 0.5, 2.0, 0.3, 0.9, 0.4, 1.2, 2.7, 0.8, 1.5]])
# 沿着每一行(dim=1)查找最大值的索引
predictions_dim1 = torch.argmax(logits, dim=1)
print(predictions_dim1)
输出:
tensor([5, 7])
解释:
- 第一个样本(第一行):最大值是
3.1
,对应的索引是5
(类别 “Sandal”)。 - 第二个样本(第二行):最大值是
2.7
,对应的索引是7
(类别 “Sneaker”)。
因此,argmax(dim=1)
返回 [5, 7]
,表示两个样本的预测类别分别是 5
和 7
。
如果我们改为 dim=0
,结果会是:
# 沿着每一列(dim=0)查找最大值的索引
predictions_dim0 = torch.argmax(logits, dim=0)
print(predictions_dim0)
输出:
tensor([1, 0, 1, 0, 0, 0, 0, 1, 1, 1])
解释:
argmax(dim=0)
会沿着 每一列 查找最大值,并返回每列最大值的行索引。- 第 1 列的最大值是
2.5
,在第 1 行,所以返回1
。 - 第 2 列的最大值是
2.0
,在第 1 行,所以返回1
。 - 以此类推。
- 第 1 列的最大值是
为什么在原始代码中使用 argmax(0)
?
在原始代码中:
predicted = classes[pred[0].argmax(0)]
pred[0]
是模型的输出(logits),它是一个一维张量,表示对于单个样本的每个类别的预测分数。- 使用
argmax(0)
是为了从这个一维张量中找到最大值的索引(即预测的类别)。 - 对于一维张量来说,
dim=0
就是找到该张量中最大值的索引。
总结:
argmax(0)
是用于找到 最大值的索引,特别是在处理一维张量时,dim=0
就是直接返回最大值的索引。- 对于 二维张量,
dim=0
会沿着 每一列 查找最大值,而dim=1
会沿着 每一行 查找最大值。
在分类任务中,我们通常使用 argmax(0)
来找到模型输出中最大分数的索引,从而确定预测的类别。