learn_jax/parallel/dataload.py

173 lines
6.2 KiB
Python

# %%
# Prepare dataloader for jax training
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
from transformers import FlaxT5ForConditionalGeneration
from datasets import ClassLabel, Value, Sequence
from ml_collections import ConfigDict
import numpy as np
import jax.numpy as jnp
import jax
import math
from typing import Optional, List, Tuple, Callable, cast
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
# file_path = 'combined_data'
# split_datasets = load_from_disk(file_path)
# training_size = len(split_datasets['train'])
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
# class takes in a dataset
class DataPrepare():
def __init__(self, raw_dataset, config):
self.raw_dataset: Dataset = raw_dataset
self.train_dataset: Optional[Dataset] = None
self.size: int = len(raw_dataset)
self.config: ConfigDict = config
self.make_dataset()
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(self, example: Dataset):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = tokenizer(
inputs,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
def make_dataset(self):
train_dataset = self.raw_dataset.map(
self.preprocess_function,
batched=True,
num_proc=1,
# if we do not remove, we keep the original data
remove_columns=self.raw_dataset.column_names,)
# set to numpy
train_dataset.set_format(
type='numpy',
columns=[
'input_ids', 'attention_mask',
'labels', 'decoder_input_ids',
'decoder_attention_mask']
)
# check that data fits
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
int_array: np.array = train_dataset[name]
if np.all((int_array >= 0) & (int_array <= 65535)):
continue
else:
raise ValueError("Values are out of range for uint16")
# change to compact datatypes
features = train_dataset.features.copy()
features['input_ids'] = Sequence(Value('uint16'))
features['attention_mask'] = Sequence(Value('bool'))
features['labels'] = Sequence(Value('uint16'))
features['decoder_input_ids'] = Sequence(Value('uint16'))
features['decoder_attention_mask'] = Sequence(Value('bool'))
train_dataset = train_dataset.cast(features)
# assign the dataset to train_dataset
self.train_dataset = train_dataset
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
assert(self.train_dataset is not None)
dataset: Dataset = cast(Dataset, self.train_dataset)
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
# testing out the class
# # %%
# # init object
# # e.g. Config
# data_config = ConfigDict(
# dict(
# max_length=86,
# pad_token_id=0,
# decoder_start_token_id=0
# )
# )
#
# dataprep = DataPrepare(split_datasets, data_config)
#
# # %%
# seed = 117
# rng = jax.random.PRNGKey(seed)
# train_loader = dataprep.data_loader(rng, batch_size=32)
#
#
#
# # %%
# batch = next(iter(train_loader))
# batch['input_ids'].shape
# # %%