652 lines
20 KiB
Python
652 lines
20 KiB
Python
# MARK: START
|
|
# %%
|
|
# let's make 8-device simulator
|
|
import sys
|
|
sys.dont_write_bytecode = True
|
|
import os
|
|
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = True
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
else:
|
|
# GPU flags
|
|
flags += (
|
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
|
"--xla_gpu_triton_gemm_any=false "
|
|
"--xla_gpu_enable_async_collectives=true "
|
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
|
)
|
|
os.environ["XLA_FLAGS"] = flags
|
|
|
|
import functools
|
|
from functools import partial
|
|
from pprint import pprint
|
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh, NamedSharding
|
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
|
from jax.experimental import mesh_utils
|
|
from jax.sharding import PartitionSpec
|
|
from ml_collections import ConfigDict
|
|
import optax
|
|
import logging
|
|
import time
|
|
from datasets import Dataset, load_from_disk
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
from flax.core.frozen_dict import freeze, unfreeze
|
|
import flax.core
|
|
|
|
from partitions import set_partitions
|
|
|
|
from tqdm import tqdm
|
|
|
|
from dataload import DataPrepare
|
|
|
|
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
|
|
if USE_CPU_ONLY:
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
else:
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
|
|
|
|
# %%
|
|
# get platform type
|
|
from jax.lib import xla_bridge
|
|
print(xla_bridge.get_backend().platform)
|
|
|
|
# %%
|
|
# config options
|
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
|
save_path = 't5_80_1_bf16'
|
|
# file_path = 'combined_data'
|
|
split_datasets = load_from_disk(file_path)
|
|
training_size = len(split_datasets['train'])
|
|
# Store some constant
|
|
seed = 117
|
|
num_epochs = 5
|
|
batch_size = 2 # 384 is the best
|
|
num_train_epochs = num_epochs
|
|
per_device_train_batch_size = batch_size
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
per_device_eval_batch_size = batch_size
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = training_size // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
warmup_steps = 0
|
|
learning_rate = 2e-5
|
|
|
|
weight_decay = 0.01
|
|
adam_beta1 = 0.9
|
|
adam_beta2 = 0.999
|
|
adam_epsilon = 1e-8
|
|
label_smoothing_factor = 0.0
|
|
|
|
num_beams = 1
|
|
val_max_target_length = 128
|
|
|
|
predict_with_generate = True
|
|
|
|
|
|
# %%
|
|
# prepare data
|
|
# init object
|
|
# e.g. Config
|
|
data_config = ConfigDict(
|
|
dict(
|
|
max_length=86,
|
|
pad_token_id=0,
|
|
decoder_start_token_id=0
|
|
)
|
|
)
|
|
|
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
|
# # example usage
|
|
# # %%
|
|
seed = 117
|
|
rng = jax.random.PRNGKey(seed)
|
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
|
batch = next(iter(train_loader))
|
|
# batch
|
|
|
|
# %%
|
|
# model
|
|
|
|
|
|
from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
|
|
# from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
|
|
from t5_model.pure_t5 import make_config
|
|
config = make_config()
|
|
model = model_init(config)
|
|
|
|
# %%
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
from transformers import T5Config
|
|
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
|
|
|
|
|
|
|
# useful for transformer model
|
|
# model.enable_gradient_checkpointing()
|
|
|
|
# enable bf16 except for layer_norm
|
|
# from flax import traverse_util
|
|
# flat_params = traverse_util.flatten_dict(model.params)
|
|
# mask = {
|
|
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
|
# }
|
|
# mask = traverse_util.unflatten_dict(mask)
|
|
# model.params = model.to_bf16(model.params, mask)
|
|
|
|
|
|
##################################################################
|
|
# set partition on model
|
|
|
|
# %%
|
|
# # let's output the model parameters to a json file for study
|
|
# import json
|
|
# shape_dict = jax.tree.map(jnp.shape, params)
|
|
# # print(json.dumps(shape_dict, sort_keys=True, indent=4))
|
|
# with open('t5.json', 'w') as f:
|
|
# json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
|
|
|
|
# MARK: setup mesh
|
|
# %%
|
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
|
print(device_mesh)
|
|
|
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
print(mesh)
|
|
|
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
return NamedSharding(mesh, pspec)
|
|
|
|
|
|
##################################################
|
|
# optimizers
|
|
# %%
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
training_size,
|
|
train_batch_size,
|
|
num_train_epochs,
|
|
warmup_steps,
|
|
learning_rate,
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=adam_beta1,
|
|
b2=adam_beta2,
|
|
eps=adam_epsilon,
|
|
weight_decay=weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
|
|
# %%
|
|
# specify sharding
|
|
|
|
# shard data
|
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
|
batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
|
|
# Defining the required dimensions for the self-attention layer input
|
|
# batch_size = 2
|
|
# seq_length = 768
|
|
# n_heads = 12
|
|
# head_dim = 768
|
|
# %%
|
|
|
|
# Create a large array with the shape (batch_size, seq_length, n_heads, head_dim)
|
|
# large_input = np.random.rand(2,768,768)
|
|
# batch = jax.device_put(large_input, x_sharding)
|
|
|
|
# %%
|
|
# jax.debug.visualize_array_sharding(batch['input_ids'])
|
|
|
|
# %%
|
|
# shard output
|
|
# we will shard state by tracking its output upon jax.eval_shape after init
|
|
# define an init function to return a TrainState
|
|
# def init_fn(rng, batch, model, optimizer):
|
|
# # do be careful with the model init
|
|
# # imported models might have complicated init methods
|
|
# variables = model.init(rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
|
# apply_fn=model.apply,
|
|
# params=variables['params'],
|
|
# tx=optimizer)
|
|
# return state
|
|
|
|
|
|
def init_fn(rng, batch, model, optimizer):
|
|
# do be careful with the model init
|
|
# imported models might have complicated init methods
|
|
variables = model.init(
|
|
rng,
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
decoder_input_ids=batch['decoder_attention_mask'],
|
|
decoder_attention_mask=batch['decoder_attention_mask']
|
|
)
|
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
|
apply_fn=model.apply,
|
|
params=variables['params'],
|
|
tx=optimizer)
|
|
return state
|
|
|
|
# %%
|
|
# alternative
|
|
# def init_fn(rng, batch, model, optimizer):
|
|
# # do be careful with the model init
|
|
# # imported models might have complicated init methods
|
|
# variables = model.init(
|
|
# rng, batch
|
|
# )
|
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
|
# apply_fn=model.apply,
|
|
# params=variables['params'],
|
|
# tx=optimizer)
|
|
# return state
|
|
|
|
|
|
# %%
|
|
# Create an abstract closure to wrap the function before feeding it in
|
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
|
# eval_shape(fn, rng_key, x)
|
|
# used to perform shape inference
|
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
|
rng, init_rng = jax.random.split(rng)
|
|
abstract_variables = jax.eval_shape(
|
|
functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
|
|
|
|
|
|
# %%
|
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
|
# of the `init_fn`.
|
|
# flan.linen.get_sharding
|
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
|
print(state_sharding)
|
|
|
|
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
|
|
|
|
|
|
# %%
|
|
jit_init_fn = jax.jit(
|
|
init_fn,
|
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
|
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
|
out_shardings=state_sharding
|
|
)
|
|
|
|
|
|
rng, init_rng = jax.random.split(rng)
|
|
initialized_state = jit_init_fn(rng, batch, model, adamw)
|
|
|
|
# %%
|
|
# we can analyze the params structure
|
|
# for weight, partitioned in initialized_state.params['decoder'].items():
|
|
# print(f'Sharding of {weight}: {partitioned}')
|
|
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
|
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
|
jax.tree.map(jnp.shape, initialized_state.params['decoder'])
|
|
|
|
|
|
# %%
|
|
print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding)
|
|
print(initialized_state.step)
|
|
print(initialized_state.step.sharding)
|
|
|
|
|
|
# %%
|
|
# train step
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
"""
|
|
The label smoothing implementation is adapted from Flax's official example:
|
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
confidence = 1.0 - label_smoothing_factor
|
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
normalizing_constant = -(
|
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
)
|
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
|
loss = loss - normalizing_constant
|
|
|
|
# ignore padded tokens from loss
|
|
loss = loss * padding_mask
|
|
loss = loss.sum()
|
|
num_labels = padding_mask.sum()
|
|
return loss, num_labels
|
|
|
|
# %%
|
|
|
|
# single device code annotated with jax.jit
|
|
@functools.partial(
|
|
jax.jit,
|
|
# in_shardings=(state_sharding, x_sharding),
|
|
out_shardings=state_sharding
|
|
)
|
|
def train_step(state, batch):
|
|
label_smoothing_factor=0.0
|
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
def compute_loss(params):
|
|
labels = batch.pop("labels")
|
|
logits = state.apply_fn(
|
|
{'params': params},
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
decoder_input_ids=batch['decoder_attention_mask'],
|
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
|
)[0]
|
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
return loss, num_labels
|
|
|
|
# compute gradients through computational graph
|
|
# allow values to pass through
|
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params)
|
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
|
|
# true grad = total grad / total samples
|
|
# grad = jax.lax.psum(grad, "batch")
|
|
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
new_state = state.apply_gradients(grads=grad)
|
|
|
|
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
return new_state
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# variables = model.init(
|
|
# rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
|
# batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
|
|
|
|
with mesh:
|
|
new_state = train_step(initialized_state, batch)
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
|
|
|
|
|
|
# # %%
|
|
# #############################################################
|
|
# # we cannot integrate our model pspec with train_state
|
|
# # we just shard separately
|
|
# # update: we also cannot use the method of modifying a partitionspec tree
|
|
# # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter
|
|
#
|
|
# # %%
|
|
# def get_optim_initial_state(params):
|
|
# params = params
|
|
# state = adamw.init(params)
|
|
# return tuple((state)), params
|
|
#
|
|
# # %%
|
|
# # create partitions for model
|
|
# from partitions import set_partitions
|
|
# # set_partitions freezes the params on return
|
|
# model_param_spec = set_partitions(unfreeze(params))
|
|
#
|
|
# # %%
|
|
# params_shapes = jax.tree.map(lambda x: x.shape, params)
|
|
# # actually tuple
|
|
# optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes)
|
|
#
|
|
# # %%
|
|
# # get pspec for opt_state
|
|
# def get_opt_spec(x):
|
|
# if isinstance(x, dict):
|
|
# return unfreeze(model_param_spec)
|
|
# return PartitionSpec()
|
|
#
|
|
# # this function replaces the empty model params spec with the 'model_param_spec'
|
|
# opt_state_spec, param_spec = jax.tree.map(
|
|
# get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
|
# )
|
|
#
|
|
# # %%
|
|
#
|
|
# model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
|
|
# # store on cpu
|
|
# model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
|
|
#
|
|
# # %%
|
|
# device_mesh = mesh_utils.create_device_mesh((2,2))
|
|
# print(device_mesh)
|
|
#
|
|
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
# print(mesh)
|
|
#
|
|
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
# return NamedSharding(mesh, pspec)
|
|
#
|
|
#
|
|
# # opt_state_sharding = mesh_sharding(opt_state_spec)
|
|
# # param_sharding = mesh_sharding(param_spec)
|
|
#
|
|
#
|
|
# # %%
|
|
# opt_state_sharding = nn.get_sharding(opt_state_spec, mesh)
|
|
# param_sharding = nn.get_sharding(param_spec, mesh)
|
|
#
|
|
# # %%
|
|
# # jit the get_initial_state function to shard params and init optimizer state in
|
|
# # a sharded way
|
|
# from jax.experimental.pjit import pjit
|
|
#
|
|
# with mesh:
|
|
# p_get_initial_state = pjit(
|
|
# get_optim_initial_state,
|
|
# in_shardings=None,
|
|
# out_shardings=(opt_state_spec, param_spec),
|
|
# )
|
|
#
|
|
# # Convert your PartitionSpec to NamedSharding for model params
|
|
# param_sharding = NamedSharding(mesh, freeze(param_spec))
|
|
# # Use device_put with sharding to move params onto the mesh
|
|
# sharded_params = jax.device_put(freeze(params), param_sharding)
|
|
#
|
|
# with mesh:
|
|
# # params is already frozen
|
|
# sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params))
|
|
#
|
|
# # %%
|
|
#
|
|
# # give up this section
|
|
# #############################################################
|
|
# # create train state
|
|
#
|
|
# # %%
|
|
# # Initialize random key and input for initialization
|
|
# rng = jax.random.PRNGKey(seed)
|
|
# loader_rng, rng = jax.random.split(rng)
|
|
# train_loader = dataprep.data_loader(rng, batch_size=2)
|
|
# batch = next(iter(train_loader))
|
|
#
|
|
# # use the T5 base model to do this
|
|
# from transformers import FlaxAutoModel
|
|
# model, params = FlaxAutoModel.from_pretrained(
|
|
# 't5-base',
|
|
# _do_init=False
|
|
# )
|
|
# t5_module = model.module
|
|
#
|
|
# # %%
|
|
# init_rng, rng = jax.random.split(rng)
|
|
# variables = t5_module.init(init_rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
# params = variables['params']
|
|
#
|
|
# # create an init function
|
|
# # %%
|
|
# # we will shard state by tracking its output upon jax.eval_shape after init
|
|
# # define an init function to return a TrainState
|
|
# def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState:
|
|
# init_rng, rng = jax.random.split(rng)
|
|
# variables = model.init(
|
|
# init_rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
# params = variables.pop("params")
|
|
# state = train_state.TrainState.create(
|
|
# apply_fn=model.__call__,
|
|
# params=params,
|
|
# tx=optimizer,
|
|
# )
|
|
# return state
|
|
#
|
|
# # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
|
|
# # Create an abstract closure to wrap the function before feeding it in
|
|
# # because `jax.eval_shape` only takes pytrees as arguments.
|
|
# # eval_shape(fn, rng_key, x)
|
|
# # used to perform shape inference
|
|
# # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
|
# init_rng, rng = jax.random.split(rng)
|
|
# abstract_variables = jax.eval_shape(
|
|
# functools.partial(init_fn, model=t5_module, optimizer=adamw),
|
|
# init_rng,
|
|
# batch
|
|
# )
|
|
#
|
|
# # %%
|
|
# # let's make our mesh
|
|
#
|
|
# device_mesh = mesh_utils.create_device_mesh((2,2))
|
|
# print(device_mesh)
|
|
#
|
|
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
# print(mesh)
|
|
#
|
|
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
# return NamedSharding(mesh, pspec)
|
|
#
|
|
# # %%
|
|
# # making jax compatible batch
|
|
#
|
|
# # %%
|
|
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
|
# # batch = jax.device_put(batch), x_sharding)
|
|
# # jax.debug.visualize_array_sharding(batch)
|
|
#
|
|
# # %%
|
|
# state_sharding = nn.get_sharding(abstract_variables, mesh)
|
|
# print(state_sharding)
|
|
#
|
|
# # %%
|
|
# # integrate model_param_specs and state_out_specs
|
|
#
|
|
# # %%
|
|
# # i want to make a Sharding object
|
|
# # model_sharding = mesh_sharding(model_param_spec)
|
|
#
|
|
# # %%
|
|
# jit_init_fn = jax.jit(
|
|
# init_fn, # rng, batch, model, optimizer
|
|
# static_argnames=('model', 'optimizer'), # skip model and optimizer
|
|
# in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data
|
|
# out_shardings=state_sharding
|
|
# )
|
|
#
|
|
# # %%
|
|
#
|
|
# init_rng, rng = jax.random.split(rng)
|
|
# initialized_state = jit_init_fn(
|
|
# init_rng,
|
|
# batch,
|
|
# t5_module,
|
|
# adamw)
|
|
#
|
|
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
|
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
|
#
|
|
#
|
|
# # %%
|
|
#
|
|
# # %%
|
|
#
|
|
# %%
|