jax에서 모델의 checkpoint를 저장하려할때 위와 같은 오류가 발생하였다.
typeerror: can not serialize 'array' object
기존에 사용하던 flax의 버전이 0.5.3이였는데, 0.6.1 버전으로 업데이트 해 주면, 쉽게 해결할 수 있다.
pip uninstall flax
pip install flax==0.6.1
다만, 이렇게 하면 flax의 optimizer 라이브러리(flax 0.5.3)가 optax라는 라이브러리(flax 0.6.1)로 바뀌어서 해당 부분을 전부 수정해야하는 번거로움이 있다.
참고 자료
https://github.com/google/jax/issues/13540
댓글