Feat: fsdp demo
Refactor: pulling dataloader code into dataload.py
This commit is contained in:
parent
005a1a5735
commit
ad5cf7735f
|
@ -0,0 +1 @@
|
||||||
|
__pycache__
|
|
@ -0,0 +1,172 @@
|
||||||
|
# %%
|
||||||
|
# Prepare dataloader for jax training
|
||||||
|
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from datasets import ClassLabel, Value, Sequence
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax
|
||||||
|
import math
|
||||||
|
from typing import Optional, List, Tuple, Callable, cast
|
||||||
|
|
||||||
|
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
# split_datasets = load_from_disk(file_path)
|
||||||
|
# training_size = len(split_datasets['train'])
|
||||||
|
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||||
|
# Define additional special tokens
|
||||||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
||||||
|
|
||||||
|
|
||||||
|
# class takes in a dataset
|
||||||
|
class DataPrepare():
|
||||||
|
|
||||||
|
def __init__(self, raw_dataset, config):
|
||||||
|
self.raw_dataset: Dataset = raw_dataset
|
||||||
|
self.train_dataset: Optional[Dataset] = None
|
||||||
|
self.size: int = len(raw_dataset)
|
||||||
|
self.config: ConfigDict = config
|
||||||
|
|
||||||
|
self.make_dataset()
|
||||||
|
|
||||||
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||||
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||||
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||||
|
|
||||||
|
# given a dataset entry, run it through the tokenizer
|
||||||
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||||
|
def preprocess_function(self, example: Dataset):
|
||||||
|
inputs = example['input']
|
||||||
|
targets = example['output']
|
||||||
|
# text_target sets the corresponding label to inputs
|
||||||
|
# there is no need to create a separate 'labels'
|
||||||
|
# produce input_ids and decoder_input_ids
|
||||||
|
model_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
max_length=self.config.max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
labels = tokenizer(
|
||||||
|
text_target=targets,
|
||||||
|
max_length=self.config.max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
# for loss computation
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
# make decoder input ids
|
||||||
|
decoder_input_ids = shift_tokens_right_fn(
|
||||||
|
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
# require by model
|
||||||
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||||
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||||
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def make_dataset(self):
|
||||||
|
train_dataset = self.raw_dataset.map(
|
||||||
|
self.preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=1,
|
||||||
|
# if we do not remove, we keep the original data
|
||||||
|
remove_columns=self.raw_dataset.column_names,)
|
||||||
|
|
||||||
|
# set to numpy
|
||||||
|
train_dataset.set_format(
|
||||||
|
type='numpy',
|
||||||
|
columns=[
|
||||||
|
'input_ids', 'attention_mask',
|
||||||
|
'labels', 'decoder_input_ids',
|
||||||
|
'decoder_attention_mask']
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that data fits
|
||||||
|
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
||||||
|
int_array: np.array = train_dataset[name]
|
||||||
|
if np.all((int_array >= 0) & (int_array <= 65535)):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError("Values are out of range for uint16")
|
||||||
|
|
||||||
|
# change to compact datatypes
|
||||||
|
features = train_dataset.features.copy()
|
||||||
|
features['input_ids'] = Sequence(Value('uint16'))
|
||||||
|
features['attention_mask'] = Sequence(Value('bool'))
|
||||||
|
features['labels'] = Sequence(Value('uint16'))
|
||||||
|
features['decoder_input_ids'] = Sequence(Value('uint16'))
|
||||||
|
features['decoder_attention_mask'] = Sequence(Value('bool'))
|
||||||
|
train_dataset = train_dataset.cast(features)
|
||||||
|
# assign the dataset to train_dataset
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
|
||||||
|
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
"""
|
||||||
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||||
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||||
|
"""
|
||||||
|
assert(self.train_dataset is not None)
|
||||||
|
dataset: Dataset = cast(Dataset, self.train_dataset)
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||||
|
batch_idx = np.asarray(batch_idx)
|
||||||
|
else:
|
||||||
|
batch_idx = np.arange(len(dataset))
|
||||||
|
|
||||||
|
if drop_last:
|
||||||
|
steps_per_epoch = len(dataset) // batch_size
|
||||||
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||||
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||||
|
else:
|
||||||
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||||
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||||
|
|
||||||
|
for idx in batch_idx:
|
||||||
|
batch = dataset[idx]
|
||||||
|
batch = {k: jnp.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
# testing out the class
|
||||||
|
# # %%
|
||||||
|
# # init object
|
||||||
|
# # e.g. Config
|
||||||
|
# data_config = ConfigDict(
|
||||||
|
# dict(
|
||||||
|
# max_length=86,
|
||||||
|
# pad_token_id=0,
|
||||||
|
# decoder_start_token_id=0
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# dataprep = DataPrepare(split_datasets, data_config)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# seed = 117
|
||||||
|
# rng = jax.random.PRNGKey(seed)
|
||||||
|
# train_loader = dataprep.data_loader(rng, batch_size=32)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# batch = next(iter(train_loader))
|
||||||
|
# batch['input_ids'].shape
|
||||||
|
# # %%
|
|
@ -0,0 +1,754 @@
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Fully-Sharded Data Parallelism
|
||||||
|
|
||||||
|
|
||||||
|
# MARK: START
|
||||||
|
# %%
|
||||||
|
# let's make 8-device simulator
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# required functions:
|
||||||
|
# Batch
|
||||||
|
# TrainState
|
||||||
|
# accumulate_gradients
|
||||||
|
# print_metrics
|
||||||
|
from single_gpu_optimizations import Batch, TrainState, accumulate_gradients, print_metrics
|
||||||
|
# %%
|
||||||
|
# import the fold_rng_over_axis
|
||||||
|
|
||||||
|
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
|
||||||
|
"""Folds the random number generator over the given axis.
|
||||||
|
|
||||||
|
This is useful for generating a different random number for each device
|
||||||
|
across a certain axis (e.g. the model axis).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rng: The random number generator.
|
||||||
|
axis_name: The axis name to fold the random number generator over.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new random number generator, different for each device index along the axis.
|
||||||
|
"""
|
||||||
|
axis_index = jax.lax.axis_index(axis_name)
|
||||||
|
return jax.random.fold_in(rng, axis_index)
|
||||||
|
|
||||||
|
# MARK: DATA PARALLELISM
|
||||||
|
# %% [markdown]
|
||||||
|
# # Data Parallelism
|
||||||
|
# we start with plain data parallelism
|
||||||
|
#
|
||||||
|
# using shard_map, we write single-device code and let shard map handle the rest
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# plain data parallel - sharding only data inputs and outputs
|
||||||
|
class DPClassifier(nn.Module):
|
||||||
|
# contains the attributes listed in config
|
||||||
|
# hidden_size
|
||||||
|
# dropout_rate
|
||||||
|
# dtype - for computation
|
||||||
|
# num_classes
|
||||||
|
# data_axis_name
|
||||||
|
config: ConfigDict
|
||||||
|
|
||||||
|
# note how there is no data_axis_name within the actual __call__
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
||||||
|
x = nn.Dense(
|
||||||
|
features=self.config.hidden_size,
|
||||||
|
dtype=self.config.dtype,
|
||||||
|
name="input_dense",
|
||||||
|
)(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
|
||||||
|
x = nn.Dense(
|
||||||
|
features=self.config.num_classes,
|
||||||
|
dtype=self.config.dtype,
|
||||||
|
name="output_dense",
|
||||||
|
)(x)
|
||||||
|
x = x.astype(jnp.float32)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# config
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
batch_size=128,
|
||||||
|
num_classes=10,
|
||||||
|
input_size=784,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
hidden_size=512,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
dtype=jnp.bfloat16,
|
||||||
|
num_classes=data_config.num_classes,
|
||||||
|
data_axis_name="data",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
optimizer_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
learning_rate=1e-3,
|
||||||
|
num_minibatches=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
model=model_config,
|
||||||
|
optimizer=optimizer_config,
|
||||||
|
data=data_config,
|
||||||
|
data_axis_name=model_config.data_axis_name,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# initialize
|
||||||
|
model_dp = DPClassifier(config=config.model)
|
||||||
|
optimizer = optax.adamw(
|
||||||
|
learning_rate=config.optimizer.learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init rng
|
||||||
|
rng = jax.random.PRNGKey(config.seed)
|
||||||
|
# init model rng
|
||||||
|
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
|
||||||
|
# create synthetic data
|
||||||
|
batch = Batch(
|
||||||
|
inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
|
||||||
|
labels=jax.random.randint(
|
||||||
|
data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# init data_parallel TrainState state
|
||||||
|
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
|
||||||
|
init_rng, rng = jax.random.split(rng)
|
||||||
|
variables = model.init({"params": init_rng}, x, train=False)
|
||||||
|
params = variables.pop("params")
|
||||||
|
state = TrainState.create(
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer,
|
||||||
|
rng=rng,
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# create mesh
|
||||||
|
device_array = np.array(jax.devices())
|
||||||
|
mesh = Mesh(device_array, (config.data_axis_name,))
|
||||||
|
|
||||||
|
# we are just sharding the same model across devices
|
||||||
|
# no different from a flax replicate
|
||||||
|
init_dp_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
functools.partial(init_dp, model=model_dp),
|
||||||
|
mesh,
|
||||||
|
in_specs=(P(), P(config.data_axis_name)),
|
||||||
|
out_specs=P(),
|
||||||
|
check_rep=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dp = init_dp_fn(model_init_rng, batch.inputs)
|
||||||
|
print("DP Parameters")
|
||||||
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
|
||||||
|
|
||||||
|
# MARK: TRAIN STEP
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(
|
||||||
|
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
|
||||||
|
) -> Tuple[jax.Array, Dict[str, Any]]:
|
||||||
|
|
||||||
|
# set different rng over various devices
|
||||||
|
dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
|
||||||
|
|
||||||
|
# Remaining computation is the same as before for single device.
|
||||||
|
logits = apply_fn(
|
||||||
|
{"params": params},
|
||||||
|
batch.inputs,
|
||||||
|
train=True,
|
||||||
|
rngs={"dropout": dropout_rng})
|
||||||
|
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
|
||||||
|
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
|
||||||
|
batch_size = batch.inputs.shape[0]
|
||||||
|
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss, step_metrics
|
||||||
|
|
||||||
|
# train step dp
|
||||||
|
# simple data parallel has the model on every device
|
||||||
|
# but each device has different data
|
||||||
|
def train_step_dp(
|
||||||
|
state: TrainState,
|
||||||
|
metrics: Metrics | None,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Tuple[TrainState, Metrics]:
|
||||||
|
rng, step_rng = jax.random.split(state.rng)
|
||||||
|
# accumulate gradients like before
|
||||||
|
grads, step_metrics = accumulate_gradients(
|
||||||
|
state,
|
||||||
|
batch,
|
||||||
|
step_rng,
|
||||||
|
config.optimizer.num_minibatches,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
)
|
||||||
|
# Update parameters. We need to sync the gradients across devices before updating.
|
||||||
|
with jax.named_scope("sync_gradients"):
|
||||||
|
grads = jax.tree.map(
|
||||||
|
lambda g: jax.lax.pmean(
|
||||||
|
g, axis_name=config.data_axis_name),
|
||||||
|
grads)
|
||||||
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
||||||
|
|
||||||
|
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
|
||||||
|
# and only synchronize them before logging. For simplicity, we sum them here.
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = jax.tree.map(
|
||||||
|
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
if metrics is None:
|
||||||
|
metrics = step_metrics
|
||||||
|
else:
|
||||||
|
# combine all the synced metrics
|
||||||
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
||||||
|
|
||||||
|
return new_state, metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we will now wrap the train step with shard_map and jit it
|
||||||
|
# here we will be sharding input and output data
|
||||||
|
train_step_dp_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
train_step_dp,
|
||||||
|
mesh,
|
||||||
|
in_specs=(P(), P(), P(config.data_axis_name)),
|
||||||
|
out_specs=(P(), P()),
|
||||||
|
check_rep=False,
|
||||||
|
),
|
||||||
|
# state and metrics change and won't be re-used
|
||||||
|
# pass by reference and throw away with function
|
||||||
|
donate_argnames=("state", "metrics"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# get the metric_shapes so that we can init arrays for accumulation
|
||||||
|
_, metric_shapes = jax.eval_shape(
|
||||||
|
train_step_dp_fn,
|
||||||
|
state_dp,
|
||||||
|
None,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
# init arrays with shape
|
||||||
|
metrics_dp = jax.tree.map(
|
||||||
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
||||||
|
metric_shapes)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(15):
|
||||||
|
state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(duration)
|
||||||
|
|
||||||
|
final_metrics_dp = jax.tree.map(
|
||||||
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
||||||
|
metric_shapes)
|
||||||
|
state_dp, final_metrics_dp = train_step_dp_fn(
|
||||||
|
state_dp,
|
||||||
|
final_metrics_dp,
|
||||||
|
batch)
|
||||||
|
print_metrics(final_metrics_dp)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print("DP Parameters")
|
||||||
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))
|
||||||
|
print("Metrics")
|
||||||
|
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), final_metrics_dp))
|
||||||
|
|
||||||
|
####################################################################
|
||||||
|
# stuff works until here
|
||||||
|
# it is still same as flax replicate style in huggingface
|
||||||
|
|
||||||
|
|
||||||
|
# MARK: PARAMETER SHARDING
|
||||||
|
# %% [markdown]
|
||||||
|
# # parameter sharding
|
||||||
|
# Basic strategy: init full parameters on each device, then use
|
||||||
|
# jax.lax.axis_index to split parameters across devices, and keep a shard on
|
||||||
|
# each device
|
||||||
|
#
|
||||||
|
# use nn.Partitioned to annotate sharding spec on parameters
|
||||||
|
# quite similar to PartitionSpec
|
||||||
|
#
|
||||||
|
# parameters are either jax.Array or a flax.linen.Partitioned
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# type annotation
|
||||||
|
Parameter = jax.Array | nn.Partitioned
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# function to shard parameters across devices
|
||||||
|
# look for an axis to equally split across the number of devices
|
||||||
|
# we can specify which parameters to shard, since they vary in size
|
||||||
|
# we set a floor on the size for sharding
|
||||||
|
@jax.named_scope("shard_params")
|
||||||
|
def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree:
|
||||||
|
"""Shard parameters across the given mesh axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters to shard.
|
||||||
|
axis_name: The axis to shard parameters across.
|
||||||
|
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTree of same structure as params, but with leaves sharded over new axis if possible.
|
||||||
|
"""
|
||||||
|
# axis_index
|
||||||
|
axis_idx = jax.lax.axis_index(axis_name)
|
||||||
|
# number of units in the axis
|
||||||
|
axis_size = jax.lax.psum(1, axis_name)
|
||||||
|
|
||||||
|
# split function
|
||||||
|
# check each parameter if it had been sharded
|
||||||
|
def _split(x: Parameter) -> Parameter:
|
||||||
|
|
||||||
|
# already sharded
|
||||||
|
if isinstance(x, nn.Partitioned):
|
||||||
|
value, names = x.value, x.names
|
||||||
|
# not sharded
|
||||||
|
else:
|
||||||
|
value = x
|
||||||
|
names = (None,) * value.ndim
|
||||||
|
|
||||||
|
# logging only runs on first jit
|
||||||
|
# this section checks for why a parameter is not already sharded on the axis
|
||||||
|
# check for sharded parameters despite being sharded
|
||||||
|
# (that means its on a different axis)
|
||||||
|
if axis_name in names:
|
||||||
|
logging.warning(
|
||||||
|
f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}."
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
# check if parameter is to small
|
||||||
|
elif value.size <= min_weight_size:
|
||||||
|
logging.info(
|
||||||
|
f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}."
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
# let's start sharding!
|
||||||
|
else:
|
||||||
|
shape = value.shape
|
||||||
|
idx = np.argsort(shape)[::-1] # Shard along largest possible axis.
|
||||||
|
for i in idx:
|
||||||
|
# this technically runs once because of return
|
||||||
|
# we only shard if we can split evenly across devices
|
||||||
|
# and if it ain't alreayd sharded
|
||||||
|
if shape[i] % axis_size == 0 and names[i] is None:
|
||||||
|
split_size = shape[i] // axis_size
|
||||||
|
p_sharded = nn.Partitioned(
|
||||||
|
value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device.
|
||||||
|
value,
|
||||||
|
axis_idx * split_size,
|
||||||
|
split_size,
|
||||||
|
axis=i
|
||||||
|
),
|
||||||
|
names=names[:i] + (axis_name,) + names[i + 1 :],
|
||||||
|
)
|
||||||
|
return p_sharded
|
||||||
|
|
||||||
|
logging.warning(
|
||||||
|
f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found."
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# we apply the _split function across the parameter pytree
|
||||||
|
return jax.tree.map(
|
||||||
|
_split,
|
||||||
|
params,
|
||||||
|
is_leaf=lambda x: isinstance(
|
||||||
|
x, nn.Partitioned
|
||||||
|
), # Consider a nn.Partitioned object as a leaf.
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# function to gather parameters back to a single device
|
||||||
|
|
||||||
|
# but first we need create a custom function for mean gradient computation
|
||||||
|
# jax.lax.all_gather -> retrieve shards and assemble full array in each device
|
||||||
|
# jax.lax.psum_scatter -> scatter gradients back to respective devices
|
||||||
|
def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str):
|
||||||
|
"""Gathering with averaging gradients across replicas."""
|
||||||
|
axis_size = jax.lax.psum(1, axis_name)
|
||||||
|
|
||||||
|
# Define a custom gradient for the gather operation.
|
||||||
|
@jax.custom_gradient
|
||||||
|
def f(x):
|
||||||
|
# adjust backward to turn sum into mean of axis
|
||||||
|
def grad_fn(g):
|
||||||
|
# pmean_scatter from psum_scatter
|
||||||
|
# after computing from full gradient array, our shard only has a
|
||||||
|
# portion of the parameters, we only get the gradients associated
|
||||||
|
# with parameters of our shard
|
||||||
|
return (
|
||||||
|
jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# assemble shards to form full gradient array
|
||||||
|
return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn
|
||||||
|
|
||||||
|
return f(x)
|
||||||
|
|
||||||
|
# gather params back - e.g. when computing a module forward call
|
||||||
|
# reverse operation of "shard_params"
|
||||||
|
# depends on: gather_array_with_mean_grads
|
||||||
|
@jax.named_scope("gather_params")
|
||||||
|
def gather_params(params: PyTree, axis_name: str) -> PyTree:
|
||||||
|
"""Gather parameters from all replicas across the given axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: The parameters to gather.
|
||||||
|
axis_name: The axis to gather parameters across.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _gather(p: Parameter) -> Parameter:
|
||||||
|
if isinstance(p, nn.Partitioned) and axis_name in p.names:
|
||||||
|
param_shard = p.names
|
||||||
|
shard_axis = param_shard.index(axis_name)
|
||||||
|
value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name)
|
||||||
|
|
||||||
|
# If there are any other axes that are sharded, we need to keep the partitioned structure.
|
||||||
|
# Otherwise, we can return the value directly.
|
||||||
|
param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :]
|
||||||
|
if any([name is not None for name in param_shard]):
|
||||||
|
# we return the still-sharded axes shard
|
||||||
|
return nn.Partitioned(value, param_shard)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return p
|
||||||
|
|
||||||
|
# we find all the sharded params and gather them, returning a complete parameter
|
||||||
|
return jax.tree.map(
|
||||||
|
_gather,
|
||||||
|
params,
|
||||||
|
is_leaf=lambda x: isinstance(x, nn.Partitioned))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# when we call a module, we gather the parameters back to a single device
|
||||||
|
# wrap a module into a nn.map_variables transform
|
||||||
|
# allows for transforms on the parameter before and after a module call
|
||||||
|
# depends on: gather_params, shard_params
|
||||||
|
def shard_module_params(
|
||||||
|
target: nn.Module | Callable,
|
||||||
|
axis_name: str,
|
||||||
|
min_weight_size: int = 2**18 # 262,144
|
||||||
|
) -> nn.Module | Callable:
|
||||||
|
"""Shard parameters of a module across replicas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The module to shard.
|
||||||
|
axis_name: The axis name to shard parameters across.
|
||||||
|
min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module with sharded parameters.
|
||||||
|
"""
|
||||||
|
return nn.map_variables(
|
||||||
|
target,
|
||||||
|
trans_in_fn=functools.partial(
|
||||||
|
gather_params, axis_name=axis_name),
|
||||||
|
trans_out_fn=functools.partial(
|
||||||
|
shard_params, axis_name=axis_name, min_weight_size=min_weight_size
|
||||||
|
),
|
||||||
|
mapped_collections="params",
|
||||||
|
mutable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# define new function with axes constraints
|
||||||
|
# this forms the template for sharding future modules
|
||||||
|
# remember, flax modules are subclassed from elementary flax modules
|
||||||
|
class FSDPClassifier(nn.Module):
|
||||||
|
config: ConfigDict
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
||||||
|
# create a sharded module
|
||||||
|
sharded_dense = shard_module_params(
|
||||||
|
nn.Dense,
|
||||||
|
axis_name=self.config.data_axis_name, # axes
|
||||||
|
min_weight_size=self.config.min_weight_size, # min_weight
|
||||||
|
)
|
||||||
|
x = sharded_dense(
|
||||||
|
features=self.config.hidden_size,
|
||||||
|
dtype=self.config.dtype,
|
||||||
|
name="input_dense",
|
||||||
|
)(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
|
||||||
|
x = sharded_dense(
|
||||||
|
features=self.config.num_classes,
|
||||||
|
dtype=self.config.dtype,
|
||||||
|
name="output_dense",
|
||||||
|
)(x)
|
||||||
|
x = x.astype(jnp.float32)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# initialization
|
||||||
|
config.model.min_weight_size = 2**4
|
||||||
|
model_fsdp = FSDPClassifier(config=config.model)
|
||||||
|
|
||||||
|
# the earlier init function
|
||||||
|
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
|
||||||
|
init_rng, rng = jax.random.split(rng)
|
||||||
|
variables = model.init({"params": init_rng}, x, train=False)
|
||||||
|
params = variables.pop("params")
|
||||||
|
state = TrainState.create(
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer,
|
||||||
|
rng=rng,
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# initialize our sharded model with mesh
|
||||||
|
# we need to adjust the shard map since partitioning is determined within the
|
||||||
|
# model init, hence we cannot manually specify it
|
||||||
|
#
|
||||||
|
# we do a hack where we just try and let it evaluate the shapes
|
||||||
|
# we set an unknown output specification - aka fully replicate
|
||||||
|
#
|
||||||
|
# we then get the partition_spec of the shapes of the parameters
|
||||||
|
init_fsdp_fn = shard_map(
|
||||||
|
functools.partial(init_dp, model=model_fsdp),
|
||||||
|
mesh,
|
||||||
|
# first P() is for model_init_rng
|
||||||
|
# second P(config.data_axis_name) is for batch.inputs
|
||||||
|
in_specs=(P(), P(config.data_axis_name)),
|
||||||
|
# not partitioned, fully replicated
|
||||||
|
out_specs=P(),
|
||||||
|
check_rep=False, # disable checks for replication errors in out_specs
|
||||||
|
)
|
||||||
|
state_fsdp_shapes = jax.eval_shape(init_fsdp_fn, model_init_rng, batch.inputs)
|
||||||
|
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
|
||||||
|
# %% [raw]
|
||||||
|
# TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of FSDPClassifier(
|
||||||
|
# # attributes
|
||||||
|
# config = data_axis_name: data
|
||||||
|
# dropout_rate: 0.1
|
||||||
|
# dtype: !!python/name:jax.numpy.bfloat16 ''
|
||||||
|
# hidden_size: 512
|
||||||
|
# min_weight_size: 16
|
||||||
|
# num_classes: 10
|
||||||
|
#
|
||||||
|
# )>,
|
||||||
|
# params={
|
||||||
|
# 'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
||||||
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
|
||||||
|
# tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x761e8ef00400>,
|
||||||
|
# update=<function chain.<locals>.update_fn at 0x761e8ef01080>),
|
||||||
|
# opt_state=(ScaleByAdamState(count=PartitionSpec(),
|
||||||
|
# mu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
||||||
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}},
|
||||||
|
# nu={'input_dense': {'bias': PartitionSpec('data',), 'kernel': PartitionSpec('data', None)},
|
||||||
|
# 'output_dense': {'bias': PartitionSpec(), 'kernel': PartitionSpec('data', None)}}),
|
||||||
|
# EmptyState(), EmptyState()), rng=PartitionSpec())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# then from the state_fsdp_specs, we obtain our config
|
||||||
|
# this print clarifies everything -> the reason why earlier we do not know the
|
||||||
|
# partitionspec is because we only know which parameters gets to be sharded at
|
||||||
|
# model init
|
||||||
|
print("RNG", state_fsdp_specs.rng)
|
||||||
|
print("\nParameters")
|
||||||
|
pprint(state_fsdp_specs.params)
|
||||||
|
print("\nOptimizer state")
|
||||||
|
pprint(state_fsdp_specs.opt_state[0])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# init again, this time with the specs and knowledge of what is and should not
|
||||||
|
# be sharded
|
||||||
|
init_fsdp_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
functools.partial(init_dp, model=model_fsdp),
|
||||||
|
mesh,
|
||||||
|
in_specs=(P(), P(config.data_axis_name)),
|
||||||
|
out_specs=state_fsdp_specs,
|
||||||
|
check_rep=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
state_fsdp = init_fsdp_fn(model_init_rng, batch.inputs)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print("FSDP Parameters")
|
||||||
|
pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_fsdp.params)))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
|
||||||
|
# we need to handle the sync of gradients
|
||||||
|
# some parameters are sharded, some are not
|
||||||
|
def sync_gradients(
|
||||||
|
grads: PyTree,
|
||||||
|
axis_names: Sequence[str],
|
||||||
|
) -> PyTree:
|
||||||
|
"""Synchronize gradients across devices.
|
||||||
|
|
||||||
|
Gradients for parameters that are replicated over a given axis are averaged across devices.
|
||||||
|
Parameters that are partitioned over a given axis are considered to already have a mean of
|
||||||
|
the gradients on each device, and hence do not need to be altered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads: The gradients to synchronize.
|
||||||
|
axis_names: The axis names to synchronize gradients across.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The gradients averaged over the specified axes if they are replicated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_grad(g: Parameter) -> Parameter:
|
||||||
|
if isinstance(g, nn.Partitioned):
|
||||||
|
# Tree leaves for flattening potentially nested axis (multiple names
|
||||||
|
# can exist for single array axis).
|
||||||
|
replication_axis_names = [
|
||||||
|
name for name in axis_names if name not in jax.tree_util.tree_leaves(g.names)
|
||||||
|
]
|
||||||
|
if len(replication_axis_names) == 0:
|
||||||
|
# Parameters partitioned over all axes.
|
||||||
|
return g
|
||||||
|
else:
|
||||||
|
# Average over remaining replicated axes.
|
||||||
|
return g.replace(value=jax.lax.pmean(g.value, axis_name=replication_axis_names))
|
||||||
|
else:
|
||||||
|
# Parameters are replicated over all axes.
|
||||||
|
return jax.lax.pmean(g, axis_name=axis_names)
|
||||||
|
|
||||||
|
return jax.tree.map(
|
||||||
|
sync_grad,
|
||||||
|
grads,
|
||||||
|
is_leaf=lambda x: isinstance(x, nn.Partitioned))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def train_step_fsdp(
|
||||||
|
state: TrainState,
|
||||||
|
metrics: Metrics,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Tuple[TrainState, Metrics]:
|
||||||
|
rng, step_rng = jax.random.split(state.rng)
|
||||||
|
# perform one forward pass
|
||||||
|
grads, step_metrics = accumulate_gradients(
|
||||||
|
state,
|
||||||
|
batch,
|
||||||
|
step_rng,
|
||||||
|
config.optimizer.num_minibatches,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
)
|
||||||
|
# Update parameters. We need to sync the gradients across devices before updating.
|
||||||
|
with jax.named_scope("sync_gradients"):
|
||||||
|
grads = sync_gradients(grads, (config.data_axis_name,))
|
||||||
|
# then update model
|
||||||
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
||||||
|
|
||||||
|
# Sum metrics across replicas. Alternatively, we could keep the metrics separate
|
||||||
|
# and only synchronize them before logging. For simplicity, we sum them here.
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = jax.tree.map(
|
||||||
|
lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
|
||||||
|
)
|
||||||
|
if metrics is None:
|
||||||
|
metrics = step_metrics
|
||||||
|
else:
|
||||||
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
||||||
|
return new_state, metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jit the train_step_fsdp
|
||||||
|
train_step_fsdp_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
train_step_fsdp,
|
||||||
|
mesh,
|
||||||
|
in_specs=(state_fsdp_specs, P(), P(config.data_axis_name)),
|
||||||
|
out_specs=(state_fsdp_specs, P()),
|
||||||
|
check_rep=False,
|
||||||
|
),
|
||||||
|
donate_argnames=("state", "metrics"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the metric shape to initialize accumulator arrays for metrics
|
||||||
|
_, metric_shapes = jax.eval_shape(
|
||||||
|
train_step_fsdp_fn,
|
||||||
|
state_fsdp,
|
||||||
|
None,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
metrics_fsdp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
|
||||||
|
# %%
|
||||||
|
# train
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(15):
|
||||||
|
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
|
||||||
|
state_fsdp,
|
||||||
|
metrics_fsdp, batch)
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(duration)
|
||||||
|
|
||||||
|
# get metrics and state
|
||||||
|
final_metrics_fsdp = jax.tree.map(
|
||||||
|
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
|
||||||
|
metric_shapes)
|
||||||
|
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
|
||||||
|
state_fsdp,
|
||||||
|
final_metrics_fsdp, batch)
|
||||||
|
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,373 @@
|
||||||
|
# %% [markdown]
|
||||||
|
# # Distribute computin in JAX
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import functools
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jax.devices()
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# when we create array, we can check the location
|
||||||
|
a = jnp.arange(8)
|
||||||
|
print("Array", a)
|
||||||
|
print("Device", a.device)
|
||||||
|
print("Sharding", a.sharding)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Single-Axis Mesh
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# let's create a Mesh
|
||||||
|
# multidimensional Numpy array of jax devices
|
||||||
|
# jax.sharding.Mesh(devices, axis_names)
|
||||||
|
mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.sharding.NamedSharding(mesh, spec)
|
||||||
|
# pair of a Mesh of devices and PartitionSpec
|
||||||
|
# PartitionSpec describes how to share an array across that mesh
|
||||||
|
# "i" is the value of the dimension of the array
|
||||||
|
# to shard an array axis over a certain mesh axis, add the axis name at the
|
||||||
|
# corresponding position in the tuple
|
||||||
|
sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
a_sharded = jax.device_put(a, sharding)
|
||||||
|
print("Sharded array", a_sharded)
|
||||||
|
print("Device", a_sharded.devices())
|
||||||
|
print("Sharding", a_sharded.sharding)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jax.debug.visualize_array_sharding(a_sharded)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# let's try some computation on the mesh
|
||||||
|
out = nn.tanh(a_sharded)
|
||||||
|
print("Output array", out)
|
||||||
|
jax.debug.visualize_array_sharding(out)
|
||||||
|
# note how the output array is sharded across the devices
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## multi-axis mesh
|
||||||
|
# Why would you shard across multiple dimensions?
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j"))
|
||||||
|
# axis i/0 refers to the row-wise axis progressing downwards
|
||||||
|
# axis j/1 refers to the column-wise axis progressing rightward
|
||||||
|
mesh # noqa: B018
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we now illustrate sharded MAC operation
|
||||||
|
# y = x @ w + b
|
||||||
|
batch_size = 192
|
||||||
|
input_dim = 64
|
||||||
|
output_dim = 128
|
||||||
|
# input: (batch_size, input_dim)
|
||||||
|
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
|
||||||
|
# w: (input_dim, output_dim)
|
||||||
|
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
|
||||||
|
# b: (output_dim,)
|
||||||
|
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# x sharded along 0 axis (partition)
|
||||||
|
#
|
||||||
|
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None)))
|
||||||
|
w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j")))
|
||||||
|
b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j")))
|
||||||
|
|
||||||
|
print('x blocks:')
|
||||||
|
jax.debug.visualize_array_sharding(x_sharded)
|
||||||
|
print('w blocks:')
|
||||||
|
jax.debug.visualize_array_sharding(w_sharded)
|
||||||
|
print('b blocks:')
|
||||||
|
jax.debug.visualize_array_sharding(b_sharded)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
out = jnp.dot(x_sharded, w_sharded) + b_sharded
|
||||||
|
print("Output shape", out.shape)
|
||||||
|
jax.debug.visualize_array_sharding(out)
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Shard Map -shmap
|
||||||
|
#
|
||||||
|
# beforehand, we manually assign the sharding partition to assign the exact
|
||||||
|
# partitions to achieve independent, parallel block matrix computation
|
||||||
|
#
|
||||||
|
# This allows us to write code with explicit control over parallelization and
|
||||||
|
# communication
|
||||||
|
#
|
||||||
|
# what is a shard_map?
|
||||||
|
#
|
||||||
|
# it is a transformation that takes a function, a mesh, and a sharding
|
||||||
|
# specification for inputs and outputs
|
||||||
|
#
|
||||||
|
# in other words, we write a function that executes on each device only, then
|
||||||
|
# apply across all the shards
|
||||||
|
#
|
||||||
|
# but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough
|
||||||
|
# information about the shards to efficiently perform sharding for complicated
|
||||||
|
# meshes.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
|
||||||
|
print("Local x shape", x.shape)
|
||||||
|
print("Local w shape", w.shape)
|
||||||
|
print("Local b shape", b.shape)
|
||||||
|
# so simple!
|
||||||
|
return jnp.dot(x,w) + b
|
||||||
|
|
||||||
|
# %%
|
||||||
|
matmul_sharded = shard_map(
|
||||||
|
matmul_fn, # the function for operating on a single device
|
||||||
|
mesh, # the device topology
|
||||||
|
# the input mesh partition argument for each input
|
||||||
|
in_specs=(
|
||||||
|
PartitionSpec("i", None), # x
|
||||||
|
PartitionSpec(None, "j"), # w
|
||||||
|
PartitionSpec("j") # b
|
||||||
|
),
|
||||||
|
# the output to read from the mesh
|
||||||
|
out_specs=PartitionSpec("i", "j")
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# y = matmul_sharded(x_sharded, w_sharded, b_sharded)
|
||||||
|
# there is no need to device_put,
|
||||||
|
# partitioning is done according to your in_specs
|
||||||
|
y = matmul_sharded(x, w, b)
|
||||||
|
print("Output shape", y.shape)
|
||||||
|
jax.debug.visualize_array_sharding(y)
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Axis Communication
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# example of mean/sum across devices per shard
|
||||||
|
|
||||||
|
# the following wants to find the statistics of x
|
||||||
|
# we compute the normalized x according to each row statistics (mean and std)
|
||||||
|
@functools.partial(
|
||||||
|
shard_map,
|
||||||
|
mesh=mesh,
|
||||||
|
in_specs=PartitionSpec("i", "j"),
|
||||||
|
out_specs=PartitionSpec("i", "j"))
|
||||||
|
def parallel_normalize(x: jax.Array) -> jax.Array:
|
||||||
|
# jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis
|
||||||
|
# "axis_name"
|
||||||
|
# get the mean across the "j" axis of the mesh - column wise
|
||||||
|
mean = jax.lax.pmean(x, axis_name="j")
|
||||||
|
# get the std across the "j" axis of the mesh - column wise
|
||||||
|
std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
# communicated along "j" axis of mesh for row elements
|
||||||
|
|
||||||
|
|
||||||
|
out = parallel_normalize(x)
|
||||||
|
out = jax.device_get(out)
|
||||||
|
print(out.shape)
|
||||||
|
print("Mean", out.mean())
|
||||||
|
print("Std", out.std())
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# scenario: array is sharded across devices, some values missing per shard
|
||||||
|
# all-gather: gather values of an array from all devices
|
||||||
|
@functools.partial(
|
||||||
|
shard_map,
|
||||||
|
mesh=mesh,
|
||||||
|
in_specs=(
|
||||||
|
PartitionSpec("i", None), # artificially shard across "i"
|
||||||
|
PartitionSpec("i", None)
|
||||||
|
),
|
||||||
|
out_specs=PartitionSpec("i", None))
|
||||||
|
def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array:
|
||||||
|
print("Original w shape", w.shape)
|
||||||
|
# pull the full w matrix values from neighboring devices
|
||||||
|
w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True)
|
||||||
|
print("Gathered w shape", w_gathered.shape)
|
||||||
|
y = jnp.dot(x, w_gathered)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
out = matmul_with_weight_gather(x, w)
|
||||||
|
out = jax.device_get(out)
|
||||||
|
np.testing.assert_array_equal(out, jnp.dot(x, w))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# scenario: arrays are sharded across all devices
|
||||||
|
# scatter sum: each function instance of each device gets only one shard of the result
|
||||||
|
#
|
||||||
|
# therefore each device gets the sum of some(or one) array(s)
|
||||||
|
|
||||||
|
@functools.partial(
|
||||||
|
shard_map,mesh=mesh,
|
||||||
|
in_specs=PartitionSpec("i", None),
|
||||||
|
out_specs=PartitionSpec("i", None))
|
||||||
|
def scatter_example(x: jax.Array) -> jax.Array:
|
||||||
|
x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1)
|
||||||
|
return x_scatter
|
||||||
|
|
||||||
|
|
||||||
|
x_exmp = np.array(
|
||||||
|
[
|
||||||
|
[3, 1, 4, 1],
|
||||||
|
[5, 9, 2, 6],
|
||||||
|
[5, 3, 5, 8],
|
||||||
|
[9, 7, 1, 2],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = scatter_example(x_exmp)
|
||||||
|
print("Output", out)
|
||||||
|
# %%
|
||||||
|
# ppermute: communicates an array in a round robin fashion
|
||||||
|
#
|
||||||
|
# this is used in implementing pipeline parallelism where results are passed to another device
|
||||||
|
# used in tensor parallelism
|
||||||
|
#
|
||||||
|
# notice how the results roll through the devices
|
||||||
|
#
|
||||||
|
# this can actually implement all other lax communication operations
|
||||||
|
|
||||||
|
@functools.partial(
|
||||||
|
shard_map,
|
||||||
|
mesh=mesh,
|
||||||
|
in_specs=PartitionSpec("i"),
|
||||||
|
out_specs=PartitionSpec("i"))
|
||||||
|
def ppermute_example(x: jax.Array) -> jax.Array:
|
||||||
|
axis_size = mesh.shape["i"]
|
||||||
|
print('BEFORE:\n', x)
|
||||||
|
x_perm = jax.lax.ppermute(
|
||||||
|
x,
|
||||||
|
axis_name="i",
|
||||||
|
perm=[
|
||||||
|
# source_index, destination_index pairs
|
||||||
|
(i, (i + 1) % axis_size) for i in range(axis_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print('AFTER:\n', x_perm)
|
||||||
|
return x_perm
|
||||||
|
|
||||||
|
|
||||||
|
x_exmp = np.arange(4)
|
||||||
|
out = ppermute_example(x_exmp)
|
||||||
|
print("Output", out) # the value is that of each axis 0 device
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # axis indexing: get the index of device along axis
|
||||||
|
# sometimes our computations need adjustment depending on the device its being ran on
|
||||||
|
#
|
||||||
|
# we will use jax.lax.axis_index to return the index of the current device along an axis
|
||||||
|
#
|
||||||
|
# this function will be jitted and will be almost 0 cost
|
||||||
|
|
||||||
|
axis_idx_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
lambda: jnp.stack(
|
||||||
|
[
|
||||||
|
jax.lax.axis_index("i"), # Device index in mesh along the "i" axis
|
||||||
|
jax.lax.axis_index("j"), # Device index in mesh along the "j" axis
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)[None],
|
||||||
|
mesh,
|
||||||
|
in_specs=PartitionSpec(),
|
||||||
|
out_specs=PartitionSpec(
|
||||||
|
("i", "j"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
out = axis_idx_fn()
|
||||||
|
out = jax.device_get(out)
|
||||||
|
for i in range(out.shape[0]):
|
||||||
|
print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# usage 2: fold rng over given axis
|
||||||
|
# jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key
|
||||||
|
# from a source RNG key, we generate new RNG keys
|
||||||
|
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
|
||||||
|
"""Folds the random number generator over the given axis.
|
||||||
|
|
||||||
|
This is useful for generating a different random number for each device
|
||||||
|
across a certain axis (e.g. the model axis).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rng: The random number generator.
|
||||||
|
axis_name: The axis name to fold the random number generator over.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new random number generator, different for each device index along the axis.
|
||||||
|
"""
|
||||||
|
axis_index = jax.lax.axis_index(axis_name)
|
||||||
|
return jax.random.fold_in(rng, axis_index)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we fold RNG over the i axis only
|
||||||
|
# same RNG used across j axis
|
||||||
|
fold_fn = jax.jit(
|
||||||
|
shard_map(
|
||||||
|
# fold over for "i" only
|
||||||
|
functools.partial(fold_rng_over_axis, axis_name="i"),
|
||||||
|
mesh,
|
||||||
|
in_specs=PartitionSpec(),
|
||||||
|
out_specs=PartitionSpec(
|
||||||
|
("i", "j"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
out = fold_fn(rng)
|
||||||
|
out = jax.device_get(out)
|
||||||
|
for i in range(out.shape[0] // 2):
|
||||||
|
print(f"Device {i}: RNG={out[2*i:2*i+2]}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,441 @@
|
||||||
|
|
||||||
|
# MARK: import
|
||||||
|
# %% [markdown]
|
||||||
|
# # single gpu optimizaitons
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# os.environ["XLA_FLAGS"] = (
|
||||||
|
# "--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
# "--xla_gpu_triton_gemm_any=false "
|
||||||
|
# "--xla_gpu_enable_async_collectives=true "
|
||||||
|
# "--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
# "--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
# )
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import functools
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import optax
|
||||||
|
from flax.struct import dataclass
|
||||||
|
from flax.training import train_state
|
||||||
|
|
||||||
|
# Type aliases
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
# %% [mardown]
|
||||||
|
# # bfloat16 mixed precision compute
|
||||||
|
class MLPClassifier(nn.Module):
|
||||||
|
dtype: Any # we set the dtype here for computation
|
||||||
|
hidden_size: int = 256
|
||||||
|
num_classes: int = 100
|
||||||
|
dropout_rate: float = 0.1
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x: jax.Array, train: bool) -> jax.Array:
|
||||||
|
x = nn.Dense(
|
||||||
|
features=self.hidden_size,
|
||||||
|
dtype=self.dtype, # Computation in specified dtype, params stay in float32
|
||||||
|
)(x)
|
||||||
|
x = nn.LayerNorm(dtype=self.dtype)(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
|
||||||
|
x = nn.Dense(
|
||||||
|
features=self.num_classes,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)(x)
|
||||||
|
x = x.astype(jnp.float32)
|
||||||
|
x = nn.log_softmax(x, axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
x = jnp.ones((512, 128), dtype=jnp.float32)
|
||||||
|
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
|
||||||
|
model_float32 = MLPClassifier(dtype=jnp.float32)
|
||||||
|
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# inputs and activations (outputs) in bfloat16
|
||||||
|
# parameters in float32
|
||||||
|
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
|
||||||
|
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})
|
||||||
|
|
||||||
|
# MARK: GRADIENT CHECKPOINT
|
||||||
|
# %% [markdown]
|
||||||
|
# # gradient checkpoint
|
||||||
|
#
|
||||||
|
# in jax this is implemented with the remat function
|
||||||
|
#
|
||||||
|
# practical notes on remat: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes
|
||||||
|
def gelu(x: jax.Array) -> jax.Array:
|
||||||
|
"""GeLU activation function with approximate tanh."""
|
||||||
|
# This will be printed once every time the function is executed.
|
||||||
|
jax.debug.print("Executing GeLU")
|
||||||
|
# See https://arxiv.org/abs/1606.08415 for details.
|
||||||
|
x3 = jnp.power(x, 3)
|
||||||
|
tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
|
||||||
|
return 0.5 * x * (1 + jnp.tanh(tanh_input))
|
||||||
|
|
||||||
|
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
|
||||||
|
act_fn = gelu
|
||||||
|
if remat:
|
||||||
|
act_fn = jax.remat(act_fn)
|
||||||
|
return jnp.mean(act_fn(x))
|
||||||
|
|
||||||
|
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
|
||||||
|
grad_fn = jax.grad(loss_fn)
|
||||||
|
# regenerate function on backward
|
||||||
|
_ = grad_fn(x, remat=True)
|
||||||
|
|
||||||
|
# no remat, no function regeneration
|
||||||
|
_ = loss_fn(x, remat=False)
|
||||||
|
|
||||||
|
#MARK: GRADIENT ACCUMULATION
|
||||||
|
# %% [markdown]
|
||||||
|
# # gradient accumulation
|
||||||
|
#
|
||||||
|
# run many mini-batches, and accumulate their gradients to feed into optimizer
|
||||||
|
# as if there were one large batch
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class TrainState(train_state.TrainState):
|
||||||
|
rng: jax.Array
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
inputs: jax.Array
|
||||||
|
labels: jax.Array
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# nothing special here, just a loss function
|
||||||
|
def classification_loss_fn(
|
||||||
|
params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
|
||||||
|
) -> Tuple[PyTree, Metrics]:
|
||||||
|
"""Classification loss function with cross-entropy."""
|
||||||
|
logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
|
||||||
|
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
|
||||||
|
correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
|
||||||
|
batch_size = batch.inputs.shape[0]
|
||||||
|
# step_metrics contains the loss sum
|
||||||
|
step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
|
||||||
|
# loss contains the mean loss
|
||||||
|
mean_loss = loss.mean() # the mathematical output of function
|
||||||
|
return mean_loss, step_metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# gradient accumulation training loop
|
||||||
|
def accumulate_gradients_loop(
|
||||||
|
state: TrainState,
|
||||||
|
batch: Batch,
|
||||||
|
rng: jax.random.PRNGKey,
|
||||||
|
num_minibatches: int,
|
||||||
|
loss_fn: Callable,
|
||||||
|
) -> Tuple[PyTree, Metrics]:
|
||||||
|
"""Calculate gradients and metrics for a batch using gradient accumulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current training state.
|
||||||
|
batch: Full training batch.
|
||||||
|
rng: Random number generator to use.
|
||||||
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
||||||
|
loss_fn: Loss function to calculate gradients and metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with accumulated gradients and metrics over the minibatches.
|
||||||
|
"""
|
||||||
|
batch_size = batch.inputs.shape[0]
|
||||||
|
minibatch_size = batch_size // num_minibatches
|
||||||
|
rngs = jax.random.split(rng, num_minibatches)
|
||||||
|
# Define gradient function for single minibatch.
|
||||||
|
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
|
||||||
|
# otherwise it returns (value, gradient), where value is the actual output
|
||||||
|
# of the function, hence the "value" of the namesake
|
||||||
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||||
|
# Prepare loop variables.
|
||||||
|
grads = None
|
||||||
|
metrics = None
|
||||||
|
for minibatch_idx in range(num_minibatches):
|
||||||
|
with jax.named_scope(f"minibatch_{minibatch_idx}"):
|
||||||
|
# Split the batch into minibatches.
|
||||||
|
start = minibatch_idx * minibatch_size
|
||||||
|
end = start + minibatch_size
|
||||||
|
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
|
||||||
|
# Calculate gradients and metrics for the minibatch.
|
||||||
|
# missing value is mean loss of batch
|
||||||
|
(_, step_metrics), step_grads = grad_fn(
|
||||||
|
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
|
||||||
|
)
|
||||||
|
# Accumulate gradients and metrics across minibatches.
|
||||||
|
if grads is None:
|
||||||
|
grads = step_grads
|
||||||
|
metrics = step_metrics
|
||||||
|
else:
|
||||||
|
# accumulation adder
|
||||||
|
grads = jax.tree.map(jnp.add, grads, step_grads)
|
||||||
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
||||||
|
# Average gradients over minibatches.
|
||||||
|
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
|
||||||
|
return grads, metrics
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.scan implementation
|
||||||
|
#
|
||||||
|
# pros: faster compile
|
||||||
|
# cons: slower inference
|
||||||
|
def accumulate_gradients_scan(
|
||||||
|
state: TrainState,
|
||||||
|
batch: Batch,
|
||||||
|
rng: jax.random.PRNGKey,
|
||||||
|
num_minibatches: int,
|
||||||
|
loss_fn: Callable,
|
||||||
|
) -> Tuple[PyTree, Metrics]:
|
||||||
|
"""Calculate gradients and metrics for a batch using gradient accumulation.
|
||||||
|
|
||||||
|
In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current training state.
|
||||||
|
batch: Full training batch.
|
||||||
|
rng: Random number generator to use.
|
||||||
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
||||||
|
loss_fn: Loss function to calculate gradients and metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with accumulated gradients and metrics over the minibatches.
|
||||||
|
"""
|
||||||
|
batch_size = batch.inputs.shape[0]
|
||||||
|
minibatch_size = batch_size // num_minibatches
|
||||||
|
rngs = jax.random.split(rng, num_minibatches)
|
||||||
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||||
|
|
||||||
|
def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
|
||||||
|
"""Determine gradients and metrics for a single minibatch."""
|
||||||
|
minibatch = jax.tree.map(
|
||||||
|
# jax.lax.dynamic_slice_in_dim
|
||||||
|
# This is roughly equivalent to the following Python indexing syntax
|
||||||
|
# applied along the specified axis: operand[..., start_index:start_index + slice_size].
|
||||||
|
# jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)
|
||||||
|
lambda x: jax.lax.dynamic_slice_in_dim( # Slicing with variable index (jax.Array).
|
||||||
|
x,
|
||||||
|
start_index=minibatch_idx * minibatch_size,
|
||||||
|
slice_size=minibatch_size,
|
||||||
|
axis=0
|
||||||
|
),
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
(_, step_metrics), step_grads = grad_fn(
|
||||||
|
state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
|
||||||
|
)
|
||||||
|
return step_grads, step_metrics
|
||||||
|
|
||||||
|
# the function we expect scan to use
|
||||||
|
def _scan_step(
|
||||||
|
carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
|
||||||
|
) -> Tuple[Tuple[PyTree, Metrics], None]:
|
||||||
|
"""Scan step function for looping over minibatches."""
|
||||||
|
step_grads, step_metrics = _minibatch_step(minibatch_idx)
|
||||||
|
# notice how the carry type is a tuple of pytree and metrics
|
||||||
|
# carry is literally the accumulator of (step_grads, step_metrics)
|
||||||
|
carry = jax.tree.map(jnp.add, carry, (step_grads, step_metrics))
|
||||||
|
# jax.lax.scan expects a carry and a y
|
||||||
|
# but we have no y
|
||||||
|
return carry, None
|
||||||
|
|
||||||
|
# Determine initial shapes for gradients and metrics.
|
||||||
|
grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
|
||||||
|
grads = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
|
||||||
|
metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
|
||||||
|
# Loop over minibatches to determine gradients and metrics.
|
||||||
|
# jax.lax.scan
|
||||||
|
# jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)
|
||||||
|
# purpose: Scan a function over leading array axes while carrying along state.
|
||||||
|
# in other words, a functional for-loop
|
||||||
|
# why? because the for-loop is a single WhileOp in JAX primitive, making it faster
|
||||||
|
# equivalent python code semantics:
|
||||||
|
# def scan(f, init, xs, length=None):
|
||||||
|
# if xs is None:
|
||||||
|
# xs = [None] * length
|
||||||
|
# carry = init
|
||||||
|
# ys = []
|
||||||
|
# for x in xs:
|
||||||
|
# carry, y = f(carry, x)
|
||||||
|
# ys.append(y)
|
||||||
|
# return carry, np.stack(ys)
|
||||||
|
# note: usually we expect the ys to be the output and the carry to be hidden state
|
||||||
|
# y is the pure function output we expect
|
||||||
|
(grads, metrics), _ = jax.lax.scan(
|
||||||
|
_scan_step,
|
||||||
|
init=(grads, metrics),
|
||||||
|
xs=jnp.arange(num_minibatches),
|
||||||
|
length=num_minibatches
|
||||||
|
)
|
||||||
|
# Average gradients over minibatches.
|
||||||
|
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
|
||||||
|
return grads, metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
|
||||||
|
if use_scan:
|
||||||
|
return accumulate_gradients_scan(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return accumulate_gradients_loop(*args, **kwargs)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def train_step(
|
||||||
|
state: TrainState,
|
||||||
|
metrics: Metrics | None,
|
||||||
|
batch: Batch,
|
||||||
|
num_minibatches: int,
|
||||||
|
) -> Tuple[TrainState, Metrics]:
|
||||||
|
"""Training step function.
|
||||||
|
|
||||||
|
Executes a full training step with gradient accumulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current training state.
|
||||||
|
metrics: Current metrics, accumulated from previous training steps.
|
||||||
|
batch: Training batch.
|
||||||
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
|
||||||
|
"""
|
||||||
|
# Split the random number generator for the current step.
|
||||||
|
rng, step_rng = jax.random.split(state.rng)
|
||||||
|
# Determine gradients and metrics for the full batch.
|
||||||
|
grads, step_metrics = accumulate_gradients(
|
||||||
|
# we cannot use a variable to choose use_scan
|
||||||
|
# cardinal sin of jax: passing boolean into jitted function
|
||||||
|
state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
|
||||||
|
)
|
||||||
|
# Optimizer step.
|
||||||
|
new_state = state.apply_gradients(grads=grads, rng=rng)
|
||||||
|
# Accumulate metrics across training steps.
|
||||||
|
if metrics is None:
|
||||||
|
metrics = step_metrics
|
||||||
|
else:
|
||||||
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
||||||
|
return new_state, metrics
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
batch_size = 512
|
||||||
|
num_inputs = 128
|
||||||
|
num_classes = 100
|
||||||
|
rng_seed = 0
|
||||||
|
|
||||||
|
rng = jax.random.PRNGKey(rng_seed)
|
||||||
|
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
|
||||||
|
batch = Batch(
|
||||||
|
inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
|
||||||
|
labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Zero dropout for checking later equality between training with and without gradient accumulation.
|
||||||
|
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
|
||||||
|
params = model.init(model_rng, batch.inputs, train=False)["params"]
|
||||||
|
state = TrainState.create(
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optax.adam(1e-3),
|
||||||
|
rng=state_rng,
|
||||||
|
)
|
||||||
|
# %%
|
||||||
|
# jax.eval_shape(fun, *args, **kwargs)
|
||||||
|
# compute shape/dtype of fun without any FLOPs
|
||||||
|
|
||||||
|
# this fails because it jits train_step without minibatch number
|
||||||
|
# thus causing the shape inference to fail
|
||||||
|
# _, metric_shapes = jax.eval_shape(
|
||||||
|
# train_step, # fun
|
||||||
|
# state, # train state
|
||||||
|
# None, # metrics
|
||||||
|
# batch, # batch
|
||||||
|
# 4, # num_minibatches
|
||||||
|
# )
|
||||||
|
|
||||||
|
_, metric_shapes = jax.eval_shape(
|
||||||
|
# this thing jitted works
|
||||||
|
functools.partial(train_step, num_minibatches=4),
|
||||||
|
state, # train state
|
||||||
|
None, # metrics
|
||||||
|
batch, # batch
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Metric shapes:")
|
||||||
|
pprint(metric_shapes)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# this is an optimization trick
|
||||||
|
# cache this every time num_minibatches change
|
||||||
|
# otherwise re-compile every time
|
||||||
|
train_step_jit = jax.jit(
|
||||||
|
train_step,
|
||||||
|
# treat as a static argument
|
||||||
|
static_argnames="num_minibatches",
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def train_with_minibatches(
|
||||||
|
state: TrainState,
|
||||||
|
batch: Batch,
|
||||||
|
num_minibatches: int,
|
||||||
|
num_train_steps: int,
|
||||||
|
) -> Tuple[TrainState, Metrics]:
|
||||||
|
"""Small helper function for training loop."""
|
||||||
|
train_metrics = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
|
||||||
|
for _ in range(num_train_steps):
|
||||||
|
state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
|
||||||
|
return state, train_metrics
|
||||||
|
# %%
|
||||||
|
def print_metrics(metrics: Metrics, title: str | None = None) -> None:
|
||||||
|
"""Prints metrics with an optional title."""
|
||||||
|
metrics = jax.device_get(metrics)
|
||||||
|
lines = [f"{k}: {v[0] / v[1]:.6f}" for k, v in metrics.items()]
|
||||||
|
if title:
|
||||||
|
title = f" {title} "
|
||||||
|
max_len = max(len(title), max(map(len, lines)))
|
||||||
|
lines = [title.center(max_len, "=")] + lines
|
||||||
|
print("\n".join(lines))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
state_mini1, metrics_mini1 = train_with_minibatches(
|
||||||
|
state, batch, num_minibatches=1, num_train_steps=4
|
||||||
|
)
|
||||||
|
state_mini4, metrics_mini4 = train_with_minibatches(
|
||||||
|
state, batch, num_minibatches=4, num_train_steps=4
|
||||||
|
)
|
||||||
|
print_metrics(metrics_mini1, "Minibatch 1")
|
||||||
|
print_metrics(metrics_mini4, "Minibatch 4")
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # donate_buffers
|
||||||
|
# jax perform pass by value due to its functional nature
|
||||||
|
# we can do pass by reference for certain arguments
|
||||||
|
# what can be donated?
|
||||||
|
# we can only do this if we are sure that arguments will not be used
|
||||||
|
# this is usually true for model parameters and optimizer state
|
||||||
|
# since we have totally new values, and we won't use the argument values anymore
|
||||||
|
# after an update (e.g. we will use new_state and new_metrics)
|
||||||
|
train_step_donated = jax.jit(
|
||||||
|
train_step,
|
||||||
|
static_argnames="num_minibatches",
|
||||||
|
donate_argnames=(
|
||||||
|
"state",
|
||||||
|
"metrics",
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,392 @@
|
||||||
|
# %% [markdown]
|
||||||
|
# # T5 implementation using jax with pjit
|
||||||
|
|
||||||
|
|
||||||
|
# MARK: START
|
||||||
|
# %%
|
||||||
|
# let's make 8-device simulator
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataload import DataPrepare
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# import jax
|
||||||
|
# import jax.numpy as jnp
|
||||||
|
# import optax
|
||||||
|
# import numpy as np
|
||||||
|
# from functools import partial
|
||||||
|
# from typing import Callable, Optional
|
||||||
|
# import math
|
||||||
|
#
|
||||||
|
# # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
|
# jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
# # jax.config.update("jax_enable_x64", False)
|
||||||
|
# # enable cache
|
||||||
|
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
#
|
||||||
|
# from flax import jax_utils, traverse_util
|
||||||
|
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
# from flax.training import train_state
|
||||||
|
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
# import flax.core
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# get platform type
|
||||||
|
from jax.lib import xla_bridge
|
||||||
|
print(xla_bridge.get_backend().platform)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
|
save_path = 't5_80_1_bf16'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 384 # 384 is the best
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 2e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=86,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
# seed = 117
|
||||||
|
# rng = jax.random.PRNGKey(seed)
|
||||||
|
# train_loader = dataprep.data_loader(rng, batch_size=1)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# model
|
||||||
|
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from transformers import T5Config
|
||||||
|
|
||||||
|
config = T5Config()
|
||||||
|
|
||||||
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||||
|
# then pass the mask as follows
|
||||||
|
from flax import traverse_util
|
||||||
|
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
# useful for transformer model
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# enable bf16 except for layer_norm
|
||||||
|
flat_params = traverse_util.flatten_dict(model.params)
|
||||||
|
mask = {
|
||||||
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Model
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Initialize our training
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimization functions
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Training functions
|
||||||
|
class TrainState(train_state.TrainState):
|
||||||
|
dropout_rng: jnp.ndarray
|
||||||
|
|
||||||
|
# easy way to achieve data parallelism
|
||||||
|
# also achieves folding of rng keys
|
||||||
|
def replicate(self):
|
||||||
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||||
|
|
||||||
|
# set bf16 for model params
|
||||||
|
# model.params = model.to_bf16(model.params)
|
||||||
|
# Cast parameters to bfloat16 if desired
|
||||||
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
|
||||||
|
# Setup train state
|
||||||
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
|
# label smoothed cross entropy
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.sum()
|
||||||
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# Define gradient update step fn
|
||||||
|
@jax.jit
|
||||||
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
|
||||||
|
def compute_loss(params):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
# loss = jax.lax.psum(loss, "batch")
|
||||||
|
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||||
|
|
||||||
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
return new_state, metrics
|
||||||
|
|
||||||
|
# Define generation function
|
||||||
|
max_length = (
|
||||||
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||||
|
)
|
||||||
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
|
# def generate_step(params, batch):
|
||||||
|
# model.params = params
|
||||||
|
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||||||
|
# return output_ids.sequences
|
||||||
|
|
||||||
|
# Create parallel version of the train and eval step
|
||||||
|
p_train_step = jax.pmap(
|
||||||
|
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
||||||
|
)
|
||||||
|
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
||||||
|
# p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
|
# Replicate the train state on each device
|
||||||
|
state = state.replicate()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
rng, data_rng = jax.random.split(rng)
|
||||||
|
train_loader = dataprep.data_loader(data_rng, batch_size=batch_size)
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
batch = shard(batch)
|
||||||
|
state, train_metric = p_train_step(state, batch)
|
||||||
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
train_metric = unreplicate(train_metric)
|
||||||
|
train_metric['loss'].block_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
# f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# output_dir = save_path
|
||||||
|
# # save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
# if jax.process_index() == 0:
|
||||||
|
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
# model.save_pretrained(output_dir, params=params)
|
||||||
|
# tokenizer.save_pretrained(output_dir)
|
|
@ -196,7 +196,7 @@ model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -391,6 +391,8 @@ adamw = optax.adamw(
|
||||||
class TrainState(train_state.TrainState):
|
class TrainState(train_state.TrainState):
|
||||||
dropout_rng: jnp.ndarray
|
dropout_rng: jnp.ndarray
|
||||||
|
|
||||||
|
# easy way to achieve data parallelism
|
||||||
|
# also achieves folding of rng keys
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue