diff --git a/.gitignore b/.gitignore index d39181d..4d15f00 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ exports/ traces/ ruff.toml settings.json +__pycache__/ diff --git a/parallel/dataload.py b/parallel/dataload.py index 9de9632..f80adac 100644 --- a/parallel/dataload.py +++ b/parallel/dataload.py @@ -11,7 +11,6 @@ import math from typing import Optional, List, Tuple, Callable, cast -file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' # file_path = 'combined_data' # split_datasets = load_from_disk(file_path) # training_size = len(split_datasets['train']) @@ -124,6 +123,25 @@ class DataPrepare(): # assign the dataset to train_dataset self.train_dataset = train_dataset + # Example pad function + def _pad_to_batch_size(self, batch, target_size): + # Get the current batch size + input_ids = batch['input_ids'] + current_size = input_ids.shape[0] + if current_size < target_size: + # Calculate how much padding is needed + padding_size = target_size - current_size + # Create padding (e.g., zeros or some appropriate value) + padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D + # Concatenate to create a full batch + # repeat for all arrays in the tree + padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch) + # padded_batch = jnp.concatenate([batch, padding], axis=0) + + else: + padded_batch = batch + return padded_batch + def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True): """ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, @@ -148,7 +166,8 @@ class DataPrepare(): for idx in batch_idx: batch = dataset[idx] - batch = {k: np.array(v) for k, v in batch.items()} + batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()} + batch = self._pad_to_batch_size(batch, batch_size) yield batch @@ -157,6 +176,8 @@ class DataPrepare(): # # %% # # init object # # e.g. Config +# +# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc' # data_config = ConfigDict( # dict( # max_length=86, @@ -190,3 +211,5 @@ class DataPrepare(): # # # # %% +# +# %% diff --git a/t5_jax_parallel.py b/t5_jax_parallel.py deleted file mode 100644 index fb92073..0000000 --- a/t5_jax_parallel.py +++ /dev/null @@ -1,697 +0,0 @@ -# %% -import os -# Set this to True to run the model on CPU only. -USE_CPU_ONLY = False - -flags = os.environ.get("XLA_FLAGS", "") -if USE_CPU_ONLY: - flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices - # Enforce CPU-only execution - os.environ["CUDA_VISIBLE_DEVICES"] = "" - os.environ["JAX_PLATFORMS"] = "cpu" -else: - # GPU flags - flags = ( - '--xla_gpu_enable_triton_softmax_fusion=true ' - '--xla_gpu_triton_gemm_any=True ' - # '--xla_gpu_enable_async_collectives=true ' - '--xla_gpu_enable_latency_hiding_scheduler=true ' - '--xla_gpu_enable_highest_priority_async_stream=true ' - ) - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - -os.environ["XLA_FLAGS"] = flags -os.environ.update({ - "TOKENIZERS_PARALLELISM" : "false", - "CUDA_DEVICE_MAX_CONNECTIONS" : "1", - "NCCL_LL128_BUFFSIZE": "-2", - "NCCL_LL_BUFFSIZE": "-2", - "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", - # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" - }) - - - - -import functools -from functools import partial -from pprint import pprint -from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding -# from jax.experimental.pjit import pjit # superseded by jax.jit -from jax.experimental import mesh_utils -from jax.sharding import PartitionSpec -from ml_collections import ConfigDict -import optax -import logging -import time -from datasets import Dataset, load_from_disk - -from flax import jax_utils, traverse_util -from flax.jax_utils import pad_shard_unpad, unreplicate -from flax.training import train_state -from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key -from flax.core.frozen_dict import freeze, unfreeze, FrozenDict -import flax.core - -# model checkpointing and saving utilities -from flax import linen as nn -from flax.training import checkpoints, train_state -from flax import struct, serialization -import orbax.checkpoint as ocp -from flax.training import orbax_utils - -from parallel.partitions import set_partitions - -from tqdm import tqdm - -from parallel.dataload import DataPrepare - - -PyTree = Any -Metrics = Dict[str, Tuple[jax.Array, ...]] - -if USE_CPU_ONLY: - jax.config.update('jax_platform_name', 'cpu') -else: - jax.config.update("jax_default_matmul_precision", "bfloat16") - -jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") -jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - - - - -# %% -## get platform type -from jax.extend.backend import get_backend -print(get_backend().platform) -print(jax.devices()) - -# %% -# config options -file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' -save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/' -# file_path = 'combined_data' -split_datasets = load_from_disk(file_path) -training_size = len(split_datasets['train']) -# Store some constant -seed = 117 -num_epochs = 5 -batch_size = 32 -num_train_epochs = num_epochs -per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() -per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() -steps_per_epoch = training_size // train_batch_size -total_train_steps = steps_per_epoch * num_epochs - -warmup_steps = 0 -learning_rate = 5e-5 - -weight_decay = 0.01 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_epsilon = 1e-8 -label_smoothing_factor = 0.0 -num_beams = 1 -val_max_target_length = 128 -predict_with_generate = True - - -# %% -# prepare data -# init object -# e.g. Config -print("preparing data") -data_config = ConfigDict( - dict( - max_length=128, - pad_token_id=0, - decoder_start_token_id=0 - ) -) - -dataprep = DataPrepare(split_datasets['train'], data_config) -# # example usage -# # %% -seed = 117 -rng = jax.random.PRNGKey(seed) -train_loader = dataprep.data_loader(rng, batch_size=batch_size) -batch = next(iter(train_loader)) -# batch - -# %% -# model - -# working -# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init -# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init -# from parallel.t5_model.pure_t5 import make_config -# config = make_config() -# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True) - - -# %% -# from transformers import FlaxT5ForConditionalGeneration, T5Config -# model = FlaxT5ForConditionalGeneration.from_pretrained( -# "t5-base", -# dtype=jnp.bfloat16, -# ) -# # pretrained_params = model.params -# model = model.module - -# %% -# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom -from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model -main_model = custom_model.from_pretrained( - "t5-base", - dtype=jnp.float32, - # gradient_checkpointing=True, -) -params = main_model.params -# pretrained_params = model.params -model = main_model.module - -# %% -# # testing config hashability -# # some explanation: -# # The PreTrainedModel class loads a T5Config model that is not hashable because -# # it is a complicated class that pretends to be a dataclass. -# # The solution is to extract a dict from it, then make a ConfigDict from -# # ml_collections library so that we can get values via the "." operator. -# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to -# # modify the config before passing to the next layer -# from transformers import T5Config -# from t5_model.configuration_t5 import FrozenT5Config -# from ml_collections import ConfigDict, FrozenConfigDict -# -# config = T5Config.from_pretrained("t5-base").to_dict() -# config.pop('architectures') -# config.pop('id2label') -# # test if it works -# frozen_config = FrozenConfigDict(config) -# # test hash -# hash(frozen_config) - -# %% - -# %% -# # print model -# rng, input_rng = jax.random.split(rng) -# model.tabulate( -# input_rng, -# input_ids=batch['input_ids'], -# attention_mask=batch['attention_mask'], -# decoder_input_ids=batch['decoder_input_ids'], -# decoder_attention_mask=batch['decoder_attention_mask'], -# console_kwargs={"force_jupyter": True} -# ) - -# %% -# print model datatype to verify -# rng, input_rng = jax.random.split(rng) -# variables = model.init( -# input_rng, -# input_ids=batch['input_ids'], -# attention_mask=batch['attention_mask'], -# decoder_input_ids=batch['decoder_input_ids'], -# decoder_attention_mask=batch['decoder_attention_mask'] -# ) - - - - - - -# %% -# create mesh -print("creating mesh") -device_mesh = mesh_utils.create_device_mesh((1,1)) -print(device_mesh) - -mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) -print(mesh) - -def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: - return NamedSharding(mesh, pspec, memory_kind="device") - -x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis -model_sharding=mesh_sharding(PartitionSpec(None, 'model')) - - -# %% -# optimizers - -def create_learning_rate_fn( - train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.ndarray]: - """Returns a linear warmup, linear_decay learning rate function.""" - steps_per_epoch = train_ds_size // train_batch_size - num_train_steps = steps_per_epoch * num_train_epochs - warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) - decay_fn = optax.linear_schedule( - init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps - ) - schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) - return schedule_fn - - -# Create learning rate schedule -linear_decay_lr_schedule_fn = create_learning_rate_fn( - training_size, - train_batch_size, - num_train_epochs, - warmup_steps, - learning_rate, -) - -# We use Optax's "masking" functionality to not apply weight decay -# to bias and LayerNorm scale parameters. decay_mask_fn returns a -# mask boolean with the same structure as the parameters. -# The mask is True for parameters that should be decayed. -def decay_mask_fn(params): - flat_params = traverse_util.flatten_dict(params) - # find out all LayerNorm parameters - layer_norm_candidates = ["layernorm", "layer_norm", "ln"] - layer_norm_named_params = { - layer[-2:] - for layer_norm_name in layer_norm_candidates - for layer in flat_params.keys() - if layer_norm_name in "".join(layer).lower() - } - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} - return traverse_util.unflatten_dict(flat_mask) - -# create adam optimizer -adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=adam_beta1, - b2=adam_beta2, - eps=adam_epsilon, - weight_decay=weight_decay, - mask=decay_mask_fn, -) - - -print("compile") - - -# enable bf16 except for layer_norm -def create_mask_for_layer_norm(params): - flat_params = traverse_util.flatten_dict(params) - mask = { - path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params - } - mask = traverse_util.unflatten_dict(mask) - return mask - -# borrowed from transformers modeling_flax_utils -def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: - """ - Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. - """ - - # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 - def conditional_cast(param): - if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): - param = param.astype(dtype) - return param - - if mask is None: - return jax.tree_util.tree_map(conditional_cast, params) - - flat_params = traverse_util.flatten_dict(params) - flat_mask, _ = jax.tree_util.tree_flatten(mask) - - for masked, key in zip(flat_mask, sorted(flat_params.keys())): - if masked: - flat_params[key] = conditional_cast(flat_params[key]) - - return traverse_util.unflatten_dict(flat_params) - -# Cast all parameters to bfloat16 if desired -# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params) - -# %% -def init_fn(params, model, optimizer): - # do be careful with the model init - # imported models might have complicated init methods - # mask = create_mask_for_layer_norm(params) - # override params with bfloat version - # params= cast_floating_to(params, jnp.bfloat16, mask) - - state = train_state.TrainState.create( # Create a `TrainState`. - apply_fn=model.apply, - params=params, - tx=optimizer) - return state - - -# def init_fn(rng, batch, model, optimizer): -# # do be careful with the model init -# # imported models might have complicated init methods -# variables = model.init( -# rng, -# input_ids=batch['input_ids'], -# attention_mask=batch['attention_mask'], -# decoder_input_ids=batch['decoder_input_ids'], -# decoder_attention_mask=batch['decoder_attention_mask'] -# ) -# params = variables['params'] -# mask = create_mask_for_layer_norm(params) -# # override params with bfloat version -# params= cast_floating_to(params, jnp.bfloat16, mask) -# -# state = train_state.TrainState.create( # Create a `TrainState`. -# apply_fn=model.apply, -# params=params, -# tx=optimizer) -# return state - - - -# %% -# Create an abstract closure to wrap the function before feeding it in -# because `jax.eval_shape` only takes pytrees as arguments. -# eval_shape(fn, rng_key, x) -# used to perform shape inference -# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves -# rng, init_rng = jax.random.split(rng) -abstract_variables = jax.eval_shape( - functools.partial(init_fn, model=model, optimizer=adamw), params) - -# rng, init_rng = jax.random.split(rng) -# abstract_variables = jax.eval_shape( -# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch) - - -# %% -# This `state_sharding` has the same pytree structure as `state`, the output -# of the `init_fn`. -# flan.linen.get_sharding -# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh -# jax.sharding: describes how a jax.Array is laid out across devices -state_sharding = nn.get_sharding(abstract_variables, mesh) -# print(state_sharding) - -# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity - -################################################## -# # %% -# # replace the params tree with the new modified tree -# # create partitions for model -# from parallel.partitions import set_partitions -# # set_partitions freezes the params on return -# model_part_spec = set_partitions(unfreeze(params)) -# # p is already a partition spec -# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) -# -# # %% -# # get_shapes = jax.tree.map(jnp.shape, params) -# # actually tuple -# # state_shapes = jax.eval_shape(state_sharding, get_shapes) -# -# # %% -# # get pspec for opt_state -# def get_opt_spec(x): -# if isinstance(x, dict): -# return unfreeze(model_named_sharding) -# # return an empty partspec -# return mesh_sharding((PartitionSpec())) -# -# # this function replaces the empty model params spec with the 'model_named_shard' -# state_sharding = jax.tree.map( -# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) -# ) - - - -# %% -jit_init_fn = jax.jit( - init_fn, - static_argnames=('model', 'optimizer'), # skip model and optimizer - in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly - out_shardings=state_sharding # but returned initialized_state is sharded -) -initialized_state = jit_init_fn(params, model, adamw) - - -# jit_init_fn = jax.jit( -# init_fn, -# static_argnames=('model', 'optimizer'), # skip model and optimizer -# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data -# out_shardings=state_sharding -# ) -# -# -# rng, init_rng = jax.random.split(rng) -# initialized_state = jit_init_fn(rng, batch, model, adamw) - - -# %% -# train step -def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): - """ - The label smoothing implementation is adapted from Flax's official example: - https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 - """ - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing_factor - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) - - loss = optax.softmax_cross_entropy(logits, soft_labels) - loss = loss - normalizing_constant - - # ignore padded tokens from loss - loss = loss * padding_mask - loss = loss.sum() - num_labels = padding_mask.sum() - return loss, num_labels - -# %% - -# sharded_loss_fn = jax.jit( -# loss_fn, -# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis -# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model' -# ) - -def gather_and_sum( - sharded_values, - in_shardings -): - with mesh: - # Gather sharded values into a single device - gathered_values = jax.jit( - lambda x: x, in_shardings=in_shardings, out_shardings=None - )(sharded_values) - - # Compute the sum of gathered values - summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values) - return summed_value - - -# single device code annotated with jax.jit -@functools.partial( - jax.jit, - # state is state_sharding initialized from init_fn - # x_sharding is data sharded explicitly later - in_shardings=(state_sharding, x_sharding), - # return state as state_sharding - # we do not shard the metrics - out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), - donate_argnames=('state'), -) -def train_step(state, batch): - - label_smoothing_factor=0.0 - # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) - def compute_loss(params, batch): - # check constraints - # frozen dict not allowed as sharding object - # params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) - # batch = jax.lax.with_sharding_constraint(batch, x_sharding) - # labels = batch.pop("decoder_input_ids") - # no use of labels here - logits = state.apply_fn( - {'params': params}, - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - decoder_input_ids=batch['decoder_input_ids'], - decoder_attention_mask=batch['decoder_attention_mask'], - )[0] # zero because output is some structure, where first is the logit - # use labels here - loss, num_labels = loss_fn( - logits, - batch["labels"], - batch["decoder_attention_mask"], - label_smoothing_factor) - return loss, num_labels - - # compute gradients through computational graph - # allow values to pass through - grad_fn = jax.value_and_grad(compute_loss, has_aux=True) - (loss, num_labels), grad = grad_fn(state.params, batch) - # num_labels = jax.lax.psum(num_labels, "batch") - - - # true grad = total grad / total samples - # needs to be in a singleton tuple for some reason - # gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),)) - - # gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec())) - - # summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad) - # grad = jax.lax.psum(grad, "batch") - # grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) - - new_state = state.apply_gradients(grads=grad) - with jax.named_scope("sync_metrics"): - step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - # step_metrics = jax.tree.map( - # # previously needed lax.psum - # # now just write single device code, let compiler handle - # lambda x: jnp.mean(x), step_metrics - # ) - - # if metrics is None: - # metrics = step_metrics - # else: - # # combine all the synced metrics - # metrics = jax.tree.map(jnp.mean, metrics, step_metrics) - - - return new_state, step_metrics - - - - -# %% -# prep 1 step -print("1 step for jit-ting") - - -with mesh: - state, metrics = train_step(initialized_state, batch) - - -# %% - -# %% -# tr -print("***** Running training *****") -print(f" Num examples = {training_size}") -print(f" Num Epochs = {num_epochs}") -print(f" Instantaneous batch size per device = {per_device_train_batch_size}") -print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") -print(f" Total optimization steps = {total_train_steps}") - - -# %% -# jax.profiler.start_trace("./traces") - -# function to shard a batch by treating it as a pytree -def shard_batch(batch): - # Shard each element in the dictionary (i.e., each key-value pair) - return jax.tree_util.tree_map( - lambda x: jax.device_put(x, x_sharding), - batch - ) - - -print("*" * 10) -print("training start") -rng, input_rng = jax.random.split(rng) -train_time = 0 -epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) -for epoch in epochs: - train_start = time.time() - - # Create sampling rng - train_metrics = [] - steps_per_epoch = training_size // train_batch_size - train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) - # Generate an epoch by shuffling sampling indices from the train dataset - for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): - batch = next(train_loader) - # send to device - # batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()} - # batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding) - # batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding) - # batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding) - # batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding) - sharded_batch = shard_batch(batch) - with mesh: - state, train_metric = train_step(state, sharded_batch) - - # train_metrics.append(train_metric) - - - # this is for more accurate time stats, but slows down training - # train_metric['loss'].block_until_ready() - train_time = time.time() - train_start - - - - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | " - f"Loss: {train_metric['loss']}, " - f"Learning Rate:{train_metric['learning_rate']}, " - f"Last train time: {train_time})" - ) -# jax.profiler.stop_trace() -# %% -# with mesh: -# gathered_params = jax.jit( -# lambda x: x, -# in_shardings=(unfreeze(model_named_sharding),), -# out_shardings=mesh_sharding(PartitionSpec()) -# )(state.params) - -main_model = custom_model.from_pretrained('t5-base') -output_dir = save_path -# save checkpoint after each epoch and push checkpoint to the hub -if jax.process_index() == 0: - params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) - main_model.save_pretrained(output_dir, params=params) - -# # stick to defaults -# options = ocp.CheckpointManagerOptions() -# with ocp.CheckpointManager( -# ocp.test_utils.erase_and_create_empty(save_path), -# options=options, -# ) as mngr: -# -# mngr.save(0, args=ocp.args.StandardSave(state)) -# mngr.wait_until_finished() - - # After providing `args` during an initial `save` or `restore` call, the - # `CheckpointManager` instance records the type so that you do not need to - # specify it again. If the `CheckpointManager` instance is not provided with a - # `ocp.args.CheckpointArgs` instance for a particular item on a previous - # occasion it cannot be restored without specifying the argument at restore - # time. - - # # In many cases, you can restore exactly as saved without specifying additional - # # arguments. - # mngr.restore(0) - # # If customization of properties like sharding or dtype is desired, just provide - # # the abstract target PyTree, the properties of which will be used to set - # # the properties of the restored arrays. - # mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree)) - -# %% diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index f8dc1bd..9eeb63c 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -118,7 +118,7 @@ predict_with_generate = True # Initialize our prediction rng = jax.random.PRNGKey(seed) -rng, dropout_rng = jax.random.split(rng) +# rng, dropout_rng = jax.random.split(rng) print("preparing data") data_config = ConfigDict( @@ -130,11 +130,6 @@ data_config = ConfigDict( ) dataprep = DataPrepare(test_dataset, data_config) -# # example usage -# # %% -seed = 117 -rng = jax.random.PRNGKey(seed) - # %% # Ensure model.params is properly initialized (this is just an example) @@ -186,6 +181,8 @@ for _ in tqdm(range(pred_steps), desc="Predicting..."): # generation + # pad_shard_unpad is useful for calling a pmap’ed function with inputs that + # aren’t divisible by the number of devices. generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) diff --git a/t5_jax_sfp_grad_accumulate.py b/t5_jax_sfp_grad_accumulate.py deleted file mode 100644 index 52abe0a..0000000 --- a/t5_jax_sfp_grad_accumulate.py +++ /dev/null @@ -1,610 +0,0 @@ -# %% -import os -# Set this to True to run the model on CPU only. -USE_CPU_ONLY = False - -flags = os.environ.get("XLA_FLAGS", "") -if USE_CPU_ONLY: - flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices - # Enforce CPU-only execution - os.environ["CUDA_VISIBLE_DEVICES"] = "" - os.environ["JAX_PLATFORMS"] = "cpu" -else: - # GPU flags - flags = ( - '--xla_gpu_enable_triton_softmax_fusion=true ' - '--xla_gpu_triton_gemm_any=True ' - # '--xla_gpu_enable_async_collectives=true ' - '--xla_gpu_enable_latency_hiding_scheduler=true ' - '--xla_gpu_enable_highest_priority_async_stream=true ' - ) - os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" - -os.environ["XLA_FLAGS"] = flags -os.environ.update({ - "TOKENIZERS_PARALLELISM" : "false", - "CUDA_DEVICE_MAX_CONNECTIONS" : "1", - "NCCL_LL128_BUFFSIZE": "-2", - "NCCL_LL_BUFFSIZE": "-2", - "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", - # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" - }) - - - - -import functools -from functools import partial -from pprint import pprint -from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding -# from jax.experimental.pjit import pjit # superseded by jax.jit -from jax.experimental import mesh_utils -from jax.sharding import PartitionSpec -from ml_collections import ConfigDict -import optax -import logging -import time -from datasets import Dataset, load_from_disk - -from flax import jax_utils, traverse_util -from flax.training import train_state -from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key -from flax.core.frozen_dict import freeze, unfreeze, FrozenDict -import flax.core - -# model checkpointing and saving utilities -from flax import linen as nn -from flax.training import checkpoints, train_state -from flax import struct, serialization - -from parallel.partitions import set_partitions - -from tqdm import tqdm - -from parallel.dataload import DataPrepare - -# for memory tracking -# from jax_smi import initialise_tracking -# initialise_tracking() - - -PyTree = Any -Metrics = Dict[str, Tuple[jax.Array, ...]] - -if USE_CPU_ONLY: - jax.config.update('jax_platform_name', 'cpu') -else: - jax.config.update("jax_default_matmul_precision", "bfloat16") - -jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") -jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - - - - -# %% -## get platform type -from jax.extend.backend import get_backend -print(get_backend().platform) -print(jax.devices()) - -# %% -# config options -file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' -save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' -# file_path = 'combined_data' -split_datasets = load_from_disk(file_path) -training_size = len(split_datasets['train']) -# Store some constant -seed = 117 -num_epochs = 5 -batch_size = 64 -num_train_epochs = num_epochs -per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() -per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() -steps_per_epoch = training_size // train_batch_size -total_train_steps = steps_per_epoch * num_epochs - -warmup_steps = 0 -learning_rate = 5e-5 - -weight_decay = 0.01 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_epsilon = 1e-8 -label_smoothing_factor = 0.0 -num_beams = 1 -val_max_target_length = 128 -predict_with_generate = True - - -# %% -# prepare data -# init object -# e.g. Config -print("preparing data") -data_config = ConfigDict( - dict( - max_length=128, - pad_token_id=0, - decoder_start_token_id=0 - ) -) - -dataprep = DataPrepare(split_datasets['train'], data_config) -# # example usage -# # %% -seed = 117 -rng = jax.random.PRNGKey(seed) -train_loader = dataprep.data_loader(rng, batch_size=batch_size) -batch = next(iter(train_loader)) - -# %% -# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom -from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model -main_model = custom_model.from_pretrained( - "t5-base", - dtype=jnp.bfloat16, - gradient_checkpointing=True, -) -params = main_model.params -# pretrained_params = model.params -model = main_model.module - -# %% -# # testing config hashability -# # some explanation: -# # The PreTrainedModel class loads a T5Config model that is not hashable because -# # it is a complicated class that pretends to be a dataclass. -# # The solution is to extract a dict from it, then make a ConfigDict from -# # ml_collections library so that we can get values via the "." operator. -# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to -# # modify the config before passing to the next layer -# from transformers import T5Config -# from t5_model.configuration_t5 import FrozenT5Config -# from ml_collections import ConfigDict, FrozenConfigDict -# -# config = T5Config.from_pretrained("t5-base").to_dict() -# config.pop('architectures') -# config.pop('id2label') -# # test if it works -# frozen_config = FrozenConfigDict(config) -# # test hash -# hash(frozen_config) - -# %% -# # print model -# rng, input_rng = jax.random.split(rng) -# model.tabulate( -# input_rng, -# input_ids=batch['input_ids'], -# attention_mask=batch['attention_mask'], -# decoder_input_ids=batch['decoder_input_ids'], -# decoder_attention_mask=batch['decoder_attention_mask'], -# console_kwargs={"force_jupyter": True} -# ) - - - -# %% -# create mesh -print("creating mesh") -device_mesh = mesh_utils.create_device_mesh((2,2)) -print(device_mesh) - -mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) -print(mesh) - -def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: - return NamedSharding(mesh, pspec, memory_kind="device") - -x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis -model_sharding=mesh_sharding(PartitionSpec(None, 'model')) - - -# %% -# optimizers - -def create_learning_rate_fn( - train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.ndarray]: - """Returns a linear warmup, linear_decay learning rate function.""" - steps_per_epoch = train_ds_size // train_batch_size - num_train_steps = steps_per_epoch * num_train_epochs - warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) - decay_fn = optax.linear_schedule( - init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps - ) - schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) - return schedule_fn - - -# Create learning rate schedule -linear_decay_lr_schedule_fn = create_learning_rate_fn( - training_size, - train_batch_size, - num_train_epochs, - warmup_steps, - learning_rate, -) - -# We use Optax's "masking" functionality to not apply weight decay -# to bias and LayerNorm scale parameters. decay_mask_fn returns a -# mask boolean with the same structure as the parameters. -# The mask is True for parameters that should be decayed. -def decay_mask_fn(params): - flat_params = traverse_util.flatten_dict(params) - # find out all LayerNorm parameters - layer_norm_candidates = ["layernorm", "layer_norm", "ln"] - layer_norm_named_params = { - layer[-2:] - for layer_norm_name in layer_norm_candidates - for layer in flat_params.keys() - if layer_norm_name in "".join(layer).lower() - } - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} - return traverse_util.unflatten_dict(flat_mask) - -# create adam optimizer -adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=adam_beta1, - b2=adam_beta2, - eps=adam_epsilon, - weight_decay=weight_decay, - mask=decay_mask_fn, -) - -# %% - -print("compile") - -# enable bf16 -# enable only for dense, some transformer sections, and shared -def create_mask_for_layer_norm(params): - flat_params = traverse_util.flatten_dict(params) - mask = { - # path: not ( - # (path[-2] == "layer_norm" and path[-1] == "weight") or - # (path[-2] == "final_layer_norm" and path[-1] == "weight") or - # (path[-2] == "o" and path[-1] == "kernel") - # ) - # for path in flat_params - path: ( - (path[-2] == "wi" and path[-1] == "weight") or - (path[-2] == "wo" and path[-1] == "weight") or - (path[-2] == "k" and path[-1] == "kernel") or - (path[-2] == "q" and path[-1] == "kernel") or - (path[-2] == "v" and path[-1] == "kernel") or - (path[-2] == "shared" and path[-1] == "embedding") - ) for path in flat_params - } - mask = traverse_util.unflatten_dict(mask) - return mask - -# borrowed from transformers modeling_flax_utils -def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: - """ - Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. - """ - - # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 - def conditional_cast(param): - if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): - param = param.astype(dtype) - return param - - if mask is None: - return jax.tree_util.tree_map(conditional_cast, params) - - flat_params = traverse_util.flatten_dict(params) - flat_mask, _ = jax.tree_util.tree_flatten(mask) - - for masked, key in zip(flat_mask, sorted(flat_params.keys())): - if masked: - flat_params[key] = conditional_cast(flat_params[key]) - - return traverse_util.unflatten_dict(flat_params) - -# create init_fn to produce sharded state -def init_fn(params, model, optimizer): - # do be careful with the model init - # imported models might have complicated init methods - - # mask = create_mask_for_layer_norm(params) - # override params with bfloat version - # params= cast_floating_to(params, jnp.bfloat16, mask) - - state = train_state.TrainState.create( # Create a `TrainState`. - apply_fn=model.apply, - params=params, - tx=optimizer) - return state - - -abstract_variables = jax.eval_shape( - functools.partial(init_fn, model=model, optimizer=adamw), params) - -# jax.sharding: describes how a jax.Array is laid out across devices -state_sharding = nn.get_sharding(abstract_variables, mesh) -# print(state_sharding) - -# %% - -# replace the params tree with the new modified tree -# create partitions for model -from parallel.partitions import set_partitions -# set_partitions freezes the params on return -model_part_spec = set_partitions(unfreeze(params)) -# p is already a partition spec -model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) - -# get pspec for opt_state -def get_opt_spec(x): - if isinstance(x, dict): - return unfreeze(model_named_sharding) - # return an empty partspec - return mesh_sharding((PartitionSpec())) - -# this function replaces the empty model params spec with the 'model_named_shard' -state_sharding = jax.tree.map( - get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) -) - - - -jit_init_fn = jax.jit( - init_fn, - static_argnames=('model', 'optimizer'), # skip model and optimizer - in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly - out_shardings=state_sharding # but returned initialized_state is sharded -) -initialized_state = jit_init_fn(params, model, adamw) - -# %% -# train step -def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): - """ - The label smoothing implementation is adapted from Flax's official example: - https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 - """ - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing_factor - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) - logits = jnp.asarray(logits, dtype=jnp.float32) - logits = logits.astype(jnp.float32) - soft_labels = soft_labels.astype(jnp.float32) - loss = optax.softmax_cross_entropy(logits, soft_labels) - loss = loss - normalizing_constant - - # ignore padded tokens from loss - loss = loss * padding_mask - loss = loss.mean() - # num_labels = padding_mask.mean() - return loss # , num_labels - -# %% -# gradient accumulation -def accumulate_gradients_loop( - state, - batch, - minibatch_size: int, - loss_fn: Callable, -) -> Tuple[PyTree, Metrics]: - """Calculate gradients and metrics for a batch using gradient accumulation. - - Args: - state: Current training state. - batch: Full training batch. - rng: Random number generator to use. - num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. - loss_fn: Loss function to calculate gradients and metrics. - - Returns: - Tuple with accumulated gradients and metrics over the minibatches. - """ - batch_size = batch['input_ids'].shape[0] - # minibatch_size = batch_size // num_minibatches - num_minibatches = batch_size // minibatch_size - # Define gradient function for single minibatch. - # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. - # otherwise it returns (value, gradient), where value is the actual output - # of the function, hence the "value" of the namesake - grad_fn = jax.value_and_grad(loss_fn, has_aux=False) - # Prepare loop variables. - grads = None - metrics = None - for minibatch_idx in range(num_minibatches): - with jax.named_scope(f"minibatch_{minibatch_idx}"): - # Split the batch into minibatches. - start = minibatch_idx * minibatch_size - end = start + minibatch_size - minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023 - # Calculate gradients and metrics for the minibatch. - # missing value is mean loss of batch - loss, step_grads = grad_fn( - state.params, minibatch - ) - with jax.named_scope("sync_metrics"): - step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - - # Accumulate gradients and metrics across minibatches. - if grads is None: - grads = step_grads - metrics = step_metrics - else: - # accumulation adder - grads = jax.tree.map(jnp.add, grads, step_grads) - metrics = jax.tree.map(jnp.add, metrics, step_metrics) - # Average gradients over minibatches. - grads = jax.tree.map(lambda g: g / num_minibatches, grads) - return grads, metrics - - - -# single device code annotated with jax.jit -@functools.partial( - jax.jit, - # state is state_sharding initialized from init_fn - # x_sharding is data sharded explicitly later - in_shardings=(state_sharding, x_sharding), - out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), - donate_argnames=('state'), -) -def train_step(state, batch): - - label_smoothing_factor=0.0 - # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) - def compute_loss(params, batch): - # check constraints - # frozen dict not allowed as sharding object - params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) - batch = jax.lax.with_sharding_constraint(batch, x_sharding) - logits = state.apply_fn( - {'params': params}, - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - decoder_input_ids=batch['decoder_input_ids'], - decoder_attention_mask=batch['decoder_attention_mask'], - )[0] # zero because output is some structure, where first is the logit - # use labels here - # loss, num_labels = loss_fn( - loss = loss_fn( - logits, - batch["labels"], - batch["decoder_attention_mask"], - label_smoothing_factor) - return loss # , num_labels - - # # compute gradients through computational graph - # # allow values to pass through - # grad_fn = jax.value_and_grad(compute_loss, has_aux=False) - # (loss), grad = grad_fn(state.params, batch) - # # num_labels = jax.lax.psum(num_labels, "batch") - - - # new_state = state.apply_gradients(grads=grad) - # with jax.named_scope("sync_metrics"): - # step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - - # use gradient accumulation - grads, step_metrics = accumulate_gradients_loop( - state=state, - batch=batch, - minibatch_size=32, - loss_fn=compute_loss - ) - new_state = state.apply_gradients(grads=grads) - - return new_state, step_metrics - -# %% -# explore data sharding -sharded_batch = next(iter(train_loader)) -sharded_batch = jax.device_put(sharded_batch, x_sharding) -# jax.debug.visualize_array_sharding(sharded_batch['input_ids']) -# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) - - -# %% -# # prep 1 step -# print("1 step for jit-ting") -# with mesh: -# state, metrics = train_step(initialized_state, sharded_batch) - -# %% - -# %% -# tr -print("***** Running training *****") -print(f" Num examples = {training_size}") -print(f" Num Epochs = {num_epochs}") -print(f" Instantaneous batch size per device = {per_device_train_batch_size}") -print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") -print(f" Total optimization steps = {total_train_steps}") - - -# %% -# jax.profiler.start_trace("./traces") - - -print("*" * 10) -print("training start") -rng, input_rng = jax.random.split(rng) -train_time = 0 -state = initialized_state -epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) -for epoch in epochs: - train_start = time.time() - - # Create sampling rng - train_metrics = [] - steps_per_epoch = training_size // train_batch_size - train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) - # Generate an epoch by shuffling sampling indices from the train dataset - for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): - batch = next(train_loader) - batch = jax.device_put(batch, x_sharding) - with mesh: - state, train_metric = train_step(state, batch) - - # train_metrics.append(train_metric) - - - # this is for more accurate time stats, but slows down training - # train_metric['loss'].block_until_ready() - train_time = time.time() - train_start - - - - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | " - f"Loss: {train_metric['loss']}, " - f"Learning Rate:{train_metric['learning_rate']}, " - f"Last train time: {train_time})" - ) -# jax.profiler.stop_trace() - -# %% -# try out -gather_state = jax.device_get(state) -gather_batch = jax.device_get(batch) -logits = gather_state.apply_fn( - {'params': gather_state.params}, - input_ids=gather_batch['input_ids'], - attention_mask=gather_batch['attention_mask'], - decoder_input_ids=gather_batch['decoder_input_ids'], - decoder_attention_mask=gather_batch['decoder_attention_mask'], -)[0] # zero because output is some structure, where first is the logit - -probs = nn.softmax(logits, axis=-1) -predicted = jnp.argmax(probs, axis=-1) -print("sample output") -print(predicted[1]) - -# %% -main_model = custom_model.from_pretrained('t5-base') -output_dir = save_path - -# save checkpoint after each epoch and push checkpoint to the hub -if jax.process_index() == 0: - params = jax.device_get(state.params) - main_model.save_pretrained(output_dir, params=params) - - -# %% diff --git a/t5_jax_shmap.py b/t5_jax_shmap.py deleted file mode 100644 index 9d47b77..0000000 --- a/t5_jax_shmap.py +++ /dev/null @@ -1,550 +0,0 @@ -# %% -import os -# Set this to True to run the model on CPU only. -USE_CPU_ONLY = False - -flags = os.environ.get("XLA_FLAGS", "") -if USE_CPU_ONLY: - flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices - # Enforce CPU-only execution - os.environ["CUDA_VISIBLE_DEVICES"] = "" - os.environ["JAX_PLATFORMS"] = "cpu" -else: - # GPU flags - flags = ( - '--xla_gpu_enable_triton_softmax_fusion=true ' - '--xla_gpu_triton_gemm_any=True ' - # '--xla_gpu_enable_async_collectives=true ' - '--xla_gpu_enable_latency_hiding_scheduler=true ' - '--xla_gpu_enable_highest_priority_async_stream=true ' - ) - os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" - -os.environ["XLA_FLAGS"] = flags -os.environ.update({ - "TOKENIZERS_PARALLELISM" : "false", - "CUDA_DEVICE_MAX_CONNECTIONS" : "1", - "NCCL_LL128_BUFFSIZE": "-2", - "NCCL_LL_BUFFSIZE": "-2", - "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5", - # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" - }) - - - - -import functools -from functools import partial -from pprint import pprint -from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding -# from jax.experimental.pjit import pjit # superseded by jax.jit -from jax.experimental import mesh_utils -from jax.sharding import PartitionSpec -from ml_collections import ConfigDict -import optax -import logging -import time -from datasets import Dataset, load_from_disk - -from flax import jax_utils, traverse_util -from flax.training import train_state -from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key -from flax.core.frozen_dict import freeze, unfreeze, FrozenDict -import flax.core - -# model checkpointing and saving utilities -from flax import linen as nn -from flax.training import checkpoints, train_state -from flax import struct, serialization - -from parallel.partitions import set_partitions - -from tqdm import tqdm - -from parallel.dataload import DataPrepare - -# for memory tracking -# from jax_smi import initialise_tracking -# initialise_tracking() - - -PyTree = Any -Metrics = Dict[str, Tuple[jax.Array, ...]] - -if USE_CPU_ONLY: - jax.config.update('jax_platform_name', 'cpu') -else: - jax.config.update("jax_default_matmul_precision", "bfloat16") - -jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") -jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - - - - -# %% -## get platform type -from jax.extend.backend import get_backend -print(get_backend().platform) -print(jax.devices()) - -# %% -# config options -file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' -save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/' -# file_path = 'combined_data' -split_datasets = load_from_disk(file_path) -training_size = len(split_datasets['train']) -# Store some constant -seed = 117 -num_epochs = 40 -batch_size = 32 # do not go beyond 128, 64 is good -num_train_epochs = num_epochs -per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() -per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() -steps_per_epoch = training_size // train_batch_size -total_train_steps = steps_per_epoch * num_epochs - -warmup_steps = 0 -learning_rate = 2e-5 - -weight_decay = 0.01 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_epsilon = 1e-8 -label_smoothing_factor = 0.0 -num_beams = 1 -val_max_target_length = 128 -predict_with_generate = True - - -# %% -# prepare data -# init object -# e.g. Config -print("preparing data") -data_config = ConfigDict( - dict( - max_length=128, - pad_token_id=0, - decoder_start_token_id=0 - ) -) - -dataprep = DataPrepare(split_datasets['train'], data_config) -# # example usage -# # %% -seed = 117 -rng = jax.random.PRNGKey(seed) -train_loader = dataprep.data_loader(rng, batch_size=batch_size) -batch = next(iter(train_loader)) - -# %% -# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom -from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model -main_model = custom_model.from_pretrained( - "t5-base", - dtype=jnp.bfloat16, - gradient_checkpointing=True, -) -params = main_model.params -# pretrained_params = model.params -model = main_model.module - -# %% -# # testing config hashability -# # some explanation: -# # The PreTrainedModel class loads a T5Config model that is not hashable because -# # it is a complicated class that pretends to be a dataclass. -# # The solution is to extract a dict from it, then make a ConfigDict from -# # ml_collections library so that we can get values via the "." operator. -# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to -# # modify the config before passing to the next layer -# from transformers import T5Config -# from t5_model.configuration_t5 import FrozenT5Config -# from ml_collections import ConfigDict, FrozenConfigDict -# -# config = T5Config.from_pretrained("t5-base").to_dict() -# config.pop('architectures') -# config.pop('id2label') -# # test if it works -# frozen_config = FrozenConfigDict(config) -# # test hash -# hash(frozen_config) - -# %% -# # print model -# rng, input_rng = jax.random.split(rng) -# model.tabulate( -# input_rng, -# input_ids=batch['input_ids'], -# attention_mask=batch['attention_mask'], -# decoder_input_ids=batch['decoder_input_ids'], -# decoder_attention_mask=batch['decoder_attention_mask'], -# console_kwargs={"force_jupyter": True} -# ) - - - -# %% -# create mesh -print("creating mesh") -device_mesh = mesh_utils.create_device_mesh((2,2)) -print(device_mesh) - -mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) -print(mesh) - -def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: - if USE_CPU_ONLY: - return NamedSharding(mesh, pspec, memory_kind="unpinned_host") - else: - # if gpu - return NamedSharding(mesh, pspec, memory_kind="device") - -x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis -model_sharding=mesh_sharding(PartitionSpec(None, 'model')) - - -# %% -# optimizers - -def create_learning_rate_fn( - train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.ndarray]: - """Returns a linear warmup, linear_decay learning rate function.""" - steps_per_epoch = train_ds_size // train_batch_size - num_train_steps = steps_per_epoch * num_train_epochs - warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) - decay_fn = optax.linear_schedule( - init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps - ) - schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) - return schedule_fn - - -# Create learning rate schedule -linear_decay_lr_schedule_fn = create_learning_rate_fn( - training_size, - train_batch_size, - num_train_epochs, - warmup_steps, - learning_rate, -) - -# We use Optax's "masking" functionality to not apply weight decay -# to bias and LayerNorm scale parameters. decay_mask_fn returns a -# mask boolean with the same structure as the parameters. -# The mask is True for parameters that should be decayed. -def decay_mask_fn(params): - flat_params = traverse_util.flatten_dict(params) - # find out all LayerNorm parameters - layer_norm_candidates = ["final_layer_norm", "layer_norm"] - layer_norm_named_params = { - layer[-2:] - for layer_norm_name in layer_norm_candidates - for layer in flat_params.keys() - if layer_norm_name in "".join(layer).lower() - } - flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} - return traverse_util.unflatten_dict(flat_mask) - -# create adam optimizer -adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=adam_beta1, - b2=adam_beta2, - eps=adam_epsilon, - weight_decay=weight_decay, - mask=decay_mask_fn, -) - -# %% -def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): - """ - The label smoothing implementation is adapted from Flax's official example: - https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 - """ - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing_factor - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) - logits = jnp.asarray(logits, dtype=jnp.float32) - logits = logits.astype(jnp.float32) - soft_labels = soft_labels.astype(jnp.float32) - loss = optax.softmax_cross_entropy(logits, soft_labels) - loss = loss - normalizing_constant - - # ignore padded tokens from loss - loss = loss * padding_mask - mean_loss = loss.mean() - # num_labels = padding_mask.mean() - return mean_loss # , num_labels - - -# %% -################################################################ -# old jit in_shardings method - -# create init_fn to produce sharded state -def init_fn(params, model, optimizer): - # do be careful with the model init - # imported models might have complicated init methods - - # mask = create_mask_for_layer_norm(params) - # override params with bfloat version - # params= cast_floating_to(params, jnp.bfloat16, mask) - - state = train_state.TrainState.create( # Create a `TrainState`. - apply_fn=model.apply, - params=params, - tx=optimizer) - return state - - -abstract_variables = jax.eval_shape( - functools.partial(init_fn, model=model, optimizer=adamw), params) - -# jax.sharding: describes how a jax.Array is laid out across devices -state_sharding = nn.get_sharding(abstract_variables, mesh) - -from parallel.partitions import set_partitions -# set_partitions freezes the params on return -model_part_spec = set_partitions(unfreeze(params)) -# p is already a partition spec -model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) - -# get pspec for opt_state -def get_opt_spec(x): - if isinstance(x, dict): - return unfreeze(model_named_sharding) - # return an empty partspec - return mesh_sharding((PartitionSpec())) - -# this function replaces the empty model params spec with the 'model_named_shard' -state_sharding = jax.tree.map( - get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) -) - - - -# %% -jit_init_fn = jax.jit( - init_fn, - static_argnames=('model', 'optimizer'), # skip model and optimizer - in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly - out_shardings=state_sharding # but returned initialized_state is sharded -) -initialized_state = jit_init_fn(params, model, adamw) - -# %% -# train step -def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): - """ - The label smoothing implementation is adapted from Flax's official example: - https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 - """ - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing_factor - low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -( - confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) - ) - soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) - logits = jnp.asarray(logits, dtype=jnp.float32) - logits = logits.astype(jnp.float32) - soft_labels = soft_labels.astype(jnp.float32) - loss = optax.softmax_cross_entropy(logits, soft_labels) - loss = loss - normalizing_constant - - # ignore padded tokens from loss - loss = loss * padding_mask - loss = loss.mean() - # num_labels = padding_mask.mean() - return loss # , num_labels - -# %% - -# single device code annotated with jax.jit -@functools.partial( - jax.jit, - # state is state_sharding initialized from init_fn - # x_sharding is data sharded explicitly later - in_shardings=(state_sharding, x_sharding), - out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), - donate_argnames=('state'), -) -def train_step(state, batch): - - label_smoothing_factor=0.0 - # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) - # computes loss per shard - def compute_loss(params, batch): - # check constraints - # frozen dict not allowed as sharding object - params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) - batch = jax.lax.with_sharding_constraint(batch, x_sharding) - logits = state.apply_fn( - {'params': params}, - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - decoder_input_ids=batch['decoder_input_ids'], - decoder_attention_mask=batch['decoder_attention_mask'], - )[0] # zero because output is some structure, where first is the logit - - # logits sharding - # data, None, model - # - print("logits") - jax.debug.inspect_array_sharding(logits, callback=print) - # use labels here - # loss, num_labels = loss_fn( - loss = loss_fn( - logits, - batch["labels"], - batch["decoder_attention_mask"], - label_smoothing_factor) - # loss sharding - # it gives PartitionSpec(), which implies a reduction already happened - print("loss") - jax.debug.inspect_array_sharding(loss, callback=print) - - return loss # , num_labels - - # compute gradients through computational graph - # allow values to pass through - grad_fn = jax.value_and_grad(compute_loss, has_aux=False) - batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) - (loss), grads = grad_fn(state.params, batch) - # num_labels = jax.lax.psum(num_labels, "batch") - - # so far we have been operating from within each shard - # we need to sync gradients across devices - # we bring all gradients together onto a single device - # jax.debug.inspect_array_sharding(grads, callback=print) - grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec())) - # grads = jax.lax.with_sharding_constraint(grads, state_sharding) - # jax.debug.visualize_array_sharding(grad) - # jax.debug.inspect_array_sharding(grad, callback=print) - # check the output grad tree from mean - # print(jax.tree.map(jnp.shape, grad)) - - - - new_state = state.apply_gradients(grads=grads) - with jax.named_scope("sync_metrics"): - step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - - return new_state, step_metrics - -# %% -# explore data sharding -sharded_batch = next(iter(train_loader)) -# sharded_batch = jax.device_put(sharded_batch, x_sharding) -sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) -jax.debug.visualize_array_sharding(sharded_batch['input_ids']) -# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) - - - -# %% -# # prep 1 step -print("1 step for jit-ting") -with mesh: - state, metrics = train_step(initialized_state, sharded_batch) - -# %% - -# %% -# tr -print("***** Running training *****") -print(f" Num examples = {training_size}") -print(f" Num Epochs = {num_epochs}") -print(f" Instantaneous batch size per device = {per_device_train_batch_size}") -print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") -print(f" Total optimization steps = {total_train_steps}") - - -# %% -# jax.profiler.start_trace("./traces") - - -print("*" * 20) -print("training start") -rng, input_rng = jax.random.split(rng) -train_time = 0 -state = initialized_state -epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) -for epoch in epochs: - train_start = time.time() - - # Create sampling rng - train_metrics = [] - steps_per_epoch = training_size // train_batch_size - train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) - # Generate an epoch by shuffling sampling indices from the train dataset - for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): - batch = next(train_loader) - batch = jax.device_put(batch, x_sharding) - with mesh: - state, train_metric = train_step(state, batch) - - # train_metrics.append(train_metric) - - - # this is for more accurate time stats, but slows down training - # train_metric['loss'].block_until_ready() - train_time = time.time() - train_start - - - - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | " - f"Loss: {train_metric['loss']}, " - f"Learning Rate:{train_metric['learning_rate']}, " - f"Last train time: {train_time})" - ) -# jax.profiler.stop_trace() - -# %% -# try out -# gather_state = jax.device_get(state) -# gather_batch = jax.device_get(batch) -# logits = gather_state.apply_fn( -# {'params': gather_state.params}, -# input_ids=gather_batch['input_ids'], -# attention_mask=gather_batch['attention_mask'], -# decoder_input_ids=gather_batch['decoder_input_ids'], -# decoder_attention_mask=gather_batch['decoder_attention_mask'], -# )[0] # zero because output is some structure, where first is the logit -# -# probs = nn.softmax(logits, axis=-1) -# predicted = jnp.argmax(probs, axis=-1) -# print(predicted[0]) - -# %% -main_model = custom_model.from_pretrained('t5-base') -output_dir = save_path - -# save checkpoint after each epoch and push checkpoint to the hub -if jax.process_index() == 0: - params = jax.device_get(state.params) - params = jax.tree.map(lambda x: x.astype(jnp.float32), params) - main_model.save_pretrained(output_dir, params=params) - - -# %% diff --git a/t5_jax_simple_parallel.py b/t5_jax_simple_parallel.py index a832222..54687c1 100644 --- a/t5_jax_simple_parallel.py +++ b/t5_jax_simple_parallel.py @@ -12,15 +12,16 @@ if USE_CPU_ONLY: else: # GPU flags flags = ( - '--xla_gpu_enable_triton_softmax_fusion=true ' + # '--xla_gpu_enable_custom_fusions=true ' '--xla_gpu_triton_gemm_any=True ' # '--xla_gpu_enable_async_collectives=true ' '--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_highest_priority_async_stream=true ' + '--xla_gpu_enable_pipelined_all_reduce=true ' + '--xla_gpu_enable_nccl_user_buffers=true ' ) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" -os.environ["XLA_FLAGS"] = flags os.environ.update({ "TOKENIZERS_PARALLELISM" : "false", "CUDA_DEVICE_MAX_CONNECTIONS" : "1", @@ -28,9 +29,10 @@ os.environ.update({ "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", + "NCCL_NVLS_ENABLE": "1", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) - +os.environ["XLA_FLAGS"] = flags @@ -69,7 +71,7 @@ from parallel.partitions import set_partitions from tqdm import tqdm -from parallel.dataload import DataPrepare +from dataload import DataPrepare # for memory tracking # from jax_smi import initialise_tracking @@ -114,7 +116,7 @@ model_sharding=mesh_sharding(PartitionSpec(None, 'model')) # %% # config options -file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_simple/' save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) @@ -122,12 +124,12 @@ training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 40 -batch_size = 128 +batch_size = 64 num_train_epochs = num_epochs per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * 2 +train_batch_size = per_device_train_batch_size * mesh.shape['data'] per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * 2 +eval_batch_size = per_device_eval_batch_size * mesh.shape['data'] steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -394,13 +396,29 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): # ignore padded tokens from loss loss = loss * padding_mask + loss = jax.lax.with_sharding_constraint(loss, x_sharding) loss = loss.mean() # num_labels = padding_mask.mean() return loss # , num_labels # %% +# def extract_spec(sharding_obj): +# # Check if the object is a NamedSharding instance +# if isinstance(sharding_obj, NamedSharding): +# return sharding_obj.spec # Return the spec if it is a NamedSharding +# return sharding_obj # Return the object itself if not +# +# state_sharding_spec = jax.tree.map(extract_spec, state_sharding) # single device code annotated with jax.jit +# @partial( +# shard_map, +# mesh=mesh, +# in_specs=(state_sharding_spec, +# x_sharding.spec), +# out_specs=(state_sharding_spec, PartitionSpec()), +# check_rep=False, +# ) @functools.partial( jax.jit, # state is state_sharding initialized from init_fn @@ -418,6 +436,8 @@ def train_step(state, batch): # frozen dict not allowed as sharding object params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) batch = jax.lax.with_sharding_constraint(batch, x_sharding) + # require decoder_input_ids to simulate auto-regressive output during + # generation time logits = state.apply_fn( {'params': params}, input_ids=batch['input_ids'], @@ -439,7 +459,7 @@ def train_step(state, batch): grad_fn = jax.value_and_grad(compute_loss, has_aux=False) (loss), grad = grad_fn(state.params, batch) # num_labels = jax.lax.psum(num_labels, "batch") - + # loss, grad = jax.lax.pmean((loss, grad), axis_name="data") new_state = state.apply_gradients(grads=grad) with jax.named_scope("sync_metrics"): @@ -447,15 +467,12 @@ def train_step(state, batch): return new_state, step_metrics -# %% -# explore data sharding +# # %% +# # explore data sharding # sharded_batch = next(iter(train_loader)) # sharded_batch = jax.device_put(sharded_batch, x_sharding) # jax.debug.visualize_array_sharding(sharded_batch['input_ids']) # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) - - -# %% # # prep 1 step # print("1 step for jit-ting") # with mesh: @@ -476,6 +493,8 @@ print(f" Total optimization steps = {total_train_steps}") # %% # jax.profiler.start_trace("./traces") +# jit_train_step = jax.jit(train_step) + print("*" * 50) print("training start") diff --git a/t5_model/modeling_t5_flax.py b/t5_model/modeling_t5_flax.py index 4884050..059bfde 100644 --- a/t5_model/modeling_t5_flax.py +++ b/t5_model/modeling_t5_flax.py @@ -289,6 +289,7 @@ class FlaxT5Attention(nn.Module): def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + # i suspect we are threading state here @nn.compact def _concatenate_to_cache(self, key, value, query, attention_mask): """ @@ -298,10 +299,17 @@ class FlaxT5Attention(nn.Module): """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") + # Variables are identified by a collection (e.g., "batch_stats") and a name + # (e.g., "moving_mean"). The value property gives access to the variable's + # content and can be assigned to for mutation. + # + # self.variable either 1.) initializes values for the first time + # 2.) retrieves the variable and does not override cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + # only run if initialized before if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices @@ -688,7 +696,7 @@ class FlaxT5BlockCollection(nn.Module): position_bias = None encoder_decoder_position_bias = None - for i, layer_module in enumerate(self.blocks): + for _, layer_module in enumerate(self.blocks): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -933,7 +941,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): config_class = T5Config base_model_prefix = "transformer" - module_class: nn.Module = None + module_class: nn.Module = None # to be overriden by subclass def __init__( self, diff --git a/t5_prediction_old.py b/t5_prediction_old.py deleted file mode 100644 index 87edd6c..0000000 --- a/t5_prediction_old.py +++ /dev/null @@ -1,360 +0,0 @@ - -# --- -# jupyter: -# jupytext: -# formats: ipynb,py:percent -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.16.4 -# --- - -# %% [markdown] -# # prediction code -# ## import and process test data - - -# %% -# import libraries -import pandas as pd -import matplotlib.pyplot as plt - -from datasets import Dataset, DatasetDict - -import jax -import jax.numpy as jnp -import optax -import numpy as np -from functools import partial -from typing import Callable, Optional -import math - -# jax.config.update("jax_default_matmul_precision", "tensorfloat32") -jax.config.update("jax_default_matmul_precision", "high") - -jax.config.update("jax_enable_x64", False) - - -from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig - - -import datasets -from datasets import Dataset -import evaluate -from tqdm import tqdm - - -import nltk # Here to have a nice missing dependency error message early on - -from flax import jax_utils, traverse_util -from flax.jax_utils import pad_shard_unpad, unreplicate -from flax.training import train_state -from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key - - -import time - - -# %% - -# data_path = f"../make_data/select_db/data_mapping_filtered.csv" -# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" -data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' -# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' - -# Ensure to include 'ships_idx' in the fields list -fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] - -# Load the dataset -df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) - -def process_df(df): - output_list = [{ - 'input': f"{row['tag_name']}{row['tag_description']}", - # 'input': f"{row['tag_description']}", - # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", - # 'input': f"{row['tag_description']}{row['unit']}", - 'output': f"{row['thing']}{row['property']}", - # 'answer': f"{row['thing']} {row['property']}", - # 'answer_thing': row['thing'], - # 'answer_property': row['property'], - } for _, row in df.iterrows()] - - return output_list - - -# takes 1 minute to run without batching -test_dataset = Dataset.from_list(process_df(df)) - - -# %% [markdown] -# ## Load model for attributes - -# %% -# load model -model_name_or_path = "./t5_80_1" # Replace with your specific model name - -# Load configuration -config = AutoConfig.from_pretrained(model_name_or_path) - -# Load model -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - pretrained_model_name_or_path=model_name_or_path -) - - -# %% [markdown] -# ## Tokenizer - -# %% -# prepare tokenizer -from transformers import T5TokenizerFast -tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) -# Define additional special tokens -additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] -# Add the additional special tokens to the tokenizer -tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) - -max_length = 86 - -model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) -shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") - -# given a dataset entry, run it through the tokenizer -# Setting padding="max_length" as we need fixed length inputs for jitted functions -def preprocess_function(example): - inputs = example['input'] - targets = example['output'] - # text_target sets the corresponding label to inputs - # there is no need to create a separate 'labels' - model_inputs = tokenizer( - inputs, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="np" - ) - labels = tokenizer( - text_target=targets, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="np" - ) - - model_inputs["labels"] = labels["input_ids"] - decoder_input_ids = shift_tokens_right_fn( - labels["input_ids"], config.pad_token_id, config.decoder_start_token_id - ) - model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) - - # We need decoder_attention_mask so we can ignore pad tokens from loss - model_inputs["decoder_attention_mask"] = labels["attention_mask"] - - return model_inputs - -# map maps function to each "row" in the dataset -# aka the data in the immediate nesting -test_dataset = test_dataset.map( - preprocess_function, - batched=True, - num_proc=1, - remove_columns=test_dataset.column_names, -) - -def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): - """ - Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, - and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. - """ - if shuffle: - batch_idx = jax.random.permutation(rng, len(dataset)) - batch_idx = np.asarray(batch_idx) - else: - batch_idx = np.arange(len(dataset)) - - if drop_last: - steps_per_epoch = len(dataset) // batch_size - batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. - batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) - else: - steps_per_epoch = math.ceil(len(dataset) / batch_size) - batch_idx = np.array_split(batch_idx, steps_per_epoch) - - for idx in batch_idx: - batch = dataset[idx] - batch = {k: np.array(v) for k, v in batch.items()} - - yield batch - -# %% [markdown] -# # model generation - -# %% -seed = 117 -num_epochs = 80 -batch_size = 96 -num_train_epochs = num_epochs -per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() -per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() -steps_per_epoch = len(test_dataset) // train_batch_size -total_train_steps = steps_per_epoch * num_epochs - -num_beams = 1 -val_max_target_length = 128 - -predict_with_generate = True - - -# Initialize our training -rng = jax.random.PRNGKey(seed) -rng, dropout_rng = jax.random.split(rng) - - -# %% - -# reload model to prevent leakage of variables -# load model -model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name - -# Load configuration -config = AutoConfig.from_pretrained(model_name_or_path) - -# Load model -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - model_name_or_path -) - - -# Ensure model.params is properly initialized (this is just an example) -# Normally you would get this from a model initialization call with dummy input -params = model.params -# ensure full size floats -params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) -# we need to replicate model over devices -replicated_params = jax.device_put_replicated(params_f16, jax.devices()) - - -# Define generation function -max_length = ( - val_max_target_length if val_max_target_length is not None else model.config.max_length -) -num_beams = num_beams if num_beams is not None else model.config.num_beams -gen_kwargs = {"max_length": max_length, "num_beams": num_beams} - -def generate_step(params, batch): - output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) - return output_ids.sequences - -# Create parallel version of the train and eval step -p_generate_step = jax.pmap(generate_step, "batch") - - - -pred_generations = [] -pred_labels = [] - -rng, input_rng = jax.random.split(rng) - -pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) -pred_steps = math.ceil(len(test_dataset) / eval_batch_size) - -print("***** Running training *****") -print(f" Num examples = {len(test_dataset)}") -print(f" Num steps = {num_epochs}") -print(f" Instantaneous batch size per device = {per_device_train_batch_size}") -print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") - - -for _ in tqdm(range(pred_steps), desc="Predicting..."): - # Model forward - batch = next(pred_loader) - labels = batch["labels"] - - # generation - generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) - pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) - pred_labels.extend(labels) - - - -# %% [markdown] -# # process predictions - - -# %% -# code to get special token ids -# sentence = "" -# tokens = tokenizer.tokenize(sentence) -# print("Tokens:", tokens) -# # Get the IDs (integer indices) of specific tokens -# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] -# print("Token IDs:", token_ids) - - -# %% -# extract sequence and decode -def extract_seq(tokens, start_value, end_value): - if start_value not in tokens or end_value not in tokens: - return None # Or handle this case according to your requirements - start_id = np.where(tokens == start_value)[0][0] - end_id = np.where(tokens == end_value)[0][0] - - return tokens[start_id+1:end_id] - - -def process_tensor_output(tokens): - thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = - property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = - p_thing = None - p_property = None - if (thing_seq is not None): - p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain - if (property_seq is not None): - p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain - return p_thing, p_property - - -# %% -# decode prediction labels -def decode_preds(tokens_list): - thing_prediction_list = [] - property_prediction_list = [] - for tokens in tokens_list: - p_thing, p_property = process_tensor_output(tokens) - thing_prediction_list.append(p_thing) - property_prediction_list.append(p_property) - return thing_prediction_list, property_prediction_list - -thing_prediction_list, property_prediction_list = decode_preds(pred_generations) - -# %% -# add labels too -thing_actual_list, property_actual_list = decode_preds(pred_labels) - -# Convert the list to a Pandas DataFrame -df = pd.DataFrame({'p_thing': thing_prediction_list, - 'p_property': property_prediction_list, - 'thing': thing_actual_list, - 'property' : property_actual_list}) - -df['p_thing_correct'] = df['p_thing'] == df['thing'] -df['p_property_correct'] = df['p_property'] == df['property'] - -# %% -print("thing accuracy", sum(df['p_thing_correct'])/len(df)) -print("property accuracy", sum(df['p_property_correct'])/len(df)) -print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) -# %% -df[~df["p_property_correct"]] - -# %% -df['p_thing'] -# %% -# Save the DataFrame as a Parquet file (using pyarrow or fastparquet) -# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" - -