jax에서 모델의 checkpoint를 저장하려할때 위와 같은 오류가 발생하였다.
typeerror: can not serialize 'array' object
기존에 사용하던 flax의 버전이 0.5.3이였는데, 0.6.1 버전으로 업데이트 해 주면, 쉽게 해결할 수 있다.
pip uninstall flax
pip install flax==0.6.1
다만, 이렇게 하면 flax의 optimizer 라이브러리(flax 0.5.3)가 optax라는 라이브러리(flax 0.6.1)로 바뀌어서 해당 부분을 전부 수정해야하는 번거로움이 있다.
참고 자료
https://github.com/google/jax/issues/13540
cannot serialize 'Array' object when jax_array = True · Issue #13540 · google/jax
Description import jax.numpy as jnp from flax import serialization s = jnp.zeros(5) t = serialization.msgpack_serialize(s) print(len(t)) jax 0.3.25: works OK (prints 36) jax HEAD: File "jaxbug.py",...
github.com
댓글