# %% [markdown] # # Distribute computin in JAX # %% import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = True flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=false " "--xla_gpu_enable_async_collectives=true " "--xla_gpu_enable_latency_hiding_scheduler=true " "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags # %% import functools from typing import Any, Dict, Tuple import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] jax.config.update('jax_platform_name', 'cpu') # %% jax.devices() # %% # when we create array, we can check the location a = jnp.arange(8) print("Array", a) print("Device", a.device) print("Sharding", a.sharding) # %% [markdown] # ## Single-Axis Mesh # %% # let's create a Mesh # multidimensional Numpy array of jax devices # jax.sharding.Mesh(devices, axis_names) mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",)) print(mesh) # %% # jax.sharding.NamedSharding(mesh, spec) # pair of a Mesh of devices and PartitionSpec # PartitionSpec describes how to share an array across that mesh # "i" is the value of the dimension of the array # to shard an array axis over a certain mesh axis, add the axis name at the # corresponding position in the tuple sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",)) # %% a_sharded = jax.device_put(a, sharding) print("Sharded array", a_sharded) print("Device", a_sharded.devices()) print("Sharding", a_sharded.sharding) # %% jax.debug.visualize_array_sharding(a_sharded) # %% # let's try some computation on the mesh out = nn.tanh(a_sharded) print("Output array", out) jax.debug.visualize_array_sharding(out) # note how the output array is sharded across the devices # %% [markdown] # ## multi-axis mesh # Why would you shard across multiple dimensions? # # # %% mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j")) # axis i/0 refers to the row-wise axis progressing downwards # axis j/1 refers to the column-wise axis progressing rightward mesh # noqa: B018 # %% # we now illustrate sharded MAC operation # y = x @ w + b batch_size = 192 input_dim = 64 output_dim = 128 # input: (batch_size, input_dim) x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim)) # w: (input_dim, output_dim) w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim)) # b: (output_dim,) b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,)) # %% # x sharded along 0 axis (partition) # x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None))) w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j"))) b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j"))) print('x blocks:') jax.debug.visualize_array_sharding(x_sharded) print('w blocks:') jax.debug.visualize_array_sharding(w_sharded) print('b blocks:') jax.debug.visualize_array_sharding(b_sharded) # %% out = jnp.dot(x_sharded, w_sharded) + b_sharded print("Output shape", out.shape) jax.debug.visualize_array_sharding(out) # %% [markdown] # # Shard Map -shmap # # beforehand, we manually assign the sharding partition to assign the exact # partitions to achieve independent, parallel block matrix computation # # This allows us to write code with explicit control over parallelization and # communication # # what is a shard_map? # # it is a transformation that takes a function, a mesh, and a sharding # specification for inputs and outputs # # in other words, we write a function that executes on each device only, then # apply across all the shards # # but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough # information about the shards to efficiently perform sharding for complicated # meshes. # %% def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array: print("Local x shape", x.shape) print("Local w shape", w.shape) print("Local b shape", b.shape) # so simple! return jnp.dot(x,w) + b # %% matmul_sharded = shard_map( matmul_fn, # the function for operating on a single device mesh, # the device topology # the input mesh partition argument for each input in_specs=( PartitionSpec("i", None), # x PartitionSpec(None, "j"), # w PartitionSpec("j") # b ), # the output to read from the mesh out_specs=PartitionSpec("i", "j") ) # %% # y = matmul_sharded(x_sharded, w_sharded, b_sharded) # there is no need to device_put, # partitioning is done according to your in_specs y = matmul_sharded(x, w, b) print("Output shape", y.shape) jax.debug.visualize_array_sharding(y) # %% [markdown] # # Axis Communication # %% # example of mean/sum across devices per shard # the following wants to find the statistics of x # we compute the normalized x according to each row statistics (mean and std) @functools.partial( shard_map, mesh=mesh, in_specs=PartitionSpec("i", "j"), out_specs=PartitionSpec("i", "j")) def parallel_normalize(x: jax.Array) -> jax.Array: # jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis # "axis_name" # get the mean across the "j" axis of the mesh - column wise mean = jax.lax.pmean(x, axis_name="j") # get the std across the "j" axis of the mesh - column wise std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5 return (x - mean) / std # communicated along "j" axis of mesh for row elements out = parallel_normalize(x) out = jax.device_get(out) print(out.shape) print("Mean", out.mean()) print("Std", out.std()) # %% # scenario: array is sharded across devices, some values missing per shard # all-gather: gather values of an array from all devices @functools.partial( shard_map, mesh=mesh, in_specs=( PartitionSpec("i", None), # artificially shard across "i" PartitionSpec("i", None) ), out_specs=PartitionSpec("i", None)) def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array: print("Original w shape", w.shape) # pull the full w matrix values from neighboring devices w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True) print("Gathered w shape", w_gathered.shape) y = jnp.dot(x, w_gathered) return y out = matmul_with_weight_gather(x, w) out = jax.device_get(out) np.testing.assert_array_equal(out, jnp.dot(x, w)) # %% # scenario: arrays are sharded across all devices # scatter sum: each function instance of each device gets only one shard of the result # # therefore each device gets the sum of some(or one) array(s) @functools.partial( shard_map,mesh=mesh, in_specs=PartitionSpec("i", None), out_specs=PartitionSpec("i", None)) def scatter_example(x: jax.Array) -> jax.Array: x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1) return x_scatter x_exmp = np.array( [ [3, 1, 4, 1], [5, 9, 2, 6], [5, 3, 5, 8], [9, 7, 1, 2], ] ) out = scatter_example(x_exmp) print("Output", out) # %% # ppermute: communicates an array in a round robin fashion # # this is used in implementing pipeline parallelism where results are passed to another device # used in tensor parallelism # # notice how the results roll through the devices # # this can actually implement all other lax communication operations @functools.partial( shard_map, mesh=mesh, in_specs=PartitionSpec("i"), out_specs=PartitionSpec("i")) def ppermute_example(x: jax.Array) -> jax.Array: axis_size = mesh.shape["i"] print('BEFORE:\n', x) x_perm = jax.lax.ppermute( x, axis_name="i", perm=[ # source_index, destination_index pairs (i, (i + 1) % axis_size) for i in range(axis_size) ] ) print('AFTER:\n', x_perm) return x_perm x_exmp = np.arange(4) out = ppermute_example(x_exmp) print("Output", out) # the value is that of each axis 0 device # %% # # axis indexing: get the index of device along axis # sometimes our computations need adjustment depending on the device its being ran on # # we will use jax.lax.axis_index to return the index of the current device along an axis # # this function will be jitted and will be almost 0 cost axis_idx_fn = jax.jit( shard_map( lambda: jnp.stack( [ jax.lax.axis_index("i"), # Device index in mesh along the "i" axis jax.lax.axis_index("j"), # Device index in mesh along the "j" axis ], axis=-1, )[None], mesh, in_specs=PartitionSpec(), out_specs=PartitionSpec( ("i", "j"), ), ) ) out = axis_idx_fn() out = jax.device_get(out) for i in range(out.shape[0]): print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}") # %% # usage 2: fold rng over given axis # jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key # from a source RNG key, we generate new RNG keys def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey: """Folds the random number generator over the given axis. This is useful for generating a different random number for each device across a certain axis (e.g. the model axis). Args: rng: The random number generator. axis_name: The axis name to fold the random number generator over. Returns: A new random number generator, different for each device index along the axis. """ axis_index = jax.lax.axis_index(axis_name) return jax.random.fold_in(rng, axis_index) # %% # we fold RNG over the i axis only # same RNG used across j axis fold_fn = jax.jit( shard_map( # fold over for "i" only functools.partial(fold_rng_over_axis, axis_name="i"), mesh, in_specs=PartitionSpec(), out_specs=PartitionSpec( ("i", "j"), ), ) ) rng = jax.random.PRNGKey(0) out = fold_fn(rng) out = jax.device_get(out) for i in range(out.shape[0] // 2): print(f"Device {i}: RNG={out[2*i:2*i+2]}") # %%