diff --git a/parallel/.gitignore b/parallel/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/parallel/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/parallel/dataload.py b/parallel/dataload.py new file mode 100644 index 0000000..eea2570 --- /dev/null +++ b/parallel/dataload.py @@ -0,0 +1,172 @@ +# %% +# Prepare dataloader for jax training +from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk +from transformers import FlaxT5ForConditionalGeneration +from datasets import ClassLabel, Value, Sequence +from ml_collections import ConfigDict +import numpy as np +import jax.numpy as jnp +import jax +import math +from typing import Optional, List, Tuple, Callable, cast + + +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' +# file_path = 'combined_data' +# split_datasets = load_from_disk(file_path) +# training_size = len(split_datasets['train']) + +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + +model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") + +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 + + +# class takes in a dataset +class DataPrepare(): + + def __init__(self, raw_dataset, config): + self.raw_dataset: Dataset = raw_dataset + self.train_dataset: Optional[Dataset] = None + self.size: int = len(raw_dataset) + self.config: ConfigDict = config + + self.make_dataset() + + # In Flax, for seq2seq models we need to pass `decoder_input_ids` + # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here + # for that dynamically import the `shift_tokens_right` function from the model file + + # given a dataset entry, run it through the tokenizer + # Setting padding="max_length" as we need fixed length inputs for jitted functions + def preprocess_function(self, example: Dataset): + inputs = example['input'] + targets = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + # produce input_ids and decoder_input_ids + model_inputs = tokenizer( + inputs, + max_length=self.config.max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + labels = tokenizer( + text_target=targets, + max_length=self.config.max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + # for loss computation + model_inputs["labels"] = labels["input_ids"] + # make decoder input ids + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id + ) + # require by model + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + + def make_dataset(self): + train_dataset = self.raw_dataset.map( + self.preprocess_function, + batched=True, + num_proc=1, + # if we do not remove, we keep the original data + remove_columns=self.raw_dataset.column_names,) + + # set to numpy + train_dataset.set_format( + type='numpy', + columns=[ + 'input_ids', 'attention_mask', + 'labels', 'decoder_input_ids', + 'decoder_attention_mask'] + ) + + # check that data fits + for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: + int_array: np.array = train_dataset[name] + if np.all((int_array >= 0) & (int_array <= 65535)): + continue + else: + raise ValueError("Values are out of range for uint16") + + # change to compact datatypes + features = train_dataset.features.copy() + features['input_ids'] = Sequence(Value('uint16')) + features['attention_mask'] = Sequence(Value('bool')) + features['labels'] = Sequence(Value('uint16')) + features['decoder_input_ids'] = Sequence(Value('uint16')) + features['decoder_attention_mask'] = Sequence(Value('bool')) + train_dataset = train_dataset.cast(features) + # assign the dataset to train_dataset + self.train_dataset = train_dataset + + def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + assert(self.train_dataset is not None) + dataset: Dataset = cast(Dataset, self.train_dataset) + + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: jnp.array(v) for k, v in batch.items()} + + yield batch + + +# testing out the class +# # %% +# # init object +# # e.g. Config +# data_config = ConfigDict( +# dict( +# max_length=86, +# pad_token_id=0, +# decoder_start_token_id=0 +# ) +# ) +# +# dataprep = DataPrepare(split_datasets, data_config) +# +# # %% +# seed = 117 +# rng = jax.random.PRNGKey(seed) +# train_loader = dataprep.data_loader(rng, batch_size=32) +# +# +# +# # %% +# batch = next(iter(train_loader)) +# batch['input_ids'].shape +# # %% diff --git a/parallel/fully_sharded_data_parallelism.py b/parallel/fully_sharded_data_parallelism.py new file mode 100644 index 0000000..098a389 --- /dev/null +++ b/parallel/fully_sharded_data_parallelism.py @@ -0,0 +1,754 @@ + +# %% [markdown] +# # Fully-Sharded Data Parallelism + + +# MARK: START +# %% +# let's make 8-device simulator +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +import functools +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from ml_collections import ConfigDict +import optax +import logging +import time + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] +jax.config.update('jax_platform_name', 'cpu') + +# %% +# required functions: +# Batch +# TrainState +# accumulate_gradients +# print_metrics +from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics +# %% +# import the fold_rng_over_axis + +def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey: + """Folds the random number generator over the given axis. + + This is useful for generating a different random number for each device + across a certain axis (e.g. the model axis). + + Args: + rng: The random number generator. + axis_name: The axis name to fold the random number generator over. + + Returns: + A new random number generator, different for each device index along the axis. + """ + axis_index = jax.lax.axis_index(axis_name) + return jax.random.fold_in(rng, axis_index) + +# MARK: DATA PARALLELISM +# %% [markdown] +# # Data Parallelism +# we start with plain data parallelism +# +# using shard_map, we write single-device code and let shard map handle the rest + +# %% +# plain data parallel - sharding only data inputs and outputs +class DPClassifier(nn.Module): + # contains the attributes listed in config + # hidden_size + # dropout_rate + # dtype - for computation + # num_classes + # data_axis_name + config: ConfigDict + + # note how there is no data_axis_name within the actual __call__ + @nn.compact + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + x = nn.Dense( + features=self.config.hidden_size, + dtype=self.config.dtype, + name="input_dense", + )(x) + x = nn.silu(x) + x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x) + x = nn.Dense( + features=self.config.num_classes, + dtype=self.config.dtype, + name="output_dense", + )(x) + x = x.astype(jnp.float32) + return x + +# config +data_config = ConfigDict( + dict( + batch_size=128, + num_classes=10, + input_size=784, + ) +) +model_config = ConfigDict( + dict( + hidden_size=512, + dropout_rate=0.1, + dtype=jnp.bfloat16, + num_classes=data_config.num_classes, + data_axis_name="data", + ) +) +optimizer_config = ConfigDict( + dict( + learning_rate=1e-3, + num_minibatches=4, + ) +) +config = ConfigDict( + dict( + model=model_config, + optimizer=optimizer_config, + data=data_config, + data_axis_name=model_config.data_axis_name, + seed=42, + ) +) + +# %% +# initialize +model_dp = DPClassifier(config=config.model) +optimizer = optax.adamw( + learning_rate=config.optimizer.learning_rate, +) + +# init rng +rng = jax.random.PRNGKey(config.seed) +# init model rng +model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3) +# create synthetic data +batch = Batch( + inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)), + labels=jax.random.randint( + data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes + ), +) + +# init data_parallel TrainState state +def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState: + init_rng, rng = jax.random.split(rng) + variables = model.init({"params": init_rng}, x, train=False) + params = variables.pop("params") + state = TrainState.create( + apply_fn=model.apply, + params=params, + tx=optimizer, + rng=rng, + ) + return state + +# create mesh +device_array = np.array(jax.devices()) +mesh = Mesh(device_array, (config.data_axis_name,)) + +# we are just sharding the same model across devices +# no different from a flax replicate +init_dp_fn = jax.jit( + shard_map( + functools.partial(init_dp, model=model_dp), + mesh, + in_specs=(P(), P(config.data_axis_name)), + out_specs=P(), + check_rep=False, + ), +) + +state_dp = init_dp_fn(model_init_rng, batch.inputs) +print("DP Parameters") +pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params)) + +# MARK: TRAIN STEP +# %% +# train step +def loss_fn( + params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array +) -> Tuple[jax.Array, Dict[str, Any]]: + + # set different rng over various devices + dropout_rng = fold_rng_over_axis(rng, config.data_axis_name) + + # Remaining computation is the same as before for single device. + logits = apply_fn( + {"params": params}, + batch.inputs, + train=True, + rngs={"dropout": dropout_rng}) + loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels) + correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels) + batch_size = batch.inputs.shape[0] + step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)} + loss = loss.mean() + return loss, step_metrics + +# train step dp +# simple data parallel has the model on every device +# but each device has different data +def train_step_dp( + state: TrainState, + metrics: Metrics | None, + batch: Batch, +) -> Tuple[TrainState, Metrics]: + rng, step_rng = jax.random.split(state.rng) + # accumulate gradients like before + grads, step_metrics = accumulate_gradients( + state, + batch, + step_rng, + config.optimizer.num_minibatches, + loss_fn=loss_fn, + ) + # Update parameters. We need to sync the gradients across devices before updating. + with jax.named_scope("sync_gradients"): + grads = jax.tree.map( + lambda g: jax.lax.pmean( + g, axis_name=config.data_axis_name), + grads) + new_state = state.apply_gradients(grads=grads, rng=rng) + + # Sum metrics across replicas. Alternatively, we could keep the metrics separate + # and only synchronize them before logging. For simplicity, we sum them here. + with jax.named_scope("sync_metrics"): + step_metrics = jax.tree.map( + lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics + ) + + if metrics is None: + metrics = step_metrics + else: + # combine all the synced metrics + metrics = jax.tree.map(jnp.add, metrics, step_metrics) + + return new_state, metrics + +# %% +# we will now wrap the train step with shard_map and jit it +# here we will be sharding input and output data +train_step_dp_fn = jax.jit( + shard_map( + train_step_dp, + mesh, + in_specs=(P(), P(), P(config.data_axis_name)), + out_specs=(P(), P()), + check_rep=False, + ), + # state and metrics change and won't be re-used + # pass by reference and throw away with function + donate_argnames=("state", "metrics"), +) + +# %% +# get the metric_shapes so that we can init arrays for accumulation +_, metric_shapes = jax.eval_shape( + train_step_dp_fn, + state_dp, + None, + batch, +) +# init arrays with shape +metrics_dp = jax.tree.map( + lambda x: jnp.zeros(x.shape, dtype=x.dtype), + metric_shapes) + + +# %% +start_time = time.time() +for _ in range(15): + state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch) +duration = time.time() - start_time +print(duration) + +final_metrics_dp = jax.tree.map( + lambda x: jnp.zeros(x.shape, dtype=x.dtype), + metric_shapes) +state_dp, final_metrics_dp = train_step_dp_fn( + state_dp, + final_metrics_dp, + batch) +print_metrics(final_metrics_dp) + +# %% +print("DP Parameters") +pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params)) +print("Metrics") +pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp)) + +#################################################################### +# stuff works until here +# it is still same as flax replicate style in huggingface + + +# MARK: PARAMETER SHARDING +# %% [markdown] +# # parameter sharding +# Basic strategy: init full parameters on each device, then use +# jax.lax.axis_index to split parameters across devices, and keep a shard on +# each device +# +# use nn.Partitioned to annotate sharding spec on parameters +# quite similar to PartitionSpec +# +# parameters are either jax.Array or a flax.linen.Partitioned + +# %% +# type annotation +Parameter = jax.Array | nn.Partitioned + +# %% +# function to shard parameters across devices +# look for an axis to equally split across the number of devices +# we can specify which parameters to shard, since they vary in size +# we set a floor on the size for sharding +@jax.named_scope("shard_params") +def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree: + """Shard parameters across the given mesh axis. + + Args: + params: The parameters to shard. + axis_name: The axis to shard parameters across. + min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. + + Returns: + PyTree of same structure as params, but with leaves sharded over new axis if possible. + """ + # axis_index + axis_idx = jax.lax.axis_index(axis_name) + # number of units in the axis + axis_size = jax.lax.psum(1, axis_name) + + # split function + # check each parameter if it had been sharded + def _split(x: Parameter) -> Parameter: + + # already sharded + if isinstance(x, nn.Partitioned): + value, names = x.value, x.names + # not sharded + else: + value = x + names = (None,) * value.ndim + + # logging only runs on first jit + # this section checks for why a parameter is not already sharded on the axis + # check for sharded parameters despite being sharded + # (that means its on a different axis) + if axis_name in names: + logging.warning( + f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}." + ) + return x + # check if parameter is to small + elif value.size <= min_weight_size: + logging.info( + f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}." + ) + return x + # let's start sharding! + else: + shape = value.shape + idx = np.argsort(shape)[::-1] # Shard along largest possible axis. + for i in idx: + # this technically runs once because of return + # we only shard if we can split evenly across devices + # and if it ain't alreayd sharded + if shape[i] % axis_size == 0 and names[i] is None: + split_size = shape[i] // axis_size + p_sharded = nn.Partitioned( + value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device. + value, + axis_idx * split_size, + split_size, + axis=i + ), + names=names[:i] + (axis_name,) + names[i + 1 :], + ) + return p_sharded + + logging.warning( + f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found." + ) + return x + + # we apply the _split function across the parameter pytree + return jax.tree.map( + _split, + params, + is_leaf=lambda x: isinstance( + x, nn.Partitioned + ), # Consider a nn.Partitioned object as a leaf. + ) + +# %% +# function to gather parameters back to a single device + +# but first we need create a custom function for mean gradient computation +# jax.lax.all_gather -> retrieve shards and assemble full array in each device +# jax.lax.psum_scatter -> scatter gradients back to respective devices +def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str): + """Gathering with averaging gradients across replicas.""" + axis_size = jax.lax.psum(1, axis_name) + + # Define a custom gradient for the gather operation. + @jax.custom_gradient + def f(x): + # adjust backward to turn sum into mean of axis + def grad_fn(g): + # pmean_scatter from psum_scatter + # after computing from full gradient array, our shard only has a + # portion of the parameters, we only get the gradients associated + # with parameters of our shard + return ( + jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size + ) + + # assemble shards to form full gradient array + return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn + + return f(x) + +# gather params back - e.g. when computing a module forward call +# reverse operation of "shard_params" +# depends on: gather_array_with_mean_grads +@jax.named_scope("gather_params") +def gather_params(params: PyTree, axis_name: str) -> PyTree: + """Gather parameters from all replicas across the given axis. + + Args: + params: The parameters to gather. + axis_name: The axis to gather parameters across. + + Returns: + PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object. + """ + + def _gather(p: Parameter) -> Parameter: + if isinstance(p, nn.Partitioned) and axis_name in p.names: + param_shard = p.names + shard_axis = param_shard.index(axis_name) + value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name) + + # If there are any other axes that are sharded, we need to keep the partitioned structure. + # Otherwise, we can return the value directly. + param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :] + if any([name is not None for name in param_shard]): + # we return the still-sharded axes shard + return nn.Partitioned(value, param_shard) + else: + return value + else: + return p + + # we find all the sharded params and gather them, returning a complete parameter + return jax.tree.map( + _gather, + params, + is_leaf=lambda x: isinstance(x, nn.Partitioned)) + +# %% +# when we call a module, we gather the parameters back to a single device +# wrap a module into a nn.map_variables transform +# allows for transforms on the parameter before and after a module call +# depends on: gather_params, shard_params +def shard_module_params( + target: nn.Module | Callable, + axis_name: str, + min_weight_size: int = 2**18 # 262,144 +) -> nn.Module | Callable: + """Shard parameters of a module across replicas. + + Args: + target: The module to shard. + axis_name: The axis name to shard parameters across. + min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. + + Returns: + The module with sharded parameters. + """ + return nn.map_variables( + target, + trans_in_fn=functools.partial( + gather_params, axis_name=axis_name), + trans_out_fn=functools.partial( + shard_params, axis_name=axis_name, min_weight_size=min_weight_size + ), + mapped_collections="params", + mutable=True, + ) + +# %% +# define new function with axes constraints +# this forms the template for sharding future modules +# remember, flax modules are subclassed from elementary flax modules +class FSDPClassifier(nn.Module): + config: ConfigDict + + @nn.compact + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + # create a sharded module + sharded_dense = shard_module_params( + nn.Dense, + axis_name=self.config.data_axis_name, # axes + min_weight_size=self.config.min_weight_size, # min_weight + ) + x = sharded_dense( + features=self.config.hidden_size, + dtype=self.config.dtype, + name="input_dense", + )(x) + x = nn.silu(x) + x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x) + x = sharded_dense( + features=self.config.num_classes, + dtype=self.config.dtype, + name="output_dense", + )(x) + x = x.astype(jnp.float32) + return x + +# %% +# initialization +config.model.min_weight_size = 2**4 +model_fsdp = FSDPClassifier(config=config.model) + +# the earlier init function +def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState: + init_rng, rng = jax.random.split(rng) + variables = model.init({"params": init_rng}, x, train=False) + params = variables.pop("params") + state = TrainState.create( + apply_fn=model.apply, + params=params, + tx=optimizer, + rng=rng, + ) + return state + + +# initialize our sharded model with mesh +# we need to adjust the shard map since partitioning is determined within the +# model init, hence we cannot manually specify it +# +# we do a hack where we just try and let it evaluate the shapes +# we set an unknown output specification - aka fully replicate +# +# we then get the partition_spec of the shapes of the parameters +init_fsdp_fn = shard_map( + functools.partial(init_dp, model=model_fsdp), + mesh, + # first P() is for model_init_rng + # second P(config.data_axis_name) is for batch.inputs + in_specs=(P(), P(config.data_axis_name)), + # not partitioned, fully replicated + out_specs=P(), + check_rep=False, # disable checks for replication errors in out_specs +) +state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs) +state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes) +# %% [raw] +# TrainState(step=PartitionSpec(), apply_fn=, +# params={ +# 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, +# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, +# tx=GradientTransformationExtraArgs(init=.init_fn at 0x761e8ef00400>, +# update=.update_fn at 0x761e8ef01080>), +# opt_state=(ScaleByAdamState(count=PartitionSpec(), +# mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, +# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}, +# nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)}, +# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}), +# EmptyState(), EmptyState()), rng=PartitionSpec()) + +# %% +# then from the state_fsdp_specs, we obtain our config +# this print clarifies everything -> the reason why earlier we do not know the +# partitionspec is because we only know which parameters gets to be sharded at +# model init +print("RNG", state_fsdp_specs.rng) +print("\nParameters") +pprint(state_fsdp_specs.params) +print("\nOptimizer state") +pprint(state_fsdp_specs.opt_state[0]) + +# %% +# init again, this time with the specs and knowledge of what is and should not +# be sharded +init_fsdp_fn = jax.jit( + shard_map( + functools.partial(init_dp, model=model_fsdp), + mesh, + in_specs=(P(), P(config.data_axis_name)), + out_specs=state_fsdp_specs, + check_rep=False, + ) +) +state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs) + +# %% +print("FSDP Parameters") +pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params))) + +# %% +# train step + +# we need to handle the sync of gradients +# some parameters are sharded, some are not +def sync_gradients( + grads: PyTree, + axis_names: Sequence[str], +) -> PyTree: + """Synchronize gradients across devices. + + Gradients for parameters that are replicated over a given axis are averaged across devices. + Parameters that are partitioned over a given axis are considered to already have a mean of + the gradients on each device, and hence do not need to be altered. + + Args: + grads: The gradients to synchronize. + axis_names: The axis names to synchronize gradients across. + + Returns: + The gradients averaged over the specified axes if they are replicated. + """ + + def sync_grad(g: Parameter) -> Parameter: + if isinstance(g, nn.Partitioned): + # Tree leaves for flattening potentially nested axis (multiple names + # can exist for single array axis). + replication_axis_names = [ + name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names) + ] + if len(replication_axis_names) == 0: + # Parameters partitioned over all axes. + return g + else: + # Average over remaining replicated axes. + return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names)) + else: + # Parameters are replicated over all axes. + return jax.lax.pmean(g, axis_name=axis_names) + + return jax.tree.map( + sync_grad, + grads, + is_leaf=lambda x: isinstance(x, nn.Partitioned)) + +# %% +def train_step_fsdp( + state: TrainState, + metrics: Metrics, + batch: Batch, +) -> Tuple[TrainState, Metrics]: + rng, step_rng = jax.random.split(state.rng) + # perform one forward pass + grads, step_metrics = accumulate_gradients( + state, + batch, + step_rng, + config.optimizer.num_minibatches, + loss_fn=loss_fn, + ) + # Update parameters. We need to sync the gradients across devices before updating. + with jax.named_scope("sync_gradients"): + grads = sync_gradients(grads, (config.data_axis_name,)) + # then update model + new_state = state.apply_gradients(grads=grads, rng=rng) + + # Sum metrics across replicas. Alternatively, we could keep the metrics separate + # and only synchronize them before logging. For simplicity, we sum them here. + with jax.named_scope("sync_metrics"): + step_metrics = jax.tree.map( + lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics + ) + if metrics is None: + metrics = step_metrics + else: + metrics = jax.tree.map(jnp.add, metrics, step_metrics) + return new_state, metrics + +# %% +# jit the train_step_fsdp +train_step_fsdp_fn = jax.jit( + shard_map( + train_step_fsdp, + mesh, + in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)), + out_specs=(state_fsdp_specs, P()), + check_rep=False, + ), + donate_argnames=("state", "metrics"), +) + +# get the metric shape to initialize accumulator arrays for metrics +_, metric_shapes = jax.eval_shape( + train_step_fsdp_fn, + state_fsdp, + None, + batch, +) +metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) +# %% +# train +start_time = time.time() +for _ in range(15): + state_fsdp, metrics_fsdp = train_step_fsdp_fn( + state_fsdp, + metrics_fsdp, batch) +duration = time.time() - start_time +print(duration) + +# get metrics and state +final_metrics_fsdp = jax.tree.map( + lambda x: jnp.zeros(x.shape, dtype=x.dtype), + metric_shapes) +state_fsdp, final_metrics_fsdp = train_step_fsdp_fn( + state_fsdp, + final_metrics_fsdp, batch) +print_metrics(final_metrics_fsdp, "FSDP - Final metrics") + + +# %% diff --git a/parallel/intro_to_distributed.py b/parallel/intro_to_distributed.py new file mode 100644 index 0000000..c1f3dd6 --- /dev/null +++ b/parallel/intro_to_distributed.py @@ -0,0 +1,373 @@ +# %% [markdown] +# # Distribute computin in JAX + +# %% +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +# %% +import functools +from typing import Any, Dict, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] +jax.config.update('jax_platform_name', 'cpu') + +# %% +jax.devices() + + +# %% +# when we create array, we can check the location +a = jnp.arange(8) +print("Array", a) +print("Device", a.device) +print("Sharding", a.sharding) + +# %% [markdown] +# ## Single-Axis Mesh + +# %% +# let's create a Mesh +# multidimensional Numpy array of jax devices +# jax.sharding.Mesh(devices, axis_names) +mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",)) +print(mesh) + +# %% +# jax.sharding.NamedSharding(mesh, spec) +# pair of a Mesh of devices and PartitionSpec +# PartitionSpec describes how to share an array across that mesh +# "i" is the value of the dimension of the array +# to shard an array axis over a certain mesh axis, add the axis name at the +# corresponding position in the tuple +sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",)) + +# %% +a_sharded = jax.device_put(a, sharding) +print("Sharded array", a_sharded) +print("Device", a_sharded.devices()) +print("Sharding", a_sharded.sharding) + +# %% +jax.debug.visualize_array_sharding(a_sharded) + +# %% +# let's try some computation on the mesh +out = nn.tanh(a_sharded) +print("Output array", out) +jax.debug.visualize_array_sharding(out) +# note how the output array is sharded across the devices + +# %% [markdown] +# ## multi-axis mesh +# Why would you shard across multiple dimensions? +# +# + +# %% +mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j")) +# axis i/0 refers to the row-wise axis progressing downwards +# axis j/1 refers to the column-wise axis progressing rightward +mesh # noqa: B018 + +# %% +# we now illustrate sharded MAC operation +# y = x @ w + b +batch_size = 192 +input_dim = 64 +output_dim = 128 +# input: (batch_size, input_dim) +x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim)) +# w: (input_dim, output_dim) +w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim)) +# b: (output_dim,) +b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,)) + + +# %% +# x sharded along 0 axis (partition) +# +x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None))) +w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j"))) +b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j"))) + +print('x blocks:') +jax.debug.visualize_array_sharding(x_sharded) +print('w blocks:') +jax.debug.visualize_array_sharding(w_sharded) +print('b blocks:') +jax.debug.visualize_array_sharding(b_sharded) + + +# %% +out = jnp.dot(x_sharded, w_sharded) + b_sharded +print("Output shape", out.shape) +jax.debug.visualize_array_sharding(out) + + +# %% [markdown] +# # Shard Map -shmap +# +# beforehand, we manually assign the sharding partition to assign the exact +# partitions to achieve independent, parallel block matrix computation +# +# This allows us to write code with explicit control over parallelization and +# communication +# +# what is a shard_map? +# +# it is a transformation that takes a function, a mesh, and a sharding +# specification for inputs and outputs +# +# in other words, we write a function that executes on each device only, then +# apply across all the shards +# +# but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough +# information about the shards to efficiently perform sharding for complicated +# meshes. + +# %% +def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array: + print("Local x shape", x.shape) + print("Local w shape", w.shape) + print("Local b shape", b.shape) + # so simple! + return jnp.dot(x,w) + b + +# %% +matmul_sharded = shard_map( + matmul_fn, # the function for operating on a single device + mesh, # the device topology + # the input mesh partition argument for each input + in_specs=( + PartitionSpec("i", None), # x + PartitionSpec(None, "j"), # w + PartitionSpec("j") # b + ), + # the output to read from the mesh + out_specs=PartitionSpec("i", "j") +) + +# %% +# y = matmul_sharded(x_sharded, w_sharded, b_sharded) +# there is no need to device_put, +# partitioning is done according to your in_specs +y = matmul_sharded(x, w, b) +print("Output shape", y.shape) +jax.debug.visualize_array_sharding(y) + + +# %% [markdown] +# # Axis Communication + +# %% +# example of mean/sum across devices per shard + +# the following wants to find the statistics of x +# we compute the normalized x according to each row statistics (mean and std) +@functools.partial( + shard_map, + mesh=mesh, + in_specs=PartitionSpec("i", "j"), + out_specs=PartitionSpec("i", "j")) +def parallel_normalize(x: jax.Array) -> jax.Array: + # jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis + # "axis_name" + # get the mean across the "j" axis of the mesh - column wise + mean = jax.lax.pmean(x, axis_name="j") + # get the std across the "j" axis of the mesh - column wise + std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5 + return (x - mean) / std + +# communicated along "j" axis of mesh for row elements + + +out = parallel_normalize(x) +out = jax.device_get(out) +print(out.shape) +print("Mean", out.mean()) +print("Std", out.std()) + + +# %% +# scenario: array is sharded across devices, some values missing per shard +# all-gather: gather values of an array from all devices +@functools.partial( + shard_map, + mesh=mesh, + in_specs=( + PartitionSpec("i", None), # artificially shard across "i" + PartitionSpec("i", None) + ), + out_specs=PartitionSpec("i", None)) +def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array: + print("Original w shape", w.shape) + # pull the full w matrix values from neighboring devices + w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True) + print("Gathered w shape", w_gathered.shape) + y = jnp.dot(x, w_gathered) + return y + + +out = matmul_with_weight_gather(x, w) +out = jax.device_get(out) +np.testing.assert_array_equal(out, jnp.dot(x, w)) + +# %% +# scenario: arrays are sharded across all devices +# scatter sum: each function instance of each device gets only one shard of the result +# +# therefore each device gets the sum of some(or one) array(s) + +@functools.partial( + shard_map,mesh=mesh, + in_specs=PartitionSpec("i", None), + out_specs=PartitionSpec("i", None)) +def scatter_example(x: jax.Array) -> jax.Array: + x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1) + return x_scatter + + +x_exmp = np.array( + [ + [3, 1, 4, 1], + [5, 9, 2, 6], + [5, 3, 5, 8], + [9, 7, 1, 2], + ] +) +out = scatter_example(x_exmp) +print("Output", out) +# %% +# ppermute: communicates an array in a round robin fashion +# +# this is used in implementing pipeline parallelism where results are passed to another device +# used in tensor parallelism +# +# notice how the results roll through the devices +# +# this can actually implement all other lax communication operations + +@functools.partial( + shard_map, + mesh=mesh, + in_specs=PartitionSpec("i"), + out_specs=PartitionSpec("i")) +def ppermute_example(x: jax.Array) -> jax.Array: + axis_size = mesh.shape["i"] + print('BEFORE:\n', x) + x_perm = jax.lax.ppermute( + x, + axis_name="i", + perm=[ + # source_index, destination_index pairs + (i, (i + 1) % axis_size) for i in range(axis_size) + ] + ) + print('AFTER:\n', x_perm) + return x_perm + + +x_exmp = np.arange(4) +out = ppermute_example(x_exmp) +print("Output", out) # the value is that of each axis 0 device + + +# %% +# # axis indexing: get the index of device along axis +# sometimes our computations need adjustment depending on the device its being ran on +# +# we will use jax.lax.axis_index to return the index of the current device along an axis +# +# this function will be jitted and will be almost 0 cost + +axis_idx_fn = jax.jit( + shard_map( + lambda: jnp.stack( + [ + jax.lax.axis_index("i"), # Device index in mesh along the "i" axis + jax.lax.axis_index("j"), # Device index in mesh along the "j" axis + ], + axis=-1, + )[None], + mesh, + in_specs=PartitionSpec(), + out_specs=PartitionSpec( + ("i", "j"), + ), + ) +) +out = axis_idx_fn() +out = jax.device_get(out) +for i in range(out.shape[0]): + print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}") + +# %% +# usage 2: fold rng over given axis +# jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key +# from a source RNG key, we generate new RNG keys +def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey: + """Folds the random number generator over the given axis. + + This is useful for generating a different random number for each device + across a certain axis (e.g. the model axis). + + Args: + rng: The random number generator. + axis_name: The axis name to fold the random number generator over. + + Returns: + A new random number generator, different for each device index along the axis. + """ + axis_index = jax.lax.axis_index(axis_name) + return jax.random.fold_in(rng, axis_index) + +# %% +# we fold RNG over the i axis only +# same RNG used across j axis +fold_fn = jax.jit( + shard_map( + # fold over for "i" only + functools.partial(fold_rng_over_axis, axis_name="i"), + mesh, + in_specs=PartitionSpec(), + out_specs=PartitionSpec( + ("i", "j"), + ), + ) +) +rng = jax.random.PRNGKey(0) +out = fold_fn(rng) +out = jax.device_get(out) +for i in range(out.shape[0] // 2): + print(f"Device {i}: RNG={out[2*i:2*i+2]}") + + +# %% diff --git a/parallel/single_gpu_optimizations.py b/parallel/single_gpu_optimizations.py new file mode 100644 index 0000000..49efd67 --- /dev/null +++ b/parallel/single_gpu_optimizations.py @@ -0,0 +1,441 @@ + +# MARK: import +# %% [markdown] +# # single gpu optimizaitons + +import os + +# os.environ["XLA_FLAGS"] = ( +# "--xla_gpu_enable_triton_softmax_fusion=true " +# "--xla_gpu_triton_gemm_any=false " +# "--xla_gpu_enable_async_collectives=true " +# "--xla_gpu_enable_latency_hiding_scheduler=true " +# "--xla_gpu_enable_highest_priority_async_stream=true " +# ) +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +# %% +import functools +from pprint import pprint +from typing import Any, Callable, Dict, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.struct import dataclass +from flax.training import train_state + +# Type aliases +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +# %% [mardown] +# # bfloat16 mixed precision compute +class MLPClassifier(nn.Module): + dtype: Any # we set the dtype here for computation + hidden_size: int = 256 + num_classes: int = 100 + dropout_rate: float = 0.1 + + @nn.compact + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + x = nn.Dense( + features=self.hidden_size, + dtype=self.dtype, # Computation in specified dtype, params stay in float32 + )(x) + x = nn.LayerNorm(dtype=self.dtype)(x) + x = nn.silu(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + x = nn.Dense( + features=self.num_classes, + dtype=self.dtype, + )(x) + x = x.astype(jnp.float32) + x = nn.log_softmax(x, axis=-1) + return x + + +# %% +x = jnp.ones((512, 128), dtype=jnp.float32) +rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)} +model_float32 = MLPClassifier(dtype=jnp.float32) +model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True}) + + +# %% +# inputs and activations (outputs) in bfloat16 +# parameters in float32 +model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16) +model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True}) + +# MARK: GRADIENT CHECKPOINT +# %% [markdown] +# # gradient checkpoint +# +# in jax this is implemented with the remat function +# +# practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes +def gelu(x: jax.Array) -> jax.Array: + """GeLU activation function with approximate tanh.""" + # This will be printed once every time the function is executed. + jax.debug.print("Executing GeLU") + # See https://arxiv.org/abs/1606.08415 for details. + x3 = jnp.power(x, 3) + tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3) + return 0.5 * x * (1 + jnp.tanh(tanh_input)) + +def loss_fn(x: jax.Array, remat: bool) -> jax.Array: + act_fn = gelu + if remat: + act_fn = jax.remat(act_fn) + return jnp.mean(act_fn(x)) + +x = jax.random.normal(jax.random.PRNGKey(0), (100,)) +grad_fn = jax.grad(loss_fn) +# regenerate function on backward +_ = grad_fn(x, remat=True) + +# no remat, no function regeneration +_ = loss_fn(x, remat=False) + +#MARK: GRADIENT ACCUMULATION +# %% [markdown] +# # gradient accumulation +# +# run many mini-batches, and accumulate their gradients to feed into optimizer +# as if there were one large batch + + +# %% +class TrainState(train_state.TrainState): + rng: jax.Array + +@dataclass +class Batch: + inputs: jax.Array + labels: jax.Array + +# %% +# nothing special here, just a loss function +def classification_loss_fn( + params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array +) -> Tuple[PyTree, Metrics]: + """Classification loss function with cross-entropy.""" + logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng}) + loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels) + correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels) + batch_size = batch.inputs.shape[0] + # step_metrics contains the loss sum + step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)} + # loss contains the mean loss + mean_loss = loss.mean() # the mathematical output of function + return mean_loss, step_metrics + +# %% +# gradient accumulation training loop +def accumulate_gradients_loop( + state: TrainState, + batch: Batch, + rng: jax.random.PRNGKey, + num_minibatches: int, + loss_fn: Callable, +) -> Tuple[PyTree, Metrics]: + """Calculate gradients and metrics for a batch using gradient accumulation. + + Args: + state: Current training state. + batch: Full training batch. + rng: Random number generator to use. + num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. + loss_fn: Loss function to calculate gradients and metrics. + + Returns: + Tuple with accumulated gradients and metrics over the minibatches. + """ + batch_size = batch.inputs.shape[0] + minibatch_size = batch_size // num_minibatches + rngs = jax.random.split(rng, num_minibatches) + # Define gradient function for single minibatch. + # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. + # otherwise it returns (value, gradient), where value is the actual output + # of the function, hence the "value" of the namesake + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + # Prepare loop variables. + grads = None + metrics = None + for minibatch_idx in range(num_minibatches): + with jax.named_scope(f"minibatch_{minibatch_idx}"): + # Split the batch into minibatches. + start = minibatch_idx * minibatch_size + end = start + minibatch_size + minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023 + # Calculate gradients and metrics for the minibatch. + # missing value is mean loss of batch + (_, step_metrics), step_grads = grad_fn( + state.params, state.apply_fn, minibatch, rngs[minibatch_idx] + ) + # Accumulate gradients and metrics across minibatches. + if grads is None: + grads = step_grads + metrics = step_metrics + else: + # accumulation adder + grads = jax.tree.map(jnp.add, grads, step_grads) + metrics = jax.tree.map(jnp.add, metrics, step_metrics) + # Average gradients over minibatches. + grads = jax.tree.map(lambda g: g / num_minibatches, grads) + return grads, metrics + + +# %% +# jax.scan implementation +# +# pros: faster compile +# cons: slower inference +def accumulate_gradients_scan( + state: TrainState, + batch: Batch, + rng: jax.random.PRNGKey, + num_minibatches: int, + loss_fn: Callable, +) -> Tuple[PyTree, Metrics]: + """Calculate gradients and metrics for a batch using gradient accumulation. + + In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time. + + Args: + state: Current training state. + batch: Full training batch. + rng: Random number generator to use. + num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. + loss_fn: Loss function to calculate gradients and metrics. + + Returns: + Tuple with accumulated gradients and metrics over the minibatches. + """ + batch_size = batch.inputs.shape[0] + minibatch_size = batch_size // num_minibatches + rngs = jax.random.split(rng, num_minibatches) + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + + def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]: + """Determine gradients and metrics for a single minibatch.""" + minibatch = jax.tree.map( + # jax.lax.dynamic_slice_in_dim + # This is roughly equivalent to the following Python indexing syntax + # applied along the specified axis: operand[..., start_index:start_index + slice_size]. + # jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0) + lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array). + x, + start_index=minibatch_idx * minibatch_size, + slice_size=minibatch_size, + axis=0 + ), + batch, + ) + (_, step_metrics), step_grads = grad_fn( + state.params, state.apply_fn, minibatch, rngs[minibatch_idx] + ) + return step_grads, step_metrics + + # the function we expect scan to use + def _scan_step( + carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int + ) -> Tuple[Tuple[PyTree, Metrics], None]: + """Scan step function for looping over minibatches.""" + step_grads, step_metrics = _minibatch_step(minibatch_idx) + # notice how the carry type is a tuple of pytree and metrics + # carry is literally the accumulator of (step_grads, step_metrics) + carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics)) + # jax.lax.scan expects a carry and a y + # but we have no y + return carry, None + + # Determine initial shapes for gradients and metrics. + grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0) + grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes) + metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape) + # Loop over minibatches to determine gradients and metrics. + # jax.lax.scan + # jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False) + # purpose: Scan a function over leading array axes while carrying along state. + # in other words, a functional for-loop + # why? because the for-loop is a single WhileOp in JAX primitive, making it faster + # equivalent python code semantics: + # def scan(f, init, xs, length=None): + # if xs is None: + # xs = [None] * length + # carry = init + # ys = [] + # for x in xs: + # carry, y = f(carry, x) + # ys.append(y) + # return carry, np.stack(ys) + # note: usually we expect the ys to be the output and the carry to be hidden state + # y is the pure function output we expect + (grads, metrics), _ = jax.lax.scan( + _scan_step, + init=(grads, metrics), + xs=jnp.arange(num_minibatches), + length=num_minibatches + ) + # Average gradients over minibatches. + grads = jax.tree.map(lambda g: g / num_minibatches, grads) + return grads, metrics + +# %% +def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]: + if use_scan: + return accumulate_gradients_scan(*args, **kwargs) + else: + return accumulate_gradients_loop(*args, **kwargs) + +# %% +def train_step( + state: TrainState, + metrics: Metrics | None, + batch: Batch, + num_minibatches: int, +) -> Tuple[TrainState, Metrics]: + """Training step function. + + Executes a full training step with gradient accumulation. + + Args: + state: Current training state. + metrics: Current metrics, accumulated from previous training steps. + batch: Training batch. + num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. + + Returns: + Tuple with updated training state (parameters, optimizer state, etc.) and metrics. + """ + # Split the random number generator for the current step. + rng, step_rng = jax.random.split(state.rng) + # Determine gradients and metrics for the full batch. + grads, step_metrics = accumulate_gradients( + # we cannot use a variable to choose use_scan + # cardinal sin of jax: passing boolean into jitted function + state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True + ) + # Optimizer step. + new_state = state.apply_gradients(grads=grads, rng=rng) + # Accumulate metrics across training steps. + if metrics is None: + metrics = step_metrics + else: + metrics = jax.tree.map(jnp.add, metrics, step_metrics) + return new_state, metrics + + +# %% +batch_size = 512 +num_inputs = 128 +num_classes = 100 +rng_seed = 0 + +rng = jax.random.PRNGKey(rng_seed) +data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4) +batch = Batch( + inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)), + labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes), +) + +# Zero dropout for checking later equality between training with and without gradient accumulation. +model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0) +params = model.init(model_rng, batch.inputs, train=False)["params"] +state = TrainState.create( + apply_fn=model.apply, + params=params, + tx=optax.adam(1e-3), + rng=state_rng, +) +# %% +# jax.eval_shape(fun, *args, **kwargs) +# compute shape/dtype of fun without any FLOPs + +# this fails because it jits train_step without minibatch number +# thus causing the shape inference to fail +# _, metric_shapes = jax.eval_shape( +# train_step, # fun +# state, # train state +# None, # metrics +# batch, # batch +# 4, # num_minibatches +# ) + +_, metric_shapes = jax.eval_shape( + # this thing jitted works + functools.partial(train_step, num_minibatches=4), + state, # train state + None, # metrics + batch, # batch +) + +print("Metric shapes:") +pprint(metric_shapes) + +# %% +# this is an optimization trick +# cache this every time num_minibatches change +# otherwise re-compile every time +train_step_jit = jax.jit( + train_step, + # treat as a static argument + static_argnames="num_minibatches", +) + +# %% +def train_with_minibatches( + state: TrainState, + batch: Batch, + num_minibatches: int, + num_train_steps: int, +) -> Tuple[TrainState, Metrics]: + """Small helper function for training loop.""" + train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) + for _ in range(num_train_steps): + state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches) + return state, train_metrics +# %% +def print_metrics(metrics: Metrics, title: str | None = None) -> None: + """Prints metrics with an optional title.""" + metrics = jax.device_get(metrics) + lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()] + if title: + title = f" {title} " + max_len = max(len(title), max(map(len, lines))) + lines = [title.center(max_len, "=")] + lines + print("\n".join(lines)) + +# %% +state_mini1, metrics_mini1 = train_with_minibatches( + state, batch, num_minibatches=1, num_train_steps=4 +) +state_mini4, metrics_mini4 = train_with_minibatches( + state, batch, num_minibatches=4, num_train_steps=4 +) +print_metrics(metrics_mini1, "Minibatch 1") +print_metrics(metrics_mini4, "Minibatch 4") + + +# %% [markdown] +# # donate_buffers +# jax perform pass by value due to its functional nature +# we can do pass by reference for certain arguments +# what can be donated? +# we can only do this if we are sure that arguments will not be used +# this is usually true for model parameters and optimizer state +# since we have totally new values, and we won't use the argument values anymore +# after an update (e.g. we will use new_state and new_metrics) +train_step_donated = jax.jit( + train_step, + static_argnames="num_minibatches", + donate_argnames=( + "state", + "metrics", + ), +) diff --git a/parallel/t5_jax_train_pjit.py b/parallel/t5_jax_train_pjit.py new file mode 100644 index 0000000..c1dc6b3 --- /dev/null +++ b/parallel/t5_jax_train_pjit.py @@ -0,0 +1,392 @@ +# %% [markdown] +# # T5 implementation using jax with pjit + + +# MARK: START +# %% +# let's make 8-device simulator +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +import flax.core + +from tqdm import tqdm + +from dataload import DataPrepare + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + + +# # %% +# import jax +# import jax.numpy as jnp +# import optax +# import numpy as np +# from functools import partial +# from typing import Callable, Optional +# import math +# +# # jax.config.update("jax_default_matmul_precision", "tensorfloat32") +# jax.config.update("jax_default_matmul_precision", "bfloat16") +# # jax.config.update("jax_enable_x64", False) +# # enable cache +# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +# +# +# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig +# +# from flax import jax_utils, traverse_util +# from flax.jax_utils import pad_shard_unpad, unreplicate +# from flax.training import train_state +# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +# import flax.core + + +# %% +# get platform type +from jax.lib import xla_bridge +print(xla_bridge.get_backend().platform) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' +save_path = 't5_80_1_bf16' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 384 # 384 is the best +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 2e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = 128 + +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +data_config = ConfigDict( + dict( + max_length=86, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +# seed = 117 +# rng = jax.random.PRNGKey(seed) +# train_loader = dataprep.data_loader(rng, batch_size=1) + + +# %% +# model + +from transformers import FlaxT5ForConditionalGeneration +from transformers import T5Config + +config = T5Config() + +# If you want don't want to cast certain parameters (for example layer norm bias and scale) +# then pass the mask as follows +from flax import traverse_util + +model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") +# useful for transformer model +model.enable_gradient_checkpointing() + +# enable bf16 except for layer_norm +flat_params = traverse_util.flatten_dict(model.params) +mask = { + path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params +} +mask = traverse_util.unflatten_dict(mask) +model.params = model.to_bf16(model.params, mask) + + +# %% [markdown] +# # Model +# +# +# + +# %% + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + + +# %% +# optimization functions + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +# %% +# Training functions +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + # easy way to achieve data parallelism + # also achieves folding of rng keys + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + +# set bf16 for model params +# model.params = model.to_bf16(model.params) +# Cast parameters to bfloat16 if desired +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + +# Setup train state +state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + +# label smoothed cross entropy +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# Define gradient update step fn +@jax.jit +def train_step(state, batch, label_smoothing_factor=0.0): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + # loss = jax.lax.psum(loss, "batch") + # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + +# Define generation function +max_length = ( + val_max_target_length if val_max_target_length is not None else model.config.max_length +) +num_beams = num_beams if num_beams is not None else model.config.num_beams +gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +# def generate_step(params, batch): +# model.params = params +# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) +# return output_ids.sequences + +# Create parallel version of the train and eval step +p_train_step = jax.pmap( + partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) +) +# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") +# p_generate_step = jax.pmap(generate_step, "batch") + +# Replicate the train state on each device +state = state.replicate() + + + +# %% + + +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + +rng, input_rng = jax.random.split(rng) +train_time = 0 +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + rng, data_rng = jax.random.split(rng) + train_loader = dataprep.data_loader(data_rng, batch_size=batch_size) + steps_per_epoch = training_size // train_batch_size + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_time = time.time() - train_start + + train_metric = unreplicate(train_metric) + train_metric['loss'].block_until_ready() + + + + epochs.write( + # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " + f"Epoch... ({epoch + 1}/{num_epochs} | " + # f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() +# %% + +# output_dir = save_path +# # save checkpoint after each epoch and push checkpoint to the hub +# if jax.process_index() == 0: +# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) +# model.save_pretrained(output_dir, params=params) +# tokenizer.save_pretrained(output_dir) diff --git a/t5_jax.py b/t5_jax.py index 0fd7c89..765e1cf 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -196,7 +196,7 @@ model.params = model.to_bf16(model.params, mask) # %% model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) -shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 @@ -391,6 +391,8 @@ adamw = optax.adamw( class TrainState(train_state.TrainState): dropout_rng: jnp.ndarray + # easy way to achieve data parallelism + # also achieves folding of rng keys def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))