U-Net实战教学(3)- 模型中的卷积层详细介绍

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

深入解析 DoubleConv 模块

这段代码定义了一个 神经网络层模块,用于 U-Net 结构中。它的作用是 执行两次卷积操作,每次卷积后都跟着批归一化(BatchNorm)和 ReLU 激活函数,因此称为 DoubleConv


1. DoubleConv 是什么?

  • 这是一个 PyTorch 神经网络模块,用于图像特征提取
  • 两层卷积(Conv2d)+ 批归一化(BatchNorm)+ ReLU 激活 组成。
  • 适用于深度学习中的 图像分割任务,帮助网络学习到更有效的特征。

2. DoubleConv 代码拆解

(1) __init__() 初始化

def __init__(self, in_channels, out_channels, mid_channels=None):

参数解析

  • in_channels:输入通道数(通常是前一层的输出通道数)。
  • out_channels:输出通道数(表示这个模块最终的通道数)。
  • mid_channels:中间层的通道数(如果不提供,就等于 out_channels)。

(2) super().__init__() 继承 nn.Module

super().__init__()
  • 继承 nn.Module,使 DoubleConv 变成一个 PyTorch 模块,可以在更大的网络(如 U-Net)中使用。

(3) 设定 mid_channels

if not mid_channels:
    mid_channels = out_channels
  • 如果 mid_channels 为空(None),则让 mid_channels = out_channels
  • 这样可以默认让两层卷积的输入/输出通道数相等

(4) 定义 self.double_conv

self.double_conv = nn.Sequential(
    nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(mid_channels),
    nn.ReLU(inplace=True),
    nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(out_channels),
    nn.ReLU(inplace=True)
)

理解 nn.Sequential()

  • nn.Sequential() 是一个 容器,可以把多个层(层按顺序执行)组合在一起
  • 这样可以减少代码量,使得 DoubleConv 结构更清晰。

(5) 第一层卷积

nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
  • in_channelsmid_channels:输入 in_channels,输出 mid_channels(通道数变换)。
  • kernel_size=3:使用 3×3 卷积核(是最常见的卷积核尺寸)。
  • padding=1
    • 3×3 的卷积核在 padding=1 时,输入和输出的尺寸不会变化
    • 如果输入是 HxW,输出仍然是 HxW(保持大小)。
  • bias=False
    • 这里 bias=False,因为 BatchNorm2d 已经有偏置参数了,所以不需要额外的偏置。

(6) 批归一化(BatchNorm)

nn.BatchNorm2d(mid_channels)
  • 作用
    • 使数据归一化(减去均值,除以标准差)。
    • 加速训练,稳定梯度,减少梯度消失问题。
  • mid_channels
    • 批归一化的输入通道数要和卷积层的输出通道数一致。

(7) ReLU 激活函数

nn.ReLU(inplace=True)
  • 作用
    • ReLU(x) = max(0, x),让负数变为 0,增加非线性,提高模型能力。
    • inplace=True 表示 直接修改输入数据,减少显存占用,加快计算。

(8) 第二层卷积

nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
nn.BatchNorm2d(out_channels)
nn.ReLU(inplace=True)
  • 与第一层类似,但输入通道变成 mid_channels,输出变为 out_channels
  • 这样就完成了 两次 3×3 卷积

3. forward() 前向传播

def forward(self, x):
    return self.double_conv(x)
  • 输入 x(通常是一个 多通道的特征图)。
  • 经过 self.double_conv(两次卷积+归一化+激活)。
  • 输出 x(尺寸不变,通道数变化)。

4. DoubleConv 运行示例

假设输入图像 大小为 256×256,通道数为 3(RGB 图像)。

import torch

# 创建一个 DoubleConv 层
double_conv = DoubleConv(in_channels=3, out_channels=64)

# 生成一个 (批次大小, 通道数, 高度, 宽度) 的输入
x = torch.randn(1, 3, 256, 256)  # 批次大小=1, 3通道, 256x256 图像

# 通过 DoubleConv 前向传播
output = double_conv(x)

print(output.shape)  # 输出形状

运行结果

torch.Size([1, 64, 256, 256])

说明:

  • 输入是 (1, 3, 256, 256)
  • 经过 DoubleConv(3, 64) 后,通道数变成 64,但大小仍然是 256×256(因为 padding=1 )。

5. 视觉化 DoubleConv 的过程

输入 3×3 卷积:

输入: [3通道,256x256]
↓
Conv2D (3x3) → BatchNorm → ReLU
↓
[中间通道(默认等于out_channels),256x256]
↓
Conv2D (3x3) → BatchNorm → ReLU
↓
输出:[out_channels,256x256]

6. 总结

  • DoubleConv两个 3×3 卷积 + 批归一化 + ReLU 激活 组成。
  • 输入 in_channels,输出 out_channels,保证特征提取能力。
  • padding=1 确保尺寸不变,避免形状变化影响后续计算。
  • ReLU 非线性变换,提高模型学习能力。
  • BatchNorm 稳定训练过程,减少梯度消失问题。

你可以把 DoubleConv 想象成一个小型的特征提取器,它让 U-Net 能够学习更复杂的特征! :rocket: