本文内容接着上文继续详细解读相关程序每个部分的含义。帮助梳理编程上下采样时的概念性问题。
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
详细解析 Down(下采样)和 Up(上采样)
在 U-Net 结构中:
- 下采样(Down) 主要用于 降低特征图的分辨率,同时提取高层特征。
- 上采样(Up) 主要用于 恢复特征图的分辨率,并结合下采样阶段的特征以恢复细节。
1. Down(下采样)详细解析
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
(1) __init__() 方法
def __init__(self, in_channels, out_channels):
参数解析
in_channels:输入通道数(上一层特征图的通道数)。out_channels:输出通道数(这一层特征图的通道数)。
(2) 定义 self.maxpool_conv
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
self.maxpool_conv 由两部分组成:
nn.MaxPool2d(2):最大池化,将图像分辨率减少一半(高宽都变成原来的一半)。DoubleConv(in_channels, out_channels):双层 3x3 卷积,提取特征。
(3) forward() 方法
def forward(self, x):
return self.maxpool_conv(x)
- 直接将输入
x传入self.maxpool_conv。 - 先进行 最大池化,再经过 两层卷积。
(4) Down 的执行过程
假设输入 x 形状为 (1, 64, 256, 256)(批量大小 1,通道数 64,大小 256x256)。
| 操作 | 说明 | 输出形状 |
|---|---|---|
MaxPool2d(2) |
2×2 池化,降低分辨率 | (1, 64, 128, 128) |
DoubleConv(64, 128) |
两次 3×3 卷积,通道变为 128 | (1, 128, 128, 128) |
2. Up(上采样)详细解析
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
(1) __init__() 方法
def __init__(self, in_channels, out_channels, bilinear=True):
参数解析
in_channels:输入通道数(上一层的输出通道)。out_channels:输出通道数(经过DoubleConv之后的通道数)。bilinear:True:使用 双线性插值(更节省显存)。False:使用 反卷积(转置卷积)(学习能力更强)。
(2) 选择上采样方法
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
两种不同的上采样方式:
-
双线性插值 (
bilinear=True)nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)- 直接将图像放大 2 倍(不涉及学习参数)。
- 使用
DoubleConv(in_channels, out_channels, in_channels // 2)进行特征提取。
-
转置卷积 (
bilinear=False)nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)- 通过 可学习的卷积操作 来恢复分辨率。
DoubleConv(in_channels, out_channels)进一步提取特征。
(3) forward() 方法
def forward(self, x1, x2):
x1:来自上一层上采样的特征图(低分辨率)。x2:来自U-Net 跳跃连接的特征图(高分辨率)。
(4) 进行上采样
x1 = self.up(x1)
x1被上采样(放大 2 倍)。
(5) 计算尺寸差距
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1经过Upsample之后,尺寸可能和x2不完全匹配,因此需要计算两者的差距。
(6) 进行填充 (F.pad())
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
F.pad()对x1进行填充,使x1和x2尺寸匹配:diffX // 2:左侧填充的像素数diffX - diffX // 2:右侧填充的像素数diffY // 2:顶部填充的像素数diffY - diffY // 2:底部填充的像素数
(7) 跳跃连接 torch.cat()
x = torch.cat([x2, x1], dim=1)
- 将
x2(下采样时保存的高分辨率特征图)与x1(上采样后的特征图)沿通道维度拼接。 dim=1代表沿 通道方向 拼接。
(8) DoubleConv 提取特征
return self.conv(x)
- 拼接后的
x通过DoubleConv,得到最终输出。
3. Up 运行示例
假设:
x1.shape = (1, 512, 32, 32)(低分辨率特征)x2.shape = (1, 256, 64, 64)(跳跃连接特征)
通过 Up(512, 256),最终输出:
torch.Size([1, 256, 64, 64])
说明:
x1上采样后变为64×64,并与x2拼接。DoubleConv使输出通道数变成256。
总结
Down:最大池化 +DoubleConv(降低分辨率,提取特征)。Up:上采样 +DoubleConv(恢复分辨率,并结合高分辨率特征)。
问题:为什么可以像函数一样直接使用 DoubleConv()?
你可能觉得 DoubleConv 是一个 类,但它为什么可以像函数一样直接用,比如:
self.conv = DoubleConv(in_channels, out_channels)
这个问题的核心是:在 Python 中,类本质上是可调用对象(callable objects),当你实例化它时,相当于调用了它的 __init__() 方法。
1. Python 类的实例化
在 Python 里,类是模板,而对象是具体的实例。当我们创建一个对象时,我们实际上是在调用类的 __init__ 方法。
例如:
class Example:
def __init__(self, value):
self.value = value
# 创建对象
e = Example(10) # 这相当于调用 Example.__init__(e, 10)
等价于
e = Example.__call__(10)
2. DoubleConv 是 nn.Module 的子类
class DoubleConv(nn.Module):
- 由于
DoubleConv继承了torch.nn.Module,所以DoubleConv本质上是一个 神经网络模块。 - 当我们调用
DoubleConv(in_channels, out_channels)时,它会 自动调用__init__()方法,创建一个DoubleConv层。
示例:
conv_layer = DoubleConv(3, 64) # 这里 conv_layer 是一个包含两层卷积的神经网络层
这其实是:
conv_layer = DoubleConv.__call__(3, 64)
等价于
conv_layer = DoubleConv.__new__(DoubleConv)
conv_layer.__init__(3, 64)
最终 conv_layer 是 DoubleConv 的一个实例,可以像普通 PyTorch 层那样使用。
3. 为什么可以像函数一样使用?
在 PyTorch 中,所有 nn.Module 的子类都默认实现了 __call__() 方法,它会自动调用 forward() 方法。因此:
x = torch.randn(1, 3, 256, 256) # 创建一个 3 通道 256x256 的张量
output = conv_layer(x) # 这实际上是调用 conv_layer.forward(x)
这等价于:
output = conv_layer.forward(x)
所以在 Up 里面这样写:
self.conv = DoubleConv(in_channels, out_channels)
意味着:
self.conv是DoubleConv的一个 实例(对象),它已经初始化好网络层。- 之后调用
self.conv(x)时,实际上是在调用self.conv.forward(x),它会执行DoubleConv里的forward()方法。
4. nn.Sequential() 进一步简化调用
在 Down 里,我们这样用:
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
这里 nn.Sequential() 是 一个模块容器,它会顺序执行里面的层。因此:
DoubleConv(in_channels, out_channels)创建了一个DoubleConv实例,它本质上是一个nn.Module。nn.Sequential()允许我们直接用self.maxpool_conv(x)来调用整个块,它会自动 先池化,再卷积。
5. 直观理解
你可以把 DoubleConv 想象成一个 自定义的神经网络层:
nn.Conv2d也是一个类,但我们通常这样用:
因为conv = nn.Conv2d(3, 64, kernel_size=3) output = conv(x)nn.Conv2d继承了nn.Module,它可以像函数一样被调用。- 同理,
DoubleConv也是nn.Module的子类,所以它也可以像函数一样调用。
6. 总结
DoubleConv是一个nn.Module的子类,当我们调用DoubleConv(in_channels, out_channels)时,它创建了一个包含两层卷积的神经网络模块。- 在 PyTorch 中,所有
nn.Module子类都实现了__call__()方法,它会自动调用forward(),所以可以直接conv_layer(x)来使用。 - 这样我们可以把
DoubleConv当作普通函数一样用,比如:self.conv = DoubleConv(in_channels, out_channels) x = self.conv(input) # 其实是在调用 self.conv.forward(input) - 这种设计使得 PyTorch 代码简洁,可以像使用
nn.Conv2d那样使用自定义的DoubleConv层。
所以,虽然 DoubleConv 是一个类,但它的行为更像是一个 可调用对象(callable object),这正是 PyTorch 的设计哲学,让我们可以更方便地构建和调用神经网络模块! ![]()
问题: 为什么在 U-Net 中需要 Padding(填充)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
在 U-Net 中,上采样(upsample) 后得到的特征图 x1 往往会比编码器传过来的 x2 稍微小一点,原因是:
- 下采样(如 max pooling)会丢掉像素。
- 上采样无法完美还原原来的尺寸。
但在 U-Net 中我们需要将 x1 和 x2 进行拼接(torch.cat),它们的 高和宽必须完全一致,否则就会报错。
举个例子(用具体数字来说明)
假设:
- 编码器传过来的
x2形状是[batch=1, channels=64, height=100, width=100] - 上采样后的
x1是[1, 64, 98, 98]
所以:
diffY = 100 - 98 = 2
diffX = 100 - 98 = 2
我们就需要对 x1 进行 padding,让它变成和 x2 一样大。
F.pad(x1, […]) 是什么意思?
F.pad(x1, [左, 右, 上, 下])
在代码中:
F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
带入数值(假设 diffX = diffY = 2):
左边 = 2 // 2 = 1
右边 = 2 - 1 = 1
上边 = 2 // 2 = 1
下边 = 2 - 1 = 1
最终就是:
F.pad(x1, [1, 1, 1, 1])
表示:
- 宽度方向左右各加 1 个像素
- 高度方向上下各加 1 个像素
填充后,x1 变成 [1, 64, 100, 100],与 x2 对齐,就可以拼接了。
假如差值是奇数怎么办?
比如 diffX = 3:
左边 = 3 // 2 = 1
右边 = 3 - 1 = 2
所以就是 [1, 2] 的非对称填充,这是允许的,PyTorch 能处理。
那为什么不用 F.interpolate 直接改尺寸呢?
你也可以用:
F.interpolate(x1, size=(x2.size(2), x2.size(3)))
来强制改尺寸,但:
- Padding 更精确。
- 它能保留特征图的空间对齐中心,对语义分割这类任务来说非常关键。
总结表格
| 特征图 | 高度 | 宽度 | 说明 |
|---|---|---|---|
x2 |
100 | 100 | 编码器输出 |
x1 |
98 | 98 | 上采样结果 |
| Padding | 上下各加 1,左右各加 1 | 使得 x1 尺寸变为 100×100 |
最终目的:
对 x1 进行填充,让它和 x2 尺寸一致 → 之后才能进行拼接。