learn_jax/parallel/t5_pjit.py

652 lines
20 KiB
Python

# MARK: START
# %%
# let's make 8-device simulator
import sys
sys.dont_write_bytecode = True
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze
import flax.core
from partitions import set_partitions
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1_bf16'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 2 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
data_config = ConfigDict(
dict(
max_length=86,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# batch
# %%
# model
from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
# from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
from t5_model.pure_t5 import make_config
config = make_config()
model = model_init(config)
# %%
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
# useful for transformer model
# model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
# from flax import traverse_util
# flat_params = traverse_util.flatten_dict(model.params)
# mask = {
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
# }
# mask = traverse_util.unflatten_dict(mask)
# model.params = model.to_bf16(model.params, mask)
##################################################################
# set partition on model
# %%
# # let's output the model parameters to a json file for study
# import json
# shape_dict = jax.tree.map(jnp.shape, params)
# # print(json.dumps(shape_dict, sort_keys=True, indent=4))
# with open('t5.json', 'w') as f:
# json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
# MARK: setup mesh
# %%
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec)
##################################################
# optimizers
# %%
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# specify sharding
# shard data
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
# Defining the required dimensions for the self-attention layer input
# batch_size = 2
# seq_length = 768
# n_heads = 12
# head_dim = 768
# %%
# Create a large array with the shape (batch_size, seq_length, n_heads, head_dim)
# large_input = np.random.rand(2,768,768)
# batch = jax.device_put(large_input, x_sharding)
# %%
# jax.debug.visualize_array_sharding(batch['input_ids'])
# %%
# shard output
# we will shard state by tracking its output upon jax.eval_shape after init
# define an init function to return a TrainState
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=variables['params'],
# tx=optimizer)
# return state
def init_fn(rng, batch, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
variables = model.init(
rng,
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask']
)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
# %%
# alternative
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(
# rng, batch
# )
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=variables['params'],
# tx=optimizer)
# return state
# %%
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
rng, init_rng = jax.random.split(rng)
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
# %%
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
print(state_sharding)
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
out_shardings=state_sharding
)
rng, init_rng = jax.random.split(rng)
initialized_state = jit_init_fn(rng, batch, model, adamw)
# %%
# we can analyze the params structure
# for weight, partitioned in initialized_state.params['decoder'].items():
# print(f'Sharding of {weight}: {partitioned}')
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
jax.tree.map(jnp.shape, initialized_state.params['decoder'])
# %%
print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding)
print(initialized_state.step)
print(initialized_state.step.sharding)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# %%
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# in_shardings=(state_sharding, x_sharding),
out_shardings=state_sharding
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
# num_labels = jax.lax.psum(num_labels, "batch")
# true grad = total grad / total samples
# grad = jax.lax.psum(grad, "batch")
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state
# %%
# variables = model.init(
# rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
# batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
with mesh:
new_state = train_step(initialized_state, batch)
# %%
# # %%
# #############################################################
# # we cannot integrate our model pspec with train_state
# # we just shard separately
# # update: we also cannot use the method of modifying a partitionspec tree
# # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter
#
# # %%
# def get_optim_initial_state(params):
# params = params
# state = adamw.init(params)
# return tuple((state)), params
#
# # %%
# # create partitions for model
# from partitions import set_partitions
# # set_partitions freezes the params on return
# model_param_spec = set_partitions(unfreeze(params))
#
# # %%
# params_shapes = jax.tree.map(lambda x: x.shape, params)
# # actually tuple
# optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes)
#
# # %%
# # get pspec for opt_state
# def get_opt_spec(x):
# if isinstance(x, dict):
# return unfreeze(model_param_spec)
# return PartitionSpec()
#
# # this function replaces the empty model params spec with the 'model_param_spec'
# opt_state_spec, param_spec = jax.tree.map(
# get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
# )
#
# # %%
#
# model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
# # store on cpu
# model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
#
# # %%
# device_mesh = mesh_utils.create_device_mesh((2,2))
# print(device_mesh)
#
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# print(mesh)
#
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
# return NamedSharding(mesh, pspec)
#
#
# # opt_state_sharding = mesh_sharding(opt_state_spec)
# # param_sharding = mesh_sharding(param_spec)
#
#
# # %%
# opt_state_sharding = nn.get_sharding(opt_state_spec, mesh)
# param_sharding = nn.get_sharding(param_spec, mesh)
#
# # %%
# # jit the get_initial_state function to shard params and init optimizer state in
# # a sharded way
# from jax.experimental.pjit import pjit
#
# with mesh:
# p_get_initial_state = pjit(
# get_optim_initial_state,
# in_shardings=None,
# out_shardings=(opt_state_spec, param_spec),
# )
#
# # Convert your PartitionSpec to NamedSharding for model params
# param_sharding = NamedSharding(mesh, freeze(param_spec))
# # Use device_put with sharding to move params onto the mesh
# sharded_params = jax.device_put(freeze(params), param_sharding)
#
# with mesh:
# # params is already frozen
# sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params))
#
# # %%
#
# # give up this section
# #############################################################
# # create train state
#
# # %%
# # Initialize random key and input for initialization
# rng = jax.random.PRNGKey(seed)
# loader_rng, rng = jax.random.split(rng)
# train_loader = dataprep.data_loader(rng, batch_size=2)
# batch = next(iter(train_loader))
#
# # use the T5 base model to do this
# from transformers import FlaxAutoModel
# model, params = FlaxAutoModel.from_pretrained(
# 't5-base',
# _do_init=False
# )
# t5_module = model.module
#
# # %%
# init_rng, rng = jax.random.split(rng)
# variables = t5_module.init(init_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables['params']
#
# # create an init function
# # %%
# # we will shard state by tracking its output upon jax.eval_shape after init
# # define an init function to return a TrainState
# def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState:
# init_rng, rng = jax.random.split(rng)
# variables = model.init(
# init_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables.pop("params")
# state = train_state.TrainState.create(
# apply_fn=model.__call__,
# params=params,
# tx=optimizer,
# )
# return state
#
# # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
# # Create an abstract closure to wrap the function before feeding it in
# # because `jax.eval_shape` only takes pytrees as arguments.
# # eval_shape(fn, rng_key, x)
# # used to perform shape inference
# # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
# init_rng, rng = jax.random.split(rng)
# abstract_variables = jax.eval_shape(
# functools.partial(init_fn, model=t5_module, optimizer=adamw),
# init_rng,
# batch
# )
#
# # %%
# # let's make our mesh
#
# device_mesh = mesh_utils.create_device_mesh((2,2))
# print(device_mesh)
#
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# print(mesh)
#
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
# return NamedSharding(mesh, pspec)
#
# # %%
# # making jax compatible batch
#
# # %%
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
# # batch = jax.device_put(batch), x_sharding)
# # jax.debug.visualize_array_sharding(batch)
#
# # %%
# state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
#
# # %%
# # integrate model_param_specs and state_out_specs
#
# # %%
# # i want to make a Sharding object
# # model_sharding = mesh_sharding(model_param_spec)
#
# # %%
# jit_init_fn = jax.jit(
# init_fn, # rng, batch, model, optimizer
# static_argnames=('model', 'optimizer'), # skip model and optimizer
# in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data
# out_shardings=state_sharding
# )
#
# # %%
#
# init_rng, rng = jax.random.split(rng)
# initialized_state = jit_init_fn(
# init_rng,
# batch,
# t5_module,
# adamw)
#
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
#
#
# # %%
#
# # %%
#
# %%