bmm function(待完善) 发表于 2020-08-24 更新于 2020-08-29 分类于 Pytorch function 阅读次数: 15 本文字数: 595 阅读时长 ≈ 1 分钟 官方文档: torch.bmm(), version=1.2.0 torch.bmm()用于实现矩阵的乘法,其引用方式为: 123456torch.bmm(input, mat2, deterministic = False, out = None)# 参数:# !!!input和mat2的都需要是3-D tensor,即维度等于3# input(Tensor): 第一个用于矩阵乘法的batch# mat2(Tensor): 第二个用于矩阵乘法的batch# out(Tensor, optional): 输出结果 除了需要满足input和mat2的维度一致,相似于线性代数的矩阵相乘,bmm()的矩阵乘法也约束了当input的维度为(b×n×m),mat2的维度应为(b×m×p),相乘结果out维度为(b×n×p)。 123456789101112# exampleimport torch# optionalout1 = torch.empty(0)input = torch.randn(3,4,5)mat2 = torch.randn(3,5,6)# method 1torch.bmm(input, mat2, out = out1)# method 2torch.bmm(input, mat2, out = mat2)# method 3out2 = torch.bmm(input, mat2)