286 lines
9.0 KiB
Python
286 lines
9.0 KiB
Python
# ---
|
||
# jupyter:
|
||
# jupytext:
|
||
# formats: ipynb,py:percent
|
||
# text_representation:
|
||
# extension: .py
|
||
# format_name: percent
|
||
# format_version: '1.3'
|
||
# jupytext_version: 1.16.4
|
||
# ---
|
||
|
||
# %% [markdown]
|
||
# # prediction code
|
||
# ## import and process test data
|
||
|
||
# %%
|
||
# import libraries
|
||
import os
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
|
||
from datasets import Dataset, DatasetDict
|
||
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import optax
|
||
import numpy as np
|
||
from functools import partial
|
||
from typing import Callable, Optional
|
||
import math
|
||
|
||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||
|
||
jax.config.update("jax_enable_x64", False)
|
||
|
||
|
||
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||
|
||
|
||
import datasets
|
||
from datasets import Dataset
|
||
import evaluate
|
||
from tqdm import tqdm
|
||
|
||
|
||
import nltk # Here to have a nice missing dependency error message early on
|
||
|
||
from flax import jax_utils, traverse_util
|
||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||
from flax.training import train_state
|
||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||
|
||
from ml_collections import ConfigDict
|
||
|
||
import time
|
||
|
||
from parallel.dataload import DataPrepare
|
||
import orbax.checkpoint as ocp
|
||
|
||
|
||
# %%
|
||
|
||
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
||
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
||
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
|
||
data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
||
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
||
|
||
# Ensure to include 'ships_idx' in the fields list
|
||
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||
|
||
# Load the dataset
|
||
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||
|
||
def process_df(df):
|
||
output_list = [{
|
||
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||
# 'input': f"<DESC>{row['tag_description']}<DESC>",
|
||
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||
# 'answer': f"{row['thing']} {row['property']}",
|
||
# 'answer_thing': row['thing'],
|
||
# 'answer_property': row['property'],
|
||
} for _, row in df.iterrows()]
|
||
|
||
return output_list
|
||
|
||
|
||
# takes 1 minute to run without batching
|
||
test_dataset = Dataset.from_list(process_df(df))
|
||
|
||
# %%
|
||
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
||
from transformers import FlaxT5ForConditionalGeneration
|
||
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
||
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
|
||
params = model.params
|
||
|
||
# %%
|
||
seed = 117
|
||
batch_size = 128
|
||
per_device_train_batch_size = batch_size
|
||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||
per_device_eval_batch_size = batch_size
|
||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||
steps_per_epoch = len(test_dataset) // train_batch_size
|
||
|
||
num_beams = 1
|
||
val_max_target_length = 128
|
||
|
||
predict_with_generate = True
|
||
|
||
|
||
# Initialize our prediction
|
||
rng = jax.random.PRNGKey(seed)
|
||
# rng, dropout_rng = jax.random.split(rng)
|
||
|
||
print("preparing data")
|
||
data_config = ConfigDict(
|
||
dict(
|
||
max_length=128,
|
||
pad_token_id=0,
|
||
decoder_start_token_id=0
|
||
)
|
||
)
|
||
|
||
dataprep = DataPrepare(test_dataset, data_config)
|
||
|
||
# %%
|
||
# Ensure model.params is properly initialized (this is just an example)
|
||
# Normally you would get this from a model initialization call with dummy input
|
||
# we need to replicate model over devices
|
||
replicated_params = jax.device_put_replicated(params, jax.devices())
|
||
|
||
|
||
# Define generation function
|
||
max_length = (
|
||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||
)
|
||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||
|
||
def generate_step(params, batch):
|
||
output_ids = model.generate(
|
||
batch["input_ids"],
|
||
attention_mask=batch["attention_mask"],
|
||
params=params,
|
||
**gen_kwargs)
|
||
return output_ids.sequences
|
||
|
||
# Create parallel version of the train and eval step
|
||
p_generate_step = jax.pmap(generate_step, "batch")
|
||
|
||
pred_generations = []
|
||
pred_labels = []
|
||
decoder_input_list = []
|
||
|
||
rng, input_rng = jax.random.split(rng)
|
||
|
||
|
||
pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False)
|
||
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
||
|
||
print("***** Running training *****")
|
||
print(f" Num examples = {len(test_dataset)}")
|
||
# print(f" Num steps = {num_epochs}")
|
||
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||
|
||
|
||
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||
# Model forward
|
||
batch = next(pred_loader)
|
||
labels = batch["labels"]
|
||
decoder_input = batch["decoder_input_ids"]
|
||
|
||
|
||
# generation
|
||
# pad_shard_unpad is useful for calling a pmap’ed function with inputs that
|
||
# aren’t divisible by the number of devices.
|
||
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
||
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||
pred_labels.extend(labels)
|
||
decoder_input_list.extend(decoder_input)
|
||
|
||
|
||
|
||
# %%
|
||
|
||
# %% [markdown]
|
||
# # process predictions
|
||
from transformers import T5TokenizerFast
|
||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
||
# Define additional special tokens
|
||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||
# Add the additional special tokens to the tokenizer
|
||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||
|
||
|
||
|
||
# %%
|
||
# code to get special token ids
|
||
sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT><SIG><UNIT><DATA_TYPE>"
|
||
tokens = tokenizer.tokenize(sentence)
|
||
print("Tokens:", tokens)
|
||
# Get the IDs (integer indices) of specific tokens
|
||
token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
||
print("Token IDs:", token_ids)
|
||
|
||
|
||
# %%
|
||
# extract sequence and decode
|
||
def extract_seq(tokens, start_value, end_value):
|
||
if start_value not in tokens or end_value not in tokens:
|
||
return None # Or handle this case according to your requirements
|
||
start_id = np.where(tokens == start_value)[0][0]
|
||
end_id = np.where(tokens == end_value)[0][0]
|
||
|
||
return tokens[start_id+1:end_id]
|
||
|
||
# %%
|
||
i = 2
|
||
print(pred_generations[i])
|
||
print(extract_seq(pred_generations[i], 32100, 32101))
|
||
|
||
|
||
# %%
|
||
def process_tensor_output(tokens):
|
||
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
||
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
||
p_thing = None
|
||
p_property = None
|
||
if (thing_seq is not None):
|
||
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||
if (property_seq is not None):
|
||
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||
return p_thing, p_property
|
||
|
||
|
||
# %%
|
||
# decode prediction labels
|
||
def decode_preds(tokens_list):
|
||
thing_prediction_list = []
|
||
property_prediction_list = []
|
||
for tokens in tokens_list:
|
||
p_thing, p_property = process_tensor_output(tokens)
|
||
thing_prediction_list.append(p_thing)
|
||
property_prediction_list.append(p_property)
|
||
return thing_prediction_list, property_prediction_list
|
||
|
||
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
|
||
|
||
# %%
|
||
# add labels too
|
||
thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
||
|
||
# Convert the list to a Pandas DataFrame
|
||
df = pd.DataFrame({'p_thing': thing_prediction_list,
|
||
'p_property': property_prediction_list,
|
||
'thing': thing_actual_list,
|
||
'property' : property_actual_list})
|
||
|
||
df['p_thing_correct'] = df['p_thing'] == df['thing']
|
||
df['p_property_correct'] = df['p_property'] == df['property']
|
||
|
||
# %%
|
||
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
||
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
||
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
||
|
||
# %%
|
||
# df[~df["p_property_correct"]]
|
||
|
||
# %%
|
||
# df['p_thing']
|
||
# %%
|
||
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||
|
||
|