본문 바로가기
기타/오류 뿌수기

[jax 오류] jax에서 debugging하는 방법 (how to print BatchTracer class in jax)

by climba 2023. 3. 25.

x_0, transition_probs, samples 등의 변수에 어떤 값이 할당되어 있는지 값을 보고싶은데 ... 자꾸 shape만 출력된다

jax에서 변수들을 출력해보면, Traced<ShapedArray(float32[256,28])>with<BatchTrace(level=1/1)> with val = ~ 이런 식으로 나오는 경우가 많다. (BatchTracer Class)

 

이는 jax가 함수들을 compile한 후 한번에 실행하기 때문에, print 할 수 있는 실제 값이 없고, 그 shape(껍데기)만 출력하기 때문이다.

(자세한 이유는 https://github.com/google/jax/issues/196에 정확하고, 상세하게 나와있으니 참고하면 좋을 것 같다.)

그러나 d3pm 코드를 이해하는 과정에서 변수에 어떤 값이 할당되어 있는지 확인해야하는 경우가 많았기 때문에 debugging 방법이 정말 없는지 열심히 찾아보았다.

 

최종적으로 발견한 방법은 크게 두 가지가 있었는데, 우리의 코드에서는 두번째 방법으로 해결하였다.

첫번째 방법은 아래와 같이 jax.experimental.host_callback에서 call 함수를 활용하는 것이다.

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call

@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
  print(x)
  call(lambda x: print(f"x: {x}"), x)
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000000,))
selu(x)

그러나 이 방법은 모든 함수를 @jax.jit 데코레이터로 모든 활용되는 함수를 감싼 후 실행해야하는데, d3pm 같은 논문을 구현한 코드는 사용된 함수들이 너무 많고 복잡하기 때문에 활용하기 쉽지 않았다.

 

두번째 방법은 비교적 간단한데, 바로 jax.debug.print(“now {}“, 변수)를 활용하는 것이다.

다만 이때 주의해야될 것은 vscode에서 debugging tool을 활용하면 안되고, 그냥 바로 python 파일을 실행해야한다. 아마 위 이유와 비슷한 이유인 것 같은데, 코드를 함수 단위 혹은 클래스 단위로 컴파일해서 활용하는 jax의 특성 상 line by line으로 debugging 하면서 출력이 안되는 것 같다. 또한 꼭 출력 formatting을 "now {}", 변수 형식으로 해야하는지도 아직 잘 모르겠다.

(참고 자료 : https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html)

 

jax.debug.print and jax.debug.breakpoint — JAX documentation

jax.debug.print and jax.debug.breakpoint The jax.debug package offers some useful tools for inspecting values inside of JIT-ted functions. Debugging with jax.debug.print and other debugging callbacks TL;DR Use jax.debug.print() to print values to stdout in

jax.readthedocs.io

vscode에 있는 이 debugging tool을 사용하면 안된다!!

여러모로 에러가 많이 뜨는 jax여서 코드 이해부터 쉽지 않지만, 하나씩 에러가 해결되니까 조금은 이해가 되는 것 같다.

 

 

 

 


블로그 글을 정리하다가 알게 되었는데, jax.debug.breakpoint()를 활용하면 해당 부분에서 모든 변수를 다 출력해볼 수 있어서 훨씬 편리한 것 같다!!

참고로 jdb(jax.debug.breakpoint()에서 command는 다음과 같다.

Debugger commands:

  • help - prints out available commands
  • p - evaluates an expression and prints its result
  • pp - evaluates an expression and pretty-prints its result
  • u(p) - go up a stack frame (위로 이동)
  • d(own) - go down a stack frame (아래로 이동)
  • w(here)/bt - print out a backtrace
  • l(ist) - print out code context (현재 어떤 위치에서 debugging 하는것인지 확인하려면 l)
  • c(ont(inue)) - resumes the execution of the program (계속 진행하려면 c)
  • q(uit)/exit - exits the program (does not work on TPU) (나가려면 q)

댓글