U-Net实战教学(5)- 构建U-Net

详细解析 OutConvUNet 代码


1. OutConv 解析

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

(1) __init__() 构造函数

def __init__(self, in_channels, out_channels):

参数解释

  • in_channels:输入通道数(上采样后输出的特征图通道数)。
  • out_channels:输出通道数(U-Net 最终的类别数)。
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  • 这里用 nn.Conv2d 创建了一个 1×1 卷积层
  • kernel_size=1(1×1 卷积)的作用:
    • 降低通道数(从 64 通道变为 n_classes)。
    • 逐像素处理,但不改变特征图的 H × W 大小。
    • 用于 最终的类别预测(比如分割任务中,每个像素预测类别)。

(2) forward() 前向传播

def forward(self, x):
    return self.conv(x)
  • 直接将输入 x 通过 self.conv(x) 计算。
  • 输出是 通道数为 n_classes 的特征图,用于最终的像素级分类。

(3) OutConv 运行示例

假设 x 的形状是 (1, 64, 256, 256)(64 个通道,大小 256×256)。

out_layer = OutConv(64, 1)
x = torch.randn(1, 64, 256, 256)
output = out_layer(x)
print(output.shape)  # 预期输出 (1, 1, 256, 256)

说明

  • x 输入有 64 个通道,经过 1×1 卷积后,输出变成 1 个通道(即 n_classes=1)。
  • 形状仍然是 256×256,没有改变空间分辨率。

2. UNet 解析

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return torch.sigmoid(logits)  # Apply sigmoid for binary segmentation

上面是构建U-Net部分的程序,使用了前面几个帖子的内容,包括DoubleConv,上转换模块,下转换模块,和上面提到的OutConv。以下是这部分的详细解释。

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=False):

参数解释

  • n_channels=1:输入通道数(灰度图是 1,RGB 图像是 3)。
  • n_classes=1:输出通道数(1 表示二分类问题,多个表示多类别分割)。
  • bilinear=False
    • False:使用 转置卷积(更强的学习能力)
    • True:使用 双线性插值(减少参数,提高推理速度)

(1) 定义 U-Net 结构

self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
  • 下采样部分(编码器):Down 结构逐步降低分辨率,并增加通道数。
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
  • 上采样部分(解码器):Up 结构逐步恢复分辨率,并结合跳跃连接特征。
self.outc = OutConv(64, n_classes)
  • OutConv(64, n_classes) 最终输出分割结果

(2) forward() 前向传播

def forward(self, x):

输入 x 是一个 n_channels 通道的图片,如 (1, 1, 256, 256)


(3) 下采样部分

x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
步骤 作用 特征图形状(假设 n_channels=1,输入 256×256
self.inc(x) DoubleConv(1, 64) (1, 64, 256, 256)
self.down1(x1) MaxPool + DoubleConv(64, 128) (1, 128, 128, 128)
self.down2(x2) MaxPool + DoubleConv(128, 256) (1, 256, 64, 64)
self.down3(x3) MaxPool + DoubleConv(256, 512) (1, 512, 32, 32)
self.down4(x4) MaxPool + DoubleConv(512, 1024//factor) (1, 1024, 16, 16)

(4) 上采样部分

x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
  • Up 结构包含 上采样(恢复分辨率)+ 跳跃连接(增强细节)
步骤 作用 特征图形状
self.up1(x5, x4) 上采样 16×16 → 32×32,跳跃连接 (1, 512, 32, 32)
self.up2(x, x3) 上采样 32×32 → 64×64,跳跃连接 (1, 256, 64, 64)
self.up3(x, x2) 上采样 64×64 → 128×128,跳跃连接 (1, 128, 128, 128)
self.up4(x, x1) 上采样 128×128 → 256×256,跳跃连接 (1, 64, 256, 256)

(5) 输出层

logits = self.outc(x)
return torch.sigmoid(logits)  # Apply sigmoid for binary segmentation
  • OutConv(64, n_classes) 将通道数从 64 降为 n_classes(如 1 通道的二分类分割)。
  • torch.sigmoid(logits) 归一化输出,使得像素值在 [0,1] 之间(适用于二分类任务)。

3. UNet 运行示例

unet = UNet(n_channels=1, n_classes=1)
x = torch.randn(1, 1, 256, 256)  # 假设输入是 1 通道 256x256
output = unet(x)
print(output.shape)  # 预期输出 (1, 1, 256, 256)

4. 总结

  1. OutConv 作用

    • 使用 1×1 卷积 进行 通道降维,使最终输出 n_classes 个通道。
    • 保持 H × W 不变,仅改变通道数。
  2. UNet 作用

    • 编码(Down):最大池化 + DoubleConv 降低分辨率。
    • 解码(Up):上采样 + DoubleConv 恢复分辨率,并融合编码特征。
    • 输出(OutConv):转换为 n_classes 维的分割结果。
    • torch.sigmoid() 让输出归一化到 [0,1],适用于二分类任务。

U-Net 是一个全卷积神经网络(FCN),专门用于 像素级分类任务,比如医学图像分割!:rocket: