多层感知机的简洁实现

多层感知机的简洁实现#

import torch
from torch import nn
from d2l import torch as d2l

#下载模型使用
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,datasets
import matplotlib.pyplot as plt

我们添加了2个全连接层(之前我们只添加了1个全连接层)。 第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。 第二层是输出层。

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

训练过程的实现与我们实现softmax回归时完全相同, 这种模块化设计使我们能够将与模型架构有关的内容独立出来。

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

还是有自己的方式下载文件,这样可以指定位置

image_size = 28
data_transform = transforms.Compose([
    #transforms.ToPILImage(),   # 将torch.Tensor或numpy.ndarray类型图像转为PIL.Image类型图像。这段里面可以移除transforms.ToPILImage(),因为 FashionMNIST 数据集已经是 PIL.Image 类型
    transforms.Resize(image_size),#按给定尺寸对图像进行缩放
    transforms.ToTensor() #将PIL.Image或numpy.ndarray类型图像转为torch.Tensor类型图像
])
# train表示是否是训练集,download表示是否需要下载,transform表示是否需要进行数据变换
train_data = datasets.FashionMNIST(root='../raw/data/', train=True, download=True, transform=data_transform)
test_data = datasets.FashionMNIST(root='../raw/data/', train=False, download=True, transform=data_transform)
batch_size = 256
num_workers = 0  #mac 不知道为什么变为4也报错   # 对于Windows用户,这里应设置为0,否则会出现多线程错误
# DataLoader是一个用于生成batch数据的迭代器,可以设置batch_size、shuffle、num_workers等参数
#batch_size是指每个批次中包含的样本数量。shuffle=True表示在每个epoch开始时,将训练数据集打乱顺序,以增加模型的泛化能力。num_workers是指用于数据加载的线程数量,可以加快数据加载的速度。drop_last=True表示如果训练数据集的样本数量不能被batch_size整除,最后一个不完整的批次将被丢弃。
train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_iter = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

进行训练

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
../_images/3f013c57fd734a108cef195c0fbb910c26652ffd16320abff11b1aabf2631309.svg