# %% import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = False flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["JAX_PLATFORMS"] = "cpu" else: # GPU flags flags = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_triton_gemm_any=True ' # '--xla_gpu_enable_async_collectives=true ' '--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_highest_priority_async_stream=true ' ) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" os.environ["XLA_FLAGS"] = flags os.environ.update({ "TOKENIZERS_PARALLELISM" : "false", "CUDA_DEVICE_MAX_CONNECTIONS" : "1", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) import functools from functools import partial from pprint import pprint from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding # from jax.experimental.pjit import pjit # superseded by jax.jit from jax.experimental import mesh_utils from jax.sharding import PartitionSpec from ml_collections import ConfigDict import optax import logging import time from datasets import Dataset, load_from_disk from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.core.frozen_dict import freeze, unfreeze, FrozenDict import flax.core # model checkpointing and saving utilities from flax import linen as nn from flax.training import checkpoints, train_state from flax import struct, serialization from parallel.partitions import set_partitions from tqdm import tqdm from parallel.dataload import DataPrepare # for memory tracking # from jax_smi import initialise_tracking # initialise_tracking() PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] if USE_CPU_ONLY: jax.config.update('jax_platform_name', 'cpu') else: jax.config.update("jax_default_matmul_precision", "bfloat16") jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) # %% ## get platform type from jax.extend.backend import get_backend print(get_backend().platform) print(jax.devices()) # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 5 batch_size = 64 num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 5e-5 weight_decay = 0.01 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = 128 predict_with_generate = True # %% # prepare data # init object # e.g. Config print("preparing data") data_config = ConfigDict( dict( max_length=128, pad_token_id=0, decoder_start_token_id=0 ) ) dataprep = DataPrepare(split_datasets['train'], data_config) # # example usage # # %% seed = 117 rng = jax.random.PRNGKey(seed) train_loader = dataprep.data_loader(rng, batch_size=batch_size) batch = next(iter(train_loader)) # %% # from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model main_model = custom_model.from_pretrained( "t5-base", dtype=jnp.bfloat16, gradient_checkpointing=True, ) params = main_model.params # pretrained_params = model.params model = main_model.module # %% # # testing config hashability # # some explanation: # # The PreTrainedModel class loads a T5Config model that is not hashable because # # it is a complicated class that pretends to be a dataclass. # # The solution is to extract a dict from it, then make a ConfigDict from # # ml_collections library so that we can get values via the "." operator. # # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to # # modify the config before passing to the next layer # from transformers import T5Config # from t5_model.configuration_t5 import FrozenT5Config # from ml_collections import ConfigDict, FrozenConfigDict # # config = T5Config.from_pretrained("t5-base").to_dict() # config.pop('architectures') # config.pop('id2label') # # test if it works # frozen_config = FrozenConfigDict(config) # # test hash # hash(frozen_config) # %% # # print model # rng, input_rng = jax.random.split(rng) # model.tabulate( # input_rng, # input_ids=batch['input_ids'], # attention_mask=batch['attention_mask'], # decoder_input_ids=batch['decoder_input_ids'], # decoder_attention_mask=batch['decoder_attention_mask'], # console_kwargs={"force_jupyter": True} # ) # %% # create mesh print("creating mesh") device_mesh = mesh_utils.create_device_mesh((2,2)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec, memory_kind="device") x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis model_sharding=mesh_sharding(PartitionSpec(None, 'model')) # %% # optimizers def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( training_size, train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% print("compile") # enable bf16 # enable only for dense, some transformer sections, and shared def create_mask_for_layer_norm(params): flat_params = traverse_util.flatten_dict(params) mask = { # path: not ( # (path[-2] == "layer_norm" and path[-1] == "weight") or # (path[-2] == "final_layer_norm" and path[-1] == "weight") or # (path[-2] == "o" and path[-1] == "kernel") # ) # for path in flat_params path: ( (path[-2] == "wi" and path[-1] == "weight") or (path[-2] == "wo" and path[-1] == "weight") or (path[-2] == "k" and path[-1] == "kernel") or (path[-2] == "q" and path[-1] == "kernel") or (path[-2] == "v" and path[-1] == "kernel") or (path[-2] == "shared" and path[-1] == "embedding") ) for path in flat_params } mask = traverse_util.unflatten_dict(mask) return mask # borrowed from transformers modeling_flax_utils def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_util.tree_map(conditional_cast, params) flat_params = traverse_util.flatten_dict(params) flat_mask, _ = jax.tree_util.tree_flatten(mask) for masked, key in zip(flat_mask, sorted(flat_params.keys())): if masked: flat_params[key] = conditional_cast(flat_params[key]) return traverse_util.unflatten_dict(flat_params) # create init_fn to produce sharded state def init_fn(params, model, optimizer): # do be careful with the model init # imported models might have complicated init methods # mask = create_mask_for_layer_norm(params) # override params with bfloat version # params= cast_floating_to(params, jnp.bfloat16, mask) state = train_state.TrainState.create( # Create a `TrainState`. apply_fn=model.apply, params=params, tx=optimizer) return state abstract_variables = jax.eval_shape( functools.partial(init_fn, model=model, optimizer=adamw), params) # jax.sharding: describes how a jax.Array is laid out across devices state_sharding = nn.get_sharding(abstract_variables, mesh) # print(state_sharding) # %% # replace the params tree with the new modified tree # create partitions for model from parallel.partitions import set_partitions # set_partitions freezes the params on return model_part_spec = set_partitions(unfreeze(params)) # p is already a partition spec model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) # get pspec for opt_state def get_opt_spec(x): if isinstance(x, dict): return unfreeze(model_named_sharding) # return an empty partspec return mesh_sharding((PartitionSpec())) # this function replaces the empty model params spec with the 'model_named_shard' state_sharding = jax.tree.map( get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) ) jit_init_fn = jax.jit( init_fn, static_argnames=('model', 'optimizer'), # skip model and optimizer in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly out_shardings=state_sharding # but returned initialized_state is sharded ) initialized_state = jit_init_fn(params, model, adamw) # %% # train step def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) logits = jnp.asarray(logits, dtype=jnp.float32) logits = logits.astype(jnp.float32) soft_labels = soft_labels.astype(jnp.float32) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.mean() # num_labels = padding_mask.mean() return loss # , num_labels # %% # gradient accumulation def accumulate_gradients_loop( state, batch, minibatch_size: int, loss_fn: Callable, ) -> Tuple[PyTree, Metrics]: """Calculate gradients and metrics for a batch using gradient accumulation. Args: state: Current training state. batch: Full training batch. rng: Random number generator to use. num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. loss_fn: Loss function to calculate gradients and metrics. Returns: Tuple with accumulated gradients and metrics over the minibatches. """ batch_size = batch['input_ids'].shape[0] # minibatch_size = batch_size // num_minibatches num_minibatches = batch_size // minibatch_size # Define gradient function for single minibatch. # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. # otherwise it returns (value, gradient), where value is the actual output # of the function, hence the "value" of the namesake grad_fn = jax.value_and_grad(loss_fn, has_aux=False) # Prepare loop variables. grads = None metrics = None for minibatch_idx in range(num_minibatches): with jax.named_scope(f"minibatch_{minibatch_idx}"): # Split the batch into minibatches. start = minibatch_idx * minibatch_size end = start + minibatch_size minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023 # Calculate gradients and metrics for the minibatch. # missing value is mean loss of batch loss, step_grads = grad_fn( state.params, minibatch ) with jax.named_scope("sync_metrics"): step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} # Accumulate gradients and metrics across minibatches. if grads is None: grads = step_grads metrics = step_metrics else: # accumulation adder grads = jax.tree.map(jnp.add, grads, step_grads) metrics = jax.tree.map(jnp.add, metrics, step_metrics) # Average gradients over minibatches. grads = jax.tree.map(lambda g: g / num_minibatches, grads) return grads, metrics # single device code annotated with jax.jit @functools.partial( jax.jit, # state is state_sharding initialized from init_fn # x_sharding is data sharded explicitly later in_shardings=(state_sharding, x_sharding), out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), donate_argnames=('state'), ) def train_step(state, batch): label_smoothing_factor=0.0 # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params, batch): # check constraints # frozen dict not allowed as sharding object params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) batch = jax.lax.with_sharding_constraint(batch, x_sharding) logits = state.apply_fn( {'params': params}, input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], decoder_input_ids=batch['decoder_input_ids'], decoder_attention_mask=batch['decoder_attention_mask'], )[0] # zero because output is some structure, where first is the logit # use labels here # loss, num_labels = loss_fn( loss = loss_fn( logits, batch["labels"], batch["decoder_attention_mask"], label_smoothing_factor) return loss # , num_labels # # compute gradients through computational graph # # allow values to pass through # grad_fn = jax.value_and_grad(compute_loss, has_aux=False) # (loss), grad = grad_fn(state.params, batch) # # num_labels = jax.lax.psum(num_labels, "batch") # new_state = state.apply_gradients(grads=grad) # with jax.named_scope("sync_metrics"): # step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} # use gradient accumulation grads, step_metrics = accumulate_gradients_loop( state=state, batch=batch, minibatch_size=32, loss_fn=compute_loss ) new_state = state.apply_gradients(grads=grads) return new_state, step_metrics # %% # explore data sharding sharded_batch = next(iter(train_loader)) sharded_batch = jax.device_put(sharded_batch, x_sharding) # jax.debug.visualize_array_sharding(sharded_batch['input_ids']) # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) # %% # # prep 1 step # print("1 step for jit-ting") # with mesh: # state, metrics = train_step(initialized_state, sharded_batch) # %% # %% # tr print("***** Running training *****") print(f" Num examples = {training_size}") print(f" Num Epochs = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") print(f" Total optimization steps = {total_train_steps}") # %% # jax.profiler.start_trace("./traces") print("*" * 10) print("training start") rng, input_rng = jax.random.split(rng) train_time = 0 state = initialized_state epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() # Create sampling rng train_metrics = [] steps_per_epoch = training_size // train_batch_size train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = jax.device_put(batch, x_sharding) with mesh: state, train_metric = train_step(state, batch) # train_metrics.append(train_metric) # this is for more accurate time stats, but slows down training # train_metric['loss'].block_until_ready() train_time = time.time() - train_start epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | " f"Loss: {train_metric['loss']}, " f"Learning Rate:{train_metric['learning_rate']}, " f"Last train time: {train_time})" ) # jax.profiler.stop_trace() # %% # try out gather_state = jax.device_get(state) gather_batch = jax.device_get(batch) logits = gather_state.apply_fn( {'params': gather_state.params}, input_ids=gather_batch['input_ids'], attention_mask=gather_batch['attention_mask'], decoder_input_ids=gather_batch['decoder_input_ids'], decoder_attention_mask=gather_batch['decoder_attention_mask'], )[0] # zero because output is some structure, where first is the logit probs = nn.softmax(logits, axis=-1) predicted = jnp.argmax(probs, axis=-1) print("sample output") print(predicted[1]) # %% main_model = custom_model.from_pretrained('t5-base') output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(state.params) main_model.save_pretrained(output_dir, params=params) # %%