要训练模型,你需要一个损失函数loss function和一个优化器optimizer。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
一个简单的训练函数如下:
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
这个 train()
函数的作用是什么?
这个函数用于训练神经网络,它的核心流程是:
- 把模型设为训练模式 (
model.train()
)。 - 遍历训练数据 并将其分批(batch)送入模型。
- 计算预测结果和损失 (
loss_fn(pred, y)
)。 - 反向传播 (
loss.backward()
),计算梯度。 - 更新模型参数 (
optimizer.step()
)。 - 清除旧的梯度信息 (
optimizer.zero_grad()
)。 - 每 100 个 batch 记录一次损失值。
代码逐行解析
1. 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
参数解释:
dataloader
:数据加载器,把训练数据分成小批次(batch)。model
:要训练的神经网络模型。loss_fn
:损失函数(例如CrossEntropyLoss
),用于衡量预测结果和真实标签的差距。optimizer
:优化器(例如SGD
、Adam
),用于调整模型参数,使损失降低。
2. 计算训练集大小
size = len(dataloader.dataset)
- 获取整个训练集的数据总数,以便后面计算进度。
3. 设定模型为训练模式
model.train()
- 让模型进入训练模式,这会影响BatchNorm 和 Dropout 层的行为:
- BatchNorm:会计算新的均值和方差。
- Dropout:会随机丢弃一部分神经元,提高泛化能力。
4. 遍历训练数据
for batch, (X, y) in enumerate(dataloader):
dataloader
会自动把数据拆成多个小 batch 并逐个取出。X
:当前 batch 的输入数据。y
:当前 batch 的真实标签。
5. 把数据转移到 GPU(如果可用)
X, y = X.to(device), y.to(device)
- 让
X
和y
进入 GPU 或 CPU,加快计算速度。
6. 计算预测结果
pred = model(X)
- 让
X
通过模型,得到预测结果pred
。
7. 计算损失
loss = loss_fn(pred, y)
- 使用损失函数
loss_fn
来计算当前 batch 的损失值。
8. 反向传播计算梯度
loss.backward()
- 计算损失对每个参数的导数(梯度),用于优化参数。
9. 更新模型参数
optimizer.step()
- 使用优化器更新模型的参数:
- SGD/Adam 这些优化器会读取梯度信息,然后更新参数,使损失下降。
10. 清除梯度
optimizer.zero_grad()
- 由于 PyTorch 默认会累积梯度,所以每次更新完参数后,需要清空旧的梯度。
11. 打印训练进度
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
- 每 100 个 batch 打印一次:
loss.item()
获取数值损失值。current
计算当前已经处理的数据量。- 显示训练进度,如:
loss: 0.325678 [6400/60000]
重点总结
步骤 | 作用 |
---|---|
1. 设置训练模式 | model.train() |
2. 遍历数据 | 取出 X, y 并移动到 device |
3. 前向传播 | pred = model(X) 计算预测结果 |
4. 计算损失 | loss = loss_fn(pred, y) |
5. 反向传播 | loss.backward() 计算梯度 |
6. 参数更新 | optimizer.step() 更新模型 |
7. 清空梯度 | optimizer.zero_grad() |
8. 记录损失 | print() 监控训练情况 |
希望这个详细解析对你有帮助!如果有疑问,欢迎继续交流!