广播机制

广播机制#

定义#

广播机制规则,如果遵守以下规则,则两个tensor是“可广播的”:

  • 每个tensor至少有一个维度;

  • 遍历tensor所有维度时,从末尾开始遍历(从右往左开始遍历)(从后往前开始遍历),两个tensor存在下列情况:

    • tensor维度相等。

    • tensor维度不等且其中一个维度为1。

    • tensor维度不等且其中一个维度不存在。

如果两个tensor是“可广播的”,则计算过程遵循下列规则:

  • 如果两个tensor的维度不同,则在维度较小的tensor的前面增加维度,使它们维度相等。

  • 对于每个维度,计算结果的维度值取两个tensor中较大的那个值。

  • 两个tensor扩展维度的过程是将数值进行复制。

举例#

import torch

相同维度,一定可以 broadcasting。

# 相同维度,一定可以 broadcasting
x=torch.ones(5,7,3)
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
(torch.Size([5, 7, 3]), torch.Size([5, 7, 3]), torch.Size([5, 7, 3]))

x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting。

# x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting
x=torch.ones((0,))
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 4
      2 x=torch.ones((0,))
      3 y=torch.ones(5,7,3)
----> 4 z = x+y
      5 x.shape,y.shape,z.shape

RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 2

x 和 y 可以广播。

# x 和 y 可以广播
x=torch.ones(5,3,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape
# 从尾部维度开始遍历
# 1st尾部维度: x和y相同,都为1。
# 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。
# 3rd尾部维度: x和y相同,都为3。
# 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
(torch.Size([5, 3, 4, 1]), torch.Size([3, 1, 1]), torch.Size([5, 3, 4, 1]))

x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。

# x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。
x=torch.ones(5,2,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 4
      2 x=torch.ones(5,2,4,1)
      3 y=torch.ones(  3,1,1)
----> 4 z = x+y
      5 x.shape,y.shape,z.shape

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等,同时使他们维度大小相同。

# x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。
x=torch.ones(5,2,4,1)
y=torch.ones(1,1)
z = x+y
x.shape,y.shape,z.shape
(torch.Size([5, 2, 4, 1]), torch.Size([1, 1]), torch.Size([5, 2, 4, 1]))