Feat: implemented attention layer in equinox

This commit is contained in:
Richard Wong 2024-10-06 23:52:42 +09:00
parent a817fe16cc
commit 0762c02b31
10 changed files with 2119 additions and 0 deletions

172
dataload.py Normal file
View File

@ -0,0 +1,172 @@
# %%
# Prepare dataloader for jax training
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
from transformers import FlaxT5ForConditionalGeneration
from datasets import ClassLabel, Value, Sequence
from ml_collections import ConfigDict
import numpy as np
import jax.numpy as jnp
import jax
import math
from typing import Optional, List, Tuple, Callable, cast
# file_path = 'combined_data'
# split_datasets = load_from_disk(file_path)
# training_size = len(split_datasets['train'])
from transformers import T5TokenizerFast
# class takes in a dataset
class DataPrepare():
def __init__(self, raw_dataset, config):
self.raw_dataset: Dataset = raw_dataset
self.size: int = len(raw_dataset)
self.config: ConfigDict = config
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
# Define additional special tokens
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
self.train_dataset = self.preprocess_function(
self.raw_dataset
)
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(self, example: Dataset):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = self.tokenizer(
inputs,
max_length=self.config.max_length,
padding=True,
truncation=True,
return_tensors="np"
)
# we separate it out because we need the attention mask
labels = self.tokenizer(
text_target=targets,
max_length=self.config.max_length,
padding=True,
truncation=True,
return_tensors="np"
)
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
# this is actually "model output" shifted right
decoder_input_ids = self.shift_tokens_right_fn(
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# decoder_attention_mask = shift_tokens_right_fn(
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
# )
# We need decoder_attention_mask so we can ignore pad tokens in loss
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
return model_inputs
# Example pad function
def _pad_to_batch_size(self, batch, target_size):
# Get the current batch size
input_ids = batch['input_ids']
current_size = input_ids.shape[0]
if current_size < target_size:
# Calculate how much padding is needed
padding_size = target_size - current_size
# Create padding (e.g., zeros or some appropriate value)
padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D
# Concatenate to create a full batch
# repeat for all arrays in the tree
padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch)
# padded_batch = jnp.concatenate([batch, padding], axis=0)
else:
padded_batch = batch
return padded_batch
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
dataset: Dataset = Dataset.from_dict(self.train_dataset)
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
minibatch_list = np.array_split(batch_idx, steps_per_epoch)
for minibatch in minibatch_list:
batch = dataset[minibatch]
batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()}
batch = self._pad_to_batch_size(batch, batch_size)
yield batch
# # testing out the class
# # %%
# # init object
# # e.g. Config
#
# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc'
# data_config = ConfigDict(
# dict(
# max_length=86,
# pad_token_id=0,
# decoder_start_token_id=0
# )
# )
#
# from datasets import load_from_disk
# split_datasets = load_from_disk(file_path)
# dataprep = DataPrepare(split_datasets['train'], data_config)
#
# # %%
# seed = 117
# rng = jax.random.PRNGKey(seed)
# train_loader = dataprep.data_loader(rng, batch_size=32)
#
#
#
# # %%
# batch = next(train_loader)
#
# print(batch['input_ids'].shape)
# print(batch['decoder_input_ids'].shape)
#
# # %%

1
equinox/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__/

View File

@ -0,0 +1,54 @@
# an example of stateful operations
# %%
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # https://github.com/deepmind/optax
from equinox.nn import State, StateIndex, StatefulLayer
from jaxtyping import Array
# %%
class Counter(eqx.Module):
# This wraps together (a) a unique dictionary key used for looking up a
# stateful value, and (b) how that stateful value should be initialised.
index: eqx.nn.StateIndex
def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)
# eqx.nn.State stores the state of the model
# This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.
# This class should be initialised via equinox.nn.make_with_state.
#
# Basically just a dictionary which (a) works only with StateIndex-s, and which (b)
# works around a JAX bug that prevents flattening dicts with `object()` keys, and which
# (c) does error-checking that you're using the most up-to-date version of it.
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
# Sets a new value for an [`equinox.nn.StateIndex`][], and returns the
# updated state.
new_state = state.set(self.index, value + 1)
return new_x, new_state
# make_with_state is the recommended way to start a stateful model
counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 0
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 1
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 2
# %%

View File

@ -0,0 +1,106 @@
# %%
# introduction to how flax does stateful operations
import flax.linen as nn
import jax.numpy as jnp
import jax
import flax
from jaxtyping import Array
# %%
class BiasAdderWithRunningMean(nn.Module):
momentum: float = 0.9
@nn.compact
def __call__(self, x):
is_initialized = self.has_variable('hehe', 'mean')
print(is_initialized)
mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:])
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
print(mean.value) # notice that value retains after first call
mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(
x, axis=0, keepdims=True
)
print(mean.value)
return mean.value + bias
# %%
input_key = jax.random.PRNGKey(0)
model = BiasAdderWithRunningMean()
inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values
variables = model.init(input_key, inputs)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
print(f"first init: {state}")
# %%
for i in range(5):
new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs
# notice how we are threading the state
# perform argument unpacking on state dictionary
output, state = model.apply({'params': params, **state},
new_inputs, mutable=list(state.keys()))
# mean_state = variables['batch_stats']['mean'] # Access the updated mean state
print(f"updated state {state}")
print(f"Output after input {i + 1}: {output}")
# print(f"Updated running mean state: {mean_state}")
# %%
###########################################################
# example 2
from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal
import jax.random as random
class Foo(nn.Module):
features: int
@nn.compact
def __call__(self):
key = self.make_rng('spectral_norm_stats')
print(key)
u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features))
return u0_variable.value
foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)})
Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32)
# %%
model = Foo(3)
# %%
# state is kept in self.variable, tied to the layer
output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
# %%
output, state = model.apply(
foovars,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
print(output, state)
# %%
output, state = model.apply(
state,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
# no change because input state is the same
print(output, state)
# %%
state_array = state['spectral_norm_stats']['u0']
modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0))
state['spectral_norm_stats']['u0'] = modified_array
# %%
# %%
output, state = model.apply(
state,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
# state takes from given state
# note the modified 0.9 value
# note how the state is not re-initialized
print(output, state)
# %%

252
equinox/mnist.py Normal file
View File

@ -0,0 +1,252 @@
# %%
import equinox as eqx
import jax
import jax.numpy as jnp
import optax # https://github.com/deepmind/optax
import torch # https://pytorch.org
import torchvision # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
# %%
# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678
key = jax.random.PRNGKey(SEED)
# %%
normalise_data = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
)
train_dataset = torchvision.datasets.MNIST(
"MNIST",
train=True,
download=True,
transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
"MNIST",
train=False,
download=True,
transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True
)
# %%
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape) # 64x1x28x28
print(dummy_y.shape) # 64
print(dummy_y)
# %%
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu, # jax functions!!!
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
key, subkey = jax.random.split(key, 2)
model = CNN(subkey)
# %%
print(model)
# %%
# print the first layer: Conv2d
print(model.layers[0])
# %%
# illustrated inference
def loss(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
# a single input input image of shape (1, 28, 28).
#
# Therefore, we have to use jax.vmap, which in this case maps our model over the
# leading (batch) axis.
#
# This is an example of writing function for one input, then letting jax
# automatically vectorize over the batch dimension
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
def cross_entropy(
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
# y are the true targets, and should be integers 0-9.
# pred_y are the log-softmax'd predictions.
# take_along_axis: take from pred_y along axis 1 according to 2nd argument
# expand_dims to axis 1 makes it of shape (y_dim, 1)
# since we take along axis 1, each y (in 2nd arg) therefore takes the
# corresponding entry of each row in pred_y
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y) # negative mean of relevant logits
# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value)
print(loss_value.shape) # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape) # batch of predictions
# %%
# This is an error!
# the reason is that model has to be parameters, but model has non-parameters
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
# %%
# we separate out things that are params from other things
# since params are things that are arrays
# partition is doing filter(...) and filter(..., inverse=True)
params, static = eqx.partition(model, eqx.is_array)
# %%
# lets compare the same object in both terms
print(static.layers[0])
print(params.layers[0])
# %%
# in the loss, we recombine both to form back our model
def loss2(params, static, x, y):
model = eqx.combine(params, static)
return loss(model, x, y)
# Now this will work!
# since the grad only looks at the first argument, this works out
loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)
print(loss_value)
# %%
# This will work too!
# this works the same as the previous
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)
# %%
# evaluation
loss = eqx.filter_jit(loss) # JIT our loss function from earlier!
@eqx.filter_jit
def compute_accuracy(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
"""This function takes as input the current model
and computes the average accuracy on a batch.
"""
pred_y = jax.vmap(model)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
# %%
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
"""
avg_loss = 0
avg_acc = 0
for x, y in testloader:
x = x.numpy()
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
avg_loss += loss(model, x, y)
avg_acc += compute_accuracy(model, x, y)
return avg_loss / len(testloader), avg_acc / len(testloader)
# %%
evaluate(model, testloader)
# %%
# training
optim = optax.adamw(LEARNING_RATE)
def train(
model: CNN,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
# Just like earlier: It only makes sense to train the arrays in our model,
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array))
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(model, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
)
return model
# %%
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
# %%

View File

@ -0,0 +1,591 @@
# %%
# package imports from equinox BERT example
import functools
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
# import einops # https://github.com/arogozhnikov/einops
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax # https://github.com/deepmind/optax
from datasets import load_dataset # https://github.com/huggingface/datasets
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
from ml_collections import ConfigDict, FrozenConfigDict
# helper functions for attention computation
# they are implemented with jax w/o flax
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
import flax.linen as nn
# %%
class T5LayerNorm(eqx.Module):
eps: float = 1e-6
weight: jax.Array
# staticmethod forces the method to be by itself
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
def __init__(
self: eqx.Module,
hidden_size: int,
key: jax.random.PRNGKey,
# dtype: jnp.dtype = jnp.float32,
):
# self.dtype = dtype
# self.params = {
# 'weight': self.weight_init(key, (hidden_size,), dtype)
# }
# force the use of float32
# note that the key argument is ignored, so key is actually optional
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
# takes in argument for hidden states so that it can fall through and remain
# a pure function
def __call__(self, hidden_states):
"""
Construct a layernorm module in the T5 style;
No bias and no subtraction of mean
"""
# always compute in float32 for layer norm
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
return self.weight * hidden_states
# # %%
# # testing T5LayerNorm
# key = jax.random.PRNGKey(0)
# hidden_size = 128 # Example hidden size
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
# # Create some example input data
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
# # Forward pass
# output = layer_norm(hidden_states)
# print("Output shape:", output.shape)
# %%
class KaimingLinear(eqx.Module):
dtype: jnp.dtype = jnp.float32
weights: jax.Array
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
input_dim: int,
output_dim: int,
weights_init_std: float,
dtype: jnp.dtype = jnp.float32
):
self.dtype = dtype
# the initialization strategy is to standardize on output dimension
# shapes are: (input_dim, output_dim)
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
def __call__(
self,
inputs: Float[Array, " input"],
):
hidden = jnp.dot(inputs, self.weights)
return hidden
# %%
# this function fortunately supports batched operations by default due to broadcasting
class T5DenseActDense(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype = jnp.float32
wi: jax.Array
wo: jax.Array
dropout: eqx.nn.Dropout
act: jax.nn.relu
def __init__(
self: eqx.Module,
config: FrozenConfigDict,
dtype: jnp.dtype,
key: jax.random.PRNGKey
):
self.config = config
self.dtype = dtype
mlp_key, output_key = jax.random.split(key)
# the initialization strategy is to standardize on output dimension
# input
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
# shapes are: (config.d_model, config.d_ff)
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
self.wi = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_model,
output_dim=self.config.d_ff,
weights_init_std=wi_init_std,
dtype=self.dtype
)
# output
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
# shapes are: (config.d_ff, config.d_model)
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
self.wo = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_ff,
output_dim=self.config.d_model,
weights_init_std=wo_init_std,
dtype=self.dtype
)
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
# just set to relu for now since the smaller T5's use relu
self.act = jax.nn.relu
def __call__(
self,
inputs: Float[Array, " d_model"],
enable_dropout: bool = False,
dropout_key: Optional[jax.random.PRNGKey] = None,
) -> Float[Array, " d_model"]:
hidden = self.wi(inputs)
# hidden = jnp.dot(inputs, self.wi)
hidden = self.act(hidden)
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
hidden = self.wo(hidden)
# hidden = jnp.dot(hidden, self.wo)
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# dense = T5DenseActDense(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
# output.shape
# %%
class T5LayerFF(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype
DenseReluDense: T5DenseActDense
layer_norm: T5LayerNorm
dropout: eqx.nn.Dropout
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.config = config
self.dtype = dtype
dense_key, key = jax.random.split(key)
# args: key, config, dtype
self.DenseReluDense = T5DenseActDense(
key=dense_key,
config=config,
dtype=dtype
)
layer_key, key = jax.random.split(key)
# args: key, hidden_size
self.layer_norm = T5LayerNorm(
key=layer_key,
hidden_size=self.config.d_model
)
# args: dropout_rate
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
def __call__(
self: eqx.Module,
inputs: Float[Array, " d_model"],
enable_dropout: bool =False,
dropout_key: Optional[jax.random.PRNGKey] = None,
):
forwarded_states = self.layer_norm(inputs)
dropout_key, key = jax.random.split(dropout_key)
forwarded_states = self.DenseReluDense(
inputs=forwarded_states,
enable_dropout=enable_dropout,
dropout_key=dropout_key
)
dropout_key, key = jax.random.split(key)
dropout_states = self.dropout(
x = forwarded_states,
inference=not enable_dropout,
key = dropout_key,
)
hidden = inputs + dropout_states
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# ff_layer = T5LayerFF(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
# output.shape
# %%
class T5Attention(eqx.Module):
config: FrozenConfigDict
has_relative_attention_bias: bool = False
causal: bool = False # False for encoder, True for decoder
dtype: jnp.dtype
# parameters
q: jax.Array
k: jax.Array
v: jax.Array
o: jax.Array
# additional terms
relative_attention_num_buckets: int
relative_attention_max_distance: int
d_model: int
key_value_proj_dim: int
n_heads: int
dropout: float
inner_dim: int
initializer_factor: float
def __init__(
self: eqx.Module,
config: FrozenConfigDict,
dtype: jnp.dtype,
key: jax.random.PRNGKey,
):
self.config = config
self.dtype = dtype
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
# size of k,v projection for each head
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.initializer_factor = self.config.initializer_factor
q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
kv_init_std = self.initializer_factor * (self.inner_dim**-0.5)
o_init_std = self.initializer_factor * (self.inner_dim**-0.5)
q_key, key = jax.random.split(key)
self.q = KaimingLinear(
key=q_key,
input_dim=(self.inner_dim),
output_dim=self.inner_dim,
weights_init_std=q_init_std,
dtype=self.dtype
)
k_key, key = jax.random.split(key)
self.k = KaimingLinear(
key=k_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
weights_init_std=kv_init_std,
dtype=self.dtype
)
v_key, key = jax.random.split(key)
self.v = KaimingLinear(
key=v_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
weights_init_std=kv_init_std,
dtype=self.dtype
)
o_key, key = jax.random.split(key)
self.o = KaimingLinear(
key=o_key,
input_dim=self.inner_dim,
output_dim=self.d_model,
weights_init_std=o_init_std,
dtype=self.dtype
)
@staticmethod
def _relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
# bidirection determines if positive relative positions are valid
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
# relative position range of [0, inf]
relative_position = -jnp.clip(relative_position, a_max=0)
# half of buckets are for exact increments in positions
max_exact = num_buckets // 2
# boolean to assign relative buckets later
is_small = relative_position < max_exact
# other half are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
# jnp.where(condition, x, y), true->x, false->y
# in-place cumulative summation
# yields a list where every element has the correct relative bucket position
# whether its small or large
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
# bias gives weight based on relative distance aside from attention score
def compute_bias(self, query_length, key_length):
"""
Compute binned relative position bias
"""
# arange in the first dim
context_position = jnp.arange(query_length, dtype="i4")[:, None]
# arange in the second dim
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
# The relative position is defined as memory_position - query_position,
# i.e. the distance in tokens from the attending position to the
# attended-to position.
#
# 2D array where each entry represents the distance from a query token
# to a key token
relative_position = memory_position - context_position
# now we apply the earlier bucket creation function
relative_position_bucket = self._relative_position_bucket(
relative_position=relative_position,
bidirectional=(not self.causal), # causal during decode -> not bi
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# retrieve the bias values
# shape (query_length, key_length, n_heads)
values = self.relative_attention_bias(relative_position_bucket)
# shape (1, n_heads, query_length, key_length)
# ready for attention
values = values.transpose((2, 0, 1))[None, :, :, :]
return values
# from (batch_size, seq_length, d_model) to
# (batch_size, seq_length, n_heads, head_dim)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
# from (batch_size, seq_length, n_heads, head_dim) to
# (batch_size, seq_length, d_model)
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
def _create_position_bias(
self,
key_states,
query_states,
attention_mask,
):
# unlike the flax version, we don't even check for cache
key_length = key_states.shape[1]
query_length = query_states.shape[1]
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros(
(1, self.n_heads, query_length, key_length),
dtype=self.dtype
)
return position_bias
def __call__(
self,
inputs,
attention_mask=None,
key_value_states=None,
position_bias=None,
output_attentions=False,
enable_dropout=False,
dropout_key: Optional[jax.random.PRNGKey] = None,
):
# expected input shape: (batch_size, seq_len, d_model)
# expected output: tuple of 2 arrays same shape as input
# (attn, position_bias)
batch_size, seq_length = inputs.shape[:2]
# q,k,v projections
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = self.q(inputs)
key_states = (
self.k(inputs) if key_value_states is None else self.k(key_value_states)
)
value_states = (
self.v(inputs) if key_value_states is None else self.v(key_value_states)
)
# reshape to (batch_size, seq_length, n_heads, head_dim)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# counteract scaling in dot_product_attention_weights function
query_states *= jnp.sqrt(query_states.shape[-1])
# create causal attention_mask
if self.causal:
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
# broadcast causal attention mask & attention mask to fit for merge
causal_attention_mask = jnp.broadcast_to(
causal_attention_mask,
(batch_size,) + causal_attention_mask.shape[1:]
)
attention_mask = jnp.broadcast_to(
jnp.expand_dims(attention_mask, axis=(-3, -2)),
causal_attention_mask.shape
)
attention_mask = combine_masks(attention_mask, causal_attention_mask)
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# replace masked positions with -10_000
if attention_mask is not None:
mask_value = jnp.finfo(self.dtype).min
attention_mask = jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
)
if position_bias is None:
# compute position bias (only for first layer)
position_bias = self._create_position_bias(
key_states, query_states, attention_mask
)
if attention_mask is not None:
position_bias = position_bias + attention_mask
# Softmax(QK^T)
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=position_bias,
dropout_rng=dropout_key,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=not enable_dropout,
dtype=self.dtype,
)
# multiply with value states
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
# bring back to (batch_size, seq_length, d_model)
attn_output = self._merge_heads(attn_output)
# apply output matrix
attn_output = self.o(attn_output)
outputs = (attn_output, position_bias)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
# # %%
# # test for T5Attention
# # create fake config
# config_dict = {
# 'relative_attention_num_buckets': 32,
# 'relative_attention_max_distance': 128,
# 'd_model': 768, # 64 * 12
# 'd_kv': 64,
# 'num_heads': 12,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# attn_layer = T5Attention(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# batch_size = 1
# seq_length = 10
# inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model))
# dropout_key, key = jax.random.split(key)
# output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
# print(len(output))
# print(output[0].shape)
# %%

495
equinox/t5_train_model.py Normal file
View File

@ -0,0 +1,495 @@
# %%
# package imports from equinox BERT example
import functools
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
# import einops # https://github.com/arogozhnikov/einops
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax # https://github.com/deepmind/optax
from datasets import load_dataset # https://github.com/huggingface/datasets
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
from ml_collections import ConfigDict, FrozenConfigDict
import flax.linen as nn
# %%
class T5LayerNorm(eqx.Module):
eps: float = 1e-6
weight: jax.Array
# staticmethod forces the method to be by itself
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
hidden_size: int,
# dtype: jnp.dtype = jnp.float32,
):
# self.dtype = dtype
# self.params = {
# 'weight': self.weight_init(key, (hidden_size,), dtype)
# }
# force the use of float32
# note that the key argument is ignored, so key is actually optional
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
# takes in argument for hidden states so that it can fall through and remain
# a pure function
def __call__(self, hidden_states):
"""
Construct a layernorm module in the T5 style;
No bias and no subtraction of mean
"""
# always compute in float32 for layer norm
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
return self.weight * hidden_states
# # %%
# # testing T5LayerNorm
# key = jax.random.PRNGKey(0)
# hidden_size = 128 # Example hidden size
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
# # Create some example input data
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
# # Forward pass
# output = layer_norm(hidden_states)
# print("Output shape:", output.shape)
# %%
class KaimingLinear(eqx.Module):
dtype: jnp.dtype = jnp.float32
weights: jax.Array
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
input_dim: int,
output_dim: int,
initializer_factor: float,
dtype: jnp.dtype = jnp.float32
):
self.dtype = dtype
# the initialization strategy is to standardize on output dimension
# input
weights_init_std = initializer_factor * (input_dim**-0.5)
# shapes are: (input_dim, output_dim)
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
def __call__(
self,
inputs: Float[Array, " input"],
):
hidden = jnp.dot(inputs, self.weights)
return hidden
# %%
# this function fortunately supports batched operations by default due to broadcasting
class T5DenseActDense(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype = jnp.float32
wi: jax.Array
wo: jax.Array
dropout: eqx.nn.Dropout
act: jax.nn.relu
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.config = config
self.dtype = dtype
mlp_key, output_key = jax.random.split(key)
# the initialization strategy is to standardize on output dimension
# input
# wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
# shapes are: (config.d_model, config.d_ff)
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
self.wi = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_model,
output_dim=self.config.d_ff,
initializer_factor=self.config.initializer_factor,
dtype=self.dtype
)
# output
# wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
# shapes are: (config.d_ff, config.d_model)
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
self.wo = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_ff,
output_dim=self.config.d_model,
initializer_factor=self.config.initializer_factor,
dtype=self.dtype
)
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
# just set to relu for now since the smaller T5's use relu
self.act = jax.nn.relu
def __call__(
self,
inputs: Float[Array, " d_model"],
enable_dropout: bool = False,
dropout_key: Optional[jax.random.PRNGKey] = None,
) -> Float[Array, " d_model"]:
hidden = self.wi(inputs)
# hidden = jnp.dot(inputs, self.wi)
hidden = self.act(hidden)
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
hidden = self.wo(hidden)
# hidden = jnp.dot(hidden, self.wo)
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# dense = T5DenseActDense(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = dense(inputs=inputs, enable_dropout=False, key=dropout_key)
# output.shape
# %%
class T5LayerFF(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype
DenseReluDense: T5DenseActDense
layer_norm: T5LayerNorm
dropout: eqx.nn.Dropout
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.config = config
self.dtype = dtype
dense_key, key = jax.random.split(key)
# args: key, config, dtype
self.DenseReluDense = T5DenseActDense(
key=dense_key,
config=config,
dtype=dtype
)
layer_key, key = jax.random.split(key)
# args: key, hidden_size
self.layer_norm = T5LayerNorm(
key=layer_key,
hidden_size=self.config.d_model
)
# args: dropout_rate
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
def __call__(
self: eqx.Module,
inputs: Float[Array, " d_model"],
enable_dropout: bool =False,
dropout_key: Optional[jax.random.PRNGKey] = None,
):
forwarded_states = self.layer_norm(inputs)
dropout_key, key = jax.random.split(dropout_key)
forwarded_states = self.DenseReluDense(
inputs=forwarded_states,
enable_dropout=enable_dropout,
dropout_key=dropout_key
)
dropout_key, key = jax.random.split(key)
dropout_states = self.dropout(
x = forwarded_states,
key = dropout_key,
inference=not enable_dropout
)
hidden = inputs + dropout_states
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# ff_layer = T5LayerFF(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
# output.shape
# %%
class T5Attention(eqx.Module):
config: FrozenConfigDict
has_relative_attention_bias: bool = False
causal: bool = False # False for encoder, True for decoder
dtype: jnp.dtype
# additional terms
relative_attention_num_buckets: int
relative_attention_max_distance: int
d_model: int
key_value_proj_dim: int
n_heads: int
dropout: float
inner_dim: int
initializer_factor: float
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
# size of k,v projection for each head
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.initializer_factor = self.config.initializer_factor
q_key, key = jax.random.split(key)
self.q = KaimingLinear(
key=q_key,
input_dim=(self.inner_dim * self.key_value_proj_dim),
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
k_key, key = jax.random.split(key)
self.k = KaimingLinear(
key=k_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
v_key, key = jax.random.split(key)
self.v = KaimingLinear(
key=v_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
o_key, key = jax.random.split(key)
self.o = KaimingLinear(
key=o_key,
input_dim=self.inner_dim,
output_dim=self.d_model,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
# 1 bias per head, so output is n_heads
# bias is learned during training
if self.has_relative_attention_bias:
input_dim = self.relative_attention_num_buckets
output_dim = self.n_heads
initializer_factor=self.initializer_factor
# we standardize based on the output dimension,
# which is n_head * kv_proj_dim - during multi head attention
weights_init_std = initializer_factor * (self.inner_dim**-0.5)
# shapes are: (input_dim, output_dim)
weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std
self.relative_attention_bias = eqx.nn.Embedding(
weights=weights
)
@staticmethod
def _relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
# bidirection determines if positive relative positions are valid
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
# relative position range of [0, inf]
relative_position = -jnp.clip(relative_position, a_max=0)
# half of buckets are for exact increments in positions
max_exact = num_buckets // 2
# boolean to assign relative buckets later
is_small = relative_position < max_exact
# other half are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
# jnp.where(condition, x, y), true->x, false->y
# in-place cumulative summation
# yields a list where every element has the correct relative bucket position
# whether its small or large
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
# bias gives weight based on relative distance aside from attention score
def compute_bias(self, query_length, key_length):
"""
Compute binned relative position bias
"""
# arange in the first dim
context_position = jnp.arange(query_length, dtype="i4")[:, None]
# arange in the second dim
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
# The relative position is defined as memory_position - query_position,
# i.e. the distance in tokens from the attending position to the
# attended-to position.
#
# 2D array where each entry represents the distance from a query token
# to a key token
relative_position = memory_position - context_position
# now we apply the earlier bucket creation function
relative_position_bucket = self._relative_position_bucket(
relative_position=relative_position,
bidirectional=(not self.causal), # causal during decode -> not bi
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# retrieve the bias values
# shape (query_length, key_length, n_heads)
values = self.relative_attention_bias(relative_position_bucket)
# shape (1, n_heads, query_length, key_length)
# ready for attention
values = values.transpose((2, 0, 1))[None, :, :, :]
return values
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
def _create_position_bias(
self,
key_states,
query_states,
attention_mask,
init_cache,
seq_length,
causal_attention_mask_shift
):
# unlike the flax version, we don't even check for cache
key_length = key_states.shape[1]
query_length = query_states.shape[1]
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros(
(1, self.n_heads, query_length, key_length),
dtype=self.dtype
)
return position_bias
def __call__(
self,
inputs,
attention_mask=None,
key_value_states=None,
position_bias=None,
output_attentions=False,
enable_dropout=False
):
batch_size, seq_length = inputs.shape[:2]
# q,k,v projections
query_states = self.q(inputs)
key_states = (
self.k(inputs) if key_value_states is None else self.k(key_value_states)
)
value_states = (
self.v(inputs) if key_value_states is None else self.v(key_value_states)
)
# reshape to (batch_size, seq_length, n_heads, head_dim)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# counteract scaling in dot_product_attention_weights function
# not sure if this is a good idea in equinox

279
make_context_data.py Normal file
View File

@ -0,0 +1,279 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset
import jax
import jax.numpy as jnp
import optax
import numpy as np
import functools
from typing import Callable, Optional
import math
from jax.sharding import Mesh, NamedSharding
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
# set cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_enable_x64", False)
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset
import evaluate
from tqdm import tqdm
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ml_collections import ConfigDict
import time
from parallel.dataload import DataPrepare
# %%
# import data
# load training
data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv"
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
# # load valid
# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv"
# # Ensure to include 'ships_idx' in the fields list
# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# # Load the dataset
# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
train_dataset = Dataset.from_list(process_df(train_df))
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(train_dataset, data_config)
# %%
# load model
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
from transformers import FlaxT5ForConditionalGeneration
model = FlaxT5ForConditionalGeneration.from_pretrained(
model_name_or_path
)
params = model.params
# %%
# load data
SEED = 117
batch_size = 256 # per device batch_size
# test_batch_size multiplies by 4 because we shard by 4 later
train_batch_size = batch_size * jax.device_count()
rng = jax.random.PRNGKey(SEED)
# %%
# setup sharding
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((4,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis
# model_sharding=mesh_sharding(PartitionSpec('model'))
replicate_sharding=mesh_sharding(PartitionSpec())
# %%
# define function to get encodings
def get_encodings(batch, params):
input_ids=batch['input_ids']
attention_mask=batch['attention_mask']
input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1]))
attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1]))
encoder_outputs = model.encode(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
train=False,
params=params,
dropout_rng=None
)
# encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed)
# the embedding is not the full embedding size, but the self-attention embed
# size
# parallelize by multiply encoder outputs with attention mask
# shape (batch, embed) -> (batch, embed, 1)
# this helps it to have the same shape as encoder_outputs
expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1)
# here is an element-wise multiply
embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768)
# summing embeddings in axis 1 will sum column-wise into a (batch, 768)
# summing attention_mask in axis 1 will sum column-wise to get the total
# unmasked token count for data entry
mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1)
# the shape of mean_embeddings is (batch, embed), we are ready to return
return mean_embeddings
get_encodings_jit = jax.jit(
functools.partial(get_encodings, params=params),
# rng, batch
in_shardings=(data_sharding),
out_shardings=replicate_sharding,
)
# # %%
# # test the get_encodings function
#
# rng, input_rng = jax.random.split(rng)
# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
# batch = next(train_loader)
# encodings = get_encodings(batch, params)
# # function works!
# %%
# perform actual prediction
encoding_list = []
# note: train_batch_size is batch_size * 4
# we have 4 devices
pred_steps = math.ceil(len(train_dataset) / train_batch_size)
print("***** Running prediction *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num steps = {pred_steps}")
print(f" Instantaneous batch size per device = {batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
rng, input_rng = jax.random.split(rng)
train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
for _ in tqdm(range(pred_steps), desc="Predicting..."):
batch = next(train_loader)
batch = jax.device_put(batch, data_sharding)
encodings = get_encodings_jit(batch)
encoding_list.extend(jax.device_get(encodings))
# %%
encoding_list = jnp.vstack(encoding_list)
# slice up to the previously defined list to unpad
encoding_list = encoding_list[:len(train_dataset)]
print(encoding_list.shape)
# %%
# getting top-k
def top_k_cosine_similarity(M, a, k):
"""
Find the top-k rows in matrix M that are most cosine similar to array a.
Args:
M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector.
a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M.
k (int): Number of top cosine similarities to retrieve.
Returns:
values (jnp.ndarray): Top-k cosine similarity values.
indices (jnp.ndarray): Indices of the top-k most similar rows in M.
"""
# Step 1: Normalize both M and a
M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d)
a_norm = a / jnp.linalg.norm(a) # Shape: (d,)
# Step 2: Compute cosine similarity via dot product
cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,)
# Step 3: Get the top-k values and their indices using jax.lax.top_k
values, indices = jax.lax.top_k(cosine_similarities, k)
return values, indices
# Example usage:
M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions
a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions
# Find top 5 most similar rows
top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5)
print("Top-k cosine similarity values:", top_k_values)
print("Indices of top-k similar rows:", top_k_indices)
# %%

1
nnx/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
MNIST

168
nnx/mnist.py Normal file
View File

@ -0,0 +1,168 @@
# %%
import tensorflow_datasets as tfds # TFDS for MNIST
import tensorflow as tf # TensorFlow operations
tf.random.set_seed(0) # set random seed for reproducibility
num_epochs = 10
batch_size = 32
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
train_ds = train_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize train set
test_ds = test_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize test set
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# %%
from flax import nnx # NNX API
from functools import partial
class CNN(nnx.Module):
"""A simple CNN model."""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
model = CNN(rngs=nnx.Rngs(0))
# %%
nnx.display(model)
# %%
# test the model by feeding an example input
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
# %%
import optax
learning_rate = 0.005
momentum = 0.9
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)
nnx.display(optimizer)
# %%
def loss_fn(model: CNN, batch):
logits = model(batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits
# %%
@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""Train for a single step."""
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimizer.update(grads)
# %%
# evaluation step
@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
loss, logits = loss_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
# %%
# for dataset seed random generation
tf.random.set_seed(0)
# %%
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
metrics_history = {
'train_loss': [],
'train_accuracy': [],
'test_loss': [],
'test_accuracy': [],
}
for step, batch in enumerate(train_ds.as_numpy_iterator()):
# Run the optimization for one step and make a stateful update to the following:
# - the train state's model parameters
# - the optimizer state
# - the training loss and accuracy batch metrics
train_step(model, optimizer, metrics, batch)
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
# Log training metrics
for metric, value in metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
metrics.reset() # reset metrics for test set
# Compute metrics on the test set after each training epoch
for test_batch in test_ds.as_numpy_iterator():
eval_step(model, metrics, test_batch)
# Log test metrics
for metric, value in metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
metrics.reset() # reset metrics for next training epoch
print(
f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
)
print(
f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
)
# %%
# visualize metrics
import matplotlib.pyplot as plt # Visualization
# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
# %%