1. index_select (파이토치 인덱싱)
파이토치에서 인덱싱은 기본적으로 index_select 함수를 사용한다.
💡 참고로 pytorch에서의 dim은 numpy에서의 axis로 생각하면 된다.
- torch.index_select(input, dim, *, out = None) → Tensor
- Parameters
- input ( Tensor ) - 입력값
- dim ( int ) - 인덱스 할 기준 차원
- index ( IntTensor or LongTensor ) - 인덱싱할 인덱스를 포함하는 1차원 텐서
- Keyword Arguments
- out ( Tensor, optional ) - 출력 텐서
- Parameters
x = torch.randn(3, 4)
x
'''
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
'''
indices = torch.tensor([0, 2])
torch.index_select(x, 0, indices)
'''
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
'''
torch.index_select(x, 1, indices)
'''
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
'''
dim이 0이면 행(dim=0)에 대한 인덱싱을 진행하고, dim이 1이면 열(dim=1)에 대한 인덱싱을 진행한다.
그렇다면 대각선 요소와 같이 각 행 혹은 열 별로 인덱싱 할 값이 다른 경우는 어떡할까?
바로 torch.gather을 사용해서 이를 해결할 수 있는데, 고차원의 경우 torch.gather 및 뒤에 나올 torch.scatter 개념이 헷갈릴 수 있으니 잘 이해해보자.
2. torch.gather
- torch.gather(input, dim, index, *, sparse_grad = False, out = None) → Tensor
- Parameters
- input ( Tensor ) - the source tensor
- dim ( int ) - the axis along which to index (인덱싱할 기준 차원)
- index ( LongTensor ) - the indices of elements to gather
- Keyword Arguments
- sparse_grad ( bool, optional ) - If True, gradient w.r.t. input will be a sparse tensor
- out ( Tensor, optional ) - the destination tensor
- Parameters
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
'''
tensor([[ 1, 1],
[ 4, 3]])
'''
위 코드를 예시로 설명하면, t라는 source tensor에 대해 dim = 1 기준(열 기준)으로 [0, 0] 인덱스에 해당하는 [1, 1] 와 [1, 0]에 해당하는 [4, 3]를 출력한다.
여기서, 열 기준이라는 것은 첫 번째 행 기준 열 방향으로 아래로 내려간다는 것을 의미한다.
2차원 예시는 비교적 간단하여 쉽게 이해할 수 있지만, 3차원 이상의 경우 조금 헷갈릴 수 있으므로 조금 더 자세히 봐보자.
공식 document에 따르면 3-D tensor에 대한 output은 아래와 같이 정의된다.
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
dim이 0, 1, 2임에 따라 input의 index가 바뀌는 것이다.
A = torch.Tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
2 x 2 x 2 크기의 A를 그림으로 표시하면 아래와 같고, 우리가 인덱싱하고 싶은 값은 빨간색 동그라미 친 1, 4, 5, 8 이다

우선 앞에 있는 2x2와 뒤에 있는 2x2 모두 인덱싱 할 것이 있으므로 dim = 2를 입력해야 할 것이다.
또한 각각의 2x2 matrix는 최대 인덱스가 2 이므로 [[0]],[1]]을 두번 적어서 [[[0]],[1]] (앞), [[0]],[1]] (뒤)] 이런식으로 인덱스를 주면 될 것이다. 따라서 정답 코드는 다음과 같다.
output = torch.gather(A,2,torch.tensor([[[0],[1]],[[0],[1]]])).view(2,2)
'''
view(2,2)를 사용하지 않으면
tensor([[[1.],
[4.]],
[[5.],
[8.]]])
위와같은 2x2x1 shape이 된다.
따라서 view(2,2)를 통해
tensor([[1., 4.],
[5., 8.]])
이런식으로 출력해준다.
'''
이를 확장시켜 임의의 크기의 3D tensor에서 대각선 요소를 가져오려면 아래와 같이 할 수 있다.
import torch
# TODO : 임의의 크기의 3D tensor에서 대각선 요소 가져와 2D로 반환하는 함수를 만드세요!
def get_diag_element_3D(A):
i = min(A.shape[1],A.shape[2])
tsr = [[[i] for i in range(i)]] * A.shape[0]
tsr = torch.tensor(tsr)
output = torch.gather(A,2,tsr).view(tsr.shape[0],tsr.shape[1])
return output
C = 2
H = 3
W = 3
# 아래 코드는 수정하실 필요가 없습니다!
A = torch.tensor([i for i in range(1, C*H*W + 1)])
A = A.view(C, H, W)
print(f"원본 3D 행렬\n{A}")
print("-" * 30)
print(f"대각선 요소를 모은 2D 행렬")
get_diag_element_3D(A)
'''
원본 3D 행렬
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]]])
------------------------------
대각선 요소를 모은 2D 행렬
tensor([[ 1, 5, 9],
[10, 14, 18]])
'''
약간의 설명을 추가하면 get_diag_element_3D 함수에서 tsr이라는 index리스트의 i번째 요소는 [i]가 될 것이다.
여기서 i는 행,열 중 작은 값을 나타내고 예를 들어 n x 3 x 4 이라 하면 i는 3이되고, tsr = [[0], [1], [2]] 이 된다.
이 tsr을 n개 만들어 모든 대각선 요소를 인덱싱 할 수 있는 것이다.
3. torch.Tensor.scatter_
- torch.Tensor.scatter_ (dim, index, src, reduce = None) → Tensor
- Parameters
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src . When empty, the operation returns self unchanged.
- src (Tensor or float) – the source element(s) to scatter.
- reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'.
print(torch.zeros(3,5))
'''
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
'''
src = torch.arange(1, 11).reshape((2, 5))
src
'''
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
'''
위의 src를 참고해서 코드를 짜보면
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
'''
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
'''
이런식으로 특정 인덱스의 값을 바꿔 줄 수 있다. index = [0, 1, 2, 0]으로 되어있는데, scatter_의 dim 값에 따라 적용이 달라진다.
scatter_ 에서도 마찬가지로 처음에 들어가는 dim이 0 이기 때문에 gather에서와 같은 맥락으로 행 방향로 살펴가며
첫 번째 행 → index : 0 => [1] 을 0번째 인덱스에 대입
두 번째 행 → index : 1 => [2] 을 1번째 인덱스에 대입
세 번째 행 → index : 2 => [3] 을 2번째 인덱스에 대입
네 번째 행(다시 첫번째 행) → index : 0 => [4] 을 0번째 인덱스에 대입
(index의 요소가 총 4개이므로 네번째 열 까지만 살펴본다.)
이런식으로 값을 바꿔준다.
조금 더 어려운 케이스로,
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
'''
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
'''
위와 같은 경우에는 dim이 1이기 때문에 열 방향로 살펴가며
첫 번째 열 → index : [0, 1, 2] => [1, 2, 3]을 0, 1, 2번째 인덱스에 대입
두 번째 열 → index : [0, 1, 4] => [6, 7, 8]을 0, 1, 4번째 인덱스에 대입
(index의 요소가 총 2개이므로 두번째 행 까지만 살펴본다.)
4. 마무리하며
pytorch에서 가장 중요한 것은 결국 차원인 것 같다. dim = 0 , dim = 1, dim = 2가 대부분의 함수에서 비슷한 맥락으로 사용되는 것 같으니 한번 잘 이해해두면 나중에 다른 함수나 모듈을 이해 할 때도 도움이 많이 될 것이다.
'AI Theory > DL Basic' 카테고리의 다른 글
[Pytorch] hook은 왜 있는 것일까? (1) | 2022.09.29 |
---|---|
[Pytorch] torch.nn.Module에는 어떤 method가 있을까? (0) | 2022.09.29 |
딥러닝을 위한 경사하강법(Gradient Descent) (1) | 2022.09.23 |
Inception v1,v2,v3,v4는 무엇이 다른가 (+ CNN의 역사) (1) | 2022.03.30 |
[딥러닝기초] 머신러닝 흐름 파악하기 (0) | 2022.03.16 |
댓글