线性回归#

线性回归核心就是在x和y轴中,给出一个数x会有相对应的一个y值。我们需要得到这一个模型(通俗说:一个直线公式)。

在线性回归中,数据使用线性预测函数来建模,并且未知的模型参数也是通过数据来估计。这些模型被叫做线性模型。最常用的线性回归建模是给定X值的y的条件均值是X的仿射函数。不太一般的情况,线性回归模型可以是一个中位数或一些其他的给定X的条件下y的条件分布的分位数作为X的线性函数表示。 线性回归有很多实际用途。分为以下两大类:

  1. 如果目标是预测或者映射,线性回归可以用来对观测数据集的和X的值拟合出一个预测模型。当完成这样一个模型以后,对于一个新增的X值,在没有给定与它相配对的y的情况下,可以用这个拟合过的模型预测出一个y值。

  2. 给定一个变量y和一些变量\({\displaystyle X_{1}},...,{\displaystyle X_{p}}\),这些变量有可能与y相关,线性回归分析可以用来量化y与Xj之间相关性的强度,评估出与y不相关的\({\displaystyle X_{j}}\),并识别出哪些\({\displaystyle X_{j}}\)的子集包含了关于y的冗余信息。

训练的过程下图会更直观的展示

from IPython import display
import time
from mindspore.train.callback import Callback
from mindspore import Model
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import nn
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import numpy as np
from mindspore import context

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def get_data(num, w=2.0, b=3.0):
    for _ in range(num):
        x = np.random.uniform(-10.0, 10.0)
        noise = np.random.normal(0, 1)
        y = x * w + b + noise
        yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)

plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_target_label, y_target_label, color="green")
plt.title("Eval data")
plt.show()
def create_dataset(num_data, batch_size=16, repeat_size=1):
    input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
    input_data = input_data.batch(batch_size)
    input_data = input_data.repeat(repeat_size)
    return input_data
data_number = 1600
batch_number = 16
repeat_number = 1

ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())

print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
class LinearNet(nn.Cell):
    def __init__(self):
        super(LinearNet, self).__init__()
        self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))

    def construct(self, x):
        x = self.fc(x)
        return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
    print(param, param.asnumpy())
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
                 Tensor(model_params[1]).asnumpy()[0])

plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.show()
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)

def plot_model_and_datasets(net, eval_data):
    weight = net.trainable_params()[0]
    bias = net.trainable_params()[1]
    x = np.arange(-10, 10, 0.1)
    y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
    x1, y1 = zip(*eval_data)
    x_target = x
    y_target = x_target * 2 + 3

    plt.axis([-11, 11, -20, 25])
    plt.scatter(x1, y1, color="red", s=5)
    plt.plot(x, y, color="blue")
    plt.plot(x_target, y_target, color="green")
    plt.show()
    time.sleep(0.02)
class ImageShowCallback(Callback):
    def __init__(self, net, eval_data):
        self.net = net
        self.eval_data = eval_data

    def step_end(self, run_context):
        plot_model_and_datasets(self.net, self.eval_data)
        display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)

plot_model_and_datasets(net, eval_data)
for param in net.trainable_params():
    print(param, param.asnumpy())
%matplotlib inline
import random
import torch
from d2l import torch as d2l

生成数据集#

在下面的代码中,我们生成一个包含1000个样本的数据集, 每个样本包含从标准正态分布中采样的2个特征。 我们的合成数据集是一个矩阵\(\mathbf{X}\in \mathbb{R}^{1000 \times 2}\)

我们使用线性模型参数\(\mathbf{w} = [2, -3.4]^\top\)\(b = 4.2\) 和噪声项\(\epsilon\)生成数据集及其标签:

\[\mathbf{y}= \mathbf{X} \mathbf{w} + b + \mathbf\epsilon.\]

\(\epsilon\)可以视为模型预测和标签时的潜在观测误差。 在这里我们认为标准假设成立,即\(\epsilon\)服从均值为0的正态分布。 为了简化问题,我们将标准差设为0.01。

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    #means (Tensor) – 均值(平均值)
    #std (Tensor) – 标准差 https://zh.wikihow.com/%E8%AE%A1%E7%AE%97%E6%A0%87%E5%87%86%E5%B7%AE
    #out (Tensor) – 可选的输出张量
    X = torch.normal(0, 1, (num_examples, len(w)))
    print(X)
    #两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现
    y = torch.matmul(X, w) + b
    #print(y)
    y += torch.normal(0, 0.01, y.shape)
    #print(y)
    #torch.shape 和 torch.size()
    #-1表示总数所在的位置
    return X, y.reshape((-1, 1))
    
#创建张量
true_w = torch.tensor([2, -3.4])
print(true_w)
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
tensor([ 2.0000, -3.4000])
tensor([[-0.2266, -1.3060],
        [ 0.1052,  0.3722],
        [ 0.0629,  0.5947],
        ...,
        [ 0.0245,  1.1694],
        [-0.0106,  2.2218],
        [ 1.0490,  1.2395]])

features中的每一行都包含一个二维数据样本, labels中的每一行都包含一维标签值(一个标量)

print('features:', features[0],'\nlabel:', labels[0])
print(features)
features: tensor([-0.2266, -1.3060]) 
label: tensor([8.1863])
tensor([[-0.2266, -1.3060],
        [ 0.1052,  0.3722],
        [ 0.0629,  0.5947],
        ...,
        [ 0.0245,  1.1694],
        [-0.0106,  2.2218],
        [ 1.0490,  1.2395]])

通过生成第二个特征features[:, 1]和labels的散点图, 可以直观观察到两者之间的线性关系。

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
../_images/1490f23352962688d411b771fc6e8bf3e16b9a47d7ef36fc473841d939b425dc.svg

读取数据集#

我们定义一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。

使用下面代码的时候先阅读下用法

yueld用法

直接参考 https://blog.csdn.net/mieleizhi0522/article/details/82142856/

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    #print(num_examples)
    indices = list(range(num_examples))
    #print(indices)
    # 这些样本是随机读取的,没有特定的顺序,打乱位置
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):  # 从0开始每次+10进行循环到1000为止
        batch_indices=indices[i: min(i + batch_size, num_examples)]  # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]         
        #print(batch_indices)
        #https://blog.csdn.net/mieleizhi0522/article/details/82142856/
        yield features[batch_indices], labels[batch_indices]  # 这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的

通常,我们利用GPU并行运算的优势,处理合理大小的“小批量”。 每个样本都可以并行地进行模型计算,且每个样本损失函数的梯度也可以被并行计算。 GPU可以在处理几百个样本时,所花费的时间不比处理一个样本时多太多。

我们直观感受一下小批量运算:读取第一个小批量数据样本并打印。 每个批量的特征维度显示批量大小和输入特征数。 同样的,批量的标签形状与batch_size相等。

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break
tensor([[-0.3870, -1.0756],
        [ 1.2208,  0.3579],
        [-0.7829,  0.2493],
        [-0.3913,  1.7142],
        [ 0.0617, -1.6198],
        [-1.2673, -1.0734],
        [ 1.4152, -1.0266],
        [-1.1467, -0.3620],
        [-1.0045, -0.2672],
        [ 1.5604,  1.1649]]) 
 tensor([[ 7.1007],
        [ 5.4431],
        [ 1.8114],
        [-2.4055],
        [ 9.8235],
        [ 5.3071],
        [10.5071],
        [ 3.1315],
        [ 3.0812],
        [ 3.3516]])

初始化模型参数#

在下面的代码中,我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

定义模型#

我们只需计算输入特征\(\mathbf{X}\)和模型权重\(\mathbf{w}\)的矩阵-向量乘法后加上偏置\(b\)。 注意,上面的\(\mathbf{Xw}\)是一个向量,而\(b\)是一个标量。 回想一下广播机制: 当我们用一个向量加一个标量时,标量会被加到向量的每个分量上。

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

定义损失函数#

因为需要计算损失函数的梯度,所以我们应该先定义损失函数。在实现中,我们需要将真实值y的形状转换为和预测值y_hat的形状相同

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    #print(y_hat.shape)
    #print(y_hat)
    #真实值y的形状转换为和预测值y_hat的形状相同
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

定义优化算法#

下面的函数实现小批量随机梯度下降更新。 该函数接受模型参数集合、学习速率和批量大小作为输入。每 一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    #with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭/线程中锁的自动获取和释放等。
    # no_grad用来关闭梯度计算
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            #需要清理梯度值不然会累加
            param.grad.zero_()

训练#

现在我们已经准备好了模型训练所有需要的要素,可以实现主要的训练过程部分了。 理解这段代码至关重要,因为从事深度学习后, 相同的训练过程几乎一遍又一遍地出现。 在每次迭代中,我们读取一小批量训练样本,并通过我们的模型来获得一组预测。 计算完损失后,我们开始反向传播,存储每个参数的梯度。 最后,我们调用优化算法sgd来更新模型参数。

概括一下,我们将执行以下循环:

  • 初始化参数

  • 重复以下训练,直到完成

    • 计算梯度\(\mathbf{g} \leftarrow \partial_{(\mathbf{w},b)} \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} l(\mathbf{x}^{(i)}, y^{(i)}, \mathbf{w}, b)\)

    • 更新参数\((\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \mathbf{g}\)

在每个迭代周期(epoch)中,我们使用data_iter函数遍历整个数据集, 并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。 这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设为3和0.03。 设置超参数很棘手,需要通过反复试验进行调整。

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    
    for X, y in data_iter(batch_size, features, labels):
        # print(X)
        # print(y)
        # print(X.shape)
        # print(y.shape)
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        c=l.sum() #这个值可以很直观的反映出逐渐下降
        print(c)
        c.backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数,学习率lr是0.03
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
tensor(144.6424, grad_fn=<SumBackward0>)
tensor(231.2420, grad_fn=<SumBackward0>)
tensor(163.4158, grad_fn=<SumBackward0>)
tensor(114.7277, grad_fn=<SumBackward0>)
tensor(89.3091, grad_fn=<SumBackward0>)
tensor(90.3960, grad_fn=<SumBackward0>)
tensor(125.6702, grad_fn=<SumBackward0>)
tensor(97.0799, grad_fn=<SumBackward0>)
tensor(63.2129, grad_fn=<SumBackward0>)
tensor(130.1256, grad_fn=<SumBackward0>)
tensor(103.8148, grad_fn=<SumBackward0>)
tensor(44.6553, grad_fn=<SumBackward0>)
tensor(94.6665, grad_fn=<SumBackward0>)
tensor(100.4475, grad_fn=<SumBackward0>)
tensor(79.5025, grad_fn=<SumBackward0>)
tensor(70.8233, grad_fn=<SumBackward0>)
tensor(61.7535, grad_fn=<SumBackward0>)
tensor(22.8789, grad_fn=<SumBackward0>)
tensor(57.6419, grad_fn=<SumBackward0>)
tensor(52.4064, grad_fn=<SumBackward0>)
tensor(50.1369, grad_fn=<SumBackward0>)
tensor(35.8219, grad_fn=<SumBackward0>)
tensor(36.5302, grad_fn=<SumBackward0>)
tensor(30.3027, grad_fn=<SumBackward0>)
tensor(38.6878, grad_fn=<SumBackward0>)
tensor(55.7632, grad_fn=<SumBackward0>)
tensor(42.8121, grad_fn=<SumBackward0>)
tensor(33.0177, grad_fn=<SumBackward0>)
tensor(32.0234, grad_fn=<SumBackward0>)
tensor(44.0309, grad_fn=<SumBackward0>)
tensor(15.8246, grad_fn=<SumBackward0>)
tensor(37.2418, grad_fn=<SumBackward0>)
tensor(18.1009, grad_fn=<SumBackward0>)
tensor(22.2856, grad_fn=<SumBackward0>)
tensor(19.5201, grad_fn=<SumBackward0>)
tensor(17.7467, grad_fn=<SumBackward0>)
tensor(15.3512, grad_fn=<SumBackward0>)
tensor(25.2421, grad_fn=<SumBackward0>)
tensor(19.7973, grad_fn=<SumBackward0>)
tensor(11.0220, grad_fn=<SumBackward0>)
tensor(10.3089, grad_fn=<SumBackward0>)
tensor(14.2223, grad_fn=<SumBackward0>)
tensor(13.0006, grad_fn=<SumBackward0>)
tensor(27.6724, grad_fn=<SumBackward0>)
tensor(6.6847, grad_fn=<SumBackward0>)
tensor(9.4211, grad_fn=<SumBackward0>)
tensor(4.2987, grad_fn=<SumBackward0>)
tensor(18.0866, grad_fn=<SumBackward0>)
tensor(6.6209, grad_fn=<SumBackward0>)
tensor(12.4300, grad_fn=<SumBackward0>)
tensor(4.5059, grad_fn=<SumBackward0>)
tensor(9.5686, grad_fn=<SumBackward0>)
tensor(4.8782, grad_fn=<SumBackward0>)
tensor(3.7715, grad_fn=<SumBackward0>)
tensor(5.5920, grad_fn=<SumBackward0>)
tensor(4.5820, grad_fn=<SumBackward0>)
tensor(7.3536, grad_fn=<SumBackward0>)
tensor(4.8988, grad_fn=<SumBackward0>)
tensor(8.2794, grad_fn=<SumBackward0>)
tensor(5.9296, grad_fn=<SumBackward0>)
tensor(3.7783, grad_fn=<SumBackward0>)
tensor(5.4082, grad_fn=<SumBackward0>)
tensor(5.1099, grad_fn=<SumBackward0>)
tensor(1.7293, grad_fn=<SumBackward0>)
tensor(2.4726, grad_fn=<SumBackward0>)
tensor(5.3091, grad_fn=<SumBackward0>)
tensor(4.1019, grad_fn=<SumBackward0>)
tensor(2.3321, grad_fn=<SumBackward0>)
tensor(4.0303, grad_fn=<SumBackward0>)
tensor(4.0960, grad_fn=<SumBackward0>)
tensor(2.1253, grad_fn=<SumBackward0>)
tensor(3.9624, grad_fn=<SumBackward0>)
tensor(1.5492, grad_fn=<SumBackward0>)
tensor(4.3124, grad_fn=<SumBackward0>)
tensor(4.4334, grad_fn=<SumBackward0>)
tensor(2.4073, grad_fn=<SumBackward0>)
tensor(2.0754, grad_fn=<SumBackward0>)
tensor(1.6762, grad_fn=<SumBackward0>)
tensor(1.4315, grad_fn=<SumBackward0>)
tensor(1.1265, grad_fn=<SumBackward0>)
tensor(1.3554, grad_fn=<SumBackward0>)
tensor(2.7523, grad_fn=<SumBackward0>)
tensor(1.5462, grad_fn=<SumBackward0>)
tensor(1.1481, grad_fn=<SumBackward0>)
tensor(0.8130, grad_fn=<SumBackward0>)
tensor(1.0212, grad_fn=<SumBackward0>)
tensor(1.1880, grad_fn=<SumBackward0>)
tensor(1.0681, grad_fn=<SumBackward0>)
tensor(0.7941, grad_fn=<SumBackward0>)
tensor(0.6680, grad_fn=<SumBackward0>)
tensor(1.0883, grad_fn=<SumBackward0>)
tensor(0.6638, grad_fn=<SumBackward0>)
tensor(0.5248, grad_fn=<SumBackward0>)
tensor(0.4682, grad_fn=<SumBackward0>)
tensor(0.3444, grad_fn=<SumBackward0>)
tensor(0.6271, grad_fn=<SumBackward0>)
tensor(1.3872, grad_fn=<SumBackward0>)
tensor(0.3900, grad_fn=<SumBackward0>)
tensor(0.2750, grad_fn=<SumBackward0>)
tensor(0.5190, grad_fn=<SumBackward0>)
epoch 1, loss 0.044810
tensor(0.5909, grad_fn=<SumBackward0>)
tensor(0.4514, grad_fn=<SumBackward0>)
tensor(0.5269, grad_fn=<SumBackward0>)
tensor(0.4626, grad_fn=<SumBackward0>)
tensor(0.2509, grad_fn=<SumBackward0>)
tensor(0.1799, grad_fn=<SumBackward0>)
tensor(0.3994, grad_fn=<SumBackward0>)
tensor(0.4317, grad_fn=<SumBackward0>)
tensor(0.2371, grad_fn=<SumBackward0>)
tensor(0.2292, grad_fn=<SumBackward0>)
tensor(0.2046, grad_fn=<SumBackward0>)
tensor(0.1934, grad_fn=<SumBackward0>)
tensor(0.1169, grad_fn=<SumBackward0>)
tensor(0.1345, grad_fn=<SumBackward0>)
tensor(0.1486, grad_fn=<SumBackward0>)
tensor(0.1324, grad_fn=<SumBackward0>)
tensor(0.0856, grad_fn=<SumBackward0>)
tensor(0.2349, grad_fn=<SumBackward0>)
tensor(0.3047, grad_fn=<SumBackward0>)
tensor(0.2714, grad_fn=<SumBackward0>)
tensor(0.1644, grad_fn=<SumBackward0>)
tensor(0.0887, grad_fn=<SumBackward0>)
tensor(0.1386, grad_fn=<SumBackward0>)
tensor(0.0722, grad_fn=<SumBackward0>)
tensor(0.0600, grad_fn=<SumBackward0>)
tensor(0.1080, grad_fn=<SumBackward0>)
tensor(0.1303, grad_fn=<SumBackward0>)
tensor(0.1055, grad_fn=<SumBackward0>)
tensor(0.1548, grad_fn=<SumBackward0>)
tensor(0.0644, grad_fn=<SumBackward0>)
tensor(0.1230, grad_fn=<SumBackward0>)
tensor(0.1125, grad_fn=<SumBackward0>)
tensor(0.0852, grad_fn=<SumBackward0>)
tensor(0.0939, grad_fn=<SumBackward0>)
tensor(0.0313, grad_fn=<SumBackward0>)
tensor(0.0304, grad_fn=<SumBackward0>)
tensor(0.0435, grad_fn=<SumBackward0>)
tensor(0.0219, grad_fn=<SumBackward0>)
tensor(0.0351, grad_fn=<SumBackward0>)
tensor(0.0398, grad_fn=<SumBackward0>)
tensor(0.0250, grad_fn=<SumBackward0>)
tensor(0.0194, grad_fn=<SumBackward0>)
tensor(0.0364, grad_fn=<SumBackward0>)
tensor(0.0097, grad_fn=<SumBackward0>)
tensor(0.0292, grad_fn=<SumBackward0>)
tensor(0.0287, grad_fn=<SumBackward0>)
tensor(0.0313, grad_fn=<SumBackward0>)
tensor(0.0345, grad_fn=<SumBackward0>)
tensor(0.0315, grad_fn=<SumBackward0>)
tensor(0.0331, grad_fn=<SumBackward0>)
tensor(0.0416, grad_fn=<SumBackward0>)
tensor(0.0187, grad_fn=<SumBackward0>)
tensor(0.0139, grad_fn=<SumBackward0>)
tensor(0.0074, grad_fn=<SumBackward0>)
tensor(0.0175, grad_fn=<SumBackward0>)
tensor(0.0185, grad_fn=<SumBackward0>)
tensor(0.0160, grad_fn=<SumBackward0>)
tensor(0.0131, grad_fn=<SumBackward0>)
tensor(0.0120, grad_fn=<SumBackward0>)
tensor(0.0192, grad_fn=<SumBackward0>)
tensor(0.0043, grad_fn=<SumBackward0>)
tensor(0.0111, grad_fn=<SumBackward0>)
tensor(0.0170, grad_fn=<SumBackward0>)
tensor(0.0157, grad_fn=<SumBackward0>)
tensor(0.0084, grad_fn=<SumBackward0>)
tensor(0.0086, grad_fn=<SumBackward0>)
tensor(0.0093, grad_fn=<SumBackward0>)
tensor(0.0129, grad_fn=<SumBackward0>)
tensor(0.0090, grad_fn=<SumBackward0>)
tensor(0.0061, grad_fn=<SumBackward0>)
tensor(0.0070, grad_fn=<SumBackward0>)
tensor(0.0049, grad_fn=<SumBackward0>)
tensor(0.0045, grad_fn=<SumBackward0>)
tensor(0.0060, grad_fn=<SumBackward0>)
tensor(0.0056, grad_fn=<SumBackward0>)
tensor(0.0051, grad_fn=<SumBackward0>)
tensor(0.0036, grad_fn=<SumBackward0>)
tensor(0.0051, grad_fn=<SumBackward0>)
tensor(0.0050, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0031, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0022, grad_fn=<SumBackward0>)
tensor(0.0032, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0019, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0037, grad_fn=<SumBackward0>)
tensor(0.0022, grad_fn=<SumBackward0>)
tensor(0.0031, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
epoch 2, loss 0.000175
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0019, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
epoch 3, loss 0.000050

因为我们使用的是自己合成的数据集,所以我们知道真正的参数是什么。 因此,我们可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 事实上,真实参数和通过训练学到的参数确实非常接近。

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
w的估计误差: tensor([ 0.0008, -0.0010], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0002], grad_fn=<RsubBackward1>)

完整代码#

import random
import torch
from d2l import torch as d2l
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)  #生成特征和标签
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    print(indices)
    for i in range(0, num_examples, batch_size): #从0开始每次+10进行循环到1000为止
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])  # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]
        #print(batch_indices)
        yield features[batch_indices], labels[batch_indices]  #这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的



w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b
def squared_loss(y_hat, y):  #@save
    """均方损失"""
    #查看算出来的值和原来的标签比损失多少
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params: #先更新w然后更新b
            param -= lr * param.grad / batch_size
            param.grad.zero_()



for epoch in range(3):
    for X, y in data_iter(10, features, labels):
        # print(X)
        # print(y)
        # print(X.shape)
        # print(y.shape)
        l = squared_loss(linreg(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        c=l.sum() #这个值可以很直观的反映出逐渐下降
        print(c)
        c.backward()
        sgd([w, b], 0.03, 10)  # 使用参数的梯度更新参数,学习率lr是0.03
    with torch.no_grad():
        train_l = squared_loss(linreg(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
[967, 462, 867, 633, 961, 756, 532, 379, 700, 356, 760, 799, 908, 284, 79, 314, 209, 487, 946, 905, 36, 929, 133, 887, 411, 927, 741, 486, 975, 895, 602, 102, 909, 880, 942, 745, 418, 829, 395, 146, 341, 776, 836, 912, 105, 723, 884, 18, 714, 762, 342, 483, 349, 232, 842, 196, 361, 901, 372, 111, 142, 279, 160, 354, 84, 492, 955, 669, 999, 330, 930, 404, 92, 184, 984, 611, 343, 412, 480, 659, 779, 680, 166, 496, 1, 743, 806, 471, 862, 899, 7, 510, 707, 386, 388, 96, 40, 260, 755, 709, 654, 690, 620, 595, 985, 305, 730, 853, 369, 156, 167, 119, 516, 128, 466, 185, 904, 527, 738, 539, 72, 508, 405, 66, 803, 403, 479, 189, 962, 85, 385, 793, 297, 138, 849, 552, 283, 705, 739, 773, 526, 935, 558, 976, 296, 64, 97, 740, 145, 332, 67, 169, 766, 797, 295, 937, 581, 875, 933, 911, 95, 667, 109, 970, 753, 693, 798, 921, 108, 826, 357, 729, 42, 445, 586, 134, 241, 941, 280, 430, 658, 648, 421, 713, 978, 818, 503, 214, 731, 840, 346, 852, 601, 318, 831, 248, 37, 864, 352, 104, 583, 11, 979, 968, 960, 604, 61, 129, 453, 255, 495, 673, 415, 834, 370, 455, 716, 835, 460, 676, 580, 118, 559, 649, 883, 924, 264, 448, 632, 998, 647, 754, 245, 848, 641, 81, 472, 208, 813, 517, 814, 100, 507, 890, 540, 839, 300, 155, 817, 181, 825, 645, 315, 969, 671, 587, 851, 523, 93, 457, 509, 653, 164, 345, 15, 597, 897, 551, 857, 556, 408, 868, 585, 47, 675, 697, 351, 39, 33, 377, 465, 631, 994, 141, 91, 665, 400, 394, 511, 43, 850, 903, 463, 578, 846, 163, 120, 267, 966, 691, 347, 902, 625, 192, 566, 27, 833, 630, 534, 417, 205, 618, 914, 140, 547, 258, 541, 323, 362, 752, 640, 28, 861, 550, 329, 74, 854, 720, 770, 48, 423, 244, 621, 328, 35, 795, 491, 459, 339, 376, 500, 683, 913, 545, 827, 600, 21, 563, 263, 599, 878, 468, 45, 763, 617, 678, 188, 724, 77, 950, 147, 60, 191, 65, 674, 565, 735, 154, 629, 276, 290, 200, 71, 668, 390, 63, 239, 922, 536, 988, 422, 844, 251, 613, 882, 603, 830, 524, 301, 94, 860, 744, 131, 656, 420, 515, 893, 433, 122, 923, 926, 991, 278, 333, 562, 6, 371, 627, 26, 643, 624, 399, 348, 52, 900, 353, 308, 579, 29, 38, 866, 449, 777, 115, 932, 781, 287, 614, 311, 748, 622, 879, 101, 858, 794, 512, 176, 298, 144, 746, 46, 56, 971, 326, 218, 847, 666, 378, 139, 204, 957, 919, 368, 23, 236, 576, 447, 796, 5, 787, 560, 749, 327, 173, 953, 452, 915, 291, 13, 725, 262, 410, 87, 612, 993, 872, 702, 944, 391, 446, 303, 350, 651, 271, 309, 199, 747, 936, 546, 32, 681, 51, 360, 774, 805, 715, 252, 161, 126, 863, 482, 497, 431, 788, 136, 215, 530, 692, 616, 577, 490, 685, 506, 802, 533, 498, 292, 792, 439, 355, 564, 364, 58, 784, 670, 812, 677, 210, 78, 247, 996, 521, 809, 943, 485, 615, 107, 238, 694, 553, 338, 25, 307, 152, 14, 881, 843, 464, 316, 49, 121, 414, 375, 269, 86, 688, 898, 451, 261, 302, 661, 811, 3, 528, 543, 170, 450, 750, 426, 990, 240, 505, 59, 964, 634, 607, 331, 859, 982, 636, 980, 340, 855, 598, 112, 110, 177, 179, 235, 57, 940, 722, 187, 519, 159, 832, 246, 194, 544, 268, 20, 513, 17, 75, 103, 288, 910, 12, 150, 822, 684, 233, 106, 277, 568, 206, 591, 223, 381, 99, 413, 116, 925, 458, 76, 582, 983, 951, 494, 759, 16, 782, 44, 387, 963, 193, 444, 186, 216, 874, 336, 299, 470, 286, 224, 461, 319, 321, 804, 273, 589, 596, 995, 789, 635, 231, 249, 726, 289, 213, 538, 19, 211, 70, 367, 250, 732, 873, 518, 660, 50, 229, 778, 765, 222, 219, 956, 973, 758, 780, 567, 569, 719, 570, 588, 30, 228, 265, 162, 650, 254, 416, 358, 535, 197, 718, 488, 769, 751, 201, 149, 393, 790, 610, 334, 742, 202, 203, 785, 88, 158, 590, 456, 293, 712, 226, 608, 892, 124, 877, 89, 856, 525, 440, 98, 272, 954, 10, 917, 90, 768, 916, 312, 294, 82, 885, 609, 113, 172, 489, 783, 896, 593, 698, 841, 478, 384, 974, 870, 217, 493, 481, 682, 432, 475, 322, 644, 443, 190, 642, 701, 182, 733, 628, 584, 828, 474, 484, 548, 454, 695, 514, 2, 711, 886, 679, 335, 366, 945, 9, 959, 819, 225, 949, 281, 655, 977, 427, 285, 706, 143, 397, 220, 477, 939, 396, 207, 259, 815, 807, 824, 504, 865, 435, 948, 888, 499, 938, 125, 920, 965, 135, 389, 62, 696, 304, 392, 801, 501, 437, 918, 382, 275, 424, 324, 502, 406, 699, 436, 687, 402, 721, 637, 549, 808, 171, 626, 997, 73, 54, 306, 374, 473, 542, 664, 180, 572, 574, 764, 646, 986, 820, 31, 894, 606, 476, 313, 221, 845, 821, 270, 689, 337, 771, 907, 592, 317, 800, 989, 253, 663, 409, 529, 363, 662, 230, 906, 174, 987, 183, 786, 736, 4, 947, 401, 573, 561, 522, 53, 227, 234, 175, 623, 123, 419, 310, 531, 992, 703, 127, 157, 772, 68, 710, 8, 441, 537, 757, 952, 891, 153, 972, 876, 958, 178, 934, 282, 928, 243, 810, 55, 717, 151, 555, 837, 520, 383, 708, 195, 320, 869, 168, 775, 734, 373, 398, 728, 41, 823, 114, 554, 117, 380, 767, 274, 638, 619, 442, 266, 704, 434, 132, 981, 816, 148, 571, 761, 639, 256, 237, 889, 80, 365, 428, 652, 469, 198, 438, 557, 344, 407, 325, 0, 165, 83, 672, 22, 931, 737, 242, 359, 838, 69, 425, 594, 791, 871, 130, 429, 24, 34, 605, 212, 686, 575, 257, 467, 137, 727, 657]
tensor(128.7001, grad_fn=<SumBackward0>)
tensor(171.3044, grad_fn=<SumBackward0>)
tensor(291.9825, grad_fn=<SumBackward0>)
tensor(93.7639, grad_fn=<SumBackward0>)
tensor(75.0274, grad_fn=<SumBackward0>)
tensor(96.8312, grad_fn=<SumBackward0>)
tensor(180.1020, grad_fn=<SumBackward0>)
tensor(76.0412, grad_fn=<SumBackward0>)
tensor(86.2822, grad_fn=<SumBackward0>)
tensor(66.8793, grad_fn=<SumBackward0>)
tensor(101.0904, grad_fn=<SumBackward0>)
tensor(69.5738, grad_fn=<SumBackward0>)
tensor(94.2856, grad_fn=<SumBackward0>)
tensor(47.7455, grad_fn=<SumBackward0>)
tensor(47.2838, grad_fn=<SumBackward0>)
tensor(85.3362, grad_fn=<SumBackward0>)
tensor(103.6539, grad_fn=<SumBackward0>)
tensor(93.9505, grad_fn=<SumBackward0>)
tensor(52.2939, grad_fn=<SumBackward0>)
tensor(51.2293, grad_fn=<SumBackward0>)
tensor(72.3017, grad_fn=<SumBackward0>)
tensor(32.3095, grad_fn=<SumBackward0>)
tensor(39.2859, grad_fn=<SumBackward0>)
tensor(46.7255, grad_fn=<SumBackward0>)
tensor(25.3942, grad_fn=<SumBackward0>)
tensor(69.4885, grad_fn=<SumBackward0>)
tensor(34.0577, grad_fn=<SumBackward0>)
tensor(34.0460, grad_fn=<SumBackward0>)
tensor(34.9788, grad_fn=<SumBackward0>)
tensor(25.4518, grad_fn=<SumBackward0>)
tensor(25.2133, grad_fn=<SumBackward0>)
tensor(44.1974, grad_fn=<SumBackward0>)
tensor(24.6700, grad_fn=<SumBackward0>)
tensor(12.1781, grad_fn=<SumBackward0>)
tensor(16.9477, grad_fn=<SumBackward0>)
tensor(18.8487, grad_fn=<SumBackward0>)
tensor(16.7771, grad_fn=<SumBackward0>)
tensor(8.4158, grad_fn=<SumBackward0>)
tensor(13.8966, grad_fn=<SumBackward0>)
tensor(14.1992, grad_fn=<SumBackward0>)
tensor(7.3075, grad_fn=<SumBackward0>)
tensor(7.2561, grad_fn=<SumBackward0>)
tensor(16.8716, grad_fn=<SumBackward0>)
tensor(9.8331, grad_fn=<SumBackward0>)
tensor(10.1860, grad_fn=<SumBackward0>)
tensor(8.3158, grad_fn=<SumBackward0>)
tensor(4.5362, grad_fn=<SumBackward0>)
tensor(7.7493, grad_fn=<SumBackward0>)
tensor(3.3129, grad_fn=<SumBackward0>)
tensor(6.8449, grad_fn=<SumBackward0>)
tensor(2.7537, grad_fn=<SumBackward0>)
tensor(10.8037, grad_fn=<SumBackward0>)
tensor(6.7881, grad_fn=<SumBackward0>)
tensor(8.0662, grad_fn=<SumBackward0>)
tensor(3.8062, grad_fn=<SumBackward0>)
tensor(4.9091, grad_fn=<SumBackward0>)
tensor(2.4863, grad_fn=<SumBackward0>)
tensor(3.3162, grad_fn=<SumBackward0>)
tensor(4.2701, grad_fn=<SumBackward0>)
tensor(8.7932, grad_fn=<SumBackward0>)
tensor(1.7218, grad_fn=<SumBackward0>)
tensor(2.8972, grad_fn=<SumBackward0>)
tensor(1.2911, grad_fn=<SumBackward0>)
tensor(2.5773, grad_fn=<SumBackward0>)
tensor(2.9949, grad_fn=<SumBackward0>)
tensor(1.5874, grad_fn=<SumBackward0>)
tensor(3.6575, grad_fn=<SumBackward0>)
tensor(4.1022, grad_fn=<SumBackward0>)
tensor(0.9099, grad_fn=<SumBackward0>)
tensor(0.5322, grad_fn=<SumBackward0>)
tensor(1.9304, grad_fn=<SumBackward0>)
tensor(1.6320, grad_fn=<SumBackward0>)
tensor(1.8155, grad_fn=<SumBackward0>)
tensor(0.8852, grad_fn=<SumBackward0>)
tensor(0.5456, grad_fn=<SumBackward0>)
tensor(2.3587, grad_fn=<SumBackward0>)
tensor(0.9203, grad_fn=<SumBackward0>)
tensor(0.8397, grad_fn=<SumBackward0>)
tensor(0.8964, grad_fn=<SumBackward0>)
tensor(0.6959, grad_fn=<SumBackward0>)
tensor(0.8773, grad_fn=<SumBackward0>)
tensor(0.8334, grad_fn=<SumBackward0>)
tensor(0.7579, grad_fn=<SumBackward0>)
tensor(1.2198, grad_fn=<SumBackward0>)
tensor(0.9874, grad_fn=<SumBackward0>)
tensor(0.7335, grad_fn=<SumBackward0>)
tensor(0.6981, grad_fn=<SumBackward0>)
tensor(0.5757, grad_fn=<SumBackward0>)
tensor(0.5609, grad_fn=<SumBackward0>)
tensor(0.2775, grad_fn=<SumBackward0>)
tensor(0.4876, grad_fn=<SumBackward0>)
tensor(0.2568, grad_fn=<SumBackward0>)
tensor(0.4402, grad_fn=<SumBackward0>)
tensor(0.2958, grad_fn=<SumBackward0>)
tensor(0.5264, grad_fn=<SumBackward0>)
tensor(0.4358, grad_fn=<SumBackward0>)
tensor(0.5736, grad_fn=<SumBackward0>)
tensor(0.6261, grad_fn=<SumBackward0>)
tensor(0.2040, grad_fn=<SumBackward0>)
tensor(0.2032, grad_fn=<SumBackward0>)
epoch 1, loss 0.027386
[294, 398, 960, 771, 529, 986, 101, 528, 449, 198, 591, 744, 557, 155, 660, 263, 470, 116, 306, 742, 387, 987, 92, 544, 318, 193, 99, 72, 872, 590, 498, 750, 404, 696, 118, 600, 225, 562, 134, 32, 603, 841, 506, 535, 627, 424, 395, 843, 371, 207, 61, 880, 714, 252, 259, 976, 301, 842, 629, 237, 269, 901, 39, 468, 745, 402, 578, 973, 397, 667, 946, 453, 638, 724, 165, 554, 156, 525, 515, 519, 549, 274, 698, 860, 298, 0, 703, 746, 918, 739, 162, 681, 433, 91, 991, 707, 367, 588, 659, 502, 440, 537, 190, 495, 883, 625, 177, 323, 768, 131, 354, 401, 126, 409, 244, 84, 166, 640, 18, 969, 415, 784, 167, 892, 847, 230, 596, 159, 806, 329, 974, 429, 740, 120, 140, 57, 964, 929, 979, 375, 373, 864, 319, 42, 687, 341, 621, 172, 961, 723, 405, 664, 276, 211, 299, 289, 465, 693, 835, 8, 450, 900, 644, 915, 594, 914, 940, 457, 984, 17, 622, 943, 922, 838, 97, 977, 931, 785, 189, 849, 492, 83, 902, 715, 464, 307, 873, 29, 653, 980, 393, 776, 619, 330, 803, 286, 124, 608, 494, 100, 476, 733, 709, 788, 218, 569, 989, 471, 186, 444, 565, 912, 933, 420, 56, 751, 837, 545, 10, 88, 173, 490, 863, 794, 614, 454, 392, 384, 463, 637, 217, 623, 168, 138, 223, 95, 736, 945, 858, 704, 425, 53, 288, 30, 679, 197, 520, 675, 759, 599, 215, 260, 224, 461, 855, 117, 732, 487, 180, 255, 954, 572, 362, 5, 254, 370, 19, 235, 854, 396, 14, 426, 758, 532, 998, 129, 844, 250, 802, 121, 107, 749, 60, 355, 673, 480, 582, 937, 316, 924, 313, 311, 328, 12, 414, 913, 143, 952, 903, 141, 862, 830, 571, 280, 630, 268, 431, 604, 422, 388, 227, 990, 523, 241, 216, 439, 133, 59, 251, 249, 153, 719, 349, 312, 936, 628, 416, 4, 610, 920, 358, 472, 982, 559, 271, 447, 413, 620, 473, 825, 695, 543, 460, 320, 764, 437, 685, 96, 353, 478, 335, 814, 324, 511, 277, 568, 247, 778, 994, 279, 483, 897, 275, 285, 878, 163, 419, 895, 779, 513, 706, 762, 959, 200, 766, 708, 731, 314, 770, 526, 303, 317, 13, 730, 539, 484, 399, 527, 394, 956, 119, 43, 606, 351, 157, 89, 765, 48, 234, 925, 344, 281, 656, 27, 683, 232, 85, 955, 66, 657, 820, 754, 680, 366, 846, 560, 996, 356, 865, 702, 16, 666, 522, 661, 340, 921, 359, 213, 508, 999, 727, 617, 264, 853, 720, 563, 953, 103, 45, 322, 493, 130, 188, 477, 145, 76, 848, 949, 797, 671, 272, 906, 690, 775, 466, 567, 184, 135, 966, 267, 504, 82, 769, 859, 102, 800, 796, 125, 983, 185, 28, 586, 721, 774, 452, 7, 761, 767, 663, 407, 93, 238, 331, 479, 869, 44, 87, 390, 132, 997, 381, 605, 256, 942, 726, 505, 552, 743, 408, 831, 546, 350, 609, 208, 270, 114, 616, 819, 985, 108, 336, 839, 581, 229, 343, 46, 78, 38, 441, 840, 793, 417, 485, 304, 445, 292, 261, 79, 346, 15, 824, 887, 971, 646, 822, 110, 31, 810, 618, 411, 533, 790, 507, 542, 310, 451, 575, 972, 938, 876, 136, 361, 792, 773, 894, 566, 684, 161, 804, 160, 176, 442, 851, 127, 262, 811, 282, 786, 573, 63, 65, 585, 55, 300, 212, 77, 482, 975, 467, 917, 181, 540, 911, 243, 710, 584, 209, 51, 315, 561, 753, 365, 40, 676, 486, 481, 757, 916, 434, 149, 589, 491, 741, 6, 34, 139, 755, 601, 219, 593, 435, 81, 729, 932, 297, 633, 377, 780, 54, 910, 58, 503, 735, 716, 150, 25, 205, 489, 908, 37, 553, 500, 326, 886, 203, 248, 147, 548, 164, 941, 70, 33, 221, 760, 783, 474, 20, 73, 967, 541, 893, 302, 597, 246, 701, 634, 175, 222, 278, 815, 927, 521, 228, 75, 885, 669, 305, 805, 151, 626, 752, 71, 592, 550, 379, 655, 944, 516, 233, 705, 122, 403, 391, 875, 518, 360, 123, 204, 128, 26, 432, 245, 905, 169, 651, 146, 747, 832, 436, 287, 21, 809, 672, 386, 368, 427, 631, 823, 898, 836, 106, 923, 692, 210, 308, 111, 817, 882, 850, 438, 694, 333, 284, 636, 963, 870, 598, 357, 295, 547, 325, 291, 369, 36, 182, 574, 332, 668, 654, 290, 615, 456, 624, 389, 580, 253, 174, 738, 327, 884, 807, 455, 781, 383, 650, 579, 69, 877, 992, 833, 385, 283, 981, 334, 866, 728, 428, 337, 410, 231, 1, 265, 154, 382, 725, 236, 214, 678, 718, 926, 400, 496, 2, 105, 339, 475, 430, 879, 202, 965, 41, 641, 662, 828, 674, 639, 338, 35, 795, 273, 112, 697, 570, 777, 86, 98, 183, 577, 421, 137, 345, 179, 829, 812, 970, 682, 418, 658, 64, 538, 226, 378, 688, 509, 406, 517, 712, 423, 818, 826, 907, 196, 531, 649, 296, 950, 632, 178, 443, 711, 868, 462, 448, 347, 888, 772, 342, 240, 170, 199, 852, 665, 827, 115, 534, 50, 845, 968, 497, 691, 412, 9, 635, 871, 861, 930, 799, 896, 798, 374, 192, 756, 512, 148, 309, 501, 881, 947, 187, 47, 583, 595, 293, 257, 699, 372, 171, 645, 152, 380, 499, 890, 856, 602, 948, 988, 962, 514, 242, 62, 80, 713, 957, 821, 834, 459, 558, 700, 555, 928, 763, 689, 109, 11, 919, 613, 867, 935, 364, 748, 376, 939, 206, 488, 23, 612, 67, 978, 889, 24, 576, 652, 194, 816, 458, 934, 801, 995, 68, 266, 551, 536, 899, 510, 642, 446, 363, 993, 587, 220, 524, 556, 874, 647, 104, 686, 789, 348, 530, 191, 737, 94, 258, 3, 74, 142, 52, 22, 239, 607, 321, 734, 144, 717, 722, 782, 857, 352, 113, 90, 958, 564, 643, 677, 670, 787, 201, 791, 469, 611, 195, 158, 49, 813, 808, 909, 891, 904, 951, 648]
tensor(0.4449, grad_fn=<SumBackward0>)
tensor(0.2951, grad_fn=<SumBackward0>)
tensor(0.1344, grad_fn=<SumBackward0>)
tensor(0.1869, grad_fn=<SumBackward0>)
tensor(0.1448, grad_fn=<SumBackward0>)
tensor(0.2241, grad_fn=<SumBackward0>)
tensor(0.1051, grad_fn=<SumBackward0>)
tensor(0.1936, grad_fn=<SumBackward0>)
tensor(0.1734, grad_fn=<SumBackward0>)
tensor(0.1195, grad_fn=<SumBackward0>)
tensor(0.1783, grad_fn=<SumBackward0>)
tensor(0.0968, grad_fn=<SumBackward0>)
tensor(0.1111, grad_fn=<SumBackward0>)
tensor(0.1750, grad_fn=<SumBackward0>)
tensor(0.0615, grad_fn=<SumBackward0>)
tensor(0.1541, grad_fn=<SumBackward0>)
tensor(0.1947, grad_fn=<SumBackward0>)
tensor(0.0874, grad_fn=<SumBackward0>)
tensor(0.0782, grad_fn=<SumBackward0>)
tensor(0.1010, grad_fn=<SumBackward0>)
tensor(0.0697, grad_fn=<SumBackward0>)
tensor(0.0231, grad_fn=<SumBackward0>)
tensor(0.0727, grad_fn=<SumBackward0>)
tensor(0.0502, grad_fn=<SumBackward0>)
tensor(0.0646, grad_fn=<SumBackward0>)
tensor(0.0279, grad_fn=<SumBackward0>)
tensor(0.0514, grad_fn=<SumBackward0>)
tensor(0.0518, grad_fn=<SumBackward0>)
tensor(0.0772, grad_fn=<SumBackward0>)
tensor(0.0593, grad_fn=<SumBackward0>)
tensor(0.0721, grad_fn=<SumBackward0>)
tensor(0.0529, grad_fn=<SumBackward0>)
tensor(0.0191, grad_fn=<SumBackward0>)
tensor(0.0245, grad_fn=<SumBackward0>)
tensor(0.0198, grad_fn=<SumBackward0>)
tensor(0.0297, grad_fn=<SumBackward0>)
tensor(0.0207, grad_fn=<SumBackward0>)
tensor(0.0198, grad_fn=<SumBackward0>)
tensor(0.0281, grad_fn=<SumBackward0>)
tensor(0.0387, grad_fn=<SumBackward0>)
tensor(0.0132, grad_fn=<SumBackward0>)
tensor(0.0197, grad_fn=<SumBackward0>)
tensor(0.0189, grad_fn=<SumBackward0>)
tensor(0.0106, grad_fn=<SumBackward0>)
tensor(0.0286, grad_fn=<SumBackward0>)
tensor(0.0121, grad_fn=<SumBackward0>)
tensor(0.0064, grad_fn=<SumBackward0>)
tensor(0.0166, grad_fn=<SumBackward0>)
tensor(0.0109, grad_fn=<SumBackward0>)
tensor(0.0109, grad_fn=<SumBackward0>)
tensor(0.0063, grad_fn=<SumBackward0>)
tensor(0.0093, grad_fn=<SumBackward0>)
tensor(0.0132, grad_fn=<SumBackward0>)
tensor(0.0157, grad_fn=<SumBackward0>)
tensor(0.0078, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0064, grad_fn=<SumBackward0>)
tensor(0.0073, grad_fn=<SumBackward0>)
tensor(0.0065, grad_fn=<SumBackward0>)
tensor(0.0079, grad_fn=<SumBackward0>)
tensor(0.0069, grad_fn=<SumBackward0>)
tensor(0.0051, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0065, grad_fn=<SumBackward0>)
tensor(0.0061, grad_fn=<SumBackward0>)
tensor(0.0021, grad_fn=<SumBackward0>)
tensor(0.0034, grad_fn=<SumBackward0>)
tensor(0.0026, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0034, grad_fn=<SumBackward0>)
tensor(0.0036, grad_fn=<SumBackward0>)
tensor(0.0024, grad_fn=<SumBackward0>)
tensor(0.0022, grad_fn=<SumBackward0>)
tensor(0.0035, grad_fn=<SumBackward0>)
tensor(0.0056, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0020, grad_fn=<SumBackward0>)
tensor(0.0019, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0032, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
epoch 2, loss 0.000096
[676, 289, 769, 293, 123, 869, 125, 374, 619, 942, 707, 297, 245, 442, 739, 503, 830, 17, 343, 305, 208, 409, 592, 641, 171, 581, 705, 711, 538, 767, 337, 831, 872, 32, 370, 742, 425, 855, 815, 522, 88, 617, 976, 799, 402, 155, 612, 167, 778, 603, 454, 341, 182, 287, 103, 92, 84, 50, 847, 786, 235, 249, 638, 690, 112, 509, 54, 492, 124, 300, 152, 21, 895, 649, 256, 734, 569, 379, 149, 906, 83, 533, 35, 76, 99, 63, 394, 749, 878, 974, 777, 31, 542, 382, 384, 835, 274, 940, 876, 135, 864, 928, 636, 485, 905, 366, 654, 764, 91, 317, 263, 148, 48, 526, 613, 314, 750, 618, 924, 917, 736, 531, 632, 406, 760, 286, 562, 97, 954, 795, 87, 191, 27, 997, 897, 335, 457, 206, 491, 912, 978, 718, 825, 816, 588, 821, 310, 512, 26, 67, 853, 728, 824, 935, 709, 383, 993, 209, 717, 880, 529, 645, 848, 336, 326, 571, 946, 686, 468, 860, 810, 162, 192, 489, 746, 33, 130, 694, 960, 428, 214, 172, 230, 299, 258, 106, 601, 23, 60, 582, 543, 386, 420, 179, 170, 648, 561, 480, 677, 466, 564, 787, 544, 596, 197, 535, 175, 890, 875, 518, 672, 483, 116, 994, 793, 756, 284, 187, 144, 434, 665, 481, 761, 511, 499, 998, 698, 629, 80, 81, 381, 566, 174, 560, 369, 397, 493, 524, 856, 583, 840, 294, 593, 650, 72, 51, 579, 46, 723, 201, 261, 264, 679, 153, 328, 325, 674, 804, 283, 207, 745, 607, 30, 833, 389, 73, 255, 36, 555, 902, 549, 164, 758, 250, 142, 469, 281, 93, 765, 252, 692, 669, 273, 423, 53, 507, 242, 185, 653, 367, 479, 321, 3, 495, 461, 66, 911, 846, 189, 740, 628, 874, 278, 683, 360, 805, 222, 608, 884, 732, 345, 377, 295, 282, 351, 6, 350, 122, 576, 726, 813, 165, 478, 851, 996, 624, 421, 693, 184, 981, 342, 956, 737, 802, 584, 770, 968, 670, 290, 595, 44, 288, 18, 664, 784, 747, 741, 980, 393, 910, 659, 416, 459, 731, 212, 25, 643, 86, 992, 96, 938, 545, 194, 916, 303, 166, 270, 203, 530, 460, 651, 115, 703, 28, 661, 559, 597, 304, 435, 951, 102, 971, 743, 900, 627, 445, 829, 52, 71, 710, 719, 957, 757, 371, 418, 253, 260, 550, 202, 16, 169, 668, 312, 193, 933, 989, 808, 210, 896, 673, 995, 358, 100, 704, 652, 863, 945, 302, 296, 662, 961, 689, 722, 131, 47, 660, 977, 839, 4, 540, 752, 817, 826, 525, 505, 513, 346, 642, 368, 430, 965, 262, 788, 95, 556, 407, 827, 738, 147, 681, 620, 635, 563, 173, 163, 186, 573, 417, 602, 932, 541, 266, 599, 691, 45, 233, 200, 349, 675, 42, 376, 243, 644, 887, 783, 948, 298, 239, 462, 959, 412, 979, 49, 600, 716, 280, 452, 984, 547, 363, 823, 640, 77, 785, 329, 871, 861, 908, 776, 37, 748, 339, 565, 504, 13, 276, 762, 375, 65, 811, 477, 484, 119, 105, 108, 836, 552, 909, 918, 9, 439, 947, 656, 398, 844, 590, 301, 780, 567, 500, 551, 553, 922, 843, 318, 633, 882, 411, 702, 528, 614, 766, 578, 508, 340, 180, 0, 609, 964, 137, 862, 893, 892, 699, 320, 768, 195, 38, 225, 646, 735, 141, 631, 364, 537, 330, 771, 859, 532, 591, 157, 487, 204, 605, 444, 655, 678, 111, 90, 391, 774, 606, 151, 128, 920, 866, 958, 482, 775, 113, 950, 845, 929, 520, 237, 470, 812, 666, 70, 58, 441, 431, 279, 14, 822, 136, 168, 688, 476, 251, 24, 234, 988, 658, 400, 881, 246, 334, 521, 904, 352, 267, 868, 449, 218, 986, 98, 838, 621, 373, 448, 604, 953, 223, 803, 438, 883, 85, 763, 443, 132, 966, 107, 610, 161, 327, 682, 949, 759, 269, 271, 727, 680, 941, 244, 254, 40, 554, 231, 919, 585, 68, 433, 62, 820, 744, 973, 586, 519, 819, 515, 247, 215, 625, 794, 807, 177, 841, 380, 800, 837, 622, 257, 801, 498, 720, 178, 419, 685, 755, 227, 451, 61, 527, 501, 580, 467, 891, 706, 865, 700, 446, 695, 422, 2, 952, 587, 57, 217, 474, 427, 150, 701, 885, 987, 121, 82, 309, 181, 236, 357, 536, 712, 126, 915, 797, 796, 697, 598, 733, 415, 854, 362, 510, 447, 238, 930, 292, 806, 471, 967, 240, 39, 611, 410, 899, 344, 781, 404, 405, 143, 496, 934, 779, 395, 894, 268, 59, 5, 408, 429, 497, 539, 139, 159, 114, 943, 867, 972, 475, 506, 970, 331, 265, 414, 22, 923, 322, 639, 751, 361, 426, 79, 983, 886, 809, 630, 684, 403, 205, 74, 308, 671, 955, 558, 724, 792, 455, 962, 424, 465, 354, 832, 347, 546, 490, 647, 623, 721, 913, 41, 514, 259, 657, 219, 782, 990, 43, 523, 133, 378, 348, 849, 104, 145, 272, 228, 29, 903, 975, 534, 615, 715, 183, 463, 413, 75, 985, 577, 385, 221, 570, 687, 396, 55, 154, 15, 494, 729, 858, 456, 926, 392, 963, 574, 190, 198, 708, 730, 939, 667, 275, 1, 359, 921, 355, 307, 790, 220, 11, 196, 19, 575, 306, 857, 834, 89, 616, 925, 372, 634, 696, 226, 437, 877, 56, 982, 110, 626, 8, 870, 387, 316, 789, 401, 502, 773, 315, 117, 213, 589, 486, 927, 907, 453, 313, 109, 232, 248, 118, 663, 129, 323, 798, 12, 842, 828, 211, 944, 458, 324, 450, 229, 120, 138, 969, 557, 199, 356, 319, 134, 888, 390, 156, 7, 365, 241, 432, 914, 291, 637, 188, 473, 10, 20, 388, 338, 889, 814, 572, 353, 772, 127, 436, 753, 472, 517, 818, 999, 879, 333, 224, 64, 311, 332, 69, 464, 160, 176, 440, 725, 873, 594, 713, 277, 850, 791, 936, 548, 34, 158, 931, 78, 714, 568, 488, 146, 94, 399, 898, 852, 285, 901, 216, 516, 140, 754, 937, 101, 991]
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
epoch 3, loss 0.000051
w的估计误差: tensor([-0.0005, -0.0002], grad_fn=<SubBackward0>)
b的估计误差: tensor([-0.0003], grad_fn=<RsubBackward1>)