# MARK: START # %% # let's make 8-device simulator import sys sys.dont_write_bytecode = True import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = True flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["JAX_PLATFORMS"] = "cpu" else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=false " "--xla_gpu_enable_async_collectives=true " "--xla_gpu_enable_latency_hiding_scheduler=true " "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags import functools from functools import partial from pprint import pprint from typing import Any, Dict, Tuple, Callable, Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding # from jax.experimental.pjit import pjit # superseded by jax.jit from jax.experimental import mesh_utils from jax.sharding import PartitionSpec from ml_collections import ConfigDict import optax import logging import time from datasets import Dataset, load_from_disk from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.core.frozen_dict import freeze, unfreeze import flax.core from partitions import set_partitions from tqdm import tqdm from dataload import DataPrepare PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] if USE_CPU_ONLY: jax.config.update('jax_platform_name', 'cpu') else: jax.config.update("jax_default_matmul_precision", "bfloat16") # %% # get platform type from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' save_path = 't5_80_1_bf16' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 5 batch_size = 2 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 2e-5 weight_decay = 0.01 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = 128 predict_with_generate = True # %% # prepare data # init object # e.g. Config data_config = ConfigDict( dict( max_length=86, pad_token_id=0, decoder_start_token_id=0 ) ) dataprep = DataPrepare(split_datasets['train'], data_config) # # example usage # # %% seed = 117 rng = jax.random.PRNGKey(seed) train_loader = dataprep.data_loader(rng, batch_size=batch_size) batch = next(iter(train_loader)) # batch # %% # model from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init from t5_model.pure_t5 import make_config config = make_config() model = model_init(config) # %% from transformers import FlaxT5ForConditionalGeneration from transformers import T5Config model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False) # useful for transformer model # model.enable_gradient_checkpointing() # enable bf16 except for layer_norm # from flax import traverse_util # flat_params = traverse_util.flatten_dict(model.params) # mask = { # path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params # } # mask = traverse_util.unflatten_dict(mask) # model.params = model.to_bf16(model.params, mask) ################################################################## # set partition on model # %% # # let's output the model parameters to a json file for study # import json # shape_dict = jax.tree.map(jnp.shape, params) # # print(json.dumps(shape_dict, sort_keys=True, indent=4)) # with open('t5.json', 'w') as f: # json.dump(shape_dict, fp=f, sort_keys=True, indent=2) # MARK: setup mesh # %% device_mesh = mesh_utils.create_device_mesh((2,2)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec) ################################################## # optimizers # %% def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( training_size, train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% # specify sharding # shard data x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()} # Defining the required dimensions for the self-attention layer input # batch_size = 2 # seq_length = 768 # n_heads = 12 # head_dim = 768 # %% # Create a large array with the shape (batch_size, seq_length, n_heads, head_dim) # large_input = np.random.rand(2,768,768) # batch = jax.device_put(large_input, x_sharding) # %% # jax.debug.visualize_array_sharding(batch['input_ids']) # %% # shard output # we will shard state by tracking its output upon jax.eval_shape after init # define an init function to return a TrainState # def init_fn(rng, batch, model, optimizer): # # do be careful with the model init # # imported models might have complicated init methods # variables = model.init(rng, # input_ids=batch['input_ids'], # attention_mask=batch['attention_mask'], # decoder_input_ids=batch['decoder_attention_mask'], # decoder_attention_mask=batch['decoder_attention_mask'] # ) # state = train_state.TrainState.create( # Create a `TrainState`. # apply_fn=model.apply, # params=variables['params'], # tx=optimizer) # return state def init_fn(rng, batch, model, optimizer): # do be careful with the model init # imported models might have complicated init methods variables = model.init( rng, input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], decoder_input_ids=batch['decoder_attention_mask'], decoder_attention_mask=batch['decoder_attention_mask'] ) state = train_state.TrainState.create( # Create a `TrainState`. apply_fn=model.apply, params=variables['params'], tx=optimizer) return state # %% # alternative # def init_fn(rng, batch, model, optimizer): # # do be careful with the model init # # imported models might have complicated init methods # variables = model.init( # rng, batch # ) # state = train_state.TrainState.create( # Create a `TrainState`. # apply_fn=model.apply, # params=variables['params'], # tx=optimizer) # return state # %% # Create an abstract closure to wrap the function before feeding it in # because `jax.eval_shape` only takes pytrees as arguments. # eval_shape(fn, rng_key, x) # used to perform shape inference # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves rng, init_rng = jax.random.split(rng) abstract_variables = jax.eval_shape( functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch) # %% # This `state_sharding` has the same pytree structure as `state`, the output # of the `init_fn`. # flan.linen.get_sharding # extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh # jax.sharding: describes how a jax.Array is laid out across devices state_sharding = nn.get_sharding(abstract_variables, mesh) print(state_sharding) # warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity # %% jit_init_fn = jax.jit( init_fn, static_argnames=('model', 'optimizer'), # skip model and optimizer in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data out_shardings=state_sharding ) rng, init_rng = jax.random.split(rng) initialized_state = jit_init_fn(rng, batch, model, adamw) # %% # we can analyze the params structure # for weight, partitioned in initialized_state.params['decoder'].items(): # print(f'Sharding of {weight}: {partitioned}') # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) jax.tree.map(jnp.shape, initialized_state.params['decoder']) # %% print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding) print(initialized_state.step) print(initialized_state.step.sharding) # %% # train step def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() num_labels = padding_mask.sum() return loss, num_labels # %% # single device code annotated with jax.jit @functools.partial( jax.jit, # in_shardings=(state_sharding, x_sharding), out_shardings=state_sharding ) def train_step(state, batch): label_smoothing_factor=0.0 # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn( {'params': params}, input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], decoder_input_ids=batch['decoder_attention_mask'], decoder_attention_mask=batch['decoder_attention_mask'], )[0] loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss, num_labels # compute gradients through computational graph # allow values to pass through grad_fn = jax.value_and_grad(compute_loss, has_aux=True) (loss, num_labels), grad = grad_fn(state.params) # num_labels = jax.lax.psum(num_labels, "batch") # true grad = total grad / total samples # grad = jax.lax.psum(grad, "batch") # grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad) # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state # %% # variables = model.init( # rng, # input_ids=batch['input_ids'], # attention_mask=batch['attention_mask'], # decoder_input_ids=batch['decoder_attention_mask'], # decoder_attention_mask=batch['decoder_attention_mask'] # ) # x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis # batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()} with mesh: new_state = train_step(initialized_state, batch) # %% # # %% # ############################################################# # # we cannot integrate our model pspec with train_state # # we just shard separately # # update: we also cannot use the method of modifying a partitionspec tree # # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter # # # %% # def get_optim_initial_state(params): # params = params # state = adamw.init(params) # return tuple((state)), params # # # %% # # create partitions for model # from partitions import set_partitions # # set_partitions freezes the params on return # model_param_spec = set_partitions(unfreeze(params)) # # # %% # params_shapes = jax.tree.map(lambda x: x.shape, params) # # actually tuple # optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes) # # # %% # # get pspec for opt_state # def get_opt_spec(x): # if isinstance(x, dict): # return unfreeze(model_param_spec) # return PartitionSpec() # # # this function replaces the empty model params spec with the 'model_param_spec' # opt_state_spec, param_spec = jax.tree.map( # get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) # ) # # # %% # # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True) # # store on cpu # model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params) # # # %% # device_mesh = mesh_utils.create_device_mesh((2,2)) # print(device_mesh) # # mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) # print(mesh) # # def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: # return NamedSharding(mesh, pspec) # # # # opt_state_sharding = mesh_sharding(opt_state_spec) # # param_sharding = mesh_sharding(param_spec) # # # # %% # opt_state_sharding = nn.get_sharding(opt_state_spec, mesh) # param_sharding = nn.get_sharding(param_spec, mesh) # # # %% # # jit the get_initial_state function to shard params and init optimizer state in # # a sharded way # from jax.experimental.pjit import pjit # # with mesh: # p_get_initial_state = pjit( # get_optim_initial_state, # in_shardings=None, # out_shardings=(opt_state_spec, param_spec), # ) # # # Convert your PartitionSpec to NamedSharding for model params # param_sharding = NamedSharding(mesh, freeze(param_spec)) # # Use device_put with sharding to move params onto the mesh # sharded_params = jax.device_put(freeze(params), param_sharding) # # with mesh: # # params is already frozen # sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params)) # # # %% # # # give up this section # ############################################################# # # create train state # # # %% # # Initialize random key and input for initialization # rng = jax.random.PRNGKey(seed) # loader_rng, rng = jax.random.split(rng) # train_loader = dataprep.data_loader(rng, batch_size=2) # batch = next(iter(train_loader)) # # # use the T5 base model to do this # from transformers import FlaxAutoModel # model, params = FlaxAutoModel.from_pretrained( # 't5-base', # _do_init=False # ) # t5_module = model.module # # # %% # init_rng, rng = jax.random.split(rng) # variables = t5_module.init(init_rng, # input_ids=batch['input_ids'], # attention_mask=batch['attention_mask'], # decoder_input_ids=batch['decoder_attention_mask'], # decoder_attention_mask=batch['decoder_attention_mask'] # ) # params = variables['params'] # # # create an init function # # %% # # we will shard state by tracking its output upon jax.eval_shape after init # # define an init function to return a TrainState # def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState: # init_rng, rng = jax.random.split(rng) # variables = model.init( # init_rng, # input_ids=batch['input_ids'], # attention_mask=batch['attention_mask'], # decoder_input_ids=batch['decoder_attention_mask'], # decoder_attention_mask=batch['decoder_attention_mask'] # ) # params = variables.pop("params") # state = train_state.TrainState.create( # apply_fn=model.__call__, # params=params, # tx=optimizer, # ) # return state # # # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True) # # Create an abstract closure to wrap the function before feeding it in # # because `jax.eval_shape` only takes pytrees as arguments. # # eval_shape(fn, rng_key, x) # # used to perform shape inference # # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves # init_rng, rng = jax.random.split(rng) # abstract_variables = jax.eval_shape( # functools.partial(init_fn, model=t5_module, optimizer=adamw), # init_rng, # batch # ) # # # %% # # let's make our mesh # # device_mesh = mesh_utils.create_device_mesh((2,2)) # print(device_mesh) # # mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) # print(mesh) # # def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: # return NamedSharding(mesh, pspec) # # # %% # # making jax compatible batch # # # %% # x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis # # batch = jax.device_put(batch), x_sharding) # # jax.debug.visualize_array_sharding(batch) # # # %% # state_sharding = nn.get_sharding(abstract_variables, mesh) # print(state_sharding) # # # %% # # integrate model_param_specs and state_out_specs # # # %% # # i want to make a Sharding object # # model_sharding = mesh_sharding(model_param_spec) # # # %% # jit_init_fn = jax.jit( # init_fn, # rng, batch, model, optimizer # static_argnames=('model', 'optimizer'), # skip model and optimizer # in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data # out_shardings=state_sharding # ) # # # %% # # init_rng, rng = jax.random.split(rng) # initialized_state = jit_init_fn( # init_rng, # batch, # t5_module, # adamw) # # # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) # # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) # # # # %% # # # %% # # %%