learn_jax/t5_jax.py

559 lines
18 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: jax
# language: python
# name: python3
# ---
# %% [markdown]
# # T5 implementation using jax
# %%
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
import flax.linen as nn
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "bfloat16")
# jax.config.update("jax_enable_x64", False)
# enable cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
from tqdm import tqdm
from datasets import load_from_disk
import nltk # Here to have a nice missing dependency error message early on
from typing import Dict, Any, Union
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import FrozenDict, unfreeze
import flax.core
import time
# %%
import os
# os.environ['XLA_FLAGS'] = (
# '--xla_gpu_triton_gemm_any=true '
# '--xla_gpu_enable_custom_fusions=true '
# '--xla_gpu_enable_address_computation_fusion=true'
# )
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
# %%
# get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
# %%
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
print("error")
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 40
batch_size = 64 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-4
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 128
# %%
len(tokenizer)
# %%
# load pytorch model first
# from transformers import AutoModelForSeq2SeqLM
# model_checkpoint = "t5-base"
# model_pt = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# # important! after extending tokens vocab
# model_pt.resize_token_embeddings(len(tokenizer))
# model_pt.save_pretrained('./modified_t5_model')
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
# pretrained_model_name_or_path="modified_t5_model",
# dtype=jax.numpy.bfloat16,
# from_pt=True
# )
# %%
# model_path = './t5_80_1'
# model_path = 't5-base'
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
# pretrained_model_name_or_path=model_path,
# dtype=jax.numpy.bfloat16
# )
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
# from t5_model.configuration_t5 import T5Config
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
config = T5Config()
model = FlaxT5ForConditionalGeneration.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True
)
params = model.params
# enable bf16 except for layer_norm
# enable bf16
# enable only for dense, some transformer sections, and shared
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
# path: not (
# (path[-2] == "layer_norm" and path[-1] == "weight") or
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
# (path[-2] == "o" and path[-1] == "kernel")
# )
# for path in flat_params
path: (
(path[-2] == "wi" and path[-1] == "weight") or
(path[-2] == "wo" and path[-1] == "weight") or
(path[-2] == "k" and path[-1] == "kernel") or
(path[-2] == "q" and path[-1] == "kernel") or
(path[-2] == "v" and path[-1] == "kernel") or
(path[-2] == "shared" and path[-1] == "embedding")
) for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
mask = create_mask_for_layer_norm(params)
# override params with bfloat version
params= cast_floating_to(params, jnp.bfloat16, mask)
# %%
# # Function to extract shape and dtype without showing values
# def format_param(param):
# return f"shape={param.shape}, dtype={param.dtype}"
#
# # Use jax.tree_map to apply the formatter across the parameter tree
# formatted_params = jax.tree.map(format_param, model.params)
#
# # Pretty-print the tree
# for k, v in formatted_params.items():
# print(f"{k}: {v}")
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
# %%
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = tokenizer(
inputs,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# %%
# temp
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
token_datasets = split_datasets.map(
preprocess_function,
batched=True,
num_proc=1,
# if we do not remove, we keep the original data
remove_columns=split_datasets["train"].column_names,
)
train_dataset = token_datasets["train"]
# %%
token_datasets.set_format(
type='numpy',
columns=[
'input_ids', 'attention_mask',
'labels', 'decoder_input_ids',
'decoder_attention_mask']
)
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: v for k, v in batch.items()}
yield batch
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
# easy way to achieve data parallelism
# also achieves folding of rng keys
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Setup train state
# input all the state here
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
@jax.jit
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# return new_state
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# def generate_step(params, batch):
# model.params = params
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
# return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
# p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
# %%
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
# train_metrics.append(train_metric)
train_time = time.time() - train_start
train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready()
epochs.write(
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# %%