Feat: fsdp demo

Refactor: pulling dataloader code into dataload.py
This commit is contained in:
Richard Wong 2024-09-15 22:41:00 +09:00
parent 005a1a5735
commit ad5cf7735f
7 changed files with 2136 additions and 1 deletions

1
parallel/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

172
parallel/dataload.py Normal file
View File

@ -0,0 +1,172 @@
# %%
# Prepare dataloader for jax training
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
from transformers import FlaxT5ForConditionalGeneration
from datasets import ClassLabel, Value, Sequence
from ml_collections import ConfigDict
import numpy as np
import jax.numpy as jnp
import jax
import math
from typing import Optional, List, Tuple, Callable, cast
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
# file_path = 'combined_data'
# split_datasets = load_from_disk(file_path)
# training_size = len(split_datasets['train'])
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
# class takes in a dataset
class DataPrepare():
def __init__(self, raw_dataset, config):
self.raw_dataset: Dataset = raw_dataset
self.train_dataset: Optional[Dataset] = None
self.size: int = len(raw_dataset)
self.config: ConfigDict = config
self.make_dataset()
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(self, example: Dataset):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = tokenizer(
inputs,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
def make_dataset(self):
train_dataset = self.raw_dataset.map(
self.preprocess_function,
batched=True,
num_proc=1,
# if we do not remove, we keep the original data
remove_columns=self.raw_dataset.column_names,)
# set to numpy
train_dataset.set_format(
type='numpy',
columns=[
'input_ids', 'attention_mask',
'labels', 'decoder_input_ids',
'decoder_attention_mask']
)
# check that data fits
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
int_array: np.array = train_dataset[name]
if np.all((int_array >= 0) & (int_array <= 65535)):
continue
else:
raise ValueError("Values are out of range for uint16")
# change to compact datatypes
features = train_dataset.features.copy()
features['input_ids'] = Sequence(Value('uint16'))
features['attention_mask'] = Sequence(Value('bool'))
features['labels'] = Sequence(Value('uint16'))
features['decoder_input_ids'] = Sequence(Value('uint16'))
features['decoder_attention_mask'] = Sequence(Value('bool'))
train_dataset = train_dataset.cast(features)
# assign the dataset to train_dataset
self.train_dataset = train_dataset
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
assert(self.train_dataset is not None)
dataset: Dataset = cast(Dataset, self.train_dataset)
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
# testing out the class
# # %%
# # init object
# # e.g. Config
# data_config = ConfigDict(
# dict(
# max_length=86,
# pad_token_id=0,
# decoder_start_token_id=0
# )
# )
#
# dataprep = DataPrepare(split_datasets, data_config)
#
# # %%
# seed = 117
# rng = jax.random.PRNGKey(seed)
# train_loader = dataprep.data_loader(rng, batch_size=32)
#
#
#
# # %%
# batch = next(iter(train_loader))
# batch['input_ids'].shape
# # %%

View File

@ -0,0 +1,754 @@
# %% [markdown]
# # Fully-Sharded Data Parallelism
# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
import optax
import logging
import time
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
jax.config.update('jax_platform_name', 'cpu')
# %%
# required functions:
# Batch
# TrainState
# accumulate_gradients
# print_metrics
from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics
# %%
# import the fold_rng_over_axis
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
"""Folds the random number generator over the given axis.
This is useful for generating a different random number for each device
across a certain axis (e.g. the model axis).
Args:
rng: The random number generator.
axis_name: The axis name to fold the random number generator over.
Returns:
A new random number generator, different for each device index along the axis.
"""
axis_index = jax.lax.axis_index(axis_name)
return jax.random.fold_in(rng, axis_index)
# MARK: DATA PARALLELISM
# %% [markdown]
# # Data Parallelism
# we start with plain data parallelism
#
# using shard_map, we write single-device code and let shard map handle the rest
# %%
# plain data parallel - sharding only data inputs and outputs
class DPClassifier(nn.Module):
# contains the attributes listed in config
# hidden_size
# dropout_rate
# dtype - for computation
# num_classes
# data_axis_name
config: ConfigDict
# note how there is no data_axis_name within the actual __call__
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
x = nn.Dense(
features=self.config.hidden_size,
dtype=self.config.dtype,
name="input_dense",
)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
x = nn.Dense(
features=self.config.num_classes,
dtype=self.config.dtype,
name="output_dense",
)(x)
x = x.astype(jnp.float32)
return x
# config
data_config = ConfigDict(
dict(
batch_size=128,
num_classes=10,
input_size=784,
)
)
model_config = ConfigDict(
dict(
hidden_size=512,
dropout_rate=0.1,
dtype=jnp.bfloat16,
num_classes=data_config.num_classes,
data_axis_name="data",
)
)
optimizer_config = ConfigDict(
dict(
learning_rate=1e-3,
num_minibatches=4,
)
)
config = ConfigDict(
dict(
model=model_config,
optimizer=optimizer_config,
data=data_config,
data_axis_name=model_config.data_axis_name,
seed=42,
)
)
# %%
# initialize
model_dp = DPClassifier(config=config.model)
optimizer = optax.adamw(
learning_rate=config.optimizer.learning_rate,
)
# init rng
rng = jax.random.PRNGKey(config.seed)
# init model rng
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
# create synthetic data
batch = Batch(
inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
labels=jax.random.randint(
data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
),
)
# init data_parallel TrainState state
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
init_rng, rng = jax.random.split(rng)
variables = model.init({"params": init_rng}, x, train=False)
params = variables.pop("params")
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
rng=rng,
)
return state
# create mesh
device_array = np.array(jax.devices())
mesh = Mesh(device_array, (config.data_axis_name,))
# we are just sharding the same model across devices
# no different from a flax replicate
init_dp_fn = jax.jit(
shard_map(
functools.partial(init_dp, model=model_dp),
mesh,
in_specs=(P(), P(config.data_axis_name)),
out_specs=P(),
check_rep=False,
),
)
state_dp = init_dp_fn(model_init_rng, batch.inputs)
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
# MARK: TRAIN STEP
# %%
# train step
def loss_fn(
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
# set different rng over various devices
dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
# Remaining computation is the same as before for single device.
logits = apply_fn(
{"params": params},
batch.inputs,
train=True,
rngs={"dropout": dropout_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
batch_size = batch.inputs.shape[0]
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
loss = loss.mean()
return loss, step_metrics
# train step dp
# simple data parallel has the model on every device
# but each device has different data
def train_step_dp(
state: TrainState,
metrics: Metrics | None,
batch: Batch,
) -> Tuple[TrainState, Metrics]:
rng, step_rng = jax.random.split(state.rng)
# accumulate gradients like before
grads, step_metrics = accumulate_gradients(
state,
batch,
step_rng,
config.optimizer.num_minibatches,
loss_fn=loss_fn,
)
# Update parameters. We need to sync the gradients across devices before updating.
with jax.named_scope("sync_gradients"):
grads = jax.tree.map(
lambda g: jax.lax.pmean(
g, axis_name=config.data_axis_name),
grads)
new_state = state.apply_gradients(grads=grads, rng=rng)
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
# and only synchronize them before logging. For simplicity, we sum them here.
with jax.named_scope("sync_metrics"):
step_metrics = jax.tree.map(
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
)
if metrics is None:
metrics = step_metrics
else:
# combine all the synced metrics
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
# we will now wrap the train step with shard_map and jit it
# here we will be sharding input and output data
train_step_dp_fn = jax.jit(
shard_map(
train_step_dp,
mesh,
in_specs=(P(), P(), P(config.data_axis_name)),
out_specs=(P(), P()),
check_rep=False,
),
# state and metrics change and won't be re-used
# pass by reference and throw away with function
donate_argnames=("state", "metrics"),
)
# %%
# get the metric_shapes so that we can init arrays for accumulation
_, metric_shapes = jax.eval_shape(
train_step_dp_fn,
state_dp,
None,
batch,
)
# init arrays with shape
metrics_dp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
# %%
start_time = time.time()
for _ in range(15):
state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
duration = time.time() - start_time
print(duration)
final_metrics_dp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
state_dp, final_metrics_dp = train_step_dp_fn(
state_dp,
final_metrics_dp,
batch)
print_metrics(final_metrics_dp)
# %%
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
print("Metrics")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp))
####################################################################
# stuff works until here
# it is still same as flax replicate style in huggingface
# MARK: PARAMETER SHARDING
# %% [markdown]
# # parameter sharding
# Basic strategy: init full parameters on each device, then use
# jax.lax.axis_index to split parameters across devices, and keep a shard on
# each device
#
# use nn.Partitioned to annotate sharding spec on parameters
# quite similar to PartitionSpec
#
# parameters are either jax.Array or a flax.linen.Partitioned
# %%
# type annotation
Parameter = jax.Array | nn.Partitioned
# %%
# function to shard parameters across devices
# look for an axis to equally split across the number of devices
# we can specify which parameters to shard, since they vary in size
# we set a floor on the size for sharding
@jax.named_scope("shard_params")
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
"""Shard parameters across the given mesh axis.
Args:
params: The parameters to shard.
axis_name: The axis to shard parameters across.
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
Returns:
PyTree of same structure as params, but with leaves sharded over new axis if possible.
"""
# axis_index
axis_idx = jax.lax.axis_index(axis_name)
# number of units in the axis
axis_size = jax.lax.psum(1, axis_name)
# split function
# check each parameter if it had been sharded
def _split(x: Parameter) -> Parameter:
# already sharded
if isinstance(x, nn.Partitioned):
value, names = x.value, x.names
# not sharded
else:
value = x
names = (None,) * value.ndim
# logging only runs on first jit
# this section checks for why a parameter is not already sharded on the axis
# check for sharded parameters despite being sharded
# (that means its on a different axis)
if axis_name in names:
logging.warning(
f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
)
return x
# check if parameter is to small
elif value.size <= min_weight_size:
logging.info(
f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
)
return x
# let's start sharding!
else:
shape = value.shape
idx = np.argsort(shape)[::-1] # Shard along largest possible axis.
for i in idx:
# this technically runs once because of return
# we only shard if we can split evenly across devices
# and if it ain't alreayd sharded
if shape[i] % axis_size == 0 and names[i] is None:
split_size = shape[i] // axis_size
p_sharded = nn.Partitioned(
value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device.
value,
axis_idx * split_size,
split_size,
axis=i
),
names=names[:i] + (axis_name,) + names[i + 1 :],
)
return p_sharded
logging.warning(
f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
)
return x
# we apply the _split function across the parameter pytree
return jax.tree.map(
_split,
params,
is_leaf=lambda x: isinstance(
x, nn.Partitioned
), # Consider a nn.Partitioned object as a leaf.
)
# %%
# function to gather parameters back to a single device
# but first we need create a custom function for mean gradient computation
# jax.lax.all_gather -> retrieve shards and assemble full array in each device
# jax.lax.psum_scatter -> scatter gradients back to respective devices
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
"""Gathering with averaging gradients across replicas."""
axis_size = jax.lax.psum(1, axis_name)
# Define a custom gradient for the gather operation.
@jax.custom_gradient
def f(x):
# adjust backward to turn sum into mean of axis
def grad_fn(g):
# pmean_scatter from psum_scatter
# after computing from full gradient array, our shard only has a
# portion of the parameters, we only get the gradients associated
# with parameters of our shard
return (
jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
)
# assemble shards to form full gradient array
return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn
return f(x)
# gather params back - e.g. when computing a module forward call
# reverse operation of "shard_params"
# depends on: gather_array_with_mean_grads
@jax.named_scope("gather_params")
def gather_params(params: PyTree, axis_name: str) -> PyTree:
"""Gather parameters from all replicas across the given axis.
Args:
params: The parameters to gather.
axis_name: The axis to gather parameters across.
Returns:
PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
"""
def _gather(p: Parameter) -> Parameter:
if isinstance(p, nn.Partitioned) and axis_name in p.names:
param_shard = p.names
shard_axis = param_shard.index(axis_name)
value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
# If there are any other axes that are sharded, we need to keep the partitioned structure.
# Otherwise, we can return the value directly.
param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
if any([name is not None for name in param_shard]):
# we return the still-sharded axes shard
return nn.Partitioned(value, param_shard)
else:
return value
else:
return p
# we find all the sharded params and gather them, returning a complete parameter
return jax.tree.map(
_gather,
params,
is_leaf=lambda x: isinstance(x, nn.Partitioned))
# %%
# when we call a module, we gather the parameters back to a single device
# wrap a module into a nn.map_variables transform
# allows for transforms on the parameter before and after a module call
# depends on: gather_params, shard_params
def shard_module_params(
target: nn.Module | Callable,
axis_name: str,
min_weight_size: int = 2**18 # 262,144
) -> nn.Module | Callable:
"""Shard parameters of a module across replicas.
Args:
target: The module to shard.
axis_name: The axis name to shard parameters across.
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
Returns:
The module with sharded parameters.
"""
return nn.map_variables(
target,
trans_in_fn=functools.partial(
gather_params, axis_name=axis_name),
trans_out_fn=functools.partial(
shard_params, axis_name=axis_name, min_weight_size=min_weight_size
),
mapped_collections="params",
mutable=True,
)
# %%
# define new function with axes constraints
# this forms the template for sharding future modules
# remember, flax modules are subclassed from elementary flax modules
class FSDPClassifier(nn.Module):
config: ConfigDict
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
# create a sharded module
sharded_dense = shard_module_params(
nn.Dense,
axis_name=self.config.data_axis_name, # axes
min_weight_size=self.config.min_weight_size, # min_weight
)
x = sharded_dense(
features=self.config.hidden_size,
dtype=self.config.dtype,
name="input_dense",
)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
x = sharded_dense(
features=self.config.num_classes,
dtype=self.config.dtype,
name="output_dense",
)(x)
x = x.astype(jnp.float32)
return x
# %%
# initialization
config.model.min_weight_size = 2**4
model_fsdp = FSDPClassifier(config=config.model)
# the earlier init function
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
init_rng, rng = jax.random.split(rng)
variables = model.init({"params": init_rng}, x, train=False)
params = variables.pop("params")
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
rng=rng,
)
return state
# initialize our sharded model with mesh
# we need to adjust the shard map since partitioning is determined within the
# model init, hence we cannot manually specify it
#
# we do a hack where we just try and let it evaluate the shapes
# we set an unknown output specification - aka fully replicate
#
# we then get the partition_spec of the shapes of the parameters
init_fsdp_fn = shard_map(
functools.partial(init_dp, model=model_fsdp),
mesh,
# first P() is for model_init_rng
# second P(config.data_axis_name) is for batch.inputs
in_specs=(P(), P(config.data_axis_name)),
# not partitioned, fully replicated
out_specs=P(),
check_rep=False, # disable checks for replication errors in out_specs
)
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
# %% [raw]
# TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of FSDPClassifier(
# # attributes
# config = data_axis_name: data
# dropout_rate: 0.1
# dtype: !!python/name:jax.numpy.bfloat16 ''
# hidden_size: 512
# min_weight_size: 16
# num_classes: 10
#
# )>,
# params={
# 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
# tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x761e8ef00400>,
# update=<function chain.<locals>.update_fn at 0x761e8ef01080>),
# opt_state=(ScaleByAdamState(count=PartitionSpec(),
# mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
# nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}),
# EmptyState(), EmptyState()), rng=PartitionSpec())
# %%
# then from the state_fsdp_specs, we obtain our config
# this print clarifies everything -> the reason why earlier we do not know the
# partitionspec is because we only know which parameters gets to be sharded at
# model init
print("RNG", state_fsdp_specs.rng)
print("\nParameters")
pprint(state_fsdp_specs.params)
print("\nOptimizer state")
pprint(state_fsdp_specs.opt_state[0])
# %%
# init again, this time with the specs and knowledge of what is and should not
# be sharded
init_fsdp_fn = jax.jit(
shard_map(
functools.partial(init_dp, model=model_fsdp),
mesh,
in_specs=(P(), P(config.data_axis_name)),
out_specs=state_fsdp_specs,
check_rep=False,
)
)
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)
# %%
print("FSDP Parameters")
pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params)))
# %%
# train step
# we need to handle the sync of gradients
# some parameters are sharded, some are not
def sync_gradients(
grads: PyTree,
axis_names: Sequence[str],
) -> PyTree:
"""Synchronize gradients across devices.
Gradients for parameters that are replicated over a given axis are averaged across devices.
Parameters that are partitioned over a given axis are considered to already have a mean of
the gradients on each device, and hence do not need to be altered.
Args:
grads: The gradients to synchronize.
axis_names: The axis names to synchronize gradients across.
Returns:
The gradients averaged over the specified axes if they are replicated.
"""
def sync_grad(g: Parameter) -> Parameter:
if isinstance(g, nn.Partitioned):
# Tree leaves for flattening potentially nested axis (multiple names
# can exist for single array axis).
replication_axis_names = [
name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
]
if len(replication_axis_names) == 0:
# Parameters partitioned over all axes.
return g
else:
# Average over remaining replicated axes.
return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
else:
# Parameters are replicated over all axes.
return jax.lax.pmean(g, axis_name=axis_names)
return jax.tree.map(
sync_grad,
grads,
is_leaf=lambda x: isinstance(x, nn.Partitioned))
# %%
def train_step_fsdp(
state: TrainState,
metrics: Metrics,
batch: Batch,
) -> Tuple[TrainState, Metrics]:
rng, step_rng = jax.random.split(state.rng)
# perform one forward pass
grads, step_metrics = accumulate_gradients(
state,
batch,
step_rng,
config.optimizer.num_minibatches,
loss_fn=loss_fn,
)
# Update parameters. We need to sync the gradients across devices before updating.
with jax.named_scope("sync_gradients"):
grads = sync_gradients(grads, (config.data_axis_name,))
# then update model
new_state = state.apply_gradients(grads=grads, rng=rng)
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
# and only synchronize them before logging. For simplicity, we sum them here.
with jax.named_scope("sync_metrics"):
step_metrics = jax.tree.map(
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
)
if metrics is None:
metrics = step_metrics
else:
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
# jit the train_step_fsdp
train_step_fsdp_fn = jax.jit(
shard_map(
train_step_fsdp,
mesh,
in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
out_specs=(state_fsdp_specs, P()),
check_rep=False,
),
donate_argnames=("state", "metrics"),
)
# get the metric shape to initialize accumulator arrays for metrics
_, metric_shapes = jax.eval_shape(
train_step_fsdp_fn,
state_fsdp,
None,
batch,
)
metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
# %%
# train
start_time = time.time()
for _ in range(15):
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
state_fsdp,
metrics_fsdp, batch)
duration = time.time() - start_time
print(duration)
# get metrics and state
final_metrics_fsdp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
state_fsdp,
final_metrics_fsdp, batch)
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
# %%

View File

@ -0,0 +1,373 @@
# %% [markdown]
# # Distribute computin in JAX
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
# %%
import functools
from typing import Any, Dict, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
jax.config.update('jax_platform_name', 'cpu')
# %%
jax.devices()
# %%
# when we create array, we can check the location
a = jnp.arange(8)
print("Array", a)
print("Device", a.device)
print("Sharding", a.sharding)
# %% [markdown]
# ## Single-Axis Mesh
# %%
# let's create a Mesh
# multidimensional Numpy array of jax devices
# jax.sharding.Mesh(devices, axis_names)
mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",))
print(mesh)
# %%
# jax.sharding.NamedSharding(mesh, spec)
# pair of a Mesh of devices and PartitionSpec
# PartitionSpec describes how to share an array across that mesh
# "i" is the value of the dimension of the array
# to shard an array axis over a certain mesh axis, add the axis name at the
# corresponding position in the tuple
sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",))
# %%
a_sharded = jax.device_put(a, sharding)
print("Sharded array", a_sharded)
print("Device", a_sharded.devices())
print("Sharding", a_sharded.sharding)
# %%
jax.debug.visualize_array_sharding(a_sharded)
# %%
# let's try some computation on the mesh
out = nn.tanh(a_sharded)
print("Output array", out)
jax.debug.visualize_array_sharding(out)
# note how the output array is sharded across the devices
# %% [markdown]
# ## multi-axis mesh
# Why would you shard across multiple dimensions?
#
#
# %%
mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j"))
# axis i/0 refers to the row-wise axis progressing downwards
# axis j/1 refers to the column-wise axis progressing rightward
mesh # noqa: B018
# %%
# we now illustrate sharded MAC operation
# y = x @ w + b
batch_size = 192
input_dim = 64
output_dim = 128
# input: (batch_size, input_dim)
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
# w: (input_dim, output_dim)
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
# b: (output_dim,)
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
# %%
# x sharded along 0 axis (partition)
#
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None)))
w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j")))
b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j")))
print('x blocks:')
jax.debug.visualize_array_sharding(x_sharded)
print('w blocks:')
jax.debug.visualize_array_sharding(w_sharded)
print('b blocks:')
jax.debug.visualize_array_sharding(b_sharded)
# %%
out = jnp.dot(x_sharded, w_sharded) + b_sharded
print("Output shape", out.shape)
jax.debug.visualize_array_sharding(out)
# %% [markdown]
# # Shard Map -shmap
#
# beforehand, we manually assign the sharding partition to assign the exact
# partitions to achieve independent, parallel block matrix computation
#
# This allows us to write code with explicit control over parallelization and
# communication
#
# what is a shard_map?
#
# it is a transformation that takes a function, a mesh, and a sharding
# specification for inputs and outputs
#
# in other words, we write a function that executes on each device only, then
# apply across all the shards
#
# but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough
# information about the shards to efficiently perform sharding for complicated
# meshes.
# %%
def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
print("Local x shape", x.shape)
print("Local w shape", w.shape)
print("Local b shape", b.shape)
# so simple!
return jnp.dot(x,w) + b
# %%
matmul_sharded = shard_map(
matmul_fn, # the function for operating on a single device
mesh, # the device topology
# the input mesh partition argument for each input
in_specs=(
PartitionSpec("i", None), # x
PartitionSpec(None, "j"), # w
PartitionSpec("j") # b
),
# the output to read from the mesh
out_specs=PartitionSpec("i", "j")
)
# %%
# y = matmul_sharded(x_sharded, w_sharded, b_sharded)
# there is no need to device_put,
# partitioning is done according to your in_specs
y = matmul_sharded(x, w, b)
print("Output shape", y.shape)
jax.debug.visualize_array_sharding(y)
# %% [markdown]
# # Axis Communication
# %%
# example of mean/sum across devices per shard
# the following wants to find the statistics of x
# we compute the normalized x according to each row statistics (mean and std)
@functools.partial(
shard_map,
mesh=mesh,
in_specs=PartitionSpec("i", "j"),
out_specs=PartitionSpec("i", "j"))
def parallel_normalize(x: jax.Array) -> jax.Array:
# jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis
# "axis_name"
# get the mean across the "j" axis of the mesh - column wise
mean = jax.lax.pmean(x, axis_name="j")
# get the std across the "j" axis of the mesh - column wise
std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5
return (x - mean) / std
# communicated along "j" axis of mesh for row elements
out = parallel_normalize(x)
out = jax.device_get(out)
print(out.shape)
print("Mean", out.mean())
print("Std", out.std())
# %%
# scenario: array is sharded across devices, some values missing per shard
# all-gather: gather values of an array from all devices
@functools.partial(
shard_map,
mesh=mesh,
in_specs=(
PartitionSpec("i", None), # artificially shard across "i"
PartitionSpec("i", None)
),
out_specs=PartitionSpec("i", None))
def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array:
print("Original w shape", w.shape)
# pull the full w matrix values from neighboring devices
w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True)
print("Gathered w shape", w_gathered.shape)
y = jnp.dot(x, w_gathered)
return y
out = matmul_with_weight_gather(x, w)
out = jax.device_get(out)
np.testing.assert_array_equal(out, jnp.dot(x, w))
# %%
# scenario: arrays are sharded across all devices
# scatter sum: each function instance of each device gets only one shard of the result
#
# therefore each device gets the sum of some(or one) array(s)
@functools.partial(
shard_map,mesh=mesh,
in_specs=PartitionSpec("i", None),
out_specs=PartitionSpec("i", None))
def scatter_example(x: jax.Array) -> jax.Array:
x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1)
return x_scatter
x_exmp = np.array(
[
[3, 1, 4, 1],
[5, 9, 2, 6],
[5, 3, 5, 8],
[9, 7, 1, 2],
]
)
out = scatter_example(x_exmp)
print("Output", out)
# %%
# ppermute: communicates an array in a round robin fashion
#
# this is used in implementing pipeline parallelism where results are passed to another device
# used in tensor parallelism
#
# notice how the results roll through the devices
#
# this can actually implement all other lax communication operations
@functools.partial(
shard_map,
mesh=mesh,
in_specs=PartitionSpec("i"),
out_specs=PartitionSpec("i"))
def ppermute_example(x: jax.Array) -> jax.Array:
axis_size = mesh.shape["i"]
print('BEFORE:\n', x)
x_perm = jax.lax.ppermute(
x,
axis_name="i",
perm=[
# source_index, destination_index pairs
(i, (i + 1) % axis_size) for i in range(axis_size)
]
)
print('AFTER:\n', x_perm)
return x_perm
x_exmp = np.arange(4)
out = ppermute_example(x_exmp)
print("Output", out) # the value is that of each axis 0 device
# %%
# # axis indexing: get the index of device along axis
# sometimes our computations need adjustment depending on the device its being ran on
#
# we will use jax.lax.axis_index to return the index of the current device along an axis
#
# this function will be jitted and will be almost 0 cost
axis_idx_fn = jax.jit(
shard_map(
lambda: jnp.stack(
[
jax.lax.axis_index("i"), # Device index in mesh along the "i" axis
jax.lax.axis_index("j"), # Device index in mesh along the "j" axis
],
axis=-1,
)[None],
mesh,
in_specs=PartitionSpec(),
out_specs=PartitionSpec(
("i", "j"),
),
)
)
out = axis_idx_fn()
out = jax.device_get(out)
for i in range(out.shape[0]):
print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}")
# %%
# usage 2: fold rng over given axis
# jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key
# from a source RNG key, we generate new RNG keys
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
"""Folds the random number generator over the given axis.
This is useful for generating a different random number for each device
across a certain axis (e.g. the model axis).
Args:
rng: The random number generator.
axis_name: The axis name to fold the random number generator over.
Returns:
A new random number generator, different for each device index along the axis.
"""
axis_index = jax.lax.axis_index(axis_name)
return jax.random.fold_in(rng, axis_index)
# %%
# we fold RNG over the i axis only
# same RNG used across j axis
fold_fn = jax.jit(
shard_map(
# fold over for "i" only
functools.partial(fold_rng_over_axis, axis_name="i"),
mesh,
in_specs=PartitionSpec(),
out_specs=PartitionSpec(
("i", "j"),
),
)
)
rng = jax.random.PRNGKey(0)
out = fold_fn(rng)
out = jax.device_get(out)
for i in range(out.shape[0] // 2):
print(f"Device {i}: RNG={out[2*i:2*i+2]}")
# %%

View File

@ -0,0 +1,441 @@
# MARK: import
# %% [markdown]
# # single gpu optimizaitons
import os
# os.environ["XLA_FLAGS"] = (
# "--xla_gpu_enable_triton_softmax_fusion=true "
# "--xla_gpu_triton_gemm_any=false "
# "--xla_gpu_enable_async_collectives=true "
# "--xla_gpu_enable_latency_hiding_scheduler=true "
# "--xla_gpu_enable_highest_priority_async_stream=true "
# )
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# %%
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import dataclass
from flax.training import train_state
# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
# %% [mardown]
# # bfloat16 mixed precision compute
class MLPClassifier(nn.Module):
dtype: Any # we set the dtype here for computation
hidden_size: int = 256
num_classes: int = 100
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
x = nn.Dense(
features=self.hidden_size,
dtype=self.dtype, # Computation in specified dtype, params stay in float32
)(x)
x = nn.LayerNorm(dtype=self.dtype)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
x = nn.Dense(
features=self.num_classes,
dtype=self.dtype,
)(x)
x = x.astype(jnp.float32)
x = nn.log_softmax(x, axis=-1)
return x
# %%
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
# %%
# inputs and activations (outputs) in bfloat16
# parameters in float32
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
# MARK: GRADIENT CHECKPOINT
# %% [markdown]
# # gradient checkpoint
#
# in jax this is implemented with the remat function
#
# practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes
def gelu(x: jax.Array) -> jax.Array:
"""GeLU activation function with approximate tanh."""
# This will be printed once every time the function is executed.
jax.debug.print("Executing GeLU")
# See https://arxiv.org/abs/1606.08415 for details.
x3 = jnp.power(x, 3)
tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
return 0.5 * x * (1 + jnp.tanh(tanh_input))
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
act_fn = gelu
if remat:
act_fn = jax.remat(act_fn)
return jnp.mean(act_fn(x))
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
grad_fn = jax.grad(loss_fn)
# regenerate function on backward
_ = grad_fn(x, remat=True)
# no remat, no function regeneration
_ = loss_fn(x, remat=False)
#MARK: GRADIENT ACCUMULATION
# %% [markdown]
# # gradient accumulation
#
# run many mini-batches, and accumulate their gradients to feed into optimizer
# as if there were one large batch
# %%
class TrainState(train_state.TrainState):
rng: jax.Array
@dataclass
class Batch:
inputs: jax.Array
labels: jax.Array
# %%
# nothing special here, just a loss function
def classification_loss_fn(
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
"""Classification loss function with cross-entropy."""
logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
batch_size = batch.inputs.shape[0]
# step_metrics contains the loss sum
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
# loss contains the mean loss
mean_loss = loss.mean() # the mathematical output of function
return mean_loss, step_metrics
# %%
# gradient accumulation training loop
def accumulate_gradients_loop(
state: TrainState,
batch: Batch,
rng: jax.random.PRNGKey,
num_minibatches: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch.inputs.shape[0]
minibatch_size = batch_size // num_minibatches
rngs = jax.random.split(rng, num_minibatches)
# Define gradient function for single minibatch.
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
# otherwise it returns (value, gradient), where value is the actual output
# of the function, hence the "value" of the namesake
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# Prepare loop variables.
grads = None
metrics = None
for minibatch_idx in range(num_minibatches):
with jax.named_scope(f"minibatch_{minibatch_idx}"):
# Split the batch into minibatches.
start = minibatch_idx * minibatch_size
end = start + minibatch_size
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
# Calculate gradients and metrics for the minibatch.
# missing value is mean loss of batch
(_, step_metrics), step_grads = grad_fn(
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
)
# Accumulate gradients and metrics across minibatches.
if grads is None:
grads = step_grads
metrics = step_metrics
else:
# accumulation adder
grads = jax.tree.map(jnp.add, grads, step_grads)
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# %%
# jax.scan implementation
#
# pros: faster compile
# cons: slower inference
def accumulate_gradients_scan(
state: TrainState,
batch: Batch,
rng: jax.random.PRNGKey,
num_minibatches: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch.inputs.shape[0]
minibatch_size = batch_size // num_minibatches
rngs = jax.random.split(rng, num_minibatches)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
"""Determine gradients and metrics for a single minibatch."""
minibatch = jax.tree.map(
# jax.lax.dynamic_slice_in_dim
# This is roughly equivalent to the following Python indexing syntax
# applied along the specified axis: operand[..., start_index:start_index + slice_size].
# jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)
lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array).
x,
start_index=minibatch_idx * minibatch_size,
slice_size=minibatch_size,
axis=0
),
batch,
)
(_, step_metrics), step_grads = grad_fn(
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
)
return step_grads, step_metrics
# the function we expect scan to use
def _scan_step(
carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
) -> Tuple[Tuple[PyTree, Metrics], None]:
"""Scan step function for looping over minibatches."""
step_grads, step_metrics = _minibatch_step(minibatch_idx)
# notice how the carry type is a tuple of pytree and metrics
# carry is literally the accumulator of (step_grads, step_metrics)
carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics))
# jax.lax.scan expects a carry and a y
# but we have no y
return carry, None
# Determine initial shapes for gradients and metrics.
grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
# Loop over minibatches to determine gradients and metrics.
# jax.lax.scan
# jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)
# purpose: Scan a function over leading array axes while carrying along state.
# in other words, a functional for-loop
# why? because the for-loop is a single WhileOp in JAX primitive, making it faster
# equivalent python code semantics:
# def scan(f, init, xs, length=None):
# if xs is None:
# xs = [None] * length
# carry = init
# ys = []
# for x in xs:
# carry, y = f(carry, x)
# ys.append(y)
# return carry, np.stack(ys)
# note: usually we expect the ys to be the output and the carry to be hidden state
# y is the pure function output we expect
(grads, metrics), _ = jax.lax.scan(
_scan_step,
init=(grads, metrics),
xs=jnp.arange(num_minibatches),
length=num_minibatches
)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# %%
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
if use_scan:
return accumulate_gradients_scan(*args, **kwargs)
else:
return accumulate_gradients_loop(*args, **kwargs)
# %%
def train_step(
state: TrainState,
metrics: Metrics | None,
batch: Batch,
num_minibatches: int,
) -> Tuple[TrainState, Metrics]:
"""Training step function.
Executes a full training step with gradient accumulation.
Args:
state: Current training state.
metrics: Current metrics, accumulated from previous training steps.
batch: Training batch.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
Returns:
Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
"""
# Split the random number generator for the current step.
rng, step_rng = jax.random.split(state.rng)
# Determine gradients and metrics for the full batch.
grads, step_metrics = accumulate_gradients(
# we cannot use a variable to choose use_scan
# cardinal sin of jax: passing boolean into jitted function
state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
)
# Optimizer step.
new_state = state.apply_gradients(grads=grads, rng=rng)
# Accumulate metrics across training steps.
if metrics is None:
metrics = step_metrics
else:
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
batch_size = 512
num_inputs = 128
num_classes = 100
rng_seed = 0
rng = jax.random.PRNGKey(rng_seed)
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
batch = Batch(
inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
)
# Zero dropout for checking later equality between training with and without gradient accumulation.
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
params = model.init(model_rng, batch.inputs, train=False)["params"]
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3),
rng=state_rng,
)
# %%
# jax.eval_shape(fun, *args, **kwargs)
# compute shape/dtype of fun without any FLOPs
# this fails because it jits train_step without minibatch number
# thus causing the shape inference to fail
# _, metric_shapes = jax.eval_shape(
# train_step, # fun
# state, # train state
# None, # metrics
# batch, # batch
# 4, # num_minibatches
# )
_, metric_shapes = jax.eval_shape(
# this thing jitted works
functools.partial(train_step, num_minibatches=4),
state, # train state
None, # metrics
batch, # batch
)
print("Metric shapes:")
pprint(metric_shapes)
# %%
# this is an optimization trick
# cache this every time num_minibatches change
# otherwise re-compile every time
train_step_jit = jax.jit(
train_step,
# treat as a static argument
static_argnames="num_minibatches",
)
# %%
def train_with_minibatches(
state: TrainState,
batch: Batch,
num_minibatches: int,
num_train_steps: int,
) -> Tuple[TrainState, Metrics]:
"""Small helper function for training loop."""
train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
for _ in range(num_train_steps):
state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
return state, train_metrics
# %%
def print_metrics(metrics: Metrics, title: str | None = None) -> None:
"""Prints metrics with an optional title."""
metrics = jax.device_get(metrics)
lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()]
if title:
title = f" {title} "
max_len = max(len(title), max(map(len, lines)))
lines = [title.center(max_len, "=")] + lines
print("\n".join(lines))
# %%
state_mini1, metrics_mini1 = train_with_minibatches(
state, batch, num_minibatches=1, num_train_steps=4
)
state_mini4, metrics_mini4 = train_with_minibatches(
state, batch, num_minibatches=4, num_train_steps=4
)
print_metrics(metrics_mini1, "Minibatch 1")
print_metrics(metrics_mini4, "Minibatch 4")
# %% [markdown]
# # donate_buffers
# jax perform pass by value due to its functional nature
# we can do pass by reference for certain arguments
# what can be donated?
# we can only do this if we are sure that arguments will not be used
# this is usually true for model parameters and optimizer state
# since we have totally new values, and we won't use the argument values anymore
# after an update (e.g. we will use new_state and new_metrics)
train_step_donated = jax.jit(
train_step,
static_argnames="num_minibatches",
donate_argnames=(
"state",
"metrics",
),
)

View File

@ -0,0 +1,392 @@
# %% [markdown]
# # T5 implementation using jax with pjit
# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import flax.core
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# # %%
# import jax
# import jax.numpy as jnp
# import optax
# import numpy as np
# from functools import partial
# from typing import Callable, Optional
# import math
#
# # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
# jax.config.update("jax_default_matmul_precision", "bfloat16")
# # jax.config.update("jax_enable_x64", False)
# # enable cache
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
#
#
# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
#
# from flax import jax_utils, traverse_util
# from flax.jax_utils import pad_shard_unpad, unreplicate
# from flax.training import train_state
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
# import flax.core
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1_bf16'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 384 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
data_config = ConfigDict(
dict(
max_length=86,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
# seed = 117
# rng = jax.random.PRNGKey(seed)
# train_loader = dataprep.data_loader(rng, batch_size=1)
# %%
# model
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
config = T5Config()
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
# then pass the mask as follows
from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
# useful for transformer model
model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
flat_params = traverse_util.flatten_dict(model.params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
model.params = model.to_bf16(model.params, mask)
# %% [markdown]
# # Model
#
#
#
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
# easy way to achieve data parallelism
# also achieves folding of rng keys
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# set bf16 for model params
# model.params = model.to_bf16(model.params)
# Cast parameters to bfloat16 if desired
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
@jax.jit
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
# loss = jax.lax.psum(loss, "batch")
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# def generate_step(params, batch):
# model.params = params
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
# return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
# p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
# %%
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
rng, data_rng = jax.random.split(rng)
train_loader = dataprep.data_loader(data_rng, batch_size=batch_size)
steps_per_epoch = training_size // train_batch_size
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time = time.time() - train_start
train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready()
epochs.write(
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Epoch... ({epoch + 1}/{num_epochs} | "
# f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# output_dir = save_path
# # save checkpoint after each epoch and push checkpoint to the hub
# if jax.process_index() == 0:
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# model.save_pretrained(output_dir, params=params)
# tokenizer.save_pretrained(output_dir)

View File

@ -196,7 +196,7 @@ model.params = model.to_bf16(model.params, mask)
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
@ -391,6 +391,8 @@ adamw = optax.adamw(
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
# easy way to achieve data parallelism
# also achieves folding of rng keys
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))