From edd9c3551f9c8781922bb751bcfa52d00a20f880 Mon Sep 17 00:00:00 2001 From: Richard Wong Date: Thu, 12 Sep 2024 22:57:19 +0900 Subject: [PATCH] Feat: implement working prediction --- .gitignore | 1 + check_time.py | 27 ++++ t5_jax.py | 293 +++++++++++-------------------------------- t5_jax_prediction.py | 201 +++++++++++++---------------- 4 files changed, 189 insertions(+), 333 deletions(-) create mode 100644 check_time.py diff --git a/.gitignore b/.gitignore index bd49e38..d20f0d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.ipynb t5_*/ exports/ +modified_t5_model/ diff --git a/check_time.py b/check_time.py new file mode 100644 index 0000000..4e8c074 --- /dev/null +++ b/check_time.py @@ -0,0 +1,27 @@ +from pathlib import Path + +# Define the folder to check +folder = Path(".") + +# Get all .py and .ipynb files in the folder +py_files = {file.stem: file for file in folder.glob("*.py")} +ipynb_files = {file.stem: file for file in folder.glob("*.ipynb")} + +# Check for linked .py and .ipynb files +all_newer = True + +for stem, py_file in py_files.items(): + if stem in ipynb_files: + ipynb_file = ipynb_files[stem] + + # Compare the modification times + if py_file.stat().st_mtime > ipynb_file.stat().st_mtime: + print(f"{py_file} is newer than {ipynb_file}.") + else: + print(f"{py_file} is not newer than {ipynb_file}.") + all_newer = False + +if all_newer: + print("All linked .py files are newer than their corresponding .ipynb files.") +else: + print("Some .py files are not newer than their corresponding .ipynb files.") \ No newline at end of file diff --git a/t5_jax.py b/t5_jax.py index 240b005..952d1a3 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -16,67 +16,6 @@ # %% [markdown] # # T5 implementation using jax -# %% [markdown] -# ## import - -# %% [raw] -# import json -# import logging -# import math -# import os -# import sys -# import time -# from dataclasses import asdict, dataclass, field -# from enum import Enum -# from functools import partial -# from pathlib import Path -# from typing import Callable, Optional -# -# import datasets -# import evaluate -# import jax -# import jax.numpy as jnp -# import nltk # Here to have a nice missing dependency error message early on -# import numpy as np -# import optax -# from datasets import Dataset, load_dataset -# from filelock import FileLock -# from flax import jax_utils, traverse_util -# from flax.jax_utils import pad_shard_unpad, unreplicate -# from flax.training import train_state -# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key -# from tqdm import tqdm -# -# import transformers -# from transformers import ( -# CONFIG_MAPPING, -# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, -# AutoConfig, -# AutoTokenizer, -# FlaxAutoModelForSeq2SeqLM, -# HfArgumentParser, -# is_tensorboard_available, -# ) -# from transformers.utils import is_offline_mode, send_example_telemetry -# -# -# logger = logging.getLogger(__name__) -# -# try: -# nltk.data.find("tokenizers/punkt") -# except (LookupError, OSError): -# if is_offline_mode(): -# raise LookupError( -# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" -# ) -# with FileLock(".lock") as lock: -# nltk.download("punkt", quiet=True) -# -# -# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) -# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - # %% import jax import jax.numpy as jnp @@ -88,8 +27,11 @@ import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "high") - jax.config.update("jax_enable_x64", False) +# enable cache +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig @@ -108,6 +50,7 @@ from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +import flax.core import time @@ -116,14 +59,15 @@ import time # %% import os os.environ['XLA_FLAGS'] = ( - '--xla_gpu_enable_triton_softmax_fusion=True ' - '--xla_gpu_triton_gemm_any=True ' + '--xla_gpu_triton_gemm_any=true --xla_gpu_enable_custom_fusions=true --xla_gpu_enable_address_computation_fusion=true' ) os.environ.update({ + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : ".95" }) # %% @@ -132,17 +76,10 @@ print(xla_bridge.get_backend().platform) # %% -# nltk.download('punkt') try: nltk.data.find("tokenizers/punkt") except (LookupError, OSError): - if is_offline_mode(): - raise LookupError( - "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" - ) - with FileLock(".lock") as lock: - nltk.download("punkt", quiet=True) - + print("error") # %% [markdown] @@ -153,17 +90,8 @@ except (LookupError, OSError): model_name_or_path = "t5-small" # Replace with your specific model name # Load configuration -config = AutoConfig.from_pretrained(model_name_or_path) - -# Load model -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - model_name_or_path -) - - -# %% -model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) -shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") +config = AutoConfig.from_pretrained(model_name_or_path, + force_download=False) # %% @@ -173,7 +101,11 @@ save_path = 't5_80_1' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) -# prepare tokenizer +# %% + +split_datasets['train'][0] + +# %% from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens @@ -183,6 +115,43 @@ tokenizer.add_special_tokens({"additional_special_tokens": additional_special_to max_length = 86 +# %% +len(tokenizer) + +# %% +# load pytorch model first +# from transformers import AutoModelForSeq2SeqLM +# model_checkpoint = "t5-base" +# model_pt = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) +# # important! after extending tokens vocab +# model_pt.resize_token_embeddings(len(tokenizer)) +# model_pt.save_pretrained('./modified_t5_model') +# model = FlaxAutoModelForSeq2SeqLM.from_pretrained( +# pretrained_model_name_or_path="modified_t5_model", +# dtype=jax.numpy.bfloat16, +# from_pt=True +# ) + + +# %% +model_path = './t5_80_1' +# model_path = 't5=base' +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + pretrained_model_name_or_path=model_path, + dtype=jax.numpy.float32 +) + + +# %% +model.params_shape_tree['shared'] + +# %% +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + + + +# %% # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file @@ -191,21 +160,19 @@ max_length = 86 # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): - input = example['input'] - target = example['output'] + inputs = example['input'] + targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' model_inputs = tokenizer( - input, - text_target=target, + inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( - input, - text_target=target, + text_target=targets, max_length=max_length, padding="max_length", truncation=True, @@ -233,15 +200,18 @@ tokenized_datasets = split_datasets.map( ) - - -# %% -tokenized_datasets +tokenized_datasets.set_format(type='numpy', + columns=['input_ids', 'attention_mask', + 'labels', 'decoder_input_ids', + 'decoder_attention_mask']) # %% train_dataset = tokenized_datasets["train"] eval_dataset = tokenized_datasets["validation"] +# %% +train_dataset[0] + # %% def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): @@ -270,65 +240,14 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf yield batch - -# %% [markdown] -# Now we have model inputs in terms of the variable tokenized_datasets - -# %% -# metric -metric = evaluate.load("sacrebleu") - -def postprocess_text(preds, labels): - preds = [pred.strip() for pred in preds] - labels = [label.strip() for label in labels] - - # rougeLSum expects newline after each sentence - preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] - labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] - - return preds, labels - -# def compute_metrics(preds, labels): -# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) -# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) -# -# # Some simple post-processing -# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) -# -# result = metric.compute(predictions=decoded_preds, references=decoded_labels) -# result = {k: round(v * 100, 4) for k, v in result.items()} -# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] -# result["gen_len"] = np.mean(prediction_lens) -# return result - -def compute_metrics(preds, labels): - # In case the model returns more than the prediction logits - if isinstance(preds, tuple): - preds = preds[0] - - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - - # Replace -100s in the labels as we can't decode them - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - - # Some simple post-processing - decoded_preds = [pred.strip() for pred in decoded_preds] - decoded_labels = [[label.strip()] for label in decoded_labels] - - result = metric.compute(predictions=decoded_preds, references=decoded_labels) - return {"bleu": result["score"]} - - - # %% [markdown] # # Model # %% # Store some constant seed = 117 -num_epochs = 80 -batch_size = 96 +num_epochs = 40 +batch_size = 32 num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() @@ -338,16 +257,16 @@ steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 -learning_rate = 5e-5 +learning_rate = 2e-5 -weight_decay = 0.0 +weight_decay = 0.01 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 -val_max_target_length = None +val_max_target_length = 128 predict_with_generate = True @@ -421,15 +340,14 @@ class TrainState(train_state.TrainState): def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) -# Ensure model.params is properly initialized (this is just an example) -# Normally you would get this from a model initialization call with dummy input +# set bf16 for model params +# model.params = model.to_bf16(model.params) params = model.params # Cast parameters to bfloat16 if desired -params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) - +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) # Setup train state -state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) +state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): @@ -481,21 +399,6 @@ def train_step(state, batch, label_smoothing_factor=0.0): metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics -# Define eval fn -def eval_step(params, batch, label_smoothing_factor=0.0): - labels = batch.pop("labels") - logits = model(**batch, params=params, train=False)[0] - - loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) - num_labels = jax.lax.psum(num_labels, "batch") - - # true loss = total loss / total samples - loss = jax.lax.psum(loss, "batch") - loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) - - metrics = {"loss": loss} - return metrics - # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length @@ -512,7 +415,7 @@ def generate_step(params, batch): p_train_step = jax.pmap( partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) ) -p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") +# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device @@ -563,50 +466,6 @@ for epoch in epochs: f" {train_metric['learning_rate']})" ) - # ======================== Evaluating ============================== - eval_metrics = [] - eval_preds = [] - eval_labels = [] - - eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) - eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) - for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): - # Model forward - batch = next(eval_loader) - labels = batch["labels"] - - metrics = pad_shard_unpad(p_eval_step, static_return=True)( - state.params, batch, min_device_batch=per_device_eval_batch_size - ) - eval_metrics.append(metrics) - - # generation - if predict_with_generate: - generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) - eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) - eval_labels.extend(labels) - - # normalize eval metrics - eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) - - # compute metrics - rouge_desc = "" - if predict_with_generate: - rouge_metrics = compute_metrics(eval_preds, eval_labels) - eval_metrics.update(rouge_metrics) - rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) - - # Print metrics and update progress bar - desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" - epochs.write(desc) - epochs.desc = desc - - # Save metrics - # if has_tensorboard and jax.process_index() == 0: - # cur_step = epoch * (len(train_dataset) // train_batch_size) - # write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) - output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: @@ -614,7 +473,3 @@ for epoch in epochs: model.save_pretrained(output_dir, params=params) tokenizer.save_pretrained(output_dir) - - -# %% [markdown] -# # diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index 217372b..5961448 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -39,10 +39,9 @@ from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets -from datasets import Dataset, load_dataset +from datasets import Dataset import evaluate from tqdm import tqdm -from datasets import load_from_disk import nltk # Here to have a nice missing dependency error message early on @@ -76,9 +75,9 @@ def process_df(df): # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", # 'input': f"{row['tag_description']}{row['unit']}", 'output': f"{row['thing']}{row['property']}", - 'answer': f"{row['thing']} {row['property']}", - 'answer_thing': row['thing'], - 'answer_property': row['property'], + # 'answer': f"{row['thing']} {row['property']}", + # 'answer_thing': row['thing'], + # 'answer_property': row['property'], } for _, row in df.iterrows()] return output_list @@ -93,14 +92,14 @@ test_dataset = Dataset.from_list(process_df(df)) # %% # load model -model_name_or_path = "t5_80_1" # Replace with your specific model name +model_name_or_path = "./t5_80_1" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - model_name_or_path + pretrained_model_name_or_path=model_name_or_path ) @@ -124,21 +123,19 @@ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): - input = example['input'] - target = example['output'] + inputs = example['input'] + targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' model_inputs = tokenizer( - input, - text_target=target, + inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( - input, - text_target=target, + text_target=targets, max_length=max_length, padding="max_length", truncation=True, @@ -191,7 +188,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf yield batch # %% [markdown] -# # Model Training +# # model generation # %% seed = 117 @@ -205,17 +202,8 @@ eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(test_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs -warmup_steps = 0 -learning_rate = 5e-5 - -weight_decay = 0.0 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_epsilon = 1e-8 -label_smoothing_factor = 0.0 - num_beams = 1 -val_max_target_length = None +val_max_target_length = 128 predict_with_generate = True @@ -224,55 +212,6 @@ predict_with_generate = True rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) -def create_learning_rate_fn( - train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.ndarray]: - """Returns a linear warmup, linear_decay learning rate function.""" - steps_per_epoch = train_ds_size // train_batch_size - num_train_steps = steps_per_epoch * num_train_epochs - warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) - decay_fn = optax.linear_schedule( - init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps - ) - schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) - return schedule_fn - - -# Create learning rate schedule -linear_decay_lr_schedule_fn = create_learning_rate_fn( - len(test_dataset), - train_batch_size, - num_train_epochs, - warmup_steps, - learning_rate, -) - -# We use Optax's "masking" functionality to not apply weight decay -# to bias and LayerNorm scale parameters. decay_mask_fn returns a -# mask boolean with the same structure as the parameters. -# The mask is True for parameters that should be decayed. -def decay_mask_fn(params): - flat_params = traverse_util.flatten_dict(params) - # find out all LayerNorm parameters - layer_norm_candidates = ["layernorm", "layer_norm", "ln"] - layer_norm_named_params = { - layer[-2:] - for layer_norm_name in layer_norm_candidates - for layer in flat_params.keys() - if layer_norm_name in "".join(layer).lower() - } - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} - return traverse_util.unflatten_dict(flat_mask) - -# create adam optimizer -adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=adam_beta1, - b2=adam_beta2, - eps=adam_epsilon, - weight_decay=weight_decay, - mask=decay_mask_fn, -) # %% @@ -288,23 +227,14 @@ model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_name_or_path ) -# Training functions -class TrainState(train_state.TrainState): - dropout_rng: jnp.ndarray - - def replicate(self): - return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input params = model.params -# Cast parameters to bfloat16 if desired -params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) - - -# Setup train state -state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) - +# ensure full size floats +params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) +# we need to replicate model over devices +replicated_params = jax.device_put_replicated(params_f16, jax.devices()) # Define generation function @@ -315,18 +245,14 @@ num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): - model.params = params - output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_generate_step = jax.pmap(generate_step, "batch") -# Replicate the train state on each device -state = state.replicate() -pred_metrics = [] pred_generations = [] pred_labels = [] @@ -342,45 +268,92 @@ print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") -for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False): +for _ in tqdm(range(pred_steps), desc="Predicting..."): # Model forward batch = next(pred_loader) labels = batch["labels"] # generation - if predict_with_generate: - generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) - pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) - pred_labels.extend(labels) + generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) + pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + pred_labels.extend(labels) -# Print metrics -# desc = f"Predict Loss: {pred_metrics['loss']})" -# print(desc) +# %% [markdown] +# # process predictions + # %% -# save predictions to parquet +# code to get special token ids +# sentence = "" +# tokens = tokenizer.tokenize(sentence) +# print("Tokens:", tokens) +# # Get the IDs (integer indices) of specific tokens +# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] +# print("Token IDs:", token_ids) + +# %% +# extract sequence and decode +def extract_seq(tokens, start_value, end_value): + if start_value not in tokens or end_value not in tokens: + return None # Or handle this case according to your requirements + start_id = np.where(tokens == start_value)[0][0] + end_id = np.where(tokens == end_value)[0][0] + + return tokens[start_id+1:end_id] + + +def process_tensor_output(tokens): + thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = + property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = + p_thing = None + p_property = None + if (thing_seq is not None): + p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain + if (property_seq is not None): + p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain + return p_thing, p_property + + +# %% # decode prediction labels -def decode_preds(preds): - # In case the model returns more than the prediction logits - if isinstance(preds, tuple): - preds = preds[0] +def decode_preds(tokens_list): + thing_prediction_list = [] + property_prediction_list = [] + for tokens in tokens_list: + p_thing, p_property = process_tensor_output(tokens) + thing_prediction_list.append(p_thing) + property_prediction_list.append(p_property) + return thing_prediction_list, property_prediction_list - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - - decoded_preds = [pred for pred in decoded_preds] - - return decoded_preds +thing_prediction_list, property_prediction_list = decode_preds(pred_generations) +# %% +# add labels too +thing_actual_list, property_actual_list = decode_preds(pred_labels) # Convert the list to a Pandas DataFrame -df = pd.DataFrame(decode_preds(pred_labels)) - -# Save the DataFrame as a Parquet file (using pyarrow or fastparquet) -df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" - +df = pd.DataFrame({'p_thing': thing_prediction_list, + 'p_property': property_prediction_list, + 'thing': thing_actual_list, + 'property' : property_actual_list}) +df['p_thing_correct'] = df['p_thing'] == df['thing'] +df['p_property_correct'] = df['p_property'] == df['property'] # %% +print("thing accuracy", sum(df['p_thing_correct'])/len(df)) +print("property accuracy", sum(df['p_property_correct'])/len(df)) +print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) +# %% +df[~df["p_property_correct"]] + +# %% +df['p_thing'] +# %% +# Save the DataFrame as a Parquet file (using pyarrow or fastparquet) +# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" + +