什么是计算图(Computational Graph)?
可以把 计算图 想象成一个流程图,跟踪张量(数据)是如何一步步被运算的。你对张量进行的每一个操作(如加法、乘法等)都会成为 图中的一个节点。
例如,执行以下代码时:
import torch
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 # x 的平方
z = y + 3 # 加 3
PyTorch 会创建如下计算图:
x → (平方) → y → (加3) → z
由于 x
设置了 requires_grad=True
,PyTorch 记录了 x
如何影响 z
,这样就可以进行自动求导。
autograd 如何使用计算图?
PyTorch 自动构建 计算图并完成两个主要工作:
- 前向传播(Forward Pass):执行计算操作,生成新的张量(如
y = x²
)。 - 反向传播(Backward Pass):当你调用
.backward()
时,PyTorch 会沿着计算图 反向 传播,使用 链式法则 计算梯度。
示例:
z.backward() # 计算梯度
print(x.grad) # 输出 x 的梯度
PyTorch 计算 z
对 x
的梯度,即 dz/dx = 2x,当 x=2
时,结果是 4
。
什么是有向无环图(DAG, Directed Acyclic Graph)?
- 有向(Directed) → 计算图中的箭头表示数据流动的方向。
- 无环(Acyclic) → 计算图中不会有循环,计算只能沿着箭头的方向进行。
- 图(Graph) → 计算图连接了输入(如
x
)和输出(如z
)。
在 PyTorch 中,DAG 记录所有计算过程,使得梯度能够 自动计算。
为什么 PyTorch 的计算图是“动态的”?
与某些其他深度学习框架不同,PyTorch 在每次调用 .backward()
时 都会重新构建计算图,这意味着:
你可以使用
if
语句、循环等控制流,而不会影响计算图的构建。
计算图 不会固定,你可以在每次迭代中动态更改计算过程。
示例
for i in range(3):
x = torch.tensor(float(i), requires_grad=True)
y = x**3 if i % 2 == 0 else x**2
y.backward()
print(f"x={i}, 梯度={x.grad}")
每次循环,PyTorch 都会 创建新的计算图,从而支持不同的计算逻辑。
总结
计算图:跟踪张量运算的流程图
autograd 自动构建计算图
前向传播 计算输出值
反向传播 使用 链式法则 计算梯度
PyTorch 的计算图是动态的,支持灵活的运算