Feat: implemented attention layer in equinox
This commit is contained in:
parent
a817fe16cc
commit
0762c02b31
|
@ -0,0 +1,172 @@
|
||||||
|
# %%
|
||||||
|
# Prepare dataloader for jax training
|
||||||
|
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from datasets import ClassLabel, Value, Sequence
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax
|
||||||
|
import math
|
||||||
|
from typing import Optional, List, Tuple, Callable, cast
|
||||||
|
|
||||||
|
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
# split_datasets = load_from_disk(file_path)
|
||||||
|
# training_size = len(split_datasets['train'])
|
||||||
|
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
# class takes in a dataset
|
||||||
|
class DataPrepare():
|
||||||
|
|
||||||
|
def __init__(self, raw_dataset, config):
|
||||||
|
self.raw_dataset: Dataset = raw_dataset
|
||||||
|
self.size: int = len(raw_dataset)
|
||||||
|
self.config: ConfigDict = config
|
||||||
|
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
||||||
|
# Define additional special tokens
|
||||||
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
||||||
|
|
||||||
|
|
||||||
|
self.train_dataset = self.preprocess_function(
|
||||||
|
self.raw_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||||
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||||
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||||
|
|
||||||
|
# given a dataset entry, run it through the tokenizer
|
||||||
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||||
|
def preprocess_function(self, example: Dataset):
|
||||||
|
inputs = example['input']
|
||||||
|
targets = example['output']
|
||||||
|
# text_target sets the corresponding label to inputs
|
||||||
|
# there is no need to create a separate 'labels'
|
||||||
|
# produce input_ids and decoder_input_ids
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
inputs,
|
||||||
|
max_length=self.config.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
# we separate it out because we need the attention mask
|
||||||
|
labels = self.tokenizer(
|
||||||
|
text_target=targets,
|
||||||
|
max_length=self.config.max_length,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
|
||||||
|
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
|
||||||
|
# for loss computation
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
# make decoder input ids
|
||||||
|
# this is actually "model output" shifted right
|
||||||
|
decoder_input_ids = self.shift_tokens_right_fn(
|
||||||
|
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
# require by model
|
||||||
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||||
|
# decoder_attention_mask = shift_tokens_right_fn(
|
||||||
|
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
# )
|
||||||
|
# We need decoder_attention_mask so we can ignore pad tokens in loss
|
||||||
|
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
# Example pad function
|
||||||
|
def _pad_to_batch_size(self, batch, target_size):
|
||||||
|
# Get the current batch size
|
||||||
|
input_ids = batch['input_ids']
|
||||||
|
current_size = input_ids.shape[0]
|
||||||
|
if current_size < target_size:
|
||||||
|
# Calculate how much padding is needed
|
||||||
|
padding_size = target_size - current_size
|
||||||
|
# Create padding (e.g., zeros or some appropriate value)
|
||||||
|
padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D
|
||||||
|
# Concatenate to create a full batch
|
||||||
|
# repeat for all arrays in the tree
|
||||||
|
padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch)
|
||||||
|
# padded_batch = jnp.concatenate([batch, padding], axis=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
padded_batch = batch
|
||||||
|
return padded_batch
|
||||||
|
|
||||||
|
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
"""
|
||||||
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||||
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||||
|
"""
|
||||||
|
dataset: Dataset = Dataset.from_dict(self.train_dataset)
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||||
|
batch_idx = np.asarray(batch_idx)
|
||||||
|
else:
|
||||||
|
batch_idx = np.arange(len(dataset))
|
||||||
|
|
||||||
|
if drop_last:
|
||||||
|
steps_per_epoch = len(dataset) // batch_size
|
||||||
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||||
|
minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||||
|
else:
|
||||||
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||||
|
minibatch_list = np.array_split(batch_idx, steps_per_epoch)
|
||||||
|
|
||||||
|
for minibatch in minibatch_list:
|
||||||
|
batch = dataset[minibatch]
|
||||||
|
batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()}
|
||||||
|
batch = self._pad_to_batch_size(batch, batch_size)
|
||||||
|
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
# # testing out the class
|
||||||
|
# # %%
|
||||||
|
# # init object
|
||||||
|
# # e.g. Config
|
||||||
|
#
|
||||||
|
# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc'
|
||||||
|
# data_config = ConfigDict(
|
||||||
|
# dict(
|
||||||
|
# max_length=86,
|
||||||
|
# pad_token_id=0,
|
||||||
|
# decoder_start_token_id=0
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# from datasets import load_from_disk
|
||||||
|
# split_datasets = load_from_disk(file_path)
|
||||||
|
# dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# seed = 117
|
||||||
|
# rng = jax.random.PRNGKey(seed)
|
||||||
|
# train_loader = dataprep.data_loader(rng, batch_size=32)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# batch = next(train_loader)
|
||||||
|
#
|
||||||
|
# print(batch['input_ids'].shape)
|
||||||
|
# print(batch['decoder_input_ids'].shape)
|
||||||
|
#
|
||||||
|
# # %%
|
|
@ -0,0 +1 @@
|
||||||
|
__pycache__/
|
|
@ -0,0 +1,54 @@
|
||||||
|
# an example of stateful operations
|
||||||
|
# %%
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.random as jr
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
from equinox.nn import State, StateIndex, StatefulLayer
|
||||||
|
from jaxtyping import Array
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class Counter(eqx.Module):
|
||||||
|
# This wraps together (a) a unique dictionary key used for looking up a
|
||||||
|
# stateful value, and (b) how that stateful value should be initialised.
|
||||||
|
index: eqx.nn.StateIndex
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
init_state = jnp.array(0)
|
||||||
|
self.index = eqx.nn.StateIndex(init_state)
|
||||||
|
|
||||||
|
# eqx.nn.State stores the state of the model
|
||||||
|
# This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.
|
||||||
|
# This class should be initialised via equinox.nn.make_with_state.
|
||||||
|
#
|
||||||
|
# Basically just a dictionary which (a) works only with StateIndex-s, and which (b)
|
||||||
|
# works around a JAX bug that prevents flattening dicts with `object()` keys, and which
|
||||||
|
# (c) does error-checking that you're using the most up-to-date version of it.
|
||||||
|
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
|
||||||
|
value = state.get(self.index)
|
||||||
|
new_x = x + value
|
||||||
|
|
||||||
|
# Sets a new value for an [`equinox.nn.StateIndex`][], and returns the
|
||||||
|
# updated state.
|
||||||
|
new_state = state.set(self.index, value + 1)
|
||||||
|
return new_x, new_state
|
||||||
|
|
||||||
|
# make_with_state is the recommended way to start a stateful model
|
||||||
|
counter, state = eqx.nn.make_with_state(Counter)()
|
||||||
|
x = jnp.array(2.3)
|
||||||
|
|
||||||
|
num_calls = state.get(counter.index)
|
||||||
|
print(f"Called {num_calls} times.") # 0
|
||||||
|
|
||||||
|
_, state = counter(x, state)
|
||||||
|
num_calls = state.get(counter.index)
|
||||||
|
print(f"Called {num_calls} times.") # 1
|
||||||
|
|
||||||
|
_, state = counter(x, state)
|
||||||
|
num_calls = state.get(counter.index)
|
||||||
|
print(f"Called {num_calls} times.") # 2
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,106 @@
|
||||||
|
# %%
|
||||||
|
# introduction to how flax does stateful operations
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax
|
||||||
|
import flax
|
||||||
|
from jaxtyping import Array
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
class BiasAdderWithRunningMean(nn.Module):
|
||||||
|
momentum: float = 0.9
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
is_initialized = self.has_variable('hehe', 'mean')
|
||||||
|
print(is_initialized)
|
||||||
|
mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:])
|
||||||
|
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
|
||||||
|
if is_initialized:
|
||||||
|
print(mean.value) # notice that value retains after first call
|
||||||
|
mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(
|
||||||
|
x, axis=0, keepdims=True
|
||||||
|
)
|
||||||
|
print(mean.value)
|
||||||
|
return mean.value + bias
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
input_key = jax.random.PRNGKey(0)
|
||||||
|
model = BiasAdderWithRunningMean()
|
||||||
|
inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values
|
||||||
|
variables = model.init(input_key, inputs)
|
||||||
|
# Split state and params (which are updated by optimizer).
|
||||||
|
state, params = flax.core.pop(variables, 'params')
|
||||||
|
print(f"first init: {state}")
|
||||||
|
# %%
|
||||||
|
for i in range(5):
|
||||||
|
new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs
|
||||||
|
# notice how we are threading the state
|
||||||
|
# perform argument unpacking on state dictionary
|
||||||
|
output, state = model.apply({'params': params, **state},
|
||||||
|
new_inputs, mutable=list(state.keys()))
|
||||||
|
|
||||||
|
# mean_state = variables['batch_stats']['mean'] # Access the updated mean state
|
||||||
|
print(f"updated state {state}")
|
||||||
|
print(f"Output after input {i + 1}: {output}")
|
||||||
|
# print(f"Updated running mean state: {mean_state}")
|
||||||
|
# %%
|
||||||
|
###########################################################
|
||||||
|
# example 2
|
||||||
|
from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal
|
||||||
|
import jax.random as random
|
||||||
|
class Foo(nn.Module):
|
||||||
|
features: int
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self):
|
||||||
|
key = self.make_rng('spectral_norm_stats')
|
||||||
|
print(key)
|
||||||
|
u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features))
|
||||||
|
return u0_variable.value
|
||||||
|
|
||||||
|
foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)})
|
||||||
|
Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
||||||
|
# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model = Foo(3)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# state is kept in self.variable, tied to the layer
|
||||||
|
output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
||||||
|
|
||||||
|
# %%
|
||||||
|
output, state = model.apply(
|
||||||
|
foovars,
|
||||||
|
mutable=list(foovars.keys()),
|
||||||
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||||
|
)
|
||||||
|
print(output, state)
|
||||||
|
# %%
|
||||||
|
output, state = model.apply(
|
||||||
|
state,
|
||||||
|
mutable=list(foovars.keys()),
|
||||||
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||||
|
)
|
||||||
|
# no change because input state is the same
|
||||||
|
print(output, state)
|
||||||
|
# %%
|
||||||
|
state_array = state['spectral_norm_stats']['u0']
|
||||||
|
modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0))
|
||||||
|
state['spectral_norm_stats']['u0'] = modified_array
|
||||||
|
# %%
|
||||||
|
# %%
|
||||||
|
output, state = model.apply(
|
||||||
|
state,
|
||||||
|
mutable=list(foovars.keys()),
|
||||||
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||||
|
)
|
||||||
|
# state takes from given state
|
||||||
|
# note the modified 0.9 value
|
||||||
|
# note how the state is not re-initialized
|
||||||
|
print(output, state)
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,252 @@
|
||||||
|
# %%
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
import torch # https://pytorch.org
|
||||||
|
import torchvision # https://pytorch.org
|
||||||
|
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
|
||||||
|
# %%
|
||||||
|
# Hyperparameters
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
LEARNING_RATE = 3e-4
|
||||||
|
STEPS = 300
|
||||||
|
PRINT_EVERY = 30
|
||||||
|
SEED = 5678
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(SEED)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
normalise_data = torchvision.transforms.Compose(
|
||||||
|
[
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
train_dataset = torchvision.datasets.MNIST(
|
||||||
|
"MNIST",
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=normalise_data,
|
||||||
|
)
|
||||||
|
test_dataset = torchvision.datasets.MNIST(
|
||||||
|
"MNIST",
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=normalise_data,
|
||||||
|
)
|
||||||
|
trainloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||||
|
)
|
||||||
|
testloader = torch.utils.data.DataLoader(
|
||||||
|
test_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
|
||||||
|
dummy_x, dummy_y = next(iter(trainloader))
|
||||||
|
dummy_x = dummy_x.numpy()
|
||||||
|
dummy_y = dummy_y.numpy()
|
||||||
|
print(dummy_x.shape) # 64x1x28x28
|
||||||
|
print(dummy_y.shape) # 64
|
||||||
|
print(dummy_y)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class CNN(eqx.Module):
|
||||||
|
layers: list
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||||||
|
# Standard CNN setup: convolutional layer, followed by flattening,
|
||||||
|
# with a small MLP on top.
|
||||||
|
self.layers = [
|
||||||
|
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||||
|
eqx.nn.MaxPool2d(kernel_size=2),
|
||||||
|
jax.nn.relu, # jax functions!!!
|
||||||
|
jnp.ravel,
|
||||||
|
eqx.nn.Linear(1728, 512, key=key2),
|
||||||
|
jax.nn.sigmoid,
|
||||||
|
eqx.nn.Linear(512, 64, key=key3),
|
||||||
|
jax.nn.relu,
|
||||||
|
eqx.nn.Linear(64, 10, key=key4),
|
||||||
|
jax.nn.log_softmax,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
key, subkey = jax.random.split(key, 2)
|
||||||
|
model = CNN(subkey)
|
||||||
|
# %%
|
||||||
|
print(model)
|
||||||
|
# %%
|
||||||
|
# print the first layer: Conv2d
|
||||||
|
print(model.layers[0])
|
||||||
|
# %%
|
||||||
|
# illustrated inference
|
||||||
|
def loss(
|
||||||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
|
||||||
|
# a single input input image of shape (1, 28, 28).
|
||||||
|
#
|
||||||
|
# Therefore, we have to use jax.vmap, which in this case maps our model over the
|
||||||
|
# leading (batch) axis.
|
||||||
|
#
|
||||||
|
# This is an example of writing function for one input, then letting jax
|
||||||
|
# automatically vectorize over the batch dimension
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
return cross_entropy(y, pred_y)
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(
|
||||||
|
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
# y are the true targets, and should be integers 0-9.
|
||||||
|
# pred_y are the log-softmax'd predictions.
|
||||||
|
|
||||||
|
# take_along_axis: take from pred_y along axis 1 according to 2nd argument
|
||||||
|
# expand_dims to axis 1 makes it of shape (y_dim, 1)
|
||||||
|
# since we take along axis 1, each y (in 2nd arg) therefore takes the
|
||||||
|
# corresponding entry of each row in pred_y
|
||||||
|
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||||
|
return -jnp.mean(pred_y) # negative mean of relevant logits
|
||||||
|
|
||||||
|
|
||||||
|
# Example loss
|
||||||
|
loss_value = loss(model, dummy_x, dummy_y)
|
||||||
|
print(loss_value)
|
||||||
|
print(loss_value.shape) # scalar loss
|
||||||
|
# Example inference
|
||||||
|
output = jax.vmap(model)(dummy_x)
|
||||||
|
print(output.shape) # batch of predictions
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This is an error!
|
||||||
|
# the reason is that model has to be parameters, but model has non-parameters
|
||||||
|
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we separate out things that are params from other things
|
||||||
|
# since params are things that are arrays
|
||||||
|
# partition is doing filter(...) and filter(..., inverse=True)
|
||||||
|
params, static = eqx.partition(model, eqx.is_array)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# lets compare the same object in both terms
|
||||||
|
print(static.layers[0])
|
||||||
|
print(params.layers[0])
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# in the loss, we recombine both to form back our model
|
||||||
|
def loss2(params, static, x, y):
|
||||||
|
model = eqx.combine(params, static)
|
||||||
|
return loss(model, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
# Now this will work!
|
||||||
|
# since the grad only looks at the first argument, this works out
|
||||||
|
loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)
|
||||||
|
print(loss_value)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This will work too!
|
||||||
|
# this works the same as the previous
|
||||||
|
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
|
||||||
|
print(value)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# evaluation
|
||||||
|
loss = eqx.filter_jit(loss) # JIT our loss function from earlier!
|
||||||
|
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def compute_accuracy(
|
||||||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
"""This function takes as input the current model
|
||||||
|
and computes the average accuracy on a batch.
|
||||||
|
"""
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
pred_y = jnp.argmax(pred_y, axis=1)
|
||||||
|
return jnp.mean(y == pred_y)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
|
||||||
|
"""This function evaluates the model on the test dataset,
|
||||||
|
computing both the average loss and the average accuracy.
|
||||||
|
"""
|
||||||
|
avg_loss = 0
|
||||||
|
avg_acc = 0
|
||||||
|
for x, y in testloader:
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||||
|
# and both have JIT wrappers, so this is fast.
|
||||||
|
avg_loss += loss(model, x, y)
|
||||||
|
avg_acc += compute_accuracy(model, x, y)
|
||||||
|
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
evaluate(model, testloader)
|
||||||
|
# %%
|
||||||
|
# training
|
||||||
|
optim = optax.adamw(LEARNING_RATE)
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model: CNN,
|
||||||
|
trainloader: torch.utils.data.DataLoader,
|
||||||
|
testloader: torch.utils.data.DataLoader,
|
||||||
|
optim: optax.GradientTransformation,
|
||||||
|
steps: int,
|
||||||
|
print_every: int,
|
||||||
|
) -> CNN:
|
||||||
|
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||||
|
# so filter out everything else.
|
||||||
|
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||||
|
|
||||||
|
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||||
|
# the model -- into a single JIT region. This ensures things run as fast as
|
||||||
|
# possible.
|
||||||
|
@eqx.filter_jit
|
||||||
|
def make_step(
|
||||||
|
model: CNN,
|
||||||
|
opt_state: PyTree,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"],
|
||||||
|
):
|
||||||
|
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||||||
|
updates, opt_state = optim.update(grads, opt_state, model)
|
||||||
|
model = eqx.apply_updates(model, updates)
|
||||||
|
return model, opt_state, loss_value
|
||||||
|
|
||||||
|
# Loop over our training dataset as many times as we need.
|
||||||
|
def infinite_trainloader():
|
||||||
|
while True:
|
||||||
|
yield from trainloader
|
||||||
|
|
||||||
|
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||||
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
|
# so convert them to NumPy arrays.
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||||
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
|
test_loss, test_accuracy = evaluate(model, testloader)
|
||||||
|
print(
|
||||||
|
f"{step=}, train_loss={train_loss.item()}, "
|
||||||
|
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,591 @@
|
||||||
|
# %%
|
||||||
|
# package imports from equinox BERT example
|
||||||
|
import functools
|
||||||
|
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
# import einops # https://github.com/arogozhnikov/einops
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
from datasets import load_dataset # https://github.com/huggingface/datasets
|
||||||
|
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
|
||||||
|
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
|
||||||
|
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
|
||||||
|
from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
|
||||||
|
# helper functions for attention computation
|
||||||
|
# they are implemented with jax w/o flax
|
||||||
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
# %%
|
||||||
|
class T5LayerNorm(eqx.Module):
|
||||||
|
eps: float = 1e-6
|
||||||
|
weight: jax.Array
|
||||||
|
# staticmethod forces the method to be by itself
|
||||||
|
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
hidden_size: int,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
# dtype: jnp.dtype = jnp.float32,
|
||||||
|
):
|
||||||
|
# self.dtype = dtype
|
||||||
|
# self.params = {
|
||||||
|
# 'weight': self.weight_init(key, (hidden_size,), dtype)
|
||||||
|
# }
|
||||||
|
# force the use of float32
|
||||||
|
# note that the key argument is ignored, so key is actually optional
|
||||||
|
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
|
||||||
|
|
||||||
|
# takes in argument for hidden states so that it can fall through and remain
|
||||||
|
# a pure function
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
"""
|
||||||
|
Construct a layernorm module in the T5 style;
|
||||||
|
No bias and no subtraction of mean
|
||||||
|
"""
|
||||||
|
# always compute in float32 for layer norm
|
||||||
|
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
||||||
|
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
||||||
|
|
||||||
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # testing T5LayerNorm
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# hidden_size = 128 # Example hidden size
|
||||||
|
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
|
||||||
|
# # Create some example input data
|
||||||
|
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
|
||||||
|
# # Forward pass
|
||||||
|
# output = layer_norm(hidden_states)
|
||||||
|
# print("Output shape:", output.shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class KaimingLinear(eqx.Module):
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
weights: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
weights_init_std: float,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
):
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# the initialization strategy is to standardize on output dimension
|
||||||
|
# shapes are: (input_dim, output_dim)
|
||||||
|
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: Float[Array, " input"],
|
||||||
|
):
|
||||||
|
hidden = jnp.dot(inputs, self.weights)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# this function fortunately supports batched operations by default due to broadcasting
|
||||||
|
class T5DenseActDense(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
wi: jax.Array
|
||||||
|
wo: jax.Array
|
||||||
|
dropout: eqx.nn.Dropout
|
||||||
|
act: jax.nn.relu
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype,
|
||||||
|
key: jax.random.PRNGKey
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
mlp_key, output_key = jax.random.split(key)
|
||||||
|
# the initialization strategy is to standardize on output dimension
|
||||||
|
# input
|
||||||
|
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
||||||
|
# shapes are: (config.d_model, config.d_ff)
|
||||||
|
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
|
||||||
|
self.wi = KaimingLinear(
|
||||||
|
key=mlp_key,
|
||||||
|
input_dim=self.config.d_model,
|
||||||
|
output_dim=self.config.d_ff,
|
||||||
|
weights_init_std=wi_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
# output
|
||||||
|
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
||||||
|
# shapes are: (config.d_ff, config.d_model)
|
||||||
|
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
|
||||||
|
self.wo = KaimingLinear(
|
||||||
|
key=mlp_key,
|
||||||
|
input_dim=self.config.d_ff,
|
||||||
|
output_dim=self.config.d_model,
|
||||||
|
weights_init_std=wo_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||||
|
# just set to relu for now since the smaller T5's use relu
|
||||||
|
self.act = jax.nn.relu
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: Float[Array, " d_model"],
|
||||||
|
enable_dropout: bool = False,
|
||||||
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||||
|
) -> Float[Array, " d_model"]:
|
||||||
|
hidden = self.wi(inputs)
|
||||||
|
# hidden = jnp.dot(inputs, self.wi)
|
||||||
|
hidden = self.act(hidden)
|
||||||
|
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
|
||||||
|
hidden = self.wo(hidden)
|
||||||
|
# hidden = jnp.dot(hidden, self.wo)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # test for T5DenseActDense
|
||||||
|
# # create fake config
|
||||||
|
# config_dict = {
|
||||||
|
# 'd_model': 768,
|
||||||
|
# 'd_ff': 2048,
|
||||||
|
# 'dropout_rate': 0.1,
|
||||||
|
# 'initializer_factor': 1.0,
|
||||||
|
# }
|
||||||
|
# # Create a FrozenDict from the standard dictionary
|
||||||
|
# frozen_config = FrozenConfigDict(config_dict)
|
||||||
|
# # initialize model
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# dense = T5DenseActDense(
|
||||||
|
# key=key,
|
||||||
|
# config=frozen_config,
|
||||||
|
# dtype=jnp.float32
|
||||||
|
# )
|
||||||
|
# input_key, key = jax.random.split(key)
|
||||||
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||||
|
# dropout_key, key = jax.random.split(key)
|
||||||
|
# output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||||
|
# output.shape
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class T5LayerFF(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
dtype: jnp.dtype
|
||||||
|
DenseReluDense: T5DenseActDense
|
||||||
|
layer_norm: T5LayerNorm
|
||||||
|
dropout: eqx.nn.Dropout
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
dense_key, key = jax.random.split(key)
|
||||||
|
# args: key, config, dtype
|
||||||
|
self.DenseReluDense = T5DenseActDense(
|
||||||
|
key=dense_key,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
layer_key, key = jax.random.split(key)
|
||||||
|
# args: key, hidden_size
|
||||||
|
self.layer_norm = T5LayerNorm(
|
||||||
|
key=layer_key,
|
||||||
|
hidden_size=self.config.d_model
|
||||||
|
)
|
||||||
|
# args: dropout_rate
|
||||||
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self: eqx.Module,
|
||||||
|
inputs: Float[Array, " d_model"],
|
||||||
|
enable_dropout: bool =False,
|
||||||
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||||
|
):
|
||||||
|
forwarded_states = self.layer_norm(inputs)
|
||||||
|
dropout_key, key = jax.random.split(dropout_key)
|
||||||
|
forwarded_states = self.DenseReluDense(
|
||||||
|
inputs=forwarded_states,
|
||||||
|
enable_dropout=enable_dropout,
|
||||||
|
dropout_key=dropout_key
|
||||||
|
)
|
||||||
|
dropout_key, key = jax.random.split(key)
|
||||||
|
dropout_states = self.dropout(
|
||||||
|
x = forwarded_states,
|
||||||
|
inference=not enable_dropout,
|
||||||
|
key = dropout_key,
|
||||||
|
)
|
||||||
|
hidden = inputs + dropout_states
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # test for T5DenseActDense
|
||||||
|
# # create fake config
|
||||||
|
# config_dict = {
|
||||||
|
# 'd_model': 768,
|
||||||
|
# 'd_ff': 2048,
|
||||||
|
# 'dropout_rate': 0.1,
|
||||||
|
# 'initializer_factor': 1.0,
|
||||||
|
# }
|
||||||
|
# # Create a FrozenDict from the standard dictionary
|
||||||
|
# frozen_config = FrozenConfigDict(config_dict)
|
||||||
|
# # initialize model
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# ff_layer = T5LayerFF(
|
||||||
|
# key=key,
|
||||||
|
# config=frozen_config,
|
||||||
|
# dtype=jnp.float32
|
||||||
|
# )
|
||||||
|
# input_key, key = jax.random.split(key)
|
||||||
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||||
|
# dropout_key, key = jax.random.split(key)
|
||||||
|
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||||
|
# output.shape
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class T5Attention(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
has_relative_attention_bias: bool = False
|
||||||
|
causal: bool = False # False for encoder, True for decoder
|
||||||
|
dtype: jnp.dtype
|
||||||
|
|
||||||
|
# parameters
|
||||||
|
q: jax.Array
|
||||||
|
k: jax.Array
|
||||||
|
v: jax.Array
|
||||||
|
o: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
# additional terms
|
||||||
|
relative_attention_num_buckets: int
|
||||||
|
relative_attention_max_distance: int
|
||||||
|
d_model: int
|
||||||
|
key_value_proj_dim: int
|
||||||
|
n_heads: int
|
||||||
|
dropout: float
|
||||||
|
inner_dim: int
|
||||||
|
initializer_factor: float
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.dtype = dtype
|
||||||
|
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
||||||
|
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
||||||
|
self.d_model = self.config.d_model
|
||||||
|
# size of k,v projection for each head
|
||||||
|
self.key_value_proj_dim = self.config.d_kv
|
||||||
|
self.n_heads = self.config.num_heads
|
||||||
|
self.dropout = self.config.dropout_rate
|
||||||
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||||
|
self.initializer_factor = self.config.initializer_factor
|
||||||
|
|
||||||
|
q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
||||||
|
kv_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
||||||
|
o_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
||||||
|
|
||||||
|
q_key, key = jax.random.split(key)
|
||||||
|
self.q = KaimingLinear(
|
||||||
|
key=q_key,
|
||||||
|
input_dim=(self.inner_dim),
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
weights_init_std=q_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
k_key, key = jax.random.split(key)
|
||||||
|
self.k = KaimingLinear(
|
||||||
|
key=k_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
weights_init_std=kv_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
v_key, key = jax.random.split(key)
|
||||||
|
self.v = KaimingLinear(
|
||||||
|
key=v_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
weights_init_std=kv_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
o_key, key = jax.random.split(key)
|
||||||
|
self.o = KaimingLinear(
|
||||||
|
key=o_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.d_model,
|
||||||
|
weights_init_std=o_init_std,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
# bidirection determines if positive relative positions are valid
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0) * num_buckets
|
||||||
|
relative_position = jnp.abs(relative_position)
|
||||||
|
else:
|
||||||
|
# relative position range of [0, inf]
|
||||||
|
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||||
|
|
||||||
|
# half of buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
# boolean to assign relative buckets later
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# other half are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||||
|
|
||||||
|
# jnp.where(condition, x, y), true->x, false->y
|
||||||
|
# in-place cumulative summation
|
||||||
|
# yields a list where every element has the correct relative bucket position
|
||||||
|
# whether its small or large
|
||||||
|
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||||
|
|
||||||
|
return relative_buckets.astype("i4")
|
||||||
|
|
||||||
|
# bias gives weight based on relative distance aside from attention score
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
"""
|
||||||
|
Compute binned relative position bias
|
||||||
|
"""
|
||||||
|
# arange in the first dim
|
||||||
|
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
||||||
|
# arange in the second dim
|
||||||
|
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
||||||
|
|
||||||
|
# The relative position is defined as memory_position - query_position,
|
||||||
|
# i.e. the distance in tokens from the attending position to the
|
||||||
|
# attended-to position.
|
||||||
|
#
|
||||||
|
# 2D array where each entry represents the distance from a query token
|
||||||
|
# to a key token
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
# now we apply the earlier bucket creation function
|
||||||
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
|
relative_position=relative_position,
|
||||||
|
bidirectional=(not self.causal), # causal during decode -> not bi
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
# retrieve the bias values
|
||||||
|
# shape (query_length, key_length, n_heads)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
# shape (1, n_heads, query_length, key_length)
|
||||||
|
# ready for attention
|
||||||
|
values = values.transpose((2, 0, 1))[None, :, :, :]
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
# from (batch_size, seq_length, d_model) to
|
||||||
|
# (batch_size, seq_length, n_heads, head_dim)
|
||||||
|
def _split_heads(self, hidden_states):
|
||||||
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
||||||
|
|
||||||
|
# from (batch_size, seq_length, n_heads, head_dim) to
|
||||||
|
# (batch_size, seq_length, d_model)
|
||||||
|
def _merge_heads(self, hidden_states):
|
||||||
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_position_bias(
|
||||||
|
self,
|
||||||
|
key_states,
|
||||||
|
query_states,
|
||||||
|
attention_mask,
|
||||||
|
):
|
||||||
|
# unlike the flax version, we don't even check for cache
|
||||||
|
key_length = key_states.shape[1]
|
||||||
|
query_length = query_states.shape[1]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
position_bias = self.compute_bias(query_length, key_length)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
position_bias = jnp.zeros_like(attention_mask)
|
||||||
|
else:
|
||||||
|
position_bias = jnp.zeros(
|
||||||
|
(1, self.n_heads, query_length, key_length),
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return position_bias
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
key_value_states=None,
|
||||||
|
position_bias=None,
|
||||||
|
output_attentions=False,
|
||||||
|
enable_dropout=False,
|
||||||
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||||
|
):
|
||||||
|
# expected input shape: (batch_size, seq_len, d_model)
|
||||||
|
# expected output: tuple of 2 arrays same shape as input
|
||||||
|
# (attn, position_bias)
|
||||||
|
batch_size, seq_length = inputs.shape[:2]
|
||||||
|
|
||||||
|
# q,k,v projections
|
||||||
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||||
|
query_states = self.q(inputs)
|
||||||
|
key_states = (
|
||||||
|
self.k(inputs) if key_value_states is None else self.k(key_value_states)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v(inputs) if key_value_states is None else self.v(key_value_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
||||||
|
query_states = self._split_heads(query_states)
|
||||||
|
key_states = self._split_heads(key_states)
|
||||||
|
value_states = self._split_heads(value_states)
|
||||||
|
|
||||||
|
# counteract scaling in dot_product_attention_weights function
|
||||||
|
query_states *= jnp.sqrt(query_states.shape[-1])
|
||||||
|
|
||||||
|
# create causal attention_mask
|
||||||
|
if self.causal:
|
||||||
|
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
||||||
|
|
||||||
|
# broadcast causal attention mask & attention mask to fit for merge
|
||||||
|
causal_attention_mask = jnp.broadcast_to(
|
||||||
|
causal_attention_mask,
|
||||||
|
(batch_size,) + causal_attention_mask.shape[1:]
|
||||||
|
)
|
||||||
|
attention_mask = jnp.broadcast_to(
|
||||||
|
jnp.expand_dims(attention_mask, axis=(-3, -2)),
|
||||||
|
causal_attention_mask.shape
|
||||||
|
)
|
||||||
|
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||||
|
|
||||||
|
# replace masked positions with -10_000
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask_value = jnp.finfo(self.dtype).min
|
||||||
|
attention_mask = jax.lax.select(
|
||||||
|
attention_mask > 0,
|
||||||
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
|
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_bias is None:
|
||||||
|
# compute position bias (only for first layer)
|
||||||
|
position_bias = self._create_position_bias(
|
||||||
|
key_states, query_states, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
position_bias = position_bias + attention_mask
|
||||||
|
|
||||||
|
# Softmax(QK^T)
|
||||||
|
attn_weights = dot_product_attention_weights(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
bias=position_bias,
|
||||||
|
dropout_rng=dropout_key,
|
||||||
|
dropout_rate=self.dropout,
|
||||||
|
broadcast_dropout=True,
|
||||||
|
deterministic=not enable_dropout,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# multiply with value states
|
||||||
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||||
|
|
||||||
|
# bring back to (batch_size, seq_length, d_model)
|
||||||
|
attn_output = self._merge_heads(attn_output)
|
||||||
|
|
||||||
|
# apply output matrix
|
||||||
|
attn_output = self.o(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, position_bias)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs = outputs + (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # test for T5Attention
|
||||||
|
# # create fake config
|
||||||
|
# config_dict = {
|
||||||
|
# 'relative_attention_num_buckets': 32,
|
||||||
|
# 'relative_attention_max_distance': 128,
|
||||||
|
# 'd_model': 768, # 64 * 12
|
||||||
|
# 'd_kv': 64,
|
||||||
|
# 'num_heads': 12,
|
||||||
|
# 'dropout_rate': 0.1,
|
||||||
|
# 'initializer_factor': 1.0,
|
||||||
|
# }
|
||||||
|
# # Create a FrozenDict from the standard dictionary
|
||||||
|
# frozen_config = FrozenConfigDict(config_dict)
|
||||||
|
# # initialize model
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# attn_layer = T5Attention(
|
||||||
|
# key=key,
|
||||||
|
# config=frozen_config,
|
||||||
|
# dtype=jnp.float32
|
||||||
|
# )
|
||||||
|
# input_key, key = jax.random.split(key)
|
||||||
|
# # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||||
|
# batch_size = 1
|
||||||
|
# seq_length = 10
|
||||||
|
# inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model))
|
||||||
|
# dropout_key, key = jax.random.split(key)
|
||||||
|
# output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||||
|
# print(len(output))
|
||||||
|
# print(output[0].shape)
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,495 @@
|
||||||
|
# %%
|
||||||
|
# package imports from equinox BERT example
|
||||||
|
import functools
|
||||||
|
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
# import einops # https://github.com/arogozhnikov/einops
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
from datasets import load_dataset # https://github.com/huggingface/datasets
|
||||||
|
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
|
||||||
|
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
|
||||||
|
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
|
||||||
|
from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
# %%
|
||||||
|
class T5LayerNorm(eqx.Module):
|
||||||
|
eps: float = 1e-6
|
||||||
|
weight: jax.Array
|
||||||
|
# staticmethod forces the method to be by itself
|
||||||
|
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
hidden_size: int,
|
||||||
|
# dtype: jnp.dtype = jnp.float32,
|
||||||
|
):
|
||||||
|
# self.dtype = dtype
|
||||||
|
# self.params = {
|
||||||
|
# 'weight': self.weight_init(key, (hidden_size,), dtype)
|
||||||
|
# }
|
||||||
|
# force the use of float32
|
||||||
|
# note that the key argument is ignored, so key is actually optional
|
||||||
|
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
|
||||||
|
|
||||||
|
# takes in argument for hidden states so that it can fall through and remain
|
||||||
|
# a pure function
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
"""
|
||||||
|
Construct a layernorm module in the T5 style;
|
||||||
|
No bias and no subtraction of mean
|
||||||
|
"""
|
||||||
|
# always compute in float32 for layer norm
|
||||||
|
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
||||||
|
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
||||||
|
|
||||||
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # testing T5LayerNorm
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# hidden_size = 128 # Example hidden size
|
||||||
|
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
|
||||||
|
# # Create some example input data
|
||||||
|
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
|
||||||
|
# # Forward pass
|
||||||
|
# output = layer_norm(hidden_states)
|
||||||
|
# print("Output shape:", output.shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class KaimingLinear(eqx.Module):
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
weights: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
initializer_factor: float,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
):
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# the initialization strategy is to standardize on output dimension
|
||||||
|
# input
|
||||||
|
weights_init_std = initializer_factor * (input_dim**-0.5)
|
||||||
|
# shapes are: (input_dim, output_dim)
|
||||||
|
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: Float[Array, " input"],
|
||||||
|
):
|
||||||
|
hidden = jnp.dot(inputs, self.weights)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# this function fortunately supports batched operations by default due to broadcasting
|
||||||
|
class T5DenseActDense(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
wi: jax.Array
|
||||||
|
wo: jax.Array
|
||||||
|
dropout: eqx.nn.Dropout
|
||||||
|
act: jax.nn.relu
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
mlp_key, output_key = jax.random.split(key)
|
||||||
|
# the initialization strategy is to standardize on output dimension
|
||||||
|
# input
|
||||||
|
# wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
||||||
|
# shapes are: (config.d_model, config.d_ff)
|
||||||
|
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
|
||||||
|
self.wi = KaimingLinear(
|
||||||
|
key=mlp_key,
|
||||||
|
input_dim=self.config.d_model,
|
||||||
|
output_dim=self.config.d_ff,
|
||||||
|
initializer_factor=self.config.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
# output
|
||||||
|
# wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
||||||
|
# shapes are: (config.d_ff, config.d_model)
|
||||||
|
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
|
||||||
|
self.wo = KaimingLinear(
|
||||||
|
key=mlp_key,
|
||||||
|
input_dim=self.config.d_ff,
|
||||||
|
output_dim=self.config.d_model,
|
||||||
|
initializer_factor=self.config.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||||
|
# just set to relu for now since the smaller T5's use relu
|
||||||
|
self.act = jax.nn.relu
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: Float[Array, " d_model"],
|
||||||
|
enable_dropout: bool = False,
|
||||||
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||||
|
) -> Float[Array, " d_model"]:
|
||||||
|
hidden = self.wi(inputs)
|
||||||
|
# hidden = jnp.dot(inputs, self.wi)
|
||||||
|
hidden = self.act(hidden)
|
||||||
|
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
|
||||||
|
hidden = self.wo(hidden)
|
||||||
|
# hidden = jnp.dot(hidden, self.wo)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # test for T5DenseActDense
|
||||||
|
# # create fake config
|
||||||
|
# config_dict = {
|
||||||
|
# 'd_model': 768,
|
||||||
|
# 'd_ff': 2048,
|
||||||
|
# 'dropout_rate': 0.1,
|
||||||
|
# 'initializer_factor': 1.0,
|
||||||
|
# }
|
||||||
|
# # Create a FrozenDict from the standard dictionary
|
||||||
|
# frozen_config = FrozenConfigDict(config_dict)
|
||||||
|
# # initialize model
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# dense = T5DenseActDense(
|
||||||
|
# key=key,
|
||||||
|
# config=frozen_config,
|
||||||
|
# dtype=jnp.float32
|
||||||
|
# )
|
||||||
|
# input_key, key = jax.random.split(key)
|
||||||
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||||
|
# dropout_key, key = jax.random.split(key)
|
||||||
|
# output = dense(inputs=inputs, enable_dropout=False, key=dropout_key)
|
||||||
|
# output.shape
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class T5LayerFF(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
dtype: jnp.dtype
|
||||||
|
DenseReluDense: T5DenseActDense
|
||||||
|
layer_norm: T5LayerNorm
|
||||||
|
dropout: eqx.nn.Dropout
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
dense_key, key = jax.random.split(key)
|
||||||
|
# args: key, config, dtype
|
||||||
|
self.DenseReluDense = T5DenseActDense(
|
||||||
|
key=dense_key,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
layer_key, key = jax.random.split(key)
|
||||||
|
# args: key, hidden_size
|
||||||
|
self.layer_norm = T5LayerNorm(
|
||||||
|
key=layer_key,
|
||||||
|
hidden_size=self.config.d_model
|
||||||
|
)
|
||||||
|
# args: dropout_rate
|
||||||
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self: eqx.Module,
|
||||||
|
inputs: Float[Array, " d_model"],
|
||||||
|
enable_dropout: bool =False,
|
||||||
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||||
|
):
|
||||||
|
forwarded_states = self.layer_norm(inputs)
|
||||||
|
dropout_key, key = jax.random.split(dropout_key)
|
||||||
|
forwarded_states = self.DenseReluDense(
|
||||||
|
inputs=forwarded_states,
|
||||||
|
enable_dropout=enable_dropout,
|
||||||
|
dropout_key=dropout_key
|
||||||
|
)
|
||||||
|
dropout_key, key = jax.random.split(key)
|
||||||
|
dropout_states = self.dropout(
|
||||||
|
x = forwarded_states,
|
||||||
|
key = dropout_key,
|
||||||
|
inference=not enable_dropout
|
||||||
|
)
|
||||||
|
hidden = inputs + dropout_states
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# # test for T5DenseActDense
|
||||||
|
# # create fake config
|
||||||
|
# config_dict = {
|
||||||
|
# 'd_model': 768,
|
||||||
|
# 'd_ff': 2048,
|
||||||
|
# 'dropout_rate': 0.1,
|
||||||
|
# 'initializer_factor': 1.0,
|
||||||
|
# }
|
||||||
|
# # Create a FrozenDict from the standard dictionary
|
||||||
|
# frozen_config = FrozenConfigDict(config_dict)
|
||||||
|
# # initialize model
|
||||||
|
# key = jax.random.PRNGKey(0)
|
||||||
|
# ff_layer = T5LayerFF(
|
||||||
|
# key=key,
|
||||||
|
# config=frozen_config,
|
||||||
|
# dtype=jnp.float32
|
||||||
|
# )
|
||||||
|
# input_key, key = jax.random.split(key)
|
||||||
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||||
|
# dropout_key, key = jax.random.split(key)
|
||||||
|
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||||
|
# output.shape
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class T5Attention(eqx.Module):
|
||||||
|
config: FrozenConfigDict
|
||||||
|
has_relative_attention_bias: bool = False
|
||||||
|
causal: bool = False # False for encoder, True for decoder
|
||||||
|
dtype: jnp.dtype
|
||||||
|
|
||||||
|
# additional terms
|
||||||
|
relative_attention_num_buckets: int
|
||||||
|
relative_attention_max_distance: int
|
||||||
|
d_model: int
|
||||||
|
key_value_proj_dim: int
|
||||||
|
n_heads: int
|
||||||
|
dropout: float
|
||||||
|
inner_dim: int
|
||||||
|
initializer_factor: float
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self: eqx.Module,
|
||||||
|
key: jax.random.PRNGKey,
|
||||||
|
config: FrozenConfigDict,
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
):
|
||||||
|
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
||||||
|
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
||||||
|
self.d_model = self.config.d_model
|
||||||
|
# size of k,v projection for each head
|
||||||
|
self.key_value_proj_dim = self.config.d_kv
|
||||||
|
self.n_heads = self.config.num_heads
|
||||||
|
self.dropout = self.config.dropout_rate
|
||||||
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||||
|
self.initializer_factor = self.config.initializer_factor
|
||||||
|
|
||||||
|
|
||||||
|
q_key, key = jax.random.split(key)
|
||||||
|
self.q = KaimingLinear(
|
||||||
|
key=q_key,
|
||||||
|
input_dim=(self.inner_dim * self.key_value_proj_dim),
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
k_key, key = jax.random.split(key)
|
||||||
|
self.k = KaimingLinear(
|
||||||
|
key=k_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
v_key, key = jax.random.split(key)
|
||||||
|
self.v = KaimingLinear(
|
||||||
|
key=v_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.inner_dim,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
o_key, key = jax.random.split(key)
|
||||||
|
self.o = KaimingLinear(
|
||||||
|
key=o_key,
|
||||||
|
input_dim=self.inner_dim,
|
||||||
|
output_dim=self.d_model,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1 bias per head, so output is n_heads
|
||||||
|
# bias is learned during training
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
input_dim = self.relative_attention_num_buckets
|
||||||
|
output_dim = self.n_heads
|
||||||
|
initializer_factor=self.initializer_factor
|
||||||
|
# we standardize based on the output dimension,
|
||||||
|
# which is n_head * kv_proj_dim - during multi head attention
|
||||||
|
weights_init_std = initializer_factor * (self.inner_dim**-0.5)
|
||||||
|
# shapes are: (input_dim, output_dim)
|
||||||
|
weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std
|
||||||
|
|
||||||
|
self.relative_attention_bias = eqx.nn.Embedding(
|
||||||
|
weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
# bidirection determines if positive relative positions are valid
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0) * num_buckets
|
||||||
|
relative_position = jnp.abs(relative_position)
|
||||||
|
else:
|
||||||
|
# relative position range of [0, inf]
|
||||||
|
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||||
|
|
||||||
|
# half of buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
# boolean to assign relative buckets later
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# other half are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||||
|
|
||||||
|
# jnp.where(condition, x, y), true->x, false->y
|
||||||
|
# in-place cumulative summation
|
||||||
|
# yields a list where every element has the correct relative bucket position
|
||||||
|
# whether its small or large
|
||||||
|
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||||
|
|
||||||
|
return relative_buckets.astype("i4")
|
||||||
|
|
||||||
|
# bias gives weight based on relative distance aside from attention score
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
"""
|
||||||
|
Compute binned relative position bias
|
||||||
|
"""
|
||||||
|
# arange in the first dim
|
||||||
|
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
||||||
|
# arange in the second dim
|
||||||
|
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
||||||
|
|
||||||
|
# The relative position is defined as memory_position - query_position,
|
||||||
|
# i.e. the distance in tokens from the attending position to the
|
||||||
|
# attended-to position.
|
||||||
|
#
|
||||||
|
# 2D array where each entry represents the distance from a query token
|
||||||
|
# to a key token
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
# now we apply the earlier bucket creation function
|
||||||
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
|
relative_position=relative_position,
|
||||||
|
bidirectional=(not self.causal), # causal during decode -> not bi
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
# retrieve the bias values
|
||||||
|
# shape (query_length, key_length, n_heads)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
# shape (1, n_heads, query_length, key_length)
|
||||||
|
# ready for attention
|
||||||
|
values = values.transpose((2, 0, 1))[None, :, :, :]
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _split_heads(self, hidden_states):
|
||||||
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
||||||
|
|
||||||
|
def _merge_heads(self, hidden_states):
|
||||||
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_position_bias(
|
||||||
|
self,
|
||||||
|
key_states,
|
||||||
|
query_states,
|
||||||
|
attention_mask,
|
||||||
|
init_cache,
|
||||||
|
seq_length,
|
||||||
|
causal_attention_mask_shift
|
||||||
|
):
|
||||||
|
# unlike the flax version, we don't even check for cache
|
||||||
|
key_length = key_states.shape[1]
|
||||||
|
query_length = query_states.shape[1]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
position_bias = self.compute_bias(query_length, key_length)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
position_bias = jnp.zeros_like(attention_mask)
|
||||||
|
else:
|
||||||
|
position_bias = jnp.zeros(
|
||||||
|
(1, self.n_heads, query_length, key_length),
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return position_bias
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
key_value_states=None,
|
||||||
|
position_bias=None,
|
||||||
|
output_attentions=False,
|
||||||
|
enable_dropout=False
|
||||||
|
):
|
||||||
|
batch_size, seq_length = inputs.shape[:2]
|
||||||
|
|
||||||
|
# q,k,v projections
|
||||||
|
query_states = self.q(inputs)
|
||||||
|
key_states = (
|
||||||
|
self.k(inputs) if key_value_states is None else self.k(key_value_states)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v(inputs) if key_value_states is None else self.v(key_value_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
||||||
|
query_states = self._split_heads(query_states)
|
||||||
|
key_states = self._split_heads(key_states)
|
||||||
|
value_states = self._split_heads(value_states)
|
||||||
|
|
||||||
|
# counteract scaling in dot_product_attention_weights function
|
||||||
|
# not sure if this is a good idea in equinox
|
||||||
|
|
|
@ -0,0 +1,279 @@
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
|
)
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
|
import numpy as np
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Optional
|
||||||
|
import math
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
jax.config.update("jax_enable_x64", False)
|
||||||
|
|
||||||
|
|
||||||
|
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
from datasets import Dataset
|
||||||
|
import evaluate
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# import data
|
||||||
|
|
||||||
|
# load training
|
||||||
|
data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv"
|
||||||
|
# Ensure to include 'ships_idx' in the fields list
|
||||||
|
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||||
|
# Load the dataset
|
||||||
|
train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||||
|
|
||||||
|
# # load valid
|
||||||
|
# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv"
|
||||||
|
# # Ensure to include 'ships_idx' in the fields list
|
||||||
|
# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||||
|
# # Load the dataset
|
||||||
|
# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||||
|
def process_df(df):
|
||||||
|
output_list = [{
|
||||||
|
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||||||
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||||||
|
} for _, row in df.iterrows()]
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
|
# takes 1 minute to run without batching
|
||||||
|
train_dataset = Dataset.from_list(process_df(train_df))
|
||||||
|
|
||||||
|
|
||||||
|
print("preparing data")
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(train_dataset, data_config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# load model
|
||||||
|
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
||||||
|
model_name_or_path
|
||||||
|
)
|
||||||
|
params = model.params
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# load data
|
||||||
|
SEED = 117
|
||||||
|
batch_size = 256 # per device batch_size
|
||||||
|
# test_batch_size multiplies by 4 because we shard by 4 later
|
||||||
|
train_batch_size = batch_size * jax.device_count()
|
||||||
|
rng = jax.random.PRNGKey(SEED)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# setup sharding
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((4,1))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis
|
||||||
|
# model_sharding=mesh_sharding(PartitionSpec('model'))
|
||||||
|
replicate_sharding=mesh_sharding(PartitionSpec())
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# define function to get encodings
|
||||||
|
|
||||||
|
def get_encodings(batch, params):
|
||||||
|
input_ids=batch['input_ids']
|
||||||
|
attention_mask=batch['attention_mask']
|
||||||
|
input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1]))
|
||||||
|
attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1]))
|
||||||
|
encoder_outputs = model.encode(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
train=False,
|
||||||
|
params=params,
|
||||||
|
dropout_rng=None
|
||||||
|
)
|
||||||
|
# encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed)
|
||||||
|
# the embedding is not the full embedding size, but the self-attention embed
|
||||||
|
# size
|
||||||
|
|
||||||
|
# parallelize by multiply encoder outputs with attention mask
|
||||||
|
# shape (batch, embed) -> (batch, embed, 1)
|
||||||
|
# this helps it to have the same shape as encoder_outputs
|
||||||
|
expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1)
|
||||||
|
# here is an element-wise multiply
|
||||||
|
embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768)
|
||||||
|
# summing embeddings in axis 1 will sum column-wise into a (batch, 768)
|
||||||
|
# summing attention_mask in axis 1 will sum column-wise to get the total
|
||||||
|
# unmasked token count for data entry
|
||||||
|
mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1)
|
||||||
|
# the shape of mean_embeddings is (batch, embed), we are ready to return
|
||||||
|
return mean_embeddings
|
||||||
|
|
||||||
|
get_encodings_jit = jax.jit(
|
||||||
|
functools.partial(get_encodings, params=params),
|
||||||
|
# rng, batch
|
||||||
|
in_shardings=(data_sharding),
|
||||||
|
out_shardings=replicate_sharding,
|
||||||
|
)
|
||||||
|
# # %%
|
||||||
|
# # test the get_encodings function
|
||||||
|
#
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
||||||
|
# batch = next(train_loader)
|
||||||
|
# encodings = get_encodings(batch, params)
|
||||||
|
# # function works!
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# perform actual prediction
|
||||||
|
encoding_list = []
|
||||||
|
# note: train_batch_size is batch_size * 4
|
||||||
|
# we have 4 devices
|
||||||
|
pred_steps = math.ceil(len(train_dataset) / train_batch_size)
|
||||||
|
print("***** Running prediction *****")
|
||||||
|
print(f" Num examples = {len(train_dataset)}")
|
||||||
|
print(f" Num steps = {pred_steps}")
|
||||||
|
print(f" Instantaneous batch size per device = {batch_size}")
|
||||||
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
||||||
|
|
||||||
|
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||||||
|
batch = next(train_loader)
|
||||||
|
batch = jax.device_put(batch, data_sharding)
|
||||||
|
encodings = get_encodings_jit(batch)
|
||||||
|
encoding_list.extend(jax.device_get(encodings))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
encoding_list = jnp.vstack(encoding_list)
|
||||||
|
# slice up to the previously defined list to unpad
|
||||||
|
encoding_list = encoding_list[:len(train_dataset)]
|
||||||
|
print(encoding_list.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# getting top-k
|
||||||
|
def top_k_cosine_similarity(M, a, k):
|
||||||
|
"""
|
||||||
|
Find the top-k rows in matrix M that are most cosine similar to array a.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector.
|
||||||
|
a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M.
|
||||||
|
k (int): Number of top cosine similarities to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
values (jnp.ndarray): Top-k cosine similarity values.
|
||||||
|
indices (jnp.ndarray): Indices of the top-k most similar rows in M.
|
||||||
|
"""
|
||||||
|
# Step 1: Normalize both M and a
|
||||||
|
M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d)
|
||||||
|
a_norm = a / jnp.linalg.norm(a) # Shape: (d,)
|
||||||
|
|
||||||
|
# Step 2: Compute cosine similarity via dot product
|
||||||
|
cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,)
|
||||||
|
|
||||||
|
# Step 3: Get the top-k values and their indices using jax.lax.top_k
|
||||||
|
values, indices = jax.lax.top_k(cosine_similarities, k)
|
||||||
|
|
||||||
|
return values, indices
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions
|
||||||
|
a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions
|
||||||
|
|
||||||
|
# Find top 5 most similar rows
|
||||||
|
top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5)
|
||||||
|
|
||||||
|
print("Top-k cosine similarity values:", top_k_values)
|
||||||
|
print("Indices of top-k similar rows:", top_k_indices)
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1 @@
|
||||||
|
MNIST
|
|
@ -0,0 +1,168 @@
|
||||||
|
# %%
|
||||||
|
import tensorflow_datasets as tfds # TFDS for MNIST
|
||||||
|
import tensorflow as tf # TensorFlow operations
|
||||||
|
|
||||||
|
tf.random.set_seed(0) # set random seed for reproducibility
|
||||||
|
|
||||||
|
num_epochs = 10
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
|
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
|
||||||
|
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
|
||||||
|
|
||||||
|
train_ds = train_ds.map(
|
||||||
|
lambda sample: {
|
||||||
|
'image': tf.cast(sample['image'], tf.float32) / 255,
|
||||||
|
'label': sample['label'],
|
||||||
|
}
|
||||||
|
) # normalize train set
|
||||||
|
test_ds = test_ds.map(
|
||||||
|
lambda sample: {
|
||||||
|
'image': tf.cast(sample['image'], tf.float32) / 255,
|
||||||
|
'label': sample['label'],
|
||||||
|
}
|
||||||
|
) # normalize test set
|
||||||
|
|
||||||
|
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
||||||
|
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
|
||||||
|
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
||||||
|
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
||||||
|
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
||||||
|
test_ds = test_ds.shuffle(1024)
|
||||||
|
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
||||||
|
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
||||||
|
# %%
|
||||||
|
from flax import nnx # NNX API
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
class CNN(nnx.Module):
|
||||||
|
"""A simple CNN model."""
|
||||||
|
|
||||||
|
def __init__(self, *, rngs: nnx.Rngs):
|
||||||
|
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
|
||||||
|
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
|
||||||
|
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
|
||||||
|
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
|
||||||
|
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.avg_pool(nnx.relu(self.conv1(x)))
|
||||||
|
x = self.avg_pool(nnx.relu(self.conv2(x)))
|
||||||
|
x = x.reshape(x.shape[0], -1) # flatten
|
||||||
|
x = nnx.relu(self.linear1(x))
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
model = CNN(rngs=nnx.Rngs(0))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
nnx.display(model)
|
||||||
|
# %%
|
||||||
|
# test the model by feeding an example input
|
||||||
|
import jax.numpy as jnp # JAX NumPy
|
||||||
|
|
||||||
|
y = model(jnp.ones((1, 28, 28, 1)))
|
||||||
|
nnx.display(y)
|
||||||
|
# %%
|
||||||
|
import optax
|
||||||
|
|
||||||
|
learning_rate = 0.005
|
||||||
|
momentum = 0.9
|
||||||
|
|
||||||
|
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
|
||||||
|
metrics = nnx.MultiMetric(
|
||||||
|
accuracy=nnx.metrics.Accuracy(),
|
||||||
|
loss=nnx.metrics.Average('loss'),
|
||||||
|
)
|
||||||
|
|
||||||
|
nnx.display(optimizer)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
def loss_fn(model: CNN, batch):
|
||||||
|
logits = model(batch['image'])
|
||||||
|
loss = optax.softmax_cross_entropy_with_integer_labels(
|
||||||
|
logits=logits, labels=batch['label']
|
||||||
|
).mean()
|
||||||
|
return loss, logits
|
||||||
|
|
||||||
|
# %%
|
||||||
|
@nnx.jit
|
||||||
|
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
|
||||||
|
"""Train for a single step."""
|
||||||
|
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
|
||||||
|
(loss, logits), grads = grad_fn(model, batch)
|
||||||
|
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
||||||
|
optimizer.update(grads)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# evaluation step
|
||||||
|
@nnx.jit
|
||||||
|
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
|
||||||
|
loss, logits = loss_fn(model, batch)
|
||||||
|
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# for dataset seed random generation
|
||||||
|
tf.random.set_seed(0)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
|
||||||
|
|
||||||
|
metrics_history = {
|
||||||
|
'train_loss': [],
|
||||||
|
'train_accuracy': [],
|
||||||
|
'test_loss': [],
|
||||||
|
'test_accuracy': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for step, batch in enumerate(train_ds.as_numpy_iterator()):
|
||||||
|
# Run the optimization for one step and make a stateful update to the following:
|
||||||
|
# - the train state's model parameters
|
||||||
|
# - the optimizer state
|
||||||
|
# - the training loss and accuracy batch metrics
|
||||||
|
train_step(model, optimizer, metrics, batch)
|
||||||
|
|
||||||
|
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
|
||||||
|
# Log training metrics
|
||||||
|
for metric, value in metrics.compute().items(): # compute metrics
|
||||||
|
metrics_history[f'train_{metric}'].append(value) # record metrics
|
||||||
|
metrics.reset() # reset metrics for test set
|
||||||
|
|
||||||
|
# Compute metrics on the test set after each training epoch
|
||||||
|
for test_batch in test_ds.as_numpy_iterator():
|
||||||
|
eval_step(model, metrics, test_batch)
|
||||||
|
|
||||||
|
# Log test metrics
|
||||||
|
for metric, value in metrics.compute().items():
|
||||||
|
metrics_history[f'test_{metric}'].append(value)
|
||||||
|
metrics.reset() # reset metrics for next training epoch
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"train epoch: {(step+1) // num_steps_per_epoch}, "
|
||||||
|
f"loss: {metrics_history['train_loss'][-1]}, "
|
||||||
|
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"test epoch: {(step+1) // num_steps_per_epoch}, "
|
||||||
|
f"loss: {metrics_history['test_loss'][-1]}, "
|
||||||
|
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
|
||||||
|
)
|
||||||
|
# %%
|
||||||
|
# visualize metrics
|
||||||
|
import matplotlib.pyplot as plt # Visualization
|
||||||
|
|
||||||
|
# Plot loss and accuracy in subplots
|
||||||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||||
|
ax1.set_title('Loss')
|
||||||
|
ax2.set_title('Accuracy')
|
||||||
|
for dataset in ('train', 'test'):
|
||||||
|
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
|
||||||
|
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
|
||||||
|
ax1.legend()
|
||||||
|
ax2.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# %%
|
Loading…
Reference in New Issue