learn_jax/t5_jax_prediction.py

288 lines
8.9 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # prediction code
# ## import and process test data
# %%
# import libraries
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_enable_x64", False)
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset
import evaluate
from tqdm import tqdm
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ml_collections import ConfigDict
import time
from parallel.dataload import DataPrepare
import orbax.checkpoint as ocp
# %%
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
# 'input': f"<DESC>{row['tag_description']}<DESC>",
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
# 'answer': f"{row['thing']} {row['property']}",
# 'answer_thing': row['thing'],
# 'answer_property': row['property'],
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %%
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
from transformers import FlaxT5ForConditionalGeneration
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
params = model.params
# %%
seed = 117
batch_size = 128
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# Initialize our prediction
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(test_dataset, data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
# %%
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
# we need to replicate model over devices
replicated_params = jax.device_put_replicated(params, jax.devices())
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
output_ids = model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
params=params,
**gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
pred_generations = []
pred_labels = []
decoder_input_list = []
rng, input_rng = jax.random.split(rng)
pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
# print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
for _ in tqdm(range(pred_steps), desc="Predicting..."):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
decoder_input = batch["decoder_input_ids"]
# generation
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
decoder_input_list.extend(decoder_input)
# %%
# %% [markdown]
# # process predictions
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
# %%
# code to get special token ids
sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT><SIG><UNIT><DATA_TYPE>"
tokens = tokenizer.tokenize(sentence)
print("Tokens:", tokens)
# Get the IDs (integer indices) of specific tokens
token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
print("Token IDs:", token_ids)
# %%
# extract sequence and decode
def extract_seq(tokens, start_value, end_value):
if start_value not in tokens or end_value not in tokens:
return None # Or handle this case according to your requirements
start_id = np.where(tokens == start_value)[0][0]
end_id = np.where(tokens == end_value)[0][0]
return tokens[start_id+1:end_id]
# %%
i = 2
print(pred_generations[i])
print(extract_seq(pred_generations[i], 32100, 32101))
# %%
def process_tensor_output(tokens):
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
p_thing = None
p_property = None
if (thing_seq is not None):
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
if (property_seq is not None):
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
return p_thing, p_property
# %%
# decode prediction labels
def decode_preds(tokens_list):
thing_prediction_list = []
property_prediction_list = []
for tokens in tokens_list:
p_thing, p_property = process_tensor_output(tokens)
thing_prediction_list.append(p_thing)
property_prediction_list.append(p_property)
return thing_prediction_list, property_prediction_list
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
# %%
# add labels too
thing_actual_list, property_actual_list = decode_preds(pred_labels)
# Convert the list to a Pandas DataFrame
df = pd.DataFrame({'p_thing': thing_prediction_list,
'p_property': property_prediction_list,
'thing': thing_actual_list,
'property' : property_actual_list})
df['p_thing_correct'] = df['p_thing'] == df['thing']
df['p_property_correct'] = df['p_property'] == df['property']
# %%
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
print("property accuracy", sum(df['p_property_correct'])/len(df))
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
# %%
df[~df["p_property_correct"]]
# %%
df['p_thing']
# %%
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"