实话说,这是阿图见过的写得最麻烦的U-Net训练程序,不得不说Claude3.7写程序还是非常高级的,这种级别的程序很可能让初学者望而却步。但好在AI时代,每个部分都可以让AI解释清楚,学起来倒也不算太麻烦。下面是完整的程序train.py:
import os
# Set environment variable to suppress OpenMP warnings
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import argparse
from model import UNet, BCEDiceLoss
from dataset import RetinaDataset, get_train_transform, get_validation_transform
def train_model(args):
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create output directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'predictions'), exist_ok=True)
# Create dataset
dataset = RetinaDataset(
img_dir=os.path.join(args.data_dir, 'training/images'),
mask_dir=os.path.join(args.data_dir, 'training/masks'),
transform=get_train_transform()
)
# Split dataset into train and validation
val_size = int(len(dataset) * args.val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Override transform for validation dataset
val_dataset.dataset.transform = get_validation_transform()
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers
)
print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples")
# Create model
model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
model.to(device)
# Define loss function and optimizer
criterion = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=5,
verbose=True
)
# Training loop
best_val_loss = float('inf')
train_losses = []
val_losses = []
dice_scores = []
for epoch in range(args.epochs):
model.train()
epoch_loss = 0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{args.epochs}", unit="batch") as pbar:
for batch in train_loader:
images = batch['image'].to(device)
masks = batch['mask'].to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, masks)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
pbar.update(1)
pbar.set_postfix(loss=loss.item())
# Calculate average training loss for the epoch
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# Validation
model.eval()
val_loss = 0
dice_score = 0
with torch.no_grad():
# Save validation predictions for visualization
if epoch % 5 == 0:
# Get a batch from validation loader for visualization
val_iter = iter(val_loader)
vis_batch = next(val_iter)
vis_images = vis_batch['image'].to(device)
vis_masks = vis_batch['mask'].to(device)
vis_outputs = model(vis_images)
# Calculate Dice score for visualization
vis_pred = (vis_outputs > 0.5).float()
vis_intersection = (vis_pred * vis_masks).sum()
vis_dice = (2. * vis_intersection) / (vis_pred.sum() + vis_masks.sum() + 1e-8)
# Save visualizations
for i in range(min(4, vis_images.size(0))):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original image
axes[0].imshow(vis_images[i, 0].cpu().numpy(), cmap='gray')
axes[0].set_title('Input Image')
axes[0].axis('off')
# Ground truth mask
axes[1].imshow(vis_masks[i, 0].cpu().numpy(), cmap='gray')
axes[1].set_title('Ground Truth')
axes[1].axis('off')
# Predicted mask
axes[2].imshow(vis_outputs[i, 0].cpu().numpy(), cmap='gray')
axes[2].set_title(f'Prediction (Dice: {vis_dice.item():.4f})')
axes[2].axis('off')
plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, 'predictions', f'epoch_{epoch+1}_sample_{i+1}.png'))
plt.close(fig)
# Compute validation metrics on all validation data
for batch in val_loader:
images = batch['image'].to(device)
masks = batch['mask'].to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
# Calculate Dice score
pred = (outputs > 0.5).float()
intersection = (pred * masks).sum()
dice = (2. * intersection) / (pred.sum() + masks.sum() + 1e-8)
dice_score += dice.item()
# Calculate average validation loss and Dice score
val_loss = val_loss / len(val_loader)
val_losses.append(val_loss)
dice_score = dice_score / len(val_loader)
dice_scores.append(dice_score)
# Update learning rate
scheduler.step(val_loss)
# Print epoch results
print(f"Epoch {epoch+1}/{args.epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Dice Score: {dice_score:.4f}")
# Save model if validation loss improved
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
'dice_score': dice_score
}, os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'))
print(f"Saved best model with validation loss: {val_loss:.4f}")
# Save checkpoint every few epochs
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
'dice_score': dice_score
}, os.path.join(args.output_dir, 'checkpoints', f'model_epoch_{epoch+1}.pth'))
# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.subplot(1, 2, 2)
plt.plot(dice_scores, label='Dice Score')
plt.xlabel('Epochs')
plt.ylabel('Dice Score')
plt.legend()
plt.title('Validation Dice Score')
plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, 'training_curves.png'))
plt.close()
print("Training completed!")
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train U-Net for retina segmentation')
parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
parser.add_argument('--output_dir', type=str, default='output', help='Path to output directory')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('--val_split', type=float, default=0.2, help='Validation split ratio')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
args = parser.parse_args()
train_model(args)
下面是对 train.py
文件的详细解读,适合编程小白理解。这个文件主要用于训练一个基于 U-Net 模型的图像分割网络,特别是用于视网膜分割任务。我们将逐步分析代码的各个部分,并解释其中使用的编程技术。
1. 导入必要的库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import argparse
from model import UNet, BCEDiceLoss
from dataset import RetinaDataset, get_train_transform, get_validation_transform
- os: 用于与操作系统交互,比如创建文件夹。
- torch: PyTorch库,用于深度学习。
- torch.nn: 包含神经网络的构建模块。
- torch.optim: 包含优化算法(如Adam)。
- DataLoader: 用于加载数据集的工具。
- random_split: 用于将数据集随机分割为训练集和验证集。
- numpy: 用于数值计算。
- matplotlib.pyplot: 用于绘图,展示训练过程中的损失和Dice分数。
- tqdm: 用于显示进度条。
- argparse: 用于处理命令行参数。
2. 定义训练模型的函数
def train_model(args):
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
- train_model: 这是主要的训练函数,接收命令行参数。
- device: 检查是否有可用的GPU,如果有则使用GPU,否则使用CPU。
3. 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'predictions'), exist_ok=True)
- os.makedirs: 创建输出目录和子目录,用于保存模型检查点和预测结果。
4. 创建数据集
dataset = RetinaDataset(
img_dir=os.path.join(args.data_dir, 'training/images'),
mask_dir=os.path.join(args.data_dir, 'training/masks'),
transform=get_train_transform()
)
- RetinaDataset: 自定义的数据集类,用于加载图像和对应的掩膜(mask)。
- transform: 数据增强的变换操作。
5. 数据集分割
val_size = int(len(dataset) * args.val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
- random_split: 将数据集随机分割为训练集和验证集。
6. 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers
)
- DataLoader: 用于批量加载数据,支持多线程加载(通过
num_workers
参数)。
7. 创建模型、损失函数和优化器
model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
model.to(device)
criterion = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
- UNet: 自定义的U-Net模型,用于图像分割。
- BCEDiceLoss: 自定义的损失函数,结合了二元交叉熵损失和Dice损失。
- optim.Adam: Adam优化器,用于更新模型参数。
8. 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=5,
verbose=True
)
- ReduceLROnPlateau: 学习率调度器,当验证损失不再下降时,降低学习率。
9. 训练循环
for epoch in range(args.epochs):
model.train()
epoch_loss = 0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{args.epochs}", unit="batch") as pbar:
for batch in train_loader:
images = batch['image'].to(device)
masks = batch['mask'].to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, masks)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
pbar.update(1)
pbar.set_postfix(loss=loss.item())
- 训练循环: 遍历每个epoch,进行前向传播、计算损失、反向传播和优化。
- tqdm: 显示训练进度条。
10. 验证阶段
model.eval()
val_loss = 0
dice_score = 0
with torch.no_grad():
for batch in val_loader:
images = batch['image'].to(device)
masks = batch['mask'].to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
# Calculate Dice score
pred = (outputs > 0.5).float()
intersection = (pred * masks).sum()
dice = (2. * intersection) / (pred.sum() + masks.sum() + 1e-8)
dice_score += dice.item()
- 验证阶段: 在验证集上评估模型性能,计算验证损失和Dice分数。
11. 保存模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
'dice_score': dice_score
}, os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'))
- torch.save: 保存模型的状态字典,以便后续加载和使用。
12. 绘制训练曲线
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
- matplotlib: 用于绘制训练和验证损失的曲线,帮助可视化模型的训练过程。
13. 主程序入口
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train U-Net for retina segmentation')
parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
parser.add_argument('--output_dir', type=str, default='output', help='Path to output directory')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('--val_split', type=float, default=0.2, help='Validation split ratio')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
args = parser.parse_args()
train_model(args)
- argparse: 处理命令行参数,使得用户可以在运行程序时指定数据目录、输出目录、批量大小等参数。
- train_model(args): 调用训练函数,开始训练过程。
总结
这个 train.py
文件展示了如何使用 PyTorch 进行深度学习模型的训练,包括数据加载、模型构建、损失计算、优化、验证和结果可视化等步骤。通过逐步分析代码,编程小白可以理解深度学习训练的基本流程和相关技术。