Pytorch入门(3) - 快速上手模型训练train()

要训练模型,你需要一个损失函数loss function和一个优化器optimizer。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

一个简单的训练函数如下:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
  
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
  
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
  
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]") 

:rocket: 这个 train() 函数的作用是什么?

这个函数用于训练神经网络,它的核心流程是:

  1. 把模型设为训练模式 (model.train())。
  2. 遍历训练数据 并将其分批(batch)送入模型。
  3. 计算预测结果和损失 (loss_fn(pred, y))。
  4. 反向传播 (loss.backward()),计算梯度。
  5. 更新模型参数 (optimizer.step())。
  6. 清除旧的梯度信息 (optimizer.zero_grad())。
  7. 每 100 个 batch 记录一次损失值

:mag: 代码逐行解析

:small_blue_diamond: 1. 定义训练函数

def train(dataloader, model, loss_fn, optimizer):

参数解释:

  • dataloader:数据加载器,把训练数据分成小批次(batch)。
  • model:要训练的神经网络模型。
  • loss_fn:损失函数(例如 CrossEntropyLoss),用于衡量预测结果和真实标签的差距。
  • optimizer:优化器(例如 SGDAdam),用于调整模型参数,使损失降低。

:small_blue_diamond: 2. 计算训练集大小

size = len(dataloader.dataset)
  • 获取整个训练集的数据总数,以便后面计算进度。

:small_blue_diamond: 3. 设定模型为训练模式

model.train()
  • 让模型进入训练模式,这会影响BatchNormDropout 层的行为:
    • BatchNorm:会计算新的均值和方差。
    • Dropout:会随机丢弃一部分神经元,提高泛化能力。

:small_blue_diamond: 4. 遍历训练数据

for batch, (X, y) in enumerate(dataloader):
  • dataloader自动把数据拆成多个小 batch 并逐个取出。
  • X:当前 batch 的输入数据。
  • y:当前 batch 的真实标签。

:small_blue_diamond: 5. 把数据转移到 GPU(如果可用)

X, y = X.to(device), y.to(device)
  • Xy 进入 GPUCPU,加快计算速度。

:small_blue_diamond: 6. 计算预测结果

pred = model(X)
  • X 通过模型,得到预测结果 pred

:small_blue_diamond: 7. 计算损失

loss = loss_fn(pred, y)
  • 使用损失函数 loss_fn 来计算当前 batch 的损失值。

:small_blue_diamond: 8. 反向传播计算梯度

loss.backward()
  • 计算损失对每个参数的导数(梯度),用于优化参数。

:small_blue_diamond: 9. 更新模型参数

optimizer.step()
  • 使用优化器更新模型的参数:
    • SGD/Adam 这些优化器会读取梯度信息,然后更新参数,使损失下降。

:small_blue_diamond: 10. 清除梯度

optimizer.zero_grad()
  • 由于 PyTorch 默认会累积梯度,所以每次更新完参数后,需要清空旧的梯度。

:small_blue_diamond: 11. 打印训练进度

if batch % 100 == 0:
    loss, current = loss.item(), (batch + 1) * len(X)
    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
  • 每 100 个 batch 打印一次:
    • loss.item() 获取数值损失值。
    • current 计算当前已经处理的数据量。
    • 显示训练进度,如:
      loss: 0.325678  [6400/60000]
      

:hammer_and_wrench: 重点总结

步骤 作用
1. 设置训练模式 model.train()
2. 遍历数据 取出 X, y 并移动到 device
3. 前向传播 pred = model(X) 计算预测结果
4. 计算损失 loss = loss_fn(pred, y)
5. 反向传播 loss.backward() 计算梯度
6. 参数更新 optimizer.step() 更新模型
7. 清空梯度 optimizer.zero_grad()
8. 记录损失 print() 监控训练情况

:rocket: 希望这个详细解析对你有帮助!如果有疑问,欢迎继续交流! :dart: