# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # kernelspec: # display_name: jax # language: python # name: python3 # --- # %% [markdown] # # T5 implementation using jax # %% [markdown] # ## import # %% [raw] # import json # import logging # import math # import os # import sys # import time # from dataclasses import asdict, dataclass, field # from enum import Enum # from functools import partial # from pathlib import Path # from typing import Callable, Optional # # import datasets # import evaluate # import jax # import jax.numpy as jnp # import nltk # Here to have a nice missing dependency error message early on # import numpy as np # import optax # from datasets import Dataset, load_dataset # from filelock import FileLock # from flax import jax_utils, traverse_util # from flax.jax_utils import pad_shard_unpad, unreplicate # from flax.training import train_state # from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key # from tqdm import tqdm # # import transformers # from transformers import ( # CONFIG_MAPPING, # FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, # AutoConfig, # AutoTokenizer, # FlaxAutoModelForSeq2SeqLM, # HfArgumentParser, # is_tensorboard_available, # ) # from transformers.utils import is_offline_mode, send_example_telemetry # # # logger = logging.getLogger(__name__) # # try: # nltk.data.find("tokenizers/punkt") # except (LookupError, OSError): # if is_offline_mode(): # raise LookupError( # "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" # ) # with FileLock(".lock") as lock: # nltk.download("punkt", quiet=True) # # # MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) # MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) # %% import jax import jax.numpy as jnp import optax import numpy as np from functools import partial from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "high") jax.config.update("jax_enable_x64", False) from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset, load_dataset import evaluate import nltk # Here to have a nice missing dependency error message early on from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key import time # %% import os os.environ['XLA_FLAGS'] = ( '--xla_gpu_enable_triton_softmax_fusion=True ' '--xla_gpu_triton_gemm_any=True ' ) os.environ.update({ "CUDA_VISIBLE_DEVICES": "0, 1, 2, 3", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", }) # %% from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) # %% # nltk.download('punkt') try: nltk.data.find("tokenizers/punkt") except (LookupError, OSError): if is_offline_mode(): raise LookupError( "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" ) with FileLock(".lock") as lock: nltk.download("punkt", quiet=True) # %% [markdown] # ## Prepare datasets # %% # load model model_name_or_path = "t5-small" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_name_or_path ) # %% model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # %% from tqdm import tqdm from datasets import load_from_disk # Path to saved combined_dataset file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5' save_path = 't5_80_1_retrieval' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) # prepare tokenizer from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens # additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] # Define additional special tokens additional_special_tokens = ["", "", "", "", "", "", "", "", "", ""] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) max_length = 300 # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): input = example['input'] target = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' model_inputs = tokenizer( input, text_target=target, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( input, text_target=target, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs # map maps function to each "row" in the dataset # aka the data in the immediate nesting tokenized_datasets = split_datasets.map( preprocess_function, batched=True, num_proc=1, remove_columns=split_datasets["train"].column_names, ) # %% tokenized_datasets # %% train_dataset = tokenized_datasets["train"] eval_dataset = tokenized_datasets["validation"] # %% def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = np.asarray(batch_idx) else: batch_idx = np.arange(len(dataset)) if drop_last: steps_per_epoch = len(dataset) // batch_size batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) else: steps_per_epoch = math.ceil(len(dataset) / batch_size) batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} yield batch # %% [markdown] # Now we have model inputs in terms of the variable tokenized_datasets # %% # metric metric = evaluate.load("sacrebleu") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels # def compute_metrics(preds, labels): # decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # # # Some simple post-processing # decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) # # result = metric.compute(predictions=decoded_preds, references=decoded_labels) # result = {k: round(v * 100, 4) for k, v in result.items()} # prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] # result["gen_len"] = np.mean(prediction_lens) # return result def compute_metrics(preds, labels): # In case the model returns more than the prediction logits if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # Replace -100s in the labels as we can't decode them labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds = [pred.strip() for pred in decoded_preds] decoded_labels = [[label.strip()] for label in decoded_labels] result = metric.compute(predictions=decoded_preds, references=decoded_labels) return {"bleu": result["score"]} # %% [markdown] # # Model # %% # Store some constant seed = 117 num_epochs = 80 batch_size = 36 num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 5e-5 weight_decay = 0.0 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = None predict_with_generate = True # %% # Initialize our training rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) # %% # optimization functions def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% # Training functions class TrainState(train_state.TrainState): dropout_rng: jnp.ndarray def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input params = model.params # Cast parameters to bfloat16 if desired params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() num_labels = padding_mask.sum() return loss, num_labels # Define gradient update step fn def train_step(state, batch, label_smoothing_factor=0.0): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss, num_labels # compute gradients through computational graph grad_fn = jax.value_and_grad(compute_loss, has_aux=True) (loss, num_labels), grad = grad_fn(state.params) num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples loss = jax.lax.psum(loss, "batch") loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # true grad = total grad / total samples grad = jax.lax.psum(grad, "batch") grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics # Define eval fn def eval_step(params, batch, label_smoothing_factor=0.0): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples loss = jax.lax.psum(loss, "batch") loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) metrics = {"loss": loss} return metrics # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length ) num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): model.params = params output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_train_step = jax.pmap( partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) ) p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() # %% print("***** Running training *****") print(f" Num examples = {len(train_dataset)}") print(f" Num Epochs = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") print(f" Total optimization steps = {total_train_steps}") # %% train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) # epochs = range(num_epochs) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_time += time.time() - train_start train_metric = unreplicate(train_metric) epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== # eval_metrics = [] # eval_preds = [] # eval_labels = [] # eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) # eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) # for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # # Model forward # batch = next(eval_loader) # labels = batch["labels"] # metrics = pad_shard_unpad(p_eval_step, static_return=True)( # state.params, batch, min_device_batch=per_device_eval_batch_size # ) # eval_metrics.append(metrics) # # generation # if predict_with_generate: # generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) # eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) # eval_labels.extend(labels) # # normalize eval metrics # eval_metrics = get_metrics(eval_metrics) # eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # compute metrics # rouge_desc = "" # if predict_with_generate: # rouge_metrics = compute_metrics(eval_preds, eval_labels) # eval_metrics.update(rouge_metrics) # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) # # Print metrics and update progress bar # desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" # epochs.write(desc) # epochs.desc = desc # Save metrics # if has_tensorboard and jax.process_index() == 0: # cur_step = epoch * (len(train_dataset) // train_batch_size) # write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(output_dir, params=params) tokenizer.save_pretrained(output_dir) # %% [markdown] # #