# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # --- # %% [markdown] # # prediction code # ## import and process test data # %% # import libraries import pandas as pd import matplotlib.pyplot as plt from datasets import Dataset, DatasetDict import jax import jax.numpy as jnp import optax import numpy as np from functools import partial from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "high") jax.config.update("jax_enable_x64", False) from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset, load_dataset import evaluate from tqdm import tqdm from datasets import load_from_disk import nltk # Here to have a nice missing dependency error message early on from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key import time # %% # data_path = f"../make_data/select_db/data_mapping_filtered.csv" # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' # Ensure to include 'ships_idx' in the fields list fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] # Load the dataset df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) def process_df(df): output_list = [{ 'input': f"{row['tag_name']}{row['tag_description']}", # 'input': f"{row['tag_description']}", # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", # 'input': f"{row['tag_description']}{row['unit']}", 'output': f"{row['thing']}{row['property']}", 'answer': f"{row['thing']} {row['property']}", 'answer_thing': row['thing'], 'answer_property': row['property'], } for _, row in df.iterrows()] return output_list # takes 1 minute to run without batching test_dataset = Dataset.from_list(process_df(df)) # %% [markdown] # ## Load model for attributes # %% # load model model_name_or_path = "t5_80_1" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_name_or_path ) # %% [markdown] # ## Tokenizer # %% # prepare tokenizer from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) max_length = 86 model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(example): input = example['input'] target = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' model_inputs = tokenizer( input, text_target=target, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) labels = tokenizer( input, text_target=target, max_length=max_length, padding="max_length", truncation=True, return_tensors="np" ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs # map maps function to each "row" in the dataset # aka the data in the immediate nesting test_dataset = test_dataset.map( preprocess_function, batched=True, num_proc=1, remove_columns=test_dataset.column_names, ) def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = np.asarray(batch_idx) else: batch_idx = np.arange(len(dataset)) if drop_last: steps_per_epoch = len(dataset) // batch_size batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) else: steps_per_epoch = math.ceil(len(dataset) / batch_size) batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} yield batch # %% [markdown] # # Model Training # %% seed = 117 num_epochs = 80 batch_size = 96 num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(test_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 learning_rate = 5e-5 weight_decay = 0.0 adam_beta1 = 0.9 adam_beta2 = 0.999 adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 val_max_target_length = None predict_with_generate = True # Initialize our training rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(test_dataset), train_batch_size, num_train_epochs, warmup_steps, learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() } flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon, weight_decay=weight_decay, mask=decay_mask_fn, ) # %% # reload model to prevent leakage of variables # load model model_name_or_path = "t5_80_1" # Replace with your specific model name # Load configuration config = AutoConfig.from_pretrained(model_name_or_path) # Load model model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_name_or_path ) # Training functions class TrainState(train_state.TrainState): dropout_rng: jnp.ndarray def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input params = model.params # Cast parameters to bfloat16 if desired params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) # Define generation function max_length = ( val_max_target_length if val_max_target_length is not None else model.config.max_length ) num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): model.params = params output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() pred_metrics = [] pred_generations = [] pred_labels = [] rng, input_rng = jax.random.split(rng) pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) pred_steps = math.ceil(len(test_dataset) / eval_batch_size) print("***** Running training *****") print(f" Num examples = {len(test_dataset)}") print(f" Num steps = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False): # Model forward batch = next(pred_loader) labels = batch["labels"] # generation if predict_with_generate: generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) # Print metrics # desc = f"Predict Loss: {pred_metrics['loss']})" # print(desc) # %% # save predictions to parquet # decode prediction labels def decode_preds(preds): # In case the model returns more than the prediction logits if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_preds = [pred for pred in decoded_preds] return decoded_preds # Convert the list to a Pandas DataFrame df = pd.DataFrame(decode_preds(pred_labels)) # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" # %%