反向传播#

我们书接上回,链式法则的内容

第三站:工厂升级——引入“误差”与“反向传播”#

一个好的工厂不仅要能生产,还要能按订单生产

  • 目标: 顾客订单是 10个 汉堡。

  • 现状: 我们现在的机器设置生产出了 12个 汉堡。

  • 误差(Loss): 我们生产多了,有了误差。我们用损失函数 \(\text{Loss} = (y - y_{\text{target}})^2\) 来量化这个误差。

我们的新任务是:调整旋钮 wv,来减小这个误差 Loss

为了智能地调整,我们需要知道:

  • 误差 Loss 对旋钮 v 有多敏感?(即 \(\frac{\partial \text{Loss}}{\partial v}\)

  • 误差 Loss 对旋钮 w 有多敏感?(即 \(\frac{\partial \text{Loss}}{\partial w}\)

这个寻找“责任人”的过程,必须从后往前

  1. 最终误差:我们知道 Loss 的大小。

  2. 追溯到机器BLoss 的产生和最终产量 y 直接相关。而 y 是由机器B的旋钮 v 决定的。所以我们可以计算出 Lossv 的“责任”。

  3. 继续追溯到机器Ay 的变化也受到肉饼 u 的影响,而 u 是由机器A的旋钮 w 决定的。我们可以沿着链条,把“责任”一直传递给 w

这个从最终误差出发,利用链式法则,反向追溯并计算每个参数“责任”(梯度)的过程,就是反向传播(Backpropagation)

第四站:上代码!模拟工厂的智能调优#

现在,我们用代码来精确模拟上面那个“调整旋钮”的过程。

场景设定:

  • 机器A 的功能是:肉饼数量 u = 旋钮w * 牛肉量x。

  • 机器B 的功能是:汉堡数量 y = 旋钮v * 肉饼数量u。这完美对应了“1个肉饼产出v个汉堡”的比喻。

我们来看看PyTorch如何帮我们自动“甩锅”。

import torch

# --- 1. 初始化我们的“工厂” ---
# 旋钮 w 和 v 是我们需要学习和调整的参数
w = torch.tensor(3.0, requires_grad=True) # 机器A的旋钮,初始值为3
v = torch.tensor(2.0, requires_grad=True) # 机器B的旋钮,初始值为2

# 输入的原材料
x = torch.tensor(2.0) # 固定的2公斤牛肉

# 顾客的订单 (我们的目标)
target_y = torch.tensor(10.0) 

# --- 2. 正向传播:生产一次汉堡 ---
# 就像工厂运作一样,从头到尾计算一次
# 机器A: u = w * x
u = w * x 
# 机器B: y = v * u
y = v * u 

print(f"原材料 x: {x}, 旋钮 w: {w.item()}, 旋钮 v: {v.item()}")
print(f"中间产品 u (肉饼): {u.item()}")
print(f"最终产品 y (汉堡): {y.item()}")
print(f"顾客订单 target_y: {target_y.item()}")

# --- 3. 计算误差 (Loss) ---
# 我们生产了多少,和订单差了多少?
# 这里用最简单的差的平方作为误差函数
loss = (y - target_y)**2
print(f"当前的误差 Loss: {loss.item()}")

# --- 4. 反向传播:开始自动“甩锅”!---
# 这是魔法发生的地方。从 loss 开始,反向计算梯度
loss.backward()

# --- 5. 查看结果:每个旋钮分到了多少“责任”? ---
# PyTorch已经帮我们算好了 loss 对每个旋钮的梯度
# 也就是 d(loss)/dw 和 d(loss)/dv

# 我们来手动验证一下,用链式法则:
# loss = (y - target_y)^2 = (v*u - target_y)^2 = (v*w*x - target_y)^2

# d(loss)/dv 是多少?
# d(loss)/dv = d(loss)/dy * dy/dv
# d(loss)/dy = 2 * (y - target_y) = 2 * (12 - 10) = 4
# dy/dv = u = 6
# 所以 d(loss)/dv = 4 * 6 = 24
print(f"Loss对旋钮v的梯度 (dLoss/dv): {v.grad.item()} (手动计算结果: 24.0)")

# d(loss)/dw 是多少?
# d(loss)/dw = d(loss)/dy * dy/du * du/dw
# d(loss)/dy = 4 (上面算过了)
# dy/du = v = 2
# du/dw = x = 2
# 所以 d(loss)/dw = 4 * 2 * 2 = 16
print(f"Loss对旋钮w的梯度 (dLoss/dw): {w.grad.item()} (手动计算结果: 16.0)")
原材料 x: 2.0, 旋钮 w: 3.0, 旋钮 v: 2.0
中间产品 u (肉饼): 6.0
最终产品 y (汉堡): 12.0
顾客订单 target_y: 10.0
当前的误差 Loss: 4.0
Loss对旋钮v的梯度 (dLoss/dv): 24.0 (手动计算结果: 24.0)
Loss对旋钮w的梯度 (dLoss/dw): 16.0 (手动计算结果: 16.0)

第五站:揭秘魔法——手动推导loss.backward()#

PyTorch的一行代码背后到底发生了什么?让我们手动用链式法则来验证一下结果。

A. 计算 Loss 对旋钮 v 的梯度 (\(\frac{\partial \text{Loss}}{\partial v}\))#

目标: 弄清楚旋钮 v 对最终误差 Loss 的影响。 路径: \(v \rightarrow y \rightarrow \text{Loss}\) 链式法则公式: $\( \frac{\partial \text{Loss}}{\partial v} = \frac{\partial \text{Loss}}{\partial y} \cdot \frac{\partial y}{\partial v} \)$

  1. 计算第一环 \(\frac{\partial \text{Loss}}{\partial y}\):

    • 含义: 最终汉堡数 y 每增加1,误差 Loss 会增加多少。

    • 计算: \(\text{Loss} = (y - y_{\text{target}})^2 \implies \frac{\partial \text{Loss}}{\partial y} = 2(y - y_{\text{target}})\)

    • 代入数值: \(2 \cdot (12.0 - 10.0) = 4.0\)

  2. 计算第二环 \(\frac{\partial y}{\partial v}\):

    • 含义: 旋钮 v 每增加1,汉堡产量 y 会增加多少。

    • 计算: \(y = v \cdot u \implies \frac{\partial y}{\partial v} = u\)

    • 代入数值: \(u = 6.0\)

  3. 合并结果: $\( \frac{\partial \text{Loss}}{\partial v} = 4.0 \cdot 6.0 = 24.0 \)$ 这与PyTorch计算的 v.grad (24.0) 完全一致!

B. 计算 Loss 对旋钮 w 的梯度 (\(\frac{\partial \text{Loss}}{\partial w}\))#

目标: 弄清楚旋钮 w 对最终误差 Loss 的影响。 路径: \(w \rightarrow u \rightarrow y \rightarrow \text{Loss}\) 链式法则公式: $\( \frac{\partial \text{Loss}}{\partial w} = \frac{\partial \text{Loss}}{\partial y} \cdot \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial w} \)$

  1. 计算第一环 \(\frac{\partial \text{Loss}}{\partial y}\):

    • 含义: 同上,这个结果可以复用

    • 数值: \(4.0\)

  2. 计算第二环 \(\frac{\partial y}{\partial u}\):

    • 含义: 肉饼 u 每增加1,汉堡产量 y 会增加多少。

    • 计算: \(y = v \cdot u \implies \frac{\partial y}{\partial u} = v\)

    • 代入数值: \(v = 2.0\)

  3. 计算第三环 \(\frac{\partial u}{\partial w}\):

    • 含义: 旋钮 w 每增加1,肉饼产量 u 会增加多少。

    • 计算: \(u = w \cdot x \implies \frac{\partial u}{\partial w} = x\)

    • 代入数值: \(x = 2.0\)

  4. 合并结果: $\( \frac{\partial \text{Loss}}{\partial w} = 4.0 \cdot 2.0 \cdot 2.0 = 16.0 \)$ 这也与PyTorch计算的 w.grad (16.0) 完全一致!

最终总结#

现在,你应该对链式法则有了透彻的理解:

  1. 它是什么? 一种计算“函数套函数”(复合函数)导数的方法。核心思想是**“将一长串间接影响,拆解为一环扣一环的直接影响的乘积”**。

  2. 它在神经网络中的作用? 它就是反向传播的数学引擎。从最终的误差出发,利用链式法则,一步步反向计算出误差对网络中每一个参数(权重)的梯度(“责任”或“敏感度”)。

  3. PyTorch为我们做了什么? 你只需定义好正向的计算流程(工厂如何生产)。当你调用 .backward() 时,PyTorch会自动构建计算图,并完美地执行上述所有链式法则的运算,将梯度结果存放在 .grad 属性中,让你能轻松地更新参数、优化网络。

理解了链式法则,你就掌握了驱动深度学习模型自动学习的核心秘密。