learn_jax/parallel/flax_pjit_tutorial.py

457 lines
15 KiB
Python

# MARK: START
# %%
# let's make 8-device simulator
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import flax.core
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# MARK: sharding example
# %%
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec)
# %%
# define a layer
class DotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
# y has shape (x.shape[-1], self.depth)
# we replicate data across devices
# but we shard the layer across the model axes
y = nn.Dense(self.depth,
kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
# annotate intermediate variables to force a particular sharding pattern
# when ideal constraint is known
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_partitioning(self.dense_init, ('model', None)),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
return z, None
# %%
class MLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(
DotReluDot,
length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for _ in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x
# %%
# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = jax.random.key(117)
# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)
# %%
# specify sharding
# shard data
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)
# shard output
# we will shard state by tracking its output upon jax.eval_shape after init
# define an init function to return a TrainState
def init_fn(key, x, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
variables = model.init(key, x) # Initialize the model.
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
print(state_sharding)
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
out_shardings=state_sharding
)
initialized_state = jit_init_fn(k, x, model, optimizer)
# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
# print(f'Sharding of {weight}: {partitioned.names}')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
# %%
# inspect module output
# the params are actually linen.Partitioned objects
# the Partition objects are actually wrapping jax.Array
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
# %%
# Say for some unknown reason you want to make the whole param tree all-zero
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0
# %%
# check the jax.sharding of each parameter
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding)
print(initialized_state.step)
print(initialized_state.step.sharding)
# %%
# example of computation on pytree
diff = jax.tree.map(
lambda a, b: a - b,
initialized_state.params['DotReluDot_0'],
initialized_state.params['DotReluDot_0']
)
print(jax.tree.map(jnp.shape, diff))
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
# %%
# compile train step
@functools.partial(
jax.jit,
in_shardings=(state_sharding, x_sharding),
out_shardings=state_sharding
)
def train_step(state, x):
def loss_unrolled(params):
y = model.apply({'params': params}, x)
return y.sum()
grad_fn = jax.grad(loss_unrolled)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
# this trains for 1 step
# with mesh: # not strictly necessary in this case
# with mesh block is useful for explicit scope for device sharding
# but mesh management is automatic via jit sharding annotations
new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
# %%
# compile inference step
@functools.partial(
jax.jit,
in_shardings=(state_sharding, x_sharding),
out_shardings=x_sharding
)
def apply_fn(state, x):
return state.apply_fn({'params': state.params}, x)
# this infers for 1 step
with mesh:
y = apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)
# %%
# profiling
# measure performance
import timeit
def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs
# with mesh:
t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10)
print(t)
# MARK: logical sharding
# %%
# logical axis annotation
# why?
# rather than just be fixed with 'data' and 'model', we can annotate with logical names
# then map these logical names back to 'data' and 'model',
# or be more flexible with more axes
#
# we will substitute with the following:
# flax.linen.with_partitioning -> flax.linen.with_logical_partitioning
# flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint
class LogicalDotReluDot(nn.Module):
depth: int
dense_init: Callable = nn.initializers.xavier_normal()
@nn.compact
def __call__(self, x):
y = nn.Dense(
self.depth,
# use of logical partitioning here
# kernel_init is the initializer function for the weight matrix
kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
use_bias=False, # or overwrite with `bias_init`
)(x)
y = jax.nn.relu(y)
# Force a local sharding annotation.
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
W2 = self.param(
'W2',
nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = nn.with_logical_constraint(z, ('batch', 'embed'))
return z, None
class LogicalMLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: 'layer'}
)(self.depth)(x)
else:
for _ in range(self.num_layers):
x, _ = LogicalDotReluDot(self.depth)(x)
return x
# %%
# we initialize the model with the eval_shape method
# but we need to perform a rule to replace logical axes with real axes
rules =(
('batch', 'data'),
('hidden', 'model')
)
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
logical_abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=logical_model, optimizer=optimizer),
k,
x,
)
# linen.get_partition_spec
# extracts a partitionspec tree from a pytree containing partitioned values
# linen.Partitioned
# wrapper for partitioning metadata
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print(
"annotations are logical, not mesh specific",
logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel']
)
# we convert our logical_state_spec to a logical_state_sharding
# with defined rules
logical_state_sharding = nn.logical_to_mesh_sharding(
logical_state_spec,
mesh,
rules
)
print('sharding annotations are mesh-specific: ',
logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
# %%
# with a working state_sharding object, we can init
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
out_shardings=logical_state_sharding)
logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
# %%
# MARK: saving checkpoint
# https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing
# let us save the model
# since we already have a mesh, we will skip the making of the mesh
from typing import Optional, Any
import shutil
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint
import optax
ckpt_dir = '/tmp/flax_ckpt'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
# %%
# make up some stuff
# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,)) # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}
# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
# %%
# single host save with orbax
from flax.training import orbax_utils
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)
# %%
# multi-process checkpointing
# aka checkpointing for sharding
# The reference doesn't need to be as large as your checkpoint!
# Just make sure it has the `.sharding` you want.
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html
import orbax.checkpoint as ocp
from etils import epath
path = epath.Path('/tmp/async_checkpoint')
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path, args=ocp.args.StandardSave(train_state))
### Do some other work...
ckptr.wait_until_finished()