580 lines
18 KiB
Python
580 lines
18 KiB
Python
# ---
|
|
# jupyter:
|
|
# jupytext:
|
|
# formats: ipynb,py:percent
|
|
# text_representation:
|
|
# extension: .py
|
|
# format_name: percent
|
|
# format_version: '1.3'
|
|
# jupytext_version: 1.16.4
|
|
# kernelspec:
|
|
# display_name: jax
|
|
# language: python
|
|
# name: python3
|
|
# ---
|
|
|
|
# %% [markdown]
|
|
# # T5 implementation using jax
|
|
|
|
|
|
# %%
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
import numpy as np
|
|
from functools import partial
|
|
from typing import Callable, Optional
|
|
import math
|
|
|
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
# jax.config.update("jax_enable_x64", False)
|
|
# enable cache
|
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
|
|
|
|
|
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
|
|
|
|
|
import datasets
|
|
from datasets import Dataset, load_dataset
|
|
import evaluate
|
|
from tqdm import tqdm
|
|
from datasets import load_from_disk
|
|
|
|
|
|
import nltk # Here to have a nice missing dependency error message early on
|
|
from typing import Dict, Any, Union
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
import flax.core
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
# %%
|
|
import os
|
|
# os.environ['XLA_FLAGS'] = (
|
|
# '--xla_gpu_triton_gemm_any=true '
|
|
# '--xla_gpu_enable_custom_fusions=true '
|
|
# '--xla_gpu_enable_address_computation_fusion=true'
|
|
# )
|
|
flags = (
|
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
|
'--xla_gpu_triton_gemm_any=True '
|
|
# '--xla_gpu_enable_async_collectives=true '
|
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
|
)
|
|
os.environ["XLA_FLAGS"] = flags
|
|
|
|
|
|
os.environ.update({
|
|
"TOKENIZERS_PARALLELISM" : "false",
|
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
|
"NCCL_LL128_BUFFSIZE": "-2",
|
|
"NCCL_LL_BUFFSIZE": "-2",
|
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99",
|
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
|
})
|
|
|
|
# %%
|
|
# get platform type
|
|
from jax.extend.backend import get_backend
|
|
print(get_backend().platform)
|
|
|
|
|
|
# %%
|
|
try:
|
|
nltk.data.find("tokenizers/punkt")
|
|
except (LookupError, OSError):
|
|
print("error")
|
|
|
|
|
|
# %%
|
|
# config options
|
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
|
save_path = 't5_5e_1_pmap'
|
|
# file_path = 'combined_data'
|
|
split_datasets = load_from_disk(file_path)
|
|
training_size = len(split_datasets['train'])
|
|
# Store some constant
|
|
seed = 117
|
|
num_epochs = 5
|
|
batch_size = 32 # 384 is the best
|
|
num_train_epochs = num_epochs
|
|
per_device_train_batch_size = batch_size
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
per_device_eval_batch_size = batch_size
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = training_size // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
warmup_steps = 0
|
|
learning_rate = 2e-5
|
|
|
|
weight_decay = 0.01
|
|
adam_beta1 = 0.9
|
|
adam_beta2 = 0.999
|
|
adam_epsilon = 1e-8
|
|
label_smoothing_factor = 0.0
|
|
|
|
num_beams = 1
|
|
val_max_target_length = 86
|
|
|
|
predict_with_generate = True
|
|
|
|
|
|
|
|
# %%
|
|
from transformers import T5TokenizerFast
|
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
# Define additional special tokens
|
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
|
# Add the additional special tokens to the tokenizer
|
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
max_length = 86
|
|
|
|
# %%
|
|
len(tokenizer)
|
|
|
|
# %%
|
|
# load pytorch model first
|
|
# from transformers import AutoModelForSeq2SeqLM
|
|
# model_checkpoint = "t5-base"
|
|
# model_pt = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
|
|
# # important! after extending tokens vocab
|
|
# model_pt.resize_token_embeddings(len(tokenizer))
|
|
# model_pt.save_pretrained('./modified_t5_model')
|
|
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
# pretrained_model_name_or_path="modified_t5_model",
|
|
# dtype=jax.numpy.bfloat16,
|
|
# from_pt=True
|
|
# )
|
|
|
|
|
|
# %%
|
|
# model_path = './t5_80_1'
|
|
# model_path = 't5-base'
|
|
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
# pretrained_model_name_or_path=model_path,
|
|
# dtype=jax.numpy.bfloat16
|
|
# )
|
|
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
|
# from t5_model.configuration_t5 import T5Config
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
from transformers import T5Config
|
|
|
|
config = T5Config()
|
|
|
|
|
|
# %%
|
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
|
# then pass the mask as follows
|
|
from flax import traverse_util
|
|
|
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
|
# useful for transformer model
|
|
# model.enable_gradient_checkpointing()
|
|
|
|
# enable bf16 except for layer_norm
|
|
# flat_params = traverse_util.flatten_dict(model.params)
|
|
# mask = {
|
|
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
|
# }
|
|
# mask = traverse_util.unflatten_dict(mask)
|
|
# # borrowed from transformers modeling_flax_utils
|
|
# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
|
# """
|
|
# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
|
# """
|
|
#
|
|
# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
|
# def conditional_cast(param):
|
|
# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
|
# param = param.astype(dtype)
|
|
# return param
|
|
#
|
|
# if mask is None:
|
|
# return jax.tree_util.tree_map(conditional_cast, params)
|
|
#
|
|
# flat_params = traverse_util.flatten_dict(params)
|
|
# flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
|
#
|
|
# for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
|
# if masked:
|
|
# flat_params[key] = conditional_cast(flat_params[key])
|
|
#
|
|
# return traverse_util.unflatten_dict(flat_params)
|
|
#
|
|
# # Cast parameters to bfloat16 if desired
|
|
# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
# # instead of casting the whole thing, we cast only certain parts of the tree
|
|
# params = cast_floating_to(model.params, jnp.bfloat16, mask)
|
|
|
|
|
|
# %%
|
|
# # Function to extract shape and dtype without showing values
|
|
# def format_param(param):
|
|
# return f"shape={param.shape}, dtype={param.dtype}"
|
|
#
|
|
# # Use jax.tree_map to apply the formatter across the parameter tree
|
|
# formatted_params = jax.tree.map(format_param, model.params)
|
|
#
|
|
# # Pretty-print the tree
|
|
# for k, v in formatted_params.items():
|
|
# print(f"{k}: {v}")
|
|
|
|
# %%
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
|
|
|
|
|
|
|
# %%
|
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
|
|
|
|
|
# given a dataset entry, run it through the tokenizer
|
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
def preprocess_function(example):
|
|
inputs = example['input']
|
|
targets = example['output']
|
|
# text_target sets the corresponding label to inputs
|
|
# there is no need to create a separate 'labels'
|
|
# produce input_ids and decoder_input_ids
|
|
model_inputs = tokenizer(
|
|
inputs,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
labels = tokenizer(
|
|
text_target=targets,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
|
|
# for loss computation
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
# make decoder input ids
|
|
decoder_input_ids = shift_tokens_right_fn(
|
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
|
)
|
|
# require by model
|
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
|
|
|
return model_inputs
|
|
|
|
# %%
|
|
# temp
|
|
|
|
# map maps function to each "row" in the dataset
|
|
# aka the data in the immediate nesting
|
|
token_datasets = split_datasets.map(
|
|
preprocess_function,
|
|
batched=True,
|
|
num_proc=1,
|
|
# if we do not remove, we keep the original data
|
|
remove_columns=split_datasets["train"].column_names,
|
|
)
|
|
|
|
train_dataset = token_datasets["train"]
|
|
|
|
# %%
|
|
|
|
token_datasets.set_format(
|
|
type='numpy',
|
|
columns=[
|
|
'input_ids', 'attention_mask',
|
|
'labels', 'decoder_input_ids',
|
|
'decoder_attention_mask']
|
|
)
|
|
# %%
|
|
# check values
|
|
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
|
int_array = train_dataset[name]
|
|
if np.all((int_array >= 0) & (int_array <= 65535)):
|
|
uint16_array = int_array.astype(np.uint16)
|
|
else:
|
|
raise ValueError("Values are out of range for uint16")
|
|
|
|
# %%
|
|
|
|
from datasets import ClassLabel, Value, Sequence
|
|
features = train_dataset.features.copy()
|
|
features['input_ids'] = Sequence(Value('uint16'))
|
|
features['attention_mask'] = Sequence(Value('bool'))
|
|
features['labels'] = Sequence(Value('uint16'))
|
|
features['decoder_input_ids'] = Sequence(Value('uint16'))
|
|
features['decoder_attention_mask'] = Sequence(Value('bool'))
|
|
train_dataset = train_dataset.cast(features)
|
|
|
|
|
|
|
|
# %%
|
|
# temp
|
|
print('data type check: ', train_dataset['decoder_attention_mask'].dtype)
|
|
|
|
# %%
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
"""
|
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
"""
|
|
if shuffle:
|
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
batch_idx = np.asarray(batch_idx)
|
|
else:
|
|
batch_idx = np.arange(len(dataset))
|
|
|
|
if drop_last:
|
|
steps_per_epoch = len(dataset) // batch_size
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
else:
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
|
|
|
for idx in batch_idx:
|
|
batch = dataset[idx]
|
|
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
yield batch
|
|
|
|
|
|
# %% [markdown]
|
|
# # Model
|
|
#
|
|
#
|
|
#
|
|
|
|
# %%
|
|
|
|
# Initialize our training
|
|
rng = jax.random.PRNGKey(seed)
|
|
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
|
|
# %%
|
|
# optimization functions
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
len(train_dataset),
|
|
train_batch_size,
|
|
num_train_epochs,
|
|
warmup_steps,
|
|
learning_rate,
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=adam_beta1,
|
|
b2=adam_beta2,
|
|
eps=adam_epsilon,
|
|
weight_decay=weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
|
|
# %%
|
|
# Training functions
|
|
class TrainState(train_state.TrainState):
|
|
dropout_rng: jnp.ndarray
|
|
|
|
# easy way to achieve data parallelism
|
|
# also achieves folding of rng keys
|
|
def replicate(self):
|
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
|
# set bf16 for model params
|
|
# model.params = model.to_bf16(model.params)
|
|
# Cast parameters to bfloat16 if desired
|
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
|
|
# Setup train state
|
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
|
|
|
# label smoothed cross entropy
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
"""
|
|
The label smoothing implementation is adapted from Flax's official example:
|
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
confidence = 1.0 - label_smoothing_factor
|
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
normalizing_constant = -(
|
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
)
|
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
|
loss = loss - normalizing_constant
|
|
|
|
# ignore padded tokens from loss
|
|
loss = loss * padding_mask
|
|
loss = loss.sum()
|
|
num_labels = padding_mask.sum()
|
|
return loss, num_labels
|
|
|
|
# Define gradient update step fn
|
|
@jax.jit
|
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
def compute_loss(params):
|
|
labels = batch.pop("labels")
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
return loss, num_labels
|
|
|
|
# compute gradients through computational graph
|
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params)
|
|
num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
# true loss = total loss / total samples
|
|
# loss = jax.lax.psum(loss, "batch")
|
|
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
|
|
# true grad = total grad / total samples
|
|
grad = jax.lax.psum(grad, "batch")
|
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
|
|
|
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
# return new_state, metrics
|
|
return new_state
|
|
|
|
# Define generation function
|
|
max_length = (
|
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
|
)
|
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
|
|
# def generate_step(params, batch):
|
|
# model.params = params
|
|
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
|
# return output_ids.sequences
|
|
|
|
# Create parallel version of the train and eval step
|
|
p_train_step = jax.pmap(
|
|
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
|
)
|
|
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
|
# p_generate_step = jax.pmap(generate_step, "batch")
|
|
|
|
# Replicate the train state on each device
|
|
state = state.replicate()
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
print("***** Running training *****")
|
|
print(f" Num examples = {len(train_dataset)}")
|
|
print(f" Num Epochs = {num_epochs}")
|
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
print(f" Total optimization steps = {total_train_steps}")
|
|
|
|
|
|
# %%
|
|
# jax.profiler.start_trace("./traces")
|
|
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
for epoch in epochs:
|
|
train_start = time.time()
|
|
|
|
# Create sampling rng
|
|
train_metrics = []
|
|
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
|
batch = next(train_loader)
|
|
batch = shard(batch)
|
|
state = p_train_step(state, batch)
|
|
# train_metrics.append(train_metric)
|
|
|
|
train_time = time.time() - train_start
|
|
|
|
# train_metric = unreplicate(train_metric)
|
|
# train_metric['loss'].block_until_ready()
|
|
|
|
|
|
|
|
epochs.write(
|
|
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
|
# f"Learning Rate:{train_metric['learning_rate']}, "
|
|
f"Last train time: {train_time})"
|
|
)
|
|
# jax.profiler.stop_trace()
|
|
# %%
|
|
|
|
output_dir = save_path
|
|
# save checkpoint after each epoch and push checkpoint to the hub
|
|
if jax.process_index() == 0:
|
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
|
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
|
model.save_pretrained(output_dir, params=params)
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
# %%
|