jax에서 변수들을 출력해보면, Traced<ShapedArray(float32[256,28])>with<BatchTrace(level=1/1)> with val = ~ 이런 식으로 나오는 경우가 많다. (BatchTracer Class)
이는 jax가 함수들을 compile한 후 한번에 실행하기 때문에, print 할 수 있는 실제 값이 없고, 그 shape(껍데기)만 출력하기 때문이다.
(자세한 이유는 https://github.com/google/jax/issues/196에 정확하고, 상세하게 나와있으니 참고하면 좋을 것 같다.)
그러나 d3pm 코드를 이해하는 과정에서 변수에 어떤 값이 할당되어 있는지 확인해야하는 경우가 많았기 때문에 debugging 방법이 정말 없는지 열심히 찾아보았다.
최종적으로 발견한 방법은 크게 두 가지가 있었는데, 우리의 코드에서는 두번째 방법으로 해결하였다.
첫번째 방법은 아래와 같이 jax.experimental.host_callback에서 call 함수를 활용하는 것이다.
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call
@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
print(x)
call(lambda x: print(f"x: {x}"), x)
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000000,))
selu(x)
그러나 이 방법은 모든 함수를 @jax.jit 데코레이터로 모든 활용되는 함수를 감싼 후 실행해야하는데, d3pm 같은 논문을 구현한 코드는 사용된 함수들이 너무 많고 복잡하기 때문에 활용하기 쉽지 않았다.
두번째 방법은 비교적 간단한데, 바로 jax.debug.print(“now {}“, 변수)를 활용하는 것이다.
다만 이때 주의해야될 것은 vscode에서 debugging tool을 활용하면 안되고, 그냥 바로 python 파일을 실행해야한다. 아마 위 이유와 비슷한 이유인 것 같은데, 코드를 함수 단위 혹은 클래스 단위로 컴파일해서 활용하는 jax의 특성 상 line by line으로 debugging 하면서 출력이 안되는 것 같다. 또한 꼭 출력 formatting을 "now {}", 변수 형식으로 해야하는지도 아직 잘 모르겠다.
(참고 자료 : https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html)
여러모로 에러가 많이 뜨는 jax여서 코드 이해부터 쉽지 않지만, 하나씩 에러가 해결되니까 조금은 이해가 되는 것 같다.
블로그 글을 정리하다가 알게 되었는데, jax.debug.breakpoint()를 활용하면 해당 부분에서 모든 변수를 다 출력해볼 수 있어서 훨씬 편리한 것 같다!!
참고로 jdb(jax.debug.breakpoint()에서 command는 다음과 같다.
Debugger commands:
- help - prints out available commands
- p - evaluates an expression and prints its result
- pp - evaluates an expression and pretty-prints its result
- u(p) - go up a stack frame (위로 이동)
- d(own) - go down a stack frame (아래로 이동)
- w(here)/bt - print out a backtrace
- l(ist) - print out code context (현재 어떤 위치에서 debugging 하는것인지 확인하려면 l)
- c(ont(inue)) - resumes the execution of the program (계속 진행하려면 c)
- q(uit)/exit - exits the program (does not work on TPU) (나가려면 q)
댓글