PyTorch 张量的广播(broadcasting)机制允许在不同形状的张量之间进行算术运算。广播会按照一定的规则自动扩展较小张量的维度,使其与较大张量的维度匹配,然后进行逐元素(element-wise)运算。
以下是 PyTorch 广播的基本规则:
- 如果两个张量的维度不同,将较小的维度扩展为与较大的维度相同。扩展的方式是在较小的维度前面补 1。
- 如果两个张量在某一个维度上的大小相同,或者其中一个张量在该维度上的大小为 1,则这两个张量在该维度上可以进行广播。
- 如果两个张量在所有维度上的大小都相同,那么它们可以直接进行逐元素运算。否则,会抛出错误。
下面是一个简单的例子来说明 PyTorch 中的广播机制:
import torch # 创建两个张量 a = torch.tensor([[1., 2.], [3., 4.]]) b = torch.tensor([1., 2.]) # 广播 b 到与 a 相同的形状 b_expanded = b.expand(-1, -1) # 进行逐元素运算 result = a + b_expanded print(result)
输出结果:
tensor([[2., 4.], [4., 6.]])
在这个例子中,我们首先创建了一个形状为 (2, 2) 的张量 a
和一个形状为 (2,) 的张量 b
。然后,我们使用 expand
方法将 b
扩展为与 a
相同的形状。最后,我们对两个张量进行逐元素加法运算,得到结果。