learn_jax/parallel/fully_sharded_data_parallel...

755 lines
24 KiB
Python
Raw Normal View History

# %% [markdown]
# # Fully-Sharded Data Parallelism
# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
import optax
import logging
import time
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
jax.config.update('jax_platform_name', 'cpu')
# %%
# required functions:
# Batch
# TrainState
# accumulate_gradients
# print_metrics
from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics
# %%
# import the fold_rng_over_axis
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
"""Folds the random number generator over the given axis.
This is useful for generating a different random number for each device
across a certain axis (e.g. the model axis).
Args:
rng: The random number generator.
axis_name: The axis name to fold the random number generator over.
Returns:
A new random number generator, different for each device index along the axis.
"""
axis_index = jax.lax.axis_index(axis_name)
return jax.random.fold_in(rng, axis_index)
# MARK: DATA PARALLELISM
# %% [markdown]
# # Data Parallelism
# we start with plain data parallelism
#
# using shard_map, we write single-device code and let shard map handle the rest
# %%
# plain data parallel - sharding only data inputs and outputs
class DPClassifier(nn.Module):
# contains the attributes listed in config
# hidden_size
# dropout_rate
# dtype - for computation
# num_classes
# data_axis_name
config: ConfigDict
# note how there is no data_axis_name within the actual __call__
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
x = nn.Dense(
features=self.config.hidden_size,
dtype=self.config.dtype,
name="input_dense",
)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
x = nn.Dense(
features=self.config.num_classes,
dtype=self.config.dtype,
name="output_dense",
)(x)
x = x.astype(jnp.float32)
return x
# config
data_config = ConfigDict(
dict(
batch_size=128,
num_classes=10,
input_size=784,
)
)
model_config = ConfigDict(
dict(
hidden_size=512,
dropout_rate=0.1,
dtype=jnp.bfloat16,
num_classes=data_config.num_classes,
data_axis_name="data",
)
)
optimizer_config = ConfigDict(
dict(
learning_rate=1e-3,
num_minibatches=4,
)
)
config = ConfigDict(
dict(
model=model_config,
optimizer=optimizer_config,
data=data_config,
data_axis_name=model_config.data_axis_name,
seed=42,
)
)
# %%
# initialize
model_dp = DPClassifier(config=config.model)
optimizer = optax.adamw(
learning_rate=config.optimizer.learning_rate,
)
# init rng
rng = jax.random.PRNGKey(config.seed)
# init model rng
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
# create synthetic data
batch = Batch(
inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
labels=jax.random.randint(
data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
),
)
# init data_parallel TrainState state
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
init_rng, rng = jax.random.split(rng)
variables = model.init({"params": init_rng}, x, train=False)
params = variables.pop("params")
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
rng=rng,
)
return state
# create mesh
device_array = np.array(jax.devices())
mesh = Mesh(device_array, (config.data_axis_name,))
# we are just sharding the same model across devices
# no different from a flax replicate
init_dp_fn = jax.jit(
shard_map(
functools.partial(init_dp, model=model_dp),
mesh,
in_specs=(P(), P(config.data_axis_name)),
out_specs=P(),
check_rep=False,
),
)
state_dp = init_dp_fn(model_init_rng, batch.inputs)
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
# MARK: TRAIN STEP
# %%
# train step
def loss_fn(
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
# set different rng over various devices
dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
# Remaining computation is the same as before for single device.
logits = apply_fn(
{"params": params},
batch.inputs,
train=True,
rngs={"dropout": dropout_rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
batch_size = batch.inputs.shape[0]
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
loss = loss.mean()
return loss, step_metrics
# train step dp
# simple data parallel has the model on every device
# but each device has different data
def train_step_dp(
state: TrainState,
metrics: Metrics | None,
batch: Batch,
) -> Tuple[TrainState, Metrics]:
rng, step_rng = jax.random.split(state.rng)
# accumulate gradients like before
grads, step_metrics = accumulate_gradients(
state,
batch,
step_rng,
config.optimizer.num_minibatches,
loss_fn=loss_fn,
)
# Update parameters. We need to sync the gradients across devices before updating.
with jax.named_scope("sync_gradients"):
grads = jax.tree.map(
lambda g: jax.lax.pmean(
g, axis_name=config.data_axis_name),
grads)
new_state = state.apply_gradients(grads=grads, rng=rng)
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
# and only synchronize them before logging. For simplicity, we sum them here.
with jax.named_scope("sync_metrics"):
step_metrics = jax.tree.map(
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
)
if metrics is None:
metrics = step_metrics
else:
# combine all the synced metrics
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
# we will now wrap the train step with shard_map and jit it
# here we will be sharding input and output data
train_step_dp_fn = jax.jit(
shard_map(
train_step_dp,
mesh,
in_specs=(P(), P(), P(config.data_axis_name)),
out_specs=(P(), P()),
check_rep=False,
),
# state and metrics change and won't be re-used
# pass by reference and throw away with function
donate_argnames=("state", "metrics"),
)
# %%
# get the metric_shapes so that we can init arrays for accumulation
_, metric_shapes = jax.eval_shape(
train_step_dp_fn,
state_dp,
None,
batch,
)
# init arrays with shape
metrics_dp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
# %%
start_time = time.time()
for _ in range(15):
state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
duration = time.time() - start_time
print(duration)
final_metrics_dp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
state_dp, final_metrics_dp = train_step_dp_fn(
state_dp,
final_metrics_dp,
batch)
print_metrics(final_metrics_dp)
# %%
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
print("Metrics")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp))
####################################################################
# stuff works until here
# it is still same as flax replicate style in huggingface
# MARK: PARAMETER SHARDING
# %% [markdown]
# # parameter sharding
# Basic strategy: init full parameters on each device, then use
# jax.lax.axis_index to split parameters across devices, and keep a shard on
# each device
#
# use nn.Partitioned to annotate sharding spec on parameters
# quite similar to PartitionSpec
#
# parameters are either jax.Array or a flax.linen.Partitioned
# %%
# type annotation
Parameter = jax.Array | nn.Partitioned
# %%
# function to shard parameters across devices
# look for an axis to equally split across the number of devices
# we can specify which parameters to shard, since they vary in size
# we set a floor on the size for sharding
@jax.named_scope("shard_params")
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
"""Shard parameters across the given mesh axis.
Args:
params: The parameters to shard.
axis_name: The axis to shard parameters across.
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
Returns:
PyTree of same structure as params, but with leaves sharded over new axis if possible.
"""
# axis_index
axis_idx = jax.lax.axis_index(axis_name)
# number of units in the axis
axis_size = jax.lax.psum(1, axis_name)
# split function
# check each parameter if it had been sharded
def _split(x: Parameter) -> Parameter:
# already sharded
if isinstance(x, nn.Partitioned):
value, names = x.value, x.names
# not sharded
else:
value = x
names = (None,) * value.ndim
# logging only runs on first jit
# this section checks for why a parameter is not already sharded on the axis
# check for sharded parameters despite being sharded
# (that means its on a different axis)
if axis_name in names:
logging.warning(
f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
)
return x
# check if parameter is to small
elif value.size <= min_weight_size:
logging.info(
f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
)
return x
# let's start sharding!
else:
shape = value.shape
idx = np.argsort(shape)[::-1] # Shard along largest possible axis.
for i in idx:
# this technically runs once because of return
# we only shard if we can split evenly across devices
# and if it ain't alreayd sharded
if shape[i] % axis_size == 0 and names[i] is None:
split_size = shape[i] // axis_size
p_sharded = nn.Partitioned(
value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device.
value,
axis_idx * split_size,
split_size,
axis=i
),
names=names[:i] + (axis_name,) + names[i + 1 :],
)
return p_sharded
logging.warning(
f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
)
return x
# we apply the _split function across the parameter pytree
return jax.tree.map(
_split,
params,
is_leaf=lambda x: isinstance(
x, nn.Partitioned
), # Consider a nn.Partitioned object as a leaf.
)
# %%
# function to gather parameters back to a single device
# but first we need create a custom function for mean gradient computation
# jax.lax.all_gather -> retrieve shards and assemble full array in each device
# jax.lax.psum_scatter -> scatter gradients back to respective devices
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
"""Gathering with averaging gradients across replicas."""
axis_size = jax.lax.psum(1, axis_name)
# Define a custom gradient for the gather operation.
@jax.custom_gradient
def f(x):
# adjust backward to turn sum into mean of axis
def grad_fn(g):
# pmean_scatter from psum_scatter
# after computing from full gradient array, our shard only has a
# portion of the parameters, we only get the gradients associated
# with parameters of our shard
return (
jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
)
# assemble shards to form full gradient array
return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn
return f(x)
# gather params back - e.g. when computing a module forward call
# reverse operation of "shard_params"
# depends on: gather_array_with_mean_grads
@jax.named_scope("gather_params")
def gather_params(params: PyTree, axis_name: str) -> PyTree:
"""Gather parameters from all replicas across the given axis.
Args:
params: The parameters to gather.
axis_name: The axis to gather parameters across.
Returns:
PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
"""
def _gather(p: Parameter) -> Parameter:
if isinstance(p, nn.Partitioned) and axis_name in p.names:
param_shard = p.names
shard_axis = param_shard.index(axis_name)
value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
# If there are any other axes that are sharded, we need to keep the partitioned structure.
# Otherwise, we can return the value directly.
param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
if any([name is not None for name in param_shard]):
# we return the still-sharded axes shard
return nn.Partitioned(value, param_shard)
else:
return value
else:
return p
# we find all the sharded params and gather them, returning a complete parameter
return jax.tree.map(
_gather,
params,
is_leaf=lambda x: isinstance(x, nn.Partitioned))
# %%
# when we call a module, we gather the parameters back to a single device
# wrap a module into a nn.map_variables transform
# allows for transforms on the parameter before and after a module call
# depends on: gather_params, shard_params
def shard_module_params(
target: nn.Module | Callable,
axis_name: str,
min_weight_size: int = 2**18 # 262,144
) -> nn.Module | Callable:
"""Shard parameters of a module across replicas.
Args:
target: The module to shard.
axis_name: The axis name to shard parameters across.
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
Returns:
The module with sharded parameters.
"""
return nn.map_variables(
target,
trans_in_fn=functools.partial(
gather_params, axis_name=axis_name),
trans_out_fn=functools.partial(
shard_params, axis_name=axis_name, min_weight_size=min_weight_size
),
mapped_collections="params",
mutable=True,
)
# %%
# define new function with axes constraints
# this forms the template for sharding future modules
# remember, flax modules are subclassed from elementary flax modules
class FSDPClassifier(nn.Module):
config: ConfigDict
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
# create a sharded module
sharded_dense = shard_module_params(
nn.Dense,
axis_name=self.config.data_axis_name, # axes
min_weight_size=self.config.min_weight_size, # min_weight
)
x = sharded_dense(
features=self.config.hidden_size,
dtype=self.config.dtype,
name="input_dense",
)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
x = sharded_dense(
features=self.config.num_classes,
dtype=self.config.dtype,
name="output_dense",
)(x)
x = x.astype(jnp.float32)
return x
# %%
# initialization
config.model.min_weight_size = 2**4
model_fsdp = FSDPClassifier(config=config.model)
# the earlier init function
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
init_rng, rng = jax.random.split(rng)
variables = model.init({"params": init_rng}, x, train=False)
params = variables.pop("params")
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
rng=rng,
)
return state
# initialize our sharded model with mesh
# we need to adjust the shard map since partitioning is determined within the
# model init, hence we cannot manually specify it
#
# we do a hack where we just try and let it evaluate the shapes
# we set an unknown output specification - aka fully replicate
#
# we then get the partition_spec of the shapes of the parameters
init_fsdp_fn = shard_map(
functools.partial(init_dp, model=model_fsdp),
mesh,
# first P() is for model_init_rng
# second P(config.data_axis_name) is for batch.inputs
in_specs=(P(), P(config.data_axis_name)),
# not partitioned, fully replicated
out_specs=P(),
check_rep=False, # disable checks for replication errors in out_specs
)
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
# %% [raw]
# TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of FSDPClassifier(
# # attributes
# config = data_axis_name: data
# dropout_rate: 0.1
# dtype: !!python/name:jax.numpy.bfloat16 ''
# hidden_size: 512
# min_weight_size: 16
# num_classes: 10
#
# )>,
# params={
# 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
# tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x761e8ef00400>,
# update=<function chain.<locals>.update_fn at 0x761e8ef01080>),
# opt_state=(ScaleByAdamState(count=PartitionSpec(),
# mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
# nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}),
# EmptyState(), EmptyState()), rng=PartitionSpec())
# %%
# then from the state_fsdp_specs, we obtain our config
# this print clarifies everything -> the reason why earlier we do not know the
# partitionspec is because we only know which parameters gets to be sharded at
# model init
print("RNG", state_fsdp_specs.rng)
print("\nParameters")
pprint(state_fsdp_specs.params)
print("\nOptimizer state")
pprint(state_fsdp_specs.opt_state[0])
# %%
# init again, this time with the specs and knowledge of what is and should not
# be sharded
init_fsdp_fn = jax.jit(
shard_map(
functools.partial(init_dp, model=model_fsdp),
mesh,
in_specs=(P(), P(config.data_axis_name)),
out_specs=state_fsdp_specs,
check_rep=False,
)
)
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)
# %%
print("FSDP Parameters")
pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params)))
# %%
# train step
# we need to handle the sync of gradients
# some parameters are sharded, some are not
def sync_gradients(
grads: PyTree,
axis_names: Sequence[str],
) -> PyTree:
"""Synchronize gradients across devices.
Gradients for parameters that are replicated over a given axis are averaged across devices.
Parameters that are partitioned over a given axis are considered to already have a mean of
the gradients on each device, and hence do not need to be altered.
Args:
grads: The gradients to synchronize.
axis_names: The axis names to synchronize gradients across.
Returns:
The gradients averaged over the specified axes if they are replicated.
"""
def sync_grad(g: Parameter) -> Parameter:
if isinstance(g, nn.Partitioned):
# Tree leaves for flattening potentially nested axis (multiple names
# can exist for single array axis).
replication_axis_names = [
name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
]
if len(replication_axis_names) == 0:
# Parameters partitioned over all axes.
return g
else:
# Average over remaining replicated axes.
return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
else:
# Parameters are replicated over all axes.
return jax.lax.pmean(g, axis_name=axis_names)
return jax.tree.map(
sync_grad,
grads,
is_leaf=lambda x: isinstance(x, nn.Partitioned))
# %%
def train_step_fsdp(
state: TrainState,
metrics: Metrics,
batch: Batch,
) -> Tuple[TrainState, Metrics]:
rng, step_rng = jax.random.split(state.rng)
# perform one forward pass
grads, step_metrics = accumulate_gradients(
state,
batch,
step_rng,
config.optimizer.num_minibatches,
loss_fn=loss_fn,
)
# Update parameters. We need to sync the gradients across devices before updating.
with jax.named_scope("sync_gradients"):
grads = sync_gradients(grads, (config.data_axis_name,))
# then update model
new_state = state.apply_gradients(grads=grads, rng=rng)
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
# and only synchronize them before logging. For simplicity, we sum them here.
with jax.named_scope("sync_metrics"):
step_metrics = jax.tree.map(
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
)
if metrics is None:
metrics = step_metrics
else:
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
# jit the train_step_fsdp
train_step_fsdp_fn = jax.jit(
shard_map(
train_step_fsdp,
mesh,
in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
out_specs=(state_fsdp_specs, P()),
check_rep=False,
),
donate_argnames=("state", "metrics"),
)
# get the metric shape to initialize accumulator arrays for metrics
_, metric_shapes = jax.eval_shape(
train_step_fsdp_fn,
state_fsdp,
None,
batch,
)
metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
# %%
# train
start_time = time.time()
for _ in range(15):
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
state_fsdp,
metrics_fsdp, batch)
duration = time.time() - start_time
print(duration)
# get metrics and state
final_metrics_fsdp = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
metric_shapes)
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
state_fsdp,
final_metrics_fsdp, batch)
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
# %%