757 lines
24 KiB
Python
757 lines
24 KiB
Python
|
|
# %% [markdown]
|
|
# # Fully-Sharded Data Parallelism
|
|
|
|
|
|
# MARK: START
|
|
# %%
|
|
# let's make 8-device simulator
|
|
import os
|
|
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = True
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
else:
|
|
# GPU flags
|
|
flags += (
|
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
|
"--xla_gpu_triton_gemm_any=false "
|
|
"--xla_gpu_enable_async_collectives=true "
|
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
|
)
|
|
os.environ["XLA_FLAGS"] = flags
|
|
|
|
import functools
|
|
from pprint import pprint
|
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
from ml_collections import ConfigDict
|
|
import optax
|
|
import logging
|
|
import time
|
|
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
|
|
# %%
|
|
# required functions:
|
|
# Batch
|
|
# TrainState
|
|
# accumulate_gradients
|
|
# print_metrics
|
|
from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics
|
|
# %%
|
|
# import the fold_rng_over_axis
|
|
|
|
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
|
|
"""Folds the random number generator over the given axis.
|
|
|
|
This is useful for generating a different random number for each device
|
|
across a certain axis (e.g. the model axis).
|
|
|
|
Args:
|
|
rng: The random number generator.
|
|
axis_name: The axis name to fold the random number generator over.
|
|
|
|
Returns:
|
|
A new random number generator, different for each device index along the axis.
|
|
"""
|
|
axis_index = jax.lax.axis_index(axis_name)
|
|
return jax.random.fold_in(rng, axis_index)
|
|
|
|
# MARK: DATA PARALLELISM
|
|
# %% [markdown]
|
|
# # Data Parallelism
|
|
# we start with plain data parallelism
|
|
#
|
|
# using shard_map, we write single-device code and let shard map handle the rest
|
|
|
|
# %%
|
|
# plain data parallel - sharding only data inputs and outputs
|
|
class DPClassifier(nn.Module):
|
|
# contains the attributes listed in config
|
|
# hidden_size
|
|
# dropout_rate
|
|
# dtype - for computation
|
|
# num_classes
|
|
# data_axis_name
|
|
config: ConfigDict
|
|
|
|
# note how there is no data_axis_name within the actual __call__
|
|
@nn.compact
|
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
|
x = nn.Dense(
|
|
features=self.config.hidden_size,
|
|
dtype=self.config.dtype,
|
|
name="input_dense",
|
|
)(x)
|
|
x = nn.silu(x)
|
|
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
|
|
x = nn.Dense(
|
|
features=self.config.num_classes,
|
|
dtype=self.config.dtype,
|
|
name="output_dense",
|
|
)(x)
|
|
x = x.astype(jnp.float32)
|
|
return x
|
|
|
|
# config
|
|
data_config = ConfigDict(
|
|
dict(
|
|
batch_size=128,
|
|
num_classes=10,
|
|
input_size=784,
|
|
)
|
|
)
|
|
model_config = ConfigDict(
|
|
dict(
|
|
hidden_size=512,
|
|
dropout_rate=0.1,
|
|
dtype=jnp.bfloat16,
|
|
num_classes=data_config.num_classes,
|
|
data_axis_name="data",
|
|
)
|
|
)
|
|
optimizer_config = ConfigDict(
|
|
dict(
|
|
learning_rate=1e-3,
|
|
num_minibatches=4,
|
|
)
|
|
)
|
|
config = ConfigDict(
|
|
dict(
|
|
model=model_config,
|
|
optimizer=optimizer_config,
|
|
data=data_config,
|
|
data_axis_name=model_config.data_axis_name,
|
|
seed=42,
|
|
)
|
|
)
|
|
|
|
# %%
|
|
# initialize
|
|
model_dp = DPClassifier(config=config.model)
|
|
optimizer = optax.adamw(
|
|
learning_rate=config.optimizer.learning_rate,
|
|
)
|
|
|
|
# init rng
|
|
rng = jax.random.PRNGKey(config.seed)
|
|
# init model rng
|
|
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
|
|
# create synthetic data
|
|
batch = Batch(
|
|
inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
|
|
labels=jax.random.randint(
|
|
data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
|
|
),
|
|
)
|
|
|
|
# init data_parallel TrainState state
|
|
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
|
|
init_rng, rng = jax.random.split(rng)
|
|
variables = model.init({"params": init_rng}, x, train=False)
|
|
params = variables.pop("params")
|
|
state = TrainState.create(
|
|
apply_fn=model.apply,
|
|
params=params,
|
|
tx=optimizer,
|
|
rng=rng,
|
|
)
|
|
return state
|
|
|
|
# create mesh
|
|
device_array = np.array(jax.devices())
|
|
mesh = Mesh(device_array, (config.data_axis_name,))
|
|
|
|
# we are just sharding the same model across devices
|
|
# no different from a flax replicate
|
|
init_dp_fn = jax.jit(
|
|
shard_map(
|
|
functools.partial(init_dp, model=model_dp),
|
|
mesh,
|
|
in_specs=(P(), P(config.data_axis_name)),
|
|
out_specs=P(),
|
|
check_rep=False,
|
|
),
|
|
)
|
|
|
|
state_dp = init_dp_fn(model_init_rng, batch.inputs)
|
|
print("DP Parameters")
|
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
|
|
|
|
# MARK: TRAIN STEP
|
|
# %%
|
|
# train step
|
|
def loss_fn(
|
|
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
|
|
) -> Tuple[jax.Array, Dict[str, Any]]:
|
|
|
|
# set different rng over various devices
|
|
dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
|
|
|
|
# Remaining computation is the same as before for single device.
|
|
logits = apply_fn(
|
|
{"params": params},
|
|
batch.inputs,
|
|
train=True,
|
|
rngs={"dropout": dropout_rng})
|
|
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
|
|
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
|
|
batch_size = batch.inputs.shape[0]
|
|
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
|
|
loss = loss.mean()
|
|
return loss, step_metrics
|
|
|
|
# train step dp
|
|
# simple data parallel has the model on every device
|
|
# but each device has different data
|
|
def train_step_dp(
|
|
state: TrainState,
|
|
metrics: Metrics | None,
|
|
batch: Batch,
|
|
) -> Tuple[TrainState, Metrics]:
|
|
rng, step_rng = jax.random.split(state.rng)
|
|
# accumulate gradients like before
|
|
grads, step_metrics = accumulate_gradients(
|
|
state,
|
|
batch,
|
|
step_rng,
|
|
config.optimizer.num_minibatches,
|
|
loss_fn=loss_fn,
|
|
)
|
|
# Update parameters. We need to sync the gradients across devices before updating.
|
|
with jax.named_scope("sync_gradients"):
|
|
grads = jax.tree.map(
|
|
lambda g: jax.lax.pmean(
|
|
g, axis_name=config.data_axis_name),
|
|
grads)
|
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
|
|
|
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
|
|
# and only synchronize them before logging. For simplicity, we sum them here.
|
|
with jax.named_scope("sync_metrics"):
|
|
step_metrics = jax.tree.map(
|
|
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
|
|
)
|
|
|
|
if metrics is None:
|
|
metrics = step_metrics
|
|
else:
|
|
# combine all the synced metrics
|
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
|
|
|
return new_state, metrics
|
|
|
|
# %%
|
|
# we will now wrap the train step with shard_map and jit it
|
|
# here we will be sharding input and output data
|
|
train_step_dp_fn = jax.jit(
|
|
shard_map(
|
|
train_step_dp,
|
|
mesh,
|
|
in_specs=(P(), P(), P(config.data_axis_name)),
|
|
out_specs=(P(), P()),
|
|
check_rep=False,
|
|
),
|
|
# state and metrics change and won't be re-used
|
|
# pass by reference and throw away with function
|
|
donate_argnames=("state", "metrics"),
|
|
)
|
|
|
|
# %%
|
|
# get the metric_shapes so that we can init arrays for accumulation
|
|
_, metric_shapes = jax.eval_shape(
|
|
train_step_dp_fn,
|
|
state_dp,
|
|
None,
|
|
batch,
|
|
)
|
|
# init arrays with shape
|
|
metrics_dp = jax.tree.map(
|
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
|
metric_shapes)
|
|
|
|
|
|
# %%
|
|
start_time = time.time()
|
|
for _ in range(15):
|
|
state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
|
|
duration = time.time() - start_time
|
|
print(duration)
|
|
|
|
final_metrics_dp = jax.tree.map(
|
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
|
metric_shapes)
|
|
state_dp, final_metrics_dp = train_step_dp_fn(
|
|
state_dp,
|
|
final_metrics_dp,
|
|
batch)
|
|
print_metrics(final_metrics_dp)
|
|
|
|
# %%
|
|
print("DP Parameters")
|
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
|
|
print("Metrics")
|
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp))
|
|
|
|
####################################################################
|
|
# stuff works until here
|
|
# it is still same as flax replicate style in huggingface
|
|
|
|
|
|
# MARK: PARAMETER SHARDING
|
|
# %% [markdown]
|
|
# # parameter sharding
|
|
# Basic strategy: init full parameters on each device, then use
|
|
# jax.lax.axis_index to split parameters across devices, and keep a shard on
|
|
# each device
|
|
#
|
|
# use nn.Partitioned to annotate sharding spec on parameters
|
|
# quite similar to PartitionSpec
|
|
#
|
|
# parameters are either jax.Array or a flax.linen.Partitioned
|
|
|
|
# %%
|
|
# type annotation
|
|
Parameter = jax.Array | nn.Partitioned
|
|
|
|
# %%
|
|
# function to shard parameters across devices
|
|
# look for an axis to equally split across the number of devices
|
|
# we can specify which parameters to shard, since they vary in size
|
|
# we set a floor on the size for sharding
|
|
@jax.named_scope("shard_params")
|
|
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
|
|
"""Shard parameters across the given mesh axis.
|
|
|
|
Args:
|
|
params: The parameters to shard.
|
|
axis_name: The axis to shard parameters across.
|
|
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
|
|
|
|
Returns:
|
|
PyTree of same structure as params, but with leaves sharded over new axis if possible.
|
|
"""
|
|
# axis_index
|
|
axis_idx = jax.lax.axis_index(axis_name)
|
|
# number of units in the axis
|
|
axis_size = jax.lax.psum(1, axis_name)
|
|
|
|
# split function
|
|
# check each parameter if it had been sharded
|
|
def _split(x: Parameter) -> Parameter:
|
|
|
|
# already sharded
|
|
if isinstance(x, nn.Partitioned):
|
|
value, names = x.value, x.names
|
|
# not sharded
|
|
else:
|
|
value = x
|
|
names = (None,) * value.ndim
|
|
|
|
# logging only runs on first jit
|
|
# this section checks for why a parameter is not already sharded on the axis
|
|
# check for sharded parameters despite being sharded
|
|
# (that means its on a different axis)
|
|
if axis_name in names:
|
|
logging.warning(
|
|
f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
|
|
)
|
|
return x
|
|
# check if parameter is to small
|
|
elif value.size <= min_weight_size:
|
|
logging.info(
|
|
f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
|
|
)
|
|
return x
|
|
# let's start sharding!
|
|
else:
|
|
shape = value.shape
|
|
idx = np.argsort(shape)[::-1] # Shard along largest possible axis.
|
|
for i in idx:
|
|
# this technically runs once because of return
|
|
# we only shard if we can split evenly across devices
|
|
# and if it ain't alreayd sharded
|
|
if shape[i] % axis_size == 0 and names[i] is None:
|
|
split_size = shape[i] // axis_size
|
|
p_sharded = nn.Partitioned(
|
|
value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device.
|
|
value,
|
|
axis_idx * split_size,
|
|
split_size,
|
|
axis=i
|
|
),
|
|
names=names[:i] + (axis_name,) + names[i + 1 :],
|
|
)
|
|
return p_sharded
|
|
|
|
logging.warning(
|
|
f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
|
|
)
|
|
return x
|
|
|
|
# we apply the _split function across the parameter pytree
|
|
return jax.tree.map(
|
|
_split,
|
|
params,
|
|
is_leaf=lambda x: isinstance(
|
|
x, nn.Partitioned
|
|
), # Consider a nn.Partitioned object as a leaf.
|
|
)
|
|
|
|
# %%
|
|
# function to gather parameters back to a single device
|
|
|
|
# but first we need create a custom function for mean gradient computation
|
|
# jax.lax.all_gather -> retrieve shards and assemble full array in each device
|
|
# jax.lax.psum_scatter -> scatter gradients back to respective devices
|
|
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
|
|
"""Gathering with averaging gradients across replicas."""
|
|
axis_size = jax.lax.psum(1, axis_name)
|
|
|
|
# Define a custom gradient for the gather operation.
|
|
@jax.custom_gradient
|
|
def f(x):
|
|
# adjust backward to turn sum into mean of axis
|
|
def grad_fn(g):
|
|
# pmean_scatter from psum_scatter
|
|
# after computing from full gradient array, our shard only has a
|
|
# portion of the parameters, we only get the gradients associated
|
|
# with parameters of our shard
|
|
return (
|
|
jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
|
|
)
|
|
|
|
# assemble shards to form full gradient array
|
|
return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn
|
|
|
|
return f(x)
|
|
|
|
# gather params back - e.g. when computing a module forward call
|
|
# reverse operation of "shard_params"
|
|
# depends on: gather_array_with_mean_grads
|
|
@jax.named_scope("gather_params")
|
|
def gather_params(params: PyTree, axis_name: str) -> PyTree:
|
|
"""Gather parameters from all replicas across the given axis.
|
|
|
|
Args:
|
|
params: The parameters to gather.
|
|
axis_name: The axis to gather parameters across.
|
|
|
|
Returns:
|
|
PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
|
|
"""
|
|
|
|
def _gather(p: Parameter) -> Parameter:
|
|
if isinstance(p, nn.Partitioned) and axis_name in p.names:
|
|
param_shard = p.names
|
|
shard_axis = param_shard.index(axis_name)
|
|
value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
|
|
|
|
# If there are any other axes that are sharded, we need to keep the partitioned structure.
|
|
# Otherwise, we can return the value directly.
|
|
param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
|
|
if any([name is not None for name in param_shard]):
|
|
# we return the still-sharded axes shard
|
|
return nn.Partitioned(value, param_shard)
|
|
else:
|
|
return value
|
|
else:
|
|
return p
|
|
|
|
# we find all the sharded params and gather them, returning a complete parameter
|
|
return jax.tree.map(
|
|
_gather,
|
|
params,
|
|
is_leaf=lambda x: isinstance(x, nn.Partitioned))
|
|
|
|
# %%
|
|
# when we call a module, we gather the parameters back to a single device
|
|
# wrap a module into a nn.map_variables transform
|
|
# allows for transforms on the parameter before and after a module call
|
|
# depends on: gather_params, shard_params
|
|
def shard_module_params(
|
|
target: nn.Module | Callable,
|
|
axis_name: str,
|
|
min_weight_size: int = 2**18 # 262,144
|
|
) -> nn.Module | Callable:
|
|
"""Shard parameters of a module across replicas.
|
|
|
|
Args:
|
|
target: The module to shard.
|
|
axis_name: The axis name to shard parameters across.
|
|
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
|
|
|
|
Returns:
|
|
The module with sharded parameters.
|
|
"""
|
|
return nn.map_variables(
|
|
target,
|
|
trans_in_fn=functools.partial(
|
|
gather_params, axis_name=axis_name),
|
|
trans_out_fn=functools.partial(
|
|
shard_params, axis_name=axis_name, min_weight_size=min_weight_size
|
|
),
|
|
mapped_collections="params",
|
|
mutable=True,
|
|
)
|
|
|
|
# %%
|
|
# define new function with axes constraints
|
|
# this forms the template for sharding future modules
|
|
# remember, flax modules are subclassed from elementary flax modules
|
|
class FSDPClassifier(nn.Module):
|
|
config: ConfigDict
|
|
|
|
@nn.compact
|
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
|
# create a sharded module
|
|
sharded_dense = shard_module_params(
|
|
nn.Dense,
|
|
axis_name=self.config.data_axis_name, # axes
|
|
min_weight_size=self.config.min_weight_size, # min_weight
|
|
)
|
|
x = sharded_dense(
|
|
features=self.config.hidden_size,
|
|
dtype=self.config.dtype,
|
|
name="input_dense",
|
|
)(x)
|
|
x = nn.silu(x)
|
|
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
|
|
x = sharded_dense(
|
|
features=self.config.num_classes,
|
|
dtype=self.config.dtype,
|
|
name="output_dense",
|
|
)(x)
|
|
x = x.astype(jnp.float32)
|
|
return x
|
|
|
|
# %%
|
|
# initialization
|
|
config.model.min_weight_size = 2**4
|
|
model_fsdp = FSDPClassifier(config=config.model)
|
|
|
|
# the earlier init function
|
|
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
|
|
init_rng, rng = jax.random.split(rng)
|
|
variables = model.init({"params": init_rng}, x, train=False)
|
|
params = variables.pop("params")
|
|
state = TrainState.create(
|
|
apply_fn=model.apply,
|
|
params=params,
|
|
tx=optimizer,
|
|
rng=rng,
|
|
)
|
|
return state
|
|
|
|
|
|
# initialize our sharded model with mesh
|
|
# we need to adjust the shard map since partitioning is determined within the
|
|
# model init, hence we cannot manually specify it
|
|
#
|
|
# we do a hack where we just try and let it evaluate the shapes
|
|
# we set an unknown output specification - aka fully replicate
|
|
#
|
|
# we then get the partition_spec of the shapes of the parameters
|
|
init_fsdp_fn = shard_map(
|
|
functools.partial(init_dp, model=model_fsdp),
|
|
mesh,
|
|
# first P() is for model_init_rng
|
|
# second P(config.data_axis_name) is for batch.inputs
|
|
in_specs=(P(), P(config.data_axis_name)),
|
|
# not partitioned, fully replicated
|
|
out_specs=P(),
|
|
check_rep=False, # disable checks for replication errors in out_specs
|
|
)
|
|
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
|
|
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
|
|
# %% [raw]
|
|
# TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of FSDPClassifier(
|
|
# # attributes
|
|
# config = data_axis_name: data
|
|
# dropout_rate: 0.1
|
|
# dtype: !!python/name:jax.numpy.bfloat16 ''
|
|
# hidden_size: 512
|
|
# min_weight_size: 16
|
|
# num_classes: 10
|
|
#
|
|
# )>,
|
|
# params={
|
|
# 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
|
|
# tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x761e8ef00400>,
|
|
# update=<function chain.<locals>.update_fn at 0x761e8ef01080>),
|
|
# opt_state=(ScaleByAdamState(count=PartitionSpec(),
|
|
# mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
|
|
# nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}),
|
|
# EmptyState(), EmptyState()), rng=PartitionSpec())
|
|
|
|
# %%
|
|
# then from the state_fsdp_specs, we obtain our config
|
|
# this print clarifies everything -> the reason why earlier we do not know the
|
|
# partitionspec is because we only know which parameters gets to be sharded at
|
|
# model init
|
|
print("RNG", state_fsdp_specs.rng)
|
|
print("\nParameters")
|
|
pprint(state_fsdp_specs.params)
|
|
print("\nOptimizer state")
|
|
pprint(state_fsdp_specs.opt_state[0])
|
|
|
|
# %%
|
|
# init again, this time with the specs and knowledge of what is and should not
|
|
# be sharded
|
|
init_fsdp_fn = jax.jit(
|
|
shard_map(
|
|
functools.partial(init_dp, model=model_fsdp),
|
|
mesh,
|
|
in_specs=(P(), P(config.data_axis_name)),
|
|
out_specs=state_fsdp_specs,
|
|
check_rep=False,
|
|
)
|
|
)
|
|
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)
|
|
|
|
# %%
|
|
print("FSDP Parameters")
|
|
pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params)))
|
|
|
|
# %%
|
|
# train step
|
|
|
|
# we need to handle the sync of gradients
|
|
# some parameters are sharded, some are not
|
|
def sync_gradients(
|
|
grads: PyTree,
|
|
axis_names: Sequence[str],
|
|
) -> PyTree:
|
|
"""Synchronize gradients across devices.
|
|
|
|
Gradients for parameters that are replicated over a given axis are averaged across devices.
|
|
Parameters that are partitioned over a given axis are considered to already have a mean of
|
|
the gradients on each device, and hence do not need to be altered.
|
|
|
|
Args:
|
|
grads: The gradients to synchronize.
|
|
axis_names: The axis names to synchronize gradients across.
|
|
|
|
Returns:
|
|
The gradients averaged over the specified axes if they are replicated.
|
|
"""
|
|
|
|
def sync_grad(g: Parameter) -> Parameter:
|
|
if isinstance(g, nn.Partitioned):
|
|
# Tree leaves for flattening potentially nested axis (multiple names
|
|
# can exist for single array axis).
|
|
replication_axis_names = [
|
|
name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
|
|
]
|
|
if len(replication_axis_names) == 0:
|
|
# Parameters partitioned over all axes.
|
|
return g
|
|
else:
|
|
# Average over remaining replicated axes.
|
|
return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
|
|
else:
|
|
# Parameters are replicated over all axes.
|
|
return jax.lax.pmean(g, axis_name=axis_names)
|
|
|
|
return jax.tree.map(
|
|
sync_grad,
|
|
grads,
|
|
is_leaf=lambda x: isinstance(x, nn.Partitioned))
|
|
|
|
# %%
|
|
def train_step_fsdp(
|
|
state: TrainState,
|
|
metrics: Metrics,
|
|
batch: Batch,
|
|
) -> Tuple[TrainState, Metrics]:
|
|
rng, step_rng = jax.random.split(state.rng)
|
|
# perform one forward pass
|
|
grads, step_metrics = accumulate_gradients(
|
|
state,
|
|
batch,
|
|
step_rng,
|
|
config.optimizer.num_minibatches,
|
|
loss_fn=loss_fn,
|
|
)
|
|
# Update parameters. We need to sync the gradients across devices before updating.
|
|
with jax.named_scope("sync_gradients"):
|
|
grads = sync_gradients(grads, (config.data_axis_name,))
|
|
# then update model
|
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
|
|
|
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
|
|
# and only synchronize them before logging. For simplicity, we sum them here.
|
|
with jax.named_scope("sync_metrics"):
|
|
step_metrics = jax.tree.map(
|
|
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
|
|
)
|
|
if metrics is None:
|
|
metrics = step_metrics
|
|
else:
|
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
|
return new_state, metrics
|
|
|
|
# %%
|
|
# jit the train_step_fsdp
|
|
train_step_fsdp_fn = jax.jit(
|
|
shard_map(
|
|
train_step_fsdp,
|
|
mesh,
|
|
in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
|
|
out_specs=(state_fsdp_specs, P()),
|
|
check_rep=False,
|
|
),
|
|
donate_argnames=("state", "metrics"),
|
|
)
|
|
|
|
# get the metric shape to initialize accumulator arrays for metrics
|
|
_, metric_shapes = jax.eval_shape(
|
|
train_step_fsdp_fn,
|
|
state_fsdp,
|
|
None,
|
|
batch,
|
|
)
|
|
metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
|
|
# %%
|
|
# train
|
|
start_time = time.time()
|
|
for _ in range(15):
|
|
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
|
|
state_fsdp,
|
|
metrics_fsdp,
|
|
batch)
|
|
duration = time.time() - start_time
|
|
print(duration)
|
|
|
|
# get metrics and state
|
|
final_metrics_fsdp = jax.tree.map(
|
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
|
metric_shapes)
|
|
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
|
|
state_fsdp,
|
|
final_metrics_fsdp,
|
|
batch)
|
|
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
|
|
|
|
|
|
# %%
|