learn_jax/t5_jax_parallel.py

698 lines
23 KiB
Python

# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint as ocp
from flax.training import orbax_utils
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 32
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# batch
# %%
# model
# working
# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
# from parallel.t5_model.pure_t5 import make_config
# config = make_config()
# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True)
# %%
# from transformers import FlaxT5ForConditionalGeneration, T5Config
# model = FlaxT5ForConditionalGeneration.from_pretrained(
# "t5-base",
# dtype=jnp.bfloat16,
# )
# # pretrained_params = model.params
# model = model.module
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.float32,
# gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# print model datatype to verify
# rng, input_rng = jax.random.split(rng)
# variables = model.init(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((1,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
print("compile")
# enable bf16 except for layer_norm
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# Cast all parameters to bfloat16 if desired
# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# %%
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(
# rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables['params']
# mask = create_mask_for_layer_norm(params)
# # override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
#
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=params,
# tx=optimizer)
# return state
# %%
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
# rng, init_rng = jax.random.split(rng)
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# rng, init_rng = jax.random.split(rng)
# abstract_variables = jax.eval_shape(
# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
# %%
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
##################################################
# # %%
# # replace the params tree with the new modified tree
# # create partitions for model
# from parallel.partitions import set_partitions
# # set_partitions freezes the params on return
# model_part_spec = set_partitions(unfreeze(params))
# # p is already a partition spec
# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
#
# # %%
# # get_shapes = jax.tree.map(jnp.shape, params)
# # actually tuple
# # state_shapes = jax.eval_shape(state_sharding, get_shapes)
#
# # %%
# # get pspec for opt_state
# def get_opt_spec(x):
# if isinstance(x, dict):
# return unfreeze(model_named_sharding)
# # return an empty partspec
# return mesh_sharding((PartitionSpec()))
#
# # this function replaces the empty model params spec with the 'model_named_shard'
# state_sharding = jax.tree.map(
# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
# )
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# jit_init_fn = jax.jit(
# init_fn,
# static_argnames=('model', 'optimizer'), # skip model and optimizer
# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
# out_shardings=state_sharding
# )
#
#
# rng, init_rng = jax.random.split(rng)
# initialized_state = jit_init_fn(rng, batch, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# %%
# sharded_loss_fn = jax.jit(
# loss_fn,
# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis
# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model'
# )
def gather_and_sum(
sharded_values,
in_shardings
):
with mesh:
# Gather sharded values into a single device
gathered_values = jax.jit(
lambda x: x, in_shardings=in_shardings, out_shardings=None
)(sharded_values)
# Compute the sum of gathered values
summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values)
return summed_value
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
# return state as state_sharding
# we do not shard the metrics
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
# params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
# batch = jax.lax.with_sharding_constraint(batch, x_sharding)
# labels = batch.pop("decoder_input_ids")
# no use of labels here
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
loss, num_labels = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
# true grad = total grad / total samples
# needs to be in a singleton tuple for some reason
# gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),))
# gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec()))
# summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad)
# grad = jax.lax.psum(grad, "batch")
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# step_metrics = jax.tree.map(
# # previously needed lax.psum
# # now just write single device code, let compiler handle
# lambda x: jnp.mean(x), step_metrics
# )
# if metrics is None:
# metrics = step_metrics
# else:
# # combine all the synced metrics
# metrics = jax.tree.map(jnp.mean, metrics, step_metrics)
return new_state, step_metrics
# %%
# prep 1 step
print("1 step for jit-ting")
with mesh:
state, metrics = train_step(initialized_state, batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
# function to shard a batch by treating it as a pytree
def shard_batch(batch):
# Shard each element in the dictionary (i.e., each key-value pair)
return jax.tree_util.tree_map(
lambda x: jax.device_put(x, x_sharding),
batch
)
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
# send to device
# batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()}
# batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding)
# batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding)
# batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding)
# batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding)
sharded_batch = shard_batch(batch)
with mesh:
state, train_metric = train_step(state, sharded_batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# with mesh:
# gathered_params = jax.jit(
# lambda x: x,
# in_shardings=(unfreeze(model_named_sharding),),
# out_shardings=mesh_sharding(PartitionSpec())
# )(state.params)
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)
# # stick to defaults
# options = ocp.CheckpointManagerOptions()
# with ocp.CheckpointManager(
# ocp.test_utils.erase_and_create_empty(save_path),
# options=options,
# ) as mngr:
#
# mngr.save(0, args=ocp.args.StandardSave(state))
# mngr.wait_until_finished()
# After providing `args` during an initial `save` or `restore` call, the
# `CheckpointManager` instance records the type so that you do not need to
# specify it again. If the `CheckpointManager` instance is not provided with a
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
# occasion it cannot be restored without specifying the argument at restore
# time.
# # In many cases, you can restore exactly as saved without specifying additional
# # arguments.
# mngr.restore(0)
# # If customization of properties like sharding or dtype is desired, just provide
# # the abstract target PyTree, the properties of which will be used to set
# # the properties of the restored arrays.
# mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
# %%