U-Net实战教学(7)- 使用模型(Inference)

在做完了模型的训练,我们需要使用模型对不在训练集中的图像进行测试和使用,这个过程就是inference. 下面是AI辅助编程写下的Inference的程序。

import os
# Set environment variable to suppress OpenMP warnings
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import argparse
from tqdm import tqdm
import glob

from model import UNet
from dataset import RetinaDataset, get_test_transform


def inference(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load model
    model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
    
    # Load checkpoint
    checkpoint = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Loaded model from {args.checkpoint}")
    print(f"Model trained for {checkpoint['epoch']+1} epochs")
    print(f"Validation loss: {checkpoint['loss']:.4f}, Dice score: {checkpoint.get('dice_score', 'N/A')}")
    
    # Create test dataset
    test_dataset = RetinaDataset(
        img_dir=os.path.join(args.data_dir, 'test/images'),
        transform=get_test_transform(),
        is_test=True
    )
    
    # Process each test image
    with torch.no_grad():
        for idx in tqdm(range(len(test_dataset)), desc="Processing test images"):
            sample = test_dataset[idx]
            # Ensure image is float32 (not float64/double)
            image = sample['image'].float().unsqueeze(0).to(device)  # Add batch dimension
            filename = sample['filename']
            
            # Forward pass
            output = model(image)
            
            # Convert to binary mask
            pred_mask = (output > args.threshold).float()
            
            # Save results
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Original image
            axes[0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
            axes[0].set_title('Original Image')
            axes[0].axis('off')
            
            # Predicted mask
            axes[1].imshow(pred_mask.squeeze().cpu().numpy(), cmap='gray')
            axes[1].set_title('Predicted Mask')
            axes[1].axis('off')
            
            plt.tight_layout()
            plt.savefig(os.path.join(args.output_dir, f"{os.path.splitext(filename)[0]}_prediction.png"))
            plt.close(fig)
            
            # Save the mask as an image
            mask_img = Image.fromarray((pred_mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
            mask_img.save(os.path.join(args.output_dir, f"{os.path.splitext(filename)[0]}_mask.png"))
    
    print(f"Inference completed. Results saved to {args.output_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference with trained U-Net model')
    parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
    parser.add_argument('--output_dir', type=str, default='predictions', help='Path to output directory')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--threshold', type=float, default=0.5, help='Threshold for binary segmentation')
    parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
    
    args = parser.parse_args()
    
    inference(args)

下面是上文代码的详细解读,适合编程小白理解。这个文件主要用于使用训练好的 U-Net 模型进行图像分割的推理(Inference),即对新图像进行预测。我们将逐步分析代码的各个部分,并解释其中使用的编程技术。

1. 导入必要的库

import os
# Set environment variable to suppress OpenMP warnings
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import argparse
from tqdm import tqdm
import glob

from model import UNet
from dataset import RetinaDataset, get_test_transform
  • os: 用于与操作系统交互,比如创建文件夹。
  • torch: PyTorch库,用于深度学习。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于绘图,展示原始图像和预测结果。
  • PIL (Python Imaging Library): 用于图像处理。
  • argparse: 用于处理命令行参数。
  • tqdm: 用于显示进度条。
  • glob: 用于查找符合特定规则的文件路径。

2. 定义推理函数

def inference(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
  • inference: 这是主要的推理函数,接收命令行参数。
  • device: 检查是否有可用的GPU,如果有则使用GPU,否则使用CPU。

3. 创建输出目录

os.makedirs(args.output_dir, exist_ok=True)
  • os.makedirs: 创建输出目录,用于保存预测结果。

4. 加载模型

model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)

# Load checkpoint
checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
  • UNet: 自定义的U-Net模型,用于图像分割。
  • torch.load: 加载保存的模型检查点。
  • load_state_dict: 将检查点中的模型参数加载到模型中。
  • model.eval(): 将模型设置为评估模式,禁用 dropout 和 batch normalization。

5. 打印模型信息

print(f"Loaded model from {args.checkpoint}")
print(f"Model trained for {checkpoint['epoch']+1} epochs")
print(f"Validation loss: {checkpoint['loss']:.4f}, Dice score: {checkpoint.get('dice_score', 'N/A')}")
  • 打印加载的模型信息,包括训练的轮数、验证损失和Dice分数。

6. 创建测试数据集

test_dataset = RetinaDataset(
    img_dir=os.path.join(args.data_dir, 'test/images'),
    transform=get_test_transform(),
    is_test=True
)
  • RetinaDataset: 自定义的数据集类,用于加载测试图像。
  • transform: 应用测试时的图像变换。

7. 处理每个测试图像

with torch.no_grad():
    for idx in tqdm(range(len(test_dataset)), desc="Processing test images"):
        sample = test_dataset[idx]
        image = sample['image'].float().unsqueeze(0).to(device)  # Add batch dimension
        filename = sample['filename']
        
        # Forward pass
        output = model(image)
        
        # Convert to binary mask
        pred_mask = (output > args.threshold).float()
  • torch.no_grad(): 禁用梯度计算,以节省内存和加快推理速度。
  • tqdm: 显示处理进度条。
  • unsqueeze(0): 为图像添加一个批次维度。
  • Forward pass: 将图像输入模型,得到输出。
  • pred_mask: 将模型输出转换为二进制掩膜,使用指定的阈值。

8. 保存结果

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Original image
axes[0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

# Predicted mask
axes[1].imshow(pred_mask.squeeze().cpu().numpy(), cmap='gray')
axes[1].set_title('Predicted Mask')
axes[1].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, f"{os.path.splitext(filename)[0]}_prediction.png"))
plt.close(fig)

# Save the mask as an image
mask_img = Image.fromarray((pred_mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
mask_img.save(os.path.join(args.output_dir, f"{os.path.splitext(filename)[0]}_mask.png"))
  • matplotlib: 用于绘制原始图像和预测掩膜,并保存为图像文件。
  • Image.fromarray: 将预测掩膜转换为图像格式并保存。

9. 完成推理

print(f"Inference completed. Results saved to {args.output_dir}")
  • 打印推理完成的信息,指明结果保存的位置。

10. 主程序入口

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference with trained U-Net model')
    parser.add_argument('--data_dir', type=str, default='.', help='Path to data directory')
    parser.add_argument('--output_dir', type=str, default='predictions', help='Path to output directory')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--threshold', type=float, default=0.5, help='Threshold for binary segmentation')
    parser.add_argument('--bilinear', action='store_true', help='Use bilinear upsampling')
    
    args = parser.parse_args()
    
    inference(args)
  • argparse: 处理命令行参数,使得用户可以在运行程序时指定数据目录、输出目录、模型检查点路径等参数。
  • inference(args): 调用推理函数,开始推理过程。

总结

这个 inference.py 文件展示了如何使用训练好的 U-Net 模型对新图像进行推理,包括加载模型、处理测试数据、生成预测掩膜以及保存结果等步骤。通过逐步分析代码,编程小白可以理解推理的基本流程和相关技术。