# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # --- # %% [markdown] # # prediction code # ## import and process test data # %% # import libraries import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" import pandas as pd import matplotlib.pyplot as plt from datasets import Dataset, DatasetDict import jax import jax.numpy as jnp import optax import numpy as np from functools import partial from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "bfloat16") jax.config.update("jax_enable_x64", False) # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset import evaluate from tqdm import tqdm import nltk # Here to have a nice missing dependency error message early on from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from ml_collections import ConfigDict import time from parallel.dataload import DataPrepare import orbax.checkpoint as ocp # %% # data_path = f"../make_data/select_db/data_mapping_filtered.csv" # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' # Ensure to include 'ships_idx' in the fields list fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] # Load the dataset df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) def process_df(df): output_list = [{ 'input': f"{row['tag_name']}{row['tag_description']}", # 'input': f"{row['tag_description']}", # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", # 'input': f"{row['tag_description']}{row['unit']}", 'output': f"{row['thing']}{row['property']}", # 'answer': f"{row['thing']} {row['property']}", # 'answer_thing': row['thing'], # 'answer_property': row['property'], } for _, row in df.iterrows()] return output_list # takes 1 minute to run without batching test_dataset = Dataset.from_list(process_df(df)) # %% # from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration from transformers import FlaxT5ForConditionalGeneration # model_name_or_path = "./t5_80_1" # Replace with your specific model name model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path) params = model.params # %% seed = 117 batch_size = 128 per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(test_dataset) // train_batch_size num_beams = 1 val_max_target_length = 128 predict_with_generate = True # Initialize our prediction rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) print("preparing data") data_config = ConfigDict( dict( max_length=128, pad_token_id=0, decoder_start_token_id=0 ) ) dataprep = DataPrepare(test_dataset, data_config) # # example usage # # %% seed = 117 rng = jax.random.PRNGKey(seed) # %% # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input # we need to replicate model over devices replicated_params = jax.device_put_replicated(params, jax.devices()) # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length ) num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): output_ids = model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_generate_step = jax.pmap(generate_step, "batch") pred_generations = [] pred_labels = [] decoder_input_list = [] rng, input_rng = jax.random.split(rng) pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False) pred_steps = math.ceil(len(test_dataset) / eval_batch_size) print("***** Running training *****") print(f" Num examples = {len(test_dataset)}") # print(f" Num steps = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") for _ in tqdm(range(pred_steps), desc="Predicting..."): # Model forward batch = next(pred_loader) labels = batch["labels"] decoder_input = batch["decoder_input_ids"] # generation generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) decoder_input_list.extend(decoder_input) # %% # %% [markdown] # # process predictions from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False) # Define additional special tokens additional_special_tokens = ["", "", "", "", "", "", "", "", ""] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) # %% # code to get special token ids sentence = "" tokens = tokenizer.tokenize(sentence) print("Tokens:", tokens) # Get the IDs (integer indices) of specific tokens token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] print("Token IDs:", token_ids) # %% # extract sequence and decode def extract_seq(tokens, start_value, end_value): if start_value not in tokens or end_value not in tokens: return None # Or handle this case according to your requirements start_id = np.where(tokens == start_value)[0][0] end_id = np.where(tokens == end_value)[0][0] return tokens[start_id+1:end_id] # %% i = 2 print(pred_generations[i]) print(extract_seq(pred_generations[i], 32100, 32101)) # %% def process_tensor_output(tokens): thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = p_thing = None p_property = None if (thing_seq is not None): p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain if (property_seq is not None): p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain return p_thing, p_property # %% # decode prediction labels def decode_preds(tokens_list): thing_prediction_list = [] property_prediction_list = [] for tokens in tokens_list: p_thing, p_property = process_tensor_output(tokens) thing_prediction_list.append(p_thing) property_prediction_list.append(p_property) return thing_prediction_list, property_prediction_list thing_prediction_list, property_prediction_list = decode_preds(pred_generations) # %% # add labels too thing_actual_list, property_actual_list = decode_preds(pred_labels) # Convert the list to a Pandas DataFrame df = pd.DataFrame({'p_thing': thing_prediction_list, 'p_property': property_prediction_list, 'thing': thing_actual_list, 'property' : property_actual_list}) df['p_thing_correct'] = df['p_thing'] == df['thing'] df['p_property_correct'] = df['p_property'] == df['property'] # %% print("thing accuracy", sum(df['p_thing_correct'])/len(df)) print("property accuracy", sum(df['p_property_correct'])/len(df)) print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) # %% # df[~df["p_property_correct"]] # %% # df['p_thing'] # %% # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) # df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"