U-Net实战教学(6)- 训练模型代码解读

实话说,这是阿图见过的写得最麻烦的U-Net训练程序,不得不说Claude3.7写程序还是非常高级的,这种级别的程序很可能让初学者望而却步。但好在AI时代,每个部分都可以让AI解释清楚,学起来倒也不算太麻烦。下面是完整的程序train.py:

import os
# Set environment variable to suppress OpenMP warnings
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import argparse

from model import UNet, BCEDiceLoss
from dataset import RetinaDataset, get_train_transform, get_validation_transform

def train_model(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create output directories
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'predictions'), exist_ok=True)
    
    # Create dataset
    dataset = RetinaDataset(
        img_dir=os.path.join(args.data_dir, 'training/images'),
        mask_dir=os.path.join(args.data_dir, 'training/masks'),
        transform=get_train_transform()
    )
    
    # Split dataset into train and validation
    val_size = int(len(dataset) * args.val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Override transform for validation dataset
    val_dataset.dataset.transform = get_validation_transform()
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=True, 
        num_workers=args.num_workers
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=args.batch_size, 
        shuffle=False, 
        num_workers=args.num_workers
    )
    
    print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples")
    
    # Create model
    model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
    model.to(device)
    
    # Define loss function and optimizer
    criterion = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.1, 
        patience=5, 
        verbose=True
    )
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    dice_scores = []
    
    for epoch in range(args.epochs):
        model.train()
        epoch_loss = 0
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{args.epochs}", unit="batch") as pbar:
            for batch in train_loader:
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix(loss=loss.item())
        
        # Calculate average training loss for the epoch
        train_loss = epoch_loss / len(train_loader)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        dice_score = 0
        
        with torch.no_grad():
            # Save validation predictions for visualization
            if epoch % 5 == 0:
                # Get a batch from validation loader for visualization
                val_iter = iter(val_loader)
                vis_batch = next(val_iter)
                vis_images = vis_batch['image'].to(device)
                vis_masks = vis_batch['mask'].to(device)
                vis_outputs = model(vis_images)
                
                # Calculate Dice score for visualization
                vis_pred = (vis_outputs > 0.5).float()
                vis_intersection = (vis_pred * vis_masks).sum()
                vis_dice = (2. * vis_intersection) / (vis_pred.sum() + vis_masks.sum() + 1e-8)
                
                # Save visualizations
                for i in range(min(4, vis_images.size(0))):
                    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                    
                    # Original image
                    axes[0].imshow(vis_images[i, 0].cpu().numpy(), cmap='gray')
                    axes[0].set_title('Input Image')
                    axes[0].axis('off')
                    
                    # Ground truth mask
                    axes[1].imshow(vis_masks[i, 0].cpu().numpy(), cmap='gray')
                    axes[1].set_title('Ground Truth')
                    axes[1].axis('off')
                    
                    # Predicted mask
                    axes[2].imshow(vis_outputs[i, 0].cpu().numpy(), cmap='gray')
                    axes[2].set_title(f'Prediction (Dice: {vis_dice.item():.4f})')
                    axes[2].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(os.path.join(args.output_dir, 'predictions', f'epoch_{epoch+1}_sample_{i+1}.png'))
                    plt.close(fig)
            
            # Compute validation metrics on all validation data
            for batch in val_loader:
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Calculate Dice score
                pred = (outputs > 0.5).float()
                intersection = (pred * masks).sum()
                dice = (2. * intersection) / (pred.sum() + masks.sum() + 1e-8)
                dice_score += dice.item()
        
        # Calculate average validation loss and Dice score
        val_loss = val_loss / len(val_loader)
        val_losses.append(val_loss)
        
        dice_score = dice_score / len(val_loader)
        dice_scores.append(dice_score)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{args.epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Dice Score: {dice_score:.4f}")
        
        # Save model if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'dice_score': dice_score
            }, os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'))
            print(f"Saved best model with validation loss: {val_loss:.4f}")
        
        # Save checkpoint every few epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'dice_score': dice_score
            }, os.path.join(args.output_dir, 'checkpoints', f'model_epoch_{epoch+1}.pth'))
    
    # Plot training and validation loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(dice_scores, label='Dice Score')
    plt.xlabel('Epochs')
    plt.ylabel('Dice Score')
    plt.legend()
    plt.title('Validation Dice Score')
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.output_dir, 'training_curves.png'))
    plt.close()
    
    print("Training completed!")
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train U-Net for retina segmentation')
    parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
    parser.add_argument('--output_dir', type=str, default='output', help='Path to output directory')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--val_split', type=float, default=0.2, help='Validation split ratio')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
    parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
    
    args = parser.parse_args()
    
    train_model(args) 

下面是对 train.py 文件的详细解读,适合编程小白理解。这个文件主要用于训练一个基于 U-Net 模型的图像分割网络,特别是用于视网膜分割任务。我们将逐步分析代码的各个部分,并解释其中使用的编程技术。

1. 导入必要的库

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import argparse

from model import UNet, BCEDiceLoss
from dataset import RetinaDataset, get_train_transform, get_validation_transform
  • os: 用于与操作系统交互,比如创建文件夹。
  • torch: PyTorch库,用于深度学习。
  • torch.nn: 包含神经网络的构建模块。
  • torch.optim: 包含优化算法(如Adam)。
  • DataLoader: 用于加载数据集的工具。
  • random_split: 用于将数据集随机分割为训练集和验证集。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于绘图,展示训练过程中的损失和Dice分数。
  • tqdm: 用于显示进度条。
  • argparse: 用于处理命令行参数。

2. 定义训练模型的函数

def train_model(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
  • train_model: 这是主要的训练函数,接收命令行参数。
  • device: 检查是否有可用的GPU,如果有则使用GPU,否则使用CPU。

3. 创建输出目录

os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'predictions'), exist_ok=True)
  • os.makedirs: 创建输出目录和子目录,用于保存模型检查点和预测结果。

4. 创建数据集

dataset = RetinaDataset(
    img_dir=os.path.join(args.data_dir, 'training/images'),
    mask_dir=os.path.join(args.data_dir, 'training/masks'),
    transform=get_train_transform()
)
  • RetinaDataset: 自定义的数据集类,用于加载图像和对应的掩膜(mask)。
  • transform: 数据增强的变换操作。

5. 数据集分割

val_size = int(len(dataset) * args.val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
  • random_split: 将数据集随机分割为训练集和验证集。

6. 创建数据加载器

train_loader = DataLoader(
    train_dataset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=args.num_workers
)
  • DataLoader: 用于批量加载数据,支持多线程加载(通过 num_workers 参数)。

7. 创建模型、损失函数和优化器

model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
model.to(device)

criterion = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
  • UNet: 自定义的U-Net模型,用于图像分割。
  • BCEDiceLoss: 自定义的损失函数,结合了二元交叉熵损失和Dice损失。
  • optim.Adam: Adam优化器,用于更新模型参数。

8. 学习率调度器

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.1, 
    patience=5, 
    verbose=True
)
  • ReduceLROnPlateau: 学习率调度器,当验证损失不再下降时,降低学习率。

9. 训练循环

for epoch in range(args.epochs):
    model.train()
    epoch_loss = 0
    
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{args.epochs}", unit="batch") as pbar:
        for batch in train_loader:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.update(1)
            pbar.set_postfix(loss=loss.item())
  • 训练循环: 遍历每个epoch,进行前向传播、计算损失、反向传播和优化。
  • tqdm: 显示训练进度条。

10. 验证阶段

model.eval()
val_loss = 0
dice_score = 0

with torch.no_grad():
    for batch in val_loader:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        val_loss += loss.item()
        
        # Calculate Dice score
        pred = (outputs > 0.5).float()
        intersection = (pred * masks).sum()
        dice = (2. * intersection) / (pred.sum() + masks.sum() + 1e-8)
        dice_score += dice.item()
  • 验证阶段: 在验证集上评估模型性能,计算验证损失和Dice分数。

11. 保存模型

if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': val_loss,
        'dice_score': dice_score
    }, os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'))
  • torch.save: 保存模型的状态字典,以便后续加载和使用。

12. 绘制训练曲线

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
  • matplotlib: 用于绘制训练和验证损失的曲线,帮助可视化模型的训练过程。

13. 主程序入口

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train U-Net for retina segmentation')
    parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
    parser.add_argument('--output_dir', type=str, default='output', help='Path to output directory')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--val_split', type=float, default=0.2, help='Validation split ratio')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
    parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
    
    args = parser.parse_args()
    
    train_model(args)
  • argparse: 处理命令行参数,使得用户可以在运行程序时指定数据目录、输出目录、批量大小等参数。
  • train_model(args): 调用训练函数,开始训练过程。

总结

这个 train.py 文件展示了如何使用 PyTorch 进行深度学习模型的训练,包括数据加载、模型构建、损失计算、优化、验证和结果可视化等步骤。通过逐步分析代码,编程小白可以理解深度学习训练的基本流程和相关技术。