# MARK: import # %% [markdown] # # single gpu optimizaitons import os # os.environ["XLA_FLAGS"] = ( # "--xla_gpu_enable_triton_softmax_fusion=true " # "--xla_gpu_triton_gemm_any=false " # "--xla_gpu_enable_async_collectives=true " # "--xla_gpu_enable_latency_hiding_scheduler=true " # "--xla_gpu_enable_highest_priority_async_stream=true " # ) os.environ['CUDA_VISIBLE_DEVICES'] = '0' # %% import functools from pprint import pprint from typing import Any, Callable, Dict, Tuple import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import optax from flax.struct import dataclass from flax.training import train_state # Type aliases PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] # %% [mardown] # # bfloat16 mixed precision compute class MLPClassifier(nn.Module): dtype: Any # we set the dtype here for computation hidden_size: int = 256 num_classes: int = 100 dropout_rate: float = 0.1 @nn.compact def __call__(self, x: jax.Array, train: bool) -> jax.Array: x = nn.Dense( features=self.hidden_size, dtype=self.dtype, # Computation in specified dtype, params stay in float32 )(x) x = nn.LayerNorm(dtype=self.dtype)(x) x = nn.silu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = nn.Dense( features=self.num_classes, dtype=self.dtype, )(x) x = x.astype(jnp.float32) x = nn.log_softmax(x, axis=-1) return x # %% x = jnp.ones((512, 128), dtype=jnp.float32) rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)} model_float32 = MLPClassifier(dtype=jnp.float32) model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True}) # %% # inputs and activations (outputs) in bfloat16 # parameters in float32 model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16) model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True}) # MARK: GRADIENT CHECKPOINT # %% [markdown] # # gradient checkpoint # # in jax this is implemented with the remat function # # practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes def gelu(x: jax.Array) -> jax.Array: """GeLU activation function with approximate tanh.""" # This will be printed once every time the function is executed. jax.debug.print("Executing GeLU") # See https://arxiv.org/abs/1606.08415 for details. x3 = jnp.power(x, 3) tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3) return 0.5 * x * (1 + jnp.tanh(tanh_input)) def loss_fn(x: jax.Array, remat: bool) -> jax.Array: act_fn = gelu if remat: act_fn = jax.remat(act_fn) return jnp.mean(act_fn(x)) x = jax.random.normal(jax.random.PRNGKey(0), (100,)) grad_fn = jax.grad(loss_fn) # regenerate function on backward _ = grad_fn(x, remat=True) # no remat, no function regeneration _ = loss_fn(x, remat=False) #MARK: GRADIENT ACCUMULATION # %% [markdown] # # gradient accumulation # # run many mini-batches, and accumulate their gradients to feed into optimizer # as if there were one large batch # %% class TrainState(train_state.TrainState): rng: jax.Array @dataclass class Batch: inputs: jax.Array labels: jax.Array # %% # nothing special here, just a loss function def classification_loss_fn( params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array ) -> Tuple[PyTree, Metrics]: """Classification loss function with cross-entropy.""" logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng}) loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels) correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels) batch_size = batch.inputs.shape[0] # step_metrics contains the loss sum step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)} # loss contains the mean loss mean_loss = loss.mean() # the mathematical output of function return mean_loss, step_metrics # %% # gradient accumulation training loop def accumulate_gradients_loop( state: TrainState, batch: Batch, rng: jax.random.PRNGKey, num_minibatches: int, loss_fn: Callable, ) -> Tuple[PyTree, Metrics]: """Calculate gradients and metrics for a batch using gradient accumulation. Args: state: Current training state. batch: Full training batch. rng: Random number generator to use. num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. loss_fn: Loss function to calculate gradients and metrics. Returns: Tuple with accumulated gradients and metrics over the minibatches. """ batch_size = batch.inputs.shape[0] minibatch_size = batch_size // num_minibatches rngs = jax.random.split(rng, num_minibatches) # Define gradient function for single minibatch. # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. # otherwise it returns (value, gradient), where value is the actual output # of the function, hence the "value" of the namesake grad_fn = jax.value_and_grad(loss_fn, has_aux=True) # Prepare loop variables. grads = None metrics = None for minibatch_idx in range(num_minibatches): with jax.named_scope(f"minibatch_{minibatch_idx}"): # Split the batch into minibatches. start = minibatch_idx * minibatch_size end = start + minibatch_size minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023 # Calculate gradients and metrics for the minibatch. # missing value is mean loss of batch (_, step_metrics), step_grads = grad_fn( state.params, state.apply_fn, minibatch, rngs[minibatch_idx] ) # Accumulate gradients and metrics across minibatches. if grads is None: grads = step_grads metrics = step_metrics else: # accumulation adder grads = jax.tree.map(jnp.add, grads, step_grads) metrics = jax.tree.map(jnp.add, metrics, step_metrics) # Average gradients over minibatches. grads = jax.tree.map(lambda g: g / num_minibatches, grads) return grads, metrics # %% # jax.scan implementation # # pros: faster compile # cons: slower inference def accumulate_gradients_scan( state: TrainState, batch: Batch, rng: jax.random.PRNGKey, num_minibatches: int, loss_fn: Callable, ) -> Tuple[PyTree, Metrics]: """Calculate gradients and metrics for a batch using gradient accumulation. In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time. Args: state: Current training state. batch: Full training batch. rng: Random number generator to use. num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. loss_fn: Loss function to calculate gradients and metrics. Returns: Tuple with accumulated gradients and metrics over the minibatches. """ batch_size = batch.inputs.shape[0] minibatch_size = batch_size // num_minibatches rngs = jax.random.split(rng, num_minibatches) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]: """Determine gradients and metrics for a single minibatch.""" minibatch = jax.tree.map( # jax.lax.dynamic_slice_in_dim # This is roughly equivalent to the following Python indexing syntax # applied along the specified axis: operand[..., start_index:start_index + slice_size]. # jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0) lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array). x, start_index=minibatch_idx * minibatch_size, slice_size=minibatch_size, axis=0 ), batch, ) (_, step_metrics), step_grads = grad_fn( state.params, state.apply_fn, minibatch, rngs[minibatch_idx] ) return step_grads, step_metrics # the function we expect scan to use def _scan_step( carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int ) -> Tuple[Tuple[PyTree, Metrics], None]: """Scan step function for looping over minibatches.""" step_grads, step_metrics = _minibatch_step(minibatch_idx) # notice how the carry type is a tuple of pytree and metrics # carry is literally the accumulator of (step_grads, step_metrics) carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics)) # jax.lax.scan expects a carry and a y # but we have no y return carry, None # Determine initial shapes for gradients and metrics. grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0) grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes) metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape) # Loop over minibatches to determine gradients and metrics. # jax.lax.scan # jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False) # purpose: Scan a function over leading array axes while carrying along state. # in other words, a functional for-loop # why? because the for-loop is a single WhileOp in JAX primitive, making it faster # equivalent python code semantics: # def scan(f, init, xs, length=None): # if xs is None: # xs = [None] * length # carry = init # ys = [] # for x in xs: # carry, y = f(carry, x) # ys.append(y) # return carry, np.stack(ys) # note: usually we expect the ys to be the output and the carry to be hidden state # y is the pure function output we expect (grads, metrics), _ = jax.lax.scan( _scan_step, init=(grads, metrics), xs=jnp.arange(num_minibatches), length=num_minibatches ) # Average gradients over minibatches. grads = jax.tree.map(lambda g: g / num_minibatches, grads) return grads, metrics # %% def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]: if use_scan: return accumulate_gradients_scan(*args, **kwargs) else: return accumulate_gradients_loop(*args, **kwargs) # %% def train_step( state: TrainState, metrics: Metrics | None, batch: Batch, num_minibatches: int, ) -> Tuple[TrainState, Metrics]: """Training step function. Executes a full training step with gradient accumulation. Args: state: Current training state. metrics: Current metrics, accumulated from previous training steps. batch: Training batch. num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. Returns: Tuple with updated training state (parameters, optimizer state, etc.) and metrics. """ # Split the random number generator for the current step. rng, step_rng = jax.random.split(state.rng) # Determine gradients and metrics for the full batch. grads, step_metrics = accumulate_gradients( # we cannot use a variable to choose use_scan # cardinal sin of jax: passing boolean into jitted function state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True ) # Optimizer step. new_state = state.apply_gradients(grads=grads, rng=rng) # Accumulate metrics across training steps. if metrics is None: metrics = step_metrics else: metrics = jax.tree.map(jnp.add, metrics, step_metrics) return new_state, metrics # %% batch_size = 512 num_inputs = 128 num_classes = 100 rng_seed = 0 rng = jax.random.PRNGKey(rng_seed) data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4) batch = Batch( inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)), labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes), ) # Zero dropout for checking later equality between training with and without gradient accumulation. model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0) params = model.init(model_rng, batch.inputs, train=False)["params"] state = TrainState.create( apply_fn=model.apply, params=params, tx=optax.adam(1e-3), rng=state_rng, ) # %% # jax.eval_shape(fun, *args, **kwargs) # compute shape/dtype of fun without any FLOPs # this fails because it jits train_step without minibatch number # thus causing the shape inference to fail # _, metric_shapes = jax.eval_shape( # train_step, # fun # state, # train state # None, # metrics # batch, # batch # 4, # num_minibatches # ) _, metric_shapes = jax.eval_shape( # this thing jitted works functools.partial(train_step, num_minibatches=4), state, # train state None, # metrics batch, # batch ) print("Metric shapes:") pprint(metric_shapes) # %% # this is an optimization trick # cache this every time num_minibatches change # otherwise re-compile every time train_step_jit = jax.jit( train_step, # treat as a static argument static_argnames="num_minibatches", ) # %% def train_with_minibatches( state: TrainState, batch: Batch, num_minibatches: int, num_train_steps: int, ) -> Tuple[TrainState, Metrics]: """Small helper function for training loop.""" train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes) for _ in range(num_train_steps): state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches) return state, train_metrics # %% def print_metrics(metrics: Metrics, title: str | None = None) -> None: """Prints metrics with an optional title.""" metrics = jax.device_get(metrics) lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()] if title: title = f" {title} " max_len = max(len(title), max(map(len, lines))) lines = [title.center(max_len, "=")] + lines print("\n".join(lines)) # %% state_mini1, metrics_mini1 = train_with_minibatches( state, batch, num_minibatches=1, num_train_steps=4 ) state_mini4, metrics_mini4 = train_with_minibatches( state, batch, num_minibatches=4, num_train_steps=4 ) print_metrics(metrics_mini1, "Minibatch 1") print_metrics(metrics_mini4, "Minibatch 4") # %% [markdown] # # donate_buffers # jax perform pass by value due to its functional nature # we can do pass by reference for certain arguments # what can be donated? # we can only do this if we are sure that arguments will not be used # this is usually true for model parameters and optimizer state # since we have totally new values, and we won't use the argument values anymore # after an update (e.g. we will use new_state and new_metrics) train_step_donated = jax.jit( train_step, static_argnames="num_minibatches", donate_argnames=( "state", "metrics", ), )