官方文档: torch.bmm(), version=1.2.0
torch.bmm()用于实现矩阵的乘法,其引用方式为:
1 | torch.bmm(input, mat2, deterministic = False, out = None) |
除了需要满足input和mat2的维度一致,相似于线性代数的矩阵相乘,bmm()的矩阵乘法也约束了当input的维度为$(b\times n\times m)$,mat2的维度应为$(b\times m\times p)$,相乘结果out维度为$(b\times n\times p)$。
1 | # example |