learn_jax/parallel/single_gpu_optimizations.py

442 lines
15 KiB
Python
Raw Normal View History

# MARK: import
# %% [markdown]
# # single gpu optimizaitons
import os
# os.environ["XLA_FLAGS"] = (
# "--xla_gpu_enable_triton_softmax_fusion=true "
# "--xla_gpu_triton_gemm_any=false "
# "--xla_gpu_enable_async_collectives=true "
# "--xla_gpu_enable_latency_hiding_scheduler=true "
# "--xla_gpu_enable_highest_priority_async_stream=true "
# )
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# %%
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import dataclass
from flax.training import train_state
# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
# %% [mardown]
# # bfloat16 mixed precision compute
class MLPClassifier(nn.Module):
dtype: Any # we set the dtype here for computation
hidden_size: int = 256
num_classes: int = 100
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
x = nn.Dense(
features=self.hidden_size,
dtype=self.dtype, # Computation in specified dtype, params stay in float32
)(x)
x = nn.LayerNorm(dtype=self.dtype)(x)
x = nn.silu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
x = nn.Dense(
features=self.num_classes,
dtype=self.dtype,
)(x)
x = x.astype(jnp.float32)
x = nn.log_softmax(x, axis=-1)
return x
# %%
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
# %%
# inputs and activations (outputs) in bfloat16
# parameters in float32
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
# MARK: GRADIENT CHECKPOINT
# %% [markdown]
# # gradient checkpoint
#
# in jax this is implemented with the remat function
#
# practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes
def gelu(x: jax.Array) -> jax.Array:
"""GeLU activation function with approximate tanh."""
# This will be printed once every time the function is executed.
jax.debug.print("Executing GeLU")
# See https://arxiv.org/abs/1606.08415 for details.
x3 = jnp.power(x, 3)
tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
return 0.5 * x * (1 + jnp.tanh(tanh_input))
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
act_fn = gelu
if remat:
act_fn = jax.remat(act_fn)
return jnp.mean(act_fn(x))
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
grad_fn = jax.grad(loss_fn)
# regenerate function on backward
_ = grad_fn(x, remat=True)
# no remat, no function regeneration
_ = loss_fn(x, remat=False)
#MARK: GRADIENT ACCUMULATION
# %% [markdown]
# # gradient accumulation
#
# run many mini-batches, and accumulate their gradients to feed into optimizer
# as if there were one large batch
# %%
class TrainState(train_state.TrainState):
rng: jax.Array
@dataclass
class Batch:
inputs: jax.Array
labels: jax.Array
# %%
# nothing special here, just a loss function
def classification_loss_fn(
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
"""Classification loss function with cross-entropy."""
logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
batch_size = batch.inputs.shape[0]
# step_metrics contains the loss sum
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
# loss contains the mean loss
mean_loss = loss.mean() # the mathematical output of function
return mean_loss, step_metrics
# %%
# gradient accumulation training loop
def accumulate_gradients_loop(
state: TrainState,
batch: Batch,
rng: jax.random.PRNGKey,
num_minibatches: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch.inputs.shape[0]
minibatch_size = batch_size // num_minibatches
rngs = jax.random.split(rng, num_minibatches)
# Define gradient function for single minibatch.
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
# otherwise it returns (value, gradient), where value is the actual output
# of the function, hence the "value" of the namesake
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# Prepare loop variables.
grads = None
metrics = None
for minibatch_idx in range(num_minibatches):
with jax.named_scope(f"minibatch_{minibatch_idx}"):
# Split the batch into minibatches.
start = minibatch_idx * minibatch_size
end = start + minibatch_size
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
# Calculate gradients and metrics for the minibatch.
# missing value is mean loss of batch
(_, step_metrics), step_grads = grad_fn(
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
)
# Accumulate gradients and metrics across minibatches.
if grads is None:
grads = step_grads
metrics = step_metrics
else:
# accumulation adder
grads = jax.tree.map(jnp.add, grads, step_grads)
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# %%
# jax.scan implementation
#
# pros: faster compile
# cons: slower inference
def accumulate_gradients_scan(
state: TrainState,
batch: Batch,
rng: jax.random.PRNGKey,
num_minibatches: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch.inputs.shape[0]
minibatch_size = batch_size // num_minibatches
rngs = jax.random.split(rng, num_minibatches)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
"""Determine gradients and metrics for a single minibatch."""
minibatch = jax.tree.map(
# jax.lax.dynamic_slice_in_dim
# This is roughly equivalent to the following Python indexing syntax
# applied along the specified axis: operand[..., start_index:start_index + slice_size].
# jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)
lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array).
x,
start_index=minibatch_idx * minibatch_size,
slice_size=minibatch_size,
axis=0
),
batch,
)
(_, step_metrics), step_grads = grad_fn(
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
)
return step_grads, step_metrics
# the function we expect scan to use
def _scan_step(
carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
) -> Tuple[Tuple[PyTree, Metrics], None]:
"""Scan step function for looping over minibatches."""
step_grads, step_metrics = _minibatch_step(minibatch_idx)
# notice how the carry type is a tuple of pytree and metrics
# carry is literally the accumulator of (step_grads, step_metrics)
carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics))
# jax.lax.scan expects a carry and a y
# but we have no y
return carry, None
# Determine initial shapes for gradients and metrics.
grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
# Loop over minibatches to determine gradients and metrics.
# jax.lax.scan
# jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)
# purpose: Scan a function over leading array axes while carrying along state.
# in other words, a functional for-loop
# why? because the for-loop is a single WhileOp in JAX primitive, making it faster
# equivalent python code semantics:
# def scan(f, init, xs, length=None):
# if xs is None:
# xs = [None] * length
# carry = init
# ys = []
# for x in xs:
# carry, y = f(carry, x)
# ys.append(y)
# return carry, np.stack(ys)
# note: usually we expect the ys to be the output and the carry to be hidden state
# y is the pure function output we expect
(grads, metrics), _ = jax.lax.scan(
_scan_step,
init=(grads, metrics),
xs=jnp.arange(num_minibatches),
length=num_minibatches
)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# %%
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
if use_scan:
return accumulate_gradients_scan(*args, **kwargs)
else:
return accumulate_gradients_loop(*args, **kwargs)
# %%
def train_step(
state: TrainState,
metrics: Metrics | None,
batch: Batch,
num_minibatches: int,
) -> Tuple[TrainState, Metrics]:
"""Training step function.
Executes a full training step with gradient accumulation.
Args:
state: Current training state.
metrics: Current metrics, accumulated from previous training steps.
batch: Training batch.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
Returns:
Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
"""
# Split the random number generator for the current step.
rng, step_rng = jax.random.split(state.rng)
# Determine gradients and metrics for the full batch.
grads, step_metrics = accumulate_gradients(
# we cannot use a variable to choose use_scan
# cardinal sin of jax: passing boolean into jitted function
state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
)
# Optimizer step.
new_state = state.apply_gradients(grads=grads, rng=rng)
# Accumulate metrics across training steps.
if metrics is None:
metrics = step_metrics
else:
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
return new_state, metrics
# %%
batch_size = 512
num_inputs = 128
num_classes = 100
rng_seed = 0
rng = jax.random.PRNGKey(rng_seed)
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
batch = Batch(
inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
)
# Zero dropout for checking later equality between training with and without gradient accumulation.
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
params = model.init(model_rng, batch.inputs, train=False)["params"]
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3),
rng=state_rng,
)
# %%
# jax.eval_shape(fun, *args, **kwargs)
# compute shape/dtype of fun without any FLOPs
# this fails because it jits train_step without minibatch number
# thus causing the shape inference to fail
# _, metric_shapes = jax.eval_shape(
# train_step, # fun
# state, # train state
# None, # metrics
# batch, # batch
# 4, # num_minibatches
# )
_, metric_shapes = jax.eval_shape(
# this thing jitted works
functools.partial(train_step, num_minibatches=4),
state, # train state
None, # metrics
batch, # batch
)
print("Metric shapes:")
pprint(metric_shapes)
# %%
# this is an optimization trick
# cache this every time num_minibatches change
# otherwise re-compile every time
train_step_jit = jax.jit(
train_step,
# treat as a static argument
static_argnames="num_minibatches",
)
# %%
def train_with_minibatches(
state: TrainState,
batch: Batch,
num_minibatches: int,
num_train_steps: int,
) -> Tuple[TrainState, Metrics]:
"""Small helper function for training loop."""
train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
for _ in range(num_train_steps):
state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
return state, train_metrics
# %%
def print_metrics(metrics: Metrics, title: str | None = None) -> None:
"""Prints metrics with an optional title."""
metrics = jax.device_get(metrics)
lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()]
if title:
title = f" {title} "
max_len = max(len(title), max(map(len, lines)))
lines = [title.center(max_len, "=")] + lines
print("\n".join(lines))
# %%
state_mini1, metrics_mini1 = train_with_minibatches(
state, batch, num_minibatches=1, num_train_steps=4
)
state_mini4, metrics_mini4 = train_with_minibatches(
state, batch, num_minibatches=4, num_train_steps=4
)
print_metrics(metrics_mini1, "Minibatch 1")
print_metrics(metrics_mini4, "Minibatch 4")
# %% [markdown]
# # donate_buffers
# jax perform pass by value due to its functional nature
# we can do pass by reference for certain arguments
# what can be donated?
# we can only do this if we are sure that arguments will not be used
# this is usually true for model parameters and optimizer state
# since we have totally new values, and we won't use the argument values anymore
# after an update (e.g. we will use new_state and new_metrics)
train_step_donated = jax.jit(
train_step,
static_argnames="num_minibatches",
donate_argnames=(
"state",
"metrics",
),
)