我们除了train() (可以看上个帖子)来训练模型,还需要test()来检查评估模型训练的效果。这里使用了test dataset来测试模型训练情况:
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
这个 test()
函数的作用是什么?
该函数评估模型的性能:
- 切换到评估模式 (
model.eval()
)。 - 关闭梯度计算 (
torch.no_grad()
) 以加速测试。 - 遍历测试数据,计算预测结果。
- 计算平均损失和准确率。
- 打印最终的测试结果。
代码逐行解析
1. 定义测试函数
def test(dataloader, model, loss_fn):
dataloader
:测试数据加载器。model
:已经训练好的神经网络。loss_fn
:损失函数,用于评估预测质量。
2. 计算测试集大小
size = len(dataloader.dataset)
num_batches = len(dataloader)
size
→ 测试集总样本数。num_batches
→ mini-batch 的数量。
3. 切换到评估模式
model.eval()
- 让模型进入评估模式,关闭 Dropout、BatchNorm 的动态行为。
4. 关闭梯度计算
with torch.no_grad():
- 测试时不需要计算梯度,所以关闭它来节省显存和加快推理速度。
5. 初始化测试损失和正确预测数
test_loss, correct = 0, 0
test_loss
→ 累积所有 batch 的损失值。correct
→ 记录预测正确的样本数。
6. 遍历测试数据
for X, y in dataloader:
- 取出
X
和y
进行推理。
7. 移动数据到 GPU
X, y = X.to(device), y.to(device)
- 让
X
和y
进入GPU 或 CPU 进行计算。
8. 计算预测结果
pred = model(X)
- 让
X
通过模型,得到预测结果pred
。
9. 计算损失
test_loss += loss_fn(pred, y).item()
- 计算当前 batch 的损失,并累积到
test_loss
。
10. 计算正确预测数
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
pred.argmax(1)
选出最大值的索引,即预测类别。- 计算与真实
y
相等的样本数量,并累加到correct
。
11. 计算平均损失和准确率
test_loss /= num_batches
correct /= size
- 计算平均损失。
- 计算准确率(正确数 / 总数)。
12. 打印测试结果
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
- 输出最终准确率和平均损失。
重点总结
评估模型性能。
不计算梯度,节省资源。
衡量泛化能力,防止过拟合。
希望这个详细解析对你有帮助!如果有疑问,欢迎继续交流!