본문 바로가기

code log/PyTorch log

PyTorch: gather

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2)
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

 

dim에 해당하는 axis에서 해당하는 index의 value만을 gather하는 함수

3차원 tensor output을 예시로 들면,

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

 

Example:

>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

dim = 1이고, index는 [[0,0], [1,0]] 이다.

 

z = index [ i, j ]

0 = index [0,0]

0 = index [0,1]

1 = index [1,0]

0 = index [1,1]

 

result [i , j] = t [i ,z]

result [0 , 0] = t [0 ,0]

result [0 , 1] = t [0 ,0]

result [1 , 0] = t [1 ,1]

result [1 , 1] = t [1 ,0]

 

따라서 최종적인 결과는 result = [[1, 1], [4, 3]] 이 된다. 

 

 

실제 code에서 사용한 예시로는 FariMOT: utils.py 코드 

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2)
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

nms를 하기 위해서 sorting한 후에 ind 500개를 만들어 놓고, feat에서 500개만을 선택해서 가져오는 것이다. 

ind에는 뽑아야 하는 feat의 idx가 있기 때문에 이에 해당하는 값만을 feat에서 가져온다. 

 

 

 

 

참고자료:

 

잘 설명해주신 블로그입니다. 

velog.io/@nawnoes/torch.gather%EB%9E%80

 

torch.gather란

간혹 깃헙을 보다보면 torch gether를 볼수가 있는데 어떻게 동작하는 건지 잘 이해되지 않아 정리해본다.input 텐서가 입력으로 주어지고, 차원 dim을 따라서 각 행으로부터 값을 취해, 새로운 텐서

velog.io

 

pytorch.org/docs/stable/generated/torch.gather.html

 

torch.gather — PyTorch 1.7.0 documentation

Shortcuts

pytorch.org

 

'code log > PyTorch log' 카테고리의 다른 글

PyTorch: cumsum  (0) 2020.12.03
PyTorch: matmul, mm, bmm  (0) 2020.12.03