torch.cumsum()
torch.cumsum(input, dim, *, dtype=None, out=None) → Tensor
dim=N의 방향으로 누적합을 구하는 함수
만약 input이 N size의 백터라면, 결과는 N size의 같은 크기의 백터를 뱉는다.
실제 사용 예:
semantic segmentation처럼 dense label을 만들 때, ignore vlaue를 제외한 mask를 만들거나 foreground mask를 생성할 때
binary mask를 이용하여 원하는 값을 걸러낸다.
보통의 label은 1 channel인데, class 개수만큼의 channel를 만들고, 그 위치에 해당하는 label 값을 넣고 싶을 때
아래와 같이 할 수 있다.
label_expand = torch.unsqueeze(label, 1).repeat(1, int(num_class), 1, 1) # broadcast label to # of class
labels = label_expand.clone()
labels[labels != ignore_label] = 1.0
labels[labels == ignore_label] = 0.0
labels_valid = labels.clone()
labels_invalid = (1.0 - labels.clone())
labels = torch.cumsum(labels, dim=1)
labels[labels != label_expand + 1] = 0.0
'code log > PyTorch log' 카테고리의 다른 글
PyTorch: gather (0) | 2021.02.04 |
---|---|
PyTorch: matmul, mm, bmm (0) | 2020.12.03 |