본문 바로가기
기타/오류 뿌수기

[jax 오류] typeerror: can not serialize 'array' object (checkpoint 저장 오류)

by climba 2023. 4. 30.

jax에서 모델의 checkpoint를 저장하려할때 위와 같은 오류가 발생하였다.

typeerror: can not serialize 'array' object

 

기존에 사용하던 flax의 버전이 0.5.3이였는데, 0.6.1 버전으로 업데이트 해 주면, 쉽게 해결할 수 있다.

pip uninstall flax
pip install flax==0.6.1

 

다만, 이렇게 하면 flax의 optimizer 라이브러리(flax 0.5.3)가 optax라는 라이브러리(flax 0.6.1)로 바뀌어서 해당 부분을 전부 수정해야하는 번거로움이 있다.

 

참고 자료

https://github.com/google/jax/issues/13540

 

cannot serialize 'Array' object when jax_array = True · Issue #13540 · google/jax

Description import jax.numpy as jnp from flax import serialization s = jnp.zeros(5) t = serialization.msgpack_serialize(s) print(len(t)) jax 0.3.25: works OK (prints 36) jax HEAD: File "jaxbug.py",...

github.com

 

댓글