官方文档: torch.bmm(), version=1.2.0
torch.bmm()
用于实现矩阵的乘法,其引用方式为:
1 | torch.bmm(input, mat2, deterministic = False, out = None) |
除了需要满足input
和mat2
的维度一致,相似于线性代数的矩阵相乘,bmm()
的矩阵乘法也约束了当input
的维度为$(b\times n\times m)$,mat2
的维度应为$(b\times m\times p)$,相乘结果out
维度为$(b\times n\times p)$。
1 | # example |