본문 바로가기

code log/PyTorch log

(3)
PyTorch: gather def _gather_feat(feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat torch.gather torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor dim에 해당하는 axis에서 해당하는 index의 value만을 gather하는 함..
PyTorch: cumsum torch.cumsum() torch.cumsum(input, dim, *, dtype=None, out=None) → Tensor dim=N의 방향으로 누적합을 구하는 함수 만약 input이 N size의 백터라면, 결과는 N size의 같은 크기의 백터를 뱉는다. 실제 사용 예: semantic segmentation처럼 dense label을 만들 때, ignore vlaue를 제외한 mask를 만들거나 foreground mask를 생성할 때 binary mask를 이용하여 원하는 값을 걸러낸다. 보통의 label은 1 channel인데, class 개수만큼의 channel를 만들고, 그 위치에 해당하는 label 값을 넣고 싶을 때 아래와 같이 할 수 있다. label_expand = torch..
PyTorch: matmul, mm, bmm torch.matmul vector 및 matrix 간의 다양한 곱을 수행한다. broadcast 기능을 제공하며 가장 일반적으로 사용되나, broadcast 기능이 도리어 debug point가 될 수 있다. broadcast 기능은 아래의 예제와 같이 T1(10, 3, 4) T2(4)을 곱할 때, 맨 앞의 dim이 3개 일 때는 첫 dim을 batch로 간주하고 T1 (3, 4) tensor의 10개의 batch와 각각 T2(4)랑 곱을 해주는 것이다. torch.matmul(input, other, *, out=None) → Tensor torch.mm torch.matmul과 차이점은 broadcast가 안 된다는 점이다. 즉 mm은 정확하게 matrix 곱의 사이즈가 맞아야 사용이 가능하다. 따..