# %% [markdown] # # Fully-Sharded Data Parallelism # MARK: START # %% # let's make 8-device simulator import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = True flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=false " "--xla_gpu_enable_async_collectives=true " "--xla_gpu_enable_latency_hiding_scheduler=true " "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags import functools from pprint import pprint from typing import Any, Dict, Tuple, Callable, Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from ml_collections import ConfigDict import optax import logging import time PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] jax.config.update('jax_platform_name', 'cpu') # %% # required functions: # Batch # TrainState # accumulate_gradients # print_metrics from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics # %% # import the fold_rng_over_axis def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey: """Folds the random number generator over the given axis. This is useful for generating a different random number for each device across a certain axis (e.g. the model axis). Args: rng: The random number generator. axis_name: The axis name to fold the random number generator over. Returns: A new random number generator, different for each device index along the axis. """ axis_index = jax.lax.axis_index(axis_name) return jax.random.fold_in(rng, axis_index) # MARK: DATA PARALLELISM # %% [markdown] # # Data Parallelism # we start with plain data parallelism # # using shard_map, we write single-device code and let shard map handle the rest # %% # plain data parallel - sharding only data inputs and outputs class DPClassifier(nn.Module): # contains the attributes listed in config # hidden_size # dropout_rate # dtype - for computation # num_classes # data_axis_name config: ConfigDict # note how there is no data_axis_name within the actual __call__ @nn.compact def __call__(self, x: jax.Array, train: bool) -> jax.Array: x = nn.Dense( features=self.config.hidden_size, dtype=self.config.dtype, name="input_dense", )(x) x = nn.silu(x) x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x) x = nn.Dense( features=self.config.num_classes, dtype=self.config.dtype, name="output_dense", )(x) x = x.astype(jnp.float32) return x # config data_config = ConfigDict( dict( batch_size=128, num_classes=10, input_size=784, ) ) model_config = ConfigDict( dict( hidden_size=512, dropout_rate=0.1, dtype=jnp.bfloat16, num_classes=data_config.num_classes, data_axis_name="data", ) ) optimizer_config = ConfigDict( dict( learning_rate=1e-3, num_minibatches=4, ) ) config = ConfigDict( dict( model=model_config, optimizer=optimizer_config, data=data_config, data_axis_name=model_config.data_axis_name, seed=42, ) ) # %% # initialize model_dp = DPClassifier(config=config.model) optimizer = optax.adamw( learning_rate=config.optimizer.learning_rate, ) # init rng rng = jax.random.PRNGKey(config.seed) # init model rng model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3) # create synthetic data batch = Batch( inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)), labels=jax.random.randint( data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes ), ) # init data_parallel TrainState state def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState: init_rng, rng = jax.random.split(rng) variables = model.init({"params": init_rng}, x, train=False) params = variables.pop("params") state = TrainState.create( apply_fn=model.apply, params=params, tx=optimizer, rng=rng, ) return state # create mesh device_array = np.array(jax.devices()) mesh = Mesh(device_array, (config.data_axis_name,)) # we are just sharding the same model across devices # no different from a flax replicate init_dp_fn = jax.jit( shard_map( functools.partial(init_dp, model=model_dp), mesh, in_specs=(P(), P(config.data_axis_name)), out_specs=P(), check_rep=False, ), ) state_dp = init_dp_fn(model_init_rng, batch.inputs) print("DP Parameters") pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params)) # MARK: TRAIN STEP # %% # train step def loss_fn( params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array ) -> Tuple[jax.Array, Dict[str, Any]]: # set different rng over various devices dropout_rng = fold_rng_over_axis(rng, config.data_axis_name) # Remaining computation is the same as before for single device. logits = apply_fn( {"params": params}, batch.inputs, train=True, rngs={"dropout": dropout_rng}) loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels) correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels) batch_size = batch.inputs.shape[0] step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)} loss = loss.mean() return loss, step_metrics # train step dp # simple data parallel has the model on every device # but each device has different data def train_step_dp( state: TrainState, metrics: Metrics | None, batch: Batch, ) -> Tuple[TrainState, Metrics]: rng, step_rng = jax.random.split(state.rng) # accumulate gradients like before grads, step_metrics = accumulate_gradients( state, batch, step_rng, config.optimizer.num_minibatches, loss_fn=loss_fn, ) # Update parameters. We need to sync the gradients across devices before updating. with jax.named_scope("sync_gradients"): grads = jax.tree.map( lambda g: jax.lax.pmean( g, axis_name=config.data_axis_name), grads) new_state = state.apply_gradients(grads=grads, rng=rng) # Sum metrics across replicas. Alternatively, we could keep the metrics separate # and only synchronize them before logging. For simplicity, we sum them here. with jax.named_scope("sync_metrics"): step_metrics = jax.tree.map( lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics ) if metrics is None: metrics = step_metrics else: # combine all the synced metrics metrics = jax.tree.map(jnp.add, metrics, step_metrics) return new_state, metrics # %% # we will now wrap the train step with shard_map and jit it # here we will be sharding input and output data train_step_dp_fn = jax.jit( shard_map( train_step_dp, mesh, in_specs=(P(), P(), P(config.data_axis_name)), out_specs=(P(), P()), check_rep=False, ), # state and metrics change and won't be re-used # pass by reference and throw away with function donate_argnames=("state", "metrics"), ) # %% # get the metric_shapes so that we can init arrays for accumulation _, metric_shapes = jax.eval_shape( train_step_dp_fn, state_dp, None, batch, ) # init arrays with shape metrics_dp = jax.tree.map( lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) # %% start_time = time.time() for _ in range(15): state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch) duration = time.time() - start_time print(duration) final_metrics_dp = jax.tree.map( lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) state_dp, final_metrics_dp = train_step_dp_fn( state_dp, final_metrics_dp, batch) print_metrics(final_metrics_dp) # %% print("DP Parameters") pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params)) print("Metrics") pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp)) #################################################################### # stuff works until here # it is still same as flax replicate style in huggingface # MARK: PARAMETER SHARDING # %% [markdown] # # parameter sharding # Basic strategy: init full parameters on each device, then use # jax.lax.axis_index to split parameters across devices, and keep a shard on # each device # # use nn.Partitioned to annotate sharding spec on parameters # quite similar to PartitionSpec # # parameters are either jax.Array or a flax.linen.Partitioned # %% # type annotation Parameter = jax.Array | nn.Partitioned # %% # function to shard parameters across devices # look for an axis to equally split across the number of devices # we can specify which parameters to shard, since they vary in size # we set a floor on the size for sharding @jax.named_scope("shard_params") def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree: """Shard parameters across the given mesh axis. Args: params: The parameters to shard. axis_name: The axis to shard parameters across. min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. Returns: PyTree of same structure as params, but with leaves sharded over new axis if possible. """ # axis_index axis_idx = jax.lax.axis_index(axis_name) # number of units in the axis axis_size = jax.lax.psum(1, axis_name) # split function # check each parameter if it had been sharded def _split(x: Parameter) -> Parameter: # already sharded if isinstance(x, nn.Partitioned): value, names = x.value, x.names # not sharded else: value = x names = (None,) * value.ndim # logging only runs on first jit # this section checks for why a parameter is not already sharded on the axis # check for sharded parameters despite being sharded # (that means its on a different axis) if axis_name in names: logging.warning( f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}." ) return x # check if parameter is to small elif value.size <= min_weight_size: logging.info( f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}." ) return x # let's start sharding! else: shape = value.shape idx = np.argsort(shape)[::-1] # Shard along largest possible axis. for i in idx: # this technically runs once because of return # we only shard if we can split evenly across devices # and if it ain't alreayd sharded if shape[i] % axis_size == 0 and names[i] is None: split_size = shape[i] // axis_size p_sharded = nn.Partitioned( value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device. value, axis_idx * split_size, split_size, axis=i ), names=names[:i] + (axis_name,) + names[i + 1 :], ) return p_sharded logging.warning( f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found." ) return x # we apply the _split function across the parameter pytree return jax.tree.map( _split, params, is_leaf=lambda x: isinstance( x, nn.Partitioned ), # Consider a nn.Partitioned object as a leaf. ) # %% # function to gather parameters back to a single device # but first we need create a custom function for mean gradient computation # jax.lax.all_gather -> retrieve shards and assemble full array in each device # jax.lax.psum_scatter -> scatter gradients back to respective devices def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str): """Gathering with averaging gradients across replicas.""" axis_size = jax.lax.psum(1, axis_name) # Define a custom gradient for the gather operation. @jax.custom_gradient def f(x): # adjust backward to turn sum into mean of axis def grad_fn(g): # pmean_scatter from psum_scatter # after computing from full gradient array, our shard only has a # portion of the parameters, we only get the gradients associated # with parameters of our shard return ( jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size ) # assemble shards to form full gradient array return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn return f(x) # gather params back - e.g. when computing a module forward call # reverse operation of "shard_params" # depends on: gather_array_with_mean_grads @jax.named_scope("gather_params") def gather_params(params: PyTree, axis_name: str) -> PyTree: """Gather parameters from all replicas across the given axis. Args: params: The parameters to gather. axis_name: The axis to gather parameters across. Returns: PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object. """ def _gather(p: Parameter) -> Parameter: if isinstance(p, nn.Partitioned) and axis_name in p.names: param_shard = p.names shard_axis = param_shard.index(axis_name) value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name) # If there are any other axes that are sharded, we need to keep the partitioned structure. # Otherwise, we can return the value directly. param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :] if any([name is not None for name in param_shard]): # we return the still-sharded axes shard return nn.Partitioned(value, param_shard) else: return value else: return p # we find all the sharded params and gather them, returning a complete parameter return jax.tree.map( _gather, params, is_leaf=lambda x: isinstance(x, nn.Partitioned)) # %% # when we call a module, we gather the parameters back to a single device # wrap a module into a nn.map_variables transform # allows for transforms on the parameter before and after a module call # depends on: gather_params, shard_params def shard_module_params( target: nn.Module | Callable, axis_name: str, min_weight_size: int = 2**18 # 262,144 ) -> nn.Module | Callable: """Shard parameters of a module across replicas. Args: target: The module to shard. axis_name: The axis name to shard parameters across. min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. Returns: The module with sharded parameters. """ return nn.map_variables( target, trans_in_fn=functools.partial( gather_params, axis_name=axis_name), trans_out_fn=functools.partial( shard_params, axis_name=axis_name, min_weight_size=min_weight_size ), mapped_collections="params", mutable=True, ) # %% # define new function with axes constraints # this forms the template for sharding future modules # remember, flax modules are subclassed from elementary flax modules class FSDPClassifier(nn.Module): config: ConfigDict @nn.compact def __call__(self, x: jax.Array, train: bool) -> jax.Array: # create a sharded module sharded_dense = shard_module_params( nn.Dense, axis_name=self.config.data_axis_name, # axes min_weight_size=self.config.min_weight_size, # min_weight ) x = sharded_dense( features=self.config.hidden_size, dtype=self.config.dtype, name="input_dense", )(x) x = nn.silu(x) x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x) x = sharded_dense( features=self.config.num_classes, dtype=self.config.dtype, name="output_dense", )(x) x = x.astype(jnp.float32) return x # %% # initialization config.model.min_weight_size = 2**4 model_fsdp = FSDPClassifier(config=config.model) # the earlier init function def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState: init_rng, rng = jax.random.split(rng) variables = model.init({"params": init_rng}, x, train=False) params = variables.pop("params") state = TrainState.create( apply_fn=model.apply, params=params, tx=optimizer, rng=rng, ) return state # initialize our sharded model with mesh # we need to adjust the shard map since partitioning is determined within the # model init, hence we cannot manually specify it # # we do a hack where we just try and let it evaluate the shapes # we set an unknown output specification - aka fully replicate # # we then get the partition_spec of the shapes of the parameters init_fsdp_fn = shard_map( functools.partial(init_dp, model=model_fsdp), mesh, # first P() is for model_init_rng # second P(config.data_axis_name) is for batch.inputs in_specs=(P(), P(config.data_axis_name)), # not partitioned, fully replicated out_specs=P(), check_rep=False, # disable checks for replication errors in out_specs ) state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs) state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes) # %% [raw] # TrainState(step=PartitionSpec(), apply_fn=, # params={ # 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, # 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, # tx=GradientTransformationExtraArgs(init=.init_fn at 0x761e8ef00400>, # update=.update_fn at 0x761e8ef01080>), # opt_state=(ScaleByAdamState(count=PartitionSpec(), # mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, # 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, # nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, # 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}), # EmptyState(), EmptyState()), rng=PartitionSpec()) # %% # then from the state_fsdp_specs, we obtain our config # this print clarifies everything -> the reason why earlier we do not know the # partitionspec is because we only know which parameters gets to be sharded at # model init print("RNG", state_fsdp_specs.rng) print("\nParameters") pprint(state_fsdp_specs.params) print("\nOptimizer state") pprint(state_fsdp_specs.opt_state[0]) # %% # init again, this time with the specs and knowledge of what is and should not # be sharded init_fsdp_fn = jax.jit( shard_map( functools.partial(init_dp, model=model_fsdp), mesh, in_specs=(P(), P(config.data_axis_name)), out_specs=state_fsdp_specs, check_rep=False, ) ) state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs) # %% print("FSDP Parameters") pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params))) # %% # train step # we need to handle the sync of gradients # some parameters are sharded, some are not def sync_gradients( grads: PyTree, axis_names: Sequence[str], ) -> PyTree: """Synchronize gradients across devices. Gradients for parameters that are replicated over a given axis are averaged across devices. Parameters that are partitioned over a given axis are considered to already have a mean of the gradients on each device, and hence do not need to be altered. Args: grads: The gradients to synchronize. axis_names: The axis names to synchronize gradients across. Returns: The gradients averaged over the specified axes if they are replicated. """ def sync_grad(g: Parameter) -> Parameter: if isinstance(g, nn.Partitioned): # Tree leaves for flattening potentially nested axis (multiple names # can exist for single array axis). replication_axis_names = [ name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names) ] if len(replication_axis_names) == 0: # Parameters partitioned over all axes. return g else: # Average over remaining replicated axes. return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names)) else: # Parameters are replicated over all axes. return jax.lax.pmean(g, axis_name=axis_names) return jax.tree.map( sync_grad, grads, is_leaf=lambda x: isinstance(x, nn.Partitioned)) # %% def train_step_fsdp( state: TrainState, metrics: Metrics, batch: Batch, ) -> Tuple[TrainState, Metrics]: rng, step_rng = jax.random.split(state.rng) # perform one forward pass grads, step_metrics = accumulate_gradients( state, batch, step_rng, config.optimizer.num_minibatches, loss_fn=loss_fn, ) # Update parameters. We need to sync the gradients across devices before updating. with jax.named_scope("sync_gradients"): grads = sync_gradients(grads, (config.data_axis_name,)) # then update model new_state = state.apply_gradients(grads=grads, rng=rng) # Sum metrics across replicas. Alternatively, we could keep the metrics separate # and only synchronize them before logging. For simplicity, we sum them here. with jax.named_scope("sync_metrics"): step_metrics = jax.tree.map( lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics ) if metrics is None: metrics = step_metrics else: metrics = jax.tree.map(jnp.add, metrics, step_metrics) return new_state, metrics # %% # jit the train_step_fsdp train_step_fsdp_fn = jax.jit( shard_map( train_step_fsdp, mesh, in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)), out_specs=(state_fsdp_specs, P()), check_rep=False, ), donate_argnames=("state", "metrics"), ) # get the metric shape to initialize accumulator arrays for metrics _, metric_shapes = jax.eval_shape( train_step_fsdp_fn, state_fsdp, None, batch, ) metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) # %% # train start_time = time.time() for _ in range(15): state_fsdp, metrics_fsdp = train_step_fsdp_fn( state_fsdp, metrics_fsdp, batch) duration = time.time() - start_time print(duration) # get metrics and state final_metrics_fsdp = jax.tree.map( lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) state_fsdp, final_metrics_fsdp = train_step_fsdp_fn( state_fsdp, final_metrics_fsdp, batch) print_metrics(final_metrics_fsdp, "FSDP - Final metrics") # %%