0%

bmm function(待完善)

官方文档: torch.bmm(), version=1.2.0

torch.bmm()用于实现矩阵的乘法,其引用方式为:

1
2
3
4
5
6
torch.bmm(input, mat2, deterministic = False, out = None)
# 参数:
# !!!input和mat2的都需要是3-D tensor,即维度等于3
# input(Tensor): 第一个用于矩阵乘法的batch
# mat2(Tensor): 第二个用于矩阵乘法的batch
# out(Tensor, optional): 输出结果

除了需要满足inputmat2的维度一致,相似于线性代数的矩阵相乘,bmm()的矩阵乘法也约束了当input的维度为(b×n×m)mat2的维度应为(b×m×p),相乘结果out维度为(b×n×p)

1
2
3
4
5
6
7
8
9
10
11
12
# example
import torch
# optional
out1 = torch.empty(0)
input = torch.randn(3,4,5)
mat2 = torch.randn(3,5,6)
# method 1
torch.bmm(input, mat2, out = out1)
# method 2
torch.bmm(input, mat2, out = mat2)
# method 3
out2 = torch.bmm(input, mat2)