542 lines
16 KiB
Python
542 lines
16 KiB
Python
# %% [markdown]
|
|
# # T5 implementation using jax with pjit
|
|
|
|
|
|
# MARK: START
|
|
# %%
|
|
# let's make 8-device simulator
|
|
import os
|
|
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = True
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
else:
|
|
# GPU flags
|
|
flags += (
|
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
|
"--xla_gpu_triton_gemm_any=false "
|
|
"--xla_gpu_enable_async_collectives=true "
|
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
|
)
|
|
os.environ["XLA_FLAGS"] = flags
|
|
|
|
import functools
|
|
from functools import partial
|
|
from pprint import pprint
|
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh
|
|
from jax.experimental.pjit import pjit
|
|
from jax.sharding import PartitionSpec as P
|
|
from ml_collections import ConfigDict
|
|
import optax
|
|
import logging
|
|
import time
|
|
from datasets import Dataset, load_from_disk
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
import flax.core
|
|
|
|
from tqdm import tqdm
|
|
|
|
from dataload import DataPrepare
|
|
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
|
|
if USE_CPU_ONLY:
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
else:
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
|
|
|
|
# # %%
|
|
# import jax
|
|
# import jax.numpy as jnp
|
|
# import optax
|
|
# import numpy as np
|
|
# from functools import partial
|
|
# from typing import Callable, Optional
|
|
# import math
|
|
#
|
|
# # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
# jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
# # jax.config.update("jax_enable_x64", False)
|
|
# # enable cache
|
|
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
|
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
|
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
|
#
|
|
#
|
|
# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
|
#
|
|
# from flax import jax_utils, traverse_util
|
|
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
# from flax.training import train_state
|
|
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
# import flax.core
|
|
|
|
|
|
# %%
|
|
# get platform type
|
|
from jax.lib import xla_bridge
|
|
print(xla_bridge.get_backend().platform)
|
|
|
|
# %%
|
|
# config options
|
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
|
save_path = 't5_80_1_bf16'
|
|
# file_path = 'combined_data'
|
|
split_datasets = load_from_disk(file_path)
|
|
training_size = len(split_datasets['train'])
|
|
# Store some constant
|
|
seed = 117
|
|
num_epochs = 5
|
|
batch_size = 384 # 384 is the best
|
|
num_train_epochs = num_epochs
|
|
per_device_train_batch_size = batch_size
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
per_device_eval_batch_size = batch_size
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = training_size // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
warmup_steps = 0
|
|
learning_rate = 2e-5
|
|
|
|
weight_decay = 0.01
|
|
adam_beta1 = 0.9
|
|
adam_beta2 = 0.999
|
|
adam_epsilon = 1e-8
|
|
label_smoothing_factor = 0.0
|
|
|
|
num_beams = 1
|
|
val_max_target_length = 128
|
|
|
|
predict_with_generate = True
|
|
|
|
|
|
# %%
|
|
# prepare data
|
|
# init object
|
|
# e.g. Config
|
|
data_config = ConfigDict(
|
|
dict(
|
|
max_length=86,
|
|
pad_token_id=0,
|
|
decoder_start_token_id=0
|
|
)
|
|
)
|
|
|
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
|
# # example usage
|
|
# %%
|
|
seed = 117
|
|
rng = jax.random.PRNGKey(seed)
|
|
train_loader = dataprep.data_loader(rng, batch_size=1)
|
|
batch = next(iter(train_loader))
|
|
|
|
# %%
|
|
batch
|
|
|
|
# %%
|
|
# model
|
|
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
from transformers import T5Config
|
|
|
|
config = T5Config()
|
|
|
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
|
# then pass the mask as follows
|
|
from flax import traverse_util
|
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
|
|
|
# useful for transformer model
|
|
model.enable_gradient_checkpointing()
|
|
|
|
# enable bf16 except for layer_norm
|
|
flat_params = traverse_util.flatten_dict(model.params)
|
|
mask = {
|
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
|
}
|
|
mask = traverse_util.unflatten_dict(mask)
|
|
model.params = model.to_bf16(model.params, mask)
|
|
|
|
# %%
|
|
|
|
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
|
t5_module = model.module
|
|
|
|
# %%
|
|
jax.tree.map(jnp.shape, model.params)
|
|
|
|
# %%
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec
|
|
from pjit_partition import set_partitions
|
|
|
|
params = model.params
|
|
data_partition_specs = PartitionSpec()
|
|
extra_param_keys = list(model._missing_keys)
|
|
initial_partition_specs = set_partitions(params)
|
|
# this is the partition spec we will use
|
|
filled_param_partition_specs = set_partitions(params, extra_keys=extra_param_keys)
|
|
|
|
# %%
|
|
# let us see the param_partition_spec
|
|
filled_param_partition_specs
|
|
|
|
# %% let us set up the mesh
|
|
|
|
from jax.sharding import Mesh
|
|
devices = np.asarray(jax.devices())
|
|
|
|
# %%
|
|
|
|
# mp: model/tensor parallelism
|
|
# dp: data parallelism
|
|
# we just use 'data' as a common axis for data and model params
|
|
mesh_axis_names = ("data")
|
|
print("Logical mesh:", devices)
|
|
|
|
mesh = Mesh(devices, mesh_axis_names)
|
|
|
|
# it is technically possible to use pjit_partition to set special partition rules
|
|
# e.g. by param size
|
|
# but for now just move on
|
|
|
|
# %% [markdown]
|
|
# # Model
|
|
#
|
|
#
|
|
#
|
|
|
|
# %%
|
|
|
|
# Initialize our training
|
|
rng = jax.random.PRNGKey(seed)
|
|
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
|
|
# %%
|
|
# optimization functions
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
training_size,
|
|
train_batch_size,
|
|
num_train_epochs,
|
|
warmup_steps,
|
|
learning_rate,
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=adam_beta1,
|
|
b2=adam_beta2,
|
|
eps=adam_epsilon,
|
|
weight_decay=weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
|
|
# %%
|
|
# Training functions
|
|
# class TrainState(train_state.TrainState):
|
|
# dropout_rng: jnp.ndarray
|
|
#
|
|
# # easy way to achieve data parallelism
|
|
# # also achieves folding of rng keys
|
|
# def replicate(self):
|
|
# return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
|
# set bf16 for model params
|
|
# model.params = model.to_bf16(model.params)
|
|
# Cast parameters to bfloat16 if desired
|
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
|
|
# state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
|
|
|
# %%
|
|
# see if we can init the model
|
|
|
|
# %%
|
|
from transformers import FlaxT5Model, T5Config
|
|
config = T5Config.from_pretrained('t5-base')
|
|
model = FlaxT5Model(config, _do_init=True).module
|
|
|
|
# %%
|
|
# Initialize random key and input for initialization
|
|
rng = jax.random.PRNGKey(0)
|
|
train_loader = dataprep.data_loader(rng, batch_size=1)
|
|
batch = next(iter(train_loader))
|
|
|
|
# %%
|
|
|
|
# Initialize model parameters
|
|
# init of FlaxT5Module.__call__
|
|
variables = model.init(rng,
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
decoder_input_ids=batch['decoder_attention_mask'],
|
|
decoder_attention_mask=batch['decoder_attention_mask']
|
|
)
|
|
params = variables['params']
|
|
|
|
|
|
# %%
|
|
# create an init_fn
|
|
def init_fn(rng: jax.random.PRNGKey, batch, model) -> train_state.TrainState:
|
|
init_rng, rng = jax.random.split(rng)
|
|
variables = model.init(
|
|
init_rng,
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
decoder_input_ids=batch['decoder_attention_mask'],
|
|
decoder_attention_mask=batch['decoder_attention_mask']
|
|
)
|
|
params = variables.pop("params")
|
|
state = train_state.TrainState.create(
|
|
apply_fn=model.__call__,
|
|
params=params,
|
|
tx=adamw,
|
|
)
|
|
return state
|
|
|
|
|
|
|
|
|
|
# %%
|
|
# we do not know the output PartitionSpec
|
|
# we perform the hack where we just initialize it just to find the outspec
|
|
init_fn_try = shard_map(
|
|
functools.partial(init_fn, model=model),
|
|
mesh,
|
|
# 2nd argument is for the model
|
|
in_specs=(P(), P("data")),
|
|
out_specs=P(),
|
|
check_rep=False
|
|
)
|
|
|
|
# %%
|
|
rng, model_init_rng = jax.random.split(rng)
|
|
train_loader = dataprep.data_loader(model_init_rng, batch_size=batch_size)
|
|
batch = next(iter(train_loader))
|
|
|
|
|
|
state_fsdp_shapes = jax.eval_shape(init_fn_try, model_init_rng, batch)
|
|
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
|
|
|
|
# print("RNG", state_fsdp_specs.rng)
|
|
print("\nParameters")
|
|
pprint(state_fsdp_specs.params)
|
|
print("\nOptimizer state")
|
|
pprint(state_fsdp_specs.opt_state[0])
|
|
|
|
# note: state_fsdp_specs is now ready to be used as pjit outspec
|
|
|
|
|
|
|
|
# %%
|
|
# Setup train state
|
|
# state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
|
state = jax.jit(
|
|
init_fn,
|
|
in_shardings=(P(), P("data")),
|
|
out_shardings=state_fsdp_specs,
|
|
)
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# label smoothed cross entropy
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
"""
|
|
The label smoothing implementation is adapted from Flax's official example:
|
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
confidence = 1.0 - label_smoothing_factor
|
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
normalizing_constant = -(
|
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
)
|
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
|
loss = loss - normalizing_constant
|
|
|
|
# ignore padded tokens from loss
|
|
loss = loss * padding_mask
|
|
loss = loss.sum()
|
|
num_labels = padding_mask.sum()
|
|
return loss, num_labels
|
|
|
|
# MARK: train_step
|
|
# Define gradient update step fn
|
|
def train_step(state, batch):
|
|
label_smoothing_factor=0.0
|
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
def compute_loss(params):
|
|
labels = batch.pop("labels")
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
return loss, num_labels
|
|
|
|
# compute gradients through computational graph
|
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params)
|
|
num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
# true loss = total loss / total samples
|
|
# loss = jax.lax.psum(loss, "batch")
|
|
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
|
|
# true grad = total grad / total samples
|
|
grad = jax.lax.psum(grad, "batch")
|
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
return new_state, metrics
|
|
|
|
# max_length = (
|
|
# val_max_target_length if val_max_target_length is not None else model.config.max_length
|
|
# )
|
|
# num_beams = num_beams if num_beams is not None else model.config.num_beams
|
|
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
|
|
# Create parallel version of the train and eval step
|
|
# only state and batch
|
|
p_train_step = jax.jit(
|
|
train_step,
|
|
# state for first, batch for second
|
|
in_shardings=(P("data"), P("data")),
|
|
out_shardings=(P("data"), P("data")),
|
|
donate_argnames=("state"),
|
|
)
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
print("***** Running training *****")
|
|
print(f" Num examples = {training_size}")
|
|
print(f" Num Epochs = {num_epochs}")
|
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
print(f" Total optimization steps = {total_train_steps}")
|
|
|
|
|
|
# %%
|
|
# jax.profiler.start_trace("./traces")
|
|
|
|
# Example batch (sharded across devices)
|
|
sharded_batch = {
|
|
'input_ids': jax.device_put_sharded(batch['input_ids'], devices),
|
|
'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices),
|
|
'labels': jax.device_put_sharded(batch['labels'], devices),
|
|
'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices),
|
|
'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices),
|
|
}
|
|
|
|
# Initial TrainState (pjit-ted TrainState)
|
|
sharded_state = jax.device_put_replicated(train_state, devices)
|
|
|
|
# %%
|
|
|
|
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
for epoch in epochs:
|
|
train_start = time.time()
|
|
|
|
# Create sampling rng
|
|
train_metrics = []
|
|
rng, data_rng = jax.random.split(rng)
|
|
train_loader = dataprep.data_loader(data_rng, batch_size=batch_size)
|
|
steps_per_epoch = training_size // train_batch_size
|
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
|
batch = next(train_loader)
|
|
# batch = shard(batch)
|
|
state, train_metric = p_train_step(state, batch)
|
|
train_metrics.append(train_metric)
|
|
|
|
train_time = time.time() - train_start
|
|
|
|
train_metric = unreplicate(train_metric)
|
|
train_metric['loss'].block_until_ready()
|
|
|
|
|
|
|
|
epochs.write(
|
|
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
|
# f"Learning Rate:{train_metric['learning_rate']}, "
|
|
f"Last train time: {train_time})"
|
|
)
|
|
# jax.profiler.stop_trace()
|
|
# %%
|
|
|
|
# output_dir = save_path
|
|
# # save checkpoint after each epoch and push checkpoint to the hub
|
|
# if jax.process_index() == 0:
|
|
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
|
# model.save_pretrained(output_dir, params=params)
|
|
# tokenizer.save_pretrained(output_dir)
|