PyTorch einsum
1) Matrix multiplication
Previously: torch.matmul(aten, bten) ; aten.mm(bten)
einsum: np.einsum("ij, jk -> ik", arr1, arr2)
11) Batch Matrix Multiplication
Previously: torch.bmm(batch_tensor_1, batch_tensor_2)
einsum: np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)
Out[15]:
tensor([[[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[ 128, 158, 188, 218]],
[[ 632, 671, 710, 749],
[ 776, 824, 872, 920],
[ 920, 977, 1034, 1091],
[1064, 1130, 1196, 1262]]])
# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape
Out[16]: torch.Size([2, 4, 4])Last updated
Was this helpful?