Pytorch入门(4)- 快速上手模型训练test()

我们除了train() (可以看上个帖子)来训练模型,还需要test()来检查评估模型训练的效果。这里使用了test dataset来测试模型训练情况:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

:rocket: 这个 test() 函数的作用是什么?

该函数评估模型的性能

  1. 切换到评估模式 (model.eval())。
  2. 关闭梯度计算 (torch.no_grad()) 以加速测试。
  3. 遍历测试数据,计算预测结果。
  4. 计算平均损失和准确率
  5. 打印最终的测试结果

:mag: 代码逐行解析

:small_blue_diamond: 1. 定义测试函数

def test(dataloader, model, loss_fn):
  • dataloader:测试数据加载器。
  • model:已经训练好的神经网络。
  • loss_fn:损失函数,用于评估预测质量。

:small_blue_diamond: 2. 计算测试集大小

size = len(dataloader.dataset)
num_batches = len(dataloader)
  • size → 测试集总样本数。
  • num_batches → mini-batch 的数量。

:small_blue_diamond: 3. 切换到评估模式

model.eval()
  • 让模型进入评估模式,关闭 Dropout、BatchNorm 的动态行为。

:small_blue_diamond: 4. 关闭梯度计算

with torch.no_grad():
  • 测试时不需要计算梯度,所以关闭它来节省显存加快推理速度

:small_blue_diamond: 5. 初始化测试损失和正确预测数

test_loss, correct = 0, 0
  • test_loss → 累积所有 batch 的损失值。
  • correct → 记录预测正确的样本数。

:small_blue_diamond: 6. 遍历测试数据

for X, y in dataloader:
  • 取出 Xy 进行推理。

:small_blue_diamond: 7. 移动数据到 GPU

X, y = X.to(device), y.to(device)
  • Xy 进入GPU 或 CPU 进行计算。

:small_blue_diamond: 8. 计算预测结果

pred = model(X)
  • X 通过模型,得到预测结果 pred

:small_blue_diamond: 9. 计算损失

test_loss += loss_fn(pred, y).item()
  • 计算当前 batch 的损失,并累积到 test_loss

:small_blue_diamond: 10. 计算正确预测数

correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  • pred.argmax(1) 选出最大值的索引,即预测类别。
  • 计算与真实 y 相等的样本数量,并累加到 correct

:small_blue_diamond: 11. 计算平均损失和准确率

test_loss /= num_batches
correct /= size
  • 计算平均损失
  • 计算准确率(正确数 / 总数)。

:small_blue_diamond: 12. 打印测试结果

print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  • 输出最终准确率平均损失

:dart: 重点总结

:white_check_mark: 评估模型性能
:white_check_mark: 不计算梯度,节省资源
:white_check_mark: 衡量泛化能力,防止过拟合


:rocket: 希望这个详细解析对你有帮助!如果有疑问,欢迎继续交流! :dart: