learn_jax/parallel/gptneo_partition_test.py

25 lines
585 B
Python

# %%
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import freeze, unfreeze
from partitions import set_partitions
from transformers import FlaxAutoModelForCausalLM
# this inits the model directly
model = FlaxAutoModelForCausalLM.from_pretrained(
"gpt-neo-125m",
)
params = model.params
# %%
import json
shape_dict = jax.tree.map(jnp.shape, params)
# print(json.dumps(shape_dict, sort_keys=True, indent=4))
with open('gpt-neo-125m.json', 'w') as f:
json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
# %%
param_spec = set_partitions(unfreeze(params))
# %%