본문 바로가기

jax 오류2

[jax 오류] jax에서 debugging하는 방법 (how to print BatchTracer class in jax) jax에서 변수들을 출력해보면, Tracedwith with val = ~ 이런 식으로 나오는 경우가 많다. (BatchTracer Class) 이는 jax가 함수들을 compile한 후 한번에 실행하기 때문에, print 할 수 있는 실제 값이 없고, 그 shape(껍데기)만 출력하기 때문이다. (자세한 이유는 https://github.com/google/jax/issues/196에 정확하고, 상세하게 나와있으니 참고하면 좋을 것 같다.) 그러나 d3pm 코드를 이해하는 과정에서 변수에 어떤 값이 할당되어 있는지 확인해야하는 경우가 많았기 때문에 debugging 방법이 정말 없는지 열심히 찾아보았다. 최종적으로 발견한 방법은 크게 두 가지가 있었는데, 우리의 코드에서는 두번째 방법으로 해결하였다. .. 2023. 3. 25.
[jax 오류] INTERNAL: nvlink exited with non-zero error code 65280, output: nvlink error d3pm code를 뜯어보던 과정에서 jax 관련 오류가 발생하였다. INTERNAL: nvlink exited with non-zero error code 65280, output: nvlink error 대충 위와 같은 오류였는데, 디버깅을 해봐도 non-zero(?)와 관련된 부분은 찾을 수 없었다. 그러던 중 jax의 공식 레포(https://github.com/google/jax)에서 jax가 TPU에서 제공 된다는 것을 알게되었고, GPU 설정이 아닌 TPU 환경에서 실행해야 할 것 같다는 생각이 들었다. 참고로 개발 환경이 colab pro + vscode였는데, colab - 런타임 - 런타임 유형 변경 - 하드웨어 가속기 TPU 선택 을 통해 문제를 해결할 수 있었다. 참고로 tpu는 구.. 2023. 3. 15.