175 lines
6.3 KiB
Python
175 lines
6.3 KiB
Python
# %%
|
|
# Prepare dataloader for jax training
|
|
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
from datasets import ClassLabel, Value, Sequence
|
|
from ml_collections import ConfigDict
|
|
import numpy as np
|
|
import jax.numpy as jnp
|
|
import jax
|
|
import math
|
|
from typing import Optional, List, Tuple, Callable, cast
|
|
|
|
|
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
|
# file_path = 'combined_data'
|
|
# split_datasets = load_from_disk(file_path)
|
|
# training_size = len(split_datasets['train'])
|
|
|
|
from transformers import T5TokenizerFast
|
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
# Define additional special tokens
|
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
# Add the additional special tokens to the tokenizer
|
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
|
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
|
|
|
|
|
# class takes in a dataset
|
|
class DataPrepare():
|
|
|
|
def __init__(self, raw_dataset, config):
|
|
self.raw_dataset: Dataset = raw_dataset
|
|
self.train_dataset: Optional[Dataset] = None
|
|
self.size: int = len(raw_dataset)
|
|
self.config: ConfigDict = config
|
|
|
|
self.make_dataset()
|
|
|
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
|
|
|
# given a dataset entry, run it through the tokenizer
|
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
def preprocess_function(self, example: Dataset):
|
|
inputs = example['input']
|
|
targets = example['output']
|
|
# text_target sets the corresponding label to inputs
|
|
# there is no need to create a separate 'labels'
|
|
# produce input_ids and decoder_input_ids
|
|
model_inputs = tokenizer(
|
|
inputs,
|
|
max_length=self.config.max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
labels = tokenizer(
|
|
text_target=targets,
|
|
max_length=self.config.max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
|
|
# for loss computation
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
# make decoder input ids
|
|
decoder_input_ids = shift_tokens_right_fn(
|
|
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
|
)
|
|
# require by model
|
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
|
|
|
return model_inputs
|
|
|
|
def make_dataset(self):
|
|
train_dataset = self.raw_dataset.map(
|
|
self.preprocess_function,
|
|
batched=True,
|
|
num_proc=1,
|
|
# if we do not remove, we keep the original data
|
|
remove_columns=self.raw_dataset.column_names,)
|
|
|
|
# set to numpy
|
|
train_dataset.set_format(
|
|
type='numpy',
|
|
columns=[
|
|
'input_ids', 'attention_mask',
|
|
'labels', 'decoder_input_ids',
|
|
'decoder_attention_mask']
|
|
)
|
|
|
|
# check that data fits
|
|
# for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
|
# int_array: np.array = train_dataset[name]
|
|
# if np.all((int_array >= 0) & (int_array <= 65535)):
|
|
# continue
|
|
# else:
|
|
# raise ValueError("Values are out of range for uint16")
|
|
|
|
# change to compact datatypes
|
|
# features = train_dataset.features.copy()
|
|
# features['input_ids'] = Sequence(Value('uint16'))
|
|
# features['attention_mask'] = Sequence(Value('uint16'))
|
|
# features['labels'] = Sequence(Value('uint16'))
|
|
# features['decoder_input_ids'] = Sequence(Value('uint16'))
|
|
# features['decoder_attention_mask'] = Sequence(Value('uint16'))
|
|
# train_dataset = train_dataset.cast(features)
|
|
# assign the dataset to train_dataset
|
|
self.train_dataset = train_dataset
|
|
|
|
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
"""
|
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
"""
|
|
assert(self.train_dataset is not None)
|
|
dataset: Dataset = cast(Dataset, self.train_dataset)
|
|
|
|
if shuffle:
|
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
batch_idx = np.asarray(batch_idx)
|
|
else:
|
|
batch_idx = np.arange(len(dataset))
|
|
|
|
if drop_last:
|
|
steps_per_epoch = len(dataset) // batch_size
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
else:
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
|
|
|
for idx in batch_idx:
|
|
batch = dataset[idx]
|
|
batch = {k: v for k, v in batch.items()}
|
|
|
|
yield batch
|
|
|
|
|
|
# testing out the class
|
|
# %%
|
|
# init object
|
|
# e.g. Config
|
|
# data_config = ConfigDict(
|
|
# dict(
|
|
# max_length=86,
|
|
# pad_token_id=0,
|
|
# decoder_start_token_id=0
|
|
# )
|
|
# )
|
|
#
|
|
# from datasets import load_from_disk
|
|
# split_datasets = load_from_disk(file_path)
|
|
# dataprep = DataPrepare(split_datasets['train'], data_config)
|
|
#
|
|
# # %%
|
|
# seed = 117
|
|
# rng = jax.random.PRNGKey(seed)
|
|
# train_loader = dataprep.data_loader(rng, batch_size=32)
|
|
#
|
|
#
|
|
#
|
|
# # %%
|
|
# batch = next(iter(train_loader))
|
|
# batch['input_ids'].shape
|
|
# # %%
|