Feat: flax pjit example

This commit is contained in:
Richard Wong 2024-09-16 12:19:07 +09:00
parent ad5cf7735f
commit 429e1742ab
5 changed files with 1078 additions and 63 deletions

View File

@ -98,21 +98,21 @@ class DataPrepare():
) )
# check that data fits # check that data fits
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: # for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
int_array: np.array = train_dataset[name] # int_array: np.array = train_dataset[name]
if np.all((int_array >= 0) & (int_array <= 65535)): # if np.all((int_array >= 0) & (int_array <= 65535)):
continue # continue
else: # else:
raise ValueError("Values are out of range for uint16") # raise ValueError("Values are out of range for uint16")
# change to compact datatypes # change to compact datatypes
features = train_dataset.features.copy() # features = train_dataset.features.copy()
features['input_ids'] = Sequence(Value('uint16')) # features['input_ids'] = Sequence(Value('uint16'))
features['attention_mask'] = Sequence(Value('bool')) # features['attention_mask'] = Sequence(Value('uint16'))
features['labels'] = Sequence(Value('uint16')) # features['labels'] = Sequence(Value('uint16'))
features['decoder_input_ids'] = Sequence(Value('uint16')) # features['decoder_input_ids'] = Sequence(Value('uint16'))
features['decoder_attention_mask'] = Sequence(Value('bool')) # features['decoder_attention_mask'] = Sequence(Value('uint16'))
train_dataset = train_dataset.cast(features) # train_dataset = train_dataset.cast(features)
# assign the dataset to train_dataset # assign the dataset to train_dataset
self.train_dataset = train_dataset self.train_dataset = train_dataset
@ -140,15 +140,15 @@ class DataPrepare():
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: v for k, v in batch.items()}
yield batch yield batch
# testing out the class # testing out the class
# # %% # %%
# # init object # init object
# # e.g. Config # e.g. Config
# data_config = ConfigDict( # data_config = ConfigDict(
# dict( # dict(
# max_length=86, # max_length=86,
@ -157,7 +157,9 @@ class DataPrepare():
# ) # )
# ) # )
# #
# dataprep = DataPrepare(split_datasets, data_config) # from datasets import load_from_disk
# split_datasets = load_from_disk(file_path)
# dataprep = DataPrepare(split_datasets['train'], data_config)
# #
# # %% # # %%
# seed = 117 # seed = 117

View File

@ -0,0 +1,456 @@
# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import flax.core
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# MARK: sharding example
# %%
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec)
# %%
# define a layer
class DotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
# y has shape (x.shape[-1], self.depth)
# we replicate data across devices
# but we shard the layer across the model axes
y = nn.Dense(self.depth,
kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
# annotate intermediate variables to force a particular sharding pattern
# when ideal constraint is known
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_partitioning(self.dense_init, ('model', None)),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
return z, None
# %%
class MLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(
DotReluDot,
length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for _ in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x
# %%
# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = jax.random.key(117)
# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)
# %%
# specify sharding
# shard data
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)
# shard output
# we will shard state by tracking its output upon jax.eval_shape after init
# define an init function to return a TrainState
def init_fn(key, x, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
variables = model.init(key, x) # Initialize the model.
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
print(state_sharding)
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
out_shardings=state_sharding
)
initialized_state = jit_init_fn(k, x, model, optimizer)
# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
# print(f'Sharding of {weight}: {partitioned.names}')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
# %%
# inspect module output
# the params are actually linen.Partitioned objects
# the Partition objects are actually wrapping jax.Array
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
# %%
# Say for some unknown reason you want to make the whole param tree all-zero
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0
# %%
# check the jax.sharding of each parameter
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding)
print(initialized_state.step)
print(initialized_state.step.sharding)
# %%
# example of computation on pytree
diff = jax.tree.map(
lambda a, b: a - b,
initialized_state.params['DotReluDot_0'],
initialized_state.params['DotReluDot_0']
)
print(jax.tree.map(jnp.shape, diff))
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
# %%
# compile train step
@functools.partial(
jax.jit,
in_shardings=(state_sharding, x_sharding),
out_shardings=state_sharding
)
def train_step(state, x):
def loss_unrolled(params):
y = model.apply({'params': params}, x)
return y.sum()
grad_fn = jax.grad(loss_unrolled)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
# this trains for 1 step
# with mesh: # not strictly necessary in this case
# with mesh block is useful for explicit scope for device sharding
# but mesh management is automatic via jit sharding annotations
new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
# %%
# compile inference step
@functools.partial(
jax.jit,
in_shardings=(state_sharding, x_sharding),
out_shardings=x_sharding
)
def apply_fn(state, x):
return state.apply_fn({'params': state.params}, x)
# this infers for 1 step
with mesh:
y = apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)
# %%
# profiling
# measure performance
import timeit
def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs
# with mesh:
t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10)
print(t)
# MARK: logical sharding
# %%
# logical axis annotation
# why?
# rather than just be fixed with 'data' and 'model', we can annotate with logical names
# then map these logical names back to 'data' and 'model',
# or be more flexible with more axes
#
# we will substitute with the following:
# flax.linen.with_partitioning -> flax.linen.with_logical_partitioning
# flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint
class LogicalDotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
y = nn.Dense(
self.depth,
# use of logical partitioning here
# kernel_init is the initializer function for the weight matrix
kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = nn.with_logical_constraint(z, ('batch', 'embed'))
return z, None
class LogicalMLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: 'layer'}
)(self.depth)(x)
else:
for _ in range(self.num_layers):
x, _ = LogicalDotReluDot(self.depth)(x)
return x
# %%
# we initialize the model with the eval_shape method
# but we need to perform a rule to replace logical axes with real axes
rules =(
('batch', 'data'),
('hidden', 'model')
)
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
logical_abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=logical_model, optimizer=optimizer),
k,
x,
)
# linen.get_partition_spec
# extracts a partitionspec tree from a pytree containing partitioned values
# linen.Partitioned
# wrapper for partitioning metadata
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print(
"annotations are logical, not mesh specific",
logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel']
)
# we convert our logical_state_spec to a logical_state_sharding
# with defined rules
logical_state_sharding = nn.logical_to_mesh_sharding(
logical_state_spec,
mesh,
rules
)
print('sharding annotations are mesh-specific: ',
logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
# %%
# with a working state_sharding object, we can init
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
out_shardings=logical_state_sharding)
logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
# %%
# MARK: saving checkpoint
# https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing
# let us save the model
# since we already have a mesh, we will skip the making of the mesh
from typing import Optional, Any
import shutil
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint
import optax
ckpt_dir = '/tmp/flax_ckpt'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
# %%
# make up some stuff
# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,)) # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}
# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
# %%
# single host save with orbax
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)
# %%
# multi-process checkpointing
# aka checkpointing for sharding
# The reference doesn't need to be as large as your checkpoint!
# Just make sure it has the `.sharding` you want.
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html
import orbax.checkpoint as ocp
from etils import epath
path = epath.Path('/tmp/async_checkpoint')
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path, args=ocp.args.StandardSave(train_state))
### Do some other work...
ckptr.wait_until_finished()

View File

@ -737,7 +737,8 @@ start_time = time.time()
for _ in range(15): for _ in range(15):
state_fsdp, metrics_fsdp = train_step_fsdp_fn( state_fsdp, metrics_fsdp = train_step_fsdp_fn(
state_fsdp, state_fsdp,
metrics_fsdp, batch) metrics_fsdp,
batch)
duration = time.time() - start_time duration = time.time() - start_time
print(duration) print(duration)
@ -747,7 +748,8 @@ final_metrics_fsdp = jax.tree.map(
metric_shapes) metric_shapes)
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn( state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
state_fsdp, state_fsdp,
final_metrics_fsdp, batch) final_metrics_fsdp,
batch)
print_metrics(final_metrics_fsdp, "FSDP - Final metrics") print_metrics(final_metrics_fsdp, "FSDP - Final metrics")

View File

@ -8,7 +8,7 @@
import os import os
# Set this to True to run the model on CPU only. # Set this to True to run the model on CPU only.
USE_CPU_ONLY = True USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "") flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY: if USE_CPU_ONLY:
@ -20,10 +20,10 @@ else:
# GPU flags # GPU flags
flags += ( flags += (
"--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false " # "--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true " # "--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true " # "--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true " # "--xla_gpu_enable_highest_priority_async_stream=true "
) )
os.environ["XLA_FLAGS"] = flags os.environ["XLA_FLAGS"] = flags
@ -37,7 +37,8 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.experimental.shard_map import shard_map from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict from ml_collections import ConfigDict
import optax import optax
@ -148,7 +149,7 @@ dataprep = DataPrepare(split_datasets['train'], data_config)
# seed = 117 # seed = 117
# rng = jax.random.PRNGKey(seed) # rng = jax.random.PRNGKey(seed)
# train_loader = dataprep.data_loader(rng, batch_size=1) # train_loader = dataprep.data_loader(rng, batch_size=1)
# batch = next(iter(train_loader))
# %% # %%
# model # model
@ -161,8 +162,8 @@ config = T5Config()
# If you want don't want to cast certain parameters (for example layer norm bias and scale) # If you want don't want to cast certain parameters (for example layer norm bias and scale)
# then pass the mask as follows # then pass the mask as follows
from flax import traverse_util from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
# useful for transformer model # useful for transformer model
model.enable_gradient_checkpointing() model.enable_gradient_checkpointing()
@ -174,6 +175,21 @@ mask = {
mask = traverse_util.unflatten_dict(mask) mask = traverse_util.unflatten_dict(mask)
model.params = model.to_bf16(model.params, mask) model.params = model.to_bf16(model.params, mask)
# %%
# %%
from jax.sharding import Mesh, NamedSharding
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
from pjit_partition import set_partitions
devices = np.asarray(jax.devices())
mesh_axis_names = ('data')
mesh = Mesh(devices, 'batch')
sharding = NamedSharding(mesh, P(mesh_axis_names))
replicated_sharding = NamedSharding(mesh, P())
# %% [markdown] # %% [markdown]
# # Model # # Model
@ -243,22 +259,9 @@ adamw = optax.adamw(
# %% # %%
# Training functions # state will serve as our "params"
class TrainState(train_state.TrainState): state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
dropout_rng: jnp.ndarray
# easy way to achieve data parallelism
# also achieves folding of rng keys
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# set bf16 for model params
# model.params = model.to_bf16(model.params)
# Cast parameters to bfloat16 if desired
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy # label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@ -283,9 +286,10 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
num_labels = padding_mask.sum() num_labels = padding_mask.sum()
return loss, num_labels return loss, num_labels
# MARK: train_step
# Define gradient update step fn # Define gradient update step fn
@jax.jit def train_step(state, batch):
def train_step(state, batch, label_smoothing_factor=0.0): label_smoothing_factor=0.0
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params): def compute_loss(params):
@ -311,27 +315,22 @@ def train_step(state, batch, label_smoothing_factor=0.0):
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics return new_state, metrics
# Define generation function # max_length = (
max_length = ( # val_max_target_length if val_max_target_length is not None else model.config.max_length
val_max_target_length if val_max_target_length is not None else model.config.max_length # )
) # num_beams = num_beams if num_beams is not None else model.config.num_beams
num_beams = num_beams if num_beams is not None else model.config.num_beams # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# def generate_step(params, batch):
# model.params = params
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
# return output_ids.sequences
# Create parallel version of the train and eval step # Create parallel version of the train and eval step
p_train_step = jax.pmap( # only state and batch
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) p_train_step = jax.jit(
train_step,
# state for first, batch for second
in_shardings=(P("data"), P("data")),
out_shardings=(P("data"), P("data")),
donate_argnames=("state"),
) )
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
# p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
@ -349,6 +348,21 @@ print(f" Total optimization steps = {total_train_steps}")
# %% # %%
# jax.profiler.start_trace("./traces") # jax.profiler.start_trace("./traces")
# Example batch (sharded across devices)
sharded_batch = {
'input_ids': jax.device_put_sharded(batch['input_ids'], devices),
'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices),
'labels': jax.device_put_sharded(batch['labels'], devices),
'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices),
'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices),
}
# Initial TrainState (pjit-ted TrainState)
sharded_state = jax.device_put_replicated(train_state, devices)
# %%
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
train_time = 0 train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
@ -363,7 +377,7 @@ for epoch in epochs:
# Generate an epoch by shuffling sampling indices from the train dataset # Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader) batch = next(train_loader)
batch = shard(batch) # batch = shard(batch)
state, train_metric = p_train_step(state, batch) state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric) train_metrics.append(train_metric)

View File

@ -0,0 +1,541 @@
# %% [markdown]
# # T5 implementation using jax with pjit
# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import flax.core
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# # %%
# import jax
# import jax.numpy as jnp
# import optax
# import numpy as np
# from functools import partial
# from typing import Callable, Optional
# import math
#
# # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
# jax.config.update("jax_default_matmul_precision", "bfloat16")
# # jax.config.update("jax_enable_x64", False)
# # enable cache
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
#
#
# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
#
# from flax import jax_utils, traverse_util
# from flax.jax_utils import pad_shard_unpad, unreplicate
# from flax.training import train_state
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
# import flax.core
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1_bf16'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 384 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
data_config = ConfigDict(
dict(
max_length=86,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=1)
batch = next(iter(train_loader))
# %%
batch
# %%
# model
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
config = T5Config()
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
# then pass the mask as follows
from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
# useful for transformer model
model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
flat_params = traverse_util.flatten_dict(model.params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
model.params = model.to_bf16(model.params, mask)
# %%
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
t5_module = model.module
# %%
jax.tree.map(jnp.shape, model.params)
# %%
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec
from pjit_partition import set_partitions
params = model.params
data_partition_specs = PartitionSpec()
extra_param_keys = list(model._missing_keys)
initial_partition_specs = set_partitions(params)
# this is the partition spec we will use
filled_param_partition_specs = set_partitions(params, extra_keys=extra_param_keys)
# %%
# let us see the param_partition_spec
filled_param_partition_specs
# %% let us set up the mesh
from jax.sharding import Mesh
devices = np.asarray(jax.devices())
# %%
# mp: model/tensor parallelism
# dp: data parallelism
# we just use 'data' as a common axis for data and model params
mesh_axis_names = ("data")
print("Logical mesh:", devices)
mesh = Mesh(devices, mesh_axis_names)
# it is technically possible to use pjit_partition to set special partition rules
# e.g. by param size
# but for now just move on
# %% [markdown]
# # Model
#
#
#
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
# class TrainState(train_state.TrainState):
# dropout_rng: jnp.ndarray
#
# # easy way to achieve data parallelism
# # also achieves folding of rng keys
# def replicate(self):
# return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# set bf16 for model params
# model.params = model.to_bf16(model.params)
# Cast parameters to bfloat16 if desired
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# %%
# see if we can init the model
# %%
from transformers import FlaxT5Model, T5Config
config = T5Config.from_pretrained('t5-base')
model = FlaxT5Model(config, _do_init=True).module
# %%
# Initialize random key and input for initialization
rng = jax.random.PRNGKey(0)
train_loader = dataprep.data_loader(rng, batch_size=1)
batch = next(iter(train_loader))
# %%
# Initialize model parameters
# init of FlaxT5Module.__call__
variables = model.init(rng,
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask']
)
params = variables['params']
# %%
# create an init_fn
def init_fn(rng: jax.random.PRNGKey, batch, model) -> train_state.TrainState:
init_rng, rng = jax.random.split(rng)
variables = model.init(
init_rng,
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask']
)
params = variables.pop("params")
state = train_state.TrainState.create(
apply_fn=model.__call__,
params=params,
tx=adamw,
)
return state
# %%
# we do not know the output PartitionSpec
# we perform the hack where we just initialize it just to find the outspec
init_fn_try = shard_map(
functools.partial(init_fn, model=model),
mesh,
# 2nd argument is for the model
in_specs=(P(), P("data")),
out_specs=P(),
check_rep=False
)
# %%
rng, model_init_rng = jax.random.split(rng)
train_loader = dataprep.data_loader(model_init_rng, batch_size=batch_size)
batch = next(iter(train_loader))
state_fsdp_shapes = jax.eval_shape(init_fn_try, model_init_rng, batch)
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
# print("RNG", state_fsdp_specs.rng)
print("\nParameters")
pprint(state_fsdp_specs.params)
print("\nOptimizer state")
pprint(state_fsdp_specs.opt_state[0])
# note: state_fsdp_specs is now ready to be used as pjit outspec
# %%
# Setup train state
# state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
state = jax.jit(
init_fn,
in_shardings=(P(), P("data")),
out_shardings=state_fsdp_specs,
)
# %%
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# MARK: train_step
# Define gradient update step fn
def train_step(state, batch):
label_smoothing_factor=0.0
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
# loss = jax.lax.psum(loss, "batch")
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# max_length = (
# val_max_target_length if val_max_target_length is not None else model.config.max_length
# )
# num_beams = num_beams if num_beams is not None else model.config.num_beams
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# Create parallel version of the train and eval step
# only state and batch
p_train_step = jax.jit(
train_step,
# state for first, batch for second
in_shardings=(P("data"), P("data")),
out_shardings=(P("data"), P("data")),
donate_argnames=("state"),
)
# %%
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
# Example batch (sharded across devices)
sharded_batch = {
'input_ids': jax.device_put_sharded(batch['input_ids'], devices),
'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices),
'labels': jax.device_put_sharded(batch['labels'], devices),
'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices),
'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices),
}
# Initial TrainState (pjit-ted TrainState)
sharded_state = jax.device_put_replicated(train_state, devices)
# %%
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
rng, data_rng = jax.random.split(rng)
train_loader = dataprep.data_loader(data_rng, batch_size=batch_size)
steps_per_epoch = training_size // train_batch_size
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
# batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time = time.time() - train_start
train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready()
epochs.write(
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Epoch... ({epoch + 1}/{num_epochs} | "
# f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# output_dir = save_path
# # save checkpoint after each epoch and push checkpoint to the hub
# if jax.process_index() == 0:
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# model.save_pretrained(output_dir, params=params)
# tokenizer.save_pretrained(output_dir)