173 lines
6.6 KiB
Python
173 lines
6.6 KiB
Python
# %%
|
|
# Prepare dataloader for jax training
|
|
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
from datasets import ClassLabel, Value, Sequence
|
|
from ml_collections import ConfigDict
|
|
import numpy as np
|
|
import jax.numpy as jnp
|
|
import jax
|
|
import math
|
|
from typing import Optional, List, Tuple, Callable, cast
|
|
|
|
|
|
# file_path = 'combined_data'
|
|
# split_datasets = load_from_disk(file_path)
|
|
# training_size = len(split_datasets['train'])
|
|
|
|
from transformers import T5TokenizerFast
|
|
|
|
# class takes in a dataset
|
|
class DataPrepare():
|
|
|
|
def __init__(self, raw_dataset, config):
|
|
self.raw_dataset: Dataset = raw_dataset
|
|
self.size: int = len(raw_dataset)
|
|
self.config: ConfigDict = config
|
|
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
|
# Define additional special tokens
|
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
# Add the additional special tokens to the tokenizer
|
|
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
|
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
|
|
|
|
|
self.train_dataset = self.preprocess_function(
|
|
self.raw_dataset
|
|
)
|
|
|
|
|
|
|
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
|
|
|
# given a dataset entry, run it through the tokenizer
|
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
def preprocess_function(self, example: Dataset):
|
|
inputs = example['input']
|
|
targets = example['output']
|
|
# text_target sets the corresponding label to inputs
|
|
# there is no need to create a separate 'labels'
|
|
# produce input_ids and decoder_input_ids
|
|
model_inputs = self.tokenizer(
|
|
inputs,
|
|
max_length=self.config.max_length,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
# we separate it out because we need the attention mask
|
|
labels = self.tokenizer(
|
|
text_target=targets,
|
|
max_length=self.config.max_length,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
|
|
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
|
|
# for loss computation
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
# make decoder input ids
|
|
# this is actually "model output" shifted right
|
|
decoder_input_ids = self.shift_tokens_right_fn(
|
|
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
|
)
|
|
# require by model
|
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
# decoder_attention_mask = shift_tokens_right_fn(
|
|
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
|
|
# )
|
|
# We need decoder_attention_mask so we can ignore pad tokens in loss
|
|
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
|
|
|
|
return model_inputs
|
|
|
|
|
|
# Example pad function
|
|
def _pad_to_batch_size(self, batch, target_size):
|
|
# Get the current batch size
|
|
input_ids = batch['input_ids']
|
|
current_size = input_ids.shape[0]
|
|
if current_size < target_size:
|
|
# Calculate how much padding is needed
|
|
padding_size = target_size - current_size
|
|
# Create padding (e.g., zeros or some appropriate value)
|
|
padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D
|
|
# Concatenate to create a full batch
|
|
# repeat for all arrays in the tree
|
|
padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch)
|
|
# padded_batch = jnp.concatenate([batch, padding], axis=0)
|
|
|
|
else:
|
|
padded_batch = batch
|
|
return padded_batch
|
|
|
|
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
"""
|
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
"""
|
|
dataset: Dataset = Dataset.from_dict(self.train_dataset)
|
|
|
|
if shuffle:
|
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
batch_idx = np.asarray(batch_idx)
|
|
else:
|
|
batch_idx = np.arange(len(dataset))
|
|
|
|
if drop_last:
|
|
steps_per_epoch = len(dataset) // batch_size
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
else:
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
minibatch_list = np.array_split(batch_idx, steps_per_epoch)
|
|
|
|
for minibatch in minibatch_list:
|
|
batch = dataset[minibatch]
|
|
batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()}
|
|
batch = self._pad_to_batch_size(batch, batch_size)
|
|
|
|
yield batch
|
|
|
|
|
|
# # testing out the class
|
|
# # %%
|
|
# # init object
|
|
# # e.g. Config
|
|
#
|
|
# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc'
|
|
# data_config = ConfigDict(
|
|
# dict(
|
|
# max_length=86,
|
|
# pad_token_id=0,
|
|
# decoder_start_token_id=0
|
|
# )
|
|
# )
|
|
#
|
|
# from datasets import load_from_disk
|
|
# split_datasets = load_from_disk(file_path)
|
|
# dataprep = DataPrepare(split_datasets['train'], data_config)
|
|
#
|
|
# # %%
|
|
# seed = 117
|
|
# rng = jax.random.PRNGKey(seed)
|
|
# train_loader = dataprep.data_loader(rng, batch_size=32)
|
|
#
|
|
#
|
|
#
|
|
# # %%
|
|
# batch = next(train_loader)
|
|
#
|
|
# print(batch['input_ids'].shape)
|
|
# print(batch['decoder_input_ids'].shape)
|
|
#
|
|
# # %%
|