U-Net实战教学(4)- 下采样上采样模块详细介绍

本文内容接着上文继续详细解读相关程序每个部分的含义。帮助梳理编程上下采样时的概念性问题。


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

详细解析 Down(下采样)和 Up(上采样)

U-Net 结构中:

  • 下采样(Down) 主要用于 降低特征图的分辨率,同时提取高层特征
  • 上采样(Up) 主要用于 恢复特征图的分辨率,并结合下采样阶段的特征以恢复细节。

1. Down(下采样)详细解析

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

(1) __init__() 方法

def __init__(self, in_channels, out_channels):

参数解析

  • in_channels:输入通道数(上一层特征图的通道数)。
  • out_channels:输出通道数(这一层特征图的通道数)。

(2) 定义 self.maxpool_conv

self.maxpool_conv = nn.Sequential(
    nn.MaxPool2d(2),
    DoubleConv(in_channels, out_channels)
)

self.maxpool_conv 由两部分组成:

  1. nn.MaxPool2d(2)最大池化,将图像分辨率减少一半(高宽都变成原来的一半)。
  2. DoubleConv(in_channels, out_channels)双层 3x3 卷积,提取特征。

(3) forward() 方法

def forward(self, x):
    return self.maxpool_conv(x)
  • 直接将输入 x 传入 self.maxpool_conv
  • 先进行 最大池化,再经过 两层卷积

(4) Down 的执行过程

假设输入 x 形状为 (1, 64, 256, 256)(批量大小 1,通道数 64,大小 256x256)。

操作 说明 输出形状
MaxPool2d(2) 2×2 池化,降低分辨率 (1, 64, 128, 128)
DoubleConv(64, 128) 两次 3×3 卷积,通道变为 128 (1, 128, 128, 128)

2. Up(上采样)详细解析

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

(1) __init__() 方法

def __init__(self, in_channels, out_channels, bilinear=True):

参数解析

  • in_channels:输入通道数(上一层的输出通道)。
  • out_channels:输出通道数(经过 DoubleConv 之后的通道数)。
  • bilinear
    • True:使用 双线性插值(更节省显存)。
    • False:使用 反卷积(转置卷积)(学习能力更强)。

(2) 选择上采样方法

if bilinear:
    self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
    self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
    self.conv = DoubleConv(in_channels, out_channels)

两种不同的上采样方式:

  1. 双线性插值 (bilinear=True)

    • nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    • 直接将图像放大 2 倍(不涉及学习参数)。
    • 使用 DoubleConv(in_channels, out_channels, in_channels // 2) 进行特征提取。
  2. 转置卷积 (bilinear=False)

    • nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
    • 通过 可学习的卷积操作 来恢复分辨率。
    • DoubleConv(in_channels, out_channels) 进一步提取特征。

(3) forward() 方法

def forward(self, x1, x2):
  • x1:来自上一层上采样的特征图(低分辨率)。
  • x2:来自U-Net 跳跃连接的特征图(高分辨率)。

(4) 进行上采样

x1 = self.up(x1)
  • x1 被上采样(放大 2 倍)。

(5) 计算尺寸差距

diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
  • x1 经过 Upsample 之后,尺寸可能和 x2 不完全匹配,因此需要计算两者的差距。

(6) 进行填充 (F.pad())

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                diffY // 2, diffY - diffY // 2])
  • F.pad() x1 进行填充,使 x1x2 尺寸匹配:
    • diffX // 2:左侧填充的像素数
    • diffX - diffX // 2:右侧填充的像素数
    • diffY // 2:顶部填充的像素数
    • diffY - diffY // 2:底部填充的像素数

(7) 跳跃连接 torch.cat()

x = torch.cat([x2, x1], dim=1)
  • x2(下采样时保存的高分辨率特征图)与 x1(上采样后的特征图)沿通道维度拼接
  • dim=1 代表沿 通道方向 拼接。

(8) DoubleConv 提取特征

return self.conv(x)
  • 拼接后的 x 通过 DoubleConv,得到最终输出

3. Up 运行示例

假设:

  • x1.shape = (1, 512, 32, 32)(低分辨率特征)
  • x2.shape = (1, 256, 64, 64)(跳跃连接特征)

通过 Up(512, 256),最终输出:

torch.Size([1, 256, 64, 64])

说明:

  • x1 上采样后变为 64×64,并与 x2 拼接。
  • DoubleConv 使输出通道数变成 256

总结

  • Down:最大池化 + DoubleConv(降低分辨率,提取特征)。
  • Up:上采样 + DoubleConv(恢复分辨率,并结合高分辨率特征)。

问题:为什么可以像函数一样直接使用 DoubleConv()

你可能觉得 DoubleConv 是一个 ,但它为什么可以像函数一样直接用,比如:

self.conv = DoubleConv(in_channels, out_channels)

这个问题的核心是:在 Python 中,类本质上是可调用对象(callable objects),当你实例化它时,相当于调用了它的 __init__() 方法


1. Python 类的实例化

在 Python 里,类是模板,而对象是具体的实例。当我们创建一个对象时,我们实际上是在调用类的 __init__ 方法。

例如:

class Example:
    def __init__(self, value):
        self.value = value

# 创建对象
e = Example(10)  # 这相当于调用 Example.__init__(e, 10)

等价于

e = Example.__call__(10)

2. DoubleConvnn.Module 的子类

class DoubleConv(nn.Module):
  • 由于 DoubleConv 继承了 torch.nn.Module,所以 DoubleConv 本质上是一个 神经网络模块
  • 当我们调用 DoubleConv(in_channels, out_channels) 时,它会 自动调用 __init__() 方法,创建一个 DoubleConv 层。

示例:

conv_layer = DoubleConv(3, 64)  # 这里 conv_layer 是一个包含两层卷积的神经网络层

这其实是:

conv_layer = DoubleConv.__call__(3, 64)

等价于

conv_layer = DoubleConv.__new__(DoubleConv)
conv_layer.__init__(3, 64)

最终 conv_layerDoubleConv 的一个实例,可以像普通 PyTorch 层那样使用。


3. 为什么可以像函数一样使用?

在 PyTorch 中,所有 nn.Module 的子类都默认实现了 __call__() 方法,它会自动调用 forward() 方法。因此:

x = torch.randn(1, 3, 256, 256)  # 创建一个 3 通道 256x256 的张量
output = conv_layer(x)  # 这实际上是调用 conv_layer.forward(x)

这等价于:

output = conv_layer.forward(x)

所以在 Up 里面这样写:

self.conv = DoubleConv(in_channels, out_channels)

意味着:

  • self.convDoubleConv 的一个 实例(对象),它已经初始化好网络层。
  • 之后调用 self.conv(x) 时,实际上是在调用 self.conv.forward(x),它会执行 DoubleConv 里的 forward() 方法。

4. nn.Sequential() 进一步简化调用

Down 里,我们这样用:

self.maxpool_conv = nn.Sequential(
    nn.MaxPool2d(2),
    DoubleConv(in_channels, out_channels)
)

这里 nn.Sequential()一个模块容器,它会顺序执行里面的层。因此:

  • DoubleConv(in_channels, out_channels) 创建了一个 DoubleConv 实例,它本质上是一个 nn.Module
  • nn.Sequential() 允许我们直接用 self.maxpool_conv(x) 来调用整个块,它会自动 先池化,再卷积

5. 直观理解

你可以把 DoubleConv 想象成一个 自定义的神经网络层

  • nn.Conv2d 也是一个类,但我们通常这样用:
    conv = nn.Conv2d(3, 64, kernel_size=3)
    output = conv(x)
    
    因为 nn.Conv2d 继承了 nn.Module,它可以像函数一样被调用。
  • 同理,DoubleConv 也是 nn.Module 的子类,所以它也可以像函数一样调用。

6. 总结

  1. DoubleConv 是一个 nn.Module 的子类,当我们调用 DoubleConv(in_channels, out_channels) 时,它创建了一个包含两层卷积的神经网络模块。
  2. 在 PyTorch 中,所有 nn.Module 子类都实现了 __call__() 方法,它会自动调用 forward(),所以可以直接 conv_layer(x) 来使用。
  3. 这样我们可以把 DoubleConv 当作普通函数一样用,比如:
    self.conv = DoubleConv(in_channels, out_channels)
    x = self.conv(input)  # 其实是在调用 self.conv.forward(input)
    
  4. 这种设计使得 PyTorch 代码简洁,可以像使用 nn.Conv2d 那样使用自定义的 DoubleConv 层。

所以,虽然 DoubleConv 是一个类,但它的行为更像是一个 可调用对象(callable object),这正是 PyTorch 的设计哲学,让我们可以更方便地构建和调用神经网络模块! :rocket:


问题: 为什么在 U-Net 中需要 Padding(填充)

def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

在 U-Net 中,上采样(upsample) 后得到的特征图 x1 往往会比编码器传过来的 x2 稍微小一点,原因是:

  • 下采样(如 max pooling)会丢掉像素。
  • 上采样无法完美还原原来的尺寸。

:bulb: 但在 U-Net 中我们需要将 x1x2 进行拼接(torch.cat),它们的 高和宽必须完全一致,否则就会报错。


:abacus: 举个例子(用具体数字来说明)

假设:

  • 编码器传过来的 x2 形状是 [batch=1, channels=64, height=100, width=100]
  • 上采样后的 x1[1, 64, 98, 98]

所以:

diffY = 100 - 98 = 2
diffX = 100 - 98 = 2

我们就需要对 x1 进行 padding,让它变成和 x2 一样大。


:package: F.pad(x1, […]) 是什么意思?

F.pad(x1, [左, 右, 上, 下])

在代码中:

F.pad(x1, [diffX // 2, diffX - diffX // 2,
           diffY // 2, diffY - diffY // 2])

带入数值(假设 diffX = diffY = 2):

左边   = 2 // 2 = 1
右边   = 2 - 1 = 1
上边   = 2 // 2 = 1
下边   = 2 - 1 = 1

最终就是:

F.pad(x1, [1, 1, 1, 1])

表示:

  • 宽度方向左右各加 1 个像素
  • 高度方向上下各加 1 个像素

:arrow_right: 填充后,x1 变成 [1, 64, 100, 100],与 x2 对齐,就可以拼接了。


:repeat: 假如差值是奇数怎么办?

比如 diffX = 3

左边 = 3 // 2 = 1
右边 = 3 - 1 = 2

所以就是 [1, 2] 的非对称填充,这是允许的,PyTorch 能处理。


:brain: 那为什么不用 F.interpolate 直接改尺寸呢?

你也可以用:

F.interpolate(x1, size=(x2.size(2), x2.size(3)))

来强制改尺寸,但:

  • Padding 更精确。
  • 它能保留特征图的空间对齐中心,对语义分割这类任务来说非常关键。

:thread: 总结表格

特征图 高度 宽度 说明
x2 100 100 编码器输出
x1 98 98 上采样结果
Padding 上下各加 1,左右各加 1 使得 x1 尺寸变为 100×100

最终目的:
:white_check_mark:x1 进行填充,让它和 x2 尺寸一致 → 之后才能进行拼接。