387 lines
12 KiB
Python
387 lines
12 KiB
Python
|
# ---
|
||
|
# jupyter:
|
||
|
# jupytext:
|
||
|
# formats: ipynb,py:percent
|
||
|
# text_representation:
|
||
|
# extension: .py
|
||
|
# format_name: percent
|
||
|
# format_version: '1.3'
|
||
|
# jupytext_version: 1.16.4
|
||
|
# ---
|
||
|
|
||
|
# %% [markdown]
|
||
|
# # prediction code
|
||
|
# ## import and process test data
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# import libraries
|
||
|
import pandas as pd
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
from datasets import Dataset, DatasetDict
|
||
|
|
||
|
import jax
|
||
|
import jax.numpy as jnp
|
||
|
import optax
|
||
|
import numpy as np
|
||
|
from functools import partial
|
||
|
from typing import Callable, Optional
|
||
|
import math
|
||
|
|
||
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||
|
jax.config.update("jax_default_matmul_precision", "high")
|
||
|
|
||
|
jax.config.update("jax_enable_x64", False)
|
||
|
|
||
|
|
||
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||
|
|
||
|
|
||
|
import datasets
|
||
|
from datasets import Dataset, load_dataset
|
||
|
import evaluate
|
||
|
from tqdm import tqdm
|
||
|
from datasets import load_from_disk
|
||
|
|
||
|
|
||
|
import nltk # Here to have a nice missing dependency error message early on
|
||
|
|
||
|
from flax import jax_utils, traverse_util
|
||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||
|
from flax.training import train_state
|
||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||
|
|
||
|
|
||
|
import time
|
||
|
|
||
|
|
||
|
# %%
|
||
|
|
||
|
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
||
|
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
||
|
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
||
|
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
||
|
|
||
|
# Ensure to include 'ships_idx' in the fields list
|
||
|
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||
|
|
||
|
# Load the dataset
|
||
|
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||
|
|
||
|
def process_df(df):
|
||
|
output_list = [{
|
||
|
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||
|
# 'input': f"<DESC>{row['tag_description']}<DESC>",
|
||
|
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||
|
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||
|
'answer': f"{row['thing']} {row['property']}",
|
||
|
'answer_thing': row['thing'],
|
||
|
'answer_property': row['property'],
|
||
|
} for _, row in df.iterrows()]
|
||
|
|
||
|
return output_list
|
||
|
|
||
|
|
||
|
# takes 1 minute to run without batching
|
||
|
test_dataset = Dataset.from_list(process_df(df))
|
||
|
|
||
|
|
||
|
# %% [markdown]
|
||
|
# ## Load model for attributes
|
||
|
|
||
|
# %%
|
||
|
# load model
|
||
|
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
||
|
|
||
|
# Load configuration
|
||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||
|
|
||
|
# Load model
|
||
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||
|
model_name_or_path
|
||
|
)
|
||
|
|
||
|
|
||
|
# %% [markdown]
|
||
|
# ## Tokenizer
|
||
|
|
||
|
# %%
|
||
|
# prepare tokenizer
|
||
|
from transformers import T5TokenizerFast
|
||
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||
|
# Define additional special tokens
|
||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||
|
# Add the additional special tokens to the tokenizer
|
||
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||
|
|
||
|
max_length = 86
|
||
|
|
||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||
|
|
||
|
# given a dataset entry, run it through the tokenizer
|
||
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||
|
def preprocess_function(example):
|
||
|
input = example['input']
|
||
|
target = example['output']
|
||
|
# text_target sets the corresponding label to inputs
|
||
|
# there is no need to create a separate 'labels'
|
||
|
model_inputs = tokenizer(
|
||
|
input,
|
||
|
text_target=target,
|
||
|
max_length=max_length,
|
||
|
padding="max_length",
|
||
|
truncation=True,
|
||
|
return_tensors="np"
|
||
|
)
|
||
|
labels = tokenizer(
|
||
|
input,
|
||
|
text_target=target,
|
||
|
max_length=max_length,
|
||
|
padding="max_length",
|
||
|
truncation=True,
|
||
|
return_tensors="np"
|
||
|
)
|
||
|
|
||
|
model_inputs["labels"] = labels["input_ids"]
|
||
|
decoder_input_ids = shift_tokens_right_fn(
|
||
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||
|
)
|
||
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||
|
|
||
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||
|
|
||
|
return model_inputs
|
||
|
|
||
|
# map maps function to each "row" in the dataset
|
||
|
# aka the data in the immediate nesting
|
||
|
test_dataset = test_dataset.map(
|
||
|
preprocess_function,
|
||
|
batched=True,
|
||
|
num_proc=1,
|
||
|
remove_columns=test_dataset.column_names,
|
||
|
)
|
||
|
|
||
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||
|
"""
|
||
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||
|
"""
|
||
|
if shuffle:
|
||
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
||
|
batch_idx = np.asarray(batch_idx)
|
||
|
else:
|
||
|
batch_idx = np.arange(len(dataset))
|
||
|
|
||
|
if drop_last:
|
||
|
steps_per_epoch = len(dataset) // batch_size
|
||
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||
|
else:
|
||
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||
|
|
||
|
for idx in batch_idx:
|
||
|
batch = dataset[idx]
|
||
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||
|
|
||
|
yield batch
|
||
|
|
||
|
# %% [markdown]
|
||
|
# # Model Training
|
||
|
|
||
|
# %%
|
||
|
seed = 117
|
||
|
num_epochs = 80
|
||
|
batch_size = 96
|
||
|
num_train_epochs = num_epochs
|
||
|
per_device_train_batch_size = batch_size
|
||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||
|
per_device_eval_batch_size = batch_size
|
||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||
|
steps_per_epoch = len(test_dataset) // train_batch_size
|
||
|
total_train_steps = steps_per_epoch * num_epochs
|
||
|
|
||
|
warmup_steps = 0
|
||
|
learning_rate = 5e-5
|
||
|
|
||
|
weight_decay = 0.0
|
||
|
adam_beta1 = 0.9
|
||
|
adam_beta2 = 0.999
|
||
|
adam_epsilon = 1e-8
|
||
|
label_smoothing_factor = 0.0
|
||
|
|
||
|
num_beams = 1
|
||
|
val_max_target_length = None
|
||
|
|
||
|
predict_with_generate = True
|
||
|
|
||
|
|
||
|
# Initialize our training
|
||
|
rng = jax.random.PRNGKey(seed)
|
||
|
rng, dropout_rng = jax.random.split(rng)
|
||
|
|
||
|
def create_learning_rate_fn(
|
||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||
|
) -> Callable[[int], jnp.ndarray]:
|
||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||
|
decay_fn = optax.linear_schedule(
|
||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||
|
)
|
||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||
|
return schedule_fn
|
||
|
|
||
|
|
||
|
# Create learning rate schedule
|
||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||
|
len(test_dataset),
|
||
|
train_batch_size,
|
||
|
num_train_epochs,
|
||
|
warmup_steps,
|
||
|
learning_rate,
|
||
|
)
|
||
|
|
||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||
|
# mask boolean with the same structure as the parameters.
|
||
|
# The mask is True for parameters that should be decayed.
|
||
|
def decay_mask_fn(params):
|
||
|
flat_params = traverse_util.flatten_dict(params)
|
||
|
# find out all LayerNorm parameters
|
||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||
|
layer_norm_named_params = {
|
||
|
layer[-2:]
|
||
|
for layer_norm_name in layer_norm_candidates
|
||
|
for layer in flat_params.keys()
|
||
|
if layer_norm_name in "".join(layer).lower()
|
||
|
}
|
||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||
|
return traverse_util.unflatten_dict(flat_mask)
|
||
|
|
||
|
# create adam optimizer
|
||
|
adamw = optax.adamw(
|
||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||
|
b1=adam_beta1,
|
||
|
b2=adam_beta2,
|
||
|
eps=adam_epsilon,
|
||
|
weight_decay=weight_decay,
|
||
|
mask=decay_mask_fn,
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
|
||
|
# reload model to prevent leakage of variables
|
||
|
# load model
|
||
|
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
||
|
|
||
|
# Load configuration
|
||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||
|
|
||
|
# Load model
|
||
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||
|
model_name_or_path
|
||
|
)
|
||
|
|
||
|
# Training functions
|
||
|
class TrainState(train_state.TrainState):
|
||
|
dropout_rng: jnp.ndarray
|
||
|
|
||
|
def replicate(self):
|
||
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||
|
|
||
|
# Ensure model.params is properly initialized (this is just an example)
|
||
|
# Normally you would get this from a model initialization call with dummy input
|
||
|
params = model.params
|
||
|
# Cast parameters to bfloat16 if desired
|
||
|
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||
|
|
||
|
|
||
|
# Setup train state
|
||
|
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
||
|
|
||
|
|
||
|
|
||
|
# Define generation function
|
||
|
max_length = (
|
||
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||
|
)
|
||
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||
|
|
||
|
def generate_step(params, batch):
|
||
|
model.params = params
|
||
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||
|
return output_ids.sequences
|
||
|
|
||
|
# Create parallel version of the train and eval step
|
||
|
p_generate_step = jax.pmap(generate_step, "batch")
|
||
|
|
||
|
# Replicate the train state on each device
|
||
|
state = state.replicate()
|
||
|
|
||
|
|
||
|
pred_metrics = []
|
||
|
pred_generations = []
|
||
|
pred_labels = []
|
||
|
|
||
|
rng, input_rng = jax.random.split(rng)
|
||
|
|
||
|
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
|
||
|
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
||
|
|
||
|
print("***** Running training *****")
|
||
|
print(f" Num examples = {len(test_dataset)}")
|
||
|
print(f" Num steps = {num_epochs}")
|
||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||
|
|
||
|
|
||
|
for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False):
|
||
|
# Model forward
|
||
|
batch = next(pred_loader)
|
||
|
labels = batch["labels"]
|
||
|
|
||
|
# generation
|
||
|
if predict_with_generate:
|
||
|
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||
|
pred_labels.extend(labels)
|
||
|
|
||
|
|
||
|
|
||
|
# Print metrics
|
||
|
# desc = f"Predict Loss: {pred_metrics['loss']})"
|
||
|
# print(desc)
|
||
|
|
||
|
# %%
|
||
|
# save predictions to parquet
|
||
|
|
||
|
# decode prediction labels
|
||
|
def decode_preds(preds):
|
||
|
# In case the model returns more than the prediction logits
|
||
|
if isinstance(preds, tuple):
|
||
|
preds = preds[0]
|
||
|
|
||
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||
|
|
||
|
decoded_preds = [pred for pred in decoded_preds]
|
||
|
|
||
|
return decoded_preds
|
||
|
|
||
|
|
||
|
# Convert the list to a Pandas DataFrame
|
||
|
df = pd.DataFrame(decode_preds(pred_labels))
|
||
|
|
||
|
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||
|
df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||
|
|
||
|
|
||
|
|
||
|
# %%
|