重0开始使用 CNN 进行手写数字识别#

这个 Jupyter Notebook 将引导你使用 PyTorch 框架完成 CNN 模型训练手写数字识别的整个过程。我们同样使用 MNIST 数据集。

目标:

  1. 使用 torchvision 加载并预处理 MNIST 数据集。

  2. 构建一个继承自 torch.nn.Module 的卷积神经网络(CNN)模型。

  3. 定义损失函数和优化器。

  4. 编写训练循环来训练模型。

  5. 在测试集上评估模型的性能。

  6. 可视化预测结果。

1. 导入必要的库#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

代码详解 & 为什么这么写?#

  • import torch: 导入 PyTorch 核心库,所有张量(Tensor)操作和基本功能都来源于此。

  • import torch.nn as nn: nn 是 neural network 的缩写。这个模块包含了构建神经网络所需的所有基本组件,比如卷积层 (nn.Conv2d)、线性层 (nn.Linear)、激活函数 (nn.ReLU) 以及所有模型的基类 (nn.Module)。我们用 as nn 是一个约定俗成的简写,方便调用。

  • import torch.optim as optim: 包含了各种优化算法,比如我们后面用到的 Adam (optim.Adam) 和 SGD。优化器的作用是根据模型计算出的误差来更新模型的权重。

  • from torch.utils.data import DataLoader: 这是一个非常重要且方便的工具。它能帮我们把庞大的数据集自动打包成一小批一小批(batch)的数据,并且可以实现自动洗牌(shuffle),让模型训练更有效率。

  • from torchvision import datasets, transforms: torchvision 是 PyTorch 官方的计算机视觉工具库。datasets 里包含了许多常用的数据集(比如 MNIST),可以直接下载使用。transforms 包含了各种对图像进行预处理的操作(比如转换成张量、归一化等)。

  • import matplotlib.pyplot as plt 和 import numpy as np: 这两个是数据科学的常用工具。Matplotlib 用于数据可视化(画图),NumPy 用于处理数组,这里我们主要用它来辅助 Matplotlib 显示图像。

2. 定义超参数和设备#

我们首先定义一些训练过程中会用到的超参数,并设置计算设备(优先使用 GPU)。

# 定义超参数
EPOCHS = 10  # Epoch 指的是“轮次”,一个 Epoch 代表模型完整地学习了一遍所有的训练数据。10 个 Epoch 就是让模型把整个数据集重复学习 10 遍。
BATCH_SIZE = 64  # “批大小”指的是模型每次更新权重时看多少张图片。我们不能一次把所有 60000 张训练图片都塞进内存,所以我们把它们分成一小批一小批地喂给模型。这里每批 64 张。
LEARNING_RATE = 0.001  # “学习率”控制了模型权重更新的幅度。太大了模型可能不稳定,太小了模型学习得太慢。这是一个需要调试的关键超参数。

# 设置设备 (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cpu

3. 加载和预处理 MNIST 数据集#

我们将使用 torchvision 来加载数据。在加载过程中,我们会定义一个 transform 流水线来进行数据预处理:

transforms.ToTensor(): 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,并将像素值从 [0, 255] 缩放到 [0.0, 1.0]

那么问题来了,为什么我们需要将像素值从 [0, 255] 缩放到 [0.0, 1.0]#

加快模型收敛速度(最重要的原因)

想象一下神经网络中的一个简单计算:

output = weight * input + bias

在训练过程中,模型的目标是调整 weight 和 bias 来让 output 接近我们期望的值。这个调整是通过梯度下降来完成的,调整的幅度(步长)由学习率 (Learning Rate)梯度 (Gradient) 共同决定。

  • 如果不缩放 (input 在 0-255 之间):

    • input 的值很大(比如 150, 220)。

    • 为了让 output 发生一点点变化,weight 只需要改变一个非常非常小的值。

    • 这会导致损失函数对于权重的梯度非常大。想象一个非常陡峭的山坡,你稍微动一下就可能“滚”出很远。

    • 大的梯度意味着优化器在更新权重时会进行非常大的跳跃,这使得训练过程非常不稳定,容易在最优解附近来回震荡,难以收敛。为了稳定训练,你不得不设置一个非常小的学习率,但这又会大大减慢训练速度。

  • 如果缩放 (input 在 0.0-1.0 之间):

    • input 的值很小(比如 0.5, 0.8)。

    • 现在,输入值和权重值的大小在同一个数量级上。

    • 这使得损失函数的梯度变得更小、更稳定。就像一个平缓的山坡,你可以稳健地、一步一步地走向谷底(最优解)。

    • 因为梯度稳定,我们可以使用一个相对较大的学习率,从而大大加快模型的收敛速度

一个简单的比喻:

假设你要调整房间的温度(目标)和你身上穿的衣服(权重)。

  • 不缩放的情况:温度计用的是开尔文(比如 293 K),而你的衣服只有“穿”或“不穿”两个选项。这两个单位的尺度差别巨大,你很难做出精细的调整。

  • 缩放的情况:温度计用的是一个“舒适度”指标,从 0 到 1。你的衣服也可以分为多个等级,比如“短袖”(0.2)、“毛衣”(0.6)、“羽绒服”(1.0)。现在,输入和权重的尺度匹配了,你可以更容易地学习到“当舒适度为 0.1 时,我应该穿羽绒服”这样的关系。

避免激活函数饱和,防止梯度消失

在 CNN 中,我们常用 ReLU 作为激活函数。但在早期的神经网络中,Sigmoid 和 Tanh 函数非常流行。它们都有一个共同的问题:饱和区

  • Sigmoid 函数:将任何输入映射到 (0, 1) 之间。当输入值的绝对值很大时(比如 > 5 或 < -5),函数的曲线变得非常平坦,其导数(梯度)几乎为 0。

  • Tanh 函数:将任何输入映射到 (-1, 1) 之间,同样在两端有饱和区。

如果输入像素值是 [0, 255],那么经过权重和偏置的计算后(weight * input + bias),得到的结果很容易就落入激活函数的饱和区。

后果是什么?

在反向传播时,权重的更新量正比于这个梯度。如果梯度为 0,那么无论链式法则前面部分的梯度有多大,传到这里的梯度也变成了 0。这意味着这个神经元的权重将不会得到任何更新。这个现象被称为“梯度消失”(Vanishing Gradients),它会导致网络学习极其缓慢甚至完全停止。

通过将输入缩放到 [0, 1],我们可以让大部分计算结果落在激活函数的“敏感区”(梯度较大的区域),从而保证梯度能够有效地回传,让网络持续学习。

数值稳定性和一致性

  • 一致性:在更复杂的模型中,你可能会有多种不同来源的输入数据。比如,除了图片像素值(0-255),你可能还有另一个特征是“图片的平均亮度”(0-1)。如果不对它们进行缩放,那么像素值这个特征在数值上会完全主导模型的初始学习过程,因为它比另一个特征大得多。将所有特征都缩放到相似的范围(如 [0, 1])可以确保它们在模型学习中处于一个“公平的起跑线”。

  • 数值稳定性:在计算机进行浮点数运算时,处理绝对值较小的数字通常比处理大数更稳定,可以减少出现数值溢出(overflow)或下溢(underflow)的风险。 定的均值和标准差对张量进行归一化。这里我们将数据归一化到 [-1.0, 1.0] 的范围,这有助于模型训练。

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片的像素值从 [0, 255] 的整数范围,自动缩放到 [0.0, 1.0] 的浮点数范围。
    transforms.Normalize((0.5,), (0.5,))   # 用给定的均值和标准差对张量进行归一化。这里我们将数据归一化到 `[-1.0, 1.0]` 的范围,这有助于模型训练。
    # 计算公式为:output = (input - mean) / std

])

# 下载并加载训练数据集
train_dataset = datasets.MNIST(root='../raw/data/', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 下载并加载测试数据集
test_dataset = datasets.MNIST(root='../raw/data/', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# train=True 表示加载训练集。
# shuffle=True 它会在每个 epoch 开始前,把所有数据的顺序打乱。这可以防止模型学习到数据的排列顺序,从而帮助模型更好地泛化。测试集则不需要打乱

4. 可视化部分数据#

主要是用来看加载的数据的,这段代码也可以不用

# 获取一个批次的数据
data_iter = iter(train_loader)
images, labels = next(data_iter)

# 可视化
plt.figure(figsize=(10, 5))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    # 反归一化以便正确显示
    img = images[i] / 2 + 0.5  # 因为我们之前用 transforms.Normalize((0.5,), (0.5,)) 处理过它
                               # 所以我们需要反归一化来进行处理
                               # 归一化公式: output = (input - 0.5) / 0.5
    plt.imshow(img.squeeze(), cmap='gray')  # squeeze() 它的作用是移除张量中所有尺寸为 1 的维度。
                                            # img 的原始形状是 (1, 28, 28)
                                            # imshow 函数在处理二维数据(灰度图)时,期望的输入形状是 (Height, Width)
    plt.title(f"Label: {labels[i].item()}")
    plt.axis('off')
plt.show()
../_images/f4060861c553741895a44c852221cfabe6ef1b6c7d8b5d1697f7f9402a33713c.png

images, labels = next(data_iter) 这段代码怎么理解

images 是什么?#

一句话概括:它是一个包含了BATCH_SIZE张图片数据的集合,已经经过了预处理,随时可以送入模型。

  • 数据类型: 它是一个 PyTorch 张量 (Tensor)。这是 PyTorch 中用于所有计算的基本数据结构,你可以把它看作是能利用 GPU 加速的、功能更强大的 NumPy 数组。

  • 数据形状 (Shape): 它的形状是 (BATCH_SIZE, Channels, Height, Width)。根据我们代码中的设置,具体形状就是 (64, 1, 28, 28)

    • 64: 这是批次大小 (Batch Size)。代表这个 images 张量里包含了 64 张不同的图片。

    • 1: 这是颜色通道数 (Channels)。因为 MNIST 是灰度图,所以只有 1 个通道。如果是彩色图片(RGB),这里会是 3。

    • 28: 这是每张图片的高度 (Height),单位是像素。

    • 28: 这是每张图片的宽度 (Width),单位是像素。

  • 数据内容: 里面的值是浮点数 (float)。因为我们定义了 transforms,原始的 0-255 的整数像素值已经被转换成了 [-1.0, 1.0] 范围内的浮点数。

  • 它的用途: 这是我们模型的输入数据。在训练循环中,我们会把它喂给模型,像这样:outputs = model(images)

图形化理解 images (形状: (64, 1, 28, 28))

你可以把它想象成一个叠了 64 层的“三明治”,每一层就是一张 28x28 的单通道图片。

      +---------------+  <-- Image 1 (1x28x28)
     /               /|
    /               / |
   +---------------+  |
  /               /|  +
 /               / | /|
+---------------+  |/ |
|               |  +  |
| (64 images)   | /|  /  <-- 这是一个包含64个元素的批次 (Batch)
|               |/ | /
+---------------+  |/
|               |  +
|               | /
+---------------+

labels 是什么?#

一句话概括:它是一个包含了与 images 中 64 张图片一一对应的正确答案(标签)的列表。

  • 数据类型: 它也是一个 PyTorch 张量 (Tensor)

  • 数据形状 (Shape): 它的形状是 (BATCH_SIZE,)。根据我们的设置,具体形状就是 (64,)

    • 这代表它是一个包含 64 个元素的一维向量(或者说列表)。

  • 数据内容: 里面的值是整数 (integer),范围从 09labels 张量中的第 i 个数字,就是 images 张量中第 i 张图片所代表的真实数字。

    • 例如,labels 可能是这样的:tensor([5, 0, 4, 1, 9, 2, ..., 8])

    • labels[0] 的值是 5,意味着 images[0] 这张图片上画的是数字“5”。

    • labels[1] 的值是 0,意味着 images[1] 这张图片上画的是数字“0”。

    • 以此类推…

  • 它的用途: 这是理想答案或“标准答案” (Ground Truth)。我们会用它和模型的预测结果 outputs 一起,来计算损失值,像这样:loss = criterion(outputs, labels)


总结表格#

特性

images

labels

用途

模型的输入 (Input)

正确的答案 (Ground Truth)

数据类型

PyTorch 张量 (Tensor)

PyTorch 张量 (Tensor)

形状

(64, 1, 28, 28) (四维)

(64,) (一维)

数值类型

浮点数 (float)

整数 (integer)

数值范围

[-1.0, 1.0]

09

对应关系

包含 64 张图片的数据

包含与 64 张图片一一对应的数字标签

5. 构建 CNN 模型#

在 PyTorch 中,我们通常通过创建一个继承自 torch.nn.Module 的类来定义模型结构。我们在 __init__ 方法中定义网络层,在 forward 方法中定义数据的前向传播路径。

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 输入图像: 1x28x28 (channel, height, width)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            # -> 16x28x28
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # -> 16x14x14
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2), # -> 32x14x14
            nn.ReLU(),
            nn.MaxPool2d(2) # -> 32x7x7
        )
        # 全连接层
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x): # 根据上面的配置,传入进来的X的形状: (64, 1, 28, 28)
        x = self.conv1(x) 
        x = self.conv2(x)
        # 展平操作
        x = x.view(x.size(0), -1) # batch_size, 32*7*7
        output = self.out(x)
        return output

# 实例化模型并移动到设备
model = CNN().to(device)
print(model)
CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)

追踪一批数据 (Batch) 的变形记#

假设我们的 BATCH_SIZE 是 64。那么,输入到 forward 方法的 x 是一个形状 (Shape) 为 (64, 1, 28, 28) 的张量。

  • 64: 批次里有 64 张图片。

  • 1: 每张图片是灰度的,只有 1 个颜色通道。

  • 28, 28: 每张图片的高度和宽度是 28 像素。

现在,我们来追踪这批数据在 forward 函数中的每一步变化:

初始状态: x 的形状: (64, 1, 28, 28)


第 1 步: x = self.conv1(x) 这套流水线包含三道工序:

  1. nn.Conv2d(1, 16, 5, padding=2):

    • 16 个 5x5 的卷积核在 28x28 的图像上滑动。

    • 因为 padding=2,输出的图像尺寸不变,还是 28x28。

    • 因为有 16 个卷积核,所以输出通道数从 1 变为 16。

    • 形状变为: (64, 16, 28, 28)

  2. nn.ReLU(): 激活函数,对每个元素进行计算,不改变形状

  3. nn.MaxPool2d(2):

    • 在 2x2 的窗口内取最大值,这会将图像的高度和宽度都减半

    • 28 / 2 = 14。

    • 形状变为: (64, 16, 14, 14)


第 2 步: x = self.conv2(x) 这套流水线也包含三道工序:

  1. nn.Conv2d(16, 32, 5, padding=2):

    • 输入通道数是 16,输出通道数是 32。

    • 尺寸 14x14 同样因为 padding 而保持不变。

    • 形状变为: (64, 32, 14, 14)

  2. nn.ReLU(): 不改变形状

  3. nn.MaxPool2d(2):

    • 高度和宽度再次减半。14 / 2 = 7。

    • 形状变为: (64, 32, 7, 7)

到这里,“特征提取”阶段结束。我们得到了 64 张图片,每张图片都被浓缩成了一个 32x7x7 的特征图谱。


第 3 步: x = x.view(x.size(0), -1) (展平)

  • 这个操作会保持第一个维度(批次大小 64)不变,然后将后面的所有维度 (32, 7, 7) 拉成一个一维向量。

  • 向量的长度 = 32 * 7 * 7 = 1568

  • 形状变为: (64, 1568)

  • 现在我们有 64 个样本,每个样本都由一个包含 1568 个浓缩特征值的长向量表示。


第 4 步: output = self.out(x) (全连接层)

  • self.out 是 nn.Linear(32 * 7 * 7, 10),也就是 nn.Linear(1568, 10)。

  • 它接收一个长度为 1568 的向量,输出一个长度为 10 的向量。

  • 它对批次中的每一张图片都执行这个操作。

  • 最终输出形状: (64, 10)


最终结果: 我们输入了一批 (64, 1, 28, 28) 的图像,最终得到了一个 (64, 10) 的张量。这个张量的每一行都代表一张输入图片,该行包含 10 个分数(logits),分别对应模型预测这张图片是数字 0, 1, 2, …, 9 的可能性。接下来的损失函数(CrossEntropyLoss)就会使用这个输出来计算误差。

总结表格#

操作

输入形状

输出形状

逻辑说明

初始输入

(64, 1, 28, 28)

一批 64 张 28x28 的灰度图

self.conv1

(64, 1, 28, 28)

(64, 16, 14, 14)

提取初步特征,尺寸减半,深度增加

self.conv2

(64, 16, 14, 14)

(64, 32, 7, 7)

提取更复杂特征,尺寸再次减半,深度再增加

x.view()

(64, 32, 7, 7)

(64, 1568)

将三维特征图谱展平成一维特征向量

self.out

(64, 1568)

(64, 10)

根据特征向量,为 10 个类别打分

6. 定义损失函数和优化器#

# 损失函数: 交叉熵损失,它内部已经包含了 Softmax。是用于多分类问题的标准损失函数。它衡量的是模型输出的概率分布与真实的标签之间的差距。
criterion = nn.CrossEntropyLoss()  

# 优化器
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

7. 训练模型#

这是 PyTorch 中最核心的训练循环。对于每个批次的数据:

  1. 将数据移动到指定设备。

  2. 通过模型进行前向传播得到预测输出。

  3. 计算损失。

  4. 将梯度清零 (optimizer.zero_grad())。

  5. 反向传播计算梯度 (loss.backward())。

  6. 更新模型权重 (optimizer.step())。

print("Starting Training...")
total_step = len(train_loader)

for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移动到设备
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{EPOCHS}], Step [{i + 1}/{total_step}], Loss: {loss.item():.4f}')
            
print("Training Finished!")
Starting Training...
Epoch [1/10], Step [100/938], Loss: 0.3346
Epoch [1/10], Step [200/938], Loss: 0.1026
Epoch [1/10], Step [300/938], Loss: 0.1289
Epoch [1/10], Step [400/938], Loss: 0.1737
Epoch [1/10], Step [500/938], Loss: 0.0475
Epoch [1/10], Step [600/938], Loss: 0.0556
Epoch [1/10], Step [700/938], Loss: 0.0440
Epoch [1/10], Step [800/938], Loss: 0.0572
Epoch [1/10], Step [900/938], Loss: 0.0749
Epoch [2/10], Step [100/938], Loss: 0.0164
Epoch [2/10], Step [200/938], Loss: 0.0191
Epoch [2/10], Step [300/938], Loss: 0.0391
Epoch [2/10], Step [400/938], Loss: 0.1859
Epoch [2/10], Step [500/938], Loss: 0.0769
Epoch [2/10], Step [600/938], Loss: 0.0077
Epoch [2/10], Step [700/938], Loss: 0.0302
Epoch [2/10], Step [800/938], Loss: 0.0603
Epoch [2/10], Step [900/938], Loss: 0.1187
Epoch [3/10], Step [100/938], Loss: 0.0200
Epoch [3/10], Step [200/938], Loss: 0.0223
Epoch [3/10], Step [300/938], Loss: 0.0096
Epoch [3/10], Step [400/938], Loss: 0.0014
Epoch [3/10], Step [500/938], Loss: 0.0099
Epoch [3/10], Step [600/938], Loss: 0.0054
Epoch [3/10], Step [700/938], Loss: 0.0630
Epoch [3/10], Step [800/938], Loss: 0.0145
Epoch [3/10], Step [900/938], Loss: 0.0204
Epoch [4/10], Step [100/938], Loss: 0.0067
Epoch [4/10], Step [200/938], Loss: 0.0032
Epoch [4/10], Step [300/938], Loss: 0.0229
Epoch [4/10], Step [400/938], Loss: 0.0768
Epoch [4/10], Step [500/938], Loss: 0.0402
Epoch [4/10], Step [600/938], Loss: 0.0097
Epoch [4/10], Step [700/938], Loss: 0.0091
Epoch [4/10], Step [800/938], Loss: 0.0096
Epoch [4/10], Step [900/938], Loss: 0.0044
Epoch [5/10], Step [100/938], Loss: 0.0085
Epoch [5/10], Step [200/938], Loss: 0.0462
Epoch [5/10], Step [300/938], Loss: 0.0185
Epoch [5/10], Step [400/938], Loss: 0.0135
Epoch [5/10], Step [500/938], Loss: 0.0020
Epoch [5/10], Step [600/938], Loss: 0.0041
Epoch [5/10], Step [700/938], Loss: 0.0678
Epoch [5/10], Step [800/938], Loss: 0.0012
Epoch [5/10], Step [900/938], Loss: 0.2107
Epoch [6/10], Step [100/938], Loss: 0.0051
Epoch [6/10], Step [200/938], Loss: 0.0025
Epoch [6/10], Step [300/938], Loss: 0.0006
Epoch [6/10], Step [400/938], Loss: 0.0111
Epoch [6/10], Step [500/938], Loss: 0.0096
Epoch [6/10], Step [600/938], Loss: 0.0012
Epoch [6/10], Step [700/938], Loss: 0.0066
Epoch [6/10], Step [800/938], Loss: 0.0522
Epoch [6/10], Step [900/938], Loss: 0.0022
Epoch [7/10], Step [100/938], Loss: 0.0057
Epoch [7/10], Step [200/938], Loss: 0.0022
Epoch [7/10], Step [300/938], Loss: 0.0024
Epoch [7/10], Step [400/938], Loss: 0.0220
Epoch [7/10], Step [500/938], Loss: 0.0006
Epoch [7/10], Step [600/938], Loss: 0.0044
Epoch [7/10], Step [700/938], Loss: 0.0246
Epoch [7/10], Step [800/938], Loss: 0.0362
Epoch [7/10], Step [900/938], Loss: 0.0238
Epoch [8/10], Step [100/938], Loss: 0.0080
Epoch [8/10], Step [200/938], Loss: 0.0013
Epoch [8/10], Step [300/938], Loss: 0.0073
Epoch [8/10], Step [400/938], Loss: 0.0026
Epoch [8/10], Step [500/938], Loss: 0.0007
Epoch [8/10], Step [600/938], Loss: 0.0003
Epoch [8/10], Step [700/938], Loss: 0.0061
Epoch [8/10], Step [800/938], Loss: 0.0003
Epoch [8/10], Step [900/938], Loss: 0.0001
Epoch [9/10], Step [100/938], Loss: 0.0014
Epoch [9/10], Step [200/938], Loss: 0.0001
Epoch [9/10], Step [300/938], Loss: 0.0063
Epoch [9/10], Step [400/938], Loss: 0.0018
Epoch [9/10], Step [500/938], Loss: 0.0256
Epoch [9/10], Step [600/938], Loss: 0.0138
Epoch [9/10], Step [700/938], Loss: 0.0011
Epoch [9/10], Step [800/938], Loss: 0.0007
Epoch [9/10], Step [900/938], Loss: 0.0046
Epoch [10/10], Step [100/938], Loss: 0.0038
Epoch [10/10], Step [200/938], Loss: 0.0056
Epoch [10/10], Step [300/938], Loss: 0.0166
Epoch [10/10], Step [400/938], Loss: 0.0003
Epoch [10/10], Step [500/938], Loss: 0.0017
Epoch [10/10], Step [600/938], Loss: 0.0213
Epoch [10/10], Step [700/938], Loss: 0.0002
Epoch [10/10], Step [800/938], Loss: 0.0016
Epoch [10/10], Step [900/938], Loss: 0.0005
Training Finished!

8. 评估模型#

在评估阶段,我们不需要计算梯度,所以使用 with torch.no_grad(): 来节省计算资源。我们将模型设置为评估模式 model.eval(),这会关闭像 Dropout 和 BatchNorm 这样的层(尽管我们这个简单模型没有使用它们)。

model.eval()  # 设置模型为评估模式
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the 10000 test images: {100 * correct / total:.2f} %')
Accuracy of the model on the 10000 test images: 99.10 %

9. 可视化预测结果#

# 获取一个批次的测试数据
data_iter = iter(test_loader)
images, labels = next(data_iter)
images_on_device = images.to(device)

# 预测
outputs = model(images_on_device)
_, predicted = torch.max(outputs.data, 1)

# 可视化
plt.figure(figsize=(15, 8))
for i in range(15):
    if i >= len(images): break
    plt.subplot(3, 5, i + 1)
    img = images[i] / 2 + 0.5 # 反归一化
    plt.imshow(img.squeeze(), cmap='gray')
    
    pred_label = predicted[i].item()
    true_label = labels[i].item()
    color = 'green' if pred_label == true_label else 'red'
    
    plt.title(f"Pred: {pred_label} (True: {true_label})", color=color)
    plt.axis('off')
plt.show()
../_images/419fc3175fd8392217d9b1aa303ccd7b971be39dd93402d8cb955091e340586e.png

总结#

CNN 数据流的图形化分解#

我们再次追踪一张手写数字“8”的图片,看它如何在网络中一步步变形。

第 0 步: 输入 (Input)#

一切的开始:一张标准的、经过预处理的28x28像素灰度图。

+--------------------------+
|       Input Image        |
|      (Number '8')        |
|                          |
|        ██████            |
|      ██      ██          |
|        ██████            |
|      ██      ██          |
|        ██████            |
|                          |
+--------------------------+
  Shape: [1, 28, 28]
(通道=1, 高=28, 宽=28)

第 1 步: 进入 self.conv1 流水线 (提取基础特征)#

图像变厚、变小。网络通过16个不同的“滤镜”扫描图像,提取出边缘、角点等基础特征,然后将信息进行浓缩。

     Input Image                      Feature Maps (after Conv2d)             Feature Maps (after Pool)
  Shape: [1, 28, 28]                   Shape: [16, 28, 28]                   Shape: [16, 14, 14]
+--------------------+                                                        
|                    |                                                        
|      (Image)       | --- Conv2d (1 -> 16) --->  +-------+  --- MaxPool2d --->  +-----+
|                    |                            | ▒▒▒▒▒ |      (Size / 2)      | ▒▒▒ |
+--------------------+                            | ▒▒▒▒▒ |                      | ▒▒▒ |
                                                  |  ...  | (16 maps)            | ... | (16 maps)
                                                  +-------+                      +-----+

第 2 步: 进入 self.conv2 流水线 (提取复杂特征)#

特征图变得更厚、更小。网络在基础特征之上,组合出更复杂的形状,比如构成数字“8”的小圆圈和弧线。

  Input Feature Maps                 Feature Maps (after Conv2d)             Feature Maps (after Pool)
 Shape: [16, 14, 14]                  Shape: [32, 14, 14]                   Shape: [32, 7, 7]
                                                                          
  +----------------+                                                        
  |                |                                                        
  | (16 map stack) | --- Conv2d (16 -> 32) -->  +-------+  --- MaxPool2d --->  +-----+
  |                |                            | ▓▓▓▓▓ |      (Size / 2)      | ▓▓▓ |
  +----------------+                            | ▓▓▓▓▓ |                      | ▓▓▓ |
                                                |  ...  | (32 maps)            | ... | (32 maps)
                                                +-------+                      +-----+

第 3 步: 展平 (Flatten / View)#

从三维空间到一维空间。为了让分类器能够理解,我们将这个浓缩后的三维“特征立方体”强行拉平成一个一维的“特征长条”。

      Input "Feature Cube"                             Output "Feature Vector"
       Shape: [32, 7, 7]                                  Shape: [1568]

      +---------+
     /         /|
    /         / | ----> (32 layers deep)
   +---------+  |
   |         |  |
   | (7x7)   |  +
   |         | /
   +---------+          ------ Flatten ------>        [▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓]
  (32 * 7 * 7 = 1568)                                   (A vector with 1568 values)

第 4 步: self.out 分类决策 (Classification)#

根据“特征指纹”进行投票,得出最终结论。

       Input "Feature Vector"                               Output "Scores"
          Shape: [1568]                                       Shape: [10]

                                                              +-----------+
                                                              | Score for 0 |
                                                              | Score for 1 |
[▓▓▓▓▓...▓▓▓▓▓]  --- Fully Connected Layer --->                |    ...    |
                                                              | Score for 8 | <-- 🏆 Highest Score
                                                              |    ...    |
                                                              | Score for 9 |
                                                              +-----------+

最终结论: 模型看到数字“8”对应的分数最高,因此预测这张图片就是 “8”

https://poloclub.github.io/cnn-explainer/