본문 바로가기

code log/PyTorch log

PyTorch: matmul, mm, bmm

torch.matmul

vector 및 matrix 간의 다양한 곱을 수행한다. 

broadcast 기능을 제공하며 가장 일반적으로 사용되나, broadcast 기능이 도리어 debug point가 될 수 있다.

broadcast 기능은 아래의 예제와 같이 T1(10, 3, 4) T2(4)을 곱할 때, 맨 앞의 dim이 3개 일 때는 첫 dim을 batch로 간주하고 T1 (3, 4) tensor의 10개의 batch와 각각 T2(4)랑 곱을 해주는 것이다. 

torch.matmul(input, other, *, out=None) → Tensor

torch.mm

torch.matmul과 차이점은 broadcast가 안 된다는 점이다.

즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다.

따라서 내가 작성한 코드가 의도대로 작동하는 지 확인을 위해서 mm의 사용이 적절하다는 생각이 든다. (debug point)

 

torch.mm(input, mat2, *, out=None) → Tensor

input의 size: (n x m) 

mat2의 size: (m x p)

output의 size: (n x p)

 

torch.bmm

torch.matmul과 차이점은 broadcast가 안 된다는 점이다.

즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다.

따라서 내가 작성한 코드가 의도대로 작동하는 지 확인을 위해서 mm의 사용이 적절하다는 생각이 든다. (debug point)

 

 

 

 

'code log > PyTorch log' 카테고리의 다른 글

PyTorch: gather  (0) 2021.02.04
PyTorch: cumsum  (0) 2020.12.03