class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
深入解析 DoubleConv
模块
这段代码定义了一个 神经网络层模块,用于 U-Net 结构中。它的作用是 执行两次卷积操作,每次卷积后都跟着批归一化(BatchNorm)和 ReLU 激活函数,因此称为 DoubleConv
。
1. DoubleConv
是什么?
- 这是一个 PyTorch 神经网络模块,用于图像特征提取。
- 由 两层卷积(Conv2d)+ 批归一化(BatchNorm)+ ReLU 激活 组成。
- 适用于深度学习中的 图像分割任务,帮助网络学习到更有效的特征。
2. DoubleConv
代码拆解
(1) __init__()
初始化
def __init__(self, in_channels, out_channels, mid_channels=None):
参数解析:
in_channels
:输入通道数(通常是前一层的输出通道数)。out_channels
:输出通道数(表示这个模块最终的通道数)。mid_channels
:中间层的通道数(如果不提供,就等于out_channels
)。
(2) super().__init__()
继承 nn.Module
super().__init__()
- 继承
nn.Module
,使DoubleConv
变成一个 PyTorch 模块,可以在更大的网络(如 U-Net)中使用。
(3) 设定 mid_channels
if not mid_channels:
mid_channels = out_channels
- 如果
mid_channels
为空(None),则让mid_channels = out_channels
。 - 这样可以默认让两层卷积的输入/输出通道数相等。
(4) 定义 self.double_conv
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
理解 nn.Sequential()
nn.Sequential()
是一个 容器,可以把多个层(层按顺序执行)组合在一起。- 这样可以减少代码量,使得
DoubleConv
结构更清晰。
(5) 第一层卷积
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
in_channels
→mid_channels
:输入in_channels
,输出mid_channels
(通道数变换)。kernel_size=3
:使用 3×3 卷积核(是最常见的卷积核尺寸)。padding=1
:- 3×3 的卷积核在
padding=1
时,输入和输出的尺寸不会变化。 - 如果输入是
HxW
,输出仍然是HxW
(保持大小)。
- 3×3 的卷积核在
bias=False
:- 这里
bias=False
,因为BatchNorm2d
已经有偏置参数了,所以不需要额外的偏置。
- 这里
(6) 批归一化(BatchNorm)
nn.BatchNorm2d(mid_channels)
- 作用:
- 使数据归一化(减去均值,除以标准差)。
- 加速训练,稳定梯度,减少梯度消失问题。
mid_channels
:- 批归一化的输入通道数要和卷积层的输出通道数一致。
(7) ReLU 激活函数
nn.ReLU(inplace=True)
- 作用:
ReLU(x) = max(0, x)
,让负数变为 0,增加非线性,提高模型能力。inplace=True
表示 直接修改输入数据,减少显存占用,加快计算。
(8) 第二层卷积
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
nn.BatchNorm2d(out_channels)
nn.ReLU(inplace=True)
- 与第一层类似,但输入通道变成
mid_channels
,输出变为out_channels
。 - 这样就完成了 两次 3×3 卷积。
3. forward()
前向传播
def forward(self, x):
return self.double_conv(x)
- 输入
x
(通常是一个 多通道的特征图)。 - 经过
self.double_conv
(两次卷积+归一化+激活)。 - 输出
x
(尺寸不变,通道数变化)。
4. DoubleConv
运行示例
假设输入图像 大小为 256×256,通道数为 3(RGB 图像)。
import torch
# 创建一个 DoubleConv 层
double_conv = DoubleConv(in_channels=3, out_channels=64)
# 生成一个 (批次大小, 通道数, 高度, 宽度) 的输入
x = torch.randn(1, 3, 256, 256) # 批次大小=1, 3通道, 256x256 图像
# 通过 DoubleConv 前向传播
output = double_conv(x)
print(output.shape) # 输出形状
运行结果
torch.Size([1, 64, 256, 256])
说明:
- 输入是 (1, 3, 256, 256)。
- 经过
DoubleConv(3, 64)
后,通道数变成 64,但大小仍然是 256×256(因为padding=1
)。
5. 视觉化 DoubleConv
的过程
输入 3×3
卷积:
输入: [3通道,256x256]
↓
Conv2D (3x3) → BatchNorm → ReLU
↓
[中间通道(默认等于out_channels),256x256]
↓
Conv2D (3x3) → BatchNorm → ReLU
↓
输出:[out_channels,256x256]
6. 总结
DoubleConv
由 两个 3×3 卷积 + 批归一化 + ReLU 激活 组成。- 输入
in_channels
,输出out_channels
,保证特征提取能力。 padding=1
确保尺寸不变,避免形状变化影响后续计算。- ReLU 非线性变换,提高模型学习能力。
- BatchNorm 稳定训练过程,减少梯度消失问题。
你可以把 DoubleConv
想象成一个小型的特征提取器,它让 U-Net 能够学习更复杂的特征!