learn_jax/t5_jax_prediction.py

387 lines
12 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # prediction code
# ## import and process test data
# %%
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
from tqdm import tqdm
from datasets import load_from_disk
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
# 'input': f"<DESC>{row['tag_description']}<DESC>",
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
'answer': f"{row['thing']} {row['property']}",
'answer_thing': row['thing'],
'answer_property': row['property'],
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %% [markdown]
# ## Load model for attributes
# %%
# load model
model_name_or_path = "t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# %% [markdown]
# ## Tokenizer
# %%
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
input = example['input']
target = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=test_dataset.column_names,
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# # Model Training
# %%
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = None
predict_with_generate = True
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(test_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# reload model to prevent leakage of variables
# load model
model_name_or_path = "t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# Cast parameters to bfloat16 if desired
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
pred_metrics = []
pred_generations = []
pred_labels = []
rng, input_rng = jax.random.split(rng)
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
# generation
if predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
# Print metrics
# desc = f"Predict Loss: {pred_metrics['loss']})"
# print(desc)
# %%
# save predictions to parquet
# decode prediction labels
def decode_preds(preds):
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_preds = [pred for pred in decoded_preds]
return decoded_preds
# Convert the list to a Pandas DataFrame
df = pd.DataFrame(decode_preds(pred_labels))
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
# %%