457 lines
15 KiB
Python
457 lines
15 KiB
Python
|
# MARK: START
|
||
|
# %%
|
||
|
# let's make 8-device simulator
|
||
|
import os
|
||
|
|
||
|
# Set this to True to run the model on CPU only.
|
||
|
USE_CPU_ONLY = True
|
||
|
|
||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||
|
if USE_CPU_ONLY:
|
||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||
|
# Enforce CPU-only execution
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||
|
else:
|
||
|
# GPU flags
|
||
|
flags += (
|
||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||
|
"--xla_gpu_triton_gemm_any=false "
|
||
|
"--xla_gpu_enable_async_collectives=true "
|
||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||
|
)
|
||
|
os.environ["XLA_FLAGS"] = flags
|
||
|
|
||
|
import functools
|
||
|
from functools import partial
|
||
|
from pprint import pprint
|
||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||
|
|
||
|
import flax.linen as nn
|
||
|
import jax
|
||
|
import jax.numpy as jnp
|
||
|
import numpy as np
|
||
|
from jax.experimental.shard_map import shard_map
|
||
|
from jax.sharding import Mesh, NamedSharding
|
||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||
|
from jax.experimental import mesh_utils
|
||
|
from jax.sharding import PartitionSpec
|
||
|
from ml_collections import ConfigDict
|
||
|
import optax
|
||
|
import logging
|
||
|
import time
|
||
|
from datasets import Dataset, load_from_disk
|
||
|
|
||
|
from flax import jax_utils, traverse_util
|
||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||
|
from flax.training import train_state
|
||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||
|
import flax.core
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from dataload import DataPrepare
|
||
|
|
||
|
PyTree = Any
|
||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||
|
|
||
|
if USE_CPU_ONLY:
|
||
|
jax.config.update('jax_platform_name', 'cpu')
|
||
|
else:
|
||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||
|
|
||
|
|
||
|
# MARK: sharding example
|
||
|
# %%
|
||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||
|
print(device_mesh)
|
||
|
|
||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||
|
print(mesh)
|
||
|
|
||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||
|
return NamedSharding(mesh, pspec)
|
||
|
|
||
|
# %%
|
||
|
# define a layer
|
||
|
class DotReluDot(nn.Module):
|
||
|
depth: int
|
||
|
dense_init: Callable = nn.initializers.xavier_normal()
|
||
|
@nn.compact
|
||
|
def __call__(self, x):
|
||
|
|
||
|
# y has shape (x.shape[-1], self.depth)
|
||
|
# we replicate data across devices
|
||
|
# but we shard the layer across the model axes
|
||
|
y = nn.Dense(self.depth,
|
||
|
kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
|
||
|
use_bias=False, # or overwrite with `bias_init`
|
||
|
)(x)
|
||
|
|
||
|
y = jax.nn.relu(y)
|
||
|
# Force a local sharding annotation.
|
||
|
# annotate intermediate variables to force a particular sharding pattern
|
||
|
# when ideal constraint is known
|
||
|
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
|
||
|
|
||
|
W2 = self.param(
|
||
|
'W2',
|
||
|
nn.with_partitioning(self.dense_init, ('model', None)),
|
||
|
(self.depth, x.shape[-1]))
|
||
|
|
||
|
z = jnp.dot(y, W2)
|
||
|
# Force a local sharding annotation.
|
||
|
z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))
|
||
|
|
||
|
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
|
||
|
return z, None
|
||
|
|
||
|
# %%
|
||
|
class MLP(nn.Module):
|
||
|
num_layers: int
|
||
|
depth: int
|
||
|
use_scan: bool
|
||
|
@nn.compact
|
||
|
def __call__(self, x):
|
||
|
if self.use_scan:
|
||
|
x, _ = nn.scan(
|
||
|
DotReluDot,
|
||
|
length=self.num_layers,
|
||
|
variable_axes={"params": 0},
|
||
|
split_rngs={"params": True},
|
||
|
metadata_params={nn.PARTITION_NAME: None}
|
||
|
)(self.depth)(x)
|
||
|
else:
|
||
|
for _ in range(self.num_layers):
|
||
|
x, _ = DotReluDot(self.depth)(x)
|
||
|
return x
|
||
|
|
||
|
# %%
|
||
|
# MLP hyperparameters.
|
||
|
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
|
||
|
# Create fake inputs.
|
||
|
x = jnp.ones((BATCH, DEPTH))
|
||
|
# Initialize a PRNG key.
|
||
|
k = jax.random.key(117)
|
||
|
|
||
|
# Create an Optax optimizer.
|
||
|
optimizer = optax.adam(learning_rate=0.001)
|
||
|
# Instantiate the model.
|
||
|
model = MLP(LAYERS, DEPTH, USE_SCAN)
|
||
|
|
||
|
# %%
|
||
|
# specify sharding
|
||
|
|
||
|
# shard data
|
||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||
|
x = jax.device_put(x, x_sharding)
|
||
|
jax.debug.visualize_array_sharding(x)
|
||
|
|
||
|
# shard output
|
||
|
# we will shard state by tracking its output upon jax.eval_shape after init
|
||
|
# define an init function to return a TrainState
|
||
|
def init_fn(key, x, model, optimizer):
|
||
|
# do be careful with the model init
|
||
|
# imported models might have complicated init methods
|
||
|
variables = model.init(key, x) # Initialize the model.
|
||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||
|
apply_fn=model.apply,
|
||
|
params=variables['params'],
|
||
|
tx=optimizer)
|
||
|
return state
|
||
|
|
||
|
# Create an abstract closure to wrap the function before feeding it in
|
||
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
||
|
# eval_shape(fn, rng_key, x)
|
||
|
# used to perform shape inference
|
||
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
||
|
abstract_variables = jax.eval_shape(
|
||
|
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
|
||
|
|
||
|
|
||
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
||
|
# of the `init_fn`.
|
||
|
# flan.linen.get_sharding
|
||
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||
|
print(state_sharding)
|
||
|
|
||
|
# %%
|
||
|
jit_init_fn = jax.jit(
|
||
|
init_fn,
|
||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||
|
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
||
|
out_shardings=state_sharding
|
||
|
)
|
||
|
|
||
|
initialized_state = jit_init_fn(k, x, model, optimizer)
|
||
|
# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
|
||
|
# print(f'Sharding of {weight}: {partitioned.names}')
|
||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||
|
|
||
|
# %%
|
||
|
# inspect module output
|
||
|
# the params are actually linen.Partitioned objects
|
||
|
# the Partition objects are actually wrapping jax.Array
|
||
|
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
|
||
|
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
|
||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
|
||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
|
||
|
# %%
|
||
|
# Say for some unknown reason you want to make the whole param tree all-zero
|
||
|
unboxed_params = nn.meta.unbox(initialized_state.params)
|
||
|
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
|
||
|
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
|
||
|
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0
|
||
|
# %%
|
||
|
# check the jax.sharding of each parameter
|
||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding)
|
||
|
print(initialized_state.step)
|
||
|
print(initialized_state.step.sharding)
|
||
|
# %%
|
||
|
# example of computation on pytree
|
||
|
diff = jax.tree.map(
|
||
|
lambda a, b: a - b,
|
||
|
initialized_state.params['DotReluDot_0'],
|
||
|
initialized_state.params['DotReluDot_0']
|
||
|
)
|
||
|
print(jax.tree.map(jnp.shape, diff))
|
||
|
diff_array = diff['Dense_0']['kernel'].value
|
||
|
print(type(diff_array))
|
||
|
print(diff_array.shape)
|
||
|
# %%
|
||
|
# compile train step
|
||
|
@functools.partial(
|
||
|
jax.jit,
|
||
|
in_shardings=(state_sharding, x_sharding),
|
||
|
out_shardings=state_sharding
|
||
|
)
|
||
|
def train_step(state, x):
|
||
|
def loss_unrolled(params):
|
||
|
y = model.apply({'params': params}, x)
|
||
|
return y.sum()
|
||
|
grad_fn = jax.grad(loss_unrolled)
|
||
|
grads = grad_fn(state.params)
|
||
|
state = state.apply_gradients(grads=grads)
|
||
|
return state
|
||
|
|
||
|
# this trains for 1 step
|
||
|
# with mesh: # not strictly necessary in this case
|
||
|
# with mesh block is useful for explicit scope for device sharding
|
||
|
# but mesh management is automatic via jit sharding annotations
|
||
|
new_state = train_step(initialized_state, x)
|
||
|
|
||
|
print(f'Sharding of Weight 1:')
|
||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||
|
print(f'Sharding of Weight 2:')
|
||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# compile inference step
|
||
|
@functools.partial(
|
||
|
jax.jit,
|
||
|
in_shardings=(state_sharding, x_sharding),
|
||
|
out_shardings=x_sharding
|
||
|
)
|
||
|
def apply_fn(state, x):
|
||
|
return state.apply_fn({'params': state.params}, x)
|
||
|
|
||
|
# this infers for 1 step
|
||
|
with mesh:
|
||
|
y = apply_fn(new_state, x)
|
||
|
print(type(y))
|
||
|
print(y.dtype)
|
||
|
print(y.shape)
|
||
|
jax.debug.visualize_array_sharding(y)
|
||
|
# %%
|
||
|
# profiling
|
||
|
# measure performance
|
||
|
import timeit
|
||
|
def block_all(xs):
|
||
|
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
|
||
|
return xs
|
||
|
|
||
|
# with mesh:
|
||
|
t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10)
|
||
|
print(t)
|
||
|
|
||
|
# MARK: logical sharding
|
||
|
# %%
|
||
|
# logical axis annotation
|
||
|
# why?
|
||
|
# rather than just be fixed with 'data' and 'model', we can annotate with logical names
|
||
|
# then map these logical names back to 'data' and 'model',
|
||
|
# or be more flexible with more axes
|
||
|
#
|
||
|
# we will substitute with the following:
|
||
|
# flax.linen.with_partitioning -> flax.linen.with_logical_partitioning
|
||
|
# flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint
|
||
|
|
||
|
class LogicalDotReluDot(nn.Module):
|
||
|
depth: int
|
||
|
dense_init: Callable = nn.initializers.xavier_normal()
|
||
|
@nn.compact
|
||
|
def __call__(self, x):
|
||
|
y = nn.Dense(
|
||
|
self.depth,
|
||
|
# use of logical partitioning here
|
||
|
# kernel_init is the initializer function for the weight matrix
|
||
|
kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
|
||
|
use_bias=False, # or overwrite with `bias_init`
|
||
|
)(x)
|
||
|
|
||
|
y = jax.nn.relu(y)
|
||
|
# Force a local sharding annotation.
|
||
|
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
|
||
|
|
||
|
W2 = self.param(
|
||
|
'W2',
|
||
|
nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
|
||
|
(self.depth, x.shape[-1]))
|
||
|
|
||
|
z = jnp.dot(y, W2)
|
||
|
# Force a local sharding annotation.
|
||
|
z = nn.with_logical_constraint(z, ('batch', 'embed'))
|
||
|
return z, None
|
||
|
|
||
|
class LogicalMLP(nn.Module):
|
||
|
num_layers: int
|
||
|
depth: int
|
||
|
use_scan: bool
|
||
|
@nn.compact
|
||
|
def __call__(self, x):
|
||
|
if self.use_scan:
|
||
|
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
|
||
|
variable_axes={"params": 0},
|
||
|
split_rngs={"params": True},
|
||
|
metadata_params={nn.PARTITION_NAME: 'layer'}
|
||
|
)(self.depth)(x)
|
||
|
else:
|
||
|
for _ in range(self.num_layers):
|
||
|
x, _ = LogicalDotReluDot(self.depth)(x)
|
||
|
return x
|
||
|
|
||
|
# %%
|
||
|
# we initialize the model with the eval_shape method
|
||
|
# but we need to perform a rule to replace logical axes with real axes
|
||
|
rules =(
|
||
|
('batch', 'data'),
|
||
|
('hidden', 'model')
|
||
|
)
|
||
|
|
||
|
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
|
||
|
|
||
|
logical_abstract_variables = jax.eval_shape(
|
||
|
functools.partial(init_fn, model=logical_model, optimizer=optimizer),
|
||
|
k,
|
||
|
x,
|
||
|
)
|
||
|
|
||
|
# linen.get_partition_spec
|
||
|
# extracts a partitionspec tree from a pytree containing partitioned values
|
||
|
# linen.Partitioned
|
||
|
# wrapper for partitioning metadata
|
||
|
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
|
||
|
print(
|
||
|
"annotations are logical, not mesh specific",
|
||
|
logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel']
|
||
|
)
|
||
|
|
||
|
# we convert our logical_state_spec to a logical_state_sharding
|
||
|
# with defined rules
|
||
|
logical_state_sharding = nn.logical_to_mesh_sharding(
|
||
|
logical_state_spec,
|
||
|
mesh,
|
||
|
rules
|
||
|
)
|
||
|
print('sharding annotations are mesh-specific: ',
|
||
|
logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
|
||
|
|
||
|
# %%
|
||
|
# with a working state_sharding object, we can init
|
||
|
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
|
||
|
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
|
||
|
out_shardings=logical_state_sharding)
|
||
|
|
||
|
logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
|
||
|
|
||
|
# %%
|
||
|
# MARK: saving checkpoint
|
||
|
# https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing
|
||
|
# let us save the model
|
||
|
# since we already have a mesh, we will skip the making of the mesh
|
||
|
|
||
|
from typing import Optional, Any
|
||
|
import shutil
|
||
|
|
||
|
import numpy as np
|
||
|
import jax
|
||
|
from jax import random, numpy as jnp
|
||
|
|
||
|
import flax
|
||
|
from flax import linen as nn
|
||
|
from flax.training import checkpoints, train_state
|
||
|
from flax import struct, serialization
|
||
|
import orbax.checkpoint
|
||
|
|
||
|
import optax
|
||
|
|
||
|
ckpt_dir = '/tmp/flax_ckpt'
|
||
|
|
||
|
if os.path.exists(ckpt_dir):
|
||
|
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
|
||
|
|
||
|
# %%
|
||
|
# make up some stuff
|
||
|
# A simple model with one linear layer.
|
||
|
key1, key2 = random.split(random.key(0))
|
||
|
x1 = random.normal(key1, (5,)) # A simple JAX array.
|
||
|
model = nn.Dense(features=3)
|
||
|
variables = model.init(key2, x1)
|
||
|
|
||
|
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
|
||
|
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
|
||
|
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
|
||
|
state = train_state.TrainState.create(
|
||
|
apply_fn=model.apply,
|
||
|
params=variables['params'],
|
||
|
tx=tx)
|
||
|
# Perform a simple gradient update similar to the one during a normal training workflow.
|
||
|
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
|
||
|
|
||
|
# Some arbitrary nested pytree with a dictionary and a NumPy array.
|
||
|
config = {'dimensions': np.array([5, 3])}
|
||
|
|
||
|
# Bundle everything together.
|
||
|
ckpt = {'model': state, 'config': config, 'data': [x1]}
|
||
|
|
||
|
# %%
|
||
|
# single host save with orbax
|
||
|
from flax.training import orbax_utils
|
||
|
|
||
|
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
||
|
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# multi-process checkpointing
|
||
|
# aka checkpointing for sharding
|
||
|
# The reference doesn't need to be as large as your checkpoint!
|
||
|
# Just make sure it has the `.sharding` you want.
|
||
|
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html
|
||
|
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html
|
||
|
|
||
|
import orbax.checkpoint as ocp
|
||
|
from etils import epath
|
||
|
|
||
|
path = epath.Path('/tmp/async_checkpoint')
|
||
|
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
|
||
|
ckptr.save(path, args=ocp.args.StandardSave(train_state))
|
||
|
### Do some other work...
|
||
|
ckptr.wait_until_finished()
|