learn_jax/parallel/intro_to_distributed.py

374 lines
10 KiB
Python

# %% [markdown]
# # Distribute computin in JAX
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
# %%
import functools
from typing import Any, Dict, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
jax.config.update('jax_platform_name', 'cpu')
# %%
jax.devices()
# %%
# when we create array, we can check the location
a = jnp.arange(8)
print("Array", a)
print("Device", a.device)
print("Sharding", a.sharding)
# %% [markdown]
# ## Single-Axis Mesh
# %%
# let's create a Mesh
# multidimensional Numpy array of jax devices
# jax.sharding.Mesh(devices, axis_names)
mesh = Mesh(devices=np.array(jax.devices()), axis_names=("i",))
print(mesh)
# %%
# jax.sharding.NamedSharding(mesh, spec)
# pair of a Mesh of devices and PartitionSpec
# PartitionSpec describes how to share an array across that mesh
# "i" is the value of the dimension of the array
# to shard an array axis over a certain mesh axis, add the axis name at the
# corresponding position in the tuple
sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i",))
# %%
a_sharded = jax.device_put(a, sharding)
print("Sharded array", a_sharded)
print("Device", a_sharded.devices())
print("Sharding", a_sharded.sharding)
# %%
jax.debug.visualize_array_sharding(a_sharded)
# %%
# let's try some computation on the mesh
out = nn.tanh(a_sharded)
print("Output array", out)
jax.debug.visualize_array_sharding(out)
# note how the output array is sharded across the devices
# %% [markdown]
# ## multi-axis mesh
# Why would you shard across multiple dimensions?
#
#
# %%
mesh = Mesh(devices=np.array(jax.devices()).reshape(4,2), axis_names=("i", "j"))
# axis i/0 refers to the row-wise axis progressing downwards
# axis j/1 refers to the column-wise axis progressing rightward
mesh # noqa: B018
# %%
# we now illustrate sharded MAC operation
# y = x @ w + b
batch_size = 192
input_dim = 64
output_dim = 128
# input: (batch_size, input_dim)
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, input_dim))
# w: (input_dim, output_dim)
w = jax.random.normal(jax.random.PRNGKey(1), (input_dim, output_dim))
# b: (output_dim,)
b = jax.random.normal(jax.random.PRNGKey(2), (output_dim,))
# %%
# x sharded along 0 axis (partition)
#
x_sharded = jax.device_put(x, NamedSharding(mesh, PartitionSpec("i", None)))
w_sharded = jax.device_put(w, NamedSharding(mesh, PartitionSpec(None, "j")))
b_sharded = jax.device_put(b, NamedSharding(mesh, PartitionSpec("j")))
print('x blocks:')
jax.debug.visualize_array_sharding(x_sharded)
print('w blocks:')
jax.debug.visualize_array_sharding(w_sharded)
print('b blocks:')
jax.debug.visualize_array_sharding(b_sharded)
# %%
out = jnp.dot(x_sharded, w_sharded) + b_sharded
print("Output shape", out.shape)
jax.debug.visualize_array_sharding(out)
# %% [markdown]
# # Shard Map -shmap
#
# beforehand, we manually assign the sharding partition to assign the exact
# partitions to achieve independent, parallel block matrix computation
#
# This allows us to write code with explicit control over parallelization and
# communication
#
# what is a shard_map?
#
# it is a transformation that takes a function, a mesh, and a sharding
# specification for inputs and outputs
#
# in other words, we write a function that executes on each device only, then
# apply across all the shards
#
# but wait, doesn't pmap do this? The answer is no. pmap doesn't have enough
# information about the shards to efficiently perform sharding for complicated
# meshes.
# %%
def matmul_fn(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
print("Local x shape", x.shape)
print("Local w shape", w.shape)
print("Local b shape", b.shape)
# so simple!
return jnp.dot(x,w) + b
# %%
matmul_sharded = shard_map(
matmul_fn, # the function for operating on a single device
mesh, # the device topology
# the input mesh partition argument for each input
in_specs=(
PartitionSpec("i", None), # x
PartitionSpec(None, "j"), # w
PartitionSpec("j") # b
),
# the output to read from the mesh
out_specs=PartitionSpec("i", "j")
)
# %%
# y = matmul_sharded(x_sharded, w_sharded, b_sharded)
# there is no need to device_put,
# partitioning is done according to your in_specs
y = matmul_sharded(x, w, b)
print("Output shape", y.shape)
jax.debug.visualize_array_sharding(y)
# %% [markdown]
# # Axis Communication
# %%
# example of mean/sum across devices per shard
# the following wants to find the statistics of x
# we compute the normalized x according to each row statistics (mean and std)
@functools.partial(
shard_map,
mesh=mesh,
in_specs=PartitionSpec("i", "j"),
out_specs=PartitionSpec("i", "j"))
def parallel_normalize(x: jax.Array) -> jax.Array:
# jax.lax.pmean: compute an all-reduce sum on x over the pmapped axis
# "axis_name"
# get the mean across the "j" axis of the mesh - column wise
mean = jax.lax.pmean(x, axis_name="j")
# get the std across the "j" axis of the mesh - column wise
std = jax.lax.pmean((x - mean) ** 2, axis_name="j") ** 0.5
return (x - mean) / std
# communicated along "j" axis of mesh for row elements
out = parallel_normalize(x)
out = jax.device_get(out)
print(out.shape)
print("Mean", out.mean())
print("Std", out.std())
# %%
# scenario: array is sharded across devices, some values missing per shard
# all-gather: gather values of an array from all devices
@functools.partial(
shard_map,
mesh=mesh,
in_specs=(
PartitionSpec("i", None), # artificially shard across "i"
PartitionSpec("i", None)
),
out_specs=PartitionSpec("i", None))
def matmul_with_weight_gather(x: jax.Array, w: jax.Array) -> jax.Array:
print("Original w shape", w.shape)
# pull the full w matrix values from neighboring devices
w_gathered = jax.lax.all_gather(w, axis_name="i", axis=0, tiled=True)
print("Gathered w shape", w_gathered.shape)
y = jnp.dot(x, w_gathered)
return y
out = matmul_with_weight_gather(x, w)
out = jax.device_get(out)
np.testing.assert_array_equal(out, jnp.dot(x, w))
# %%
# scenario: arrays are sharded across all devices
# scatter sum: each function instance of each device gets only one shard of the result
#
# therefore each device gets the sum of some(or one) array(s)
@functools.partial(
shard_map,mesh=mesh,
in_specs=PartitionSpec("i", None),
out_specs=PartitionSpec("i", None))
def scatter_example(x: jax.Array) -> jax.Array:
x_scatter = jax.lax.psum_scatter(x, axis_name="i", scatter_dimension=1)
return x_scatter
x_exmp = np.array(
[
[3, 1, 4, 1],
[5, 9, 2, 6],
[5, 3, 5, 8],
[9, 7, 1, 2],
]
)
out = scatter_example(x_exmp)
print("Output", out)
# %%
# ppermute: communicates an array in a round robin fashion
#
# this is used in implementing pipeline parallelism where results are passed to another device
# used in tensor parallelism
#
# notice how the results roll through the devices
#
# this can actually implement all other lax communication operations
@functools.partial(
shard_map,
mesh=mesh,
in_specs=PartitionSpec("i"),
out_specs=PartitionSpec("i"))
def ppermute_example(x: jax.Array) -> jax.Array:
axis_size = mesh.shape["i"]
print('BEFORE:\n', x)
x_perm = jax.lax.ppermute(
x,
axis_name="i",
perm=[
# source_index, destination_index pairs
(i, (i + 1) % axis_size) for i in range(axis_size)
]
)
print('AFTER:\n', x_perm)
return x_perm
x_exmp = np.arange(4)
out = ppermute_example(x_exmp)
print("Output", out) # the value is that of each axis 0 device
# %%
# # axis indexing: get the index of device along axis
# sometimes our computations need adjustment depending on the device its being ran on
#
# we will use jax.lax.axis_index to return the index of the current device along an axis
#
# this function will be jitted and will be almost 0 cost
axis_idx_fn = jax.jit(
shard_map(
lambda: jnp.stack(
[
jax.lax.axis_index("i"), # Device index in mesh along the "i" axis
jax.lax.axis_index("j"), # Device index in mesh along the "j" axis
],
axis=-1,
)[None],
mesh,
in_specs=PartitionSpec(),
out_specs=PartitionSpec(
("i", "j"),
),
)
)
out = axis_idx_fn()
out = jax.device_get(out)
for i in range(out.shape[0]):
print(f"Device {i}: i-axis={out[i, 0]}, j-axis={out[i, 1]}")
# %%
# usage 2: fold rng over given axis
# jax.random.fold_in: folds in data to a PRNG key to form a new PRNG key
# from a source RNG key, we generate new RNG keys
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
"""Folds the random number generator over the given axis.
This is useful for generating a different random number for each device
across a certain axis (e.g. the model axis).
Args:
rng: The random number generator.
axis_name: The axis name to fold the random number generator over.
Returns:
A new random number generator, different for each device index along the axis.
"""
axis_index = jax.lax.axis_index(axis_name)
return jax.random.fold_in(rng, axis_index)
# %%
# we fold RNG over the i axis only
# same RNG used across j axis
fold_fn = jax.jit(
shard_map(
# fold over for "i" only
functools.partial(fold_rng_over_axis, axis_name="i"),
mesh,
in_specs=PartitionSpec(),
out_specs=PartitionSpec(
("i", "j"),
),
)
)
rng = jax.random.PRNGKey(0)
out = fold_fn(rng)
out = jax.device_get(out)
for i in range(out.shape[0] // 2):
print(f"Device {i}: RNG={out[2*i:2*i+2]}")
# %%