learn_jax/t5_jax_retrieval.py

625 lines
20 KiB
Python
Raw Normal View History

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: jax
# language: python
# name: python3
# ---
# %% [markdown]
# # T5 implementation using jax
# %% [markdown]
# ## import
# %% [raw]
# import json
# import logging
# import math
# import os
# import sys
# import time
# from dataclasses import asdict, dataclass, field
# from enum import Enum
# from functools import partial
# from pathlib import Path
# from typing import Callable, Optional
#
# import datasets
# import evaluate
# import jax
# import jax.numpy as jnp
# import nltk # Here to have a nice missing dependency error message early on
# import numpy as np
# import optax
# from datasets import Dataset, load_dataset
# from filelock import FileLock
# from flax import jax_utils, traverse_util
# from flax.jax_utils import pad_shard_unpad, unreplicate
# from flax.training import train_state
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
# from tqdm import tqdm
#
# import transformers
# from transformers import (
# CONFIG_MAPPING,
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
# AutoConfig,
# AutoTokenizer,
# FlaxAutoModelForSeq2SeqLM,
# HfArgumentParser,
# is_tensorboard_available,
# )
# from transformers.utils import is_offline_mode, send_example_telemetry
#
#
# logger = logging.getLogger(__name__)
#
# try:
# nltk.data.find("tokenizers/punkt")
# except (LookupError, OSError):
# if is_offline_mode():
# raise LookupError(
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
# )
# with FileLock(".lock") as lock:
# nltk.download("punkt", quiet=True)
#
#
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# %%
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=True '
'--xla_gpu_triton_gemm_any=True '
)
os.environ.update({
"CUDA_VISIBLE_DEVICES": "0, 1, 2, 3",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
# %%
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# nltk.download('punkt')
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
if is_offline_mode():
raise LookupError(
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
)
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
# %% [markdown]
# ## Prepare datasets
# %%
# load model
model_name_or_path = "t5-small" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# %%
from tqdm import tqdm
from datasets import load_from_disk
# Path to saved combined_dataset
file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5'
save_path = 't5_80_1_retrieval'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>",
"<CONTEXT>", "<EXAMPLE>", "<INPUT>", "<OUTPUT>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 300
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
input = example['input']
target = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=split_datasets["train"].column_names,
)
# %%
tokenized_datasets
# %%
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# Now we have model inputs in terms of the variable tokenized_datasets
# %%
# metric
metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
# def compute_metrics(preds, labels):
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#
# # Some simple post-processing
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
#
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
# result = {k: round(v * 100, 4) for k, v in result.items()}
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
# result["gen_len"] = np.mean(prediction_lens)
# return result
def compute_metrics(preds, labels):
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100s in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"bleu": result["score"]}
# %% [markdown]
# # Model
# %%
# Store some constant
seed = 117
num_epochs = 80
batch_size = 36
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = None
predict_with_generate = True
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# Cast parameters to bfloat16 if desired
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
# %%
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
# epochs = range(num_epochs)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
# eval_metrics = []
# eval_preds = []
# eval_labels = []
# eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
# eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
# for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# # Model forward
# batch = next(eval_loader)
# labels = batch["labels"]
# metrics = pad_shard_unpad(p_eval_step, static_return=True)(
# state.params, batch, min_device_batch=per_device_eval_batch_size
# )
# eval_metrics.append(metrics)
# # generation
# if predict_with_generate:
# generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
# eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
# eval_labels.extend(labels)
# # normalize eval metrics
# eval_metrics = get_metrics(eval_metrics)
# eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute metrics
# rouge_desc = ""
# if predict_with_generate:
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
# eval_metrics.update(rouge_metrics)
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
# # Print metrics and update progress bar
# desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
# epochs.write(desc)
# epochs.desc = desc
# Save metrics
# if has_tensorboard and jax.process_index() == 0:
# cur_step = epoch * (len(train_dataset) // train_batch_size)
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# %% [markdown]
# #