# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # kernelspec: # display_name: jax # language: python # name: python3 # --- # %% [markdown] # # T5 implementation using jax # %% import jax import jax.numpy as jnp import optax import numpy as np from functools import partial from typing import Callable, Optional import math import flax.linen as nn # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "bfloat16") # jax.config.update("jax_enable_x64", False) # enable cache jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset, load_dataset import evaluate from tqdm import tqdm from datasets import load_from_disk import nltk # Here to have a nice missing dependency error message early on from typing import Dict, Any, Union from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.core.frozen_dict import FrozenDict, unfreeze import flax.core import time # %% import os # os.environ['XLA_FLAGS'] = ( # '--xla_gpu_triton_gemm_any=true ' # '--xla_gpu_enable_custom_fusions=true ' # '--xla_gpu_enable_address_computation_fusion=true' # ) flags = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_triton_gemm_any=True ' # '--xla_gpu_enable_async_collectives=true ' '--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_highest_priority_async_stream=true ' ) os.environ["XLA_FLAGS"] = flags os.environ.update({ "TOKENIZERS_PARALLELISM" : "false", "CUDA_DEVICE_MAX_CONNECTIONS" : "1", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) # %% # get platform type from jax.extend.backend import get_backend print(get_backend().platform) # %% try: nltk.data.find("tokenizers/punkt") except (LookupError, OSError): print("error") # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 40 batch_size = 64 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 2e-4 weight_decay = 0.01 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = 128 predict_with_generate = True # %% from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens additional_special_tokens = ["", "", "", "", "", "", "", "", ""] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) max_length = 128 # %% len(tokenizer) # %% # load pytorch model first # from transformers import AutoModelForSeq2SeqLM # model_checkpoint = "t5-base" # model_pt = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) # # important! after extending tokens vocab # model_pt.resize_token_embeddings(len(tokenizer)) # model_pt.save_pretrained('./modified_t5_model') # model = FlaxAutoModelForSeq2SeqLM.from_pretrained( # pretrained_model_name_or_path="modified_t5_model", # dtype=jax.numpy.bfloat16, # from_pt=True # ) # %% # model_path = './t5_80_1' # model_path = 't5-base' # model = FlaxAutoModelForSeq2SeqLM.from_pretrained( # pretrained_model_name_or_path=model_path, # dtype=jax.numpy.bfloat16 # ) # from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration # from t5_model.configuration_t5 import T5Config from transformers import FlaxT5ForConditionalGeneration from transformers import T5Config config = T5Config() model = FlaxT5ForConditionalGeneration.from_pretrained( "t5-base", dtype=jnp.bfloat16, gradient_checkpointing=True ) params = model.params # enable bf16 except for layer_norm # enable bf16 # enable only for dense, some transformer sections, and shared def create_mask_for_layer_norm(params): flat_params = traverse_util.flatten_dict(params) mask = { # path: not ( # (path[-2] == "layer_norm" and path[-1] == "weight") or # (path[-2] == "final_layer_norm" and path[-1] == "weight") or # (path[-2] == "o" and path[-1] == "kernel") # ) # for path in flat_params path: ( (path[-2] == "wi" and path[-1] == "weight") or (path[-2] == "wo" and path[-1] == "weight") or (path[-2] == "k" and path[-1] == "kernel") or (path[-2] == "q" and path[-1] == "kernel") or (path[-2] == "v" and path[-1] == "kernel") or (path[-2] == "shared" and path[-1] == "embedding") ) for path in flat_params } mask = traverse_util.unflatten_dict(mask) return mask # borrowed from transformers modeling_flax_utils def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_util.tree_map(conditional_cast, params) flat_params = traverse_util.flatten_dict(params) flat_mask, _ = jax.tree_util.tree_flatten(mask) for masked, key in zip(flat_mask, sorted(flat_params.keys())): if masked: flat_params[key] = conditional_cast(flat_params[key]) return traverse_util.unflatten_dict(flat_params) mask = create_mask_for_layer_norm(params) # override params with bfloat version params= cast_floating_to(params, jnp.bfloat16, mask) # %% # # Function to extract shape and dtype without showing values # def format_param(param): # return f"shape={param.shape}, dtype={param.dtype}" # # # Use jax.tree_map to apply the formatter across the parameter tree # formatted_params = jax.tree.map(format_param, model.params) # # # Pretty-print the tree # for k, v in formatted_params.items(): # print(f"{k}: {v}") # %% model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 # %% # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): inputs = example['input'] targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' # produce input_ids and decoder_input_ids model_inputs = tokenizer( inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( text_target=targets, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) # for loss computation model_inputs["labels"] = labels["input_ids"] # make decoder input ids decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) # require by model model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs # %% # temp # map maps function to each "row" in the dataset # aka the data in the immediate nesting token_datasets = split_datasets.map( preprocess_function, batched=True, num_proc=1, # if we do not remove, we keep the original data remove_columns=split_datasets["train"].column_names, ) train_dataset = token_datasets["train"] # %% token_datasets.set_format( type='numpy', columns=[ 'input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask'] ) # %% def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = np.asarray(batch_idx) else: batch_idx = np.arange(len(dataset)) if drop_last: steps_per_epoch = len(dataset) // batch_size batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) else: steps_per_epoch = math.ceil(len(dataset) / batch_size) batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] batch = {k: v for k, v in batch.items()} yield batch # %% # Initialize our training rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) # %% # optimization functions def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["final_layer_norm", "layer_norm"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% # Training functions class TrainState(train_state.TrainState): dropout_rng: jnp.ndarray # easy way to achieve data parallelism # also achieves folding of rng keys def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) # Setup train state # input all the state here state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() num_labels = padding_mask.sum() return loss, num_labels # Define gradient update step fn @jax.jit def train_step(state, batch, label_smoothing_factor=0.0): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss, num_labels # compute gradients through computational graph grad_fn = jax.value_and_grad(compute_loss, has_aux=True) (loss, num_labels), grad = grad_fn(state.params) num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples loss = jax.lax.psum(loss, "batch") loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # true grad = total grad / total samples grad = jax.lax.psum(grad, "batch") grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # return new_state # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length ) num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} # def generate_step(params, batch): # model.params = params # output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) # return output_ids.sequences # Create parallel version of the train and eval step p_train_step = jax.pmap( partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) ) # p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") # p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() # %% print("***** Running training *****") print(f" Num examples = {len(train_dataset)}") print(f" Num Epochs = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") print(f" Total optimization steps = {total_train_steps}") # %% # jax.profiler.start_trace("./traces") rng, input_rng = jax.random.split(rng) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() # Create sampling rng train_metrics = [] train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) # train_metrics.append(train_metric) train_time = time.time() - train_start train_metric = unreplicate(train_metric) train_metric['loss'].block_until_ready() epochs.write( # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " f"Epoch... ({epoch + 1}/{num_epochs} | " f"Learning Rate:{train_metric['learning_rate']}, " f"Last train time: {train_time})" ) # jax.profiler.stop_trace() # %% output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) model.save_pretrained(output_dir, params=params) tokenizer.save_pretrained(output_dir) # %%