374 lines
10 KiB
Python
374 lines
10 KiB
Python
# %% [markdown]
|
|
# # Distribute computin in JAX
|
|
|
|
# %%
|
|
import os
|
|
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = True
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
else:
|
|
# GPU flags
|
|
flags += (
|
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
|
"--xla_gpu_triton_gemm_any=false "
|
|
"--xla_gpu_enable_async_collectives=true "
|
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
|
)
|
|
os.environ["XLA_FLAGS"] = flags
|
|
|
|
# %%
|
|
import functools
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec
|
|
|
|
PyTree = Any
|
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
|
jax.config.update('jax_platform_name', 'cpu')
|
|
|
|
# %%
|
|
jax.devices()
|
|
|
|
|
|
# %%
|
|
# when we create array, we can check the location
|
|
a = jnp.arange(8)
|
|
print("Array", a)
|
|
print("Device", a.device)
|
|
print("Sharding", a.sharding)
|
|
|
|
# %% [markdown]
|
|
# ## Single-Axis Mesh
|
|
|
|
# %%
|
|
# let's create a Mesh
|
|
# multidimensional Numpy array of jax devices
|
|
# jax.sharding.Mesh(devices, axis_names)
|
|
mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",))
|
|
print(mesh)
|
|
|
|
# %%
|
|
# jax.sharding.NamedSharding(mesh, spec)
|
|
# pair of a Mesh of devices and PartitionSpec
|
|
# PartitionSpec describes how to share an array across that mesh
|
|
# "i" is the value of the dimension of the array
|
|
# to shard an array axis over a certain mesh axis, add the axis name at the
|
|
# corresponding position in the tuple
|
|
sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",))
|
|
|
|
# %%
|
|
a_sharded = jax.device_put(a, sharding)
|
|
print("Sharded array", a_sharded)
|
|
print("Device", a_sharded.devices())
|
|
print("Sharding", a_sharded.sharding)
|
|
|
|
# %%
|
|
jax.debug.visualize_array_sharding(a_sharded)
|
|
|
|
# %%
|
|
# let's try some computation on the mesh
|
|
out = nn.tanh(a_sharded)
|
|
print("Output array", out)
|
|
jax.debug.visualize_array_sharding(out)
|
|
# note how the output array is sharded across the devices
|
|
|
|
# %% [markdown]
|
|
# ## multi-axis mesh
|
|
# Why would you shard across multiple dimensions?
|
|
#
|
|
#
|
|
|
|
# %%
|
|
mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j"))
|
|
# axis i/0 refers to the row-wise axis progressing downwards
|
|
# axis j/1 refers to the column-wise axis progressing rightward
|
|
mesh # noqa: B018
|
|
|
|
# %%
|
|
# we now illustrate sharded MAC operation
|
|
# y = x @ w + b
|
|
batch_size = 192
|
|
input_dim = 64
|
|
output_dim = 128
|
|
# input: (batch_size, input_dim)
|
|
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
|
|
# w: (input_dim, output_dim)
|
|
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
|
|
# b: (output_dim,)
|
|
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
|
|
|
|
|
|
# %%
|
|
# x sharded along 0 axis (partition)
|
|
#
|
|
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None)))
|
|
w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j")))
|
|
b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j")))
|
|
|
|
print('x blocks:')
|
|
jax.debug.visualize_array_sharding(x_sharded)
|
|
print('w blocks:')
|
|
jax.debug.visualize_array_sharding(w_sharded)
|
|
print('b blocks:')
|
|
jax.debug.visualize_array_sharding(b_sharded)
|
|
|
|
|
|
# %%
|
|
out = jnp.dot(x_sharded, w_sharded) + b_sharded
|
|
print("Output shape", out.shape)
|
|
jax.debug.visualize_array_sharding(out)
|
|
|
|
|
|
# %% [markdown]
|
|
# # Shard Map -shmap
|
|
#
|
|
# beforehand, we manually assign the sharding partition to assign the exact
|
|
# partitions to achieve independent, parallel block matrix computation
|
|
#
|
|
# This allows us to write code with explicit control over parallelization and
|
|
# communication
|
|
#
|
|
# what is a shard_map?
|
|
#
|
|
# it is a transformation that takes a function, a mesh, and a sharding
|
|
# specification for inputs and outputs
|
|
#
|
|
# in other words, we write a function that executes on each device only, then
|
|
# apply across all the shards
|
|
#
|
|
# but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough
|
|
# information about the shards to efficiently perform sharding for complicated
|
|
# meshes.
|
|
|
|
# %%
|
|
def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
|
|
print("Local x shape", x.shape)
|
|
print("Local w shape", w.shape)
|
|
print("Local b shape", b.shape)
|
|
# so simple!
|
|
return jnp.dot(x,w) + b
|
|
|
|
# %%
|
|
matmul_sharded = shard_map(
|
|
matmul_fn, # the function for operating on a single device
|
|
mesh, # the device topology
|
|
# the input mesh partition argument for each input
|
|
in_specs=(
|
|
PartitionSpec("i", None), # x
|
|
PartitionSpec(None, "j"), # w
|
|
PartitionSpec("j") # b
|
|
),
|
|
# the output to read from the mesh
|
|
out_specs=PartitionSpec("i", "j")
|
|
)
|
|
|
|
# %%
|
|
# y = matmul_sharded(x_sharded, w_sharded, b_sharded)
|
|
# there is no need to device_put,
|
|
# partitioning is done according to your in_specs
|
|
y = matmul_sharded(x, w, b)
|
|
print("Output shape", y.shape)
|
|
jax.debug.visualize_array_sharding(y)
|
|
|
|
|
|
# %% [markdown]
|
|
# # Axis Communication
|
|
|
|
# %%
|
|
# example of mean/sum across devices per shard
|
|
|
|
# the following wants to find the statistics of x
|
|
# we compute the normalized x according to each row statistics (mean and std)
|
|
@functools.partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=PartitionSpec("i", "j"),
|
|
out_specs=PartitionSpec("i", "j"))
|
|
def parallel_normalize(x: jax.Array) -> jax.Array:
|
|
# jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis
|
|
# "axis_name"
|
|
# get the mean across the "j" axis of the mesh - column wise
|
|
mean = jax.lax.pmean(x, axis_name="j")
|
|
# get the std across the "j" axis of the mesh - column wise
|
|
std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5
|
|
return (x - mean) / std
|
|
|
|
# communicated along "j" axis of mesh for row elements
|
|
|
|
|
|
out = parallel_normalize(x)
|
|
out = jax.device_get(out)
|
|
print(out.shape)
|
|
print("Mean", out.mean())
|
|
print("Std", out.std())
|
|
|
|
|
|
# %%
|
|
# scenario: array is sharded across devices, some values missing per shard
|
|
# all-gather: gather values of an array from all devices
|
|
@functools.partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(
|
|
PartitionSpec("i", None), # artificially shard across "i"
|
|
PartitionSpec("i", None)
|
|
),
|
|
out_specs=PartitionSpec("i", None))
|
|
def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array:
|
|
print("Original w shape", w.shape)
|
|
# pull the full w matrix values from neighboring devices
|
|
w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True)
|
|
print("Gathered w shape", w_gathered.shape)
|
|
y = jnp.dot(x, w_gathered)
|
|
return y
|
|
|
|
|
|
out = matmul_with_weight_gather(x, w)
|
|
out = jax.device_get(out)
|
|
np.testing.assert_array_equal(out, jnp.dot(x, w))
|
|
|
|
# %%
|
|
# scenario: arrays are sharded across all devices
|
|
# scatter sum: each function instance of each device gets only one shard of the result
|
|
#
|
|
# therefore each device gets the sum of some(or one) array(s)
|
|
|
|
@functools.partial(
|
|
shard_map,mesh=mesh,
|
|
in_specs=PartitionSpec("i", None),
|
|
out_specs=PartitionSpec("i", None))
|
|
def scatter_example(x: jax.Array) -> jax.Array:
|
|
x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1)
|
|
return x_scatter
|
|
|
|
|
|
x_exmp = np.array(
|
|
[
|
|
[3, 1, 4, 1],
|
|
[5, 9, 2, 6],
|
|
[5, 3, 5, 8],
|
|
[9, 7, 1, 2],
|
|
]
|
|
)
|
|
out = scatter_example(x_exmp)
|
|
print("Output", out)
|
|
# %%
|
|
# ppermute: communicates an array in a round robin fashion
|
|
#
|
|
# this is used in implementing pipeline parallelism where results are passed to another device
|
|
# used in tensor parallelism
|
|
#
|
|
# notice how the results roll through the devices
|
|
#
|
|
# this can actually implement all other lax communication operations
|
|
|
|
@functools.partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=PartitionSpec("i"),
|
|
out_specs=PartitionSpec("i"))
|
|
def ppermute_example(x: jax.Array) -> jax.Array:
|
|
axis_size = mesh.shape["i"]
|
|
print('BEFORE:\n', x)
|
|
x_perm = jax.lax.ppermute(
|
|
x,
|
|
axis_name="i",
|
|
perm=[
|
|
# source_index, destination_index pairs
|
|
(i, (i + 1) % axis_size) for i in range(axis_size)
|
|
]
|
|
)
|
|
print('AFTER:\n', x_perm)
|
|
return x_perm
|
|
|
|
|
|
x_exmp = np.arange(4)
|
|
out = ppermute_example(x_exmp)
|
|
print("Output", out) # the value is that of each axis 0 device
|
|
|
|
|
|
# %%
|
|
# # axis indexing: get the index of device along axis
|
|
# sometimes our computations need adjustment depending on the device its being ran on
|
|
#
|
|
# we will use jax.lax.axis_index to return the index of the current device along an axis
|
|
#
|
|
# this function will be jitted and will be almost 0 cost
|
|
|
|
axis_idx_fn = jax.jit(
|
|
shard_map(
|
|
lambda: jnp.stack(
|
|
[
|
|
jax.lax.axis_index("i"), # Device index in mesh along the "i" axis
|
|
jax.lax.axis_index("j"), # Device index in mesh along the "j" axis
|
|
],
|
|
axis=-1,
|
|
)[None],
|
|
mesh,
|
|
in_specs=PartitionSpec(),
|
|
out_specs=PartitionSpec(
|
|
("i", "j"),
|
|
),
|
|
)
|
|
)
|
|
out = axis_idx_fn()
|
|
out = jax.device_get(out)
|
|
for i in range(out.shape[0]):
|
|
print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}")
|
|
|
|
# %%
|
|
# usage 2: fold rng over given axis
|
|
# jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key
|
|
# from a source RNG key, we generate new RNG keys
|
|
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
|
|
"""Folds the random number generator over the given axis.
|
|
|
|
This is useful for generating a different random number for each device
|
|
across a certain axis (e.g. the model axis).
|
|
|
|
Args:
|
|
rng: The random number generator.
|
|
axis_name: The axis name to fold the random number generator over.
|
|
|
|
Returns:
|
|
A new random number generator, different for each device index along the axis.
|
|
"""
|
|
axis_index = jax.lax.axis_index(axis_name)
|
|
return jax.random.fold_in(rng, axis_index)
|
|
|
|
# %%
|
|
# we fold RNG over the i axis only
|
|
# same RNG used across j axis
|
|
fold_fn = jax.jit(
|
|
shard_map(
|
|
# fold over for "i" only
|
|
functools.partial(fold_rng_over_axis, axis_name="i"),
|
|
mesh,
|
|
in_specs=PartitionSpec(),
|
|
out_specs=PartitionSpec(
|
|
("i", "j"),
|
|
),
|
|
)
|
|
)
|
|
rng = jax.random.PRNGKey(0)
|
|
out = fold_fn(rng)
|
|
out = jax.device_get(out)
|
|
for i in range(out.shape[0] // 2):
|
|
print(f"Device {i}: RNG={out[2*i:2*i+2]}")
|
|
|
|
|
|
# %%
|