def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
torch.gather
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
dim에 해당하는 axis에서 해당하는 index의 value만을 gather하는 함수
3차원 tensor output을 예시로 들면,
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Example:
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
dim = 1이고, index는 [[0,0], [1,0]] 이다.
z = index [ i, j ]
0 = index [0,0]
0 = index [0,1]
1 = index [1,0]
0 = index [1,1]
result [i , j] = t [i ,z]
result [0 , 0] = t [0 ,0]
result [0 , 1] = t [0 ,0]
result [1 , 0] = t [1 ,1]
result [1 , 1] = t [1 ,0]
따라서 최종적인 결과는 result = [[1, 1], [4, 3]] 이 된다.
실제 code에서 사용한 예시로는 FariMOT: utils.py 코드
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
nms를 하기 위해서 sorting한 후에 ind 500개를 만들어 놓고, feat에서 500개만을 선택해서 가져오는 것이다.
ind에는 뽑아야 하는 feat의 idx가 있기 때문에 이에 해당하는 값만을 feat에서 가져온다.
참고자료:
잘 설명해주신 블로그입니다.
velog.io/@nawnoes/torch.gather%EB%9E%80
pytorch.org/docs/stable/generated/torch.gather.html
'code log > PyTorch log' 카테고리의 다른 글
PyTorch: cumsum (0) | 2020.12.03 |
---|---|
PyTorch: matmul, mm, bmm (0) | 2020.12.03 |