# %% import jax import jax.numpy as jnp from flax.core.frozen_dict import freeze, unfreeze from partitions import set_partitions from transformers import FlaxAutoModelForCausalLM # this inits the model directly model = FlaxAutoModelForCausalLM.from_pretrained( "gpt-neo-125m", ) params = model.params # %% import json shape_dict = jax.tree.map(jnp.shape, params) # print(json.dumps(shape_dict, sort_keys=True, indent=4)) with open('gpt-neo-125m.json', 'w') as f: json.dump(shape_dict, fp=f, sort_keys=True, indent=2) # %% param_spec = set_partitions(unfreeze(params)) # %%