7-Distill

7-Distill#

模型蒸馏 (Knowledge Distillation, KD) 是一种机器学习模型压缩方法,它用于将大型模型(教师模型)的知识迁移到较小的模型(学生模型)中.

KD 背后的核心思想是将教师模型的综合知识转化为更精简、更有效的表示. 学生模型是一个较小的模型,目标是学习教师模型的行为,而不是直接从原始数据中学习.

大模型的 KD 有白盒蒸馏与黑盒蒸馏两个派别,对于本次实验代码中两个模型均为 MiniMind 开源模型,支持对教师模型内部结构的访问,因此在训练过程中,我们能够获取教师模型的 softmax 概率分布并用作软标签(soft labels),让小模型学习软标签,并使用 KL-Loss 来优化模型的参数,而不是直接学习输出 Token 的硬标签. 对于下一章蒸馏推理模型中,由于我们面向推理数据集进行蒸馏,并不存在输出 Token 的概率分布让我们学习,这种面向输出数据学习的蒸馏方式被称为黑盒蒸馏.

此笔记本的完整实现见主仓库 ../raw/data/Minimind/train_distillation.py

# -------------------- 导入标准和第三方库 --------------------
# 导入 os 库,用于与操作系统交互
import os
# 导入 argparse 库,用于解析命令行参数
import argparse
# 导入 time 库,用于计时
import time
# 导入 math 库,用于数学运算
import math
# 导入 warnings 库,用于控制警告信息
import warnings

# 导入 pandas 库,用于数据分析
import pandas as pd
# 导入 PyTorch 核心库
import torch
# 导入 PyTorch 的 functional 模块,包含常用函数如 softmax, log_softmax, kl_div
import torch.nn.functional as F
# 导入 PyTorch 的分布式训练库
import torch.distributed as dist
# 导入一个上下文管理器
from contextlib import nullcontext

# 导入 PyTorch 的优化器和神经网络模块
from torch import optim, nn
# 导入 PyTorch 的分布式数据并行工具
from torch.nn.parallel import DistributedDataParallel
# 导入 PyTorch 的数据加载器和分布式采样器
from torch.utils.data import DataLoader, DistributedSampler
# -------------------- 导入 Hugging Face 和自定义模块 --------------------
# 从 transformers 库导入 AutoTokenizer 和 AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
# 从我们自己的 model 文件夹中导入之前编写好的模块
from model.model import MiniMindLM # 我们的语言模型
from model.LMConfig import LMConfig # 模型配置类
# 蒸馏过程通常使用与 SFT 相同格式的数据,所以我们复用 SFTDataset
from model.dataset import SFTDataset 

# 设置警告过滤器,让程序忽略所有的警告信息,使输出更整洁
warnings.filterwarnings('ignore')

可选参数设置#

首先,查看训练的可选参数,这些参数在实际使用时通过解析命令行进行导入,我们用 class 进行包装.

class args:
    # out_dir: str = "out" # pytorch 格式权重文件保存位置 我们只展示训练过程 所以不使用
    epochs: int = 1 # 训练轮数
    batch_size: int = 2 # pretrain 数据集仅两个样本,设置 batch 为 2
    learning_rate: float = 5e-4 # 学习率
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype: str = 'bfloat16' # 16 bit 浮点数:8 bit 指数 + 7 bit 尾数
    # use_wandb: bool = False # 是否使用 wandb 我们不使用
    wandb_project: str = 'MiniMind-Notebook'
    num_workers: int = 1 # 工作进程数
    # ddp:bool = False # 单机多卡
    accumulation_steps: int = 1 # 梯度累积步数
    grad_clip: float = 1.0 # 梯度剪裁
    warmup_iters: int = 0 # 学习率热启动
    log_interval: int = 1 # 每一步打印日志 仅用于观察
    # save_interval: int = 100 # checkpoint 保存点 我们不使用
    local_rank: int = 1 # device 设备号
    dim: int = 512 # 词嵌入维度 模型超参数
    n_layers: int = 2 # MiniMind Block 数量 模型超参数
    max_seq_len: int = 512 # 序列长度阈值
    use_moe: bool = False # 是否启用混合专家
    data_path: str = '../raw/data/Minimind/toydata/sft_data.jsonl' # 数据集路径
# 打印出 `args` 中设置的 `device`,确认程序将在哪个设备上运行(CPU 或 GPU)
print(f'查看工作设备 {args.device}')
查看工作设备 cuda

接下来,我们对分词器、MiniMind 教师/学生模型以及数据迭代器执行初始化.

# 定义一个函数来初始化学生模型 (Student Model) 和分词器
def init_student_model(lm_config):
    # 1. 加载分词器
    tokenizer = AutoTokenizer.from_pretrained('../raw/data/Minimind/model/minimind_tokenizer')
    
    # 2. 根据传入的 `lm_config` (学生模型的配置) 创建模型
    model = MiniMindLM(lm_config)
    
    # 在真实流程中,学生模型可以从一个预训练好的较小模型开始,这里我们假设它是随机初始化的
    moe_path = '_moe' if lm_config.use_moe else ''
    # ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
    # state_dict = torch.load(ckp, map_location=args.device)
    # model.load_state_dict(state_dict, strict=False)

    # 3. 打印学生模型的参数量
    print(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    
    # 4. 将学生模型移动到指定设备并返回
    model = model.to(args.device)
    return model, tokenizer


# 定义一个函数来初始化教师模型 (Teacher Model)
def init_teacher_model(lm_config):
    # 1. 根据传入的 `lm_config` (教师模型的配置) 创建模型
    model = MiniMindLM(lm_config)
    
    # 教师模型必须是一个已经训练好的、性能强大的模型
    moe_path = '_moe' if lm_config.use_moe else ''
    # ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
    # state_dict = torch.load(ckp, map_location=args.device)
    # model.load_state_dict(state_dict, strict=False)
    
    # 2. 打印教师模型的参数量
    print(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    
    # 3. 将教师模型移动到指定设备并返回
    model = model.to(args.device)
    return model
# --- 1. 创建不同的模型配置 ---
# 创建学生模型的配置:较小(维度512,层数1)
lm_config_student = LMConfig(dim=512, n_layers=1, max_seq_len=512)
# 创建教师模型的配置:较大(维度768,层数2)
lm_config_teacher = LMConfig(dim=768, n_layers=2, max_seq_len=512)

# --- 2. 初始化模型 ---
# 初始化学生模型和分词器
model, tokenizer = init_student_model(lm_config_student)
# 初始化教师模型
teacher_model = init_teacher_model(lm_config_teacher)

# --- 3. 初始化数据集 ---
# 蒸馏过程使用的数据格式与 SFT 相同,所以复用 SFTDataset
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config_student.max_seq_len)

# --- 4. 初始化数据加载器 (DataLoader) ---
train_loader = DataLoader(
    train_ds,
    batch_size=args.batch_size,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
    num_workers=args.num_workers,
)

# --- 5. 打印确认信息 ---
print(f'模型位于设备:{model.device}, 词表长度:{tokenizer.vocab_size}, DataLoader:{train_loader}')
学生模型(LLM)总参数量:6.096 百万
教师模型(LLM)总参数量:17.305 百万
模型位于设备:cuda:0, 词表长度:6400, DataLoader:<torch.utils.data.dataloader.DataLoader object at 0x0000025BC85F06A0>

启动训练#

接下来,我们定义 MiniMind LoRA 微调所使用的优化器,损失函数和学习率调度,并进行一轮简单的训练.

# 这部分代码与之前的训练阶段基本相同

# 1. 定义学习率调度函数 (余弦退火)
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

# 2. 设置混合精度训练的梯度缩放器 (Scaler)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))

# 3. 初始化优化器 (AdamW)
#    优化器只接收学生模型 `model.parameters()` 的参数,因为我们只训练学生模型
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

# 4. 设置自动混合精度上下文 (Autocast)
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()

损失函数方面,使用 KL Loss 方法.

KL Loss 中,损失是 KL 散度,衡量学生模型和教师模型在面对相同输入时,在输出层产生的分类 logits 分布之间的距离. 直观理解上,就是让学生模型的输出尽量向教师模型的输出概率靠近.

\[D_{KL}(P||Q)=\sum_i P(i)\log\frac{P(i)}{Q(i)}\]

其中,\(P(i)\) 代表教师模型的概率分布,\(Q(i)\) 代表学生模型的预测分布.

# 定义知识蒸馏的核心损失函数:KL 散度损失
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
    # `temperature` (温度) 是一个超参数,用于平滑概率分布。
    #  - T > 1 会使概率分布更平滑,让学生模型学习到更多关于类别间关系的“暗知识”。
    #  - T = 1 就是标准的 softmax。
    
    # 1. 计算教师模型的“软标签” (soft labels)
    #    `with torch.no_grad()`: 教师模型的计算不需要梯度
    with torch.no_grad():
        # a. `teacher_logits / temperature`: 用温度平滑 logits
        # b. `F.softmax(..., dim=-1)`: 计算平滑后的概率分布
        # c. `.detach()`: 显式地将结果从计算图中分离,确保没有梯度流过
        teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()

    # 2. 计算学生模型的对数概率分布
    #    同样用温度进行平滑,然后计算 log_softmax
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)

    # 3. 计算 KL 散度
    #    `F.kl_div` 是 PyTorch 内置的 KL 散度计算函数
    #    它衡量 `student_log_probs` 和 `teacher_probs` 这两个分布的差异
    kl = F.kl_div(
        student_log_probs,
        teacher_probs,
        reduction=reduction  # `batchmean` 表示损失会在 batch 和 N 维度上求平均
    )
    
    # 4. 乘以 `temperature ** 2`
    #    这是为了在反向传播时,将梯度恢复到与 T=1 时相近的尺度,保持训练稳定
    return (temperature ** 2) * kl

接下来,我们来看训练函数.

# 定义蒸馏训练的 epoch 函数
# `alpha`: 一个超参数,用于平衡 CE loss 和 Distill loss 的权重
# `temperature`: 蒸馏温度
def train_epoch(epoch, alpha=0.0, temperature=1.0):
    start_time = time.time()
    
    # 确保教师模型处于评估模式且不计算梯度
    if teacher_model is not None:
        teacher_model.eval()
        teacher_model.requires_grad_(False)

    for step, (X, Y, loss_mask) in enumerate(train_loader):
        # ... (数据移动到设备,学习率更新,与之前相同)
        X, Y, loss_mask = X.to(args.device), Y.to(args.device), loss_mask.to(args.device)
        # ... (lr update)

        # --- 1. 学生模型前向传播 ---
        with ctx:
            res = model(X)
            student_logits = res.logits

        # --- 2. 教师模型前向传播 ---
        if teacher_model is not None:
            with torch.no_grad():
                teacher_logits = teacher_model(X).logits
                # 如果教师和学生的词汇表大小不同,需要对齐
                vocab_size_student = student_logits.size(-1)
                teacher_logits = teacher_logits[..., :vocab_size_student]

        # --- 3. 计算损失 ---
        # a) 计算传统的交叉熵损失 (CE Loss),让学生模型学习真实标签
        loss_mask_flat = loss_mask.view(-1)
        ce_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), Y.view(-1), reduction='none')
        ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
        if lm_config_student.use_moe:
            ce_loss += res.aux_loss # 如果学生是 MoE 模型,加上辅助损失

        # b) 计算蒸馏损失 (Distillation Loss),让学生模型学习教师模型的输出分布
        if teacher_model is not None:
            # `loss_mask_flat == 1` 用于筛选出需要计算损失的 token
            distill_loss = distillation_loss_fn(
                student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
                teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
                temperature=temperature
            )
        else:
            distill_loss = torch.tensor(0.0, device=args.device)

        # c) 计算总损失:将 CE loss 和 Distill loss 加权求和
        #    alpha=0.0 表示只使用蒸馏损失 (纯蒸馏)
        #    alpha=1.0 表示只使用 CE 损失 (等同于 SFT)
        #    0 < alpha < 1 表示混合两种损失
        loss = alpha * ce_loss + (1 - alpha) * distill_loss

        # --- 4. 反向传播和权重更新 (与之前相同) ---
        scaler.scale(loss).backward()
        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        # --- 5. 打印日志 (与之前相同) ---
        if step % args.log_interval == 0:
            spend_time = time.time() - start_time
            print(
                'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch,
                    args.epochs - 1,
                    step,
                    iter_per_epoch,
                    loss.item(),
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
                )
            )

接下来,我们启动一个 Epoch 的训练进行观察.

# 计算每个 epoch 的迭代次数
iter_per_epoch = len(train_loader) # 结果是 1

# 开始主训练循环
for epoch in range(args.epochs):
    # 调用 `train_epoch` 函数,开始蒸馏训练
    # 默认 alpha=0.0, temperature=1.0,表示进行纯蒸馏
    train_epoch(epoch)
Epoch:[0/0](0/1) loss:0.3252 lr:0.000550000000 epoch_Time:0.0min:
# 蒸馏训练演示结束后,同时删除学生模型和教师模型,释放 GPU 显存
del model, teacher_model