25 lines
585 B
Python
25 lines
585 B
Python
# %%
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from flax.core.frozen_dict import freeze, unfreeze
|
|
from partitions import set_partitions
|
|
|
|
from transformers import FlaxAutoModelForCausalLM
|
|
# this inits the model directly
|
|
model = FlaxAutoModelForCausalLM.from_pretrained(
|
|
"gpt-neo-125m",
|
|
)
|
|
params = model.params
|
|
|
|
# %%
|
|
import json
|
|
shape_dict = jax.tree.map(jnp.shape, params)
|
|
# print(json.dumps(shape_dict, sort_keys=True, indent=4))
|
|
with open('gpt-neo-125m.json', 'w') as f:
|
|
json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
|
|
|
|
# %%
|
|
param_spec = set_partitions(unfreeze(params))
|
|
|
|
# %%
|