2024-09-11 08:17:02 +09:00
|
|
|
# ---
|
|
|
|
# jupyter:
|
|
|
|
# jupytext:
|
|
|
|
# formats: ipynb,py:percent
|
|
|
|
# text_representation:
|
|
|
|
# extension: .py
|
|
|
|
# format_name: percent
|
|
|
|
# format_version: '1.3'
|
|
|
|
# jupytext_version: 1.16.4
|
|
|
|
# ---
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
|
|
# # prediction code
|
|
|
|
# ## import and process test data
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# import libraries
|
|
|
|
import pandas as pd
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
|
|
|
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
|
|
|
import optax
|
|
|
|
import numpy as np
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Optional
|
|
|
|
import math
|
|
|
|
|
|
|
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
2024-09-14 14:13:38 +09:00
|
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
jax.config.update("jax_enable_x64", False)
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
|
|
|
|
|
|
|
|
|
|
|
import datasets
|
2024-09-12 22:57:19 +09:00
|
|
|
from datasets import Dataset
|
2024-09-11 08:17:02 +09:00
|
|
|
import evaluate
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
import nltk # Here to have a nice missing dependency error message early on
|
|
|
|
|
|
|
|
from flax import jax_utils, traverse_util
|
|
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
|
|
from flax.training import train_state
|
|
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
|
|
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
|
|
|
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
|
|
|
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
|
|
|
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
|
|
|
|
|
|
|
# Ensure to include 'ships_idx' in the fields list
|
|
|
|
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
|
|
|
|
|
|
|
# Load the dataset
|
|
|
|
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
|
|
|
|
|
|
|
def process_df(df):
|
|
|
|
output_list = [{
|
|
|
|
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
|
|
|
# 'input': f"<DESC>{row['tag_description']}<DESC>",
|
|
|
|
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
|
|
|
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
|
|
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
2024-09-12 22:57:19 +09:00
|
|
|
# 'answer': f"{row['thing']} {row['property']}",
|
|
|
|
# 'answer_thing': row['thing'],
|
|
|
|
# 'answer_property': row['property'],
|
2024-09-11 08:17:02 +09:00
|
|
|
} for _, row in df.iterrows()]
|
|
|
|
|
|
|
|
return output_list
|
|
|
|
|
|
|
|
|
|
|
|
# takes 1 minute to run without batching
|
|
|
|
test_dataset = Dataset.from_list(process_df(df))
|
|
|
|
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
|
|
# ## Load model for attributes
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# load model
|
2024-09-14 02:02:45 +09:00
|
|
|
model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
# Load configuration
|
|
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
|
|
|
|
|
|
# Load model
|
|
|
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
2024-09-12 22:57:19 +09:00
|
|
|
pretrained_model_name_or_path=model_name_or_path
|
2024-09-11 08:17:02 +09:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
|
|
# ## Tokenizer
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# prepare tokenizer
|
|
|
|
from transformers import T5TokenizerFast
|
|
|
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
|
|
# Define additional special tokens
|
|
|
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
|
|
# Add the additional special tokens to the tokenizer
|
|
|
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
|
|
|
|
max_length = 86
|
|
|
|
|
|
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
|
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
|
|
|
|
|
|
|
# given a dataset entry, run it through the tokenizer
|
|
|
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
|
|
def preprocess_function(example):
|
2024-09-12 22:57:19 +09:00
|
|
|
inputs = example['input']
|
|
|
|
targets = example['output']
|
2024-09-11 08:17:02 +09:00
|
|
|
# text_target sets the corresponding label to inputs
|
|
|
|
# there is no need to create a separate 'labels'
|
|
|
|
model_inputs = tokenizer(
|
2024-09-12 22:57:19 +09:00
|
|
|
inputs,
|
2024-09-11 08:17:02 +09:00
|
|
|
max_length=max_length,
|
|
|
|
padding="max_length",
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="np"
|
|
|
|
)
|
|
|
|
labels = tokenizer(
|
2024-09-12 22:57:19 +09:00
|
|
|
text_target=targets,
|
2024-09-11 08:17:02 +09:00
|
|
|
max_length=max_length,
|
|
|
|
padding="max_length",
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="np"
|
|
|
|
)
|
|
|
|
|
|
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
decoder_input_ids = shift_tokens_right_fn(
|
|
|
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
|
|
|
)
|
|
|
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
|
|
|
|
|
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
|
|
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
|
|
|
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
# map maps function to each "row" in the dataset
|
|
|
|
# aka the data in the immediate nesting
|
|
|
|
test_dataset = test_dataset.map(
|
|
|
|
preprocess_function,
|
|
|
|
batched=True,
|
|
|
|
num_proc=1,
|
|
|
|
remove_columns=test_dataset.column_names,
|
|
|
|
)
|
|
|
|
|
|
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
|
|
"""
|
|
|
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
|
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
|
|
"""
|
|
|
|
if shuffle:
|
|
|
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
|
batch_idx = np.asarray(batch_idx)
|
|
|
|
else:
|
|
|
|
batch_idx = np.arange(len(dataset))
|
|
|
|
|
|
|
|
if drop_last:
|
|
|
|
steps_per_epoch = len(dataset) // batch_size
|
|
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
|
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
|
|
else:
|
|
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
|
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
|
|
|
|
|
|
|
for idx in batch_idx:
|
|
|
|
batch = dataset[idx]
|
|
|
|
batch = {k: np.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
# %% [markdown]
|
2024-09-12 22:57:19 +09:00
|
|
|
# # model generation
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
# %%
|
|
|
|
seed = 117
|
|
|
|
num_epochs = 80
|
|
|
|
batch_size = 96
|
|
|
|
num_train_epochs = num_epochs
|
|
|
|
per_device_train_batch_size = batch_size
|
|
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
|
|
per_device_eval_batch_size = batch_size
|
|
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
|
|
steps_per_epoch = len(test_dataset) // train_batch_size
|
|
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
|
|
|
|
num_beams = 1
|
2024-09-12 22:57:19 +09:00
|
|
|
val_max_target_length = 128
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
predict_with_generate = True
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize our training
|
|
|
|
rng = jax.random.PRNGKey(seed)
|
|
|
|
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
|
|
# reload model to prevent leakage of variables
|
|
|
|
# load model
|
|
|
|
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
|
|
|
|
|
|
|
# Load configuration
|
|
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
|
|
|
|
|
|
# Load model
|
|
|
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
|
|
model_name_or_path
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Ensure model.params is properly initialized (this is just an example)
|
|
|
|
# Normally you would get this from a model initialization call with dummy input
|
|
|
|
params = model.params
|
2024-09-12 22:57:19 +09:00
|
|
|
# ensure full size floats
|
|
|
|
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
|
|
|
# we need to replicate model over devices
|
|
|
|
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
|
|
|
# Define generation function
|
|
|
|
max_length = (
|
|
|
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
|
|
|
)
|
|
|
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
|
|
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
|
|
|
|
|
|
def generate_step(params, batch):
|
2024-09-12 22:57:19 +09:00
|
|
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
|
2024-09-11 08:17:02 +09:00
|
|
|
return output_ids.sequences
|
|
|
|
|
|
|
|
# Create parallel version of the train and eval step
|
|
|
|
p_generate_step = jax.pmap(generate_step, "batch")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_generations = []
|
|
|
|
pred_labels = []
|
|
|
|
|
|
|
|
rng, input_rng = jax.random.split(rng)
|
|
|
|
|
|
|
|
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
|
|
|
|
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
|
|
|
|
|
|
|
print("***** Running training *****")
|
|
|
|
print(f" Num examples = {len(test_dataset)}")
|
|
|
|
print(f" Num steps = {num_epochs}")
|
|
|
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
|
|
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
|
|
|
|
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
2024-09-11 08:17:02 +09:00
|
|
|
# Model forward
|
|
|
|
batch = next(pred_loader)
|
|
|
|
labels = batch["labels"]
|
|
|
|
|
|
|
|
# generation
|
2024-09-12 22:57:19 +09:00
|
|
|
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
|
|
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
|
|
|
pred_labels.extend(labels)
|
|
|
|
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
# %% [markdown]
|
|
|
|
# # process predictions
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
|
|
|
# %%
|
2024-09-12 22:57:19 +09:00
|
|
|
# code to get special token ids
|
|
|
|
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
|
|
|
|
# tokens = tokenizer.tokenize(sentence)
|
|
|
|
# print("Tokens:", tokens)
|
|
|
|
# # Get the IDs (integer indices) of specific tokens
|
|
|
|
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
|
|
|
# print("Token IDs:", token_ids)
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
# %%
|
|
|
|
# extract sequence and decode
|
|
|
|
def extract_seq(tokens, start_value, end_value):
|
|
|
|
if start_value not in tokens or end_value not in tokens:
|
|
|
|
return None # Or handle this case according to your requirements
|
|
|
|
start_id = np.where(tokens == start_value)[0][0]
|
|
|
|
end_id = np.where(tokens == end_value)[0][0]
|
2024-09-11 08:17:02 +09:00
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
return tokens[start_id+1:end_id]
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
def process_tensor_output(tokens):
|
|
|
|
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
|
|
|
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
|
|
|
p_thing = None
|
|
|
|
p_property = None
|
|
|
|
if (thing_seq is not None):
|
|
|
|
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
|
|
|
|
if (property_seq is not None):
|
|
|
|
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
|
|
|
|
return p_thing, p_property
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
# %%
|
|
|
|
# decode prediction labels
|
|
|
|
def decode_preds(tokens_list):
|
|
|
|
thing_prediction_list = []
|
|
|
|
property_prediction_list = []
|
|
|
|
for tokens in tokens_list:
|
|
|
|
p_thing, p_property = process_tensor_output(tokens)
|
|
|
|
thing_prediction_list.append(p_thing)
|
|
|
|
property_prediction_list.append(p_property)
|
|
|
|
return thing_prediction_list, property_prediction_list
|
|
|
|
|
|
|
|
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# add labels too
|
|
|
|
thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
2024-09-11 08:17:02 +09:00
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
# Convert the list to a Pandas DataFrame
|
|
|
|
df = pd.DataFrame({'p_thing': thing_prediction_list,
|
|
|
|
'p_property': property_prediction_list,
|
|
|
|
'thing': thing_actual_list,
|
|
|
|
'property' : property_actual_list})
|
|
|
|
|
|
|
|
df['p_thing_correct'] = df['p_thing'] == df['thing']
|
|
|
|
df['p_property_correct'] = df['p_property'] == df['property']
|
2024-09-11 08:17:02 +09:00
|
|
|
|
2024-09-12 22:57:19 +09:00
|
|
|
# %%
|
|
|
|
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
|
|
|
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
|
|
|
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
|
|
|
# %%
|
|
|
|
df[~df["p_property_correct"]]
|
2024-09-11 08:17:02 +09:00
|
|
|
|
|
|
|
# %%
|
2024-09-12 22:57:19 +09:00
|
|
|
df['p_thing']
|
|
|
|
# %%
|
|
|
|
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
|
|
|
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
|
|
|
|
|
|
|
|