diff --git a/parallel/dataload.py b/parallel/dataload.py index eea2570..fff0b70 100644 --- a/parallel/dataload.py +++ b/parallel/dataload.py @@ -98,21 +98,21 @@ class DataPrepare(): ) # check that data fits - for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: - int_array: np.array = train_dataset[name] - if np.all((int_array >= 0) & (int_array <= 65535)): - continue - else: - raise ValueError("Values are out of range for uint16") + # for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: + # int_array: np.array = train_dataset[name] + # if np.all((int_array >= 0) & (int_array <= 65535)): + # continue + # else: + # raise ValueError("Values are out of range for uint16") # change to compact datatypes - features = train_dataset.features.copy() - features['input_ids'] = Sequence(Value('uint16')) - features['attention_mask'] = Sequence(Value('bool')) - features['labels'] = Sequence(Value('uint16')) - features['decoder_input_ids'] = Sequence(Value('uint16')) - features['decoder_attention_mask'] = Sequence(Value('bool')) - train_dataset = train_dataset.cast(features) + # features = train_dataset.features.copy() + # features['input_ids'] = Sequence(Value('uint16')) + # features['attention_mask'] = Sequence(Value('uint16')) + # features['labels'] = Sequence(Value('uint16')) + # features['decoder_input_ids'] = Sequence(Value('uint16')) + # features['decoder_attention_mask'] = Sequence(Value('uint16')) + # train_dataset = train_dataset.cast(features) # assign the dataset to train_dataset self.train_dataset = train_dataset @@ -140,15 +140,15 @@ class DataPrepare(): for idx in batch_idx: batch = dataset[idx] - batch = {k: jnp.array(v) for k, v in batch.items()} + batch = {k: v for k, v in batch.items()} yield batch # testing out the class -# # %% -# # init object -# # e.g. Config +# %% +# init object +# e.g. Config # data_config = ConfigDict( # dict( # max_length=86, @@ -157,7 +157,9 @@ class DataPrepare(): # ) # ) # -# dataprep = DataPrepare(split_datasets, data_config) +# from datasets import load_from_disk +# split_datasets = load_from_disk(file_path) +# dataprep = DataPrepare(split_datasets['train'], data_config) # # # %% # seed = 117 diff --git a/parallel/flax_pjit_tutorial.py b/parallel/flax_pjit_tutorial.py new file mode 100644 index 0000000..1ddf451 --- /dev/null +++ b/parallel/flax_pjit_tutorial.py @@ -0,0 +1,456 @@ +# MARK: START +# %% +# let's make 8-device simulator +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +import flax.core + +from tqdm import tqdm + +from dataload import DataPrepare + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + + +# MARK: sharding example +# %% +device_mesh = mesh_utils.create_device_mesh((2,2)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec) + +# %% +# define a layer +class DotReluDot(nn.Module): + depth: int + dense_init: Callable = nn.initializers.xavier_normal() + @nn.compact + def __call__(self, x): + + # y has shape (x.shape[-1], self.depth) + # we replicate data across devices + # but we shard the layer across the model axes + y = nn.Dense(self.depth, + kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')), + use_bias=False, # or overwrite with `bias_init` + )(x) + + y = jax.nn.relu(y) + # Force a local sharding annotation. + # annotate intermediate variables to force a particular sharding pattern + # when ideal constraint is known + y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) + + W2 = self.param( + 'W2', + nn.with_partitioning(self.dense_init, ('model', None)), + (self.depth, x.shape[-1])) + + z = jnp.dot(y, W2) + # Force a local sharding annotation. + z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None))) + + # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below. + return z, None + +# %% +class MLP(nn.Module): + num_layers: int + depth: int + use_scan: bool + @nn.compact + def __call__(self, x): + if self.use_scan: + x, _ = nn.scan( + DotReluDot, + length=self.num_layers, + variable_axes={"params": 0}, + split_rngs={"params": True}, + metadata_params={nn.PARTITION_NAME: None} + )(self.depth)(x) + else: + for _ in range(self.num_layers): + x, _ = DotReluDot(self.depth)(x) + return x + +# %% +# MLP hyperparameters. +BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False +# Create fake inputs. +x = jnp.ones((BATCH, DEPTH)) +# Initialize a PRNG key. +k = jax.random.key(117) + +# Create an Optax optimizer. +optimizer = optax.adam(learning_rate=0.001) +# Instantiate the model. +model = MLP(LAYERS, DEPTH, USE_SCAN) + +# %% +# specify sharding + +# shard data +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +x = jax.device_put(x, x_sharding) +jax.debug.visualize_array_sharding(x) + +# shard output +# we will shard state by tracking its output upon jax.eval_shape after init +# define an init function to return a TrainState +def init_fn(key, x, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + variables = model.init(key, x) # Initialize the model. + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=variables['params'], + tx=optimizer) + return state + +# Create an abstract closure to wrap the function before feeding it in +# because `jax.eval_shape` only takes pytrees as arguments. +# eval_shape(fn, rng_key, x) +# used to perform shape inference +# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=optimizer), k, x) + + +# This `state_sharding` has the same pytree structure as `state`, the output +# of the `init_fn`. +# flan.linen.get_sharding +# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) +print(state_sharding) + +# %% +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data + out_shardings=state_sharding +) + +initialized_state = jit_init_fn(k, x, model, optimizer) +# for weight, partitioned in initialized_state.params['DotReluDot_0'].items(): +# print(f'Sharding of {weight}: {partitioned.names}') +jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) +jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) + +# %% +# inspect module output +# the params are actually linen.Partitioned objects +# the Partition objects are actually wrapping jax.Array +print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) +print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)) +print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) +print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) +# %% +# Say for some unknown reason you want to make the whole param tree all-zero +unboxed_params = nn.meta.unbox(initialized_state.params) +all_zero = jax.tree.map(jnp.zeros_like, unboxed_params) +all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero) +assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0 +# %% +# check the jax.sharding of each parameter +print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding) +print(initialized_state.step) +print(initialized_state.step.sharding) +# %% +# example of computation on pytree +diff = jax.tree.map( + lambda a, b: a - b, + initialized_state.params['DotReluDot_0'], + initialized_state.params['DotReluDot_0'] +) +print(jax.tree.map(jnp.shape, diff)) +diff_array = diff['Dense_0']['kernel'].value +print(type(diff_array)) +print(diff_array.shape) +# %% +# compile train step +@functools.partial( + jax.jit, + in_shardings=(state_sharding, x_sharding), + out_shardings=state_sharding +) +def train_step(state, x): + def loss_unrolled(params): + y = model.apply({'params': params}, x) + return y.sum() + grad_fn = jax.grad(loss_unrolled) + grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + return state + +# this trains for 1 step +# with mesh: # not strictly necessary in this case +# with mesh block is useful for explicit scope for device sharding +# but mesh management is automatic via jit sharding annotations +new_state = train_step(initialized_state, x) + +print(f'Sharding of Weight 1:') +jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) +print(f'Sharding of Weight 2:') +jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) + + +# %% +# compile inference step +@functools.partial( + jax.jit, + in_shardings=(state_sharding, x_sharding), + out_shardings=x_sharding +) +def apply_fn(state, x): + return state.apply_fn({'params': state.params}, x) + +# this infers for 1 step +with mesh: + y = apply_fn(new_state, x) +print(type(y)) +print(y.dtype) +print(y.shape) +jax.debug.visualize_array_sharding(y) +# %% +# profiling +# measure performance +import timeit +def block_all(xs): + jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs) + return xs + +# with mesh: +t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10) +print(t) + +# MARK: logical sharding +# %% +# logical axis annotation +# why? +# rather than just be fixed with 'data' and 'model', we can annotate with logical names +# then map these logical names back to 'data' and 'model', +# or be more flexible with more axes +# +# we will substitute with the following: +# flax.linen.with_partitioning -> flax.linen.with_logical_partitioning +# flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint + +class LogicalDotReluDot(nn.Module): + depth: int + dense_init: Callable = nn.initializers.xavier_normal() + @nn.compact + def __call__(self, x): + y = nn.Dense( + self.depth, + # use of logical partitioning here + # kernel_init is the initializer function for the weight matrix + kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')), + use_bias=False, # or overwrite with `bias_init` + )(x) + + y = jax.nn.relu(y) + # Force a local sharding annotation. + y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) + + W2 = self.param( + 'W2', + nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')), + (self.depth, x.shape[-1])) + + z = jnp.dot(y, W2) + # Force a local sharding annotation. + z = nn.with_logical_constraint(z, ('batch', 'embed')) + return z, None + +class LogicalMLP(nn.Module): + num_layers: int + depth: int + use_scan: bool + @nn.compact + def __call__(self, x): + if self.use_scan: + x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers, + variable_axes={"params": 0}, + split_rngs={"params": True}, + metadata_params={nn.PARTITION_NAME: 'layer'} + )(self.depth)(x) + else: + for _ in range(self.num_layers): + x, _ = LogicalDotReluDot(self.depth)(x) + return x + +# %% +# we initialize the model with the eval_shape method +# but we need to perform a rule to replace logical axes with real axes +rules =( + ('batch', 'data'), + ('hidden', 'model') +) + +logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN) + +logical_abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=logical_model, optimizer=optimizer), + k, + x, +) + +# linen.get_partition_spec +# extracts a partitionspec tree from a pytree containing partitioned values +# linen.Partitioned +# wrapper for partitioning metadata +logical_state_spec = nn.get_partition_spec(logical_abstract_variables) +print( + "annotations are logical, not mesh specific", + logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'] +) + +# we convert our logical_state_spec to a logical_state_sharding +# with defined rules +logical_state_sharding = nn.logical_to_mesh_sharding( + logical_state_spec, + mesh, + rules +) +print('sharding annotations are mesh-specific: ', + logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec) + +# %% +# with a working state_sharding object, we can init +logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), + in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x + out_shardings=logical_state_sharding) + +logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) + +# %% +# MARK: saving checkpoint +# https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing +# let us save the model +# since we already have a mesh, we will skip the making of the mesh + +from typing import Optional, Any +import shutil + +import numpy as np +import jax +from jax import random, numpy as jnp + +import flax +from flax import linen as nn +from flax.training import checkpoints, train_state +from flax import struct, serialization +import orbax.checkpoint + +import optax + +ckpt_dir = '/tmp/flax_ckpt' + +if os.path.exists(ckpt_dir): + shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run. + +# %% +# make up some stuff +# A simple model with one linear layer. +key1, key2 = random.split(random.key(0)) +x1 = random.normal(key1, (5,)) # A simple JAX array. +model = nn.Dense(features=3) +variables = model.init(key2, x1) + +# Flax's TrainState is a pytree dataclass and is supported in checkpointing. +# Define your class with `@flax.struct.dataclass` decorator to make it compatible. +tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer. +state = train_state.TrainState.create( + apply_fn=model.apply, + params=variables['params'], + tx=tx) +# Perform a simple gradient update similar to the one during a normal training workflow. +state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params)) + +# Some arbitrary nested pytree with a dictionary and a NumPy array. +config = {'dimensions': np.array([5, 3])} + +# Bundle everything together. +ckpt = {'model': state, 'config': config, 'data': [x1]} + +# %% +# single host save with orbax +from flax.training import orbax_utils + +orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() +save_args = orbax_utils.save_args_from_target(ckpt) +orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args) + + +# %% +# multi-process checkpointing +# aka checkpointing for sharding +# The reference doesn't need to be as large as your checkpoint! +# Just make sure it has the `.sharding` you want. +# https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html +# https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html + +import orbax.checkpoint as ocp +from etils import epath + +path = epath.Path('/tmp/async_checkpoint') +ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) +ckptr.save(path, args=ocp.args.StandardSave(train_state)) +### Do some other work... +ckptr.wait_until_finished() diff --git a/parallel/fully_sharded_data_parallelism.py b/parallel/fully_sharded_data_parallelism.py index 098a389..0df8e58 100644 --- a/parallel/fully_sharded_data_parallelism.py +++ b/parallel/fully_sharded_data_parallelism.py @@ -737,7 +737,8 @@ start_time = time.time() for _ in range(15): state_fsdp, metrics_fsdp = train_step_fsdp_fn( state_fsdp, - metrics_fsdp, batch) + metrics_fsdp, + batch) duration = time.time() - start_time print(duration) @@ -747,7 +748,8 @@ final_metrics_fsdp = jax.tree.map( metric_shapes) state_fsdp, final_metrics_fsdp = train_step_fsdp_fn( state_fsdp, - final_metrics_fsdp, batch) + final_metrics_fsdp, + batch) print_metrics(final_metrics_fsdp, "FSDP - Final metrics") diff --git a/parallel/t5_jax_train_pjit.py b/parallel/t5_jax_train_2.py similarity index 84% rename from parallel/t5_jax_train_pjit.py rename to parallel/t5_jax_train_2.py index c1dc6b3..7b41398 100644 --- a/parallel/t5_jax_train_pjit.py +++ b/parallel/t5_jax_train_2.py @@ -8,7 +8,7 @@ import os # Set this to True to run the model on CPU only. -USE_CPU_ONLY = True +USE_CPU_ONLY = False flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: @@ -20,10 +20,10 @@ else: # GPU flags flags += ( "--xla_gpu_enable_triton_softmax_fusion=true " - "--xla_gpu_triton_gemm_any=false " - "--xla_gpu_enable_async_collectives=true " - "--xla_gpu_enable_latency_hiding_scheduler=true " - "--xla_gpu_enable_highest_priority_async_stream=true " + # "--xla_gpu_triton_gemm_any=false " + # "--xla_gpu_enable_async_collectives=true " + # "--xla_gpu_enable_latency_hiding_scheduler=true " + # "--xla_gpu_enable_highest_priority_async_stream=true " ) os.environ["XLA_FLAGS"] = flags @@ -37,7 +37,8 @@ import jax import jax.numpy as jnp import numpy as np from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh, NamedSharding +from jax.sharding import Mesh +from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from ml_collections import ConfigDict import optax @@ -148,7 +149,7 @@ dataprep = DataPrepare(split_datasets['train'], data_config) # seed = 117 # rng = jax.random.PRNGKey(seed) # train_loader = dataprep.data_loader(rng, batch_size=1) - +# batch = next(iter(train_loader)) # %% # model @@ -161,8 +162,8 @@ config = T5Config() # If you want don't want to cast certain parameters (for example layer norm bias and scale) # then pass the mask as follows from flax import traverse_util - model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") + # useful for transformer model model.enable_gradient_checkpointing() @@ -174,6 +175,21 @@ mask = { mask = traverse_util.unflatten_dict(mask) model.params = model.to_bf16(model.params, mask) +# %% + + +# %% +from jax.sharding import Mesh, NamedSharding +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec as P +from pjit_partition import set_partitions + +devices = np.asarray(jax.devices()) +mesh_axis_names = ('data') +mesh = Mesh(devices, 'batch') +sharding = NamedSharding(mesh, P(mesh_axis_names)) +replicated_sharding = NamedSharding(mesh, P()) + # %% [markdown] # # Model @@ -243,22 +259,9 @@ adamw = optax.adamw( # %% -# Training functions -class TrainState(train_state.TrainState): - dropout_rng: jnp.ndarray +# state will serve as our "params" +state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) - # easy way to achieve data parallelism - # also achieves folding of rng keys - def replicate(self): - return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) - -# set bf16 for model params -# model.params = model.to_bf16(model.params) -# Cast parameters to bfloat16 if desired -# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) - -# Setup train state -state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): @@ -283,9 +286,10 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): num_labels = padding_mask.sum() return loss, num_labels +# MARK: train_step # Define gradient update step fn -@jax.jit -def train_step(state, batch, label_smoothing_factor=0.0): +def train_step(state, batch): + label_smoothing_factor=0.0 dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): @@ -311,27 +315,22 @@ def train_step(state, batch, label_smoothing_factor=0.0): metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics -# Define generation function -max_length = ( - val_max_target_length if val_max_target_length is not None else model.config.max_length -) -num_beams = num_beams if num_beams is not None else model.config.num_beams -gen_kwargs = {"max_length": max_length, "num_beams": num_beams} - -# def generate_step(params, batch): -# model.params = params -# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) -# return output_ids.sequences +# max_length = ( +# val_max_target_length if val_max_target_length is not None else model.config.max_length +# ) +# num_beams = num_beams if num_beams is not None else model.config.num_beams +# gen_kwargs = {"max_length": max_length, "num_beams": num_beams} # Create parallel version of the train and eval step -p_train_step = jax.pmap( - partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) +# only state and batch +p_train_step = jax.jit( + train_step, + # state for first, batch for second + in_shardings=(P("data"), P("data")), + out_shardings=(P("data"), P("data")), + donate_argnames=("state"), ) -# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") -# p_generate_step = jax.pmap(generate_step, "batch") -# Replicate the train state on each device -state = state.replicate() @@ -349,6 +348,21 @@ print(f" Total optimization steps = {total_train_steps}") # %% # jax.profiler.start_trace("./traces") +# Example batch (sharded across devices) +sharded_batch = { + 'input_ids': jax.device_put_sharded(batch['input_ids'], devices), + 'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices), + 'labels': jax.device_put_sharded(batch['labels'], devices), + 'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices), + 'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices), +} + +# Initial TrainState (pjit-ted TrainState) +sharded_state = jax.device_put_replicated(train_state, devices) + +# %% + + rng, input_rng = jax.random.split(rng) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) @@ -363,7 +377,7 @@ for epoch in epochs: # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) - batch = shard(batch) + # batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) diff --git a/parallel/t5_jax_train_fail.py b/parallel/t5_jax_train_fail.py new file mode 100644 index 0000000..4867379 --- /dev/null +++ b/parallel/t5_jax_train_fail.py @@ -0,0 +1,541 @@ +# %% [markdown] +# # T5 implementation using jax with pjit + + +# MARK: START +# %% +# let's make 8-device simulator +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh +from jax.experimental.pjit import pjit +from jax.sharding import PartitionSpec as P +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +import flax.core + +from tqdm import tqdm + +from dataload import DataPrepare + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + + +# # %% +# import jax +# import jax.numpy as jnp +# import optax +# import numpy as np +# from functools import partial +# from typing import Callable, Optional +# import math +# +# # jax.config.update("jax_default_matmul_precision", "tensorfloat32") +# jax.config.update("jax_default_matmul_precision", "bfloat16") +# # jax.config.update("jax_enable_x64", False) +# # enable cache +# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +# +# +# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig +# +# from flax import jax_utils, traverse_util +# from flax.jax_utils import pad_shard_unpad, unreplicate +# from flax.training import train_state +# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +# import flax.core + + +# %% +# get platform type +from jax.lib import xla_bridge +print(xla_bridge.get_backend().platform) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' +save_path = 't5_80_1_bf16' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 384 # 384 is the best +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 2e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = 128 + +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +data_config = ConfigDict( + dict( + max_length=86, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=1) +batch = next(iter(train_loader)) + +# %% +batch + +# %% +# model + +from transformers import FlaxT5ForConditionalGeneration +from transformers import T5Config + +config = T5Config() + +# If you want don't want to cast certain parameters (for example layer norm bias and scale) +# then pass the mask as follows +from flax import traverse_util +model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False) + +# useful for transformer model +model.enable_gradient_checkpointing() + +# enable bf16 except for layer_norm +flat_params = traverse_util.flatten_dict(model.params) +mask = { + path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params +} +mask = traverse_util.unflatten_dict(mask) +model.params = model.to_bf16(model.params, mask) + +# %% + +model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False) +t5_module = model.module + +# %% +jax.tree.map(jnp.shape, model.params) + +# %% +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec +from pjit_partition import set_partitions + +params = model.params +data_partition_specs = PartitionSpec() +extra_param_keys = list(model._missing_keys) +initial_partition_specs = set_partitions(params) +# this is the partition spec we will use +filled_param_partition_specs = set_partitions(params, extra_keys=extra_param_keys) + +# %% +# let us see the param_partition_spec +filled_param_partition_specs + +# %% let us set up the mesh + +from jax.sharding import Mesh +devices = np.asarray(jax.devices()) + +# %% + +# mp: model/tensor parallelism +# dp: data parallelism +# we just use 'data' as a common axis for data and model params +mesh_axis_names = ("data") +print("Logical mesh:", devices) + +mesh = Mesh(devices, mesh_axis_names) + +# it is technically possible to use pjit_partition to set special partition rules +# e.g. by param size +# but for now just move on + +# %% [markdown] +# # Model +# +# +# + +# %% + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + + +# %% +# optimization functions + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +# %% +# Training functions +# class TrainState(train_state.TrainState): +# dropout_rng: jnp.ndarray +# +# # easy way to achieve data parallelism +# # also achieves folding of rng keys +# def replicate(self): +# return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + +# set bf16 for model params +# model.params = model.to_bf16(model.params) +# Cast parameters to bfloat16 if desired +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + +# state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + +# %% +# see if we can init the model + +# %% +from transformers import FlaxT5Model, T5Config +config = T5Config.from_pretrained('t5-base') +model = FlaxT5Model(config, _do_init=True).module + +# %% +# Initialize random key and input for initialization +rng = jax.random.PRNGKey(0) +train_loader = dataprep.data_loader(rng, batch_size=1) +batch = next(iter(train_loader)) + +# %% + +# Initialize model parameters +# init of FlaxT5Module.__call__ +variables = model.init(rng, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_attention_mask'], + decoder_attention_mask=batch['decoder_attention_mask'] + ) +params = variables['params'] + + +# %% +# create an init_fn +def init_fn(rng: jax.random.PRNGKey, batch, model) -> train_state.TrainState: + init_rng, rng = jax.random.split(rng) + variables = model.init( + init_rng, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_attention_mask'], + decoder_attention_mask=batch['decoder_attention_mask'] + ) + params = variables.pop("params") + state = train_state.TrainState.create( + apply_fn=model.__call__, + params=params, + tx=adamw, + ) + return state + + + + +# %% +# we do not know the output PartitionSpec +# we perform the hack where we just initialize it just to find the outspec +init_fn_try = shard_map( + functools.partial(init_fn, model=model), + mesh, + # 2nd argument is for the model + in_specs=(P(), P("data")), + out_specs=P(), + check_rep=False +) + +# %% +rng, model_init_rng = jax.random.split(rng) +train_loader = dataprep.data_loader(model_init_rng, batch_size=batch_size) +batch = next(iter(train_loader)) + + +state_fsdp_shapes = jax.eval_shape(init_fn_try, model_init_rng, batch) +state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes) + +# print("RNG", state_fsdp_specs.rng) +print("\nParameters") +pprint(state_fsdp_specs.params) +print("\nOptimizer state") +pprint(state_fsdp_specs.opt_state[0]) + +# note: state_fsdp_specs is now ready to be used as pjit outspec + + + +# %% +# Setup train state +# state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) +state = jax.jit( + init_fn, + in_shardings=(P(), P("data")), + out_shardings=state_fsdp_specs, +) + + + +# %% + +# label smoothed cross entropy +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# MARK: train_step +# Define gradient update step fn +def train_step(state, batch): + label_smoothing_factor=0.0 + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + # loss = jax.lax.psum(loss, "batch") + # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + +# max_length = ( +# val_max_target_length if val_max_target_length is not None else model.config.max_length +# ) +# num_beams = num_beams if num_beams is not None else model.config.num_beams +# gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +# Create parallel version of the train and eval step +# only state and batch +p_train_step = jax.jit( + train_step, + # state for first, batch for second + in_shardings=(P("data"), P("data")), + out_shardings=(P("data"), P("data")), + donate_argnames=("state"), +) + + + + +# %% + + +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + +# Example batch (sharded across devices) +sharded_batch = { + 'input_ids': jax.device_put_sharded(batch['input_ids'], devices), + 'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices), + 'labels': jax.device_put_sharded(batch['labels'], devices), + 'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices), + 'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices), +} + +# Initial TrainState (pjit-ted TrainState) +sharded_state = jax.device_put_replicated(train_state, devices) + +# %% + + +rng, input_rng = jax.random.split(rng) +train_time = 0 +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + rng, data_rng = jax.random.split(rng) + train_loader = dataprep.data_loader(data_rng, batch_size=batch_size) + steps_per_epoch = training_size // train_batch_size + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + # batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_time = time.time() - train_start + + train_metric = unreplicate(train_metric) + train_metric['loss'].block_until_ready() + + + + epochs.write( + # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " + f"Epoch... ({epoch + 1}/{num_epochs} | " + # f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() +# %% + +# output_dir = save_path +# # save checkpoint after each epoch and push checkpoint to the hub +# if jax.process_index() == 0: +# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) +# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) +# model.save_pretrained(output_dir, params=params) +# tokenizer.save_pretrained(output_dir)