Chore: removing old files and some experiments

This commit is contained in:
Richard Wong 2024-10-06 23:53:57 +09:00
parent 0762c02b31
commit 7e1f45f466
9 changed files with 72 additions and 2241 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ exports/
traces/ traces/
ruff.toml ruff.toml
settings.json settings.json
__pycache__/

View File

@ -11,7 +11,6 @@ import math
from typing import Optional, List, Tuple, Callable, cast from typing import Optional, List, Tuple, Callable, cast
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
# file_path = 'combined_data' # file_path = 'combined_data'
# split_datasets = load_from_disk(file_path) # split_datasets = load_from_disk(file_path)
# training_size = len(split_datasets['train']) # training_size = len(split_datasets['train'])
@ -124,6 +123,25 @@ class DataPrepare():
# assign the dataset to train_dataset # assign the dataset to train_dataset
self.train_dataset = train_dataset self.train_dataset = train_dataset
# Example pad function
def _pad_to_batch_size(self, batch, target_size):
# Get the current batch size
input_ids = batch['input_ids']
current_size = input_ids.shape[0]
if current_size < target_size:
# Calculate how much padding is needed
padding_size = target_size - current_size
# Create padding (e.g., zeros or some appropriate value)
padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D
# Concatenate to create a full batch
# repeat for all arrays in the tree
padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch)
# padded_batch = jnp.concatenate([batch, padding], axis=0)
else:
padded_batch = batch
return padded_batch
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True): def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
""" """
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
@ -148,7 +166,8 @@ class DataPrepare():
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()}
batch = self._pad_to_batch_size(batch, batch_size)
yield batch yield batch
@ -157,6 +176,8 @@ class DataPrepare():
# # %% # # %%
# # init object # # init object
# # e.g. Config # # e.g. Config
#
# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc'
# data_config = ConfigDict( # data_config = ConfigDict(
# dict( # dict(
# max_length=86, # max_length=86,
@ -190,3 +211,5 @@ class DataPrepare():
# #
# #
# # %% # # %%
#
# %%

View File

@ -1,697 +0,0 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint as ocp
from flax.training import orbax_utils
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 32
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# batch
# %%
# model
# working
# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
# from parallel.t5_model.pure_t5 import make_config
# config = make_config()
# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True)
# %%
# from transformers import FlaxT5ForConditionalGeneration, T5Config
# model = FlaxT5ForConditionalGeneration.from_pretrained(
# "t5-base",
# dtype=jnp.bfloat16,
# )
# # pretrained_params = model.params
# model = model.module
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.float32,
# gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# print model datatype to verify
# rng, input_rng = jax.random.split(rng)
# variables = model.init(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((1,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
print("compile")
# enable bf16 except for layer_norm
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# Cast all parameters to bfloat16 if desired
# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# %%
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(
# rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables['params']
# mask = create_mask_for_layer_norm(params)
# # override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
#
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=params,
# tx=optimizer)
# return state
# %%
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
# rng, init_rng = jax.random.split(rng)
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# rng, init_rng = jax.random.split(rng)
# abstract_variables = jax.eval_shape(
# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
# %%
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
##################################################
# # %%
# # replace the params tree with the new modified tree
# # create partitions for model
# from parallel.partitions import set_partitions
# # set_partitions freezes the params on return
# model_part_spec = set_partitions(unfreeze(params))
# # p is already a partition spec
# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
#
# # %%
# # get_shapes = jax.tree.map(jnp.shape, params)
# # actually tuple
# # state_shapes = jax.eval_shape(state_sharding, get_shapes)
#
# # %%
# # get pspec for opt_state
# def get_opt_spec(x):
# if isinstance(x, dict):
# return unfreeze(model_named_sharding)
# # return an empty partspec
# return mesh_sharding((PartitionSpec()))
#
# # this function replaces the empty model params spec with the 'model_named_shard'
# state_sharding = jax.tree.map(
# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
# )
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# jit_init_fn = jax.jit(
# init_fn,
# static_argnames=('model', 'optimizer'), # skip model and optimizer
# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
# out_shardings=state_sharding
# )
#
#
# rng, init_rng = jax.random.split(rng)
# initialized_state = jit_init_fn(rng, batch, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# %%
# sharded_loss_fn = jax.jit(
# loss_fn,
# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis
# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model'
# )
def gather_and_sum(
sharded_values,
in_shardings
):
with mesh:
# Gather sharded values into a single device
gathered_values = jax.jit(
lambda x: x, in_shardings=in_shardings, out_shardings=None
)(sharded_values)
# Compute the sum of gathered values
summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values)
return summed_value
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
# return state as state_sharding
# we do not shard the metrics
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
# params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
# batch = jax.lax.with_sharding_constraint(batch, x_sharding)
# labels = batch.pop("decoder_input_ids")
# no use of labels here
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
loss, num_labels = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
# true grad = total grad / total samples
# needs to be in a singleton tuple for some reason
# gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),))
# gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec()))
# summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad)
# grad = jax.lax.psum(grad, "batch")
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# step_metrics = jax.tree.map(
# # previously needed lax.psum
# # now just write single device code, let compiler handle
# lambda x: jnp.mean(x), step_metrics
# )
# if metrics is None:
# metrics = step_metrics
# else:
# # combine all the synced metrics
# metrics = jax.tree.map(jnp.mean, metrics, step_metrics)
return new_state, step_metrics
# %%
# prep 1 step
print("1 step for jit-ting")
with mesh:
state, metrics = train_step(initialized_state, batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
# function to shard a batch by treating it as a pytree
def shard_batch(batch):
# Shard each element in the dictionary (i.e., each key-value pair)
return jax.tree_util.tree_map(
lambda x: jax.device_put(x, x_sharding),
batch
)
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
# send to device
# batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()}
# batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding)
# batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding)
# batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding)
# batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding)
sharded_batch = shard_batch(batch)
with mesh:
state, train_metric = train_step(state, sharded_batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# with mesh:
# gathered_params = jax.jit(
# lambda x: x,
# in_shardings=(unfreeze(model_named_sharding),),
# out_shardings=mesh_sharding(PartitionSpec())
# )(state.params)
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)
# # stick to defaults
# options = ocp.CheckpointManagerOptions()
# with ocp.CheckpointManager(
# ocp.test_utils.erase_and_create_empty(save_path),
# options=options,
# ) as mngr:
#
# mngr.save(0, args=ocp.args.StandardSave(state))
# mngr.wait_until_finished()
# After providing `args` during an initial `save` or `restore` call, the
# `CheckpointManager` instance records the type so that you do not need to
# specify it again. If the `CheckpointManager` instance is not provided with a
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
# occasion it cannot be restored without specifying the argument at restore
# time.
# # In many cases, you can restore exactly as saved without specifying additional
# # arguments.
# mngr.restore(0)
# # If customization of properties like sharding or dtype is desired, just provide
# # the abstract target PyTree, the properties of which will be used to set
# # the properties of the restored arrays.
# mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
# %%

View File

@ -118,7 +118,7 @@ predict_with_generate = True
# Initialize our prediction # Initialize our prediction
rng = jax.random.PRNGKey(seed) rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng) # rng, dropout_rng = jax.random.split(rng)
print("preparing data") print("preparing data")
data_config = ConfigDict( data_config = ConfigDict(
@ -130,11 +130,6 @@ data_config = ConfigDict(
) )
dataprep = DataPrepare(test_dataset, data_config) dataprep = DataPrepare(test_dataset, data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
# %% # %%
# Ensure model.params is properly initialized (this is just an example) # Ensure model.params is properly initialized (this is just an example)
@ -186,6 +181,8 @@ for _ in tqdm(range(pred_steps), desc="Predicting..."):
# generation # generation
# pad_shard_unpad is useful for calling a pmaped function with inputs that
# arent divisible by the number of devices.
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels) pred_labels.extend(labels)

View File

@ -1,610 +0,0 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
# for memory tracking
# from jax_smi import initialise_tracking
# initialise_tracking()
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 64
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
print("compile")
# enable bf16
# enable only for dense, some transformer sections, and shared
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
# path: not (
# (path[-2] == "layer_norm" and path[-1] == "weight") or
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
# (path[-2] == "o" and path[-1] == "kernel")
# )
# for path in flat_params
path: (
(path[-2] == "wi" and path[-1] == "weight") or
(path[-2] == "wo" and path[-1] == "weight") or
(path[-2] == "k" and path[-1] == "kernel") or
(path[-2] == "q" and path[-1] == "kernel") or
(path[-2] == "v" and path[-1] == "kernel") or
(path[-2] == "shared" and path[-1] == "embedding")
) for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# create init_fn to produce sharded state
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# %%
# replace the params tree with the new modified tree
# create partitions for model
from parallel.partitions import set_partitions
# set_partitions freezes the params on return
model_part_spec = set_partitions(unfreeze(params))
# p is already a partition spec
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
# get pspec for opt_state
def get_opt_spec(x):
if isinstance(x, dict):
return unfreeze(model_named_sharding)
# return an empty partspec
return mesh_sharding((PartitionSpec()))
# this function replaces the empty model params spec with the 'model_named_shard'
state_sharding = jax.tree.map(
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.mean()
# num_labels = padding_mask.mean()
return loss # , num_labels
# %%
# gradient accumulation
def accumulate_gradients_loop(
state,
batch,
minibatch_size: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch['input_ids'].shape[0]
# minibatch_size = batch_size // num_minibatches
num_minibatches = batch_size // minibatch_size
# Define gradient function for single minibatch.
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
# otherwise it returns (value, gradient), where value is the actual output
# of the function, hence the "value" of the namesake
grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
# Prepare loop variables.
grads = None
metrics = None
for minibatch_idx in range(num_minibatches):
with jax.named_scope(f"minibatch_{minibatch_idx}"):
# Split the batch into minibatches.
start = minibatch_idx * minibatch_size
end = start + minibatch_size
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
# Calculate gradients and metrics for the minibatch.
# missing value is mean loss of batch
loss, step_grads = grad_fn(
state.params, minibatch
)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# Accumulate gradients and metrics across minibatches.
if grads is None:
grads = step_grads
metrics = step_metrics
else:
# accumulation adder
grads = jax.tree.map(jnp.add, grads, step_grads)
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
# loss, num_labels = loss_fn(
loss = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss # , num_labels
# # compute gradients through computational graph
# # allow values to pass through
# grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
# (loss), grad = grad_fn(state.params, batch)
# # num_labels = jax.lax.psum(num_labels, "batch")
# new_state = state.apply_gradients(grads=grad)
# with jax.named_scope("sync_metrics"):
# step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# use gradient accumulation
grads, step_metrics = accumulate_gradients_loop(
state=state,
batch=batch,
minibatch_size=32,
loss_fn=compute_loss
)
new_state = state.apply_gradients(grads=grads)
return new_state, step_metrics
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
sharded_batch = jax.device_put(sharded_batch, x_sharding)
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step
# print("1 step for jit-ting")
# with mesh:
# state, metrics = train_step(initialized_state, sharded_batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
state = initialized_state
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = jax.device_put(batch, x_sharding)
with mesh:
state, train_metric = train_step(state, batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# try out
gather_state = jax.device_get(state)
gather_batch = jax.device_get(batch)
logits = gather_state.apply_fn(
{'params': gather_state.params},
input_ids=gather_batch['input_ids'],
attention_mask=gather_batch['attention_mask'],
decoder_input_ids=gather_batch['decoder_input_ids'],
decoder_attention_mask=gather_batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
probs = nn.softmax(logits, axis=-1)
predicted = jnp.argmax(probs, axis=-1)
print("sample output")
print(predicted[1])
# %%
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
main_model.save_pretrained(output_dir, params=params)
# %%

View File

@ -1,550 +0,0 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
# for memory tracking
# from jax_smi import initialise_tracking
# initialise_tracking()
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 40
batch_size = 32 # do not go beyond 128, 64 is good
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
if USE_CPU_ONLY:
return NamedSharding(mesh, pspec, memory_kind="unpinned_host")
else:
# if gpu
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
mean_loss = loss.mean()
# num_labels = padding_mask.mean()
return mean_loss # , num_labels
# %%
################################################################
# old jit in_shardings method
# create init_fn to produce sharded state
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
from parallel.partitions import set_partitions
# set_partitions freezes the params on return
model_part_spec = set_partitions(unfreeze(params))
# p is already a partition spec
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
# get pspec for opt_state
def get_opt_spec(x):
if isinstance(x, dict):
return unfreeze(model_named_sharding)
# return an empty partspec
return mesh_sharding((PartitionSpec()))
# this function replaces the empty model params spec with the 'model_named_shard'
state_sharding = jax.tree.map(
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.mean()
# num_labels = padding_mask.mean()
return loss # , num_labels
# %%
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
# computes loss per shard
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# logits sharding
# data, None, model
#
print("logits")
jax.debug.inspect_array_sharding(logits, callback=print)
# use labels here
# loss, num_labels = loss_fn(
loss = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
# loss sharding
# it gives PartitionSpec(), which implies a reduction already happened
print("loss")
jax.debug.inspect_array_sharding(loss, callback=print)
return loss # , num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
(loss), grads = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
# so far we have been operating from within each shard
# we need to sync gradients across devices
# we bring all gradients together onto a single device
# jax.debug.inspect_array_sharding(grads, callback=print)
grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec()))
# grads = jax.lax.with_sharding_constraint(grads, state_sharding)
# jax.debug.visualize_array_sharding(grad)
# jax.debug.inspect_array_sharding(grad, callback=print)
# check the output grad tree from mean
# print(jax.tree.map(jnp.shape, grad))
new_state = state.apply_gradients(grads=grads)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, step_metrics
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
# sharded_batch = jax.device_put(sharded_batch, x_sharding)
sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step
print("1 step for jit-ting")
with mesh:
state, metrics = train_step(initialized_state, sharded_batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
print("*" * 20)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
state = initialized_state
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = jax.device_put(batch, x_sharding)
with mesh:
state, train_metric = train_step(state, batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# try out
# gather_state = jax.device_get(state)
# gather_batch = jax.device_get(batch)
# logits = gather_state.apply_fn(
# {'params': gather_state.params},
# input_ids=gather_batch['input_ids'],
# attention_mask=gather_batch['attention_mask'],
# decoder_input_ids=gather_batch['decoder_input_ids'],
# decoder_attention_mask=gather_batch['decoder_attention_mask'],
# )[0] # zero because output is some structure, where first is the logit
#
# probs = nn.softmax(logits, axis=-1)
# predicted = jnp.argmax(probs, axis=-1)
# print(predicted[0])
# %%
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
params = jax.tree.map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)
# %%

View File

@ -12,15 +12,16 @@ if USE_CPU_ONLY:
else: else:
# GPU flags # GPU flags
flags = ( flags = (
'--xla_gpu_enable_triton_softmax_fusion=true ' # '--xla_gpu_enable_custom_fusions=true '
'--xla_gpu_triton_gemm_any=True ' '--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true ' # '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true ' '--xla_gpu_enable_highest_priority_async_stream=true '
'--xla_gpu_enable_pipelined_all_reduce=true '
'--xla_gpu_enable_nccl_user_buffers=true '
) )
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({ os.environ.update({
"TOKENIZERS_PARALLELISM" : "false", "TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1", "CUDA_DEVICE_MAX_CONNECTIONS" : "1",
@ -28,9 +29,10 @@ os.environ.update({
"NCCL_LL_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128", "NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
"NCCL_NVLS_ENABLE": "1",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
}) })
os.environ["XLA_FLAGS"] = flags
@ -69,7 +71,7 @@ from parallel.partitions import set_partitions
from tqdm import tqdm from tqdm import tqdm
from parallel.dataload import DataPrepare from dataload import DataPrepare
# for memory tracking # for memory tracking
# from jax_smi import initialise_tracking # from jax_smi import initialise_tracking
@ -114,7 +116,7 @@ model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %% # %%
# config options # config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_simple/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/' save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/'
# file_path = 'combined_data' # file_path = 'combined_data'
split_datasets = load_from_disk(file_path) split_datasets = load_from_disk(file_path)
@ -122,12 +124,12 @@ training_size = len(split_datasets['train'])
# Store some constant # Store some constant
seed = 117 seed = 117
num_epochs = 40 num_epochs = 40
batch_size = 128 batch_size = 64
num_train_epochs = num_epochs num_train_epochs = num_epochs
per_device_train_batch_size = batch_size per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * 2 train_batch_size = per_device_train_batch_size * mesh.shape['data']
per_device_eval_batch_size = batch_size per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * 2 eval_batch_size = per_device_eval_batch_size * mesh.shape['data']
steps_per_epoch = training_size // train_batch_size steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs total_train_steps = steps_per_epoch * num_epochs
@ -394,13 +396,29 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
# ignore padded tokens from loss # ignore padded tokens from loss
loss = loss * padding_mask loss = loss * padding_mask
loss = jax.lax.with_sharding_constraint(loss, x_sharding)
loss = loss.mean() loss = loss.mean()
# num_labels = padding_mask.mean() # num_labels = padding_mask.mean()
return loss # , num_labels return loss # , num_labels
# %% # %%
# def extract_spec(sharding_obj):
# # Check if the object is a NamedSharding instance
# if isinstance(sharding_obj, NamedSharding):
# return sharding_obj.spec # Return the spec if it is a NamedSharding
# return sharding_obj # Return the object itself if not
#
# state_sharding_spec = jax.tree.map(extract_spec, state_sharding)
# single device code annotated with jax.jit # single device code annotated with jax.jit
# @partial(
# shard_map,
# mesh=mesh,
# in_specs=(state_sharding_spec,
# x_sharding.spec),
# out_specs=(state_sharding_spec, PartitionSpec()),
# check_rep=False,
# )
@functools.partial( @functools.partial(
jax.jit, jax.jit,
# state is state_sharding initialized from init_fn # state is state_sharding initialized from init_fn
@ -418,6 +436,8 @@ def train_step(state, batch):
# frozen dict not allowed as sharding object # frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding) batch = jax.lax.with_sharding_constraint(batch, x_sharding)
# require decoder_input_ids to simulate auto-regressive output during
# generation time
logits = state.apply_fn( logits = state.apply_fn(
{'params': params}, {'params': params},
input_ids=batch['input_ids'], input_ids=batch['input_ids'],
@ -439,7 +459,7 @@ def train_step(state, batch):
grad_fn = jax.value_and_grad(compute_loss, has_aux=False) grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
(loss), grad = grad_fn(state.params, batch) (loss), grad = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch") # num_labels = jax.lax.psum(num_labels, "batch")
# loss, grad = jax.lax.pmean((loss, grad), axis_name="data")
new_state = state.apply_gradients(grads=grad) new_state = state.apply_gradients(grads=grad)
with jax.named_scope("sync_metrics"): with jax.named_scope("sync_metrics"):
@ -447,15 +467,12 @@ def train_step(state, batch):
return new_state, step_metrics return new_state, step_metrics
# %% # # %%
# explore data sharding # # explore data sharding
# sharded_batch = next(iter(train_loader)) # sharded_batch = next(iter(train_loader))
# sharded_batch = jax.device_put(sharded_batch, x_sharding) # sharded_batch = jax.device_put(sharded_batch, x_sharding)
# jax.debug.visualize_array_sharding(sharded_batch['input_ids']) # jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step # # prep 1 step
# print("1 step for jit-ting") # print("1 step for jit-ting")
# with mesh: # with mesh:
@ -476,6 +493,8 @@ print(f" Total optimization steps = {total_train_steps}")
# %% # %%
# jax.profiler.start_trace("./traces") # jax.profiler.start_trace("./traces")
# jit_train_step = jax.jit(train_step)
print("*" * 50) print("*" * 50)
print("training start") print("training start")

View File

@ -289,6 +289,7 @@ class FlaxT5Attention(nn.Module):
def _merge_heads(self, hidden_states): def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
# i suspect we are threading state here
@nn.compact @nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask): def _concatenate_to_cache(self, key, value, query, attention_mask):
""" """
@ -298,10 +299,17 @@ class FlaxT5Attention(nn.Module):
""" """
# detect if we're initializing by absence of existing cache data. # detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key") is_initialized = self.has_variable("cache", "cached_key")
# Variables are identified by a collection (e.g., "batch_stats") and a name
# (e.g., "moving_mean"). The value property gives access to the variable's
# content and can be assigned to for mutation.
#
# self.variable either 1.) initializes values for the first time
# 2.) retrieves the variable and does not override
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
# only run if initialized before
if is_initialized: if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices # update key, value caches with our new 1d spatial slices
@ -688,7 +696,7 @@ class FlaxT5BlockCollection(nn.Module):
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
for i, layer_module in enumerate(self.blocks): for _, layer_module in enumerate(self.blocks):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -933,7 +941,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
config_class = T5Config config_class = T5Config
base_model_prefix = "transformer" base_model_prefix = "transformer"
module_class: nn.Module = None module_class: nn.Module = None # to be overriden by subclass
def __init__( def __init__(
self, self,

View File

@ -1,360 +0,0 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # prediction code
# ## import and process test data
# %%
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset
import evaluate
from tqdm import tqdm
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
# 'input': f"<DESC>{row['tag_description']}<DESC>",
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
# 'answer': f"{row['thing']} {row['property']}",
# 'answer_thing': row['thing'],
# 'answer_property': row['property'],
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %% [markdown]
# ## Load model for attributes
# %%
# load model
model_name_or_path = "./t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path
)
# %% [markdown]
# ## Tokenizer
# %%
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
inputs,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=test_dataset.column_names,
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# # model generation
# %%
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# reload model to prevent leakage of variables
# load model
model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# ensure full size floats
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# we need to replicate model over devices
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
pred_generations = []
pred_labels = []
rng, input_rng = jax.random.split(rng)
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
for _ in tqdm(range(pred_steps), desc="Predicting..."):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
# generation
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
# %% [markdown]
# # process predictions
# %%
# code to get special token ids
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
# tokens = tokenizer.tokenize(sentence)
# print("Tokens:", tokens)
# # Get the IDs (integer indices) of specific tokens
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
# print("Token IDs:", token_ids)
# %%
# extract sequence and decode
def extract_seq(tokens, start_value, end_value):
if start_value not in tokens or end_value not in tokens:
return None # Or handle this case according to your requirements
start_id = np.where(tokens == start_value)[0][0]
end_id = np.where(tokens == end_value)[0][0]
return tokens[start_id+1:end_id]
def process_tensor_output(tokens):
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
p_thing = None
p_property = None
if (thing_seq is not None):
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
if (property_seq is not None):
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
return p_thing, p_property
# %%
# decode prediction labels
def decode_preds(tokens_list):
thing_prediction_list = []
property_prediction_list = []
for tokens in tokens_list:
p_thing, p_property = process_tensor_output(tokens)
thing_prediction_list.append(p_thing)
property_prediction_list.append(p_property)
return thing_prediction_list, property_prediction_list
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
# %%
# add labels too
thing_actual_list, property_actual_list = decode_preds(pred_labels)
# Convert the list to a Pandas DataFrame
df = pd.DataFrame({'p_thing': thing_prediction_list,
'p_property': property_prediction_list,
'thing': thing_actual_list,
'property' : property_actual_list})
df['p_thing_correct'] = df['p_thing'] == df['thing']
df['p_property_correct'] = df['p_property'] == df['property']
# %%
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
print("property accuracy", sum(df['p_property_correct'])/len(df))
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
# %%
df[~df["p_property_correct"]]
# %%
df['p_thing']
# %%
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"