# %% [markdown] # # T5 implementation using jax with pjit # MARK: START # %% # let's make 8-device simulator import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = False flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["JAX_PLATFORMS"] = "cpu" else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " # "--xla_gpu_triton_gemm_any=false " # "--xla_gpu_enable_async_collectives=true " # "--xla_gpu_enable_latency_hiding_scheduler=true " # "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags import functools from functools import partial from pprint import pprint from typing import Any, Dict, Tuple, Callable, Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from ml_collections import ConfigDict import optax import logging import time from datasets import Dataset, load_from_disk from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key import flax.core from tqdm import tqdm from dataload import DataPrepare PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] if USE_CPU_ONLY: jax.config.update('jax_platform_name', 'cpu') else: jax.config.update("jax_default_matmul_precision", "bfloat16") # # %% # import jax # import jax.numpy as jnp # import optax # import numpy as np # from functools import partial # from typing import Callable, Optional # import math # # # jax.config.update("jax_default_matmul_precision", "tensorfloat32") # jax.config.update("jax_default_matmul_precision", "bfloat16") # # jax.config.update("jax_enable_x64", False) # # enable cache # jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") # jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) # jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) # # # # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig # # from flax import jax_utils, traverse_util # from flax.jax_utils import pad_shard_unpad, unreplicate # from flax.training import train_state # from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key # import flax.core # %% # get platform type from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' save_path = 't5_80_1_bf16' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 5 batch_size = 384 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 2e-5 weight_decay = 0.01 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = 128 predict_with_generate = True # %% # prepare data # init object # e.g. Config data_config = ConfigDict( dict( max_length=86, pad_token_id=0, decoder_start_token_id=0 ) ) dataprep = DataPrepare(split_datasets['train'], data_config) # # example usage # # %% # seed = 117 # rng = jax.random.PRNGKey(seed) # train_loader = dataprep.data_loader(rng, batch_size=1) # batch = next(iter(train_loader)) # %% # model from transformers import FlaxT5ForConditionalGeneration from transformers import T5Config config = T5Config() # If you want don't want to cast certain parameters (for example layer norm bias and scale) # then pass the mask as follows from flax import traverse_util model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") # useful for transformer model model.enable_gradient_checkpointing() # enable bf16 except for layer_norm flat_params = traverse_util.flatten_dict(model.params) mask = { path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params } mask = traverse_util.unflatten_dict(mask) model.params = model.to_bf16(model.params, mask) # %% # %% from jax.sharding import Mesh, NamedSharding from jax.experimental import mesh_utils from jax.sharding import PartitionSpec as P from pjit_partition import set_partitions devices = np.asarray(jax.devices()) mesh_axis_names = ('data') mesh = Mesh(devices, 'batch') sharding = NamedSharding(mesh, P(mesh_axis_names)) replicated_sharding = NamedSharding(mesh, P()) # %% [markdown] # # Model # # # # %% # Initialize our training rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) # %% # optimization functions def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( training_size, train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% # state will serve as our "params" state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() num_labels = padding_mask.sum() return loss, num_labels # MARK: train_step # Define gradient update step fn def train_step(state, batch): label_smoothing_factor=0.0 dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss, num_labels # compute gradients through computational graph grad_fn = jax.value_and_grad(compute_loss, has_aux=True) (loss, num_labels), grad = grad_fn(state.params) num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples # loss = jax.lax.psum(loss, "batch") # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # true grad = total grad / total samples grad = jax.lax.psum(grad, "batch") grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # max_length = ( # val_max_target_length if val_max_target_length is not None else model.config.max_length # ) # num_beams = num_beams if num_beams is not None else model.config.num_beams # gen_kwargs = {"max_length": max_length, "num_beams": num_beams} # Create parallel version of the train and eval step # only state and batch p_train_step = jax.jit( train_step, # state for first, batch for second in_shardings=(P("data"), P("data")), out_shardings=(P("data"), P("data")), donate_argnames=("state"), ) # %% print("***** Running training *****") print(f" Num examples = {training_size}") print(f" Num Epochs = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") print(f" Total optimization steps = {total_train_steps}") # %% # jax.profiler.start_trace("./traces") # Example batch (sharded across devices) sharded_batch = { 'input_ids': jax.device_put_sharded(batch['input_ids'], devices), 'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices), 'labels': jax.device_put_sharded(batch['labels'], devices), 'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices), 'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices), } # Initial TrainState (pjit-ted TrainState) sharded_state = jax.device_put_replicated(train_state, devices) # %% rng, input_rng = jax.random.split(rng) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() # Create sampling rng train_metrics = [] rng, data_rng = jax.random.split(rng) train_loader = dataprep.data_loader(data_rng, batch_size=batch_size) steps_per_epoch = training_size // train_batch_size # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) # batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_time = time.time() - train_start train_metric = unreplicate(train_metric) train_metric['loss'].block_until_ready() epochs.write( # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " f"Epoch... ({epoch + 1}/{num_epochs} | " # f"Learning Rate:{train_metric['learning_rate']}, " f"Last train time: {train_time})" ) # jax.profiler.stop_trace() # %% # output_dir = save_path # # save checkpoint after each epoch and push checkpoint to the hub # if jax.process_index() == 0: # params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) # params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) # model.save_pretrained(output_dir, params=params) # tokenizer.save_pretrained(output_dir)