Feat: introduced efficient train data dtype, jit train step, bfloat16
mat mul
This commit is contained in:
parent
edd9c3551f
commit
d2dd72227f
|
@ -2,3 +2,4 @@
|
||||||
t5_*/
|
t5_*/
|
||||||
exports/
|
exports/
|
||||||
modified_t5_model/
|
modified_t5_model/
|
||||||
|
traces/
|
||||||
|
|
227
t5_jax.py
227
t5_jax.py
|
@ -16,6 +16,7 @@
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # T5 implementation using jax
|
# # T5 implementation using jax
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
@ -26,15 +27,15 @@ from typing import Callable, Optional
|
||||||
import math
|
import math
|
||||||
|
|
||||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
jax.config.update("jax_default_matmul_precision", "high")
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
jax.config.update("jax_enable_x64", False)
|
# jax.config.update("jax_enable_x64", False)
|
||||||
# enable cache
|
# enable cache
|
||||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
@ -56,6 +57,7 @@ import flax.core
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
os.environ['XLA_FLAGS'] = (
|
os.environ['XLA_FLAGS'] = (
|
||||||
|
@ -63,14 +65,17 @@ os.environ['XLA_FLAGS'] = (
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ.update({
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
"NCCL_LL128_BUFFSIZE": "-2",
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
"NCCL_LL_BUFFSIZE": "-2",
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
"XLA_PYTHON_CLIENT_MEM_FRACTION" : ".95"
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
})
|
})
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
# get platform type
|
||||||
from jax.lib import xla_bridge
|
from jax.lib import xla_bridge
|
||||||
print(xla_bridge.get_backend().platform)
|
print(xla_bridge.get_backend().platform)
|
||||||
|
|
||||||
|
@ -82,28 +87,40 @@ except (LookupError, OSError):
|
||||||
print("error")
|
print("error")
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# ## Prepare datasets
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# load model
|
# config options
|
||||||
model_name_or_path = "t5-small" # Replace with your specific model name
|
|
||||||
|
|
||||||
# Load configuration
|
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path,
|
|
||||||
force_download=False)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# Path to saved combined_dataset
|
|
||||||
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
save_path = 't5_80_1'
|
save_path = 't5_80_1_bf16'
|
||||||
# file_path = 'combined_data'
|
# file_path = 'combined_data'
|
||||||
split_datasets = load_from_disk(file_path)
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 80
|
||||||
|
batch_size = 384 # 384 is the best
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 2e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
split_datasets['train'][0]
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
@ -134,16 +151,48 @@ len(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model_path = './t5_80_1'
|
# model_path = './t5_80_1'
|
||||||
# model_path = 't5=base'
|
# model_path = 't5-base'
|
||||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
pretrained_model_name_or_path=model_path,
|
# pretrained_model_name_or_path=model_path,
|
||||||
dtype=jax.numpy.float32
|
# dtype=jax.numpy.bfloat16
|
||||||
)
|
# )
|
||||||
|
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
||||||
|
# from t5_model.configuration_t5 import T5Config
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from transformers import T5Config
|
||||||
|
|
||||||
|
config = T5Config()
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model.params_shape_tree['shared']
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||||
|
# then pass the mask as follows
|
||||||
|
from flax import traverse_util
|
||||||
|
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
# useful for transformer model
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# enable bf16 except for layer_norm
|
||||||
|
flat_params = traverse_util.flatten_dict(model.params)
|
||||||
|
mask = {
|
||||||
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # Function to extract shape and dtype without showing values
|
||||||
|
# def format_param(param):
|
||||||
|
# return f"shape={param.shape}, dtype={param.dtype}"
|
||||||
|
#
|
||||||
|
# # Use jax.tree_map to apply the formatter across the parameter tree
|
||||||
|
# formatted_params = jax.tree.map(format_param, model.params)
|
||||||
|
#
|
||||||
|
# # Pretty-print the tree
|
||||||
|
# for k, v in formatted_params.items():
|
||||||
|
# print(f"{k}: {v}")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
@ -164,6 +213,7 @@ def preprocess_function(example):
|
||||||
targets = example['output']
|
targets = example['output']
|
||||||
# text_target sets the corresponding label to inputs
|
# text_target sets the corresponding label to inputs
|
||||||
# there is no need to create a separate 'labels'
|
# there is no need to create a separate 'labels'
|
||||||
|
# produce input_ids and decoder_input_ids
|
||||||
model_inputs = tokenizer(
|
model_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
@ -179,39 +229,68 @@ def preprocess_function(example):
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# for loss computation
|
||||||
model_inputs["labels"] = labels["input_ids"]
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
# make decoder input ids
|
||||||
decoder_input_ids = shift_tokens_right_fn(
|
decoder_input_ids = shift_tokens_right_fn(
|
||||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
# require by model
|
||||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||||
|
|
||||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# temp
|
||||||
|
|
||||||
# map maps function to each "row" in the dataset
|
# map maps function to each "row" in the dataset
|
||||||
# aka the data in the immediate nesting
|
# aka the data in the immediate nesting
|
||||||
tokenized_datasets = split_datasets.map(
|
token_datasets = split_datasets.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=1,
|
num_proc=1,
|
||||||
|
# if we do not remove, we keep the original data
|
||||||
remove_columns=split_datasets["train"].column_names,
|
remove_columns=split_datasets["train"].column_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_dataset = token_datasets["train"]
|
||||||
tokenized_datasets.set_format(type='numpy',
|
|
||||||
columns=['input_ids', 'attention_mask',
|
|
||||||
'labels', 'decoder_input_ids',
|
|
||||||
'decoder_attention_mask'])
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
train_dataset = tokenized_datasets["train"]
|
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
token_datasets.set_format(
|
||||||
|
type='numpy',
|
||||||
|
columns=[
|
||||||
|
'input_ids', 'attention_mask',
|
||||||
|
'labels', 'decoder_input_ids',
|
||||||
|
'decoder_attention_mask']
|
||||||
|
)
|
||||||
|
# %%
|
||||||
|
# check values
|
||||||
|
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
||||||
|
int_array = train_dataset[name]
|
||||||
|
if np.all((int_array >= 0) & (int_array <= 65535)):
|
||||||
|
uint16_array = int_array.astype(np.uint16)
|
||||||
|
else:
|
||||||
|
raise ValueError("Values are out of range for uint16")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
train_dataset[0]
|
|
||||||
|
|
||||||
|
from datasets import ClassLabel, Value, Sequence
|
||||||
|
features = train_dataset.features.copy()
|
||||||
|
features['input_ids'] = Sequence(Value('uint16'))
|
||||||
|
features['attention_mask'] = Sequence(Value('bool'))
|
||||||
|
features['labels'] = Sequence(Value('uint16'))
|
||||||
|
features['decoder_input_ids'] = Sequence(Value('uint16'))
|
||||||
|
features['decoder_attention_mask'] = Sequence(Value('bool'))
|
||||||
|
train_dataset = train_dataset.cast(features)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# temp
|
||||||
|
print('data type check: ', train_dataset['decoder_attention_mask'].dtype)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
@ -243,32 +322,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Model
|
# # Model
|
||||||
|
|
||||||
# %%
|
|
||||||
# Store some constant
|
|
||||||
seed = 117
|
|
||||||
num_epochs = 40
|
|
||||||
batch_size = 32
|
|
||||||
num_train_epochs = num_epochs
|
|
||||||
per_device_train_batch_size = batch_size
|
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
||||||
per_device_eval_batch_size = batch_size
|
|
||||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
||||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
|
||||||
|
|
||||||
warmup_steps = 0
|
|
||||||
learning_rate = 2e-5
|
|
||||||
|
|
||||||
weight_decay = 0.01
|
|
||||||
adam_beta1 = 0.9
|
|
||||||
adam_beta2 = 0.999
|
|
||||||
adam_epsilon = 1e-8
|
|
||||||
label_smoothing_factor = 0.0
|
|
||||||
|
|
||||||
num_beams = 1
|
|
||||||
val_max_target_length = 128
|
|
||||||
|
|
||||||
predict_with_generate = True
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -342,12 +396,11 @@ class TrainState(train_state.TrainState):
|
||||||
|
|
||||||
# set bf16 for model params
|
# set bf16 for model params
|
||||||
# model.params = model.to_bf16(model.params)
|
# model.params = model.to_bf16(model.params)
|
||||||
params = model.params
|
|
||||||
# Cast parameters to bfloat16 if desired
|
# Cast parameters to bfloat16 if desired
|
||||||
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
# label smoothed cross entropy
|
# label smoothed cross entropy
|
||||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
@ -373,6 +426,7 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
return loss, num_labels
|
return loss, num_labels
|
||||||
|
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
|
@jax.jit
|
||||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
|
||||||
|
@ -406,17 +460,17 @@ max_length = (
|
||||||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
def generate_step(params, batch):
|
# def generate_step(params, batch):
|
||||||
model.params = params
|
# model.params = params
|
||||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||||||
return output_ids.sequences
|
# return output_ids.sequences
|
||||||
|
|
||||||
# Create parallel version of the train and eval step
|
# Create parallel version of the train and eval step
|
||||||
p_train_step = jax.pmap(
|
p_train_step = jax.pmap(
|
||||||
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
||||||
)
|
)
|
||||||
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
||||||
p_generate_step = jax.pmap(generate_step, "batch")
|
# p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
# Replicate the train state on each device
|
# Replicate the train state on each device
|
||||||
state = state.replicate()
|
state = state.replicate()
|
||||||
|
@ -435,41 +489,44 @@ print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
# epochs = range(num_epochs)
|
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
# ======================== Training ================================
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
|
|
||||||
# Create sampling rng
|
# Create sampling rng
|
||||||
rng, input_rng = jax.random.split(rng)
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
||||||
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
||||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||||
# train
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
batch = next(train_loader)
|
batch = next(train_loader)
|
||||||
batch = shard(batch)
|
batch = shard(batch)
|
||||||
state, train_metric = p_train_step(state, batch)
|
state, train_metric = p_train_step(state, batch)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
train_time += time.time() - train_start
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
train_metric = unreplicate(train_metric)
|
train_metric = unreplicate(train_metric)
|
||||||
|
train_metric['loss'].block_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||||
f" {train_metric['learning_rate']})"
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
)
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
# %%
|
||||||
|
|
||||||
output_dir = save_path
|
# output_dir = save_path
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# # save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
# if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(output_dir, params=params)
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
tokenizer.save_pretrained(output_dir)
|
# model.save_pretrained(output_dir, params=params)
|
||||||
|
# tokenizer.save_pretrained(output_dir)
|
||||||
|
|
|
@ -92,7 +92,7 @@ test_dataset = Dataset.from_list(process_df(df))
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# load model
|
# load model
|
||||||
model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
Loading…
Reference in New Issue