# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # --- # %% [markdown] # # prediction code # ## import and process test data # %% # import libraries import pandas as pd import matplotlib.pyplot as plt from datasets import Dataset, DatasetDict import jax import jax.numpy as jnp import optax import numpy as np from functools import partial from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "high") jax.config.update("jax_enable_x64", False) from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset import evaluate from tqdm import tqdm import nltk # Here to have a nice missing dependency error message early on from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key import time # %% # data_path = f"../make_data/select_db/data_mapping_filtered.csv" # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' # Ensure to include 'ships_idx' in the fields list fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] # Load the dataset df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) def process_df(df): output_list = [{ 'input': f"{row['tag_name']}{row['tag_description']}", # 'input': f"{row['tag_description']}", # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", # 'input': f"{row['tag_description']}{row['unit']}", 'output': f"{row['thing']}{row['property']}", # 'answer': f"{row['thing']} {row['property']}", # 'answer_thing': row['thing'], # 'answer_property': row['property'], } for _, row in df.iterrows()] return output_list # takes 1 minute to run without batching test_dataset = Dataset.from_list(process_df(df)) # %% [markdown] # ## Load model for attributes # %% # load model model_name_or_path = "./t5_80_1" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( pretrained_model_name_or_path=model_name_or_path ) # %% [markdown] # ## Tokenizer # %% # prepare tokenizer from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) max_length = 86 model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): inputs = example['input'] targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' model_inputs = tokenizer( inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( text_target=targets, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs # map maps function to each "row" in the dataset # aka the data in the immediate nesting test_dataset = test_dataset.map( preprocess_function, batched=True, num_proc=1, remove_columns=test_dataset.column_names, ) def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = np.asarray(batch_idx) else: batch_idx = np.arange(len(dataset)) if drop_last: steps_per_epoch = len(dataset) // batch_size batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) else: steps_per_epoch = math.ceil(len(dataset) / batch_size) batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} yield batch # %% [markdown] # # model generation # %% seed = 117 num_epochs = 80 batch_size = 96 num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(test_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs num_beams = 1 val_max_target_length = 128 predict_with_generate = True # Initialize our training rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) # %% # reload model to prevent leakage of variables # load model model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_name_or_path ) # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input params = model.params # ensure full size floats params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) # we need to replicate model over devices replicated_params = jax.device_put_replicated(params_f16, jax.devices()) # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length ) num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_generate_step = jax.pmap(generate_step, "batch") pred_generations = [] pred_labels = [] rng, input_rng = jax.random.split(rng) pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) pred_steps = math.ceil(len(test_dataset) / eval_batch_size) print("***** Running training *****") print(f" Num examples = {len(test_dataset)}") print(f" Num steps = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") for _ in tqdm(range(pred_steps), desc="Predicting..."): # Model forward batch = next(pred_loader) labels = batch["labels"] # generation generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) # %% [markdown] # # process predictions # %% # code to get special token ids # sentence = "" # tokens = tokenizer.tokenize(sentence) # print("Tokens:", tokens) # # Get the IDs (integer indices) of specific tokens # token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] # print("Token IDs:", token_ids) # %% # extract sequence and decode def extract_seq(tokens, start_value, end_value): if start_value not in tokens or end_value not in tokens: return None # Or handle this case according to your requirements start_id = np.where(tokens == start_value)[0][0] end_id = np.where(tokens == end_value)[0][0] return tokens[start_id+1:end_id] def process_tensor_output(tokens): thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = p_thing = None p_property = None if (thing_seq is not None): p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain if (property_seq is not None): p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain return p_thing, p_property # %% # decode prediction labels def decode_preds(tokens_list): thing_prediction_list = [] property_prediction_list = [] for tokens in tokens_list: p_thing, p_property = process_tensor_output(tokens) thing_prediction_list.append(p_thing) property_prediction_list.append(p_property) return thing_prediction_list, property_prediction_list thing_prediction_list, property_prediction_list = decode_preds(pred_generations) # %% # add labels too thing_actual_list, property_actual_list = decode_preds(pred_labels) # Convert the list to a Pandas DataFrame df = pd.DataFrame({'p_thing': thing_prediction_list, 'p_property': property_prediction_list, 'thing': thing_actual_list, 'property' : property_actual_list}) df['p_thing_correct'] = df['p_thing'] == df['thing'] df['p_property_correct'] = df['p_property'] == df['property'] # %% print("thing accuracy", sum(df['p_thing_correct'])/len(df)) print("property accuracy", sum(df['p_property_correct'])/len(df)) print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) # %% df[~df["p_property_correct"]] # %% df['p_thing'] # %% # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) # df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"