본문 바로가기

code log/PyTorch log

PyTorch: cumsum

torch.cumsum()

torch.cumsum(input, dim, *, dtype=None, out=None) → Tensor

dim=N의 방향으로 누적합을 구하는 함수

 

만약 input이 N size의 백터라면, 결과는 N size의 같은 크기의 백터를 뱉는다. 

 

 

 

실제 사용 예:

semantic segmentation처럼 dense label을 만들 때, ignore vlaue를 제외한 mask를 만들거나 foreground mask를 생성할 때

binary mask를 이용하여 원하는 값을 걸러낸다. 

 

보통의 label은 1 channel인데, class 개수만큼의 channel를 만들고, 그 위치에 해당하는 label 값을 넣고 싶을 때

아래와 같이 할 수 있다. 

label_expand = torch.unsqueeze(label, 1).repeat(1, int(num_class), 1, 1) # broadcast label to # of class
labels = label_expand.clone()
labels[labels != ignore_label] = 1.0
labels[labels == ignore_label] = 0.0
labels_valid = labels.clone()
labels_invalid = (1.0 - labels.clone())
labels = torch.cumsum(labels, dim=1)
labels[labels != label_expand + 1] = 0.0

 

'code log > PyTorch log' 카테고리의 다른 글

PyTorch: gather  (0) 2021.02.04
PyTorch: matmul, mm, bmm  (0) 2020.12.03