反向传播#
我们书接上回,链式法则的内容
第三站:工厂升级——引入“误差”与“反向传播”#
一个好的工厂不仅要能生产,还要能按订单生产。
目标: 顾客订单是 10个 汉堡。
现状: 我们现在的机器设置生产出了 12个 汉堡。
误差(Loss): 我们生产多了,有了误差。我们用损失函数 \(\text{Loss} = (y - y_{\text{target}})^2\) 来量化这个误差。
我们的新任务是:调整旋钮 w 和 v,来减小这个误差 Loss。
为了智能地调整,我们需要知道:
误差
Loss对旋钮v有多敏感?(即 \(\frac{\partial \text{Loss}}{\partial v}\))误差
Loss对旋钮w有多敏感?(即 \(\frac{\partial \text{Loss}}{\partial w}\))
这个寻找“责任人”的过程,必须从后往前:
最终误差:我们知道
Loss的大小。追溯到机器B:
Loss的产生和最终产量y直接相关。而y是由机器B的旋钮v决定的。所以我们可以计算出Loss对v的“责任”。继续追溯到机器A:
y的变化也受到肉饼u的影响,而u是由机器A的旋钮w决定的。我们可以沿着链条,把“责任”一直传递给w。
这个从最终误差出发,利用链式法则,反向追溯并计算每个参数“责任”(梯度)的过程,就是反向传播(Backpropagation)。
第四站:上代码!模拟工厂的智能调优#
现在,我们用代码来精确模拟上面那个“调整旋钮”的过程。
场景设定:
机器A 的功能是:肉饼数量 u = 旋钮w * 牛肉量x。
机器B 的功能是:汉堡数量 y = 旋钮v * 肉饼数量u。这完美对应了“1个肉饼产出v个汉堡”的比喻。
我们来看看PyTorch如何帮我们自动“甩锅”。
import torch
# --- 1. 初始化我们的“工厂” ---
# 旋钮 w 和 v 是我们需要学习和调整的参数
w = torch.tensor(3.0, requires_grad=True) # 机器A的旋钮,初始值为3
v = torch.tensor(2.0, requires_grad=True) # 机器B的旋钮,初始值为2
# 输入的原材料
x = torch.tensor(2.0) # 固定的2公斤牛肉
# 顾客的订单 (我们的目标)
target_y = torch.tensor(10.0) 
# --- 2. 正向传播:生产一次汉堡 ---
# 就像工厂运作一样,从头到尾计算一次
# 机器A: u = w * x
u = w * x 
# 机器B: y = v * u
y = v * u 
print(f"原材料 x: {x}, 旋钮 w: {w.item()}, 旋钮 v: {v.item()}")
print(f"中间产品 u (肉饼): {u.item()}")
print(f"最终产品 y (汉堡): {y.item()}")
print(f"顾客订单 target_y: {target_y.item()}")
# --- 3. 计算误差 (Loss) ---
# 我们生产了多少,和订单差了多少?
# 这里用最简单的差的平方作为误差函数
loss = (y - target_y)**2
print(f"当前的误差 Loss: {loss.item()}")
# --- 4. 反向传播:开始自动“甩锅”!---
# 这是魔法发生的地方。从 loss 开始,反向计算梯度
loss.backward()
# --- 5. 查看结果:每个旋钮分到了多少“责任”? ---
# PyTorch已经帮我们算好了 loss 对每个旋钮的梯度
# 也就是 d(loss)/dw 和 d(loss)/dv
# 我们来手动验证一下,用链式法则:
# loss = (y - target_y)^2 = (v*u - target_y)^2 = (v*w*x - target_y)^2
# d(loss)/dv 是多少?
# d(loss)/dv = d(loss)/dy * dy/dv
# d(loss)/dy = 2 * (y - target_y) = 2 * (12 - 10) = 4
# dy/dv = u = 6
# 所以 d(loss)/dv = 4 * 6 = 24
print(f"Loss对旋钮v的梯度 (dLoss/dv): {v.grad.item()} (手动计算结果: 24.0)")
# d(loss)/dw 是多少?
# d(loss)/dw = d(loss)/dy * dy/du * du/dw
# d(loss)/dy = 4 (上面算过了)
# dy/du = v = 2
# du/dw = x = 2
# 所以 d(loss)/dw = 4 * 2 * 2 = 16
print(f"Loss对旋钮w的梯度 (dLoss/dw): {w.grad.item()} (手动计算结果: 16.0)")
原材料 x: 2.0, 旋钮 w: 3.0, 旋钮 v: 2.0
中间产品 u (肉饼): 6.0
最终产品 y (汉堡): 12.0
顾客订单 target_y: 10.0
当前的误差 Loss: 4.0
Loss对旋钮v的梯度 (dLoss/dv): 24.0 (手动计算结果: 24.0)
Loss对旋钮w的梯度 (dLoss/dw): 16.0 (手动计算结果: 16.0)
第五站:揭秘魔法——手动推导loss.backward()#
PyTorch的一行代码背后到底发生了什么?让我们手动用链式法则来验证一下结果。
A. 计算 Loss 对旋钮 v 的梯度 (\(\frac{\partial \text{Loss}}{\partial v}\))#
目标: 弄清楚旋钮 v 对最终误差 Loss 的影响。
路径: \(v \rightarrow y \rightarrow \text{Loss}\)
链式法则公式:
$\(
\frac{\partial \text{Loss}}{\partial v} = \frac{\partial \text{Loss}}{\partial y} \cdot \frac{\partial y}{\partial v}
\)$
计算第一环 \(\frac{\partial \text{Loss}}{\partial y}\):
含义: 最终汉堡数
y每增加1,误差Loss会增加多少。计算: \(\text{Loss} = (y - y_{\text{target}})^2 \implies \frac{\partial \text{Loss}}{\partial y} = 2(y - y_{\text{target}})\)
代入数值: \(2 \cdot (12.0 - 10.0) = 4.0\)
计算第二环 \(\frac{\partial y}{\partial v}\):
含义: 旋钮
v每增加1,汉堡产量y会增加多少。计算: \(y = v \cdot u \implies \frac{\partial y}{\partial v} = u\)
代入数值: \(u = 6.0\)
合并结果: $\( \frac{\partial \text{Loss}}{\partial v} = 4.0 \cdot 6.0 = 24.0 \)$ 这与PyTorch计算的
v.grad(24.0) 完全一致!
B. 计算 Loss 对旋钮 w 的梯度 (\(\frac{\partial \text{Loss}}{\partial w}\))#
目标: 弄清楚旋钮 w 对最终误差 Loss 的影响。
路径: \(w \rightarrow u \rightarrow y \rightarrow \text{Loss}\)
链式法则公式:
$\(
\frac{\partial \text{Loss}}{\partial w} = \frac{\partial \text{Loss}}{\partial y} \cdot \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial w}
\)$
计算第一环 \(\frac{\partial \text{Loss}}{\partial y}\):
含义: 同上,这个结果可以复用!
数值: \(4.0\)
计算第二环 \(\frac{\partial y}{\partial u}\):
含义: 肉饼
u每增加1,汉堡产量y会增加多少。计算: \(y = v \cdot u \implies \frac{\partial y}{\partial u} = v\)
代入数值: \(v = 2.0\)
计算第三环 \(\frac{\partial u}{\partial w}\):
含义: 旋钮
w每增加1,肉饼产量u会增加多少。计算: \(u = w \cdot x \implies \frac{\partial u}{\partial w} = x\)
代入数值: \(x = 2.0\)
合并结果: $\( \frac{\partial \text{Loss}}{\partial w} = 4.0 \cdot 2.0 \cdot 2.0 = 16.0 \)$ 这也与PyTorch计算的
w.grad(16.0) 完全一致!
最终总结#
现在,你应该对链式法则有了透彻的理解:
它是什么? 一种计算“函数套函数”(复合函数)导数的方法。核心思想是**“将一长串间接影响,拆解为一环扣一环的直接影响的乘积”**。
它在神经网络中的作用? 它就是反向传播的数学引擎。从最终的误差出发,利用链式法则,一步步反向计算出误差对网络中每一个参数(权重)的梯度(“责任”或“敏感度”)。
PyTorch为我们做了什么? 你只需定义好正向的计算流程(工厂如何生产)。当你调用
.backward()时,PyTorch会自动构建计算图,并完美地执行上述所有链式法则的运算,将梯度结果存放在.grad属性中,让你能轻松地更新参数、优化网络。
理解了链式法则,你就掌握了驱动深度学习模型自动学习的核心秘密。