Feat: introduced efficient train data dtype, jit train step, bfloat16

mat mul
This commit is contained in:
Richard Wong 2024-09-14 02:02:45 +09:00
parent edd9c3551f
commit d2dd72227f
3 changed files with 144 additions and 86 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
t5_*/
exports/
modified_t5_model/
traces/

225
t5_jax.py
View File

@ -16,6 +16,7 @@
# %% [markdown]
# # T5 implementation using jax
# %%
import jax
import jax.numpy as jnp
@ -26,15 +27,15 @@ from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
jax.config.update("jax_default_matmul_precision", "bfloat16")
# jax.config.update("jax_enable_x64", False)
# enable cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
@ -56,6 +57,7 @@ import flax.core
import time
# %%
import os
os.environ['XLA_FLAGS'] = (
@ -63,14 +65,17 @@ os.environ['XLA_FLAGS'] = (
)
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : ".95"
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
@ -82,28 +87,40 @@ except (LookupError, OSError):
print("error")
# %% [markdown]
# ## Prepare datasets
# %%
# load model
model_name_or_path = "t5-small" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path,
force_download=False)
# %%
# Path to saved combined_dataset
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1'
save_path = 't5_80_1_bf16'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 80
batch_size = 384 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
split_datasets['train'][0]
# %%
from transformers import T5TokenizerFast
@ -134,16 +151,48 @@ len(tokenizer)
# %%
model_path = './t5_80_1'
# model_path = 't5=base'
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_path,
dtype=jax.numpy.float32
)
# model_path = './t5_80_1'
# model_path = 't5-base'
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
# pretrained_model_name_or_path=model_path,
# dtype=jax.numpy.bfloat16
# )
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
# from t5_model.configuration_t5 import T5Config
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
config = T5Config()
# %%
model.params_shape_tree['shared']
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
# then pass the mask as follows
from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
# useful for transformer model
model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
flat_params = traverse_util.flatten_dict(model.params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
model.params = model.to_bf16(model.params, mask)
# %%
# # Function to extract shape and dtype without showing values
# def format_param(param):
# return f"shape={param.shape}, dtype={param.dtype}"
#
# # Use jax.tree_map to apply the formatter across the parameter tree
# formatted_params = jax.tree.map(format_param, model.params)
#
# # Pretty-print the tree
# for k, v in formatted_params.items():
# print(f"{k}: {v}")
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
@ -164,6 +213,7 @@ def preprocess_function(example):
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = tokenizer(
inputs,
max_length=max_length,
@ -179,39 +229,68 @@ def preprocess_function(example):
return_tensors="np"
)
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# %%
# temp
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
tokenized_datasets = split_datasets.map(
token_datasets = split_datasets.map(
preprocess_function,
batched=True,
num_proc=1,
# if we do not remove, we keep the original data
remove_columns=split_datasets["train"].column_names,
)
train_dataset = token_datasets["train"]
tokenized_datasets.set_format(type='numpy',
columns=['input_ids', 'attention_mask',
# %%
token_datasets.set_format(
type='numpy',
columns=[
'input_ids', 'attention_mask',
'labels', 'decoder_input_ids',
'decoder_attention_mask'])
'decoder_attention_mask']
)
# %%
# check values
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
int_array = train_dataset[name]
if np.all((int_array >= 0) & (int_array <= 65535)):
uint16_array = int_array.astype(np.uint16)
else:
raise ValueError("Values are out of range for uint16")
# %%
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
from datasets import ClassLabel, Value, Sequence
features = train_dataset.features.copy()
features['input_ids'] = Sequence(Value('uint16'))
features['attention_mask'] = Sequence(Value('bool'))
features['labels'] = Sequence(Value('uint16'))
features['decoder_input_ids'] = Sequence(Value('uint16'))
features['decoder_attention_mask'] = Sequence(Value('bool'))
train_dataset = train_dataset.cast(features)
# %%
train_dataset[0]
# temp
print('data type check: ', train_dataset['decoder_attention_mask'].dtype)
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
@ -243,32 +322,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
# %% [markdown]
# # Model
# %%
# Store some constant
seed = 117
num_epochs = 40
batch_size = 32
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
@ -342,12 +396,11 @@ class TrainState(train_state.TrainState):
# set bf16 for model params
# model.params = model.to_bf16(model.params)
params = model.params
# Cast parameters to bfloat16 if desired
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@ -373,6 +426,7 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
return loss, num_labels
# Define gradient update step fn
@jax.jit
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
@ -406,17 +460,17 @@ max_length = (
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# def generate_step(params, batch):
# model.params = params
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
# return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
@ -435,41 +489,44 @@ print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
# epochs = range(num_epochs)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_time = time.time() - train_start
train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# output_dir = save_path
# # save checkpoint after each epoch and push checkpoint to the hub
# if jax.process_index() == 0:
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# model.save_pretrained(output_dir, params=params)
# tokenizer.save_pretrained(output_dir)

View File

@ -92,7 +92,7 @@ test_dataset = Dataset.from_list(process_df(df))
# %%
# load model
model_name_or_path = "./t5_80_1" # Replace with your specific model name
model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)