698 lines
23 KiB
Python
698 lines
23 KiB
Python
# %%
|
|
import os
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = False
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
else:
|
|
# GPU flags
|
|
flags = (
|
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
|
'--xla_gpu_triton_gemm_any=True '
|
|
# '--xla_gpu_enable_async_collectives=true '
|
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
|
)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
os.environ["XLA_FLAGS"] = flags
|
|
os.environ.update({
|
|
"TOKENIZERS_PARALLELISM" : "false",
|
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
|
"NCCL_LL128_BUFFSIZE": "-2",
|
|
"NCCL_LL_BUFFSIZE": "-2",
|
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
|
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
|
})
|
|
|
|
|
|
|
|
|
|
import functools
|
|
from functools import partial
|
|
from pprint import pprint
|
|
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh, NamedSharding
|
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
|
from jax.experimental import mesh_utils
|
|
from jax.sharding import PartitionSpec
|
|
from ml_collections import ConfigDict
|
|
import optax
|
|
import logging
|
|
import time
|
|
from datasets import Dataset, load_from_disk
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
|
|
import flax.core
|
|
|
|
# model checkpointing and saving utilities
|
|
from flax import linen as nn
|
|
from flax.training import checkpoints, train_state
|
|
from flax import struct, serialization
|
|
import orbax.checkpoint as ocp
|
|
from flax.training import orbax_utils
|
|
|
|
from parallel.partitions import set_partitions
|
|
|
|
from tqdm import tqdm
|
|
|
|
from parallel.dataload import DataPrepare
|
|
|
|
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
|
|
if USE_CPU_ONLY:
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
else:
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
|
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
|
|
|
|
|
|
|
|
|
# %%
|
|
## get platform type
|
|
from jax.extend.backend import get_backend
|
|
print(get_backend().platform)
|
|
print(jax.devices())
|
|
|
|
# %%
|
|
# config options
|
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
|
save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/'
|
|
# file_path = 'combined_data'
|
|
split_datasets = load_from_disk(file_path)
|
|
training_size = len(split_datasets['train'])
|
|
# Store some constant
|
|
seed = 117
|
|
num_epochs = 5
|
|
batch_size = 32
|
|
num_train_epochs = num_epochs
|
|
per_device_train_batch_size = batch_size
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
per_device_eval_batch_size = batch_size
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = training_size // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
warmup_steps = 0
|
|
learning_rate = 5e-5
|
|
|
|
weight_decay = 0.01
|
|
adam_beta1 = 0.9
|
|
adam_beta2 = 0.999
|
|
adam_epsilon = 1e-8
|
|
label_smoothing_factor = 0.0
|
|
num_beams = 1
|
|
val_max_target_length = 128
|
|
predict_with_generate = True
|
|
|
|
|
|
# %%
|
|
# prepare data
|
|
# init object
|
|
# e.g. Config
|
|
print("preparing data")
|
|
data_config = ConfigDict(
|
|
dict(
|
|
max_length=128,
|
|
pad_token_id=0,
|
|
decoder_start_token_id=0
|
|
)
|
|
)
|
|
|
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
|
# # example usage
|
|
# # %%
|
|
seed = 117
|
|
rng = jax.random.PRNGKey(seed)
|
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
|
batch = next(iter(train_loader))
|
|
# batch
|
|
|
|
# %%
|
|
# model
|
|
|
|
# working
|
|
# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
|
|
# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
|
|
# from parallel.t5_model.pure_t5 import make_config
|
|
# config = make_config()
|
|
# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True)
|
|
|
|
|
|
# %%
|
|
# from transformers import FlaxT5ForConditionalGeneration, T5Config
|
|
# model = FlaxT5ForConditionalGeneration.from_pretrained(
|
|
# "t5-base",
|
|
# dtype=jnp.bfloat16,
|
|
# )
|
|
# # pretrained_params = model.params
|
|
# model = model.module
|
|
|
|
# %%
|
|
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
|
|
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
|
|
main_model = custom_model.from_pretrained(
|
|
"t5-base",
|
|
dtype=jnp.float32,
|
|
# gradient_checkpointing=True,
|
|
)
|
|
params = main_model.params
|
|
# pretrained_params = model.params
|
|
model = main_model.module
|
|
|
|
# %%
|
|
# # testing config hashability
|
|
# # some explanation:
|
|
# # The PreTrainedModel class loads a T5Config model that is not hashable because
|
|
# # it is a complicated class that pretends to be a dataclass.
|
|
# # The solution is to extract a dict from it, then make a ConfigDict from
|
|
# # ml_collections library so that we can get values via the "." operator.
|
|
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
|
|
# # modify the config before passing to the next layer
|
|
# from transformers import T5Config
|
|
# from t5_model.configuration_t5 import FrozenT5Config
|
|
# from ml_collections import ConfigDict, FrozenConfigDict
|
|
#
|
|
# config = T5Config.from_pretrained("t5-base").to_dict()
|
|
# config.pop('architectures')
|
|
# config.pop('id2label')
|
|
# # test if it works
|
|
# frozen_config = FrozenConfigDict(config)
|
|
# # test hash
|
|
# hash(frozen_config)
|
|
|
|
# %%
|
|
|
|
# %%
|
|
# # print model
|
|
# rng, input_rng = jax.random.split(rng)
|
|
# model.tabulate(
|
|
# input_rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_input_ids'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask'],
|
|
# console_kwargs={"force_jupyter": True}
|
|
# )
|
|
|
|
# %%
|
|
# print model datatype to verify
|
|
# rng, input_rng = jax.random.split(rng)
|
|
# variables = model.init(
|
|
# input_rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_input_ids'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
# create mesh
|
|
print("creating mesh")
|
|
device_mesh = mesh_utils.create_device_mesh((1,1))
|
|
print(device_mesh)
|
|
|
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
print(mesh)
|
|
|
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
|
|
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
|
|
|
|
|
# %%
|
|
# optimizers
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
training_size,
|
|
train_batch_size,
|
|
num_train_epochs,
|
|
warmup_steps,
|
|
learning_rate,
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=adam_beta1,
|
|
b2=adam_beta2,
|
|
eps=adam_epsilon,
|
|
weight_decay=weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
|
|
print("compile")
|
|
|
|
|
|
# enable bf16 except for layer_norm
|
|
def create_mask_for_layer_norm(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
mask = {
|
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
|
}
|
|
mask = traverse_util.unflatten_dict(mask)
|
|
return mask
|
|
|
|
# borrowed from transformers modeling_flax_utils
|
|
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
|
"""
|
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
|
"""
|
|
|
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
|
def conditional_cast(param):
|
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
|
param = param.astype(dtype)
|
|
return param
|
|
|
|
if mask is None:
|
|
return jax.tree_util.tree_map(conditional_cast, params)
|
|
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
|
|
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
|
if masked:
|
|
flat_params[key] = conditional_cast(flat_params[key])
|
|
|
|
return traverse_util.unflatten_dict(flat_params)
|
|
|
|
# Cast all parameters to bfloat16 if desired
|
|
# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
|
|
# %%
|
|
def init_fn(params, model, optimizer):
|
|
# do be careful with the model init
|
|
# imported models might have complicated init methods
|
|
# mask = create_mask_for_layer_norm(params)
|
|
# override params with bfloat version
|
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
|
|
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
|
apply_fn=model.apply,
|
|
params=params,
|
|
tx=optimizer)
|
|
return state
|
|
|
|
|
|
# def init_fn(rng, batch, model, optimizer):
|
|
# # do be careful with the model init
|
|
# # imported models might have complicated init methods
|
|
# variables = model.init(
|
|
# rng,
|
|
# input_ids=batch['input_ids'],
|
|
# attention_mask=batch['attention_mask'],
|
|
# decoder_input_ids=batch['decoder_input_ids'],
|
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
|
# )
|
|
# params = variables['params']
|
|
# mask = create_mask_for_layer_norm(params)
|
|
# # override params with bfloat version
|
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
|
#
|
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
|
# apply_fn=model.apply,
|
|
# params=params,
|
|
# tx=optimizer)
|
|
# return state
|
|
|
|
|
|
|
|
# %%
|
|
# Create an abstract closure to wrap the function before feeding it in
|
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
|
# eval_shape(fn, rng_key, x)
|
|
# used to perform shape inference
|
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
|
# rng, init_rng = jax.random.split(rng)
|
|
abstract_variables = jax.eval_shape(
|
|
functools.partial(init_fn, model=model, optimizer=adamw), params)
|
|
|
|
# rng, init_rng = jax.random.split(rng)
|
|
# abstract_variables = jax.eval_shape(
|
|
# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
|
|
|
|
|
|
# %%
|
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
|
# of the `init_fn`.
|
|
# flan.linen.get_sharding
|
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
|
# print(state_sharding)
|
|
|
|
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
|
|
|
|
##################################################
|
|
# # %%
|
|
# # replace the params tree with the new modified tree
|
|
# # create partitions for model
|
|
# from parallel.partitions import set_partitions
|
|
# # set_partitions freezes the params on return
|
|
# model_part_spec = set_partitions(unfreeze(params))
|
|
# # p is already a partition spec
|
|
# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
|
|
#
|
|
# # %%
|
|
# # get_shapes = jax.tree.map(jnp.shape, params)
|
|
# # actually tuple
|
|
# # state_shapes = jax.eval_shape(state_sharding, get_shapes)
|
|
#
|
|
# # %%
|
|
# # get pspec for opt_state
|
|
# def get_opt_spec(x):
|
|
# if isinstance(x, dict):
|
|
# return unfreeze(model_named_sharding)
|
|
# # return an empty partspec
|
|
# return mesh_sharding((PartitionSpec()))
|
|
#
|
|
# # this function replaces the empty model params spec with the 'model_named_shard'
|
|
# state_sharding = jax.tree.map(
|
|
# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
|
# )
|
|
|
|
|
|
|
|
# %%
|
|
jit_init_fn = jax.jit(
|
|
init_fn,
|
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
|
in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly
|
|
out_shardings=state_sharding # but returned initialized_state is sharded
|
|
)
|
|
initialized_state = jit_init_fn(params, model, adamw)
|
|
|
|
|
|
# jit_init_fn = jax.jit(
|
|
# init_fn,
|
|
# static_argnames=('model', 'optimizer'), # skip model and optimizer
|
|
# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
|
# out_shardings=state_sharding
|
|
# )
|
|
#
|
|
#
|
|
# rng, init_rng = jax.random.split(rng)
|
|
# initialized_state = jit_init_fn(rng, batch, model, adamw)
|
|
|
|
|
|
# %%
|
|
# train step
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
"""
|
|
The label smoothing implementation is adapted from Flax's official example:
|
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
confidence = 1.0 - label_smoothing_factor
|
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
normalizing_constant = -(
|
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
)
|
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
|
loss = loss - normalizing_constant
|
|
|
|
# ignore padded tokens from loss
|
|
loss = loss * padding_mask
|
|
loss = loss.sum()
|
|
num_labels = padding_mask.sum()
|
|
return loss, num_labels
|
|
|
|
# %%
|
|
|
|
# sharded_loss_fn = jax.jit(
|
|
# loss_fn,
|
|
# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis
|
|
# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model'
|
|
# )
|
|
|
|
def gather_and_sum(
|
|
sharded_values,
|
|
in_shardings
|
|
):
|
|
with mesh:
|
|
# Gather sharded values into a single device
|
|
gathered_values = jax.jit(
|
|
lambda x: x, in_shardings=in_shardings, out_shardings=None
|
|
)(sharded_values)
|
|
|
|
# Compute the sum of gathered values
|
|
summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values)
|
|
return summed_value
|
|
|
|
|
|
# single device code annotated with jax.jit
|
|
@functools.partial(
|
|
jax.jit,
|
|
# state is state_sharding initialized from init_fn
|
|
# x_sharding is data sharded explicitly later
|
|
in_shardings=(state_sharding, x_sharding),
|
|
# return state as state_sharding
|
|
# we do not shard the metrics
|
|
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
|
|
donate_argnames=('state'),
|
|
)
|
|
def train_step(state, batch):
|
|
|
|
label_smoothing_factor=0.0
|
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
def compute_loss(params, batch):
|
|
# check constraints
|
|
# frozen dict not allowed as sharding object
|
|
# params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
|
|
# batch = jax.lax.with_sharding_constraint(batch, x_sharding)
|
|
# labels = batch.pop("decoder_input_ids")
|
|
# no use of labels here
|
|
logits = state.apply_fn(
|
|
{'params': params},
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
decoder_input_ids=batch['decoder_input_ids'],
|
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
|
)[0] # zero because output is some structure, where first is the logit
|
|
# use labels here
|
|
loss, num_labels = loss_fn(
|
|
logits,
|
|
batch["labels"],
|
|
batch["decoder_attention_mask"],
|
|
label_smoothing_factor)
|
|
return loss, num_labels
|
|
|
|
# compute gradients through computational graph
|
|
# allow values to pass through
|
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params, batch)
|
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
|
|
# true grad = total grad / total samples
|
|
# needs to be in a singleton tuple for some reason
|
|
# gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),))
|
|
|
|
# gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec()))
|
|
|
|
# summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad)
|
|
# grad = jax.lax.psum(grad, "batch")
|
|
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
|
|
new_state = state.apply_gradients(grads=grad)
|
|
with jax.named_scope("sync_metrics"):
|
|
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
# step_metrics = jax.tree.map(
|
|
# # previously needed lax.psum
|
|
# # now just write single device code, let compiler handle
|
|
# lambda x: jnp.mean(x), step_metrics
|
|
# )
|
|
|
|
# if metrics is None:
|
|
# metrics = step_metrics
|
|
# else:
|
|
# # combine all the synced metrics
|
|
# metrics = jax.tree.map(jnp.mean, metrics, step_metrics)
|
|
|
|
|
|
return new_state, step_metrics
|
|
|
|
|
|
|
|
|
|
# %%
|
|
# prep 1 step
|
|
print("1 step for jit-ting")
|
|
|
|
|
|
with mesh:
|
|
state, metrics = train_step(initialized_state, batch)
|
|
|
|
|
|
# %%
|
|
|
|
# %%
|
|
# tr
|
|
print("***** Running training *****")
|
|
print(f" Num examples = {training_size}")
|
|
print(f" Num Epochs = {num_epochs}")
|
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
print(f" Total optimization steps = {total_train_steps}")
|
|
|
|
|
|
# %%
|
|
# jax.profiler.start_trace("./traces")
|
|
|
|
# function to shard a batch by treating it as a pytree
|
|
def shard_batch(batch):
|
|
# Shard each element in the dictionary (i.e., each key-value pair)
|
|
return jax.tree_util.tree_map(
|
|
lambda x: jax.device_put(x, x_sharding),
|
|
batch
|
|
)
|
|
|
|
|
|
print("*" * 10)
|
|
print("training start")
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
for epoch in epochs:
|
|
train_start = time.time()
|
|
|
|
# Create sampling rng
|
|
train_metrics = []
|
|
steps_per_epoch = training_size // train_batch_size
|
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
|
batch = next(train_loader)
|
|
# send to device
|
|
# batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()}
|
|
# batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding)
|
|
# batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding)
|
|
# batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding)
|
|
# batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding)
|
|
sharded_batch = shard_batch(batch)
|
|
with mesh:
|
|
state, train_metric = train_step(state, sharded_batch)
|
|
|
|
# train_metrics.append(train_metric)
|
|
|
|
|
|
# this is for more accurate time stats, but slows down training
|
|
# train_metric['loss'].block_until_ready()
|
|
train_time = time.time() - train_start
|
|
|
|
|
|
|
|
epochs.write(
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
|
f"Loss: {train_metric['loss']}, "
|
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
|
f"Last train time: {train_time})"
|
|
)
|
|
# jax.profiler.stop_trace()
|
|
# %%
|
|
# with mesh:
|
|
# gathered_params = jax.jit(
|
|
# lambda x: x,
|
|
# in_shardings=(unfreeze(model_named_sharding),),
|
|
# out_shardings=mesh_sharding(PartitionSpec())
|
|
# )(state.params)
|
|
|
|
main_model = custom_model.from_pretrained('t5-base')
|
|
output_dir = save_path
|
|
# save checkpoint after each epoch and push checkpoint to the hub
|
|
if jax.process_index() == 0:
|
|
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
|
main_model.save_pretrained(output_dir, params=params)
|
|
|
|
# # stick to defaults
|
|
# options = ocp.CheckpointManagerOptions()
|
|
# with ocp.CheckpointManager(
|
|
# ocp.test_utils.erase_and_create_empty(save_path),
|
|
# options=options,
|
|
# ) as mngr:
|
|
#
|
|
# mngr.save(0, args=ocp.args.StandardSave(state))
|
|
# mngr.wait_until_finished()
|
|
|
|
# After providing `args` during an initial `save` or `restore` call, the
|
|
# `CheckpointManager` instance records the type so that you do not need to
|
|
# specify it again. If the `CheckpointManager` instance is not provided with a
|
|
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
|
|
# occasion it cannot be restored without specifying the argument at restore
|
|
# time.
|
|
|
|
# # In many cases, you can restore exactly as saved without specifying additional
|
|
# # arguments.
|
|
# mngr.restore(0)
|
|
# # If customization of properties like sharding or dtype is desired, just provide
|
|
# # the abstract target PyTree, the properties of which will be used to set
|
|
# # the properties of the restored arrays.
|
|
# mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
|
|
|
|
# %%
|