在PyTorch中,可以使用torch.cat()
函数来拼接张量。torch.cat()
函数需要指定沿着哪个维度进行拼接。以下是一些示例:
- 沿第一个维度(axis=0)拼接两个相同形状的张量:
import torch tensor1 = torch.randn(2, 3) tensor2 = torch.randn(2, 3) result = torch.cat((tensor1, tensor2), dim=0) print(result.shape) # 输出:(4, 3)
- 沿第二个维度(axis=1)拼接两个相同形状的张量:
import torch tensor1 = torch.randn(2, 3) tensor2 = torch.randn(2, 3) result = torch.cat((tensor1, tensor2), dim=1) print(result.shape) # 输出:(2, 6)
- 沿第三个维度(axis=2)拼接两个相同形状的张量:
import torch tensor1 = torch.randn(2, 3, 4) tensor2 = torch.randn(2, 3, 4) result = torch.cat((tensor1, tensor2), dim=2) print(result.shape) # 输出:(2, 3, 8)
请注意,要沿指定维度拼接张量,它们的形状必须相同。例如,如果沿第一个维度拼接,张量的形状必须为(batch_size, input_dim1, input_dim2, ...)
。