From d2dd72227f0a462886be239a9157fac9a093cbc0 Mon Sep 17 00:00:00 2001 From: Richard Wong Date: Sat, 14 Sep 2024 02:02:45 +0900 Subject: [PATCH] Feat: introduced efficient train data dtype, jit train step, bfloat16 mat mul --- .gitignore | 1 + t5_jax.py | 227 +++++++++++++++++++++++++++---------------- t5_jax_prediction.py | 2 +- 3 files changed, 144 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index d20f0d4..ef0f135 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ t5_*/ exports/ modified_t5_model/ +traces/ diff --git a/t5_jax.py b/t5_jax.py index 952d1a3..e35e0e5 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -16,6 +16,7 @@ # %% [markdown] # # T5 implementation using jax + # %% import jax import jax.numpy as jnp @@ -26,15 +27,15 @@ from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") -jax.config.update("jax_default_matmul_precision", "high") -jax.config.update("jax_enable_x64", False) +jax.config.update("jax_default_matmul_precision", "bfloat16") +# jax.config.update("jax_enable_x64", False) # enable cache jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) -from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig +# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets @@ -56,6 +57,7 @@ import flax.core import time + # %% import os os.environ['XLA_FLAGS'] = ( @@ -63,14 +65,17 @@ os.environ['XLA_FLAGS'] = ( ) os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", "CUDA_DEVICE_MAX_CONNECTIONS" : "1", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : ".95" + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) # %% +# get platform type from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) @@ -82,28 +87,40 @@ except (LookupError, OSError): print("error") -# %% [markdown] -# ## Prepare datasets - # %% -# load model -model_name_or_path = "t5-small" # Replace with your specific model name - -# Load configuration -config = AutoConfig.from_pretrained(model_name_or_path, - force_download=False) - - -# %% -# Path to saved combined_dataset +# config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' -save_path = 't5_80_1' +save_path = 't5_80_1_bf16' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 80 +batch_size = 384 # 384 is the best +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 2e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = 128 + +predict_with_generate = True -# %% -split_datasets['train'][0] # %% from transformers import T5TokenizerFast @@ -134,16 +151,48 @@ len(tokenizer) # %% -model_path = './t5_80_1' -# model_path = 't5=base' -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - pretrained_model_name_or_path=model_path, - dtype=jax.numpy.float32 -) +# model_path = './t5_80_1' +# model_path = 't5-base' +# model = FlaxAutoModelForSeq2SeqLM.from_pretrained( +# pretrained_model_name_or_path=model_path, +# dtype=jax.numpy.bfloat16 +# ) +# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration +# from t5_model.configuration_t5 import T5Config +from transformers import FlaxT5ForConditionalGeneration +from transformers import T5Config + +config = T5Config() # %% -model.params_shape_tree['shared'] +# If you want don't want to cast certain parameters (for example layer norm bias and scale) +# then pass the mask as follows +from flax import traverse_util + +model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") +# useful for transformer model +model.enable_gradient_checkpointing() + +# enable bf16 except for layer_norm +flat_params = traverse_util.flatten_dict(model.params) +mask = { + path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params +} +mask = traverse_util.unflatten_dict(mask) +model.params = model.to_bf16(model.params, mask) + +# %% +# # Function to extract shape and dtype without showing values +# def format_param(param): +# return f"shape={param.shape}, dtype={param.dtype}" +# +# # Use jax.tree_map to apply the formatter across the parameter tree +# formatted_params = jax.tree.map(format_param, model.params) +# +# # Pretty-print the tree +# for k, v in formatted_params.items(): +# print(f"{k}: {v}") # %% model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) @@ -164,6 +213,7 @@ def preprocess_function(example): targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' + # produce input_ids and decoder_input_ids model_inputs = tokenizer( inputs, max_length=max_length, @@ -179,39 +229,68 @@ def preprocess_function(example): return_tensors="np" ) + # for loss computation model_inputs["labels"] = labels["input_ids"] + # make decoder input ids decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) + # require by model model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) - # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs +# %% +# temp + # map maps function to each "row" in the dataset # aka the data in the immediate nesting -tokenized_datasets = split_datasets.map( +token_datasets = split_datasets.map( preprocess_function, batched=True, num_proc=1, + # if we do not remove, we keep the original data remove_columns=split_datasets["train"].column_names, ) - -tokenized_datasets.set_format(type='numpy', - columns=['input_ids', 'attention_mask', - 'labels', 'decoder_input_ids', - 'decoder_attention_mask']) +train_dataset = token_datasets["train"] # %% -train_dataset = tokenized_datasets["train"] -eval_dataset = tokenized_datasets["validation"] + +token_datasets.set_format( + type='numpy', + columns=[ + 'input_ids', 'attention_mask', + 'labels', 'decoder_input_ids', + 'decoder_attention_mask'] +) +# %% +# check values +for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: + int_array = train_dataset[name] + if np.all((int_array >= 0) & (int_array <= 65535)): + uint16_array = int_array.astype(np.uint16) + else: + raise ValueError("Values are out of range for uint16") # %% -train_dataset[0] +from datasets import ClassLabel, Value, Sequence +features = train_dataset.features.copy() +features['input_ids'] = Sequence(Value('uint16')) +features['attention_mask'] = Sequence(Value('bool')) +features['labels'] = Sequence(Value('uint16')) +features['decoder_input_ids'] = Sequence(Value('uint16')) +features['decoder_attention_mask'] = Sequence(Value('bool')) +train_dataset = train_dataset.cast(features) + + + +# %% +# temp +print('data type check: ', train_dataset['decoder_attention_mask'].dtype) # %% def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): @@ -243,32 +322,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf # %% [markdown] # # Model -# %% -# Store some constant -seed = 117 -num_epochs = 40 -batch_size = 32 -num_train_epochs = num_epochs -per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() -per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() -steps_per_epoch = len(train_dataset) // train_batch_size -total_train_steps = steps_per_epoch * num_epochs -warmup_steps = 0 -learning_rate = 2e-5 - -weight_decay = 0.01 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_epsilon = 1e-8 -label_smoothing_factor = 0.0 - -num_beams = 1 -val_max_target_length = 128 - -predict_with_generate = True # %% @@ -342,12 +396,11 @@ class TrainState(train_state.TrainState): # set bf16 for model params # model.params = model.to_bf16(model.params) -params = model.params # Cast parameters to bfloat16 if desired # params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) # Setup train state -state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng) +state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): @@ -373,6 +426,7 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): return loss, num_labels # Define gradient update step fn +@jax.jit def train_step(state, batch, label_smoothing_factor=0.0): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) @@ -406,17 +460,17 @@ max_length = ( num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} -def generate_step(params, batch): - model.params = params - output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) - return output_ids.sequences +# def generate_step(params, batch): +# model.params = params +# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) +# return output_ids.sequences # Create parallel version of the train and eval step p_train_step = jax.pmap( partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) ) # p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") -p_generate_step = jax.pmap(generate_step, "batch") +# p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() @@ -435,41 +489,44 @@ print(f" Total optimization steps = {total_train_steps}") # %% +# jax.profiler.start_trace("./traces") +rng, input_rng = jax.random.split(rng) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) -# epochs = range(num_epochs) for epoch in epochs: - # ======================== Training ================================ train_start = time.time() # Create sampling rng - rng, input_rng = jax.random.split(rng) train_metrics = [] - - # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size - # train + # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) - train_time += time.time() - train_start + train_time = time.time() - train_start train_metric = unreplicate(train_metric) + train_metric['loss'].block_until_ready() + + epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" - f" {train_metric['learning_rate']})" + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " + f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" ) +# jax.profiler.stop_trace() +# %% - output_dir = save_path - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) - model.save_pretrained(output_dir, params=params) - tokenizer.save_pretrained(output_dir) - +# output_dir = save_path +# # save checkpoint after each epoch and push checkpoint to the hub +# if jax.process_index() == 0: +# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) +# model.save_pretrained(output_dir, params=params) +# tokenizer.save_pretrained(output_dir) diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index 5961448..49a751c 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -92,7 +92,7 @@ test_dataset = Dataset.from_list(process_df(df)) # %% # load model -model_name_or_path = "./t5_80_1" # Replace with your specific model name +model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path)