본문 바로가기
AI Theory/DL Basic

[Pytorch] hook은 왜 있는 것일까?

by climba 2022. 9. 29.

0. 들어가기에 앞서

Hook을 이용하여 기존 Class 바탕의 코드를 작성할 필요 없이 상태 값과 여러 React의 기능을 사용할 수 있습니다. - reactjs

Hook(후킹)은 소프트웨어 공학 용어로, 운영 체제나 응용 소프트웨어 등의 각종 컴퓨터 프로그램에서 소프트웨어 구성 요소 간에 발생하는 함수 호출, 메시지, 이벤트 등을 중간에서 바꾸거나 가로채는 명령, 방법, 기술이나 행위를 말한다.

(위키백과)

 

이렇게 말하면 조금 어려울 수 있는데, 쉽게 말해서 

  • 프로그램의 실행 로직을 분석하거나
  • 프로그램에 추가적인 기능을 제공하고 싶을 때

사용되는 것 것이 hook이라고 생각하면 된다.

 

pytorch 뿐만 아니라 react,C# 등에서도 자주 사용되는 방식으로 이번 포스팅에서는 hook이 무엇이고, pytorch에서는 어떤 방식으로 hook이 사용되는지에 대해 설명하겠다.

 

1. hook

아래와 같은 코드를 살펴보자.

def program_A(x):
    print('program A processing!')
    return x + 3

def program_B(x):
    print('program B processing!')
    return x - 3

class Package(object):
    """프로그램 A와 B를 묶어놓은 패키지 코드"""
    def __init__(self):
        self.programs = [program_A, program_B]

    def __call__(self, x):
        for program in self.programs:
            x = program(x)

        return x

이런 코드 중간에 x + 5를 수행하는 program_C라는 새로운 함수를 적용하고 싶다고 생각해보자.

우선 program_C를 만든 후 Package 클래스 안의 코드를 살피며 self.programs에 [program_A, program_B, program_C]로 수정해야 할 것이다.

 

그러나 대부분의 모듈에서는 이렇게 실제 패키지를 수정하기 어려운 경우도 많고, 설령 가능하다고 하더라도 그 과정이 되게 복잡한 경우가 많을 것이다.

 

hook은 이처럼 외부에서 만든 모듈이나 함수 등의 custom 코드를 패키지 내부에서 손쉽게 사용할 수 있게 한다.

hook 기능을 추가한 프로그램 코드는 아래와 같다.

def program_A(x):
    print('program A processing!')
    return x + 3

def program_B(x):
    print('program B processing!')
    return x - 3

class Package(object):
    """프로그램 A와 B를 묶어놓은 패키지 코드"""
    def __init__(self):
        self.programs = [program_A, program_B]

        # hooks
        self.pre_hooks = []
        self.hooks = []

    def __call__(self, x):
        for program in self.programs:
            
            # pre_hook
            if self.pre_hooks:
                for hook in self.pre_hooks:
                    output = hook(x)
                    if output:
                        x = output
                        
            # 실제 사용하는 프로그램 코드
            x = program(x)
            

            # hook
            if self.hooks:
                for hook in self.hooks:
                    output = hook(x)
                    if output:
                        x = output

        return x

여기서 한가지 중요한 점은 실제 사용하는 프로그램 코드인 x = program(x)의 전, 후로 각각 pre_hook과 hook이 있다는 것이다.

pre_hook은 프로그램 실행 전에 hook을 사용할 수 있게하고, hook은 프로그램 실행 이후에만 사용 가능하다는 차이점이 있다.

 

2. hook in pytorch

pytorch에는 두가지 종류의 hook이 있다.

  • Tensor에 적용하는 hook
  • Module에 적용하는 hook

2-1. Tensor hook

우선 Tensor에 적용하는 hook은 "_backward_hooks" 밖에 없다.

"_backward_hooks"는 Tensor의 그래디언트가 계산될 때 마다 호출된다.

사용 코드 예시는 아래와 같다.

v = torch.tensor([0., 0., 0.], requires_grad=True)
h = v.register_hook(lambda grad: grad * 2)  # double the gradient
v.backward(torch.tensor([1., 2., 3.]))
v.grad
'''
 2
 4
 6
[torch.FloatTensor of size (3,)]
'''
h.remove()  # removes the hook

register_hook을 통해 사용하고자 하는 hook을 등록하고(여기선 grad에 2를 곱해주는 함수를 등록하였다), v.backward를 실행시키면 input으로 받은 [1., 2., 3.]에 대하여 grad = grad x 2가 적용되고 이를 v.grad로 확인해보면  2, 4, 6이 출력된다.

그 후 h.remove()를 통해 해당 hook을 지울 수 있다.

2-2. Module hook

nn.Module에 등록하는 모든 hook은 "__dict__"를 이용해 한번에 확인 할 수 있다.

nn.Module에 있는 hook 및 method들의 설명은 여기에 정리해놨다.

 

[Pytorch] torch.nn.Module에는 어떤 method가 있을까?

torch.nn.Module 은 여러 기능을 모아두는 상자 역할을 한다. 하나의 nn.Module 상자가 여러 pytorch의 기능을 모아둘수도 있고, 다른 상자를 포함할수도 있다. 공식 Document 에 올라와있는 nn.Module의 여러

hyunsooworld.tistory.com

 

nn.Module에는 총 3개의 hook이 있는데 아래와 같다.

💡이외에도 state_dict_hooks도 있긴 하지만 이는 "load_state_dict" 함수에서 내부적으로 사용되는 것이다.

2-2-1. foward_hook

nn.Module에서의 forward_hook 사용 예제는 아래와 같다.

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)

        return output

# 모델 생성
add = Add()

answer = []


# pre_hook를 이용하면 x1, x2 값을 알아낼 수 있다.
def pre_hook(module, input):
    answer.append(input[0])
    answer.append(input[1])
    pass
    
add.register_forward_pre_hook(pre_hook)

# hook를 이용하면 output 값을 알아낼 수 있다.
def hook(module, input, output):
    answer.append(output[0])
    pass
    
add.register_forward_hook(hook)


# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1)
x2 = torch.rand(1)

output = add(x1, x2)

if answer == [x1, x2, output]:
    print("🎉🎉🎉 성공!!! 🎉🎉🎉")
else:
    print("🦆 다시 도전해봐요!")

2-2-2. backward_hook

backward_hook의 예시는 아래와 같다.

backward_hook을 사용하면 역전파 과정에서 뒤로 전파되는 gradient 값을 알아낼 수 있다.

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()

answer = []

# hook를 이용해서 x1.grad, x2.grad, output.grad 값을 알아낼 수 있다.
def module_hook(module, grad_input, grad_output):
    answer.append(grad_input[0])
    answer.append(grad_input[1])
    answer.append(grad_output[0])
    # return answer  => 절대 return answer을 하지 말자!!
    pass

model.register_full_backward_hook(module_hook)

# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.retain_grad()
output.backward()
# model.W.register_full_backward_hook
print(answer)
if answer == [x1.grad, x2.grad, output.grad]:
    print("🎉🎉🎉 성공!!! 🎉🎉🎉")
else:
    print("🦆 다시 도전해봐요!")

3. 마무리

이번 포스팅에서는 hook에 대해 알아보았다.

hook은 pytorch뿐만 아니라 다양한 곳에서 사용되기 때문에 개념을 이해하고 있으면 좋을 것 같다.

특히 pytorch(딥러닝)에서는 역전파 과정에서 전파되는 값들을 원하는 방식으로 수정할 수 있다는 점에서 되게 유용하다고 생각했다.

 

hook을 적절하게 활용한다면

  • gradient값의 변화를 시각화
  • gradient값이 특정 임계값을 넘으면 gradient exploding 경고 알림
  • 특정 tensor의 gradient값이 너무 커지거나 작아지는 현상이 관측되면 해당 tensor 한정으로 gradient clipping

등을 할 수 있을 것이다.

딥러닝에서의 hook 사용 예시는 아래 링크를 참고하면 조금 더 도움이 될 것이다.

 

댓글