线性回归#

线性回归核心就是在x和y轴中,给出一个数x会有相对应的一个y值。我们需要得到这一个模型(通俗说:一个直线公式)。

在线性回归中,数据使用线性预测函数来建模,并且未知的模型参数也是通过数据来估计。这些模型被叫做线性模型。最常用的线性回归建模是给定X值的y的条件均值是X的仿射函数。不太一般的情况,线性回归模型可以是一个中位数或一些其他的给定X的条件下y的条件分布的分位数作为X的线性函数表示。 线性回归有很多实际用途。分为以下两大类:

  1. 如果目标是预测或者映射,线性回归可以用来对观测数据集的和X的值拟合出一个预测模型。当完成这样一个模型以后,对于一个新增的X值,在没有给定与它相配对的y的情况下,可以用这个拟合过的模型预测出一个y值。

  2. 给定一个变量y和一些变量\({\displaystyle X_{1}},...,{\displaystyle X_{p}}\),这些变量有可能与y相关,线性回归分析可以用来量化y与Xj之间相关性的强度,评估出与y不相关的\({\displaystyle X_{j}}\),并识别出哪些\({\displaystyle X_{j}}\)的子集包含了关于y的冗余信息。

训练的过程下图会更直观的展示

from IPython import display
import time
from mindspore.train.callback import Callback
from mindspore import Model
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import nn
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import numpy as np
from mindspore import context

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def get_data(num, w=2.0, b=3.0):
    for _ in range(num):
        x = np.random.uniform(-10.0, 10.0)
        noise = np.random.normal(0, 1)
        y = x * w + b + noise
        yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)

plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_target_label, y_target_label, color="green")
plt.title("Eval data")
plt.show()
def create_dataset(num_data, batch_size=16, repeat_size=1):
    input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
    input_data = input_data.batch(batch_size)
    input_data = input_data.repeat(repeat_size)
    return input_data
data_number = 1600
batch_number = 16
repeat_number = 1

ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())

print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
class LinearNet(nn.Cell):
    def __init__(self):
        super(LinearNet, self).__init__()
        self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))

    def construct(self, x):
        x = self.fc(x)
        return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
    print(param, param.asnumpy())
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
                 Tensor(model_params[1]).asnumpy()[0])

plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.show()
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)

def plot_model_and_datasets(net, eval_data):
    weight = net.trainable_params()[0]
    bias = net.trainable_params()[1]
    x = np.arange(-10, 10, 0.1)
    y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
    x1, y1 = zip(*eval_data)
    x_target = x
    y_target = x_target * 2 + 3

    plt.axis([-11, 11, -20, 25])
    plt.scatter(x1, y1, color="red", s=5)
    plt.plot(x, y, color="blue")
    plt.plot(x_target, y_target, color="green")
    plt.show()
    time.sleep(0.02)
class ImageShowCallback(Callback):
    def __init__(self, net, eval_data):
        self.net = net
        self.eval_data = eval_data

    def step_end(self, run_context):
        plot_model_and_datasets(self.net, self.eval_data)
        display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)

plot_model_and_datasets(net, eval_data)
for param in net.trainable_params():
    print(param, param.asnumpy())
%matplotlib inline
import random
import torch
from d2l import torch as d2l

生成数据集#

在下面的代码中,我们生成一个包含1000个样本的数据集, 每个样本包含从标准正态分布中采样的2个特征。 我们的合成数据集是一个矩阵\(\mathbf{X}\in \mathbb{R}^{1000 \times 2}\)

我们使用线性模型参数\(\mathbf{w} = [2, -3.4]^\top\)\(b = 4.2\) 和噪声项\(\epsilon\)生成数据集及其标签:

\[\mathbf{y}= \mathbf{X} \mathbf{w} + b + \mathbf\epsilon.\]

\(\epsilon\)可以视为模型预测和标签时的潜在观测误差。 在这里我们认为标准假设成立,即\(\epsilon\)服从均值为0的正态分布。 为了简化问题,我们将标准差设为0.01。

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    #means (Tensor) – 均值(平均值)
    #std (Tensor) – 标准差 https://zh.wikihow.com/%E8%AE%A1%E7%AE%97%E6%A0%87%E5%87%86%E5%B7%AE
    #out (Tensor) – 可选的输出张量
    X = torch.normal(0, 1, (num_examples, len(w)))
    print(X)
    #两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现
    y = torch.matmul(X, w) + b
    #print(y)
    y += torch.normal(0, 0.01, y.shape)
    #print(y)
    #torch.shape 和 torch.size()
    #-1表示总数所在的位置
    return X, y.reshape((-1, 1))
    
#创建张量
true_w = torch.tensor([2, -3.4])
print(true_w)
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
tensor([ 2.0000, -3.4000])
tensor([[-1.4401,  0.7539],
        [ 0.5159,  0.0320],
        [-0.9519,  0.5600],
        ...,
        [-0.3998, -1.1616],
        [ 0.9922, -0.3611],
        [-1.9691, -1.0929]])

features中的每一行都包含一个二维数据样本, labels中的每一行都包含一维标签值(一个标量)

print('features:', features[0],'\nlabel:', labels[0])
print(features)
features: tensor([-1.4401,  0.7539]) 
label: tensor([-1.2494])
tensor([[-1.4401,  0.7539],
        [ 0.5159,  0.0320],
        [-0.9519,  0.5600],
        ...,
        [-0.3998, -1.1616],
        [ 0.9922, -0.3611],
        [-1.9691, -1.0929]])

通过生成第二个特征features[:, 1]和labels的散点图, 可以直观观察到两者之间的线性关系。

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
../_images/1795498fb06d38fe8f6b80ac1633dd82c0f21eaac5936c5583919e3fa3cf0b09.svg

读取数据集#

我们定义一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。

使用下面代码的时候先阅读下用法

yueld用法

直接参考 https://blog.csdn.net/mieleizhi0522/article/details/82142856/

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    #print(num_examples)
    indices = list(range(num_examples))
    #print(indices)
    # 这些样本是随机读取的,没有特定的顺序,打乱位置
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):  # 从0开始每次+10进行循环到1000为止
        batch_indices=indices[i: min(i + batch_size, num_examples)]  # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]         
        #print(batch_indices)
        #https://blog.csdn.net/mieleizhi0522/article/details/82142856/
        yield features[batch_indices], labels[batch_indices]  # 这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的

通常,我们利用GPU并行运算的优势,处理合理大小的“小批量”。 每个样本都可以并行地进行模型计算,且每个样本损失函数的梯度也可以被并行计算。 GPU可以在处理几百个样本时,所花费的时间不比处理一个样本时多太多。

我们直观感受一下小批量运算:读取第一个小批量数据样本并打印。 每个批量的特征维度显示批量大小和输入特征数。 同样的,批量的标签形状与batch_size相等。

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break
tensor([[ 2.2541,  0.6574],
        [-0.9650,  2.5686],
        [-1.9461, -0.1835],
        [-2.0508,  0.0197],
        [-1.0988,  1.5771],
        [-0.2417, -1.4114],
        [ 0.8945, -2.3486],
        [ 0.4615, -0.3796],
        [ 0.6798, -0.3848],
        [-0.2723, -0.4583]]) 
 tensor([[ 6.4765],
        [-6.4602],
        [ 0.9265],
        [ 0.0422],
        [-3.3593],
        [ 8.4987],
        [13.9808],
        [ 6.4193],
        [ 6.8528],
        [ 5.2233]])

初始化模型参数#

在下面的代码中,我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

定义模型#

我们只需计算输入特征\(\mathbf{X}\)和模型权重\(\mathbf{w}\)的矩阵-向量乘法后加上偏置\(b\)。 注意,上面的\(\mathbf{Xw}\)是一个向量,而\(b\)是一个标量。 回想一下广播机制: 当我们用一个向量加一个标量时,标量会被加到向量的每个分量上。

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

定义损失函数#

因为需要计算损失函数的梯度,所以我们应该先定义损失函数。在实现中,我们需要将真实值y的形状转换为和预测值y_hat的形状相同

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    #print(y_hat.shape)
    #print(y_hat)
    #真实值y的形状转换为和预测值y_hat的形状相同
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

定义优化算法#

下面的函数实现小批量随机梯度下降更新。 该函数接受模型参数集合、学习速率和批量大小作为输入。每 一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    #with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭/线程中锁的自动获取和释放等。
    # no_grad用来关闭梯度计算
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            #需要清理梯度值不然会累加
            param.grad.zero_()

训练#

现在我们已经准备好了模型训练所有需要的要素,可以实现主要的训练过程部分了。 理解这段代码至关重要,因为从事深度学习后, 相同的训练过程几乎一遍又一遍地出现。 在每次迭代中,我们读取一小批量训练样本,并通过我们的模型来获得一组预测。 计算完损失后,我们开始反向传播,存储每个参数的梯度。 最后,我们调用优化算法sgd来更新模型参数。

概括一下,我们将执行以下循环:

  • 初始化参数

  • 重复以下训练,直到完成

    • 计算梯度\(\mathbf{g} \leftarrow \partial_{(\mathbf{w},b)} \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} l(\mathbf{x}^{(i)}, y^{(i)}, \mathbf{w}, b)\)

    • 更新参数\((\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \mathbf{g}\)

在每个迭代周期(epoch)中,我们使用data_iter函数遍历整个数据集, 并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。 这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设为3和0.03。 设置超参数很棘手,需要通过反复试验进行调整。

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    
    for X, y in data_iter(batch_size, features, labels):
        # print(X)
        # print(y)
        # print(X.shape)
        # print(y.shape)
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        c=l.sum() #这个值可以很直观的反映出逐渐下降
        print(c)
        c.backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数,学习率lr是0.03
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
tensor(201.6803, grad_fn=<SumBackward0>)
tensor(101.2996, grad_fn=<SumBackward0>)
tensor(256.3406, grad_fn=<SumBackward0>)
tensor(70.4818, grad_fn=<SumBackward0>)
tensor(79.5928, grad_fn=<SumBackward0>)
tensor(101.3742, grad_fn=<SumBackward0>)
tensor(79.6462, grad_fn=<SumBackward0>)
tensor(167.7645, grad_fn=<SumBackward0>)
tensor(133.2247, grad_fn=<SumBackward0>)
tensor(78.7463, grad_fn=<SumBackward0>)
tensor(29.1191, grad_fn=<SumBackward0>)
tensor(73.3139, grad_fn=<SumBackward0>)
tensor(70.5829, grad_fn=<SumBackward0>)
tensor(69.4528, grad_fn=<SumBackward0>)
tensor(86.1565, grad_fn=<SumBackward0>)
tensor(41.1024, grad_fn=<SumBackward0>)
tensor(52.7591, grad_fn=<SumBackward0>)
tensor(89.0913, grad_fn=<SumBackward0>)
tensor(48.7610, grad_fn=<SumBackward0>)
tensor(55.0666, grad_fn=<SumBackward0>)
tensor(47.2398, grad_fn=<SumBackward0>)
tensor(61.6969, grad_fn=<SumBackward0>)
tensor(70.6755, grad_fn=<SumBackward0>)
tensor(37.6367, grad_fn=<SumBackward0>)
tensor(37.8340, grad_fn=<SumBackward0>)
tensor(33.1778, grad_fn=<SumBackward0>)
tensor(51.4532, grad_fn=<SumBackward0>)
tensor(49.0950, grad_fn=<SumBackward0>)
tensor(26.7704, grad_fn=<SumBackward0>)
tensor(23.9836, grad_fn=<SumBackward0>)
tensor(28.2156, grad_fn=<SumBackward0>)
tensor(21.2758, grad_fn=<SumBackward0>)
tensor(36.3212, grad_fn=<SumBackward0>)
tensor(15.0408, grad_fn=<SumBackward0>)
tensor(21.3395, grad_fn=<SumBackward0>)
tensor(18.0113, grad_fn=<SumBackward0>)
tensor(19.7111, grad_fn=<SumBackward0>)
tensor(12.6323, grad_fn=<SumBackward0>)
tensor(15.8467, grad_fn=<SumBackward0>)
tensor(33.0398, grad_fn=<SumBackward0>)
tensor(16.1379, grad_fn=<SumBackward0>)
tensor(15.6235, grad_fn=<SumBackward0>)
tensor(10.2626, grad_fn=<SumBackward0>)
tensor(7.9189, grad_fn=<SumBackward0>)
tensor(10.7055, grad_fn=<SumBackward0>)
tensor(13.0194, grad_fn=<SumBackward0>)
tensor(6.9912, grad_fn=<SumBackward0>)
tensor(10.8674, grad_fn=<SumBackward0>)
tensor(6.9948, grad_fn=<SumBackward0>)
tensor(14.2060, grad_fn=<SumBackward0>)
tensor(9.7381, grad_fn=<SumBackward0>)
tensor(6.9000, grad_fn=<SumBackward0>)
tensor(12.1574, grad_fn=<SumBackward0>)
tensor(8.6532, grad_fn=<SumBackward0>)
tensor(7.6929, grad_fn=<SumBackward0>)
tensor(8.0250, grad_fn=<SumBackward0>)
tensor(5.8311, grad_fn=<SumBackward0>)
tensor(6.2372, grad_fn=<SumBackward0>)
tensor(7.2693, grad_fn=<SumBackward0>)
tensor(2.9264, grad_fn=<SumBackward0>)
tensor(5.9385, grad_fn=<SumBackward0>)
tensor(2.1637, grad_fn=<SumBackward0>)
tensor(7.6368, grad_fn=<SumBackward0>)
tensor(7.2786, grad_fn=<SumBackward0>)
tensor(5.2494, grad_fn=<SumBackward0>)
tensor(7.1850, grad_fn=<SumBackward0>)
tensor(5.2735, grad_fn=<SumBackward0>)
tensor(3.0633, grad_fn=<SumBackward0>)
tensor(2.6340, grad_fn=<SumBackward0>)
tensor(4.8269, grad_fn=<SumBackward0>)
tensor(0.7580, grad_fn=<SumBackward0>)
tensor(1.8762, grad_fn=<SumBackward0>)
tensor(1.9080, grad_fn=<SumBackward0>)
tensor(0.9633, grad_fn=<SumBackward0>)
tensor(1.6908, grad_fn=<SumBackward0>)
tensor(3.0500, grad_fn=<SumBackward0>)
tensor(1.8512, grad_fn=<SumBackward0>)
tensor(1.1320, grad_fn=<SumBackward0>)
tensor(2.5227, grad_fn=<SumBackward0>)
tensor(1.5778, grad_fn=<SumBackward0>)
tensor(1.7570, grad_fn=<SumBackward0>)
tensor(1.5051, grad_fn=<SumBackward0>)
tensor(1.4556, grad_fn=<SumBackward0>)
tensor(1.4747, grad_fn=<SumBackward0>)
tensor(0.7939, grad_fn=<SumBackward0>)
tensor(1.8449, grad_fn=<SumBackward0>)
tensor(0.6647, grad_fn=<SumBackward0>)
tensor(0.9526, grad_fn=<SumBackward0>)
tensor(1.8906, grad_fn=<SumBackward0>)
tensor(0.9104, grad_fn=<SumBackward0>)
tensor(1.4263, grad_fn=<SumBackward0>)
tensor(0.2528, grad_fn=<SumBackward0>)
tensor(0.5916, grad_fn=<SumBackward0>)
tensor(1.0303, grad_fn=<SumBackward0>)
tensor(0.3583, grad_fn=<SumBackward0>)
tensor(0.4857, grad_fn=<SumBackward0>)
tensor(0.4611, grad_fn=<SumBackward0>)
tensor(0.4122, grad_fn=<SumBackward0>)
tensor(0.4256, grad_fn=<SumBackward0>)
tensor(0.5312, grad_fn=<SumBackward0>)
epoch 1, loss 0.048707
tensor(0.4047, grad_fn=<SumBackward0>)
tensor(0.4318, grad_fn=<SumBackward0>)
tensor(0.1489, grad_fn=<SumBackward0>)
tensor(0.3234, grad_fn=<SumBackward0>)
tensor(0.7924, grad_fn=<SumBackward0>)
tensor(0.3766, grad_fn=<SumBackward0>)
tensor(0.1817, grad_fn=<SumBackward0>)
tensor(0.2932, grad_fn=<SumBackward0>)
tensor(0.1443, grad_fn=<SumBackward0>)
tensor(0.4735, grad_fn=<SumBackward0>)
tensor(0.2328, grad_fn=<SumBackward0>)
tensor(0.1606, grad_fn=<SumBackward0>)
tensor(0.3941, grad_fn=<SumBackward0>)
tensor(0.3384, grad_fn=<SumBackward0>)
tensor(0.3248, grad_fn=<SumBackward0>)
tensor(0.2159, grad_fn=<SumBackward0>)
tensor(0.1264, grad_fn=<SumBackward0>)
tensor(0.2801, grad_fn=<SumBackward0>)
tensor(0.1981, grad_fn=<SumBackward0>)
tensor(0.2034, grad_fn=<SumBackward0>)
tensor(0.3248, grad_fn=<SumBackward0>)
tensor(0.2253, grad_fn=<SumBackward0>)
tensor(0.1125, grad_fn=<SumBackward0>)
tensor(0.0817, grad_fn=<SumBackward0>)
tensor(0.1067, grad_fn=<SumBackward0>)
tensor(0.1160, grad_fn=<SumBackward0>)
tensor(0.1588, grad_fn=<SumBackward0>)
tensor(0.1100, grad_fn=<SumBackward0>)
tensor(0.0498, grad_fn=<SumBackward0>)
tensor(0.0708, grad_fn=<SumBackward0>)
tensor(0.0662, grad_fn=<SumBackward0>)
tensor(0.0448, grad_fn=<SumBackward0>)
tensor(0.0930, grad_fn=<SumBackward0>)
tensor(0.0584, grad_fn=<SumBackward0>)
tensor(0.0522, grad_fn=<SumBackward0>)
tensor(0.0708, grad_fn=<SumBackward0>)
tensor(0.0551, grad_fn=<SumBackward0>)
tensor(0.0461, grad_fn=<SumBackward0>)
tensor(0.0486, grad_fn=<SumBackward0>)
tensor(0.0712, grad_fn=<SumBackward0>)
tensor(0.0204, grad_fn=<SumBackward0>)
tensor(0.0656, grad_fn=<SumBackward0>)
tensor(0.0338, grad_fn=<SumBackward0>)
tensor(0.1066, grad_fn=<SumBackward0>)
tensor(0.0174, grad_fn=<SumBackward0>)
tensor(0.0573, grad_fn=<SumBackward0>)
tensor(0.0279, grad_fn=<SumBackward0>)
tensor(0.0405, grad_fn=<SumBackward0>)
tensor(0.0290, grad_fn=<SumBackward0>)
tensor(0.0710, grad_fn=<SumBackward0>)
tensor(0.0504, grad_fn=<SumBackward0>)
tensor(0.0130, grad_fn=<SumBackward0>)
tensor(0.0256, grad_fn=<SumBackward0>)
tensor(0.0216, grad_fn=<SumBackward0>)
tensor(0.0116, grad_fn=<SumBackward0>)
tensor(0.0222, grad_fn=<SumBackward0>)
tensor(0.0167, grad_fn=<SumBackward0>)
tensor(0.0288, grad_fn=<SumBackward0>)
tensor(0.0346, grad_fn=<SumBackward0>)
tensor(0.0164, grad_fn=<SumBackward0>)
tensor(0.0086, grad_fn=<SumBackward0>)
tensor(0.0113, grad_fn=<SumBackward0>)
tensor(0.0191, grad_fn=<SumBackward0>)
tensor(0.0120, grad_fn=<SumBackward0>)
tensor(0.0086, grad_fn=<SumBackward0>)
tensor(0.0148, grad_fn=<SumBackward0>)
tensor(0.0114, grad_fn=<SumBackward0>)
tensor(0.0123, grad_fn=<SumBackward0>)
tensor(0.0059, grad_fn=<SumBackward0>)
tensor(0.0160, grad_fn=<SumBackward0>)
tensor(0.0057, grad_fn=<SumBackward0>)
tensor(0.0073, grad_fn=<SumBackward0>)
tensor(0.0026, grad_fn=<SumBackward0>)
tensor(0.0127, grad_fn=<SumBackward0>)
tensor(0.0050, grad_fn=<SumBackward0>)
tensor(0.0073, grad_fn=<SumBackward0>)
tensor(0.0061, grad_fn=<SumBackward0>)
tensor(0.0053, grad_fn=<SumBackward0>)
tensor(0.0055, grad_fn=<SumBackward0>)
tensor(0.0047, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0053, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0033, grad_fn=<SumBackward0>)
tensor(0.0027, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0036, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0021, grad_fn=<SumBackward0>)
tensor(0.0051, grad_fn=<SumBackward0>)
tensor(0.0020, grad_fn=<SumBackward0>)
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0033, grad_fn=<SumBackward0>)
tensor(0.0032, grad_fn=<SumBackward0>)
epoch 2, loss 0.000238
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0042, grad_fn=<SumBackward0>)
tensor(0.0024, grad_fn=<SumBackward0>)
tensor(0.0024, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0024, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
epoch 3, loss 0.000050

因为我们使用的是自己合成的数据集,所以我们知道真正的参数是什么。 因此,我们可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 事实上,真实参数和通过训练学到的参数确实非常接近。

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
w的估计误差: tensor([ 0.0001, -0.0012], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0005], grad_fn=<RsubBackward1>)

完整代码#

import random
import torch
from d2l import torch as d2l
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)  #生成特征和标签
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    print(indices)
    for i in range(0, num_examples, batch_size): #从0开始每次+10进行循环到1000为止
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])  # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]
        #print(batch_indices)
        yield features[batch_indices], labels[batch_indices]  #这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的



w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b
def squared_loss(y_hat, y):  #@save
    """均方损失"""
    #查看算出来的值和原来的标签比损失多少
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params: #先更新w然后更新b
            param -= lr * param.grad / batch_size
            param.grad.zero_()



for epoch in range(3):
    for X, y in data_iter(10, features, labels):
        # print(X)
        # print(y)
        # print(X.shape)
        # print(y.shape)
        l = squared_loss(linreg(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        c=l.sum() #这个值可以很直观的反映出逐渐下降
        print(c)
        c.backward()
        sgd([w, b], 0.03, 10)  # 使用参数的梯度更新参数,学习率lr是0.03
    with torch.no_grad():
        train_l = squared_loss(linreg(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
[934, 827, 883, 678, 487, 139, 288, 564, 730, 310, 422, 948, 605, 953, 918, 518, 645, 270, 13, 295, 694, 79, 189, 650, 736, 345, 769, 684, 200, 190, 555, 927, 72, 92, 986, 366, 810, 333, 731, 534, 438, 982, 74, 705, 353, 983, 893, 30, 857, 899, 506, 843, 338, 352, 389, 770, 388, 45, 755, 539, 400, 347, 442, 26, 935, 683, 226, 367, 128, 987, 795, 319, 93, 640, 18, 976, 886, 761, 475, 572, 157, 973, 282, 844, 560, 782, 104, 75, 580, 892, 741, 274, 548, 984, 291, 138, 833, 781, 455, 143, 35, 507, 807, 161, 205, 671, 853, 772, 529, 735, 225, 903, 441, 453, 570, 724, 448, 222, 239, 576, 847, 974, 538, 897, 968, 532, 618, 702, 591, 926, 558, 392, 583, 258, 215, 634, 554, 754, 202, 159, 68, 758, 114, 816, 577, 695, 943, 842, 432, 726, 160, 595, 240, 233, 520, 690, 563, 207, 300, 151, 325, 344, 635, 696, 732, 214, 721, 888, 196, 845, 802, 757, 173, 364, 135, 504, 725, 412, 647, 998, 480, 685, 775, 925, 681, 557, 43, 523, 1, 838, 450, 933, 31, 668, 121, 141, 706, 820, 188, 91, 766, 509, 603, 6, 165, 737, 787, 212, 250, 874, 8, 69, 349, 451, 108, 397, 62, 408, 247, 858, 751, 0, 77, 464, 28, 881, 98, 540, 211, 193, 402, 268, 611, 340, 783, 615, 494, 622, 144, 377, 255, 949, 61, 659, 586, 485, 945, 550, 337, 3, 593, 898, 290, 876, 963, 227, 374, 997, 243, 7, 278, 549, 582, 670, 361, 994, 854, 163, 217, 232, 312, 95, 773, 930, 855, 867, 848, 284, 988, 414, 962, 571, 136, 936, 39, 875, 697, 913, 849, 691, 977, 154, 20, 264, 470, 259, 718, 354, 112, 113, 261, 223, 527, 473, 627, 951, 862, 996, 486, 957, 928, 184, 587, 496, 722, 778, 253, 786, 698, 522, 887, 708, 950, 436, 24, 201, 71, 852, 590, 434, 952, 148, 788, 419, 902, 499, 359, 383, 657, 602, 490, 703, 789, 729, 118, 125, 872, 765, 231, 601, 64, 891, 716, 653, 224, 546, 273, 370, 152, 859, 792, 715, 390, 416, 514, 476, 959, 236, 727, 521, 676, 454, 517, 785, 63, 501, 565, 516, 975, 747, 262, 99, 281, 267, 970, 86, 958, 182, 578, 90, 500, 915, 740, 52, 980, 460, 574, 59, 248, 630, 805, 850, 651, 513, 937, 186, 750, 616, 971, 175, 420, 49, 452, 800, 269, 27, 439, 437, 688, 796, 162, 399, 764, 495, 180, 417, 871, 369, 183, 58, 954, 57, 923, 41, 197, 478, 80, 410, 661, 768, 381, 368, 579, 11, 70, 23, 995, 315, 445, 172, 323, 342, 912, 700, 829, 443, 421, 56, 461, 481, 544, 947, 336, 777, 126, 880, 356, 704, 47, 707, 686, 993, 846, 619, 629, 266, 275, 860, 531, 204, 427, 132, 613, 5, 809, 666, 298, 573, 511, 208, 823, 305, 720, 324, 746, 864, 693, 498, 210, 426, 467, 29, 285, 34, 567, 81, 865, 130, 313, 873, 643, 477, 195, 444, 242, 921, 249, 955, 12, 617, 272, 672, 512, 458, 42, 794, 675, 929, 137, 638, 828, 447, 25, 294, 806, 956, 594, 221, 604, 712, 826, 569, 834, 54, 556, 612, 97, 682, 946, 762, 111, 979, 868, 620, 411, 357, 699, 286, 146, 492, 817, 107, 308, 677, 303, 814, 624, 942, 304, 73, 508, 753, 985, 431, 763, 279, 32, 798, 870, 33, 636, 644, 372, 96, 821, 462, 760, 866, 60, 328, 424, 301, 403, 428, 469, 335, 293, 287, 140, 321, 896, 384, 404, 515, 662, 890, 21, 153, 505, 330, 36, 965, 316, 803, 446, 257, 665, 320, 147, 919, 553, 484, 551, 701, 463, 87, 169, 924, 459, 841, 780, 884, 784, 606, 497, 793, 348, 552, 413, 840, 801, 170, 306, 633, 474, 714, 241, 596, 999, 373, 519, 723, 652, 115, 738, 199, 194, 332, 168, 82, 909, 545, 808, 812, 991, 819, 491, 559, 466, 375, 131, 317, 589, 901, 339, 187, 219, 542, 252, 311, 156, 449, 129, 831, 334, 393, 689, 914, 710, 55, 648, 89, 568, 117, 967, 752, 83, 155, 717, 528, 425, 296, 100, 804, 625, 621, 674, 174, 535, 598, 585, 119, 435, 656, 265, 110, 932, 920, 216, 813, 365, 641, 562, 302, 17, 774, 502, 396, 597, 362, 536, 133, 346, 429, 244, 856, 742, 283, 65, 709, 917, 584, 289, 610, 407, 599, 245, 776, 642, 711, 50, 457, 15, 779, 256, 632, 228, 664, 385, 797, 639, 181, 409, 471, 547, 109, 433, 719, 355, 382, 246, 85, 331, 238, 990, 992, 363, 220, 660, 908, 749, 405, 423, 631, 835, 489, 530, 120, 230, 191, 861, 832, 360, 581, 134, 10, 103, 40, 526, 67, 733, 692, 944, 906, 895, 418, 966, 628, 122, 759, 655, 456, 592, 206, 524, 229, 127, 680, 14, 646, 2, 358, 292, 939, 745, 149, 280, 401, 728, 203, 51, 378, 503, 94, 483, 379, 46, 739, 44, 679, 387, 386, 327, 150, 314, 299, 869, 251, 176, 931, 326, 960, 889, 981, 877, 307, 297, 329, 744, 380, 493, 541, 218, 22, 48, 468, 878, 566, 488, 637, 171, 614, 350, 822, 811, 663, 106, 836, 561, 969, 669, 818, 916, 734, 142, 164, 767, 341, 235, 533, 649, 658, 351, 525, 900, 922, 623, 543, 575, 743, 145, 885, 609, 839, 482, 910, 318, 406, 894, 825, 972, 309, 941, 537, 102, 824, 756, 654, 167, 815, 748, 177, 879, 166, 687, 673, 882, 271, 472, 105, 989, 371, 237, 37, 851, 53, 904, 4, 322, 263, 465, 84, 430, 479, 158, 863, 398, 116, 123, 394, 799, 978, 607, 790, 78, 791, 415, 667, 608, 940, 254, 961, 911, 234, 9, 510, 178, 343, 209, 213, 837, 179, 185, 88, 830, 277, 391, 66, 905, 588, 124, 38, 260, 16, 771, 600, 76, 907, 276, 101, 938, 192, 713, 395, 964, 198, 440, 626, 376, 19]
tensor(232.9591, grad_fn=<SumBackward0>)
tensor(159.5916, grad_fn=<SumBackward0>)
tensor(260.8528, grad_fn=<SumBackward0>)
tensor(87.7134, grad_fn=<SumBackward0>)
tensor(145.8090, grad_fn=<SumBackward0>)
tensor(132.1825, grad_fn=<SumBackward0>)
tensor(111.1909, grad_fn=<SumBackward0>)
tensor(107.2836, grad_fn=<SumBackward0>)
tensor(92.1905, grad_fn=<SumBackward0>)
tensor(80.2855, grad_fn=<SumBackward0>)
tensor(52.5219, grad_fn=<SumBackward0>)
tensor(72.2165, grad_fn=<SumBackward0>)
tensor(57.8712, grad_fn=<SumBackward0>)
tensor(40.8757, grad_fn=<SumBackward0>)
tensor(24.2552, grad_fn=<SumBackward0>)
tensor(54.1106, grad_fn=<SumBackward0>)
tensor(30.4382, grad_fn=<SumBackward0>)
tensor(109.5641, grad_fn=<SumBackward0>)
tensor(51.5258, grad_fn=<SumBackward0>)
tensor(73.7645, grad_fn=<SumBackward0>)
tensor(56.3896, grad_fn=<SumBackward0>)
tensor(47.4813, grad_fn=<SumBackward0>)
tensor(43.7442, grad_fn=<SumBackward0>)
tensor(61.4727, grad_fn=<SumBackward0>)
tensor(45.0070, grad_fn=<SumBackward0>)
tensor(33.4824, grad_fn=<SumBackward0>)
tensor(28.9062, grad_fn=<SumBackward0>)
tensor(16.1308, grad_fn=<SumBackward0>)
tensor(24.8204, grad_fn=<SumBackward0>)
tensor(25.3970, grad_fn=<SumBackward0>)
tensor(58.5235, grad_fn=<SumBackward0>)
tensor(33.6278, grad_fn=<SumBackward0>)
tensor(17.2521, grad_fn=<SumBackward0>)
tensor(30.1957, grad_fn=<SumBackward0>)
tensor(9.8871, grad_fn=<SumBackward0>)
tensor(12.8610, grad_fn=<SumBackward0>)
tensor(22.4961, grad_fn=<SumBackward0>)
tensor(15.6347, grad_fn=<SumBackward0>)
tensor(18.2738, grad_fn=<SumBackward0>)
tensor(8.8927, grad_fn=<SumBackward0>)
tensor(17.8899, grad_fn=<SumBackward0>)
tensor(13.4591, grad_fn=<SumBackward0>)
tensor(17.7923, grad_fn=<SumBackward0>)
tensor(11.6760, grad_fn=<SumBackward0>)
tensor(9.0855, grad_fn=<SumBackward0>)
tensor(11.3534, grad_fn=<SumBackward0>)
tensor(16.0815, grad_fn=<SumBackward0>)
tensor(10.9235, grad_fn=<SumBackward0>)
tensor(13.2669, grad_fn=<SumBackward0>)
tensor(6.1385, grad_fn=<SumBackward0>)
tensor(6.9392, grad_fn=<SumBackward0>)
tensor(4.0830, grad_fn=<SumBackward0>)
tensor(4.8170, grad_fn=<SumBackward0>)
tensor(4.8928, grad_fn=<SumBackward0>)
tensor(9.8890, grad_fn=<SumBackward0>)
tensor(4.4573, grad_fn=<SumBackward0>)
tensor(8.4920, grad_fn=<SumBackward0>)
tensor(4.1211, grad_fn=<SumBackward0>)
tensor(5.6270, grad_fn=<SumBackward0>)
tensor(4.0366, grad_fn=<SumBackward0>)
tensor(2.6986, grad_fn=<SumBackward0>)
tensor(2.2631, grad_fn=<SumBackward0>)
tensor(5.1224, grad_fn=<SumBackward0>)
tensor(2.6720, grad_fn=<SumBackward0>)
tensor(2.3349, grad_fn=<SumBackward0>)
tensor(2.5388, grad_fn=<SumBackward0>)
tensor(3.9675, grad_fn=<SumBackward0>)
tensor(1.7293, grad_fn=<SumBackward0>)
tensor(1.8159, grad_fn=<SumBackward0>)
tensor(2.2723, grad_fn=<SumBackward0>)
tensor(2.4569, grad_fn=<SumBackward0>)
tensor(3.2855, grad_fn=<SumBackward0>)
tensor(2.7820, grad_fn=<SumBackward0>)
tensor(2.0565, grad_fn=<SumBackward0>)
tensor(0.6703, grad_fn=<SumBackward0>)
tensor(1.4259, grad_fn=<SumBackward0>)
tensor(1.1018, grad_fn=<SumBackward0>)
tensor(1.6113, grad_fn=<SumBackward0>)
tensor(1.4200, grad_fn=<SumBackward0>)
tensor(0.8547, grad_fn=<SumBackward0>)
tensor(1.0488, grad_fn=<SumBackward0>)
tensor(0.8984, grad_fn=<SumBackward0>)
tensor(1.4306, grad_fn=<SumBackward0>)
tensor(0.5119, grad_fn=<SumBackward0>)
tensor(1.4253, grad_fn=<SumBackward0>)
tensor(0.6320, grad_fn=<SumBackward0>)
tensor(1.0254, grad_fn=<SumBackward0>)
tensor(0.5535, grad_fn=<SumBackward0>)
tensor(1.4037, grad_fn=<SumBackward0>)
tensor(0.4481, grad_fn=<SumBackward0>)
tensor(0.7792, grad_fn=<SumBackward0>)
tensor(0.1650, grad_fn=<SumBackward0>)
tensor(0.5154, grad_fn=<SumBackward0>)
tensor(0.9358, grad_fn=<SumBackward0>)
tensor(0.4028, grad_fn=<SumBackward0>)
tensor(0.3825, grad_fn=<SumBackward0>)
tensor(0.2270, grad_fn=<SumBackward0>)
tensor(0.1948, grad_fn=<SumBackward0>)
tensor(0.2783, grad_fn=<SumBackward0>)
tensor(0.4928, grad_fn=<SumBackward0>)
epoch 1, loss 0.034670
[62, 952, 512, 227, 221, 243, 909, 102, 298, 982, 135, 643, 331, 744, 145, 333, 782, 268, 339, 344, 576, 360, 524, 912, 655, 986, 407, 59, 683, 870, 682, 868, 389, 691, 12, 465, 109, 31, 276, 873, 893, 516, 642, 277, 351, 538, 749, 993, 30, 417, 65, 405, 913, 425, 247, 581, 36, 745, 103, 111, 153, 182, 628, 677, 645, 57, 8, 858, 711, 207, 210, 791, 563, 536, 549, 653, 888, 490, 920, 274, 302, 43, 780, 237, 513, 74, 907, 637, 322, 735, 319, 160, 784, 429, 571, 377, 406, 806, 523, 49, 10, 962, 121, 233, 828, 960, 163, 661, 370, 958, 140, 164, 397, 936, 767, 58, 328, 152, 566, 358, 177, 692, 313, 335, 223, 181, 415, 853, 945, 138, 96, 831, 555, 707, 737, 825, 439, 11, 953, 105, 1, 401, 725, 804, 447, 501, 923, 724, 384, 503, 545, 572, 390, 388, 455, 650, 556, 840, 402, 119, 199, 270, 300, 799, 292, 742, 212, 166, 726, 552, 615, 172, 899, 760, 866, 463, 228, 787, 664, 693, 149, 798, 134, 529, 514, 539, 955, 648, 526, 137, 215, 727, 860, 114, 118, 601, 17, 654, 346, 849, 930, 848, 688, 592, 173, 658, 591, 441, 534, 876, 944, 410, 835, 910, 323, 528, 458, 697, 906, 559, 703, 583, 793, 597, 625, 473, 192, 639, 191, 452, 236, 636, 619, 446, 604, 854, 716, 493, 224, 167, 229, 855, 189, 589, 141, 485, 756, 608, 647, 708, 959, 988, 94, 231, 217, 889, 879, 722, 595, 792, 974, 352, 720, 877, 380, 671, 375, 928, 998, 976, 453, 900, 113, 551, 776, 200, 128, 16, 919, 445, 18, 666, 817, 800, 193, 963, 845, 987, 249, 618, 810, 39, 672, 973, 159, 316, 147, 204, 605, 949, 917, 786, 826, 474, 403, 37, 700, 422, 815, 259, 6, 347, 310, 318, 972, 550, 246, 632, 759, 201, 242, 408, 957, 941, 324, 9, 240, 472, 587, 383, 540, 457, 32, 602, 758, 411, 287, 659, 283, 620, 41, 925, 885, 656, 950, 487, 357, 115, 751, 368, 971, 543, 558, 779, 846, 34, 517, 23, 510, 462, 729, 675, 208, 132, 400, 818, 816, 856, 829, 594, 293, 770, 715, 97, 943, 750, 610, 95, 99, 161, 880, 261, 432, 822, 489, 903, 365, 116, 330, 213, 569, 250, 169, 467, 255, 239, 603, 520, 427, 162, 830, 386, 77, 689, 307, 476, 609, 151, 289, 327, 171, 170, 409, 861, 288, 301, 92, 117, 723, 214, 895, 73, 665, 68, 15, 363, 349, 325, 847, 61, 505, 820, 901, 535, 373, 905, 317, 541, 679, 902, 75, 915, 890, 561, 635, 721, 183, 613, 611, 100, 24, 281, 865, 884, 497, 398, 852, 196, 530, 752, 506, 992, 418, 548, 395, 225, 124, 626, 796, 387, 948, 921, 176, 573, 518, 640, 359, 975, 378, 374, 678, 308, 502, 904, 775, 184, 680, 985, 871, 841, 575, 940, 245, 757, 361, 681, 785, 911, 633, 271, 491, 168, 687, 304, 961, 88, 421, 946, 123, 248, 673, 740, 232, 891, 40, 81, 932, 364, 748, 143, 765, 646, 797, 356, 747, 253, 931, 477, 203, 872, 282, 35, 781, 206, 657, 244, 897, 935, 883, 937, 52, 811, 14, 394, 667, 807, 50, 47, 994, 280, 434, 285, 133, 769, 834, 252, 788, 631, 553, 557, 926, 404, 851, 175, 413, 717, 211, 376, 641, 584, 263, 728, 38, 612, 918, 226, 72, 768, 494, 342, 337, 600, 896, 578, 155, 999, 983, 596, 46, 969, 623, 299, 45, 53, 419, 295, 989, 862, 574, 504, 108, 254, 20, 461, 997, 399, 154, 694, 886, 812, 5, 366, 878, 42, 79, 266, 178, 450, 927, 258, 984, 78, 269, 981, 942, 127, 194, 794, 481, 150, 732, 805, 54, 448, 916, 469, 956, 303, 968, 908, 522, 112, 51, 86, 367, 713, 25, 195, 369, 686, 56, 190, 71, 440, 449, 588, 978, 311, 483, 670, 309, 887, 598, 198, 965, 218, 129, 321, 391, 26, 850, 651, 621, 355, 396, 122, 379, 7, 468, 939, 977, 120, 934, 488, 67, 590, 278, 857, 823, 85, 312, 859, 630, 414, 746, 320, 662, 475, 336, 91, 486, 0, 479, 568, 509, 507, 424, 454, 763, 350, 290, 66, 560, 718, 82, 649, 430, 881, 634, 821, 863, 435, 996, 158, 582, 495, 19, 537, 460, 234, 617, 291, 179, 772, 867, 98, 55, 464, 733, 106, 423, 511, 929, 104, 704, 343, 607, 699, 766, 564, 731, 771, 60, 790, 819, 381, 938, 294, 663, 165, 801, 279, 382, 964, 426, 420, 970, 13, 668, 28, 531, 773, 789, 762, 205, 209, 730, 499, 306, 93, 64, 348, 894, 459, 438, 533, 130, 783, 990, 157, 991, 219, 777, 882, 498, 519, 565, 442, 251, 527, 698, 515, 743, 690, 808, 186, 599, 90, 21, 340, 188, 392, 33, 3, 532, 624, 554, 334, 136, 753, 63, 272, 542, 593, 764, 496, 471, 706, 437, 238, 131, 480, 562, 273, 755, 898, 185, 222, 101, 734, 521, 197, 220, 954, 843, 606, 284, 484, 478, 966, 837, 125, 712, 802, 216, 180, 433, 235, 719, 669, 76, 544, 241, 267, 332, 4, 547, 27, 107, 570, 156, 844, 265, 695, 701, 22, 754, 629, 29, 685, 933, 146, 142, 980, 436, 126, 778, 580, 262, 354, 87, 525, 187, 702, 110, 874, 372, 947, 264, 644, 839, 803, 736, 614, 428, 371, 260, 256, 416, 44, 761, 275, 696, 622, 451, 674, 466, 305, 314, 2, 967, 824, 738, 627, 795, 714, 48, 144, 741, 257, 296, 89, 202, 230, 833, 924, 345, 482, 362, 500, 577, 315, 652, 864, 586, 774, 567, 827, 353, 385, 616, 431, 69, 979, 341, 951, 139, 995, 174, 838, 660, 814, 579, 456, 329, 444, 875, 914, 869, 813, 83, 832, 70, 84, 585, 412, 326, 709, 393, 710, 443, 684, 676, 80, 638, 286, 705, 470, 809, 508, 546, 492, 842, 922, 297, 148, 338, 739, 836, 892]
tensor(0.5654, grad_fn=<SumBackward0>)
tensor(0.3075, grad_fn=<SumBackward0>)
tensor(0.2248, grad_fn=<SumBackward0>)
tensor(0.3423, grad_fn=<SumBackward0>)
tensor(0.2097, grad_fn=<SumBackward0>)
tensor(0.2054, grad_fn=<SumBackward0>)
tensor(0.2592, grad_fn=<SumBackward0>)
tensor(0.2037, grad_fn=<SumBackward0>)
tensor(0.3933, grad_fn=<SumBackward0>)
tensor(0.1703, grad_fn=<SumBackward0>)
tensor(0.1157, grad_fn=<SumBackward0>)
tensor(0.1904, grad_fn=<SumBackward0>)
tensor(0.1128, grad_fn=<SumBackward0>)
tensor(0.2593, grad_fn=<SumBackward0>)
tensor(0.1346, grad_fn=<SumBackward0>)
tensor(0.1534, grad_fn=<SumBackward0>)
tensor(0.1238, grad_fn=<SumBackward0>)
tensor(0.1997, grad_fn=<SumBackward0>)
tensor(0.0937, grad_fn=<SumBackward0>)
tensor(0.1583, grad_fn=<SumBackward0>)
tensor(0.0928, grad_fn=<SumBackward0>)
tensor(0.0709, grad_fn=<SumBackward0>)
tensor(0.0403, grad_fn=<SumBackward0>)
tensor(0.0440, grad_fn=<SumBackward0>)
tensor(0.0590, grad_fn=<SumBackward0>)
tensor(0.0436, grad_fn=<SumBackward0>)
tensor(0.0497, grad_fn=<SumBackward0>)
tensor(0.0616, grad_fn=<SumBackward0>)
tensor(0.0325, grad_fn=<SumBackward0>)
tensor(0.0482, grad_fn=<SumBackward0>)
tensor(0.0435, grad_fn=<SumBackward0>)
tensor(0.0412, grad_fn=<SumBackward0>)
tensor(0.0270, grad_fn=<SumBackward0>)
tensor(0.0662, grad_fn=<SumBackward0>)
tensor(0.0474, grad_fn=<SumBackward0>)
tensor(0.0245, grad_fn=<SumBackward0>)
tensor(0.0423, grad_fn=<SumBackward0>)
tensor(0.0237, grad_fn=<SumBackward0>)
tensor(0.0474, grad_fn=<SumBackward0>)
tensor(0.0205, grad_fn=<SumBackward0>)
tensor(0.0475, grad_fn=<SumBackward0>)
tensor(0.0148, grad_fn=<SumBackward0>)
tensor(0.0167, grad_fn=<SumBackward0>)
tensor(0.0179, grad_fn=<SumBackward0>)
tensor(0.0262, grad_fn=<SumBackward0>)
tensor(0.0164, grad_fn=<SumBackward0>)
tensor(0.0110, grad_fn=<SumBackward0>)
tensor(0.0247, grad_fn=<SumBackward0>)
tensor(0.0338, grad_fn=<SumBackward0>)
tensor(0.0246, grad_fn=<SumBackward0>)
tensor(0.0186, grad_fn=<SumBackward0>)
tensor(0.0191, grad_fn=<SumBackward0>)
tensor(0.0086, grad_fn=<SumBackward0>)
tensor(0.0242, grad_fn=<SumBackward0>)
tensor(0.0118, grad_fn=<SumBackward0>)
tensor(0.0179, grad_fn=<SumBackward0>)
tensor(0.0103, grad_fn=<SumBackward0>)
tensor(0.0084, grad_fn=<SumBackward0>)
tensor(0.0033, grad_fn=<SumBackward0>)
tensor(0.0207, grad_fn=<SumBackward0>)
tensor(0.0103, grad_fn=<SumBackward0>)
tensor(0.0090, grad_fn=<SumBackward0>)
tensor(0.0042, grad_fn=<SumBackward0>)
tensor(0.0117, grad_fn=<SumBackward0>)
tensor(0.0163, grad_fn=<SumBackward0>)
tensor(0.0066, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0104, grad_fn=<SumBackward0>)
tensor(0.0091, grad_fn=<SumBackward0>)
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0081, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0049, grad_fn=<SumBackward0>)
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0055, grad_fn=<SumBackward0>)
tensor(0.0027, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0020, grad_fn=<SumBackward0>)
tensor(0.0027, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0020, grad_fn=<SumBackward0>)
tensor(0.0027, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
epoch 2, loss 0.000126
[561, 81, 126, 675, 494, 417, 333, 59, 526, 910, 397, 685, 659, 715, 975, 839, 615, 911, 655, 330, 818, 264, 426, 791, 127, 737, 376, 925, 197, 109, 756, 395, 159, 916, 211, 776, 734, 594, 807, 816, 353, 611, 770, 859, 900, 815, 272, 559, 614, 32, 94, 372, 252, 974, 358, 591, 232, 152, 76, 2, 78, 578, 814, 405, 375, 562, 247, 132, 191, 602, 738, 137, 58, 535, 196, 404, 239, 733, 304, 739, 883, 568, 629, 790, 530, 19, 291, 108, 842, 726, 933, 878, 919, 359, 751, 892, 434, 891, 219, 920, 633, 760, 773, 681, 461, 245, 781, 717, 387, 142, 70, 243, 724, 454, 725, 586, 522, 371, 421, 854, 178, 362, 139, 381, 822, 958, 775, 218, 728, 392, 403, 40, 552, 848, 587, 436, 102, 207, 569, 406, 956, 632, 355, 268, 145, 862, 30, 636, 233, 389, 75, 133, 341, 984, 281, 87, 547, 669, 134, 719, 231, 418, 556, 906, 314, 833, 660, 651, 643, 27, 747, 641, 420, 670, 803, 369, 26, 955, 296, 481, 361, 360, 7, 44, 583, 806, 266, 635, 170, 866, 908, 101, 391, 663, 498, 103, 173, 579, 886, 53, 662, 131, 979, 521, 230, 628, 117, 548, 326, 383, 836, 140, 745, 57, 769, 458, 209, 749, 849, 51, 71, 477, 484, 716, 460, 570, 251, 764, 293, 864, 47, 572, 33, 923, 114, 433, 708, 290, 486, 650, 930, 753, 518, 907, 710, 666, 12, 965, 492, 35, 536, 894, 430, 821, 879, 456, 990, 212, 944, 817, 575, 616, 307, 113, 709, 692, 11, 528, 149, 767, 935, 532, 378, 396, 318, 425, 270, 17, 674, 46, 939, 867, 657, 215, 541, 309, 24, 66, 539, 428, 810, 966, 319, 802, 385, 274, 439, 700, 968, 631, 828, 948, 680, 585, 36, 805, 384, 829, 303, 593, 147, 254, 722, 514, 774, 416, 982, 915, 794, 652, 253, 519, 520, 712, 762, 538, 138, 246, 677, 778, 782, 551, 989, 124, 412, 475, 488, 609, 988, 49, 325, 701, 208, 581, 654, 723, 742, 95, 363, 146, 820, 830, 20, 768, 470, 284, 600, 621, 354, 796, 515, 56, 798, 269, 683, 903, 118, 927, 37, 765, 435, 226, 951, 876, 801, 276, 969, 322, 278, 970, 758, 352, 523, 60, 93, 852, 605, 977, 459, 647, 50, 115, 478, 411, 220, 938, 466, 667, 1, 945, 646, 189, 952, 588, 210, 804, 167, 327, 893, 227, 546, 483, 28, 512, 554, 79, 345, 337, 890, 922, 957, 453, 163, 693, 202, 870, 321, 248, 779, 271, 357, 68, 625, 174, 84, 182, 637, 711, 265, 689, 400, 697, 119, 438, 516, 292, 808, 926, 534, 986, 992, 256, 950, 918, 972, 838, 188, 824, 865, 531, 771, 691, 199, 750, 164, 347, 348, 229, 39, 840, 855, 42, 599, 508, 487, 236, 755, 557, 601, 642, 455, 22, 603, 560, 312, 973, 529, 100, 485, 869, 315, 580, 452, 448, 261, 184, 214, 373, 604, 872, 419, 279, 125, 704, 259, 795, 800, 608, 356, 449, 994, 500, 825, 181, 67, 856, 21, 873, 415, 106, 884, 97, 394, 934, 959, 476, 15, 180, 880, 289, 62, 166, 936, 299, 634, 285, 860, 205, 499, 63, 45, 275, 316, 38, 423, 946, 283, 889, 858, 350, 144, 705, 351, 295, 690, 83, 151, 912, 567, 571, 740, 555, 736, 474, 664, 429, 558, 104, 658, 64, 942, 472, 473, 812, 624, 960, 543, 203, 863, 827, 177, 370, 981, 489, 445, 343, 198, 302, 898, 619, 311, 250, 564, 324, 451, 961, 10, 592, 525, 789, 943, 735, 980, 996, 249, 77, 482, 868, 158, 273, 401, 116, 69, 65, 176, 732, 971, 679, 668, 179, 73, 90, 111, 527, 263, 462, 301, 286, 31, 513, 467, 332, 195, 835, 837, 937, 143, 25, 305, 91, 329, 843, 899, 721, 766, 921, 678, 630, 589, 797, 610, 622, 141, 627, 550, 52, 846, 493, 517, 696, 74, 495, 949, 875, 932, 809, 661, 964, 754, 5, 135, 905, 497, 533, 216, 157, 367, 468, 784, 874, 656, 54, 928, 714, 606, 161, 688, 201, 402, 4, 997, 694, 897, 885, 447, 524, 55, 832, 813, 34, 238, 595, 85, 280, 382, 954, 88, 881, 306, 244, 450, 788, 288, 914, 644, 380, 407, 379, 563, 99, 888, 703, 993, 904, 823, 985, 225, 847, 686, 18, 707, 258, 165, 217, 107, 160, 120, 171, 237, 431, 743, 620, 336, 553, 783, 780, 260, 698, 413, 334, 861, 718, 3, 940, 82, 598, 128, 841, 613, 702, 777, 386, 408, 787, 175, 172, 240, 730, 364, 121, 308, 967, 584, 502, 297, 713, 612, 437, 545, 72, 112, 0, 729, 793, 186, 857, 757, 590, 706, 221, 427, 741, 687, 682, 200, 86, 640, 991, 799, 130, 727, 544, 963, 410, 463, 192, 331, 6, 895, 365, 653, 929, 110, 983, 844, 850, 596, 695, 953, 845, 388, 150, 831, 300, 574, 607, 187, 772, 792, 374, 183, 941, 648, 242, 887, 576, 976, 342, 98, 338, 185, 537, 393, 41, 80, 444, 222, 464, 510, 672, 917, 148, 731, 811, 962, 92, 909, 673, 763, 882, 310, 649, 277, 496, 617, 105, 549, 896, 913, 432, 826, 43, 671, 241, 618, 565, 414, 340, 987, 442, 785, 542, 349, 255, 156, 234, 29, 998, 409, 320, 819, 344, 623, 871, 48, 699, 89, 995, 235, 491, 566, 665, 902, 377, 507, 136, 328, 424, 228, 786, 398, 676, 443, 317, 8, 213, 153, 323, 501, 204, 16, 335, 223, 155, 162, 759, 14, 761, 262, 577, 339, 720, 746, 901, 61, 834, 931, 509, 947, 287, 23, 638, 122, 503, 511, 851, 744, 853, 194, 390, 471, 479, 540, 123, 752, 645, 422, 573, 748, 446, 597, 257, 582, 154, 224, 978, 298, 169, 9, 294, 490, 504, 267, 206, 282, 96, 924, 684, 440, 366, 441, 999, 129, 465, 457, 368, 313, 346, 190, 168, 626, 506, 193, 469, 505, 399, 877, 480, 13, 639]
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
epoch 3, loss 0.000049
w的估计误差: tensor([-0.0002, -0.0001], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0012], grad_fn=<RsubBackward1>)