Feat: flax pjit example
This commit is contained in:
parent
ad5cf7735f
commit
429e1742ab
|
@ -98,21 +98,21 @@ class DataPrepare():
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that data fits
|
# check that data fits
|
||||||
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
# for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
||||||
int_array: np.array = train_dataset[name]
|
# int_array: np.array = train_dataset[name]
|
||||||
if np.all((int_array >= 0) & (int_array <= 65535)):
|
# if np.all((int_array >= 0) & (int_array <= 65535)):
|
||||||
continue
|
# continue
|
||||||
else:
|
# else:
|
||||||
raise ValueError("Values are out of range for uint16")
|
# raise ValueError("Values are out of range for uint16")
|
||||||
|
|
||||||
# change to compact datatypes
|
# change to compact datatypes
|
||||||
features = train_dataset.features.copy()
|
# features = train_dataset.features.copy()
|
||||||
features['input_ids'] = Sequence(Value('uint16'))
|
# features['input_ids'] = Sequence(Value('uint16'))
|
||||||
features['attention_mask'] = Sequence(Value('bool'))
|
# features['attention_mask'] = Sequence(Value('uint16'))
|
||||||
features['labels'] = Sequence(Value('uint16'))
|
# features['labels'] = Sequence(Value('uint16'))
|
||||||
features['decoder_input_ids'] = Sequence(Value('uint16'))
|
# features['decoder_input_ids'] = Sequence(Value('uint16'))
|
||||||
features['decoder_attention_mask'] = Sequence(Value('bool'))
|
# features['decoder_attention_mask'] = Sequence(Value('uint16'))
|
||||||
train_dataset = train_dataset.cast(features)
|
# train_dataset = train_dataset.cast(features)
|
||||||
# assign the dataset to train_dataset
|
# assign the dataset to train_dataset
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
|
||||||
|
@ -140,15 +140,15 @@ class DataPrepare():
|
||||||
|
|
||||||
for idx in batch_idx:
|
for idx in batch_idx:
|
||||||
batch = dataset[idx]
|
batch = dataset[idx]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: v for k, v in batch.items()}
|
||||||
|
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
# testing out the class
|
# testing out the class
|
||||||
# # %%
|
# %%
|
||||||
# # init object
|
# init object
|
||||||
# # e.g. Config
|
# e.g. Config
|
||||||
# data_config = ConfigDict(
|
# data_config = ConfigDict(
|
||||||
# dict(
|
# dict(
|
||||||
# max_length=86,
|
# max_length=86,
|
||||||
|
@ -157,7 +157,9 @@ class DataPrepare():
|
||||||
# )
|
# )
|
||||||
# )
|
# )
|
||||||
#
|
#
|
||||||
# dataprep = DataPrepare(split_datasets, data_config)
|
# from datasets import load_from_disk
|
||||||
|
# split_datasets = load_from_disk(file_path)
|
||||||
|
# dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
#
|
#
|
||||||
# # %%
|
# # %%
|
||||||
# seed = 117
|
# seed = 117
|
||||||
|
|
|
@ -0,0 +1,456 @@
|
||||||
|
# MARK: START
|
||||||
|
# %%
|
||||||
|
# let's make 8-device simulator
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataload import DataPrepare
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
|
||||||
|
# MARK: sharding example
|
||||||
|
# %%
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# define a layer
|
||||||
|
class DotReluDot(nn.Module):
|
||||||
|
depth: int
|
||||||
|
dense_init: Callable = nn.initializers.xavier_normal()
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
|
||||||
|
# y has shape (x.shape[-1], self.depth)
|
||||||
|
# we replicate data across devices
|
||||||
|
# but we shard the layer across the model axes
|
||||||
|
y = nn.Dense(self.depth,
|
||||||
|
kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
|
||||||
|
use_bias=False, # or overwrite with `bias_init`
|
||||||
|
)(x)
|
||||||
|
|
||||||
|
y = jax.nn.relu(y)
|
||||||
|
# Force a local sharding annotation.
|
||||||
|
# annotate intermediate variables to force a particular sharding pattern
|
||||||
|
# when ideal constraint is known
|
||||||
|
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
|
||||||
|
|
||||||
|
W2 = self.param(
|
||||||
|
'W2',
|
||||||
|
nn.with_partitioning(self.dense_init, ('model', None)),
|
||||||
|
(self.depth, x.shape[-1]))
|
||||||
|
|
||||||
|
z = jnp.dot(y, W2)
|
||||||
|
# Force a local sharding annotation.
|
||||||
|
z = jax.lax.with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))
|
||||||
|
|
||||||
|
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
|
||||||
|
return z, None
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class MLP(nn.Module):
|
||||||
|
num_layers: int
|
||||||
|
depth: int
|
||||||
|
use_scan: bool
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
if self.use_scan:
|
||||||
|
x, _ = nn.scan(
|
||||||
|
DotReluDot,
|
||||||
|
length=self.num_layers,
|
||||||
|
variable_axes={"params": 0},
|
||||||
|
split_rngs={"params": True},
|
||||||
|
metadata_params={nn.PARTITION_NAME: None}
|
||||||
|
)(self.depth)(x)
|
||||||
|
else:
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
x, _ = DotReluDot(self.depth)(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# MLP hyperparameters.
|
||||||
|
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
|
||||||
|
# Create fake inputs.
|
||||||
|
x = jnp.ones((BATCH, DEPTH))
|
||||||
|
# Initialize a PRNG key.
|
||||||
|
k = jax.random.key(117)
|
||||||
|
|
||||||
|
# Create an Optax optimizer.
|
||||||
|
optimizer = optax.adam(learning_rate=0.001)
|
||||||
|
# Instantiate the model.
|
||||||
|
model = MLP(LAYERS, DEPTH, USE_SCAN)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# specify sharding
|
||||||
|
|
||||||
|
# shard data
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
x = jax.device_put(x, x_sharding)
|
||||||
|
jax.debug.visualize_array_sharding(x)
|
||||||
|
|
||||||
|
# shard output
|
||||||
|
# we will shard state by tracking its output upon jax.eval_shape after init
|
||||||
|
# define an init function to return a TrainState
|
||||||
|
def init_fn(key, x, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
variables = model.init(key, x) # Initialize the model.
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=variables['params'],
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Create an abstract closure to wrap the function before feeding it in
|
||||||
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
||||||
|
# eval_shape(fn, rng_key, x)
|
||||||
|
# used to perform shape inference
|
||||||
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
|
||||||
|
|
||||||
|
|
||||||
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
||||||
|
# of the `init_fn`.
|
||||||
|
# flan.linen.get_sharding
|
||||||
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
print(state_sharding)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
||||||
|
out_shardings=state_sharding
|
||||||
|
)
|
||||||
|
|
||||||
|
initialized_state = jit_init_fn(k, x, model, optimizer)
|
||||||
|
# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
|
||||||
|
# print(f'Sharding of {weight}: {partitioned.names}')
|
||||||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||||||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# inspect module output
|
||||||
|
# the params are actually linen.Partitioned objects
|
||||||
|
# the Partition objects are actually wrapping jax.Array
|
||||||
|
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
|
||||||
|
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
|
||||||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
|
||||||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
|
||||||
|
# %%
|
||||||
|
# Say for some unknown reason you want to make the whole param tree all-zero
|
||||||
|
unboxed_params = nn.meta.unbox(initialized_state.params)
|
||||||
|
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
|
||||||
|
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
|
||||||
|
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0
|
||||||
|
# %%
|
||||||
|
# check the jax.sharding of each parameter
|
||||||
|
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding)
|
||||||
|
print(initialized_state.step)
|
||||||
|
print(initialized_state.step.sharding)
|
||||||
|
# %%
|
||||||
|
# example of computation on pytree
|
||||||
|
diff = jax.tree.map(
|
||||||
|
lambda a, b: a - b,
|
||||||
|
initialized_state.params['DotReluDot_0'],
|
||||||
|
initialized_state.params['DotReluDot_0']
|
||||||
|
)
|
||||||
|
print(jax.tree.map(jnp.shape, diff))
|
||||||
|
diff_array = diff['Dense_0']['kernel'].value
|
||||||
|
print(type(diff_array))
|
||||||
|
print(diff_array.shape)
|
||||||
|
# %%
|
||||||
|
# compile train step
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=state_sharding
|
||||||
|
)
|
||||||
|
def train_step(state, x):
|
||||||
|
def loss_unrolled(params):
|
||||||
|
y = model.apply({'params': params}, x)
|
||||||
|
return y.sum()
|
||||||
|
grad_fn = jax.grad(loss_unrolled)
|
||||||
|
grads = grad_fn(state.params)
|
||||||
|
state = state.apply_gradients(grads=grads)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# this trains for 1 step
|
||||||
|
# with mesh: # not strictly necessary in this case
|
||||||
|
# with mesh block is useful for explicit scope for device sharding
|
||||||
|
# but mesh management is automatic via jit sharding annotations
|
||||||
|
new_state = train_step(initialized_state, x)
|
||||||
|
|
||||||
|
print(f'Sharding of Weight 1:')
|
||||||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||||||
|
print(f'Sharding of Weight 2:')
|
||||||
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# compile inference step
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=x_sharding
|
||||||
|
)
|
||||||
|
def apply_fn(state, x):
|
||||||
|
return state.apply_fn({'params': state.params}, x)
|
||||||
|
|
||||||
|
# this infers for 1 step
|
||||||
|
with mesh:
|
||||||
|
y = apply_fn(new_state, x)
|
||||||
|
print(type(y))
|
||||||
|
print(y.dtype)
|
||||||
|
print(y.shape)
|
||||||
|
jax.debug.visualize_array_sharding(y)
|
||||||
|
# %%
|
||||||
|
# profiling
|
||||||
|
# measure performance
|
||||||
|
import timeit
|
||||||
|
def block_all(xs):
|
||||||
|
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
|
||||||
|
return xs
|
||||||
|
|
||||||
|
# with mesh:
|
||||||
|
t = timeit.timeit("block_all(train_step(initialized_state, x))", globals=globals(), number=10)
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
# MARK: logical sharding
|
||||||
|
# %%
|
||||||
|
# logical axis annotation
|
||||||
|
# why?
|
||||||
|
# rather than just be fixed with 'data' and 'model', we can annotate with logical names
|
||||||
|
# then map these logical names back to 'data' and 'model',
|
||||||
|
# or be more flexible with more axes
|
||||||
|
#
|
||||||
|
# we will substitute with the following:
|
||||||
|
# flax.linen.with_partitioning -> flax.linen.with_logical_partitioning
|
||||||
|
# flax.lax.with_sharding_constraint -> flax.linen.with_logical_constraint
|
||||||
|
|
||||||
|
class LogicalDotReluDot(nn.Module):
|
||||||
|
depth: int
|
||||||
|
dense_init: Callable = nn.initializers.xavier_normal()
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
y = nn.Dense(
|
||||||
|
self.depth,
|
||||||
|
# use of logical partitioning here
|
||||||
|
# kernel_init is the initializer function for the weight matrix
|
||||||
|
kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
|
||||||
|
use_bias=False, # or overwrite with `bias_init`
|
||||||
|
)(x)
|
||||||
|
|
||||||
|
y = jax.nn.relu(y)
|
||||||
|
# Force a local sharding annotation.
|
||||||
|
y = jax.lax.with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))
|
||||||
|
|
||||||
|
W2 = self.param(
|
||||||
|
'W2',
|
||||||
|
nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
|
||||||
|
(self.depth, x.shape[-1]))
|
||||||
|
|
||||||
|
z = jnp.dot(y, W2)
|
||||||
|
# Force a local sharding annotation.
|
||||||
|
z = nn.with_logical_constraint(z, ('batch', 'embed'))
|
||||||
|
return z, None
|
||||||
|
|
||||||
|
class LogicalMLP(nn.Module):
|
||||||
|
num_layers: int
|
||||||
|
depth: int
|
||||||
|
use_scan: bool
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
if self.use_scan:
|
||||||
|
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
|
||||||
|
variable_axes={"params": 0},
|
||||||
|
split_rngs={"params": True},
|
||||||
|
metadata_params={nn.PARTITION_NAME: 'layer'}
|
||||||
|
)(self.depth)(x)
|
||||||
|
else:
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
x, _ = LogicalDotReluDot(self.depth)(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we initialize the model with the eval_shape method
|
||||||
|
# but we need to perform a rule to replace logical axes with real axes
|
||||||
|
rules =(
|
||||||
|
('batch', 'data'),
|
||||||
|
('hidden', 'model')
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
|
||||||
|
|
||||||
|
logical_abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=logical_model, optimizer=optimizer),
|
||||||
|
k,
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
|
||||||
|
# linen.get_partition_spec
|
||||||
|
# extracts a partitionspec tree from a pytree containing partitioned values
|
||||||
|
# linen.Partitioned
|
||||||
|
# wrapper for partitioning metadata
|
||||||
|
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
|
||||||
|
print(
|
||||||
|
"annotations are logical, not mesh specific",
|
||||||
|
logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel']
|
||||||
|
)
|
||||||
|
|
||||||
|
# we convert our logical_state_spec to a logical_state_sharding
|
||||||
|
# with defined rules
|
||||||
|
logical_state_sharding = nn.logical_to_mesh_sharding(
|
||||||
|
logical_state_spec,
|
||||||
|
mesh,
|
||||||
|
rules
|
||||||
|
)
|
||||||
|
print('sharding annotations are mesh-specific: ',
|
||||||
|
logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# with a working state_sharding object, we can init
|
||||||
|
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
|
||||||
|
in_shardings=(mesh_sharding(()), x_sharding), # PRNG key and x
|
||||||
|
out_shardings=logical_state_sharding)
|
||||||
|
|
||||||
|
logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# MARK: saving checkpoint
|
||||||
|
# https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing
|
||||||
|
# let us save the model
|
||||||
|
# since we already have a mesh, we will skip the making of the mesh
|
||||||
|
|
||||||
|
from typing import Optional, Any
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
from jax import random, numpy as jnp
|
||||||
|
|
||||||
|
import flax
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.training import checkpoints, train_state
|
||||||
|
from flax import struct, serialization
|
||||||
|
import orbax.checkpoint
|
||||||
|
|
||||||
|
import optax
|
||||||
|
|
||||||
|
ckpt_dir = '/tmp/flax_ckpt'
|
||||||
|
|
||||||
|
if os.path.exists(ckpt_dir):
|
||||||
|
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# make up some stuff
|
||||||
|
# A simple model with one linear layer.
|
||||||
|
key1, key2 = random.split(random.key(0))
|
||||||
|
x1 = random.normal(key1, (5,)) # A simple JAX array.
|
||||||
|
model = nn.Dense(features=3)
|
||||||
|
variables = model.init(key2, x1)
|
||||||
|
|
||||||
|
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
|
||||||
|
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
|
||||||
|
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
|
||||||
|
state = train_state.TrainState.create(
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=variables['params'],
|
||||||
|
tx=tx)
|
||||||
|
# Perform a simple gradient update similar to the one during a normal training workflow.
|
||||||
|
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
|
||||||
|
|
||||||
|
# Some arbitrary nested pytree with a dictionary and a NumPy array.
|
||||||
|
config = {'dimensions': np.array([5, 3])}
|
||||||
|
|
||||||
|
# Bundle everything together.
|
||||||
|
ckpt = {'model': state, 'config': config, 'data': [x1]}
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# single host save with orbax
|
||||||
|
from flax.training import orbax_utils
|
||||||
|
|
||||||
|
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||||
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
||||||
|
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# multi-process checkpointing
|
||||||
|
# aka checkpointing for sharding
|
||||||
|
# The reference doesn't need to be as large as your checkpoint!
|
||||||
|
# Just make sure it has the `.sharding` you want.
|
||||||
|
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html
|
||||||
|
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html
|
||||||
|
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
from etils import epath
|
||||||
|
|
||||||
|
path = epath.Path('/tmp/async_checkpoint')
|
||||||
|
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
|
||||||
|
ckptr.save(path, args=ocp.args.StandardSave(train_state))
|
||||||
|
### Do some other work...
|
||||||
|
ckptr.wait_until_finished()
|
|
@ -737,7 +737,8 @@ start_time = time.time()
|
||||||
for _ in range(15):
|
for _ in range(15):
|
||||||
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
|
state_fsdp, metrics_fsdp = train_step_fsdp_fn(
|
||||||
state_fsdp,
|
state_fsdp,
|
||||||
metrics_fsdp, batch)
|
metrics_fsdp,
|
||||||
|
batch)
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
print(duration)
|
print(duration)
|
||||||
|
|
||||||
|
@ -747,7 +748,8 @@ final_metrics_fsdp = jax.tree.map(
|
||||||
metric_shapes)
|
metric_shapes)
|
||||||
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
|
state_fsdp, final_metrics_fsdp = train_step_fsdp_fn(
|
||||||
state_fsdp,
|
state_fsdp,
|
||||||
final_metrics_fsdp, batch)
|
final_metrics_fsdp,
|
||||||
|
batch)
|
||||||
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
|
print_metrics(final_metrics_fsdp, "FSDP - Final metrics")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Set this to True to run the model on CPU only.
|
# Set this to True to run the model on CPU only.
|
||||||
USE_CPU_ONLY = True
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
flags = os.environ.get("XLA_FLAGS", "")
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
if USE_CPU_ONLY:
|
if USE_CPU_ONLY:
|
||||||
|
@ -20,10 +20,10 @@ else:
|
||||||
# GPU flags
|
# GPU flags
|
||||||
flags += (
|
flags += (
|
||||||
"--xla_gpu_enable_triton_softmax_fusion=true "
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
"--xla_gpu_triton_gemm_any=false "
|
# "--xla_gpu_triton_gemm_any=false "
|
||||||
"--xla_gpu_enable_async_collectives=true "
|
# "--xla_gpu_enable_async_collectives=true "
|
||||||
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
# "--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
"--xla_gpu_enable_highest_priority_async_stream=true "
|
# "--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
)
|
)
|
||||||
os.environ["XLA_FLAGS"] = flags
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
@ -37,7 +37,8 @@ import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.experimental.shard_map import shard_map
|
from jax.experimental.shard_map import shard_map
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from ml_collections import ConfigDict
|
from ml_collections import ConfigDict
|
||||||
import optax
|
import optax
|
||||||
|
@ -148,7 +149,7 @@ dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
# seed = 117
|
# seed = 117
|
||||||
# rng = jax.random.PRNGKey(seed)
|
# rng = jax.random.PRNGKey(seed)
|
||||||
# train_loader = dataprep.data_loader(rng, batch_size=1)
|
# train_loader = dataprep.data_loader(rng, batch_size=1)
|
||||||
|
# batch = next(iter(train_loader))
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# model
|
# model
|
||||||
|
@ -161,8 +162,8 @@ config = T5Config()
|
||||||
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||||
# then pass the mask as follows
|
# then pass the mask as follows
|
||||||
from flax import traverse_util
|
from flax import traverse_util
|
||||||
|
|
||||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
|
||||||
# useful for transformer model
|
# useful for transformer model
|
||||||
model.enable_gradient_checkpointing()
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
@ -174,6 +175,21 @@ mask = {
|
||||||
mask = traverse_util.unflatten_dict(mask)
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
model.params = model.to_bf16(model.params, mask)
|
model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from pjit_partition import set_partitions
|
||||||
|
|
||||||
|
devices = np.asarray(jax.devices())
|
||||||
|
mesh_axis_names = ('data')
|
||||||
|
mesh = Mesh(devices, 'batch')
|
||||||
|
sharding = NamedSharding(mesh, P(mesh_axis_names))
|
||||||
|
replicated_sharding = NamedSharding(mesh, P())
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Model
|
# # Model
|
||||||
|
@ -243,22 +259,9 @@ adamw = optax.adamw(
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Training functions
|
# state will serve as our "params"
|
||||||
class TrainState(train_state.TrainState):
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
dropout_rng: jnp.ndarray
|
|
||||||
|
|
||||||
# easy way to achieve data parallelism
|
|
||||||
# also achieves folding of rng keys
|
|
||||||
def replicate(self):
|
|
||||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
||||||
|
|
||||||
# set bf16 for model params
|
|
||||||
# model.params = model.to_bf16(model.params)
|
|
||||||
# Cast parameters to bfloat16 if desired
|
|
||||||
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
||||||
|
|
||||||
# Setup train state
|
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
|
||||||
|
|
||||||
# label smoothed cross entropy
|
# label smoothed cross entropy
|
||||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
@ -283,9 +286,10 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
num_labels = padding_mask.sum()
|
num_labels = padding_mask.sum()
|
||||||
return loss, num_labels
|
return loss, num_labels
|
||||||
|
|
||||||
|
# MARK: train_step
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
@jax.jit
|
def train_step(state, batch):
|
||||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
label_smoothing_factor=0.0
|
||||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
|
||||||
def compute_loss(params):
|
def compute_loss(params):
|
||||||
|
@ -311,27 +315,22 @@ def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
return new_state, metrics
|
return new_state, metrics
|
||||||
|
|
||||||
# Define generation function
|
# max_length = (
|
||||||
max_length = (
|
# val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
# )
|
||||||
)
|
# num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
||||||
|
|
||||||
# def generate_step(params, batch):
|
|
||||||
# model.params = params
|
|
||||||
# output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
|
||||||
# return output_ids.sequences
|
|
||||||
|
|
||||||
# Create parallel version of the train and eval step
|
# Create parallel version of the train and eval step
|
||||||
p_train_step = jax.pmap(
|
# only state and batch
|
||||||
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
p_train_step = jax.jit(
|
||||||
|
train_step,
|
||||||
|
# state for first, batch for second
|
||||||
|
in_shardings=(P("data"), P("data")),
|
||||||
|
out_shardings=(P("data"), P("data")),
|
||||||
|
donate_argnames=("state"),
|
||||||
)
|
)
|
||||||
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
|
||||||
# p_generate_step = jax.pmap(generate_step, "batch")
|
|
||||||
|
|
||||||
# Replicate the train state on each device
|
|
||||||
state = state.replicate()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -349,6 +348,21 @@ print(f" Total optimization steps = {total_train_steps}")
|
||||||
# %%
|
# %%
|
||||||
# jax.profiler.start_trace("./traces")
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
# Example batch (sharded across devices)
|
||||||
|
sharded_batch = {
|
||||||
|
'input_ids': jax.device_put_sharded(batch['input_ids'], devices),
|
||||||
|
'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices),
|
||||||
|
'labels': jax.device_put_sharded(batch['labels'], devices),
|
||||||
|
'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices),
|
||||||
|
'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initial TrainState (pjit-ted TrainState)
|
||||||
|
sharded_state = jax.device_put_replicated(train_state, devices)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
@ -363,7 +377,7 @@ for epoch in epochs:
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
batch = next(train_loader)
|
batch = next(train_loader)
|
||||||
batch = shard(batch)
|
# batch = shard(batch)
|
||||||
state, train_metric = p_train_step(state, batch)
|
state, train_metric = p_train_step(state, batch)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
|
@ -0,0 +1,541 @@
|
||||||
|
# %% [markdown]
|
||||||
|
# # T5 implementation using jax with pjit
|
||||||
|
|
||||||
|
|
||||||
|
# MARK: START
|
||||||
|
# %%
|
||||||
|
# let's make 8-device simulator
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=8" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh
|
||||||
|
from jax.experimental.pjit import pjit
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataload import DataPrepare
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# import jax
|
||||||
|
# import jax.numpy as jnp
|
||||||
|
# import optax
|
||||||
|
# import numpy as np
|
||||||
|
# from functools import partial
|
||||||
|
# from typing import Callable, Optional
|
||||||
|
# import math
|
||||||
|
#
|
||||||
|
# # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
|
# jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
# # jax.config.update("jax_enable_x64", False)
|
||||||
|
# # enable cache
|
||||||
|
# jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
#
|
||||||
|
# from flax import jax_utils, traverse_util
|
||||||
|
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
# from flax.training import train_state
|
||||||
|
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
# import flax.core
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# get platform type
|
||||||
|
from jax.lib import xla_bridge
|
||||||
|
print(xla_bridge.get_backend().platform)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
|
save_path = 't5_80_1_bf16'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 384 # 384 is the best
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 2e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=86,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=1)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
batch
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# model
|
||||||
|
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from transformers import T5Config
|
||||||
|
|
||||||
|
config = T5Config()
|
||||||
|
|
||||||
|
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||||
|
# then pass the mask as follows
|
||||||
|
from flax import traverse_util
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
||||||
|
|
||||||
|
# useful for transformer model
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# enable bf16 except for layer_norm
|
||||||
|
flat_params = traverse_util.flatten_dict(model.params)
|
||||||
|
mask = {
|
||||||
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
||||||
|
t5_module = model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jax.tree.map(jnp.shape, model.params)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from pjit_partition import set_partitions
|
||||||
|
|
||||||
|
params = model.params
|
||||||
|
data_partition_specs = PartitionSpec()
|
||||||
|
extra_param_keys = list(model._missing_keys)
|
||||||
|
initial_partition_specs = set_partitions(params)
|
||||||
|
# this is the partition spec we will use
|
||||||
|
filled_param_partition_specs = set_partitions(params, extra_keys=extra_param_keys)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# let us see the param_partition_spec
|
||||||
|
filled_param_partition_specs
|
||||||
|
|
||||||
|
# %% let us set up the mesh
|
||||||
|
|
||||||
|
from jax.sharding import Mesh
|
||||||
|
devices = np.asarray(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# mp: model/tensor parallelism
|
||||||
|
# dp: data parallelism
|
||||||
|
# we just use 'data' as a common axis for data and model params
|
||||||
|
mesh_axis_names = ("data")
|
||||||
|
print("Logical mesh:", devices)
|
||||||
|
|
||||||
|
mesh = Mesh(devices, mesh_axis_names)
|
||||||
|
|
||||||
|
# it is technically possible to use pjit_partition to set special partition rules
|
||||||
|
# e.g. by param size
|
||||||
|
# but for now just move on
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Model
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Initialize our training
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimization functions
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Training functions
|
||||||
|
# class TrainState(train_state.TrainState):
|
||||||
|
# dropout_rng: jnp.ndarray
|
||||||
|
#
|
||||||
|
# # easy way to achieve data parallelism
|
||||||
|
# # also achieves folding of rng keys
|
||||||
|
# def replicate(self):
|
||||||
|
# return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||||
|
|
||||||
|
# set bf16 for model params
|
||||||
|
# model.params = model.to_bf16(model.params)
|
||||||
|
# Cast parameters to bfloat16 if desired
|
||||||
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
|
||||||
|
# state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# see if we can init the model
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from transformers import FlaxT5Model, T5Config
|
||||||
|
config = T5Config.from_pretrained('t5-base')
|
||||||
|
model = FlaxT5Model(config, _do_init=True).module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Initialize random key and input for initialization
|
||||||
|
rng = jax.random.PRNGKey(0)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=1)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Initialize model parameters
|
||||||
|
# init of FlaxT5Module.__call__
|
||||||
|
variables = model.init(rng,
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
)
|
||||||
|
params = variables['params']
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create an init_fn
|
||||||
|
def init_fn(rng: jax.random.PRNGKey, batch, model) -> train_state.TrainState:
|
||||||
|
init_rng, rng = jax.random.split(rng)
|
||||||
|
variables = model.init(
|
||||||
|
init_rng,
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
)
|
||||||
|
params = variables.pop("params")
|
||||||
|
state = train_state.TrainState.create(
|
||||||
|
apply_fn=model.__call__,
|
||||||
|
params=params,
|
||||||
|
tx=adamw,
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we do not know the output PartitionSpec
|
||||||
|
# we perform the hack where we just initialize it just to find the outspec
|
||||||
|
init_fn_try = shard_map(
|
||||||
|
functools.partial(init_fn, model=model),
|
||||||
|
mesh,
|
||||||
|
# 2nd argument is for the model
|
||||||
|
in_specs=(P(), P("data")),
|
||||||
|
out_specs=P(),
|
||||||
|
check_rep=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
rng, model_init_rng = jax.random.split(rng)
|
||||||
|
train_loader = dataprep.data_loader(model_init_rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
|
||||||
|
state_fsdp_shapes = jax.eval_shape(init_fn_try, model_init_rng, batch)
|
||||||
|
state_fsdp_specs = nn.get_partition_spec(state_fsdp_shapes)
|
||||||
|
|
||||||
|
# print("RNG", state_fsdp_specs.rng)
|
||||||
|
print("\nParameters")
|
||||||
|
pprint(state_fsdp_specs.params)
|
||||||
|
print("\nOptimizer state")
|
||||||
|
pprint(state_fsdp_specs.opt_state[0])
|
||||||
|
|
||||||
|
# note: state_fsdp_specs is now ready to be used as pjit outspec
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Setup train state
|
||||||
|
# state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
state = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
in_shardings=(P(), P("data")),
|
||||||
|
out_shardings=state_fsdp_specs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# label smoothed cross entropy
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.sum()
|
||||||
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# MARK: train_step
|
||||||
|
# Define gradient update step fn
|
||||||
|
def train_step(state, batch):
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
|
||||||
|
def compute_loss(params):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
# loss = jax.lax.psum(loss, "batch")
|
||||||
|
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||||
|
|
||||||
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
return new_state, metrics
|
||||||
|
|
||||||
|
# max_length = (
|
||||||
|
# val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||||
|
# )
|
||||||
|
# num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
|
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
|
# Create parallel version of the train and eval step
|
||||||
|
# only state and batch
|
||||||
|
p_train_step = jax.jit(
|
||||||
|
train_step,
|
||||||
|
# state for first, batch for second
|
||||||
|
in_shardings=(P("data"), P("data")),
|
||||||
|
out_shardings=(P("data"), P("data")),
|
||||||
|
donate_argnames=("state"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
# Example batch (sharded across devices)
|
||||||
|
sharded_batch = {
|
||||||
|
'input_ids': jax.device_put_sharded(batch['input_ids'], devices),
|
||||||
|
'attention_mask': jax.device_put_sharded(batch['attention_mask'], devices),
|
||||||
|
'labels': jax.device_put_sharded(batch['labels'], devices),
|
||||||
|
'decoder_input_ids': jax.device_put_sharded(batch['decoder_input_ids'], devices),
|
||||||
|
'decoder_attention_mask': jax.device_put_sharded(batch['decoder_attention_mask'], devices),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initial TrainState (pjit-ted TrainState)
|
||||||
|
sharded_state = jax.device_put_replicated(train_state, devices)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
rng, data_rng = jax.random.split(rng)
|
||||||
|
train_loader = dataprep.data_loader(data_rng, batch_size=batch_size)
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
# batch = shard(batch)
|
||||||
|
state, train_metric = p_train_step(state, batch)
|
||||||
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
train_metric = unreplicate(train_metric)
|
||||||
|
train_metric['loss'].block_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
# f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# output_dir = save_path
|
||||||
|
# # save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
# if jax.process_index() == 0:
|
||||||
|
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
# model.save_pretrained(output_dir, params=params)
|
||||||
|
# tokenizer.save_pretrained(output_dir)
|
Loading…
Reference in New Issue