# %% # Prepare dataloader for jax training from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk from transformers import FlaxT5ForConditionalGeneration from datasets import ClassLabel, Value, Sequence from ml_collections import ConfigDict import numpy as np import jax.numpy as jnp import jax import math from typing import Optional, List, Tuple, Callable, cast # file_path = 'combined_data' # split_datasets = load_from_disk(file_path) # training_size = len(split_datasets['train']) from transformers import T5TokenizerFast # class takes in a dataset class DataPrepare(): def __init__(self, raw_dataset, config): self.raw_dataset: Dataset = raw_dataset self.size: int = len(raw_dataset) self.config: ConfigDict = config self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False) # Define additional special tokens # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] # Add the additional special tokens to the tokenizer self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 self.train_dataset = self.preprocess_function( self.raw_dataset ) # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file # given a dataset entry, run it through the tokenizer # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(self, example: Dataset): inputs = example['input'] targets = example['output'] # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' # produce input_ids and decoder_input_ids model_inputs = self.tokenizer( inputs, max_length=self.config.max_length, padding=True, truncation=True, return_tensors="np" ) # we separate it out because we need the attention mask labels = self.tokenizer( text_target=targets, max_length=self.config.max_length, padding=True, truncation=True, return_tensors="np" ) model_inputs['input_ids'] = np.asarray(model_inputs['input_ids']) model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask']) # for loss computation model_inputs["labels"] = labels["input_ids"] # make decoder input ids # this is actually "model output" shifted right decoder_input_ids = self.shift_tokens_right_fn( labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id ) # require by model model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # decoder_attention_mask = shift_tokens_right_fn( # labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id # ) # We need decoder_attention_mask so we can ignore pad tokens in loss model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"]) return model_inputs # Example pad function def _pad_to_batch_size(self, batch, target_size): # Get the current batch size input_ids = batch['input_ids'] current_size = input_ids.shape[0] if current_size < target_size: # Calculate how much padding is needed padding_size = target_size - current_size # Create padding (e.g., zeros or some appropriate value) padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D # Concatenate to create a full batch # repeat for all arrays in the tree padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch) # padded_batch = jnp.concatenate([batch, padding], axis=0) else: padded_batch = batch return padded_batch def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ dataset: Dataset = Dataset.from_dict(self.train_dataset) if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = np.asarray(batch_idx) else: batch_idx = np.arange(len(dataset)) if drop_last: steps_per_epoch = len(dataset) // batch_size batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size)) else: steps_per_epoch = math.ceil(len(dataset) / batch_size) minibatch_list = np.array_split(batch_idx, steps_per_epoch) for minibatch in minibatch_list: batch = dataset[minibatch] batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()} batch = self._pad_to_batch_size(batch, batch_size) yield batch # # testing out the class # # %% # # init object # # e.g. Config # # file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc' # data_config = ConfigDict( # dict( # max_length=86, # pad_token_id=0, # decoder_start_token_id=0 # ) # ) # # from datasets import load_from_disk # split_datasets = load_from_disk(file_path) # dataprep = DataPrepare(split_datasets['train'], data_config) # # # %% # seed = 117 # rng = jax.random.PRNGKey(seed) # train_loader = dataprep.data_loader(rng, batch_size=32) # # # # # %% # batch = next(train_loader) # # print(batch['input_ids'].shape) # print(batch['decoder_input_ids'].shape) # # # %%