0. 들어가기에 앞서
Hook을 이용하여 기존 Class 바탕의 코드를 작성할 필요 없이 상태 값과 여러 React의 기능을 사용할 수 있습니다. - reactjs
Hook(후킹)은 소프트웨어 공학 용어로, 운영 체제나 응용 소프트웨어 등의 각종 컴퓨터 프로그램에서 소프트웨어 구성 요소 간에 발생하는 함수 호출, 메시지, 이벤트 등을 중간에서 바꾸거나 가로채는 명령, 방법, 기술이나 행위를 말한다.
이렇게 말하면 조금 어려울 수 있는데, 쉽게 말해서
- 프로그램의 실행 로직을 분석하거나
- 프로그램에 추가적인 기능을 제공하고 싶을 때
사용되는 것 것이 hook이라고 생각하면 된다.
pytorch 뿐만 아니라 react,C# 등에서도 자주 사용되는 방식으로 이번 포스팅에서는 hook이 무엇이고, pytorch에서는 어떤 방식으로 hook이 사용되는지에 대해 설명하겠다.
1. hook
아래와 같은 코드를 살펴보자.
def program_A(x):
print('program A processing!')
return x + 3
def program_B(x):
print('program B processing!')
return x - 3
class Package(object):
"""프로그램 A와 B를 묶어놓은 패키지 코드"""
def __init__(self):
self.programs = [program_A, program_B]
def __call__(self, x):
for program in self.programs:
x = program(x)
return x
이런 코드 중간에 x + 5를 수행하는 program_C라는 새로운 함수를 적용하고 싶다고 생각해보자.
우선 program_C를 만든 후 Package 클래스 안의 코드를 살피며 self.programs에 [program_A, program_B, program_C]로 수정해야 할 것이다.
그러나 대부분의 모듈에서는 이렇게 실제 패키지를 수정하기 어려운 경우도 많고, 설령 가능하다고 하더라도 그 과정이 되게 복잡한 경우가 많을 것이다.
hook은 이처럼 외부에서 만든 모듈이나 함수 등의 custom 코드를 패키지 내부에서 손쉽게 사용할 수 있게 한다.
hook 기능을 추가한 프로그램 코드는 아래와 같다.
def program_A(x):
print('program A processing!')
return x + 3
def program_B(x):
print('program B processing!')
return x - 3
class Package(object):
"""프로그램 A와 B를 묶어놓은 패키지 코드"""
def __init__(self):
self.programs = [program_A, program_B]
# hooks
self.pre_hooks = []
self.hooks = []
def __call__(self, x):
for program in self.programs:
# pre_hook
if self.pre_hooks:
for hook in self.pre_hooks:
output = hook(x)
if output:
x = output
# 실제 사용하는 프로그램 코드
x = program(x)
# hook
if self.hooks:
for hook in self.hooks:
output = hook(x)
if output:
x = output
return x
여기서 한가지 중요한 점은 실제 사용하는 프로그램 코드인 x = program(x)의 전, 후로 각각 pre_hook과 hook이 있다는 것이다.
pre_hook은 프로그램 실행 전에 hook을 사용할 수 있게하고, hook은 프로그램 실행 이후에만 사용 가능하다는 차이점이 있다.
2. hook in pytorch
pytorch에는 두가지 종류의 hook이 있다.
- Tensor에 적용하는 hook
- Module에 적용하는 hook
2-1. Tensor hook
우선 Tensor에 적용하는 hook은 "_backward_hooks" 밖에 없다.
"_backward_hooks"는 Tensor의 그래디언트가 계산될 때 마다 호출된다.
사용 코드 예시는 아래와 같다.
v = torch.tensor([0., 0., 0.], requires_grad=True)
h = v.register_hook(lambda grad: grad * 2) # double the gradient
v.backward(torch.tensor([1., 2., 3.]))
v.grad
'''
2
4
6
[torch.FloatTensor of size (3,)]
'''
h.remove() # removes the hook
register_hook을 통해 사용하고자 하는 hook을 등록하고(여기선 grad에 2를 곱해주는 함수를 등록하였다), v.backward를 실행시키면 input으로 받은 [1., 2., 3.]에 대하여 grad = grad x 2가 적용되고 이를 v.grad로 확인해보면 2, 4, 6이 출력된다.
그 후 h.remove()를 통해 해당 hook을 지울 수 있다.
2-2. Module hook
nn.Module에 등록하는 모든 hook은 "__dict__"를 이용해 한번에 확인 할 수 있다.
nn.Module에 있는 hook 및 method들의 설명은 여기에 정리해놨다.
[Pytorch] torch.nn.Module에는 어떤 method가 있을까?
torch.nn.Module 은 여러 기능을 모아두는 상자 역할을 한다. 하나의 nn.Module 상자가 여러 pytorch의 기능을 모아둘수도 있고, 다른 상자를 포함할수도 있다. 공식 Document 에 올라와있는 nn.Module의 여러
hyunsooworld.tistory.com
nn.Module에는 총 3개의 hook이 있는데 아래와 같다.
- register_full_backward_hook( hook ) (=>register_backward_hook( )에서 바뀜)
- 모듈에 역방향 hook를 등록
- register_full_backward_hook - PyTorch 공식 문서
- register_forward_hook( hook )
- 모듈에 정방향 후크를 등록
- forward( )에는 영향을 미치지 않고, 내부 입력은 수정 할 수 있다
- register_forward_hook - PyTorch 공식 문서
- register_forward_pre_hook( hook )
- 모듈에 정방향 프리훅을 등록
- register_forward_pre_hook - PyTorch 공식 문서
💡이외에도 state_dict_hooks도 있긴 하지만 이는 "load_state_dict" 함수에서 내부적으로 사용되는 것이다.
2-2-1. foward_hook
nn.Module에서의 forward_hook 사용 예제는 아래와 같다.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
# 모델 생성
add = Add()
answer = []
# pre_hook를 이용하면 x1, x2 값을 알아낼 수 있다.
def pre_hook(module, input):
answer.append(input[0])
answer.append(input[1])
pass
add.register_forward_pre_hook(pre_hook)
# hook를 이용하면 output 값을 알아낼 수 있다.
def hook(module, input, output):
answer.append(output[0])
pass
add.register_forward_hook(hook)
# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2)
if answer == [x1, x2, output]:
print("🎉🎉🎉 성공!!! 🎉🎉🎉")
else:
print("🦆 다시 도전해봐요!")
2-2-2. backward_hook
backward_hook의 예시는 아래와 같다.
backward_hook을 사용하면 역전파 과정에서 뒤로 전파되는 gradient 값을 알아낼 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
# 모델 생성
model = Model()
answer = []
# hook를 이용해서 x1.grad, x2.grad, output.grad 값을 알아낼 수 있다.
def module_hook(module, grad_input, grad_output):
answer.append(grad_input[0])
answer.append(grad_input[1])
answer.append(grad_output[0])
# return answer => 절대 return answer을 하지 말자!!
pass
model.register_full_backward_hook(module_hook)
# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.retain_grad()
output.backward()
# model.W.register_full_backward_hook
print(answer)
if answer == [x1.grad, x2.grad, output.grad]:
print("🎉🎉🎉 성공!!! 🎉🎉🎉")
else:
print("🦆 다시 도전해봐요!")
3. 마무리
이번 포스팅에서는 hook에 대해 알아보았다.
hook은 pytorch뿐만 아니라 다양한 곳에서 사용되기 때문에 개념을 이해하고 있으면 좋을 것 같다.
특히 pytorch(딥러닝)에서는 역전파 과정에서 전파되는 값들을 원하는 방식으로 수정할 수 있다는 점에서 되게 유용하다고 생각했다.
hook을 적절하게 활용한다면
- gradient값의 변화를 시각화
- gradient값이 특정 임계값을 넘으면 gradient exploding 경고 알림
- 특정 tensor의 gradient값이 너무 커지거나 작아지는 현상이 관측되면 해당 tensor 한정으로 gradient clipping
등을 할 수 있을 것이다.
딥러닝에서의 hook 사용 예시는 아래 링크를 참고하면 조금 더 도움이 될 것이다.
- How to Use PyTorch Hooks - Medium
- PyTorch 101, Part 5: Understanding Hooks - Paperspace blog
- PyTorch Hooks Explained - In-depth Tutorial - YouTube
'AI Theory > DL Basic' 카테고리의 다른 글
HuggingFace custom dataset 저장하는 방법 (1) | 2023.06.18 |
---|---|
[Pytorch] torch.nn.Module에는 어떤 method가 있을까? (0) | 2022.09.29 |
[Pytorch] torch.gather vs torch.Tensor.scatter_ (pytorch 인덱싱 방법) (0) | 2022.09.29 |
딥러닝을 위한 경사하강법(Gradient Descent) (1) | 2022.09.23 |
Inception v1,v2,v3,v4는 무엇이 다른가 (+ CNN의 역사) (1) | 2022.03.30 |
댓글