线性回归#
线性回归核心就是在x和y轴中,给出一个数x会有相对应的一个y值。我们需要得到这一个模型(通俗说:一个直线公式)。
在线性回归中,数据使用线性预测函数来建模,并且未知的模型参数也是通过数据来估计。这些模型被叫做线性模型。最常用的线性回归建模是给定X值的y的条件均值是X的仿射函数。不太一般的情况,线性回归模型可以是一个中位数或一些其他的给定X的条件下y的条件分布的分位数作为X的线性函数表示。 线性回归有很多实际用途。分为以下两大类:
如果目标是预测或者映射,线性回归可以用来对观测数据集的和X的值拟合出一个预测模型。当完成这样一个模型以后,对于一个新增的X值,在没有给定与它相配对的y的情况下,可以用这个拟合过的模型预测出一个y值。
给定一个变量y和一些变量\({\displaystyle X_{1}},...,{\displaystyle X_{p}}\),这些变量有可能与y相关,线性回归分析可以用来量化y与Xj之间相关性的强度,评估出与y不相关的\({\displaystyle X_{j}}\),并识别出哪些\({\displaystyle X_{j}}\)的子集包含了关于y的冗余信息。
训练的过程下图会更直观的展示
from IPython import display
import time
from mindspore.train.callback import Callback
from mindspore import Model
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import nn
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import numpy as np
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_target_label, y_target_label, color="green")
plt.title("Eval data")
plt.show()
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())
print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
print(param, param.asnumpy())
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
Tensor(model_params[1]).asnumpy()[0])
plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.show()
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
def plot_model_and_datasets(net, eval_data):
weight = net.trainable_params()[0]
bias = net.trainable_params()[1]
x = np.arange(-10, 10, 0.1)
y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
x1, y1 = zip(*eval_data)
x_target = x
y_target = x_target * 2 + 3
plt.axis([-11, 11, -20, 25])
plt.scatter(x1, y1, color="red", s=5)
plt.plot(x, y, color="blue")
plt.plot(x_target, y_target, color="green")
plt.show()
time.sleep(0.02)
class ImageShowCallback(Callback):
def __init__(self, net, eval_data):
self.net = net
self.eval_data = eval_data
def step_end(self, run_context):
plot_model_and_datasets(self.net, self.eval_data)
display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)
plot_model_and_datasets(net, eval_data)
for param in net.trainable_params():
print(param, param.asnumpy())

%matplotlib inline
import random
import torch
from d2l import torch as d2l
/Users/ascotbe/anaconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'dlopen(/Users/ascotbe/anaconda3/lib/python3.10/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev
Referenced from: <6A7076EE-85BD-37A7-BC35-1D4867F2B3D3> /Users/ascotbe/anaconda3/lib/python3.10/site-packages/torchvision/image.so
Expected in: <A84DFEFF-287E-3B94-A7DB-731FA5F9CBBC> /Users/ascotbe/anaconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
生成数据集#
在下面的代码中,我们生成一个包含1000个样本的数据集, 每个样本包含从标准正态分布中采样的2个特征。 我们的合成数据集是一个矩阵\(\mathbf{X}\in \mathbb{R}^{1000 \times 2}\)。
我们使用线性模型参数\(\mathbf{w} = [2, -3.4]^\top\)、\(b = 4.2\) 和噪声项\(\epsilon\)生成数据集及其标签:
\(\epsilon\)可以视为模型预测和标签时的潜在观测误差。 在这里我们认为标准假设成立,即\(\epsilon\)服从均值为0的正态分布。 为了简化问题,我们将标准差设为0.01。
def synthetic_data(w, b, num_examples): #@save
"""生成y=Xw+b+噪声"""
#means (Tensor) – 均值(平均值)
#std (Tensor) – 标准差 https://zh.wikihow.com/%E8%AE%A1%E7%AE%97%E6%A0%87%E5%87%86%E5%B7%AE
#out (Tensor) – 可选的输出张量
X = torch.normal(0, 1, (num_examples, len(w)))
print(X)
#两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现
y = torch.matmul(X, w) + b
#print(y)
y += torch.normal(0, 0.01, y.shape)
#print(y)
#torch.shape 和 torch.size()
#-1表示总数所在的位置
return X, y.reshape((-1, 1))
#创建张量
true_w = torch.tensor([2, -3.4])
print(true_w)
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
tensor([ 2.0000, -3.4000])
tensor([[-0.0550, -1.8695],
[-0.7796, 1.2877],
[ 1.4891, 0.0385],
...,
[-0.5137, 0.3570],
[-1.3170, 0.1718],
[ 0.4080, -2.8939]])
features中的每一行都包含一个二维数据样本, labels中的每一行都包含一维标签值(一个标量)
print('features:', features[0],'\nlabel:', labels[0])
print(features)
features: tensor([-0.0550, -1.8695])
label: tensor([10.4416])
tensor([[-0.0550, -1.8695],
[-0.7796, 1.2877],
[ 1.4891, 0.0385],
...,
[-0.5137, 0.3570],
[-1.3170, 0.1718],
[ 0.4080, -2.8939]])
通过生成第二个特征features[:, 1]和labels的散点图, 可以直观观察到两者之间的线性关系。
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
读取数据集#
我们定义一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。
使用下面代码的时候先阅读下用法
yueld用法
直接参考 https://blog.csdn.net/mieleizhi0522/article/details/82142856/
def data_iter(batch_size, features, labels):
num_examples = len(features)
#print(num_examples)
indices = list(range(num_examples))
#print(indices)
# 这些样本是随机读取的,没有特定的顺序,打乱位置
random.shuffle(indices)
for i in range(0, num_examples, batch_size): # 从0开始每次+10进行循环到1000为止
batch_indices=indices[i: min(i + batch_size, num_examples)] # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]
#print(batch_indices)
#https://blog.csdn.net/mieleizhi0522/article/details/82142856/
yield features[batch_indices], labels[batch_indices] # 这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的
通常,我们利用GPU并行运算的优势,处理合理大小的“小批量”。 每个样本都可以并行地进行模型计算,且每个样本损失函数的梯度也可以被并行计算。 GPU可以在处理几百个样本时,所花费的时间不比处理一个样本时多太多。
我们直观感受一下小批量运算:读取第一个小批量数据样本并打印。 每个批量的特征维度显示批量大小和输入特征数。 同样的,批量的标签形状与batch_size相等。
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, '\n', y)
break
tensor([[-0.4661, 0.6059],
[ 1.6231, -0.7025],
[ 0.2688, -0.3114],
[-0.2076, 1.0464],
[-0.9737, 2.2190],
[-1.8458, -0.6901],
[ 0.7811, 1.0996],
[ 0.4051, 0.9514],
[ 0.9066, 0.5997],
[ 1.1037, -0.4426]])
tensor([[ 1.2227],
[ 9.8492],
[ 5.8022],
[ 0.2352],
[-5.3089],
[ 2.8795],
[ 2.0207],
[ 1.7710],
[ 3.9736],
[ 7.9158]])
初始化模型参数#
在下面的代码中,我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
定义模型#
我们只需计算输入特征\(\mathbf{X}\)和模型权重\(\mathbf{w}\)的矩阵-向量乘法后加上偏置\(b\)。 注意,上面的\(\mathbf{Xw}\)是一个向量,而\(b\)是一个标量。 回想一下广播机制: 当我们用一个向量加一个标量时,标量会被加到向量的每个分量上。
def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b
定义损失函数#
因为需要计算损失函数的梯度,所以我们应该先定义损失函数。在实现中,我们需要将真实值y的形状转换为和预测值y_hat的形状相同
def squared_loss(y_hat, y): #@save
"""均方损失"""
#print(y_hat.shape)
#print(y_hat)
#真实值y的形状转换为和预测值y_hat的形状相同
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
定义优化算法#
下面的函数实现小批量随机梯度下降更新。 该函数接受模型参数集合、学习速率和批量大小作为输入。每 一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
#with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭/线程中锁的自动获取和释放等。
# no_grad用来关闭梯度计算
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
#需要清理梯度值不然会累加
param.grad.zero_()
训练#
现在我们已经准备好了模型训练所有需要的要素,可以实现主要的训练过程部分了。
理解这段代码至关重要,因为从事深度学习后,
相同的训练过程几乎一遍又一遍地出现。
在每次迭代中,我们读取一小批量训练样本,并通过我们的模型来获得一组预测。
计算完损失后,我们开始反向传播,存储每个参数的梯度。
最后,我们调用优化算法sgd
来更新模型参数。
概括一下,我们将执行以下循环:
初始化参数
重复以下训练,直到完成
计算梯度\(\mathbf{g} \leftarrow \partial_{(\mathbf{w},b)} \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} l(\mathbf{x}^{(i)}, y^{(i)}, \mathbf{w}, b)\)
更新参数\((\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \mathbf{g}\)
在每个迭代周期(epoch)中,我们使用data_iter
函数遍历整个数据集,
并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。
这里的迭代周期个数num_epochs
和学习率lr
都是超参数,分别设为3和0.03。
设置超参数很棘手,需要通过反复试验进行调整。
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for X, y in data_iter(batch_size, features, labels):
# print(X)
# print(y)
# print(X.shape)
# print(y.shape)
l = loss(net(X, w, b), y) # X和y的小批量损失
# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
# 并以此计算关于[w,b]的梯度
c=l.sum() #这个值可以很直观的反映出逐渐下降
print(c)
c.backward()
sgd([w, b], lr, batch_size) # 使用参数的梯度更新参数,学习率lr是0.03
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
tensor(226.2876, grad_fn=<SumBackward0>)
tensor(263.4118, grad_fn=<SumBackward0>)
tensor(206.6820, grad_fn=<SumBackward0>)
tensor(159.9951, grad_fn=<SumBackward0>)
tensor(39.9315, grad_fn=<SumBackward0>)
tensor(207.5650, grad_fn=<SumBackward0>)
tensor(92.2431, grad_fn=<SumBackward0>)
tensor(132.7710, grad_fn=<SumBackward0>)
tensor(51.5576, grad_fn=<SumBackward0>)
tensor(97.8501, grad_fn=<SumBackward0>)
tensor(63.1343, grad_fn=<SumBackward0>)
tensor(96.5267, grad_fn=<SumBackward0>)
tensor(55.0593, grad_fn=<SumBackward0>)
tensor(90.5227, grad_fn=<SumBackward0>)
tensor(68.6579, grad_fn=<SumBackward0>)
tensor(44.9961, grad_fn=<SumBackward0>)
tensor(28.8401, grad_fn=<SumBackward0>)
tensor(84.6883, grad_fn=<SumBackward0>)
tensor(44.5400, grad_fn=<SumBackward0>)
tensor(43.2729, grad_fn=<SumBackward0>)
tensor(45.9701, grad_fn=<SumBackward0>)
tensor(49.1083, grad_fn=<SumBackward0>)
tensor(33.4911, grad_fn=<SumBackward0>)
tensor(26.8881, grad_fn=<SumBackward0>)
tensor(31.9125, grad_fn=<SumBackward0>)
tensor(26.5577, grad_fn=<SumBackward0>)
tensor(13.9982, grad_fn=<SumBackward0>)
tensor(15.2729, grad_fn=<SumBackward0>)
tensor(26.2543, grad_fn=<SumBackward0>)
tensor(15.3722, grad_fn=<SumBackward0>)
tensor(33.4907, grad_fn=<SumBackward0>)
tensor(34.4544, grad_fn=<SumBackward0>)
tensor(15.9225, grad_fn=<SumBackward0>)
tensor(16.3410, grad_fn=<SumBackward0>)
tensor(26.2701, grad_fn=<SumBackward0>)
tensor(12.0293, grad_fn=<SumBackward0>)
tensor(14.4626, grad_fn=<SumBackward0>)
tensor(11.3578, grad_fn=<SumBackward0>)
tensor(20.8545, grad_fn=<SumBackward0>)
tensor(16.6592, grad_fn=<SumBackward0>)
tensor(11.6822, grad_fn=<SumBackward0>)
tensor(11.2992, grad_fn=<SumBackward0>)
tensor(12.4823, grad_fn=<SumBackward0>)
tensor(9.1322, grad_fn=<SumBackward0>)
tensor(11.5876, grad_fn=<SumBackward0>)
tensor(15.5524, grad_fn=<SumBackward0>)
tensor(5.2756, grad_fn=<SumBackward0>)
tensor(10.1902, grad_fn=<SumBackward0>)
tensor(13.0943, grad_fn=<SumBackward0>)
tensor(7.4122, grad_fn=<SumBackward0>)
tensor(3.8220, grad_fn=<SumBackward0>)
tensor(6.0167, grad_fn=<SumBackward0>)
tensor(12.8497, grad_fn=<SumBackward0>)
tensor(4.2660, grad_fn=<SumBackward0>)
tensor(6.9012, grad_fn=<SumBackward0>)
tensor(10.3200, grad_fn=<SumBackward0>)
tensor(2.3296, grad_fn=<SumBackward0>)
tensor(2.8023, grad_fn=<SumBackward0>)
tensor(5.5314, grad_fn=<SumBackward0>)
tensor(5.6888, grad_fn=<SumBackward0>)
tensor(9.0017, grad_fn=<SumBackward0>)
tensor(5.5467, grad_fn=<SumBackward0>)
tensor(3.9664, grad_fn=<SumBackward0>)
tensor(1.0643, grad_fn=<SumBackward0>)
tensor(3.2483, grad_fn=<SumBackward0>)
tensor(3.0165, grad_fn=<SumBackward0>)
tensor(0.8010, grad_fn=<SumBackward0>)
tensor(1.4023, grad_fn=<SumBackward0>)
tensor(2.1092, grad_fn=<SumBackward0>)
tensor(2.0835, grad_fn=<SumBackward0>)
tensor(0.8064, grad_fn=<SumBackward0>)
tensor(3.7999, grad_fn=<SumBackward0>)
tensor(1.6019, grad_fn=<SumBackward0>)
tensor(1.1997, grad_fn=<SumBackward0>)
tensor(0.7264, grad_fn=<SumBackward0>)
tensor(1.8122, grad_fn=<SumBackward0>)
tensor(3.1229, grad_fn=<SumBackward0>)
tensor(0.4518, grad_fn=<SumBackward0>)
tensor(1.4929, grad_fn=<SumBackward0>)
tensor(0.5765, grad_fn=<SumBackward0>)
tensor(1.1170, grad_fn=<SumBackward0>)
tensor(1.5734, grad_fn=<SumBackward0>)
tensor(1.4834, grad_fn=<SumBackward0>)
tensor(0.8545, grad_fn=<SumBackward0>)
tensor(1.7870, grad_fn=<SumBackward0>)
tensor(1.1432, grad_fn=<SumBackward0>)
tensor(1.1431, grad_fn=<SumBackward0>)
tensor(0.5697, grad_fn=<SumBackward0>)
tensor(0.9944, grad_fn=<SumBackward0>)
tensor(0.8935, grad_fn=<SumBackward0>)
tensor(0.9371, grad_fn=<SumBackward0>)
tensor(1.6365, grad_fn=<SumBackward0>)
tensor(0.9641, grad_fn=<SumBackward0>)
tensor(0.5046, grad_fn=<SumBackward0>)
tensor(0.4203, grad_fn=<SumBackward0>)
tensor(0.4059, grad_fn=<SumBackward0>)
tensor(0.5704, grad_fn=<SumBackward0>)
tensor(0.4244, grad_fn=<SumBackward0>)
tensor(0.2316, grad_fn=<SumBackward0>)
tensor(0.4955, grad_fn=<SumBackward0>)
epoch 1, loss 0.036205
tensor(0.3098, grad_fn=<SumBackward0>)
tensor(0.3973, grad_fn=<SumBackward0>)
tensor(0.3253, grad_fn=<SumBackward0>)
tensor(0.5152, grad_fn=<SumBackward0>)
tensor(0.3920, grad_fn=<SumBackward0>)
tensor(0.2539, grad_fn=<SumBackward0>)
tensor(0.2239, grad_fn=<SumBackward0>)
tensor(0.3049, grad_fn=<SumBackward0>)
tensor(0.1205, grad_fn=<SumBackward0>)
tensor(0.2066, grad_fn=<SumBackward0>)
tensor(0.2309, grad_fn=<SumBackward0>)
tensor(0.2325, grad_fn=<SumBackward0>)
tensor(0.0904, grad_fn=<SumBackward0>)
tensor(0.1565, grad_fn=<SumBackward0>)
tensor(0.0707, grad_fn=<SumBackward0>)
tensor(0.1975, grad_fn=<SumBackward0>)
tensor(0.2615, grad_fn=<SumBackward0>)
tensor(0.1003, grad_fn=<SumBackward0>)
tensor(0.0905, grad_fn=<SumBackward0>)
tensor(0.0844, grad_fn=<SumBackward0>)
tensor(0.1285, grad_fn=<SumBackward0>)
tensor(0.0894, grad_fn=<SumBackward0>)
tensor(0.0924, grad_fn=<SumBackward0>)
tensor(0.0833, grad_fn=<SumBackward0>)
tensor(0.1330, grad_fn=<SumBackward0>)
tensor(0.0562, grad_fn=<SumBackward0>)
tensor(0.0886, grad_fn=<SumBackward0>)
tensor(0.0398, grad_fn=<SumBackward0>)
tensor(0.0378, grad_fn=<SumBackward0>)
tensor(0.0719, grad_fn=<SumBackward0>)
tensor(0.0148, grad_fn=<SumBackward0>)
tensor(0.0253, grad_fn=<SumBackward0>)
tensor(0.0546, grad_fn=<SumBackward0>)
tensor(0.0471, grad_fn=<SumBackward0>)
tensor(0.0348, grad_fn=<SumBackward0>)
tensor(0.0563, grad_fn=<SumBackward0>)
tensor(0.0330, grad_fn=<SumBackward0>)
tensor(0.0195, grad_fn=<SumBackward0>)
tensor(0.0248, grad_fn=<SumBackward0>)
tensor(0.0130, grad_fn=<SumBackward0>)
tensor(0.0238, grad_fn=<SumBackward0>)
tensor(0.0335, grad_fn=<SumBackward0>)
tensor(0.0191, grad_fn=<SumBackward0>)
tensor(0.0181, grad_fn=<SumBackward0>)
tensor(0.0239, grad_fn=<SumBackward0>)
tensor(0.0209, grad_fn=<SumBackward0>)
tensor(0.0153, grad_fn=<SumBackward0>)
tensor(0.0172, grad_fn=<SumBackward0>)
tensor(0.0089, grad_fn=<SumBackward0>)
tensor(0.0218, grad_fn=<SumBackward0>)
tensor(0.0065, grad_fn=<SumBackward0>)
tensor(0.0048, grad_fn=<SumBackward0>)
tensor(0.0162, grad_fn=<SumBackward0>)
tensor(0.0128, grad_fn=<SumBackward0>)
tensor(0.0068, grad_fn=<SumBackward0>)
tensor(0.0214, grad_fn=<SumBackward0>)
tensor(0.0148, grad_fn=<SumBackward0>)
tensor(0.0084, grad_fn=<SumBackward0>)
tensor(0.0077, grad_fn=<SumBackward0>)
tensor(0.0102, grad_fn=<SumBackward0>)
tensor(0.0074, grad_fn=<SumBackward0>)
tensor(0.0112, grad_fn=<SumBackward0>)
tensor(0.0084, grad_fn=<SumBackward0>)
tensor(0.0063, grad_fn=<SumBackward0>)
tensor(0.0062, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0055, grad_fn=<SumBackward0>)
tensor(0.0079, grad_fn=<SumBackward0>)
tensor(0.0059, grad_fn=<SumBackward0>)
tensor(0.0035, grad_fn=<SumBackward0>)
tensor(0.0089, grad_fn=<SumBackward0>)
tensor(0.0082, grad_fn=<SumBackward0>)
tensor(0.0036, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0034, grad_fn=<SumBackward0>)
tensor(0.0047, grad_fn=<SumBackward0>)
tensor(0.0033, grad_fn=<SumBackward0>)
tensor(0.0049, grad_fn=<SumBackward0>)
tensor(0.0072, grad_fn=<SumBackward0>)
tensor(0.0031, grad_fn=<SumBackward0>)
tensor(0.0052, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0030, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0029, grad_fn=<SumBackward0>)
tensor(0.0022, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0021, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0021, grad_fn=<SumBackward0>)
tensor(0.0018, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
epoch 2, loss 0.000129
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(9.3857e-05, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
epoch 3, loss 0.000049
因为我们使用的是自己合成的数据集,所以我们知道真正的参数是什么。 因此,我们可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 事实上,真实参数和通过训练学到的参数确实非常接近。
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
w的估计误差: tensor([ 0.0005, -0.0008], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0003], grad_fn=<RsubBackward1>)
完整代码#
import random
import torch
from d2l import torch as d2l
def synthetic_data(w, b, num_examples): #@save
"""生成y=Xw+b+噪声"""
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000) #生成特征和标签
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
print(indices)
for i in range(0, num_examples, batch_size): #从0开始每次+10进行循环到1000为止
batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)]) # 从1000个随机样本里面开始取值,每次取值范围是[i:i+batch_size]
#print(batch_indices)
yield features[batch_indices], labels[batch_indices] #这个函数表示每次features和labels都会冲上一次进行接下去,然后取值是用上面随机样本进行索引的
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b
def squared_loss(y_hat, y): #@save
"""均方损失"""
#查看算出来的值和原来的标签比损失多少
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params: #先更新w然后更新b
param -= lr * param.grad / batch_size
param.grad.zero_()
for epoch in range(3):
for X, y in data_iter(10, features, labels):
# print(X)
# print(y)
# print(X.shape)
# print(y.shape)
l = squared_loss(linreg(X, w, b), y) # X和y的小批量损失
# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
# 并以此计算关于[w,b]的梯度
c=l.sum() #这个值可以很直观的反映出逐渐下降
print(c)
c.backward()
sgd([w, b], 0.03, 10) # 使用参数的梯度更新参数,学习率lr是0.03
with torch.no_grad():
train_l = squared_loss(linreg(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
[523, 17, 275, 195, 76, 381, 462, 983, 714, 82, 619, 454, 293, 127, 692, 457, 484, 75, 602, 609, 85, 814, 574, 161, 703, 770, 223, 241, 509, 81, 355, 273, 218, 433, 894, 520, 220, 905, 834, 687, 493, 155, 524, 179, 121, 825, 177, 458, 886, 354, 713, 497, 152, 72, 136, 148, 37, 850, 975, 408, 392, 150, 495, 292, 506, 606, 40, 621, 775, 615, 766, 763, 227, 41, 294, 930, 557, 519, 411, 728, 347, 202, 870, 132, 731, 864, 49, 58, 140, 633, 442, 412, 197, 521, 500, 604, 340, 645, 302, 644, 83, 914, 389, 719, 664, 956, 562, 365, 696, 893, 311, 742, 14, 397, 258, 403, 622, 452, 423, 566, 334, 368, 271, 772, 785, 399, 395, 226, 8, 577, 950, 449, 499, 314, 103, 173, 681, 618, 708, 944, 472, 693, 632, 573, 494, 762, 31, 245, 234, 553, 879, 931, 79, 690, 634, 857, 715, 587, 685, 754, 554, 813, 741, 448, 201, 804, 593, 641, 261, 12, 15, 318, 630, 888, 486, 187, 239, 59, 707, 658, 28, 797, 24, 9, 668, 529, 689, 737, 440, 352, 333, 697, 319, 916, 748, 137, 783, 986, 721, 860, 744, 100, 959, 758, 157, 422, 735, 113, 795, 919, 639, 767, 102, 572, 585, 686, 792, 344, 463, 219, 316, 555, 979, 594, 807, 119, 924, 295, 989, 43, 765, 718, 904, 560, 917, 941, 826, 337, 66, 508, 210, 590, 178, 706, 5, 911, 954, 310, 475, 605, 535, 691, 159, 501, 359, 284, 747, 583, 824, 193, 224, 890, 626, 846, 675, 332, 443, 357, 533, 99, 278, 101, 898, 107, 503, 120, 228, 547, 596, 830, 974, 129, 789, 467, 538, 56, 756, 297, 510, 527, 945, 877, 362, 190, 776, 591, 115, 398, 709, 900, 272, 678, 51, 580, 207, 88, 67, 674, 952, 551, 873, 456, 346, 498, 769, 383, 378, 612, 777, 240, 512, 971, 434, 477, 817, 895, 629, 611, 327, 655, 400, 962, 567, 736, 558, 505, 923, 579, 694, 328, 563, 794, 158, 627, 968, 243, 287, 996, 528, 235, 46, 431, 779, 313, 451, 649, 470, 384, 326, 885, 1, 532, 992, 390, 836, 138, 768, 0, 160, 910, 882, 6, 940, 745, 781, 812, 116, 25, 872, 87, 387, 90, 614, 487, 286, 948, 162, 149, 380, 610, 878, 835, 759, 592, 416, 749, 171, 16, 840, 929, 845, 206, 671, 526, 496, 875, 277, 700, 298, 683, 196, 716, 613, 586, 990, 345, 552, 112, 453, 106, 478, 230, 205, 793, 660, 335, 325, 828, 211, 374, 39, 282, 967, 673, 799, 406, 874, 232, 163, 522, 439, 238, 466, 404, 84, 936, 730, 561, 263, 446, 438, 200, 699, 726, 988, 784, 617, 684, 764, 883, 947, 862, 815, 654, 504, 851, 143, 244, 172, 867, 204, 364, 653, 437, 821, 656, 61, 932, 861, 413, 260, 859, 415, 722, 109, 435, 963, 548, 52, 856, 853, 808, 421, 259, 54, 571, 616, 570, 376, 216, 199, 866, 530, 348, 377, 34, 575, 303, 667, 237, 4, 198, 711, 268, 638, 307, 798, 166, 564, 887, 141, 474, 356, 89, 360, 469, 248, 312, 701, 657, 720, 542, 480, 215, 78, 30, 208, 843, 425, 751, 13, 981, 342, 965, 256, 525, 578, 778, 65, 473, 550, 646, 640, 810, 459, 485, 595, 407, 662, 145, 255, 927, 889, 980, 336, 191, 320, 322, 901, 55, 868, 688, 401, 93, 598, 637, 142, 847, 274, 2, 902, 125, 760, 264, 394, 733, 32, 214, 126, 994, 366, 64, 371, 301, 939, 680, 290, 270, 809, 122, 631, 935, 978, 973, 603, 306, 153, 957, 620, 531, 26, 802, 167, 429, 27, 225, 212, 461, 837, 918, 33, 47, 35, 921, 818, 541, 568, 943, 339, 279, 960, 361, 134, 98, 964, 757, 276, 652, 661, 782, 827, 146, 254, 676, 151, 949, 358, 91, 128, 515, 896, 95, 881, 465, 323, 831, 236, 829, 556, 42, 180, 839, 217, 865, 483, 363, 296, 972, 21, 189, 186, 250, 643, 221, 991, 993, 188, 915, 233, 353, 176, 663, 464, 648, 723, 450, 717, 175, 291, 670, 324, 229, 203, 482, 599, 36, 329, 123, 441, 38, 899, 77, 308, 833, 912, 44, 589, 998, 832, 418, 791, 300, 933, 96, 545, 385, 880, 946, 48, 961, 367, 92, 68, 801, 704, 252, 628, 549, 740, 445, 338, 481, 650, 682, 321, 995, 852, 405, 659, 679, 108, 154, 285, 544, 181, 601, 897, 341, 70, 124, 970, 848, 925, 698, 800, 388, 165, 351, 168, 623, 858, 502, 185, 942, 114, 492, 45, 71, 953, 997, 73, 169, 513, 111, 597, 600, 607, 725, 907, 908, 50, 999, 424, 350, 820, 289, 788, 242, 460, 855, 432, 803, 849, 247, 80, 546, 246, 490, 147, 969, 251, 182, 539, 343, 135, 23, 266, 841, 393, 937, 299, 419, 534, 926, 386, 479, 269, 796, 436, 74, 738, 702, 427, 402, 471, 60, 913, 666, 635, 774, 164, 3, 977, 489, 636, 373, 920, 976, 192, 755, 906, 382, 267, 317, 665, 315, 410, 174, 984, 94, 131, 518, 130, 724, 581, 63, 144, 183, 309, 732, 184, 69, 7, 928, 426, 588, 517, 62, 819, 892, 372, 771, 253, 966, 507, 379, 104, 543, 871, 476, 746, 903, 838, 642, 19, 734, 491, 349, 536, 396, 934, 985, 705, 139, 194, 669, 420, 222, 647, 951, 11, 280, 170, 57, 288, 823, 304, 20, 876, 624, 86, 938, 511, 816, 625, 281, 787, 117, 780, 695, 369, 514, 559, 391, 651, 608, 743, 987, 822, 576, 854, 909, 209, 213, 739, 447, 565, 891, 844, 105, 842, 231, 249, 773, 582, 884, 761, 283, 409, 805, 414, 10, 428, 488, 330, 375, 569, 417, 584, 811, 786, 672, 97, 955, 455, 729, 133, 444, 677, 118, 958, 22, 156, 922, 257, 18, 752, 863, 468, 516, 537, 331, 265, 982, 806, 712, 869, 710, 53, 305, 110, 430, 790, 370, 753, 750, 29, 727, 262, 540]
tensor(67.7827, grad_fn=<SumBackward0>)
tensor(117.1533, grad_fn=<SumBackward0>)
tensor(113.5438, grad_fn=<SumBackward0>)
tensor(215.1122, grad_fn=<SumBackward0>)
tensor(105.1917, grad_fn=<SumBackward0>)
tensor(158.1487, grad_fn=<SumBackward0>)
tensor(107.1045, grad_fn=<SumBackward0>)
tensor(95.6748, grad_fn=<SumBackward0>)
tensor(62.8531, grad_fn=<SumBackward0>)
tensor(85.4132, grad_fn=<SumBackward0>)
tensor(116.3323, grad_fn=<SumBackward0>)
tensor(49.8533, grad_fn=<SumBackward0>)
tensor(63.9350, grad_fn=<SumBackward0>)
tensor(44.5571, grad_fn=<SumBackward0>)
tensor(55.1493, grad_fn=<SumBackward0>)
tensor(86.5195, grad_fn=<SumBackward0>)
tensor(42.2702, grad_fn=<SumBackward0>)
tensor(73.1108, grad_fn=<SumBackward0>)
tensor(30.6224, grad_fn=<SumBackward0>)
tensor(42.4549, grad_fn=<SumBackward0>)
tensor(27.9516, grad_fn=<SumBackward0>)
tensor(97.8790, grad_fn=<SumBackward0>)
tensor(75.6242, grad_fn=<SumBackward0>)
tensor(48.1507, grad_fn=<SumBackward0>)
tensor(51.9945, grad_fn=<SumBackward0>)
tensor(58.8478, grad_fn=<SumBackward0>)
tensor(36.6413, grad_fn=<SumBackward0>)
tensor(48.9841, grad_fn=<SumBackward0>)
tensor(14.6110, grad_fn=<SumBackward0>)
tensor(29.5819, grad_fn=<SumBackward0>)
tensor(46.7404, grad_fn=<SumBackward0>)
tensor(25.1352, grad_fn=<SumBackward0>)
tensor(33.1688, grad_fn=<SumBackward0>)
tensor(38.3183, grad_fn=<SumBackward0>)
tensor(14.2252, grad_fn=<SumBackward0>)
tensor(16.2574, grad_fn=<SumBackward0>)
tensor(19.7283, grad_fn=<SumBackward0>)
tensor(20.3190, grad_fn=<SumBackward0>)
tensor(30.7336, grad_fn=<SumBackward0>)
tensor(33.4252, grad_fn=<SumBackward0>)
tensor(24.4382, grad_fn=<SumBackward0>)
tensor(14.4432, grad_fn=<SumBackward0>)
tensor(13.1483, grad_fn=<SumBackward0>)
tensor(12.2199, grad_fn=<SumBackward0>)
tensor(19.4195, grad_fn=<SumBackward0>)
tensor(15.7749, grad_fn=<SumBackward0>)
tensor(16.9939, grad_fn=<SumBackward0>)
tensor(9.8265, grad_fn=<SumBackward0>)
tensor(8.8733, grad_fn=<SumBackward0>)
tensor(9.3469, grad_fn=<SumBackward0>)
tensor(11.2737, grad_fn=<SumBackward0>)
tensor(15.7006, grad_fn=<SumBackward0>)
tensor(5.5667, grad_fn=<SumBackward0>)
tensor(6.4454, grad_fn=<SumBackward0>)
tensor(10.4538, grad_fn=<SumBackward0>)
tensor(7.6709, grad_fn=<SumBackward0>)
tensor(8.5816, grad_fn=<SumBackward0>)
tensor(10.6085, grad_fn=<SumBackward0>)
tensor(6.6396, grad_fn=<SumBackward0>)
tensor(4.7115, grad_fn=<SumBackward0>)
tensor(7.9156, grad_fn=<SumBackward0>)
tensor(5.1479, grad_fn=<SumBackward0>)
tensor(4.7194, grad_fn=<SumBackward0>)
tensor(6.7882, grad_fn=<SumBackward0>)
tensor(1.8761, grad_fn=<SumBackward0>)
tensor(2.9934, grad_fn=<SumBackward0>)
tensor(3.1363, grad_fn=<SumBackward0>)
tensor(2.3502, grad_fn=<SumBackward0>)
tensor(4.8117, grad_fn=<SumBackward0>)
tensor(1.9315, grad_fn=<SumBackward0>)
tensor(4.3657, grad_fn=<SumBackward0>)
tensor(2.9920, grad_fn=<SumBackward0>)
tensor(0.6157, grad_fn=<SumBackward0>)
tensor(1.4580, grad_fn=<SumBackward0>)
tensor(1.4414, grad_fn=<SumBackward0>)
tensor(0.4393, grad_fn=<SumBackward0>)
tensor(2.4937, grad_fn=<SumBackward0>)
tensor(2.3860, grad_fn=<SumBackward0>)
tensor(1.0636, grad_fn=<SumBackward0>)
tensor(1.3308, grad_fn=<SumBackward0>)
tensor(0.8680, grad_fn=<SumBackward0>)
tensor(1.0633, grad_fn=<SumBackward0>)
tensor(1.7914, grad_fn=<SumBackward0>)
tensor(0.9368, grad_fn=<SumBackward0>)
tensor(1.3796, grad_fn=<SumBackward0>)
tensor(1.0458, grad_fn=<SumBackward0>)
tensor(0.6454, grad_fn=<SumBackward0>)
tensor(1.2953, grad_fn=<SumBackward0>)
tensor(1.1971, grad_fn=<SumBackward0>)
tensor(1.1476, grad_fn=<SumBackward0>)
tensor(0.7912, grad_fn=<SumBackward0>)
tensor(0.6744, grad_fn=<SumBackward0>)
tensor(0.1891, grad_fn=<SumBackward0>)
tensor(0.5931, grad_fn=<SumBackward0>)
tensor(0.6550, grad_fn=<SumBackward0>)
tensor(1.1027, grad_fn=<SumBackward0>)
tensor(0.9342, grad_fn=<SumBackward0>)
tensor(0.4318, grad_fn=<SumBackward0>)
tensor(0.5007, grad_fn=<SumBackward0>)
tensor(0.3387, grad_fn=<SumBackward0>)
epoch 1, loss 0.046351
[953, 633, 549, 662, 673, 440, 197, 184, 600, 75, 540, 382, 470, 177, 650, 14, 194, 935, 982, 349, 589, 1, 284, 873, 441, 379, 834, 264, 957, 282, 701, 551, 498, 131, 802, 119, 454, 47, 805, 79, 428, 925, 291, 948, 178, 969, 183, 721, 891, 555, 34, 560, 892, 640, 101, 354, 51, 814, 944, 193, 616, 448, 974, 796, 403, 501, 652, 597, 610, 169, 856, 186, 28, 769, 986, 609, 294, 579, 503, 65, 825, 815, 865, 493, 222, 595, 35, 968, 853, 955, 991, 298, 648, 478, 181, 512, 768, 413, 931, 414, 627, 200, 120, 559, 411, 766, 681, 871, 889, 369, 537, 26, 346, 940, 93, 112, 490, 214, 159, 324, 779, 514, 584, 672, 795, 336, 381, 592, 422, 588, 434, 107, 599, 665, 419, 522, 507, 71, 883, 99, 727, 900, 635, 800, 485, 712, 843, 711, 224, 872, 244, 761, 347, 471, 536, 163, 145, 644, 269, 56, 385, 639, 840, 117, 311, 124, 215, 308, 160, 715, 510, 960, 921, 109, 511, 416, 735, 86, 903, 521, 246, 647, 985, 577, 153, 920, 265, 226, 442, 970, 972, 979, 161, 677, 172, 472, 319, 195, 42, 533, 106, 190, 810, 247, 698, 927, 774, 570, 283, 85, 318, 932, 170, 180, 225, 809, 16, 89, 660, 400, 280, 3, 146, 552, 862, 275, 682, 778, 915, 138, 13, 32, 36, 288, 740, 10, 946, 316, 114, 742, 937, 656, 373, 622, 617, 987, 330, 693, 266, 281, 469, 902, 342, 462, 234, 780, 849, 581, 919, 704, 213, 524, 732, 133, 303, 273, 938, 992, 49, 468, 535, 757, 966, 686, 476, 687, 355, 111, 447, 374, 251, 487, 343, 746, 76, 789, 221, 44, 975, 762, 548, 202, 621, 323, 453, 452, 734, 433, 984, 17, 348, 983, 406, 77, 784, 956, 203, 663, 717, 327, 859, 692, 252, 486, 310, 947, 424, 158, 450, 799, 882, 301, 375, 431, 389, 867, 198, 819, 267, 409, 697, 393, 307, 231, 110, 733, 783, 841, 74, 999, 748, 839, 171, 399, 924, 604, 370, 27, 339, 6, 91, 690, 98, 218, 115, 641, 775, 357, 790, 605, 842, 188, 713, 643, 205, 618, 351, 898, 326, 954, 695, 596, 152, 11, 716, 585, 567, 666, 83, 564, 823, 254, 136, 64, 696, 855, 864, 792, 437, 417, 210, 723, 964, 527, 92, 811, 532, 671, 356, 929, 556, 513, 876, 699, 60, 755, 851, 430, 655, 306, 255, 25, 720, 259, 868, 418, 630, 274, 679, 334, 88, 725, 165, 415, 689, 603, 578, 818, 509, 634, 612, 813, 408, 129, 631, 220, 866, 474, 601, 20, 229, 147, 563, 691, 752, 217, 309, 905, 688, 941, 917, 37, 377, 332, 637, 492, 980, 118, 325, 345, 398, 880, 743, 730, 139, 575, 770, 798, 914, 460, 933, 166, 378, 870, 568, 329, 315, 930, 820, 675, 850, 718, 0, 951, 45, 854, 910, 500, 547, 365, 657, 449, 352, 87, 73, 918, 463, 126, 156, 676, 544, 702, 626, 66, 678, 81, 261, 243, 528, 710, 19, 331, 781, 313, 624, 771, 444, 140, 480, 148, 359, 611, 680, 736, 300, 830, 519, 58, 550, 295, 337, 167, 372, 397, 619, 539, 367, 989, 211, 629, 558, 977, 794, 857, 54, 219, 729, 586, 121, 84, 505, 436, 747, 520, 913, 669, 451, 557, 832, 829, 651, 33, 997, 390, 41, 516, 262, 706, 906, 176, 566, 473, 61, 173, 782, 134, 427, 29, 446, 684, 150, 745, 407, 396, 237, 46, 481, 726, 572, 235, 496, 18, 96, 168, 94, 530, 333, 939, 30, 561, 847, 538, 429, 80, 614, 277, 545, 227, 661, 994, 806, 670, 961, 890, 804, 410, 922, 122, 758, 123, 753, 852, 777, 271, 608, 412, 875, 576, 59, 287, 388, 286, 767, 250, 816, 593, 664, 43, 623, 256, 113, 317, 517, 432, 949, 897, 943, 508, 833, 879, 491, 878, 257, 667, 707, 477, 942, 646, 162, 625, 78, 996, 455, 569, 127, 518, 344, 263, 988, 443, 534, 457, 141, 68, 105, 95, 22, 554, 962, 386, 531, 738, 328, 290, 506, 394, 826, 837, 206, 314, 48, 57, 793, 479, 335, 82, 628, 489, 759, 299, 504, 142, 542, 967, 464, 658, 683, 636, 466, 439, 907, 824, 885, 5, 404, 615, 803, 700, 598, 499, 238, 368, 228, 40, 928, 38, 189, 164, 278, 305, 272, 731, 645, 103, 72, 808, 285, 293, 465, 722, 744, 15, 887, 801, 338, 423, 182, 467, 765, 358, 668, 292, 456, 362, 132, 869, 438, 590, 613, 425, 822, 179, 241, 874, 475, 426, 863, 187, 632, 737, 135, 571, 236, 494, 108, 401, 361, 904, 185, 353, 724, 151, 130, 144, 993, 67, 31, 846, 149, 916, 893, 364, 340, 772, 8, 776, 642, 209, 976, 894, 786, 2, 525, 861, 607, 526, 845, 973, 965, 541, 649, 685, 812, 24, 981, 341, 155, 97, 482, 384, 741, 405, 175, 945, 12, 848, 523, 192, 543, 249, 4, 376, 580, 998, 223, 901, 276, 836, 391, 553, 591, 756, 659, 817, 242, 21, 971, 546, 392, 289, 606, 573, 574, 297, 90, 420, 201, 363, 312, 911, 638, 233, 137, 445, 268, 895, 620, 216, 958, 239, 279, 208, 270, 232, 749, 53, 714, 154, 435, 739, 350, 886, 583, 458, 321, 582, 371, 70, 102, 952, 360, 100, 719, 497, 728, 827, 104, 62, 959, 899, 674, 934, 174, 881, 230, 860, 807, 763, 383, 594, 7, 50, 751, 912, 950, 654, 125, 387, 990, 296, 936, 908, 773, 694, 55, 978, 245, 923, 705, 191, 39, 52, 484, 421, 258, 204, 199, 240, 750, 116, 828, 495, 529, 653, 207, 821, 602, 157, 838, 877, 143, 926, 709, 835, 260, 461, 63, 196, 253, 9, 760, 858, 995, 502, 703, 302, 888, 304, 896, 69, 587, 963, 565, 128, 402, 754, 320, 488, 483, 884, 785, 248, 380, 831, 909, 366, 459, 787, 788, 395, 23, 791, 212, 515, 708, 844, 562, 797, 322, 764]
tensor(0.2391, grad_fn=<SumBackward0>)
tensor(0.7651, grad_fn=<SumBackward0>)
tensor(0.3178, grad_fn=<SumBackward0>)
tensor(0.4679, grad_fn=<SumBackward0>)
tensor(0.1211, grad_fn=<SumBackward0>)
tensor(0.3127, grad_fn=<SumBackward0>)
tensor(0.2704, grad_fn=<SumBackward0>)
tensor(0.2362, grad_fn=<SumBackward0>)
tensor(0.5525, grad_fn=<SumBackward0>)
tensor(0.2884, grad_fn=<SumBackward0>)
tensor(0.1896, grad_fn=<SumBackward0>)
tensor(0.4025, grad_fn=<SumBackward0>)
tensor(0.1706, grad_fn=<SumBackward0>)
tensor(0.2184, grad_fn=<SumBackward0>)
tensor(0.2850, grad_fn=<SumBackward0>)
tensor(0.2673, grad_fn=<SumBackward0>)
tensor(0.1261, grad_fn=<SumBackward0>)
tensor(0.1292, grad_fn=<SumBackward0>)
tensor(0.1991, grad_fn=<SumBackward0>)
tensor(0.0820, grad_fn=<SumBackward0>)
tensor(0.1273, grad_fn=<SumBackward0>)
tensor(0.1017, grad_fn=<SumBackward0>)
tensor(0.1842, grad_fn=<SumBackward0>)
tensor(0.1716, grad_fn=<SumBackward0>)
tensor(0.0977, grad_fn=<SumBackward0>)
tensor(0.1690, grad_fn=<SumBackward0>)
tensor(0.0808, grad_fn=<SumBackward0>)
tensor(0.0973, grad_fn=<SumBackward0>)
tensor(0.0798, grad_fn=<SumBackward0>)
tensor(0.1048, grad_fn=<SumBackward0>)
tensor(0.1215, grad_fn=<SumBackward0>)
tensor(0.0633, grad_fn=<SumBackward0>)
tensor(0.0408, grad_fn=<SumBackward0>)
tensor(0.0505, grad_fn=<SumBackward0>)
tensor(0.0739, grad_fn=<SumBackward0>)
tensor(0.0654, grad_fn=<SumBackward0>)
tensor(0.0636, grad_fn=<SumBackward0>)
tensor(0.0853, grad_fn=<SumBackward0>)
tensor(0.0330, grad_fn=<SumBackward0>)
tensor(0.0462, grad_fn=<SumBackward0>)
tensor(0.0588, grad_fn=<SumBackward0>)
tensor(0.0384, grad_fn=<SumBackward0>)
tensor(0.0452, grad_fn=<SumBackward0>)
tensor(0.0291, grad_fn=<SumBackward0>)
tensor(0.0142, grad_fn=<SumBackward0>)
tensor(0.0369, grad_fn=<SumBackward0>)
tensor(0.0410, grad_fn=<SumBackward0>)
tensor(0.0332, grad_fn=<SumBackward0>)
tensor(0.0377, grad_fn=<SumBackward0>)
tensor(0.0261, grad_fn=<SumBackward0>)
tensor(0.0188, grad_fn=<SumBackward0>)
tensor(0.0276, grad_fn=<SumBackward0>)
tensor(0.0250, grad_fn=<SumBackward0>)
tensor(0.0279, grad_fn=<SumBackward0>)
tensor(0.0275, grad_fn=<SumBackward0>)
tensor(0.0295, grad_fn=<SumBackward0>)
tensor(0.0271, grad_fn=<SumBackward0>)
tensor(0.0113, grad_fn=<SumBackward0>)
tensor(0.0214, grad_fn=<SumBackward0>)
tensor(0.0107, grad_fn=<SumBackward0>)
tensor(0.0129, grad_fn=<SumBackward0>)
tensor(0.0078, grad_fn=<SumBackward0>)
tensor(0.0082, grad_fn=<SumBackward0>)
tensor(0.0161, grad_fn=<SumBackward0>)
tensor(0.0054, grad_fn=<SumBackward0>)
tensor(0.0113, grad_fn=<SumBackward0>)
tensor(0.0085, grad_fn=<SumBackward0>)
tensor(0.0040, grad_fn=<SumBackward0>)
tensor(0.0073, grad_fn=<SumBackward0>)
tensor(0.0105, grad_fn=<SumBackward0>)
tensor(0.0093, grad_fn=<SumBackward0>)
tensor(0.0089, grad_fn=<SumBackward0>)
tensor(0.0088, grad_fn=<SumBackward0>)
tensor(0.0068, grad_fn=<SumBackward0>)
tensor(0.0022, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0044, grad_fn=<SumBackward0>)
tensor(0.0074, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0039, grad_fn=<SumBackward0>)
tensor(0.0049, grad_fn=<SumBackward0>)
tensor(0.0069, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0042, grad_fn=<SumBackward0>)
tensor(0.0023, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0038, grad_fn=<SumBackward0>)
tensor(0.0026, grad_fn=<SumBackward0>)
tensor(0.0037, grad_fn=<SumBackward0>)
tensor(0.0016, grad_fn=<SumBackward0>)
tensor(0.0028, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0042, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0041, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0018, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0014, grad_fn=<SumBackward0>)
tensor(0.0026, grad_fn=<SumBackward0>)
epoch 2, loss 0.000188
[917, 691, 49, 210, 584, 666, 361, 586, 429, 839, 684, 401, 842, 598, 747, 708, 102, 347, 984, 433, 969, 493, 450, 943, 312, 530, 65, 499, 623, 402, 756, 856, 216, 52, 517, 152, 139, 806, 924, 63, 901, 314, 183, 632, 70, 388, 688, 432, 442, 161, 580, 821, 172, 338, 910, 430, 791, 822, 590, 526, 928, 278, 808, 29, 38, 934, 318, 798, 239, 408, 715, 741, 215, 575, 621, 961, 595, 343, 829, 921, 95, 686, 711, 744, 922, 470, 414, 121, 607, 633, 185, 41, 101, 883, 50, 471, 832, 991, 6, 293, 757, 133, 721, 878, 269, 436, 853, 995, 252, 157, 329, 613, 828, 92, 407, 848, 76, 608, 489, 665, 180, 37, 193, 634, 42, 616, 541, 809, 840, 732, 977, 664, 572, 162, 946, 464, 781, 477, 88, 340, 965, 44, 286, 206, 495, 386, 109, 357, 197, 679, 956, 444, 728, 610, 363, 566, 796, 837, 573, 923, 896, 786, 66, 531, 251, 188, 833, 776, 18, 716, 322, 899, 795, 214, 345, 313, 280, 140, 702, 897, 479, 657, 261, 428, 17, 356, 869, 246, 523, 563, 72, 516, 879, 97, 264, 217, 909, 812, 907, 195, 882, 22, 640, 342, 160, 456, 764, 145, 805, 843, 218, 524, 68, 710, 962, 824, 205, 250, 486, 955, 426, 245, 325, 319, 538, 554, 520, 362, 790, 507, 527, 730, 615, 867, 758, 220, 769, 494, 851, 120, 420, 208, 814, 462, 455, 819, 184, 891, 712, 448, 645, 847, 515, 46, 654, 918, 792, 328, 858, 797, 782, 932, 71, 94, 233, 605, 192, 176, 734, 504, 199, 114, 994, 815, 678, 60, 807, 369, 763, 938, 222, 241, 594, 742, 886, 511, 999, 223, 925, 662, 625, 77, 457, 297, 653, 774, 498, 4, 229, 177, 169, 67, 491, 680, 964, 952, 309, 630, 546, 344, 56, 836, 579, 411, 116, 288, 628, 125, 237, 291, 404, 626, 698, 16, 59, 649, 550, 422, 745, 510, 202, 364, 236, 593, 446, 476, 571, 385, 142, 55, 480, 232, 179, 316, 868, 902, 168, 131, 931, 509, 271, 62, 61, 659, 485, 255, 889, 405, 170, 532, 754, 219, 724, 349, 10, 784, 835, 849, 198, 138, 24, 390, 600, 81, 75, 518, 284, 560, 112, 496, 722, 861, 253, 599, 761, 394, 762, 958, 490, 14, 196, 110, 425, 467, 127, 528, 929, 64, 547, 118, 537, 398, 631, 862, 987, 939, 967, 134, 89, 21, 259, 20, 694, 884, 149, 652, 548, 877, 282, 452, 380, 544, 226, 304, 144, 750, 866, 99, 692, 238, 443, 164, 447, 331, 622, 693, 894, 167, 391, 359, 327, 617, 813, 933, 100, 287, 681, 130, 27, 905, 788, 567, 453, 683, 667, 587, 151, 91, 519, 187, 926, 279, 194, 324, 783, 124, 529, 800, 949, 397, 290, 827, 690, 785, 204, 936, 753, 726, 339, 802, 225, 676, 221, 979, 841, 141, 240, 40, 80, 940, 333, 890, 502, 8, 697, 45, 663, 787, 903, 880, 641, 230, 998, 729, 643, 36, 957, 870, 336, 39, 636, 638, 163, 772, 564, 23, 581, 596, 935, 644, 337, 888, 58, 500, 1, 249, 503, 267, 300, 570, 51, 864, 627, 919, 731, 302, 84, 260, 973, 461, 47, 501, 947, 963, 424, 650, 846, 887, 589, 765, 191, 189, 986, 558, 311, 695, 281, 539, 898, 459, 661, 482, 709, 718, 73, 108, 651, 799, 713, 423, 258, 416, 381, 735, 540, 954, 990, 186, 648, 406, 760, 674, 366, 421, 273, 881, 35, 209, 568, 981, 473, 585, 257, 379, 410, 551, 435, 830, 816, 906, 294, 382, 714, 717, 242, 647, 927, 658, 9, 850, 211, 854, 992, 69, 79, 552, 512, 971, 153, 820, 699, 535, 438, 624, 801, 553, 400, 895, 33, 749, 916, 454, 970, 857, 350, 859, 113, 171, 469, 147, 352, 619, 838, 852, 578, 413, 670, 201, 775, 831, 738, 415, 682, 602, 656, 53, 396, 513, 727, 111, 863, 90, 437, 32, 276, 794, 966, 85, 468, 488, 915, 48, 263, 159, 227, 54, 34, 997, 565, 478, 126, 911, 417, 11, 15, 817, 912, 384, 811, 985, 5, 26, 937, 505, 0, 549, 989, 173, 298, 315, 689, 353, 591, 83, 231, 3, 588, 705, 392, 78, 577, 295, 272, 855, 778, 441, 87, 642, 7, 354, 254, 574, 743, 618, 213, 12, 283, 770, 740, 19, 148, 393, 484, 165, 224, 874, 983, 672, 292, 330, 200, 767, 358, 487, 243, 685, 175, 82, 419, 976, 865, 675, 375, 612, 536, 212, 439, 181, 555, 323, 460, 128, 117, 637, 399, 307, 592, 378, 143, 346, 860, 472, 978, 707, 368, 96, 86, 123, 746, 723, 875, 387, 542, 403, 556, 247, 277, 317, 673, 629, 522, 207, 687, 872, 914, 492, 755, 825, 320, 951, 136, 620, 733, 427, 920, 959, 704, 299, 374, 803, 562, 256, 445, 119, 305, 365, 751, 166, 137, 609, 98, 968, 377, 321, 235, 203, 451, 190, 534, 942, 107, 132, 773, 497, 335, 156, 696, 367, 725, 950, 412, 303, 972, 844, 752, 268, 576, 930, 514, 639, 904, 30, 913, 105, 326, 228, 521, 348, 777, 122, 372, 873, 458, 557, 474, 780, 234, 597, 974, 871, 371, 993, 582, 475, 466, 759, 289, 154, 508, 306, 308, 569, 975, 310, 876, 988, 43, 376, 845, 739, 771, 953, 158, 737, 150, 395, 135, 265, 779, 341, 908, 893, 285, 545, 980, 525, 296, 748, 370, 418, 25, 463, 736, 703, 103, 74, 982, 465, 789, 409, 351, 266, 373, 106, 583, 104, 389, 668, 941, 13, 996, 543, 701, 449, 483, 31, 262, 434, 481, 677, 301, 360, 885, 700, 57, 2, 559, 948, 646, 945, 768, 174, 28, 810, 182, 178, 655, 766, 601, 614, 115, 706, 900, 334, 129, 671, 606, 892, 660, 944, 248, 332, 669, 720, 561, 826, 274, 719, 603, 275, 804, 793, 834, 355, 960, 506, 431, 244, 533, 93, 823, 440, 146, 818, 604, 155, 635, 611, 270, 383]
tensor(0.0025, grad_fn=<SumBackward0>)
tensor(0.0043, grad_fn=<SumBackward0>)
tensor(0.0019, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0021, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0017, grad_fn=<SumBackward0>)
tensor(0.0020, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0018, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0012, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0018, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0013, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0011, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0010, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0009, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0008, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0007, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0002, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0006, grad_fn=<SumBackward0>)
tensor(0.0001, grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<SumBackward0>)
tensor(0.0004, grad_fn=<SumBackward0>)
tensor(0.0003, grad_fn=<SumBackward0>)
epoch 3, loss 0.000050
w的估计误差: tensor([ 6.2823e-05, -6.0439e-04], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0009], grad_fn=<RsubBackward1>)