From 0762c02b31924606abac572d21eddcaee7a17988 Mon Sep 17 00:00:00 2001 From: Richard Wong Date: Sun, 6 Oct 2024 23:52:42 +0900 Subject: [PATCH] Feat: implemented attention layer in equinox --- dataload.py | 172 +++++++++ equinox/.gitignore | 1 + equinox/handling_state_equinox.py | 54 +++ equinox/handling_state_flax.py | 106 ++++++ equinox/mnist.py | 252 +++++++++++++ equinox/t5_simple_train_model.py | 591 ++++++++++++++++++++++++++++++ equinox/t5_train_model.py | 495 +++++++++++++++++++++++++ make_context_data.py | 279 ++++++++++++++ nnx/.gitignore | 1 + nnx/mnist.py | 168 +++++++++ 10 files changed, 2119 insertions(+) create mode 100644 dataload.py create mode 100644 equinox/.gitignore create mode 100644 equinox/handling_state_equinox.py create mode 100644 equinox/handling_state_flax.py create mode 100644 equinox/mnist.py create mode 100644 equinox/t5_simple_train_model.py create mode 100644 equinox/t5_train_model.py create mode 100644 make_context_data.py create mode 100644 nnx/.gitignore create mode 100644 nnx/mnist.py diff --git a/dataload.py b/dataload.py new file mode 100644 index 0000000..d80cf53 --- /dev/null +++ b/dataload.py @@ -0,0 +1,172 @@ +# %% +# Prepare dataloader for jax training +from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk +from transformers import FlaxT5ForConditionalGeneration +from datasets import ClassLabel, Value, Sequence +from ml_collections import ConfigDict +import numpy as np +import jax.numpy as jnp +import jax +import math +from typing import Optional, List, Tuple, Callable, cast + + +# file_path = 'combined_data' +# split_datasets = load_from_disk(file_path) +# training_size = len(split_datasets['train']) + +from transformers import T5TokenizerFast + +# class takes in a dataset +class DataPrepare(): + + def __init__(self, raw_dataset, config): + self.raw_dataset: Dataset = raw_dataset + self.size: int = len(raw_dataset) + self.config: ConfigDict = config + self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False) + # Define additional special tokens + # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] + additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] + # Add the additional special tokens to the tokenizer + self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + + model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") + + model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) + self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 + + + self.train_dataset = self.preprocess_function( + self.raw_dataset + ) + + + + # In Flax, for seq2seq models we need to pass `decoder_input_ids` + # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here + # for that dynamically import the `shift_tokens_right` function from the model file + + # given a dataset entry, run it through the tokenizer + # Setting padding="max_length" as we need fixed length inputs for jitted functions + def preprocess_function(self, example: Dataset): + inputs = example['input'] + targets = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + # produce input_ids and decoder_input_ids + model_inputs = self.tokenizer( + inputs, + max_length=self.config.max_length, + padding=True, + truncation=True, + return_tensors="np" + ) + # we separate it out because we need the attention mask + labels = self.tokenizer( + text_target=targets, + max_length=self.config.max_length, + padding=True, + truncation=True, + return_tensors="np" + ) + model_inputs['input_ids'] = np.asarray(model_inputs['input_ids']) + model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask']) + # for loss computation + model_inputs["labels"] = labels["input_ids"] + # make decoder input ids + # this is actually "model output" shifted right + decoder_input_ids = self.shift_tokens_right_fn( + labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id + ) + # require by model + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + # decoder_attention_mask = shift_tokens_right_fn( + # labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id + # ) + # We need decoder_attention_mask so we can ignore pad tokens in loss + model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"]) + + return model_inputs + + + # Example pad function + def _pad_to_batch_size(self, batch, target_size): + # Get the current batch size + input_ids = batch['input_ids'] + current_size = input_ids.shape[0] + if current_size < target_size: + # Calculate how much padding is needed + padding_size = target_size - current_size + # Create padding (e.g., zeros or some appropriate value) + padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D + # Concatenate to create a full batch + # repeat for all arrays in the tree + padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch) + # padded_batch = jnp.concatenate([batch, padding], axis=0) + + else: + padded_batch = batch + return padded_batch + + def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + dataset: Dataset = Dataset.from_dict(self.train_dataset) + + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + minibatch_list = np.array_split(batch_idx, steps_per_epoch) + + for minibatch in minibatch_list: + batch = dataset[minibatch] + batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()} + batch = self._pad_to_batch_size(batch, batch_size) + + yield batch + + +# # testing out the class +# # %% +# # init object +# # e.g. Config +# +# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc' +# data_config = ConfigDict( +# dict( +# max_length=86, +# pad_token_id=0, +# decoder_start_token_id=0 +# ) +# ) +# +# from datasets import load_from_disk +# split_datasets = load_from_disk(file_path) +# dataprep = DataPrepare(split_datasets['train'], data_config) +# +# # %% +# seed = 117 +# rng = jax.random.PRNGKey(seed) +# train_loader = dataprep.data_loader(rng, batch_size=32) +# +# +# +# # %% +# batch = next(train_loader) +# +# print(batch['input_ids'].shape) +# print(batch['decoder_input_ids'].shape) +# +# # %% diff --git a/equinox/.gitignore b/equinox/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/equinox/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/equinox/handling_state_equinox.py b/equinox/handling_state_equinox.py new file mode 100644 index 0000000..9f02f61 --- /dev/null +++ b/equinox/handling_state_equinox.py @@ -0,0 +1,54 @@ +# an example of stateful operations +# %% +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +import optax # https://github.com/deepmind/optax +from equinox.nn import State, StateIndex, StatefulLayer +from jaxtyping import Array + + +# %% +class Counter(eqx.Module): + # This wraps together (a) a unique dictionary key used for looking up a + # stateful value, and (b) how that stateful value should be initialised. + index: eqx.nn.StateIndex + + def __init__(self): + init_state = jnp.array(0) + self.index = eqx.nn.StateIndex(init_state) + + # eqx.nn.State stores the state of the model + # This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays. + # This class should be initialised via equinox.nn.make_with_state. + # + # Basically just a dictionary which (a) works only with StateIndex-s, and which (b) + # works around a JAX bug that prevents flattening dicts with `object()` keys, and which + # (c) does error-checking that you're using the most up-to-date version of it. + def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: + value = state.get(self.index) + new_x = x + value + + # Sets a new value for an [`equinox.nn.StateIndex`][], and returns the + # updated state. + new_state = state.set(self.index, value + 1) + return new_x, new_state + +# make_with_state is the recommended way to start a stateful model +counter, state = eqx.nn.make_with_state(Counter)() +x = jnp.array(2.3) + +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 0 + +_, state = counter(x, state) +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 1 + +_, state = counter(x, state) +num_calls = state.get(counter.index) +print(f"Called {num_calls} times.") # 2 + + +# %% diff --git a/equinox/handling_state_flax.py b/equinox/handling_state_flax.py new file mode 100644 index 0000000..694bb23 --- /dev/null +++ b/equinox/handling_state_flax.py @@ -0,0 +1,106 @@ +# %% +# introduction to how flax does stateful operations +import flax.linen as nn +import jax.numpy as jnp +import jax +import flax +from jaxtyping import Array + +# %% + +class BiasAdderWithRunningMean(nn.Module): + momentum: float = 0.9 + + @nn.compact + def __call__(self, x): + is_initialized = self.has_variable('hehe', 'mean') + print(is_initialized) + mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:]) + bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) + if is_initialized: + print(mean.value) # notice that value retains after first call + mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean( + x, axis=0, keepdims=True + ) + print(mean.value) + return mean.value + bias + + +# %% + +input_key = jax.random.PRNGKey(0) +model = BiasAdderWithRunningMean() +inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values +variables = model.init(input_key, inputs) +# Split state and params (which are updated by optimizer). +state, params = flax.core.pop(variables, 'params') +print(f"first init: {state}") +# %% +for i in range(5): + new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs + # notice how we are threading the state + # perform argument unpacking on state dictionary + output, state = model.apply({'params': params, **state}, + new_inputs, mutable=list(state.keys())) + + # mean_state = variables['batch_stats']['mean'] # Access the updated mean state + print(f"updated state {state}") + print(f"Output after input {i + 1}: {output}") + # print(f"Updated running mean state: {mean_state}") +# %% +########################################################### +# example 2 +from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal +import jax.random as random +class Foo(nn.Module): + features: int + @nn.compact + def __call__(self): + key = self.make_rng('spectral_norm_stats') + print(key) + u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features)) + return u0_variable.value + +foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)}) +Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)}) +# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32) + +# %% +model = Foo(3) + +# %% +# state is kept in self.variable, tied to the layer +output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)}) + +# %% +output, state = model.apply( + foovars, + mutable=list(foovars.keys()), + rngs={'spectral_norm_stats': random.PRNGKey(1)} +) +print(output, state) +# %% +output, state = model.apply( + state, + mutable=list(foovars.keys()), + rngs={'spectral_norm_stats': random.PRNGKey(1)} +) +# no change because input state is the same +print(output, state) +# %% +state_array = state['spectral_norm_stats']['u0'] +modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0)) +state['spectral_norm_stats']['u0'] = modified_array +# %% +# %% +output, state = model.apply( + state, + mutable=list(foovars.keys()), + rngs={'spectral_norm_stats': random.PRNGKey(1)} +) +# state takes from given state +# note the modified 0.9 value +# note how the state is not re-initialized +print(output, state) + +# %% diff --git a/equinox/mnist.py b/equinox/mnist.py new file mode 100644 index 0000000..a5887ca --- /dev/null +++ b/equinox/mnist.py @@ -0,0 +1,252 @@ +# %% +import equinox as eqx +import jax +import jax.numpy as jnp +import optax # https://github.com/deepmind/optax +import torch # https://pytorch.org +import torchvision # https://pytorch.org +from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping +# %% +# Hyperparameters + +BATCH_SIZE = 64 +LEARNING_RATE = 3e-4 +STEPS = 300 +PRINT_EVERY = 30 +SEED = 5678 + +key = jax.random.PRNGKey(SEED) + +# %% +normalise_data = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (0.5,)), + ] +) +train_dataset = torchvision.datasets.MNIST( + "MNIST", + train=True, + download=True, + transform=normalise_data, +) +test_dataset = torchvision.datasets.MNIST( + "MNIST", + train=False, + download=True, + transform=normalise_data, +) +trainloader = torch.utils.data.DataLoader( + train_dataset, batch_size=BATCH_SIZE, shuffle=True +) +testloader = torch.utils.data.DataLoader( + test_dataset, batch_size=BATCH_SIZE, shuffle=True +) + +# %% +# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like) +dummy_x, dummy_y = next(iter(trainloader)) +dummy_x = dummy_x.numpy() +dummy_y = dummy_y.numpy() +print(dummy_x.shape) # 64x1x28x28 +print(dummy_y.shape) # 64 +print(dummy_y) + +# %% +class CNN(eqx.Module): + layers: list + + def __init__(self, key): + key1, key2, key3, key4 = jax.random.split(key, 4) + # Standard CNN setup: convolutional layer, followed by flattening, + # with a small MLP on top. + self.layers = [ + eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), + eqx.nn.MaxPool2d(kernel_size=2), + jax.nn.relu, # jax functions!!! + jnp.ravel, + eqx.nn.Linear(1728, 512, key=key2), + jax.nn.sigmoid, + eqx.nn.Linear(512, 64, key=key3), + jax.nn.relu, + eqx.nn.Linear(64, 10, key=key4), + jax.nn.log_softmax, + ] + + def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: + for layer in self.layers: + x = layer(x) + return x + + +key, subkey = jax.random.split(key, 2) +model = CNN(subkey) +# %% +print(model) +# %% +# print the first layer: Conv2d +print(model.layers[0]) +# %% +# illustrated inference +def loss( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on + # a single input input image of shape (1, 28, 28). + # + # Therefore, we have to use jax.vmap, which in this case maps our model over the + # leading (batch) axis. + # + # This is an example of writing function for one input, then letting jax + # automatically vectorize over the batch dimension + pred_y = jax.vmap(model)(x) + return cross_entropy(y, pred_y) + + +def cross_entropy( + y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"] +) -> Float[Array, ""]: + # y are the true targets, and should be integers 0-9. + # pred_y are the log-softmax'd predictions. + + # take_along_axis: take from pred_y along axis 1 according to 2nd argument + # expand_dims to axis 1 makes it of shape (y_dim, 1) + # since we take along axis 1, each y (in 2nd arg) therefore takes the + # corresponding entry of each row in pred_y + pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) + return -jnp.mean(pred_y) # negative mean of relevant logits + + +# Example loss +loss_value = loss(model, dummy_x, dummy_y) +print(loss_value) +print(loss_value.shape) # scalar loss +# Example inference +output = jax.vmap(model)(dummy_x) +print(output.shape) # batch of predictions + +# %% +# This is an error! +# the reason is that model has to be parameters, but model has non-parameters +jax.value_and_grad(loss)(model, dummy_x, dummy_y) + + +# %% +# we separate out things that are params from other things +# since params are things that are arrays +# partition is doing filter(...) and filter(..., inverse=True) +params, static = eqx.partition(model, eqx.is_array) + +# %% +# lets compare the same object in both terms +print(static.layers[0]) +print(params.layers[0]) + + +# %% +# in the loss, we recombine both to form back our model +def loss2(params, static, x, y): + model = eqx.combine(params, static) + return loss(model, x, y) + + +# Now this will work! +# since the grad only looks at the first argument, this works out +loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y) +print(loss_value) + +# %% +# This will work too! +# this works the same as the previous +value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y) +print(value) + +# %% +# evaluation +loss = eqx.filter_jit(loss) # JIT our loss function from earlier! + + +@eqx.filter_jit +def compute_accuracy( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + """This function takes as input the current model + and computes the average accuracy on a batch. + """ + pred_y = jax.vmap(model)(x) + pred_y = jnp.argmax(pred_y, axis=1) + return jnp.mean(y == pred_y) + +# %% +def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): + """This function evaluates the model on the test dataset, + computing both the average loss and the average accuracy. + """ + avg_loss = 0 + avg_acc = 0 + for x, y in testloader: + x = x.numpy() + y = y.numpy() + # Note that all the JAX operations happen inside `loss` and `compute_accuracy`, + # and both have JIT wrappers, so this is fast. + avg_loss += loss(model, x, y) + avg_acc += compute_accuracy(model, x, y) + return avg_loss / len(testloader), avg_acc / len(testloader) + +# %% +evaluate(model, testloader) +# %% +# training +optim = optax.adamw(LEARNING_RATE) + +def train( + model: CNN, + trainloader: torch.utils.data.DataLoader, + testloader: torch.utils.data.DataLoader, + optim: optax.GradientTransformation, + steps: int, + print_every: int, +) -> CNN: + # Just like earlier: It only makes sense to train the arrays in our model, + # so filter out everything else. + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + # Always wrap everything -- computing gradients, running the optimiser, updating + # the model -- into a single JIT region. This ensures things run as fast as + # possible. + @eqx.filter_jit + def make_step( + model: CNN, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"], + ): + loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) + updates, opt_state = optim.update(grads, opt_state, model) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss_value + + # Loop over our training dataset as many times as we need. + def infinite_trainloader(): + while True: + yield from trainloader + + for step, (x, y) in zip(range(steps), infinite_trainloader()): + # PyTorch dataloaders give PyTorch tensors by default, + # so convert them to NumPy arrays. + x = x.numpy() + y = y.numpy() + model, opt_state, train_loss = make_step(model, opt_state, x, y) + if (step % print_every) == 0 or (step == steps - 1): + test_loss, test_accuracy = evaluate(model, testloader) + print( + f"{step=}, train_loss={train_loss.item()}, " + f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" + ) + return model + + +# %% +model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) + +# %% diff --git a/equinox/t5_simple_train_model.py b/equinox/t5_simple_train_model.py new file mode 100644 index 0000000..887f281 --- /dev/null +++ b/equinox/t5_simple_train_model.py @@ -0,0 +1,591 @@ +# %% +# package imports from equinox BERT example +import functools +from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple + +# import einops # https://github.com/arogozhnikov/einops +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +import optax # https://github.com/deepmind/optax +from datasets import load_dataset # https://github.com/huggingface/datasets +from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping +from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm +from transformers import AutoTokenizer # https://github.com/huggingface/transformers +from ml_collections import ConfigDict, FrozenConfigDict + +# helper functions for attention computation +# they are implemented with jax w/o flax +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights + +import flax.linen as nn +# %% +class T5LayerNorm(eqx.Module): + eps: float = 1e-6 + weight: jax.Array + # staticmethod forces the method to be by itself + weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones) + + def __init__( + self: eqx.Module, + hidden_size: int, + key: jax.random.PRNGKey, + # dtype: jnp.dtype = jnp.float32, + ): + # self.dtype = dtype + # self.params = { + # 'weight': self.weight_init(key, (hidden_size,), dtype) + # } + # force the use of float32 + # note that the key argument is ignored, so key is actually optional + self.weight = self.weight_init(key, (hidden_size,), jnp.float32) + + # takes in argument for hidden states so that it can fall through and remain + # a pure function + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; + No bias and no subtraction of mean + """ + # always compute in float32 for layer norm + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + +# # %% +# # testing T5LayerNorm +# key = jax.random.PRNGKey(0) +# hidden_size = 128 # Example hidden size +# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size) +# # Create some example input data +# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10 +# # Forward pass +# output = layer_norm(hidden_states) +# print("Output shape:", output.shape) + +# %% +class KaimingLinear(eqx.Module): + dtype: jnp.dtype = jnp.float32 + weights: jax.Array + + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + input_dim: int, + output_dim: int, + weights_init_std: float, + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + # the initialization strategy is to standardize on output dimension + # shapes are: (input_dim, output_dim) + self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std + + def __call__( + self, + inputs: Float[Array, " input"], + ): + hidden = jnp.dot(inputs, self.weights) + return hidden + + + +# %% +# this function fortunately supports batched operations by default due to broadcasting +class T5DenseActDense(eqx.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 + wi: jax.Array + wo: jax.Array + dropout: eqx.nn.Dropout + act: jax.nn.relu + + def __init__( + self: eqx.Module, + config: FrozenConfigDict, + dtype: jnp.dtype, + key: jax.random.PRNGKey + ): + self.config = config + self.dtype = dtype + + mlp_key, output_key = jax.random.split(key) + # the initialization strategy is to standardize on output dimension + # input + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + # shapes are: (config.d_model, config.d_ff) + # self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std + self.wi = KaimingLinear( + key=mlp_key, + input_dim=self.config.d_model, + output_dim=self.config.d_ff, + weights_init_std=wi_init_std, + dtype=self.dtype + ) + # output + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + # shapes are: (config.d_ff, config.d_model) + # self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std + self.wo = KaimingLinear( + key=mlp_key, + input_dim=self.config.d_ff, + output_dim=self.config.d_model, + weights_init_std=wo_init_std, + dtype=self.dtype + ) + + + self.dropout = eqx.nn.Dropout(self.config.dropout_rate) + # just set to relu for now since the smaller T5's use relu + self.act = jax.nn.relu + + def __call__( + self, + inputs: Float[Array, " d_model"], + enable_dropout: bool = False, + dropout_key: Optional[jax.random.PRNGKey] = None, + ) -> Float[Array, " d_model"]: + hidden = self.wi(inputs) + # hidden = jnp.dot(inputs, self.wi) + hidden = self.act(hidden) + hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key) + hidden = self.wo(hidden) + # hidden = jnp.dot(hidden, self.wo) + return hidden + + +# # %% +# # test for T5DenseActDense +# # create fake config +# config_dict = { +# 'd_model': 768, +# 'd_ff': 2048, +# 'dropout_rate': 0.1, +# 'initializer_factor': 1.0, +# } +# # Create a FrozenDict from the standard dictionary +# frozen_config = FrozenConfigDict(config_dict) +# # initialize model +# key = jax.random.PRNGKey(0) +# dense = T5DenseActDense( +# key=key, +# config=frozen_config, +# dtype=jnp.float32 +# ) +# input_key, key = jax.random.split(key) +# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values +# dropout_key, key = jax.random.split(key) +# output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) +# output.shape + +# %% +class T5LayerFF(eqx.Module): + config: FrozenConfigDict + dtype: jnp.dtype + DenseReluDense: T5DenseActDense + layer_norm: T5LayerNorm + dropout: eqx.nn.Dropout + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + config: FrozenConfigDict, + dtype: jnp.dtype = jnp.float32 + ): + + self.config = config + self.dtype = dtype + + dense_key, key = jax.random.split(key) + # args: key, config, dtype + self.DenseReluDense = T5DenseActDense( + key=dense_key, + config=config, + dtype=dtype + ) + layer_key, key = jax.random.split(key) + # args: key, hidden_size + self.layer_norm = T5LayerNorm( + key=layer_key, + hidden_size=self.config.d_model + ) + # args: dropout_rate + self.dropout = eqx.nn.Dropout(self.config.dropout_rate) + + def __call__( + self: eqx.Module, + inputs: Float[Array, " d_model"], + enable_dropout: bool =False, + dropout_key: Optional[jax.random.PRNGKey] = None, + ): + forwarded_states = self.layer_norm(inputs) + dropout_key, key = jax.random.split(dropout_key) + forwarded_states = self.DenseReluDense( + inputs=forwarded_states, + enable_dropout=enable_dropout, + dropout_key=dropout_key + ) + dropout_key, key = jax.random.split(key) + dropout_states = self.dropout( + x = forwarded_states, + inference=not enable_dropout, + key = dropout_key, + ) + hidden = inputs + dropout_states + return hidden + +# # %% +# # test for T5DenseActDense +# # create fake config +# config_dict = { +# 'd_model': 768, +# 'd_ff': 2048, +# 'dropout_rate': 0.1, +# 'initializer_factor': 1.0, +# } +# # Create a FrozenDict from the standard dictionary +# frozen_config = FrozenConfigDict(config_dict) +# # initialize model +# key = jax.random.PRNGKey(0) +# ff_layer = T5LayerFF( +# key=key, +# config=frozen_config, +# dtype=jnp.float32 +# ) +# input_key, key = jax.random.split(key) +# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values +# dropout_key, key = jax.random.split(key) +# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) +# output.shape + +# %% +class T5Attention(eqx.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool = False + causal: bool = False # False for encoder, True for decoder + dtype: jnp.dtype + + # parameters + q: jax.Array + k: jax.Array + v: jax.Array + o: jax.Array + + + # additional terms + relative_attention_num_buckets: int + relative_attention_max_distance: int + d_model: int + key_value_proj_dim: int + n_heads: int + dropout: float + inner_dim: int + initializer_factor: float + + + def __init__( + self: eqx.Module, + config: FrozenConfigDict, + dtype: jnp.dtype, + key: jax.random.PRNGKey, + ): + self.config = config + self.dtype = dtype + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + # size of k,v projection for each head + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.initializer_factor = self.config.initializer_factor + + q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.initializer_factor * (self.inner_dim**-0.5) + + q_key, key = jax.random.split(key) + self.q = KaimingLinear( + key=q_key, + input_dim=(self.inner_dim), + output_dim=self.inner_dim, + weights_init_std=q_init_std, + dtype=self.dtype + ) + + k_key, key = jax.random.split(key) + self.k = KaimingLinear( + key=k_key, + input_dim=self.inner_dim, + output_dim=self.inner_dim, + weights_init_std=kv_init_std, + dtype=self.dtype + ) + + v_key, key = jax.random.split(key) + self.v = KaimingLinear( + key=v_key, + input_dim=self.inner_dim, + output_dim=self.inner_dim, + weights_init_std=kv_init_std, + dtype=self.dtype + ) + + o_key, key = jax.random.split(key) + self.o = KaimingLinear( + key=o_key, + input_dim=self.inner_dim, + output_dim=self.d_model, + weights_init_std=o_init_std, + dtype=self.dtype + ) + + @staticmethod + def _relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + + # bidirection determines if positive relative positions are valid + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + # relative position range of [0, inf] + relative_position = -jnp.clip(relative_position, a_max=0) + + # half of buckets are for exact increments in positions + max_exact = num_buckets // 2 + # boolean to assign relative buckets later + is_small = relative_position < max_exact + + # other half are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + # jnp.where(condition, x, y), true->x, false->y + # in-place cumulative summation + # yields a list where every element has the correct relative bucket position + # whether its small or large + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + # bias gives weight based on relative distance aside from attention score + def compute_bias(self, query_length, key_length): + """ + Compute binned relative position bias + """ + # arange in the first dim + context_position = jnp.arange(query_length, dtype="i4")[:, None] + # arange in the second dim + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + # The relative position is defined as memory_position - query_position, + # i.e. the distance in tokens from the attending position to the + # attended-to position. + # + # 2D array where each entry represents the distance from a query token + # to a key token + relative_position = memory_position - context_position + # now we apply the earlier bucket creation function + relative_position_bucket = self._relative_position_bucket( + relative_position=relative_position, + bidirectional=(not self.causal), # causal during decode -> not bi + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # retrieve the bias values + # shape (query_length, key_length, n_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, n_heads, query_length, key_length) + # ready for attention + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + + # from (batch_size, seq_length, d_model) to + # (batch_size, seq_length, n_heads, head_dim) + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + # from (batch_size, seq_length, n_heads, head_dim) to + # (batch_size, seq_length, d_model) + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + + def _create_position_bias( + self, + key_states, + query_states, + attention_mask, + ): + # unlike the flax version, we don't even check for cache + key_length = key_states.shape[1] + query_length = query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros( + (1, self.n_heads, query_length, key_length), + dtype=self.dtype + ) + + return position_bias + + def __call__( + self, + inputs, + attention_mask=None, + key_value_states=None, + position_bias=None, + output_attentions=False, + enable_dropout=False, + dropout_key: Optional[jax.random.PRNGKey] = None, + ): + # expected input shape: (batch_size, seq_len, d_model) + # expected output: tuple of 2 arrays same shape as input + # (attn, position_bias) + batch_size, seq_length = inputs.shape[:2] + + # q,k,v projections + # (batch_size, n_heads, seq_length, dim_per_head) + query_states = self.q(inputs) + key_states = ( + self.k(inputs) if key_value_states is None else self.k(key_value_states) + ) + value_states = ( + self.v(inputs) if key_value_states is None else self.v(key_value_states) + ) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counteract scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # create causal attention_mask + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, + (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), + causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_key, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=not enable_dropout, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + +# # %% +# # test for T5Attention +# # create fake config +# config_dict = { +# 'relative_attention_num_buckets': 32, +# 'relative_attention_max_distance': 128, +# 'd_model': 768, # 64 * 12 +# 'd_kv': 64, +# 'num_heads': 12, +# 'dropout_rate': 0.1, +# 'initializer_factor': 1.0, +# } +# # Create a FrozenDict from the standard dictionary +# frozen_config = FrozenConfigDict(config_dict) +# # initialize model +# key = jax.random.PRNGKey(0) +# attn_layer = T5Attention( +# key=key, +# config=frozen_config, +# dtype=jnp.float32 +# ) +# input_key, key = jax.random.split(key) +# # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values +# batch_size = 1 +# seq_length = 10 +# inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model)) +# dropout_key, key = jax.random.split(key) +# output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) +# print(len(output)) +# print(output[0].shape) + +# %% diff --git a/equinox/t5_train_model.py b/equinox/t5_train_model.py new file mode 100644 index 0000000..0210940 --- /dev/null +++ b/equinox/t5_train_model.py @@ -0,0 +1,495 @@ +# %% +# package imports from equinox BERT example +import functools +from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple + +# import einops # https://github.com/arogozhnikov/einops +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +import optax # https://github.com/deepmind/optax +from datasets import load_dataset # https://github.com/huggingface/datasets +from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping +from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm +from transformers import AutoTokenizer # https://github.com/huggingface/transformers +from ml_collections import ConfigDict, FrozenConfigDict + +import flax.linen as nn +# %% +class T5LayerNorm(eqx.Module): + eps: float = 1e-6 + weight: jax.Array + # staticmethod forces the method to be by itself + weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones) + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + hidden_size: int, + # dtype: jnp.dtype = jnp.float32, + ): + # self.dtype = dtype + # self.params = { + # 'weight': self.weight_init(key, (hidden_size,), dtype) + # } + # force the use of float32 + # note that the key argument is ignored, so key is actually optional + self.weight = self.weight_init(key, (hidden_size,), jnp.float32) + + # takes in argument for hidden states so that it can fall through and remain + # a pure function + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; + No bias and no subtraction of mean + """ + # always compute in float32 for layer norm + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + +# # %% +# # testing T5LayerNorm +# key = jax.random.PRNGKey(0) +# hidden_size = 128 # Example hidden size +# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size) +# # Create some example input data +# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10 +# # Forward pass +# output = layer_norm(hidden_states) +# print("Output shape:", output.shape) + +# %% +class KaimingLinear(eqx.Module): + dtype: jnp.dtype = jnp.float32 + weights: jax.Array + + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + input_dim: int, + output_dim: int, + initializer_factor: float, + dtype: jnp.dtype = jnp.float32 + ): + self.dtype = dtype + + # the initialization strategy is to standardize on output dimension + # input + weights_init_std = initializer_factor * (input_dim**-0.5) + # shapes are: (input_dim, output_dim) + self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std + + def __call__( + self, + inputs: Float[Array, " input"], + ): + hidden = jnp.dot(inputs, self.weights) + return hidden + + + +# %% +# this function fortunately supports batched operations by default due to broadcasting +class T5DenseActDense(eqx.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 + wi: jax.Array + wo: jax.Array + dropout: eqx.nn.Dropout + act: jax.nn.relu + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + config: FrozenConfigDict, + dtype: jnp.dtype = jnp.float32 + ): + self.config = config + self.dtype = dtype + + mlp_key, output_key = jax.random.split(key) + # the initialization strategy is to standardize on output dimension + # input + # wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + # shapes are: (config.d_model, config.d_ff) + # self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std + self.wi = KaimingLinear( + key=mlp_key, + input_dim=self.config.d_model, + output_dim=self.config.d_ff, + initializer_factor=self.config.initializer_factor, + dtype=self.dtype + ) + # output + # wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + # shapes are: (config.d_ff, config.d_model) + # self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std + self.wo = KaimingLinear( + key=mlp_key, + input_dim=self.config.d_ff, + output_dim=self.config.d_model, + initializer_factor=self.config.initializer_factor, + dtype=self.dtype + ) + + + self.dropout = eqx.nn.Dropout(self.config.dropout_rate) + # just set to relu for now since the smaller T5's use relu + self.act = jax.nn.relu + + def __call__( + self, + inputs: Float[Array, " d_model"], + enable_dropout: bool = False, + dropout_key: Optional[jax.random.PRNGKey] = None, + ) -> Float[Array, " d_model"]: + hidden = self.wi(inputs) + # hidden = jnp.dot(inputs, self.wi) + hidden = self.act(hidden) + hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key) + hidden = self.wo(hidden) + # hidden = jnp.dot(hidden, self.wo) + return hidden + + + + + + +# # %% +# # test for T5DenseActDense +# # create fake config +# config_dict = { +# 'd_model': 768, +# 'd_ff': 2048, +# 'dropout_rate': 0.1, +# 'initializer_factor': 1.0, +# } +# # Create a FrozenDict from the standard dictionary +# frozen_config = FrozenConfigDict(config_dict) +# # initialize model +# key = jax.random.PRNGKey(0) +# dense = T5DenseActDense( +# key=key, +# config=frozen_config, +# dtype=jnp.float32 +# ) +# input_key, key = jax.random.split(key) +# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values +# dropout_key, key = jax.random.split(key) +# output = dense(inputs=inputs, enable_dropout=False, key=dropout_key) +# output.shape + +# %% +class T5LayerFF(eqx.Module): + config: FrozenConfigDict + dtype: jnp.dtype + DenseReluDense: T5DenseActDense + layer_norm: T5LayerNorm + dropout: eqx.nn.Dropout + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + config: FrozenConfigDict, + dtype: jnp.dtype = jnp.float32 + ): + + self.config = config + self.dtype = dtype + + dense_key, key = jax.random.split(key) + # args: key, config, dtype + self.DenseReluDense = T5DenseActDense( + key=dense_key, + config=config, + dtype=dtype + ) + layer_key, key = jax.random.split(key) + # args: key, hidden_size + self.layer_norm = T5LayerNorm( + key=layer_key, + hidden_size=self.config.d_model + ) + # args: dropout_rate + self.dropout = eqx.nn.Dropout(self.config.dropout_rate) + + def __call__( + self: eqx.Module, + inputs: Float[Array, " d_model"], + enable_dropout: bool =False, + dropout_key: Optional[jax.random.PRNGKey] = None, + ): + forwarded_states = self.layer_norm(inputs) + dropout_key, key = jax.random.split(dropout_key) + forwarded_states = self.DenseReluDense( + inputs=forwarded_states, + enable_dropout=enable_dropout, + dropout_key=dropout_key + ) + dropout_key, key = jax.random.split(key) + dropout_states = self.dropout( + x = forwarded_states, + key = dropout_key, + inference=not enable_dropout + ) + hidden = inputs + dropout_states + return hidden + +# # %% +# # test for T5DenseActDense +# # create fake config +# config_dict = { +# 'd_model': 768, +# 'd_ff': 2048, +# 'dropout_rate': 0.1, +# 'initializer_factor': 1.0, +# } +# # Create a FrozenDict from the standard dictionary +# frozen_config = FrozenConfigDict(config_dict) +# # initialize model +# key = jax.random.PRNGKey(0) +# ff_layer = T5LayerFF( +# key=key, +# config=frozen_config, +# dtype=jnp.float32 +# ) +# input_key, key = jax.random.split(key) +# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values +# dropout_key, key = jax.random.split(key) +# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) +# output.shape + +# %% +class T5Attention(eqx.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool = False + causal: bool = False # False for encoder, True for decoder + dtype: jnp.dtype + + # additional terms + relative_attention_num_buckets: int + relative_attention_max_distance: int + d_model: int + key_value_proj_dim: int + n_heads: int + dropout: float + inner_dim: int + initializer_factor: float + + + def __init__( + self: eqx.Module, + key: jax.random.PRNGKey, + config: FrozenConfigDict, + dtype: jnp.dtype = jnp.float32 + + ): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + # size of k,v projection for each head + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.initializer_factor = self.config.initializer_factor + + + q_key, key = jax.random.split(key) + self.q = KaimingLinear( + key=q_key, + input_dim=(self.inner_dim * self.key_value_proj_dim), + output_dim=self.inner_dim, + initializer_factor=self.initializer_factor, + dtype=self.dtype + ) + + k_key, key = jax.random.split(key) + self.k = KaimingLinear( + key=k_key, + input_dim=self.inner_dim, + output_dim=self.inner_dim, + initializer_factor=self.initializer_factor, + dtype=self.dtype + ) + + v_key, key = jax.random.split(key) + self.v = KaimingLinear( + key=v_key, + input_dim=self.inner_dim, + output_dim=self.inner_dim, + initializer_factor=self.initializer_factor, + dtype=self.dtype + ) + + o_key, key = jax.random.split(key) + self.o = KaimingLinear( + key=o_key, + input_dim=self.inner_dim, + output_dim=self.d_model, + initializer_factor=self.initializer_factor, + dtype=self.dtype + ) + + # 1 bias per head, so output is n_heads + # bias is learned during training + if self.has_relative_attention_bias: + input_dim = self.relative_attention_num_buckets + output_dim = self.n_heads + initializer_factor=self.initializer_factor + # we standardize based on the output dimension, + # which is n_head * kv_proj_dim - during multi head attention + weights_init_std = initializer_factor * (self.inner_dim**-0.5) + # shapes are: (input_dim, output_dim) + weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std + + self.relative_attention_bias = eqx.nn.Embedding( + weights=weights + ) + + @staticmethod + def _relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + + # bidirection determines if positive relative positions are valid + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + # relative position range of [0, inf] + relative_position = -jnp.clip(relative_position, a_max=0) + + # half of buckets are for exact increments in positions + max_exact = num_buckets // 2 + # boolean to assign relative buckets later + is_small = relative_position < max_exact + + # other half are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + # jnp.where(condition, x, y), true->x, false->y + # in-place cumulative summation + # yields a list where every element has the correct relative bucket position + # whether its small or large + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + # bias gives weight based on relative distance aside from attention score + def compute_bias(self, query_length, key_length): + """ + Compute binned relative position bias + """ + # arange in the first dim + context_position = jnp.arange(query_length, dtype="i4")[:, None] + # arange in the second dim + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + # The relative position is defined as memory_position - query_position, + # i.e. the distance in tokens from the attending position to the + # attended-to position. + # + # 2D array where each entry represents the distance from a query token + # to a key token + relative_position = memory_position - context_position + # now we apply the earlier bucket creation function + relative_position_bucket = self._relative_position_bucket( + relative_position=relative_position, + bidirectional=(not self.causal), # causal during decode -> not bi + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # retrieve the bias values + # shape (query_length, key_length, n_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, n_heads, query_length, key_length) + # ready for attention + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + + def _create_position_bias( + self, + key_states, + query_states, + attention_mask, + init_cache, + seq_length, + causal_attention_mask_shift + ): + # unlike the flax version, we don't even check for cache + key_length = key_states.shape[1] + query_length = query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros( + (1, self.n_heads, query_length, key_length), + dtype=self.dtype + ) + + return position_bias + + def __call__( + self, + inputs, + attention_mask=None, + key_value_states=None, + position_bias=None, + output_attentions=False, + enable_dropout=False + ): + batch_size, seq_length = inputs.shape[:2] + + # q,k,v projections + query_states = self.q(inputs) + key_states = ( + self.k(inputs) if key_value_states is None else self.k(key_value_states) + ) + value_states = ( + self.v(inputs) if key_value_states is None else self.v(key_value_states) + ) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counteract scaling in dot_product_attention_weights function + # not sure if this is a good idea in equinox + diff --git a/make_context_data.py b/make_context_data.py new file mode 100644 index 0000000..d8aac4f --- /dev/null +++ b/make_context_data.py @@ -0,0 +1,279 @@ +# %% +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = False + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +os.environ["XLA_FLAGS"] = flags +os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" + }) + + + +import pandas as pd +import matplotlib.pyplot as plt + +from datasets import Dataset + +import jax +import jax.numpy as jnp +import optax +import numpy as np +import functools +from typing import Callable, Optional +import math +from jax.sharding import Mesh, NamedSharding +from jax.experimental import mesh_utils + +from jax.sharding import PartitionSpec + +# set cache +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + +# jax.config.update("jax_default_matmul_precision", "tensorfloat32") +jax.config.update("jax_default_matmul_precision", "bfloat16") + +jax.config.update("jax_enable_x64", False) + + +# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig + + +import datasets +from datasets import Dataset +import evaluate +from tqdm import tqdm + + +import nltk # Here to have a nice missing dependency error message early on + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key + +from ml_collections import ConfigDict + +import time + +from parallel.dataload import DataPrepare + + +# %% +# import data + +# load training +data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv" +# Ensure to include 'ships_idx' in the fields list +fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] +# Load the dataset +train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) + +# # load valid +# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv" +# # Ensure to include 'ships_idx' in the fields list +# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] +# # Load the dataset +# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) +def process_df(df): + output_list = [{ + 'input': f"{row['tag_name']}{row['tag_description']}", + 'output': f"{row['thing']}{row['property']}", + } for _, row in df.iterrows()] + return output_list + + +# takes 1 minute to run without batching +train_dataset = Dataset.from_list(process_df(train_df)) + + +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(train_dataset, data_config) + +# %% +# load model +model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name +from transformers import FlaxT5ForConditionalGeneration +model = FlaxT5ForConditionalGeneration.from_pretrained( + model_name_or_path +) +params = model.params + +# %% +# load data +SEED = 117 +batch_size = 256 # per device batch_size +# test_batch_size multiplies by 4 because we shard by 4 later +train_batch_size = batch_size * jax.device_count() +rng = jax.random.PRNGKey(SEED) + + +# %% +# setup sharding +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((4,1)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec, memory_kind="device") + +data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis +# model_sharding=mesh_sharding(PartitionSpec('model')) +replicate_sharding=mesh_sharding(PartitionSpec()) + + +# %% +# define function to get encodings + +def get_encodings(batch, params): + input_ids=batch['input_ids'] + attention_mask=batch['attention_mask'] + input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1])) + attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1])) + encoder_outputs = model.encode( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=False, + output_hidden_states=False, + train=False, + params=params, + dropout_rng=None + ) + # encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed) + # the embedding is not the full embedding size, but the self-attention embed + # size + + # parallelize by multiply encoder outputs with attention mask + # shape (batch, embed) -> (batch, embed, 1) + # this helps it to have the same shape as encoder_outputs + expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1) + # here is an element-wise multiply + embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768) + # summing embeddings in axis 1 will sum column-wise into a (batch, 768) + # summing attention_mask in axis 1 will sum column-wise to get the total + # unmasked token count for data entry + mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1) + # the shape of mean_embeddings is (batch, embed), we are ready to return + return mean_embeddings + +get_encodings_jit = jax.jit( + functools.partial(get_encodings, params=params), + # rng, batch + in_shardings=(data_sharding), + out_shardings=replicate_sharding, +) +# # %% +# # test the get_encodings function +# +# rng, input_rng = jax.random.split(rng) +# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False) +# batch = next(train_loader) +# encodings = get_encodings(batch, params) +# # function works! + + +# %% +# perform actual prediction +encoding_list = [] +# note: train_batch_size is batch_size * 4 +# we have 4 devices +pred_steps = math.ceil(len(train_dataset) / train_batch_size) +print("***** Running prediction *****") +print(f" Num examples = {len(train_dataset)}") +print(f" Num steps = {pred_steps}") +print(f" Instantaneous batch size per device = {batch_size}") +print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") + +rng, input_rng = jax.random.split(rng) +train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False) + +for _ in tqdm(range(pred_steps), desc="Predicting..."): + batch = next(train_loader) + batch = jax.device_put(batch, data_sharding) + encodings = get_encodings_jit(batch) + encoding_list.extend(jax.device_get(encodings)) + + +# %% +encoding_list = jnp.vstack(encoding_list) +# slice up to the previously defined list to unpad +encoding_list = encoding_list[:len(train_dataset)] +print(encoding_list.shape) + + + +# %% +# getting top-k +def top_k_cosine_similarity(M, a, k): + """ + Find the top-k rows in matrix M that are most cosine similar to array a. + + Args: + M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector. + a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M. + k (int): Number of top cosine similarities to retrieve. + + Returns: + values (jnp.ndarray): Top-k cosine similarity values. + indices (jnp.ndarray): Indices of the top-k most similar rows in M. + """ + # Step 1: Normalize both M and a + M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d) + a_norm = a / jnp.linalg.norm(a) # Shape: (d,) + + # Step 2: Compute cosine similarity via dot product + cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,) + + # Step 3: Get the top-k values and their indices using jax.lax.top_k + values, indices = jax.lax.top_k(cosine_similarities, k) + + return values, indices + +# Example usage: +M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions +a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions + +# Find top 5 most similar rows +top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5) + +print("Top-k cosine similarity values:", top_k_values) +print("Indices of top-k similar rows:", top_k_indices) + +# %% diff --git a/nnx/.gitignore b/nnx/.gitignore new file mode 100644 index 0000000..278f62f --- /dev/null +++ b/nnx/.gitignore @@ -0,0 +1 @@ +MNIST diff --git a/nnx/mnist.py b/nnx/mnist.py new file mode 100644 index 0000000..9b3f222 --- /dev/null +++ b/nnx/mnist.py @@ -0,0 +1,168 @@ +# %% +import tensorflow_datasets as tfds # TFDS for MNIST +import tensorflow as tf # TensorFlow operations + +tf.random.set_seed(0) # set random seed for reproducibility + +num_epochs = 10 +batch_size = 32 + +train_ds: tf.data.Dataset = tfds.load('mnist', split='train') +test_ds: tf.data.Dataset = tfds.load('mnist', split='test') + +train_ds = train_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } +) # normalize train set +test_ds = test_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } +) # normalize test set + +# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +train_ds = train_ds.repeat(num_epochs).shuffle(1024) +# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) +# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +test_ds = test_ds.shuffle(1024) +# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) +# %% +from flax import nnx # NNX API +from functools import partial + +class CNN(nnx.Module): + """A simple CNN model.""" + + def __init__(self, *, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) + + def __call__(self, x): + x = self.avg_pool(nnx.relu(self.conv1(x))) + x = self.avg_pool(nnx.relu(self.conv2(x))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.linear1(x)) + x = self.linear2(x) + return x + + +model = CNN(rngs=nnx.Rngs(0)) + +# %% +nnx.display(model) +# %% +# test the model by feeding an example input +import jax.numpy as jnp # JAX NumPy + +y = model(jnp.ones((1, 28, 28, 1))) +nnx.display(y) +# %% +import optax + +learning_rate = 0.005 +momentum = 0.9 + +optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) +metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average('loss'), +) + +nnx.display(optimizer) + +# %% + +def loss_fn(model: CNN, batch): + logits = model(batch['image']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label'] + ).mean() + return loss, logits + +# %% +@nnx.jit +def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch): + """Train for a single step.""" + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) + optimizer.update(grads) + +# %% +# evaluation step +@nnx.jit +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): + loss, logits = loss_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) + + +# %% +# for dataset seed random generation +tf.random.set_seed(0) + +# %% +num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs + +metrics_history = { + 'train_loss': [], + 'train_accuracy': [], + 'test_loss': [], + 'test_accuracy': [], +} + +for step, batch in enumerate(train_ds.as_numpy_iterator()): + # Run the optimization for one step and make a stateful update to the following: + # - the train state's model parameters + # - the optimizer state + # - the training loss and accuracy batch metrics + train_step(model, optimizer, metrics, batch) + + if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed + # Log training metrics + for metric, value in metrics.compute().items(): # compute metrics + metrics_history[f'train_{metric}'].append(value) # record metrics + metrics.reset() # reset metrics for test set + + # Compute metrics on the test set after each training epoch + for test_batch in test_ds.as_numpy_iterator(): + eval_step(model, metrics, test_batch) + + # Log test metrics + for metric, value in metrics.compute().items(): + metrics_history[f'test_{metric}'].append(value) + metrics.reset() # reset metrics for next training epoch + + print( + f"train epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['train_loss'][-1]}, " + f"accuracy: {metrics_history['train_accuracy'][-1] * 100}" + ) + print( + f"test epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['test_loss'][-1]}, " + f"accuracy: {metrics_history['test_accuracy'][-1] * 100}" + ) +# %% +# visualize metrics +import matplotlib.pyplot as plt # Visualization + +# Plot loss and accuracy in subplots +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) +ax1.set_title('Loss') +ax2.set_title('Accuracy') +for dataset in ('train', 'test'): + ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') + ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') +ax1.legend() +ax2.legend() +plt.show() + +# %%