PyTorch: gather
def _gather_feat(feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat torch.gather torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor dim에 해당하는 axis에서 해당하는 index의 value만을 gather하는 함..
PyTorch: matmul, mm, bmm
torch.matmul vector 및 matrix 간의 다양한 곱을 수행한다. broadcast 기능을 제공하며 가장 일반적으로 사용되나, broadcast 기능이 도리어 debug point가 될 수 있다. broadcast 기능은 아래의 예제와 같이 T1(10, 3, 4) T2(4)을 곱할 때, 맨 앞의 dim이 3개 일 때는 첫 dim을 batch로 간주하고 T1 (3, 4) tensor의 10개의 batch와 각각 T2(4)랑 곱을 해주는 것이다. torch.matmul(input, other, *, out=None) → Tensor torch.mm torch.matmul과 차이점은 broadcast가 안 된다는 점이다. 즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다. 따..