442 lines
15 KiB
Python
442 lines
15 KiB
Python
|
|
# MARK: import
|
|
# %% [markdown]
|
|
# # single gpu optimizaitons
|
|
|
|
import os
|
|
|
|
# os.environ["XLA_FLAGS"] = (
|
|
# "--xla_gpu_enable_triton_softmax_fusion=true "
|
|
# "--xla_gpu_triton_gemm_any=false "
|
|
# "--xla_gpu_enable_async_collectives=true "
|
|
# "--xla_gpu_enable_latency_hiding_scheduler=true "
|
|
# "--xla_gpu_enable_highest_priority_async_stream=true "
|
|
# )
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
|
|
|
|
# %%
|
|
import functools
|
|
from pprint import pprint
|
|
from typing import Any, Callable, Dict, Tuple
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import optax
|
|
from flax.struct import dataclass
|
|
from flax.training import train_state
|
|
|
|
# Type aliases
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
|
|
# %% [mardown]
|
|
# # bfloat16 mixed precision compute
|
|
class MLPClassifier(nn.Module):
|
|
dtype: Any # we set the dtype here for computation
|
|
hidden_size: int = 256
|
|
num_classes: int = 100
|
|
dropout_rate: float = 0.1
|
|
|
|
@nn.compact
|
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
|
x = nn.Dense(
|
|
features=self.hidden_size,
|
|
dtype=self.dtype, # Computation in specified dtype, params stay in float32
|
|
)(x)
|
|
x = nn.LayerNorm(dtype=self.dtype)(x)
|
|
x = nn.silu(x)
|
|
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
|
|
x = nn.Dense(
|
|
features=self.num_classes,
|
|
dtype=self.dtype,
|
|
)(x)
|
|
x = x.astype(jnp.float32)
|
|
x = nn.log_softmax(x, axis=-1)
|
|
return x
|
|
|
|
|
|
# %%
|
|
x = jnp.ones((512, 128), dtype=jnp.float32)
|
|
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
|
|
model_float32 = MLPClassifier(dtype=jnp.float32)
|
|
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
|
|
|
|
|
|
# %%
|
|
# inputs and activations (outputs) in bfloat16
|
|
# parameters in float32
|
|
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
|
|
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
|
|
|
|
# MARK: GRADIENT CHECKPOINT
|
|
# %% [markdown]
|
|
# # gradient checkpoint
|
|
#
|
|
# in jax this is implemented with the remat function
|
|
#
|
|
# practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes
|
|
def gelu(x: jax.Array) -> jax.Array:
|
|
"""GeLU activation function with approximate tanh."""
|
|
# This will be printed once every time the function is executed.
|
|
jax.debug.print("Executing GeLU")
|
|
# See https://arxiv.org/abs/1606.08415 for details.
|
|
x3 = jnp.power(x, 3)
|
|
tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
|
|
return 0.5 * x * (1 + jnp.tanh(tanh_input))
|
|
|
|
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
|
|
act_fn = gelu
|
|
if remat:
|
|
act_fn = jax.remat(act_fn)
|
|
return jnp.mean(act_fn(x))
|
|
|
|
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
|
|
grad_fn = jax.grad(loss_fn)
|
|
# regenerate function on backward
|
|
_ = grad_fn(x, remat=True)
|
|
|
|
# no remat, no function regeneration
|
|
_ = loss_fn(x, remat=False)
|
|
|
|
#MARK: GRADIENT ACCUMULATION
|
|
# %% [markdown]
|
|
# # gradient accumulation
|
|
#
|
|
# run many mini-batches, and accumulate their gradients to feed into optimizer
|
|
# as if there were one large batch
|
|
|
|
|
|
# %%
|
|
class TrainState(train_state.TrainState):
|
|
rng: jax.Array
|
|
|
|
@dataclass
|
|
class Batch:
|
|
inputs: jax.Array
|
|
labels: jax.Array
|
|
|
|
# %%
|
|
# nothing special here, just a loss function
|
|
def classification_loss_fn(
|
|
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
|
|
) -> Tuple[PyTree, Metrics]:
|
|
"""Classification loss function with cross-entropy."""
|
|
logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
|
|
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
|
|
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
|
|
batch_size = batch.inputs.shape[0]
|
|
# step_metrics contains the loss sum
|
|
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
|
|
# loss contains the mean loss
|
|
mean_loss = loss.mean() # the mathematical output of function
|
|
return mean_loss, step_metrics
|
|
|
|
# %%
|
|
# gradient accumulation training loop
|
|
def accumulate_gradients_loop(
|
|
state: TrainState,
|
|
batch: Batch,
|
|
rng: jax.random.PRNGKey,
|
|
num_minibatches: int,
|
|
loss_fn: Callable,
|
|
) -> Tuple[PyTree, Metrics]:
|
|
"""Calculate gradients and metrics for a batch using gradient accumulation.
|
|
|
|
Args:
|
|
state: Current training state.
|
|
batch: Full training batch.
|
|
rng: Random number generator to use.
|
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
|
loss_fn: Loss function to calculate gradients and metrics.
|
|
|
|
Returns:
|
|
Tuple with accumulated gradients and metrics over the minibatches.
|
|
"""
|
|
batch_size = batch.inputs.shape[0]
|
|
minibatch_size = batch_size // num_minibatches
|
|
rngs = jax.random.split(rng, num_minibatches)
|
|
# Define gradient function for single minibatch.
|
|
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
|
|
# otherwise it returns (value, gradient), where value is the actual output
|
|
# of the function, hence the "value" of the namesake
|
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
|
# Prepare loop variables.
|
|
grads = None
|
|
metrics = None
|
|
for minibatch_idx in range(num_minibatches):
|
|
with jax.named_scope(f"minibatch_{minibatch_idx}"):
|
|
# Split the batch into minibatches.
|
|
start = minibatch_idx * minibatch_size
|
|
end = start + minibatch_size
|
|
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
|
|
# Calculate gradients and metrics for the minibatch.
|
|
# missing value is mean loss of batch
|
|
(_, step_metrics), step_grads = grad_fn(
|
|
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
|
|
)
|
|
# Accumulate gradients and metrics across minibatches.
|
|
if grads is None:
|
|
grads = step_grads
|
|
metrics = step_metrics
|
|
else:
|
|
# accumulation adder
|
|
grads = jax.tree.map(jnp.add, grads, step_grads)
|
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
|
# Average gradients over minibatches.
|
|
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
|
|
return grads, metrics
|
|
|
|
|
|
# %%
|
|
# jax.scan implementation
|
|
#
|
|
# pros: faster compile
|
|
# cons: slower inference
|
|
def accumulate_gradients_scan(
|
|
state: TrainState,
|
|
batch: Batch,
|
|
rng: jax.random.PRNGKey,
|
|
num_minibatches: int,
|
|
loss_fn: Callable,
|
|
) -> Tuple[PyTree, Metrics]:
|
|
"""Calculate gradients and metrics for a batch using gradient accumulation.
|
|
|
|
In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.
|
|
|
|
Args:
|
|
state: Current training state.
|
|
batch: Full training batch.
|
|
rng: Random number generator to use.
|
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
|
loss_fn: Loss function to calculate gradients and metrics.
|
|
|
|
Returns:
|
|
Tuple with accumulated gradients and metrics over the minibatches.
|
|
"""
|
|
batch_size = batch.inputs.shape[0]
|
|
minibatch_size = batch_size // num_minibatches
|
|
rngs = jax.random.split(rng, num_minibatches)
|
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
|
|
|
def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
|
|
"""Determine gradients and metrics for a single minibatch."""
|
|
minibatch = jax.tree.map(
|
|
# jax.lax.dynamic_slice_in_dim
|
|
# This is roughly equivalent to the following Python indexing syntax
|
|
# applied along the specified axis: operand[..., start_index:start_index + slice_size].
|
|
# jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)
|
|
lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array).
|
|
x,
|
|
start_index=minibatch_idx * minibatch_size,
|
|
slice_size=minibatch_size,
|
|
axis=0
|
|
),
|
|
batch,
|
|
)
|
|
(_, step_metrics), step_grads = grad_fn(
|
|
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
|
|
)
|
|
return step_grads, step_metrics
|
|
|
|
# the function we expect scan to use
|
|
def _scan_step(
|
|
carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
|
|
) -> Tuple[Tuple[PyTree, Metrics], None]:
|
|
"""Scan step function for looping over minibatches."""
|
|
step_grads, step_metrics = _minibatch_step(minibatch_idx)
|
|
# notice how the carry type is a tuple of pytree and metrics
|
|
# carry is literally the accumulator of (step_grads, step_metrics)
|
|
carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics))
|
|
# jax.lax.scan expects a carry and a y
|
|
# but we have no y
|
|
return carry, None
|
|
|
|
# Determine initial shapes for gradients and metrics.
|
|
grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
|
|
grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
|
|
metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
|
|
# Loop over minibatches to determine gradients and metrics.
|
|
# jax.lax.scan
|
|
# jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)
|
|
# purpose: Scan a function over leading array axes while carrying along state.
|
|
# in other words, a functional for-loop
|
|
# why? because the for-loop is a single WhileOp in JAX primitive, making it faster
|
|
# equivalent python code semantics:
|
|
# def scan(f, init, xs, length=None):
|
|
# if xs is None:
|
|
# xs = [None] * length
|
|
# carry = init
|
|
# ys = []
|
|
# for x in xs:
|
|
# carry, y = f(carry, x)
|
|
# ys.append(y)
|
|
# return carry, np.stack(ys)
|
|
# note: usually we expect the ys to be the output and the carry to be hidden state
|
|
# y is the pure function output we expect
|
|
(grads, metrics), _ = jax.lax.scan(
|
|
_scan_step,
|
|
init=(grads, metrics),
|
|
xs=jnp.arange(num_minibatches),
|
|
length=num_minibatches
|
|
)
|
|
# Average gradients over minibatches.
|
|
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
|
|
return grads, metrics
|
|
|
|
# %%
|
|
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
|
|
if use_scan:
|
|
return accumulate_gradients_scan(*args, **kwargs)
|
|
else:
|
|
return accumulate_gradients_loop(*args, **kwargs)
|
|
|
|
# %%
|
|
def train_step(
|
|
state: TrainState,
|
|
metrics: Metrics | None,
|
|
batch: Batch,
|
|
num_minibatches: int,
|
|
) -> Tuple[TrainState, Metrics]:
|
|
"""Training step function.
|
|
|
|
Executes a full training step with gradient accumulation.
|
|
|
|
Args:
|
|
state: Current training state.
|
|
metrics: Current metrics, accumulated from previous training steps.
|
|
batch: Training batch.
|
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
|
|
|
Returns:
|
|
Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
|
|
"""
|
|
# Split the random number generator for the current step.
|
|
rng, step_rng = jax.random.split(state.rng)
|
|
# Determine gradients and metrics for the full batch.
|
|
grads, step_metrics = accumulate_gradients(
|
|
# we cannot use a variable to choose use_scan
|
|
# cardinal sin of jax: passing boolean into jitted function
|
|
state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
|
|
)
|
|
# Optimizer step.
|
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
|
# Accumulate metrics across training steps.
|
|
if metrics is None:
|
|
metrics = step_metrics
|
|
else:
|
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
|
return new_state, metrics
|
|
|
|
|
|
# %%
|
|
batch_size = 512
|
|
num_inputs = 128
|
|
num_classes = 100
|
|
rng_seed = 0
|
|
|
|
rng = jax.random.PRNGKey(rng_seed)
|
|
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
|
|
batch = Batch(
|
|
inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
|
|
labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
|
|
)
|
|
|
|
# Zero dropout for checking later equality between training with and without gradient accumulation.
|
|
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
|
|
params = model.init(model_rng, batch.inputs, train=False)["params"]
|
|
state = TrainState.create(
|
|
apply_fn=model.apply,
|
|
params=params,
|
|
tx=optax.adam(1e-3),
|
|
rng=state_rng,
|
|
)
|
|
# %%
|
|
# jax.eval_shape(fun, *args, **kwargs)
|
|
# compute shape/dtype of fun without any FLOPs
|
|
|
|
# this fails because it jits train_step without minibatch number
|
|
# thus causing the shape inference to fail
|
|
# _, metric_shapes = jax.eval_shape(
|
|
# train_step, # fun
|
|
# state, # train state
|
|
# None, # metrics
|
|
# batch, # batch
|
|
# 4, # num_minibatches
|
|
# )
|
|
|
|
_, metric_shapes = jax.eval_shape(
|
|
# this thing jitted works
|
|
functools.partial(train_step, num_minibatches=4),
|
|
state, # train state
|
|
None, # metrics
|
|
batch, # batch
|
|
)
|
|
|
|
print("Metric shapes:")
|
|
pprint(metric_shapes)
|
|
|
|
# %%
|
|
# this is an optimization trick
|
|
# cache this every time num_minibatches change
|
|
# otherwise re-compile every time
|
|
train_step_jit = jax.jit(
|
|
train_step,
|
|
# treat as a static argument
|
|
static_argnames="num_minibatches",
|
|
)
|
|
|
|
# %%
|
|
def train_with_minibatches(
|
|
state: TrainState,
|
|
batch: Batch,
|
|
num_minibatches: int,
|
|
num_train_steps: int,
|
|
) -> Tuple[TrainState, Metrics]:
|
|
"""Small helper function for training loop."""
|
|
train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
|
|
for _ in range(num_train_steps):
|
|
state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
|
|
return state, train_metrics
|
|
# %%
|
|
def print_metrics(metrics: Metrics, title: str | None = None) -> None:
|
|
"""Prints metrics with an optional title."""
|
|
metrics = jax.device_get(metrics)
|
|
lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()]
|
|
if title:
|
|
title = f" {title} "
|
|
max_len = max(len(title), max(map(len, lines)))
|
|
lines = [title.center(max_len, "=")] + lines
|
|
print("\n".join(lines))
|
|
|
|
# %%
|
|
state_mini1, metrics_mini1 = train_with_minibatches(
|
|
state, batch, num_minibatches=1, num_train_steps=4
|
|
)
|
|
state_mini4, metrics_mini4 = train_with_minibatches(
|
|
state, batch, num_minibatches=4, num_train_steps=4
|
|
)
|
|
print_metrics(metrics_mini1, "Minibatch 1")
|
|
print_metrics(metrics_mini4, "Minibatch 4")
|
|
|
|
|
|
# %% [markdown]
|
|
# # donate_buffers
|
|
# jax perform pass by value due to its functional nature
|
|
# we can do pass by reference for certain arguments
|
|
# what can be donated?
|
|
# we can only do this if we are sure that arguments will not be used
|
|
# this is usually true for model parameters and optimizer state
|
|
# since we have totally new values, and we won't use the argument values anymore
|
|
# after an update (e.g. we will use new_state and new_metrics)
|
|
train_step_donated = jax.jit(
|
|
train_step,
|
|
static_argnames="num_minibatches",
|
|
donate_argnames=(
|
|
"state",
|
|
"metrics",
|
|
),
|
|
)
|