# MARK: START # %% # let's make 8-device simulator import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = True flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["JAX_PLATFORMS"] = "cpu" else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=false " "--xla_gpu_enable_async_collectives=true " "--xla_gpu_enable_latency_hiding_scheduler=true " "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags import functools from functools import partial from pprint import pprint from typing import Any, Dict, Tuple, Callable, Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding # from jax.experimental.pjit import pjit # superseded by jax.jit from jax.experimental import mesh_utils from jax.sharding import PartitionSpec from ml_collections import ConfigDict import optax import logging import time from datasets import Dataset, load_from_disk from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key import flax.core from tqdm import tqdm from dataload import DataPrepare PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] if USE_CPU_ONLY: jax.config.update('jax_platform_name', 'cpu') else: jax.config.update("jax_default_matmul_precision", "bfloat16") # MARK: sharding example # %% device_mesh = mesh_utils.create_device_mesh((2,2)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec) # %% # define a layer class DotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @nn.compact def __call__(self, x): # y has shape (x.shape[-1], self.depth) # we replicate data across devices # but we shard the layer across the model axes y = nn.Dense(self.depth, kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')), use_bias=False, # or overwrite with `bias_init` )(x) y = jax.nn.relu(y) # Force a local sharding annotation. # annotate intermediate variables to force a particular sharding pattern # when ideal constraint is known y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) W2 = self.param( 'W2', nn.with_partitioning(self.dense_init, ('model', None)), (self.depth, x.shape[-1])) z = jnp.dot(y, W2) # Force a local sharding annotation. z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None))) # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below. return z, None # %% class MLP(nn.Module): num_layers: int depth: int use_scan: bool @nn.compact def __call__(self, x): if self.use_scan: x, _ = nn.scan( DotReluDot, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True}, metadata_params={nn.PARTITION_NAME: None} )(self.depth)(x) else: for _ in range(self.num_layers): x, _ = DotReluDot(self.depth)(x) return x # %% # MLP hyperparameters. BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. x = jnp.ones((BATCH, DEPTH)) # Initialize a PRNG key. k = jax.random.key(117) # Create an Optax optimizer. optimizer = optax.adam(learning_rate=0.001) # Instantiate the model. model = MLP(LAYERS, DEPTH, USE_SCAN) # %% # specify sharding # shard data x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis x = jax.device_put(x, x_sharding) jax.debug.visualize_array_sharding(x) # shard output # we will shard state by tracking its output upon jax.eval_shape after init # define an init function to return a TrainState def init_fn(key, x, model, optimizer): # do be careful with the model init # imported models might have complicated init methods variables = model.init(key, x) # Initialize the model. state = train_state.TrainState.create( # Create a `TrainState`. apply_fn=model.apply, params=variables['params'], tx=optimizer) return state # Create an abstract closure to wrap the function before feeding it in # because `jax.eval_shape` only takes pytrees as arguments. # eval_shape(fn, rng_key, x) # used to perform shape inference # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves abstract_variables = jax.eval_shape( functools.partial(init_fn, model=model, optimizer=optimizer), k, x) # This `state_sharding` has the same pytree structure as `state`, the output # of the `init_fn`. # flan.linen.get_sharding # extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh # jax.sharding: describes how a jax.Array is laid out across devices state_sharding = nn.get_sharding(abstract_variables, mesh) print(state_sharding) # %% jit_init_fn = jax.jit( init_fn, static_argnames=('model', 'optimizer'), # skip model and optimizer in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data out_shardings=state_sharding ) initialized_state = jit_init_fn(k, x, model, optimizer) # for weight, partitioned in initialized_state.params['DotReluDot_0'].items(): # print(f'Sharding of {weight}: {partitioned.names}') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) # %% # inspect module output # the params are actually linen.Partitioned objects # the Partition objects are actually wrapping jax.Array print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) # %% # Say for some unknown reason you want to make the whole param tree all-zero unboxed_params = nn.meta.unbox(initialized_state.params) all_zero = jax.tree.map(jnp.zeros_like, unboxed_params) all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero) assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0 # %% # check the jax.sharding of each parameter print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding) print(initialized_state.step) print(initialized_state.step.sharding) # %% # example of computation on pytree diff = jax.tree.map( lambda a, b: a - b, initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0'] ) print(jax.tree.map(jnp.shape, diff)) diff_array = diff['Dense_0']['kernel'].value print(type(diff_array)) print(diff_array.shape) # %% # compile train step @functools.partial( jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=state_sharding ) def train_step(state, x): def loss_unrolled(params): y = model.apply({'params': params}, x) return y.sum() grad_fn = jax.grad(loss_unrolled) grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state # this trains for 1 step # with mesh: # not strictly necessary in this case # with mesh block is useful for explicit scope for device sharding # but mesh management is automatic via jit sharding annotations with mesh: new_state = train_step(initialized_state, x) print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) # %% # compile inference step @functools.partial( jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=x_sharding ) def apply_fn(state, x): return state.apply_fn({'params': state.params}, x) # this infers for 1 step with mesh: y = apply_fn(new_state, x) print(type(y)) print(y.dtype) print(y.shape) jax.debug.visualize_array_sharding(y) # %% # profiling # measure performance import timeit def block_all(xs): jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs) return xs # with mesh: t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10) print(t) # MARK: logical sharding # %% # logical axis annotation # why? # rather than just be fixed with 'data' and 'model', we can annotate with logical names # then map these logical names back to 'data' and 'model', # or be more flexible with more axes # # we will substitute with the following: # flax.linen.with_partitioning -> flax.linen.with_logical_partitioning # flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint class LogicalDotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @nn.compact def __call__(self, x): y = nn.Dense( self.depth, # use of logical partitioning here # kernel_init is the initializer function for the weight matrix kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')), use_bias=False, # or overwrite with `bias_init` )(x) y = jax.nn.relu(y) # Force a local sharding annotation. y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) W2 = self.param( 'W2', nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')), (self.depth, x.shape[-1])) z = jnp.dot(y, W2) # Force a local sharding annotation. z = nn.with_logical_constraint(z, ('batch', 'embed')) return z, None class LogicalMLP(nn.Module): num_layers: int depth: int use_scan: bool @nn.compact def __call__(self, x): if self.use_scan: x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True}, metadata_params={nn.PARTITION_NAME: 'layer'} )(self.depth)(x) else: for _ in range(self.num_layers): x, _ = LogicalDotReluDot(self.depth)(x) return x # %% # we initialize the model with the eval_shape method # but we need to perform a rule to replace logical axes with real axes rules =( ('batch', 'data'), ('hidden', 'model') ) logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN) logical_abstract_variables = jax.eval_shape( functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x, ) # linen.get_partition_spec # extracts a partitionspec tree from a pytree containing partitioned values # linen.Partitioned # wrapper for partitioning metadata logical_state_spec = nn.get_partition_spec(logical_abstract_variables) print( "annotations are logical, not mesh specific", logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'] ) # we convert our logical_state_spec to a logical_state_sharding # with defined rules logical_state_sharding = nn.logical_to_mesh_sharding( logical_state_spec, mesh, rules ) print('sharding annotations are mesh-specific: ', logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec) # %% # with a working state_sharding object, we can init logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x out_shardings=logical_state_sharding) logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) # %% # MARK: saving checkpoint # https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing # let us save the model # since we already have a mesh, we will skip the making of the mesh from typing import Optional, Any import shutil import numpy as np import jax from jax import random, numpy as jnp import flax from flax import linen as nn from flax.training import checkpoints, train_state from flax import struct, serialization import orbax.checkpoint import optax ckpt_dir = '/tmp/flax_ckpt' if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run. # %% # make up some stuff # A simple model with one linear layer. key1, key2 = random.split(random.key(0)) x1 = random.normal(key1, (5,)) # A simple JAX array. model = nn.Dense(features=3) variables = model.init(key2, x1) # Flax's TrainState is a pytree dataclass and is supported in checkpointing. # Define your class with `@flax.struct.dataclass` decorator to make it compatible. tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer. state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx) # Perform a simple gradient update similar to the one during a normal training workflow. state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params)) # Some arbitrary nested pytree with a dictionary and a NumPy array. config = {'dimensions': np.array([5, 3])} # Bundle everything together. ckpt = {'model': state, 'config': config, 'data': [x1]} # %% # single host save with orbax from flax.training import orbax_utils orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(ckpt) orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args) # %% # multi-process checkpointing # aka checkpointing for sharding # The reference doesn't need to be as large as your checkpoint! # Just make sure it has the `.sharding` you want. # https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html # https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html import orbax.checkpoint as ocp from etils import epath path = epath.Path('/tmp/async_checkpoint') ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) ckptr.save(path, args=ocp.args.StandardSave(train_state)) ### Do some other work... ckptr.wait_until_finished()