diff --git a/.gitignore b/.gitignore index ef0f135..f2e3f61 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ t5_*/ exports/ modified_t5_model/ traces/ +ruff.toml +settings.json diff --git a/learn_flax/flax_basics.py b/learn_flax/flax_basics.py new file mode 100644 index 0000000..b4b6e23 --- /dev/null +++ b/learn_flax/flax_basics.py @@ -0,0 +1,332 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# --- + +# %% [markdown] +# # Flax Basics + +# %% [markdown] +# # Linear regression with Flax +# In this example we perform linear regression with gradient descent. It is +# done by first coming up with a model, generating data from the model (because +# data generated from known parameters will "match" the model predictions +# perfectly albeit with a little noise). + + +# %% +# import packages +import jax +from typing import Any, Callable, Sequence +from jax import random, numpy as jnp +import flax +from flax import linen as nn + +# %% +# define model +# notice there is no input size declaration +model = nn.Dense(features=5) + +# model paramaters & initialization +key1, key2 = random.split(random.key(0)) +# dummy input to trigger shape inference +x = random.normal(key1, (10,)) +params = model.init(key2, x) # initialization call + +# check output shapes +print(jax.tree_util.tree_map(lambda x: x.shape, params)) + +model.apply(params, x) + + +# %% +# example of gradient descent + +# set problem dimensions +n_samples: int = 20 +x_dim: int = 10 +y_dim: int = 5 + +# generate random ground truth W and b +key = random.key(0) +k1, k2 = random.split(key) +W = random.normal(k1, (x_dim, y_dim)) +b = random.normal(k2, (y_dim,)) +# store parameters in frozendict pytree +true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}}) +print(jax.tree_util.tree_map(lambda x: x.shape, true_params)) + +# Generate samples with additional noise. +key_sample, key_noise = random.split(k1) +x_samples = random.normal(key_sample, (n_samples, x_dim)) +# Wx + b + epsilon (noise) +y_samples = ( + jnp.dot(x_samples, W) + b + 0.1 * + random.normal(key_noise, (n_samples, y_dim)) +) +print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) + + +# %% +# get model output +@jax.jit +def mse(params, x_batched, y_batched): + # define squared loss for a single pair + def squared_error(x, y): + # use model.apply to get model_f(x) + pred = model.apply(params, x) + return jnp.inner(y - pred, y - pred) / 2.0 + # vmap squared_error function across larger array along axis 0 + return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) + + +# %% + +# perform gradient descent +learning_rate = 0.3 # Gradient step size. +print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) +# get the gradient function for backprop +loss_grad_fn = jax.value_and_grad(mse) + + +@jax.jit +def update_params(params, learning_rate, grads): + params = jax.tree_util.tree_map( + lambda p, g: p - learning_rate * g, params, grads) + return params + + +for i in range(101): + # Perform one gradient update. + loss_val, grads = loss_grad_fn(params, x_samples, y_samples) + params = update_params(params, learning_rate, grads) + if i % 10 == 0: + print(f'Loss step {i}: ', loss_val) + +# %% [markdown] +# # Optimization with Optax + +# %% +# init optax object +import optax +tx = optax.adam(learning_rate=learning_rate) +# we initialize the optimizer with model params +opt_state = tx.init(params) +loss_grad_fn = jax.value_and_grad(mse) + +# %% +# perform gradient descent with optax optimizer instead of update_params function +for i in range(101): + loss_val, grads = loss_grad_fn(params, x_samples, y_samples) + updates, opt_state = tx.update(grads, opt_state) + params = optax.apply_updates(params, updates) + if i % 10 == 0: + print('Loss step {}: '.format(i), loss_val) + +# %% [markdown] +# Serializing the result - aka saving the model +from flax import serialization +bytes_output = serialization.to_bytes(params) +dict_output = serialization.to_state_dict(params) +print('Dict output') +print(dict_output) +print('Bytes output') +print(bytes_output) + + +# %% +# load back from dict output +# we first get the param structure +model = nn.Dense(features=5) +key1, key2 = random.split(random.key(0)) +# dummy input to trigger shape inference +x = random.normal(key1, (10,)) +params = model.init(key2, x) # initialization call + +# then we get the saved weights +# from state_dict +serialization.from_state_dict(params, dict_output) +print(params) +# from bytes +serialization.from_bytes(params, bytes_output) +print(params) + + +# %% [markdown] +# Defining your own models +# A nn.Module subclass is composed of: +# +# 1. collection of data fields (argument parameters) +# 2. setup() method to setup structures to be called in the forward pass +# 3. __call__() function that implements the forward pass +# +# there is a params structure for the model in the form of a pytree of +# parameters + +# %% +# module basics +class ExplicitMLP(nn.Module): + # model arguments + # we have a list of feature sizes for the dense layer + features: Sequence[int] + + def setup(self): + self.layers = [nn.Dense(feat) for feat in self.features] + + def __call__(self, inputs): + x = inputs + # go through each layer + for i, lyr in enumerate(self.layers): + x = lyr(x) + # add an activating function at the end of each layer + if i != len(self.layers) - 1: + x = nn.relu(x) + return x + +# init the model +key1, key2 = random.split(random.key(117), 2) +# any shape is ok +x = random.uniform(key1, (4,4)) +print(x) + +# instantiate the model +model = ExplicitMLP(features=[3,4,5]) +# init model with an rng key and a test data +params = model.init(key2, x) +y = model.apply(params, x) + +print( + 'initialized parameter shapes:\n', + jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)), +) +print('output:\n', y) + +# %% +# using nn.Module with @nn.compact + +class SimpleMLP(nn.Module): + features: Sequence[int] + + # nn.compact style is declaring submodules inline + @nn.compact + def __call__(self, inputs): + x = inputs + for i, feat in enumerate(self.features): + x = nn.Dense(feat, name=f'layers_{i}')(x) + if i != len(self.features) - 1: + x = nn.relu(x) + # providing a name is optional though! + # the default autonames would be "Dense_0", "Dense_1", ... + return x + +key1, key2 = random.split(random.key(117), 2) +x = random.uniform(key1, (4,4)) + +model = SimpleMLP(features=[3,4,5]) +params = model.init(key2, x) +y = model.apply(params, x) + +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) +print('output:\n', y) + + +# %% +# creating your own layers +# this is a guide to show you how to create your own dense layer as an example +class SimpleDense(nn.Module): + features: int + kernel_init: Callable = nn.initializers.lecun_normal() + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + # utilize the nn.Module.param function to declare a module + # * because params are in __call__, shape inference is possible * + # the 'W' + kernel = self.param( + 'kernel', # name + self.kernel_init, # Initialization function + (inputs.shape[-1], self.features), # shape + ) + # 'Wx' + y = jnp.dot(inputs, kernel) + # 'Wx + b' + bias = self.param('bias', self.bias_init, (self.features,)) + y = y + bias + return y + +key1, key2 = random.split(random.key(0), 2) +x = random.uniform(key1, (4,4)) + +model = SimpleDense(features=3) +# init_fn takes (prng key, *init args, **init kwargs) +params = model.init(key2, x) +y = model.apply(params, x) + +print('initialized parameters:\n', params) +print('output:\n', y) + + +# %% +# handling state in your module +# jax does not work well with side-effects and hidden state +# jax functions are supposed to be pure functions with determinate input outputs +# we introduce "mutable" to separate the function and the state + +# note: there are 2 collections: +# 'params' for trainable, 'batch_stats' for non-trainable +class BiasAdderWithRunningMean(nn.Module): + decay: float = 0.99 + + @nn.compact + def __call__(self, x): + # easy pattern to detect if we're initializing via empty variable tree + # check for variable 'mean' in the 'batch_stats' collection + is_initialized = self.has_variable('batch_stats', 'mean') + # declare variable outside model parameters + # declare variable mean in the 'batch_stats' collection + # ra_mean now means batch_stats->mean + ra_mean = self.variable( + 'batch_stats', + 'mean', + lambda s: jnp.zeros(s), + x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5 + ) + bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) + if is_initialized: + ra_mean.value = ( + self.decay * ra_mean.value + + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True) + ) + + return x - ra_mean.value + bias + + +key1, key2 = random.split(random.key(0), 2) +x = jnp.ones((10,5)) +model = BiasAdderWithRunningMean() +variables = model.init(key1, x) +print('initialized variables:\n', variables) +# setting mutable means you also get the updated_state of mutable +# hence there is no true hidden state in the function +# this serves as a escape hatch for implementing pure functions with state +# state is therefore kept away from the module +y, updated_state = model.apply(variables, x, mutable=['batch_stats']) +print('updated state:\n', updated_state) + + +# %% +for val in [1.0, 2.0, 3.0]: + x = val * jnp.ones((10, 5)) + y, updated_state = model.apply(variables, x, mutable=['batch_stats']) + old_state, params = flax.core.pop(variables, 'params') + variables = flax.core.freeze({'params': params, **updated_state}) + print('updated state:\n', updated_state) # Shows only the mutable part + +# %% diff --git a/t5_jax.py b/t5_jax.py index e35e0e5..0fd7c89 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -96,7 +96,7 @@ split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 -num_epochs = 80 +num_epochs = 5 batch_size = 384 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size @@ -314,16 +314,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf for idx in batch_idx: batch = dataset[idx] - batch = {k: np.array(v) for k, v in batch.items()} + batch = {k: jnp.array(v) for k, v in batch.items()} yield batch # %% [markdown] # # Model - - - +# +# +# # %% @@ -442,16 +442,17 @@ def train_step(state, batch, label_smoothing_factor=0.0): num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples - loss = jax.lax.psum(loss, "batch") - loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + # loss = jax.lax.psum(loss, "batch") + # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # true grad = total grad / total samples grad = jax.lax.psum(grad, "batch") grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) - metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - return new_state, metrics + # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + # return new_state, metrics + return new_state # Define generation function max_length = ( @@ -505,19 +506,20 @@ for epoch in epochs: for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) - state, train_metric = p_train_step(state, batch) - train_metrics.append(train_metric) + state = p_train_step(state, batch) + # train_metrics.append(train_metric) train_time = time.time() - train_start - train_metric = unreplicate(train_metric) - train_metric['loss'].block_until_ready() + # train_metric = unreplicate(train_metric) + # train_metric['loss'].block_until_ready() epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " - f"Learning Rate:{train_metric['learning_rate']}, " + # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " + f"Epoch... ({epoch + 1}/{num_epochs} | " + # f"Learning Rate:{train_metric['learning_rate']}, " f"Last train time: {train_time})" ) # jax.profiler.stop_trace() diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index 49a751c..1183864 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -30,7 +30,7 @@ from typing import Callable, Optional import math # jax.config.update("jax_default_matmul_precision", "tensorfloat32") -jax.config.update("jax_default_matmul_precision", "high") +jax.config.update("jax_default_matmul_precision", "bfloat16") jax.config.update("jax_enable_x64", False)