Feat: implemented attention layer in equinox
This commit is contained in:
parent
a817fe16cc
commit
0762c02b31
|
@ -0,0 +1,172 @@
|
|||
# %%
|
||||
# Prepare dataloader for jax training
|
||||
from datasets import Dataset, DatasetDict, Value, Sequence, load_from_disk
|
||||
from transformers import FlaxT5ForConditionalGeneration
|
||||
from datasets import ClassLabel, Value, Sequence
|
||||
from ml_collections import ConfigDict
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
import math
|
||||
from typing import Optional, List, Tuple, Callable, cast
|
||||
|
||||
|
||||
# file_path = 'combined_data'
|
||||
# split_datasets = load_from_disk(file_path)
|
||||
# training_size = len(split_datasets['train'])
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
# class takes in a dataset
|
||||
class DataPrepare():
|
||||
|
||||
def __init__(self, raw_dataset, config):
|
||||
self.raw_dataset: Dataset = raw_dataset
|
||||
self.size: int = len(raw_dataset)
|
||||
self.config: ConfigDict = config
|
||||
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||
|
||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
||||
|
||||
|
||||
self.train_dataset = self.preprocess_function(
|
||||
self.raw_dataset
|
||||
)
|
||||
|
||||
|
||||
|
||||
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||
def preprocess_function(self, example: Dataset):
|
||||
inputs = example['input']
|
||||
targets = example['output']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
# produce input_ids and decoder_input_ids
|
||||
model_inputs = self.tokenizer(
|
||||
inputs,
|
||||
max_length=self.config.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
# we separate it out because we need the attention mask
|
||||
labels = self.tokenizer(
|
||||
text_target=targets,
|
||||
max_length=self.config.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
|
||||
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
|
||||
# for loss computation
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
# make decoder input ids
|
||||
# this is actually "model output" shifted right
|
||||
decoder_input_ids = self.shift_tokens_right_fn(
|
||||
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
# require by model
|
||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||
# decoder_attention_mask = shift_tokens_right_fn(
|
||||
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
# )
|
||||
# We need decoder_attention_mask so we can ignore pad tokens in loss
|
||||
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
# Example pad function
|
||||
def _pad_to_batch_size(self, batch, target_size):
|
||||
# Get the current batch size
|
||||
input_ids = batch['input_ids']
|
||||
current_size = input_ids.shape[0]
|
||||
if current_size < target_size:
|
||||
# Calculate how much padding is needed
|
||||
padding_size = target_size - current_size
|
||||
# Create padding (e.g., zeros or some appropriate value)
|
||||
padding = jnp.zeros((padding_size, input_ids.shape[1]), dtype=jnp.int32) # Assuming 2D
|
||||
# Concatenate to create a full batch
|
||||
# repeat for all arrays in the tree
|
||||
padded_batch = jax.tree.map(lambda array: jnp.concatenate([array, padding], axis=0, dtype=jnp.int32), batch)
|
||||
# padded_batch = jnp.concatenate([batch, padding], axis=0)
|
||||
|
||||
else:
|
||||
padded_batch = batch
|
||||
return padded_batch
|
||||
|
||||
def data_loader(self, rng: jax.random.PRNGKey, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
dataset: Dataset = Dataset.from_dict(self.train_dataset)
|
||||
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
minibatch_list = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
minibatch_list = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for minibatch in minibatch_list:
|
||||
batch = dataset[minibatch]
|
||||
batch = {k: jnp.array(v, dtype=jnp.int32) for k, v in batch.items()}
|
||||
batch = self._pad_to_batch_size(batch, batch_size)
|
||||
|
||||
yield batch
|
||||
|
||||
|
||||
# # testing out the class
|
||||
# # %%
|
||||
# # init object
|
||||
# # e.g. Config
|
||||
#
|
||||
# file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_desc'
|
||||
# data_config = ConfigDict(
|
||||
# dict(
|
||||
# max_length=86,
|
||||
# pad_token_id=0,
|
||||
# decoder_start_token_id=0
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# from datasets import load_from_disk
|
||||
# split_datasets = load_from_disk(file_path)
|
||||
# dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||
#
|
||||
# # %%
|
||||
# seed = 117
|
||||
# rng = jax.random.PRNGKey(seed)
|
||||
# train_loader = dataprep.data_loader(rng, batch_size=32)
|
||||
#
|
||||
#
|
||||
#
|
||||
# # %%
|
||||
# batch = next(train_loader)
|
||||
#
|
||||
# print(batch['input_ids'].shape)
|
||||
# print(batch['decoder_input_ids'].shape)
|
||||
#
|
||||
# # %%
|
|
@ -0,0 +1 @@
|
|||
__pycache__/
|
|
@ -0,0 +1,54 @@
|
|||
# an example of stateful operations
|
||||
# %%
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jr
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from equinox.nn import State, StateIndex, StatefulLayer
|
||||
from jaxtyping import Array
|
||||
|
||||
|
||||
# %%
|
||||
class Counter(eqx.Module):
|
||||
# This wraps together (a) a unique dictionary key used for looking up a
|
||||
# stateful value, and (b) how that stateful value should be initialised.
|
||||
index: eqx.nn.StateIndex
|
||||
|
||||
def __init__(self):
|
||||
init_state = jnp.array(0)
|
||||
self.index = eqx.nn.StateIndex(init_state)
|
||||
|
||||
# eqx.nn.State stores the state of the model
|
||||
# This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.
|
||||
# This class should be initialised via equinox.nn.make_with_state.
|
||||
#
|
||||
# Basically just a dictionary which (a) works only with StateIndex-s, and which (b)
|
||||
# works around a JAX bug that prevents flattening dicts with `object()` keys, and which
|
||||
# (c) does error-checking that you're using the most up-to-date version of it.
|
||||
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
|
||||
value = state.get(self.index)
|
||||
new_x = x + value
|
||||
|
||||
# Sets a new value for an [`equinox.nn.StateIndex`][], and returns the
|
||||
# updated state.
|
||||
new_state = state.set(self.index, value + 1)
|
||||
return new_x, new_state
|
||||
|
||||
# make_with_state is the recommended way to start a stateful model
|
||||
counter, state = eqx.nn.make_with_state(Counter)()
|
||||
x = jnp.array(2.3)
|
||||
|
||||
num_calls = state.get(counter.index)
|
||||
print(f"Called {num_calls} times.") # 0
|
||||
|
||||
_, state = counter(x, state)
|
||||
num_calls = state.get(counter.index)
|
||||
print(f"Called {num_calls} times.") # 1
|
||||
|
||||
_, state = counter(x, state)
|
||||
num_calls = state.get(counter.index)
|
||||
print(f"Called {num_calls} times.") # 2
|
||||
|
||||
|
||||
# %%
|
|
@ -0,0 +1,106 @@
|
|||
# %%
|
||||
# introduction to how flax does stateful operations
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
import flax
|
||||
from jaxtyping import Array
|
||||
|
||||
# %%
|
||||
|
||||
class BiasAdderWithRunningMean(nn.Module):
|
||||
momentum: float = 0.9
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
is_initialized = self.has_variable('hehe', 'mean')
|
||||
print(is_initialized)
|
||||
mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:])
|
||||
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
|
||||
if is_initialized:
|
||||
print(mean.value) # notice that value retains after first call
|
||||
mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(
|
||||
x, axis=0, keepdims=True
|
||||
)
|
||||
print(mean.value)
|
||||
return mean.value + bias
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
input_key = jax.random.PRNGKey(0)
|
||||
model = BiasAdderWithRunningMean()
|
||||
inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values
|
||||
variables = model.init(input_key, inputs)
|
||||
# Split state and params (which are updated by optimizer).
|
||||
state, params = flax.core.pop(variables, 'params')
|
||||
print(f"first init: {state}")
|
||||
# %%
|
||||
for i in range(5):
|
||||
new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs
|
||||
# notice how we are threading the state
|
||||
# perform argument unpacking on state dictionary
|
||||
output, state = model.apply({'params': params, **state},
|
||||
new_inputs, mutable=list(state.keys()))
|
||||
|
||||
# mean_state = variables['batch_stats']['mean'] # Access the updated mean state
|
||||
print(f"updated state {state}")
|
||||
print(f"Output after input {i + 1}: {output}")
|
||||
# print(f"Updated running mean state: {mean_state}")
|
||||
# %%
|
||||
###########################################################
|
||||
# example 2
|
||||
from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal
|
||||
import jax.random as random
|
||||
class Foo(nn.Module):
|
||||
features: int
|
||||
@nn.compact
|
||||
def __call__(self):
|
||||
key = self.make_rng('spectral_norm_stats')
|
||||
print(key)
|
||||
u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features))
|
||||
return u0_variable.value
|
||||
|
||||
foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)})
|
||||
Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
||||
# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32)
|
||||
|
||||
# %%
|
||||
model = Foo(3)
|
||||
|
||||
# %%
|
||||
# state is kept in self.variable, tied to the layer
|
||||
output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
||||
|
||||
# %%
|
||||
output, state = model.apply(
|
||||
foovars,
|
||||
mutable=list(foovars.keys()),
|
||||
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||
)
|
||||
print(output, state)
|
||||
# %%
|
||||
output, state = model.apply(
|
||||
state,
|
||||
mutable=list(foovars.keys()),
|
||||
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||
)
|
||||
# no change because input state is the same
|
||||
print(output, state)
|
||||
# %%
|
||||
state_array = state['spectral_norm_stats']['u0']
|
||||
modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0))
|
||||
state['spectral_norm_stats']['u0'] = modified_array
|
||||
# %%
|
||||
# %%
|
||||
output, state = model.apply(
|
||||
state,
|
||||
mutable=list(foovars.keys()),
|
||||
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
||||
)
|
||||
# state takes from given state
|
||||
# note the modified 0.9 value
|
||||
# note how the state is not re-initialized
|
||||
print(output, state)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,252 @@
|
|||
# %%
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax # https://github.com/deepmind/optax
|
||||
import torch # https://pytorch.org
|
||||
import torchvision # https://pytorch.org
|
||||
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
|
||||
# %%
|
||||
# Hyperparameters
|
||||
|
||||
BATCH_SIZE = 64
|
||||
LEARNING_RATE = 3e-4
|
||||
STEPS = 300
|
||||
PRINT_EVERY = 30
|
||||
SEED = 5678
|
||||
|
||||
key = jax.random.PRNGKey(SEED)
|
||||
|
||||
# %%
|
||||
normalise_data = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||
]
|
||||
)
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
"MNIST",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=normalise_data,
|
||||
)
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
"MNIST",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=normalise_data,
|
||||
)
|
||||
trainloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
testloader = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
|
||||
# %%
|
||||
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
|
||||
dummy_x, dummy_y = next(iter(trainloader))
|
||||
dummy_x = dummy_x.numpy()
|
||||
dummy_y = dummy_y.numpy()
|
||||
print(dummy_x.shape) # 64x1x28x28
|
||||
print(dummy_y.shape) # 64
|
||||
print(dummy_y)
|
||||
|
||||
# %%
|
||||
class CNN(eqx.Module):
|
||||
layers: list
|
||||
|
||||
def __init__(self, key):
|
||||
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||||
# Standard CNN setup: convolutional layer, followed by flattening,
|
||||
# with a small MLP on top.
|
||||
self.layers = [
|
||||
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||
eqx.nn.MaxPool2d(kernel_size=2),
|
||||
jax.nn.relu, # jax functions!!!
|
||||
jnp.ravel,
|
||||
eqx.nn.Linear(1728, 512, key=key2),
|
||||
jax.nn.sigmoid,
|
||||
eqx.nn.Linear(512, 64, key=key3),
|
||||
jax.nn.relu,
|
||||
eqx.nn.Linear(64, 10, key=key4),
|
||||
jax.nn.log_softmax,
|
||||
]
|
||||
|
||||
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
key, subkey = jax.random.split(key, 2)
|
||||
model = CNN(subkey)
|
||||
# %%
|
||||
print(model)
|
||||
# %%
|
||||
# print the first layer: Conv2d
|
||||
print(model.layers[0])
|
||||
# %%
|
||||
# illustrated inference
|
||||
def loss(
|
||||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
|
||||
# a single input input image of shape (1, 28, 28).
|
||||
#
|
||||
# Therefore, we have to use jax.vmap, which in this case maps our model over the
|
||||
# leading (batch) axis.
|
||||
#
|
||||
# This is an example of writing function for one input, then letting jax
|
||||
# automatically vectorize over the batch dimension
|
||||
pred_y = jax.vmap(model)(x)
|
||||
return cross_entropy(y, pred_y)
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
|
||||
) -> Float[Array, ""]:
|
||||
# y are the true targets, and should be integers 0-9.
|
||||
# pred_y are the log-softmax'd predictions.
|
||||
|
||||
# take_along_axis: take from pred_y along axis 1 according to 2nd argument
|
||||
# expand_dims to axis 1 makes it of shape (y_dim, 1)
|
||||
# since we take along axis 1, each y (in 2nd arg) therefore takes the
|
||||
# corresponding entry of each row in pred_y
|
||||
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||
return -jnp.mean(pred_y) # negative mean of relevant logits
|
||||
|
||||
|
||||
# Example loss
|
||||
loss_value = loss(model, dummy_x, dummy_y)
|
||||
print(loss_value)
|
||||
print(loss_value.shape) # scalar loss
|
||||
# Example inference
|
||||
output = jax.vmap(model)(dummy_x)
|
||||
print(output.shape) # batch of predictions
|
||||
|
||||
# %%
|
||||
# This is an error!
|
||||
# the reason is that model has to be parameters, but model has non-parameters
|
||||
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
|
||||
|
||||
|
||||
# %%
|
||||
# we separate out things that are params from other things
|
||||
# since params are things that are arrays
|
||||
# partition is doing filter(...) and filter(..., inverse=True)
|
||||
params, static = eqx.partition(model, eqx.is_array)
|
||||
|
||||
# %%
|
||||
# lets compare the same object in both terms
|
||||
print(static.layers[0])
|
||||
print(params.layers[0])
|
||||
|
||||
|
||||
# %%
|
||||
# in the loss, we recombine both to form back our model
|
||||
def loss2(params, static, x, y):
|
||||
model = eqx.combine(params, static)
|
||||
return loss(model, x, y)
|
||||
|
||||
|
||||
# Now this will work!
|
||||
# since the grad only looks at the first argument, this works out
|
||||
loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)
|
||||
print(loss_value)
|
||||
|
||||
# %%
|
||||
# This will work too!
|
||||
# this works the same as the previous
|
||||
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
|
||||
print(value)
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
loss = eqx.filter_jit(loss) # JIT our loss function from earlier!
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def compute_accuracy(
|
||||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
"""This function takes as input the current model
|
||||
and computes the average accuracy on a batch.
|
||||
"""
|
||||
pred_y = jax.vmap(model)(x)
|
||||
pred_y = jnp.argmax(pred_y, axis=1)
|
||||
return jnp.mean(y == pred_y)
|
||||
|
||||
# %%
|
||||
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
|
||||
"""This function evaluates the model on the test dataset,
|
||||
computing both the average loss and the average accuracy.
|
||||
"""
|
||||
avg_loss = 0
|
||||
avg_acc = 0
|
||||
for x, y in testloader:
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||
# and both have JIT wrappers, so this is fast.
|
||||
avg_loss += loss(model, x, y)
|
||||
avg_acc += compute_accuracy(model, x, y)
|
||||
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||||
|
||||
# %%
|
||||
evaluate(model, testloader)
|
||||
# %%
|
||||
# training
|
||||
optim = optax.adamw(LEARNING_RATE)
|
||||
|
||||
def train(
|
||||
model: CNN,
|
||||
trainloader: torch.utils.data.DataLoader,
|
||||
testloader: torch.utils.data.DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
steps: int,
|
||||
print_every: int,
|
||||
) -> CNN:
|
||||
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||
# so filter out everything else.
|
||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||
|
||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||
# the model -- into a single JIT region. This ensures things run as fast as
|
||||
# possible.
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
model: CNN,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
):
|
||||
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||||
updates, opt_state = optim.update(grads, opt_state, model)
|
||||
model = eqx.apply_updates(model, updates)
|
||||
return model, opt_state, loss_value
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
test_loss, test_accuracy = evaluate(model, testloader)
|
||||
print(
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# %%
|
||||
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,591 @@
|
|||
# %%
|
||||
# package imports from equinox BERT example
|
||||
import functools
|
||||
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
|
||||
|
||||
# import einops # https://github.com/arogozhnikov/einops
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from datasets import load_dataset # https://github.com/huggingface/datasets
|
||||
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
|
||||
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
|
||||
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
|
||||
from ml_collections import ConfigDict, FrozenConfigDict
|
||||
|
||||
# helper functions for attention computation
|
||||
# they are implemented with jax w/o flax
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
|
||||
import flax.linen as nn
|
||||
# %%
|
||||
class T5LayerNorm(eqx.Module):
|
||||
eps: float = 1e-6
|
||||
weight: jax.Array
|
||||
# staticmethod forces the method to be by itself
|
||||
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
hidden_size: int,
|
||||
key: jax.random.PRNGKey,
|
||||
# dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
# self.dtype = dtype
|
||||
# self.params = {
|
||||
# 'weight': self.weight_init(key, (hidden_size,), dtype)
|
||||
# }
|
||||
# force the use of float32
|
||||
# note that the key argument is ignored, so key is actually optional
|
||||
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
|
||||
|
||||
# takes in argument for hidden states so that it can fall through and remain
|
||||
# a pure function
|
||||
def __call__(self, hidden_states):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style;
|
||||
No bias and no subtraction of mean
|
||||
"""
|
||||
# always compute in float32 for layer norm
|
||||
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
||||
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
# # %%
|
||||
# # testing T5LayerNorm
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# hidden_size = 128 # Example hidden size
|
||||
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
|
||||
# # Create some example input data
|
||||
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
|
||||
# # Forward pass
|
||||
# output = layer_norm(hidden_states)
|
||||
# print("Output shape:", output.shape)
|
||||
|
||||
# %%
|
||||
class KaimingLinear(eqx.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
weights: jax.Array
|
||||
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
weights_init_std: float,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
self.dtype = dtype
|
||||
|
||||
# the initialization strategy is to standardize on output dimension
|
||||
# shapes are: (input_dim, output_dim)
|
||||
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Float[Array, " input"],
|
||||
):
|
||||
hidden = jnp.dot(inputs, self.weights)
|
||||
return hidden
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# this function fortunately supports batched operations by default due to broadcasting
|
||||
class T5DenseActDense(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
wi: jax.Array
|
||||
wo: jax.Array
|
||||
dropout: eqx.nn.Dropout
|
||||
act: jax.nn.relu
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype,
|
||||
key: jax.random.PRNGKey
|
||||
):
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
mlp_key, output_key = jax.random.split(key)
|
||||
# the initialization strategy is to standardize on output dimension
|
||||
# input
|
||||
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
||||
# shapes are: (config.d_model, config.d_ff)
|
||||
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
|
||||
self.wi = KaimingLinear(
|
||||
key=mlp_key,
|
||||
input_dim=self.config.d_model,
|
||||
output_dim=self.config.d_ff,
|
||||
weights_init_std=wi_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
# output
|
||||
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
||||
# shapes are: (config.d_ff, config.d_model)
|
||||
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
|
||||
self.wo = KaimingLinear(
|
||||
key=mlp_key,
|
||||
input_dim=self.config.d_ff,
|
||||
output_dim=self.config.d_model,
|
||||
weights_init_std=wo_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
|
||||
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||
# just set to relu for now since the smaller T5's use relu
|
||||
self.act = jax.nn.relu
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Float[Array, " d_model"],
|
||||
enable_dropout: bool = False,
|
||||
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||
) -> Float[Array, " d_model"]:
|
||||
hidden = self.wi(inputs)
|
||||
# hidden = jnp.dot(inputs, self.wi)
|
||||
hidden = self.act(hidden)
|
||||
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
|
||||
hidden = self.wo(hidden)
|
||||
# hidden = jnp.dot(hidden, self.wo)
|
||||
return hidden
|
||||
|
||||
|
||||
# # %%
|
||||
# # test for T5DenseActDense
|
||||
# # create fake config
|
||||
# config_dict = {
|
||||
# 'd_model': 768,
|
||||
# 'd_ff': 2048,
|
||||
# 'dropout_rate': 0.1,
|
||||
# 'initializer_factor': 1.0,
|
||||
# }
|
||||
# # Create a FrozenDict from the standard dictionary
|
||||
# frozen_config = FrozenConfigDict(config_dict)
|
||||
# # initialize model
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# dense = T5DenseActDense(
|
||||
# key=key,
|
||||
# config=frozen_config,
|
||||
# dtype=jnp.float32
|
||||
# )
|
||||
# input_key, key = jax.random.split(key)
|
||||
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||
# dropout_key, key = jax.random.split(key)
|
||||
# output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||
# output.shape
|
||||
|
||||
# %%
|
||||
class T5LayerFF(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
dtype: jnp.dtype
|
||||
DenseReluDense: T5DenseActDense
|
||||
layer_norm: T5LayerNorm
|
||||
dropout: eqx.nn.Dropout
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
dense_key, key = jax.random.split(key)
|
||||
# args: key, config, dtype
|
||||
self.DenseReluDense = T5DenseActDense(
|
||||
key=dense_key,
|
||||
config=config,
|
||||
dtype=dtype
|
||||
)
|
||||
layer_key, key = jax.random.split(key)
|
||||
# args: key, hidden_size
|
||||
self.layer_norm = T5LayerNorm(
|
||||
key=layer_key,
|
||||
hidden_size=self.config.d_model
|
||||
)
|
||||
# args: dropout_rate
|
||||
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||
|
||||
def __call__(
|
||||
self: eqx.Module,
|
||||
inputs: Float[Array, " d_model"],
|
||||
enable_dropout: bool =False,
|
||||
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||
):
|
||||
forwarded_states = self.layer_norm(inputs)
|
||||
dropout_key, key = jax.random.split(dropout_key)
|
||||
forwarded_states = self.DenseReluDense(
|
||||
inputs=forwarded_states,
|
||||
enable_dropout=enable_dropout,
|
||||
dropout_key=dropout_key
|
||||
)
|
||||
dropout_key, key = jax.random.split(key)
|
||||
dropout_states = self.dropout(
|
||||
x = forwarded_states,
|
||||
inference=not enable_dropout,
|
||||
key = dropout_key,
|
||||
)
|
||||
hidden = inputs + dropout_states
|
||||
return hidden
|
||||
|
||||
# # %%
|
||||
# # test for T5DenseActDense
|
||||
# # create fake config
|
||||
# config_dict = {
|
||||
# 'd_model': 768,
|
||||
# 'd_ff': 2048,
|
||||
# 'dropout_rate': 0.1,
|
||||
# 'initializer_factor': 1.0,
|
||||
# }
|
||||
# # Create a FrozenDict from the standard dictionary
|
||||
# frozen_config = FrozenConfigDict(config_dict)
|
||||
# # initialize model
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# ff_layer = T5LayerFF(
|
||||
# key=key,
|
||||
# config=frozen_config,
|
||||
# dtype=jnp.float32
|
||||
# )
|
||||
# input_key, key = jax.random.split(key)
|
||||
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||
# dropout_key, key = jax.random.split(key)
|
||||
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||
# output.shape
|
||||
|
||||
# %%
|
||||
class T5Attention(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
has_relative_attention_bias: bool = False
|
||||
causal: bool = False # False for encoder, True for decoder
|
||||
dtype: jnp.dtype
|
||||
|
||||
# parameters
|
||||
q: jax.Array
|
||||
k: jax.Array
|
||||
v: jax.Array
|
||||
o: jax.Array
|
||||
|
||||
|
||||
# additional terms
|
||||
relative_attention_num_buckets: int
|
||||
relative_attention_max_distance: int
|
||||
d_model: int
|
||||
key_value_proj_dim: int
|
||||
n_heads: int
|
||||
dropout: float
|
||||
inner_dim: int
|
||||
initializer_factor: float
|
||||
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype,
|
||||
key: jax.random.PRNGKey,
|
||||
):
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
||||
self.d_model = self.config.d_model
|
||||
# size of k,v projection for each head
|
||||
self.key_value_proj_dim = self.config.d_kv
|
||||
self.n_heads = self.config.num_heads
|
||||
self.dropout = self.config.dropout_rate
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
self.initializer_factor = self.config.initializer_factor
|
||||
|
||||
q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
||||
kv_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
||||
o_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
||||
|
||||
q_key, key = jax.random.split(key)
|
||||
self.q = KaimingLinear(
|
||||
key=q_key,
|
||||
input_dim=(self.inner_dim),
|
||||
output_dim=self.inner_dim,
|
||||
weights_init_std=q_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
k_key, key = jax.random.split(key)
|
||||
self.k = KaimingLinear(
|
||||
key=k_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.inner_dim,
|
||||
weights_init_std=kv_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
v_key, key = jax.random.split(key)
|
||||
self.v = KaimingLinear(
|
||||
key=v_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.inner_dim,
|
||||
weights_init_std=kv_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
o_key, key = jax.random.split(key)
|
||||
self.o = KaimingLinear(
|
||||
key=o_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.d_model,
|
||||
weights_init_std=o_init_std,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128
|
||||
):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
"""
|
||||
relative_buckets = 0
|
||||
|
||||
# bidirection determines if positive relative positions are valid
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
# relative position range of [0, inf]
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
|
||||
# half of buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
# boolean to assign relative buckets later
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# other half are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
# jnp.where(condition, x, y), true->x, false->y
|
||||
# in-place cumulative summation
|
||||
# yields a list where every element has the correct relative bucket position
|
||||
# whether its small or large
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
return relative_buckets.astype("i4")
|
||||
|
||||
# bias gives weight based on relative distance aside from attention score
|
||||
def compute_bias(self, query_length, key_length):
|
||||
"""
|
||||
Compute binned relative position bias
|
||||
"""
|
||||
# arange in the first dim
|
||||
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
||||
# arange in the second dim
|
||||
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
||||
|
||||
# The relative position is defined as memory_position - query_position,
|
||||
# i.e. the distance in tokens from the attending position to the
|
||||
# attended-to position.
|
||||
#
|
||||
# 2D array where each entry represents the distance from a query token
|
||||
# to a key token
|
||||
relative_position = memory_position - context_position
|
||||
# now we apply the earlier bucket creation function
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position=relative_position,
|
||||
bidirectional=(not self.causal), # causal during decode -> not bi
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
# retrieve the bias values
|
||||
# shape (query_length, key_length, n_heads)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
# shape (1, n_heads, query_length, key_length)
|
||||
# ready for attention
|
||||
values = values.transpose((2, 0, 1))[None, :, :, :]
|
||||
return values
|
||||
|
||||
|
||||
# from (batch_size, seq_length, d_model) to
|
||||
# (batch_size, seq_length, n_heads, head_dim)
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
||||
|
||||
# from (batch_size, seq_length, n_heads, head_dim) to
|
||||
# (batch_size, seq_length, d_model)
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
||||
|
||||
|
||||
def _create_position_bias(
|
||||
self,
|
||||
key_states,
|
||||
query_states,
|
||||
attention_mask,
|
||||
):
|
||||
# unlike the flax version, we don't even check for cache
|
||||
key_length = key_states.shape[1]
|
||||
query_length = query_states.shape[1]
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.compute_bias(query_length, key_length)
|
||||
elif attention_mask is not None:
|
||||
position_bias = jnp.zeros_like(attention_mask)
|
||||
else:
|
||||
position_bias = jnp.zeros(
|
||||
(1, self.n_heads, query_length, key_length),
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
return position_bias
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask=None,
|
||||
key_value_states=None,
|
||||
position_bias=None,
|
||||
output_attentions=False,
|
||||
enable_dropout=False,
|
||||
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||
):
|
||||
# expected input shape: (batch_size, seq_len, d_model)
|
||||
# expected output: tuple of 2 arrays same shape as input
|
||||
# (attn, position_bias)
|
||||
batch_size, seq_length = inputs.shape[:2]
|
||||
|
||||
# q,k,v projections
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
query_states = self.q(inputs)
|
||||
key_states = (
|
||||
self.k(inputs) if key_value_states is None else self.k(key_value_states)
|
||||
)
|
||||
value_states = (
|
||||
self.v(inputs) if key_value_states is None else self.v(key_value_states)
|
||||
)
|
||||
|
||||
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# counteract scaling in dot_product_attention_weights function
|
||||
query_states *= jnp.sqrt(query_states.shape[-1])
|
||||
|
||||
# create causal attention_mask
|
||||
if self.causal:
|
||||
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
||||
|
||||
# broadcast causal attention mask & attention mask to fit for merge
|
||||
causal_attention_mask = jnp.broadcast_to(
|
||||
causal_attention_mask,
|
||||
(batch_size,) + causal_attention_mask.shape[1:]
|
||||
)
|
||||
attention_mask = jnp.broadcast_to(
|
||||
jnp.expand_dims(attention_mask, axis=(-3, -2)),
|
||||
causal_attention_mask.shape
|
||||
)
|
||||
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# replace masked positions with -10_000
|
||||
if attention_mask is not None:
|
||||
mask_value = jnp.finfo(self.dtype).min
|
||||
attention_mask = jax.lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
||||
)
|
||||
|
||||
if position_bias is None:
|
||||
# compute position bias (only for first layer)
|
||||
position_bias = self._create_position_bias(
|
||||
key_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
position_bias = position_bias + attention_mask
|
||||
|
||||
# Softmax(QK^T)
|
||||
attn_weights = dot_product_attention_weights(
|
||||
query_states,
|
||||
key_states,
|
||||
bias=position_bias,
|
||||
dropout_rng=dropout_key,
|
||||
dropout_rate=self.dropout,
|
||||
broadcast_dropout=True,
|
||||
deterministic=not enable_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# multiply with value states
|
||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||
|
||||
# bring back to (batch_size, seq_length, d_model)
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
|
||||
# apply output matrix
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
outputs = (attn_output, position_bias)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
# # %%
|
||||
# # test for T5Attention
|
||||
# # create fake config
|
||||
# config_dict = {
|
||||
# 'relative_attention_num_buckets': 32,
|
||||
# 'relative_attention_max_distance': 128,
|
||||
# 'd_model': 768, # 64 * 12
|
||||
# 'd_kv': 64,
|
||||
# 'num_heads': 12,
|
||||
# 'dropout_rate': 0.1,
|
||||
# 'initializer_factor': 1.0,
|
||||
# }
|
||||
# # Create a FrozenDict from the standard dictionary
|
||||
# frozen_config = FrozenConfigDict(config_dict)
|
||||
# # initialize model
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# attn_layer = T5Attention(
|
||||
# key=key,
|
||||
# config=frozen_config,
|
||||
# dtype=jnp.float32
|
||||
# )
|
||||
# input_key, key = jax.random.split(key)
|
||||
# # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||
# batch_size = 1
|
||||
# seq_length = 10
|
||||
# inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model))
|
||||
# dropout_key, key = jax.random.split(key)
|
||||
# output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||
# print(len(output))
|
||||
# print(output[0].shape)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,495 @@
|
|||
# %%
|
||||
# package imports from equinox BERT example
|
||||
import functools
|
||||
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
|
||||
|
||||
# import einops # https://github.com/arogozhnikov/einops
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from datasets import load_dataset # https://github.com/huggingface/datasets
|
||||
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
|
||||
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
|
||||
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
|
||||
from ml_collections import ConfigDict, FrozenConfigDict
|
||||
|
||||
import flax.linen as nn
|
||||
# %%
|
||||
class T5LayerNorm(eqx.Module):
|
||||
eps: float = 1e-6
|
||||
weight: jax.Array
|
||||
# staticmethod forces the method to be by itself
|
||||
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
hidden_size: int,
|
||||
# dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
# self.dtype = dtype
|
||||
# self.params = {
|
||||
# 'weight': self.weight_init(key, (hidden_size,), dtype)
|
||||
# }
|
||||
# force the use of float32
|
||||
# note that the key argument is ignored, so key is actually optional
|
||||
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
|
||||
|
||||
# takes in argument for hidden states so that it can fall through and remain
|
||||
# a pure function
|
||||
def __call__(self, hidden_states):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style;
|
||||
No bias and no subtraction of mean
|
||||
"""
|
||||
# always compute in float32 for layer norm
|
||||
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
||||
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
# # %%
|
||||
# # testing T5LayerNorm
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# hidden_size = 128 # Example hidden size
|
||||
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
|
||||
# # Create some example input data
|
||||
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
|
||||
# # Forward pass
|
||||
# output = layer_norm(hidden_states)
|
||||
# print("Output shape:", output.shape)
|
||||
|
||||
# %%
|
||||
class KaimingLinear(eqx.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
weights: jax.Array
|
||||
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
initializer_factor: float,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
self.dtype = dtype
|
||||
|
||||
# the initialization strategy is to standardize on output dimension
|
||||
# input
|
||||
weights_init_std = initializer_factor * (input_dim**-0.5)
|
||||
# shapes are: (input_dim, output_dim)
|
||||
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Float[Array, " input"],
|
||||
):
|
||||
hidden = jnp.dot(inputs, self.weights)
|
||||
return hidden
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# this function fortunately supports batched operations by default due to broadcasting
|
||||
class T5DenseActDense(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
wi: jax.Array
|
||||
wo: jax.Array
|
||||
dropout: eqx.nn.Dropout
|
||||
act: jax.nn.relu
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
mlp_key, output_key = jax.random.split(key)
|
||||
# the initialization strategy is to standardize on output dimension
|
||||
# input
|
||||
# wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
||||
# shapes are: (config.d_model, config.d_ff)
|
||||
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
|
||||
self.wi = KaimingLinear(
|
||||
key=mlp_key,
|
||||
input_dim=self.config.d_model,
|
||||
output_dim=self.config.d_ff,
|
||||
initializer_factor=self.config.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
# output
|
||||
# wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
||||
# shapes are: (config.d_ff, config.d_model)
|
||||
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
|
||||
self.wo = KaimingLinear(
|
||||
key=mlp_key,
|
||||
input_dim=self.config.d_ff,
|
||||
output_dim=self.config.d_model,
|
||||
initializer_factor=self.config.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
|
||||
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||
# just set to relu for now since the smaller T5's use relu
|
||||
self.act = jax.nn.relu
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Float[Array, " d_model"],
|
||||
enable_dropout: bool = False,
|
||||
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||
) -> Float[Array, " d_model"]:
|
||||
hidden = self.wi(inputs)
|
||||
# hidden = jnp.dot(inputs, self.wi)
|
||||
hidden = self.act(hidden)
|
||||
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
|
||||
hidden = self.wo(hidden)
|
||||
# hidden = jnp.dot(hidden, self.wo)
|
||||
return hidden
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# # %%
|
||||
# # test for T5DenseActDense
|
||||
# # create fake config
|
||||
# config_dict = {
|
||||
# 'd_model': 768,
|
||||
# 'd_ff': 2048,
|
||||
# 'dropout_rate': 0.1,
|
||||
# 'initializer_factor': 1.0,
|
||||
# }
|
||||
# # Create a FrozenDict from the standard dictionary
|
||||
# frozen_config = FrozenConfigDict(config_dict)
|
||||
# # initialize model
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# dense = T5DenseActDense(
|
||||
# key=key,
|
||||
# config=frozen_config,
|
||||
# dtype=jnp.float32
|
||||
# )
|
||||
# input_key, key = jax.random.split(key)
|
||||
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||
# dropout_key, key = jax.random.split(key)
|
||||
# output = dense(inputs=inputs, enable_dropout=False, key=dropout_key)
|
||||
# output.shape
|
||||
|
||||
# %%
|
||||
class T5LayerFF(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
dtype: jnp.dtype
|
||||
DenseReluDense: T5DenseActDense
|
||||
layer_norm: T5LayerNorm
|
||||
dropout: eqx.nn.Dropout
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
dense_key, key = jax.random.split(key)
|
||||
# args: key, config, dtype
|
||||
self.DenseReluDense = T5DenseActDense(
|
||||
key=dense_key,
|
||||
config=config,
|
||||
dtype=dtype
|
||||
)
|
||||
layer_key, key = jax.random.split(key)
|
||||
# args: key, hidden_size
|
||||
self.layer_norm = T5LayerNorm(
|
||||
key=layer_key,
|
||||
hidden_size=self.config.d_model
|
||||
)
|
||||
# args: dropout_rate
|
||||
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
||||
|
||||
def __call__(
|
||||
self: eqx.Module,
|
||||
inputs: Float[Array, " d_model"],
|
||||
enable_dropout: bool =False,
|
||||
dropout_key: Optional[jax.random.PRNGKey] = None,
|
||||
):
|
||||
forwarded_states = self.layer_norm(inputs)
|
||||
dropout_key, key = jax.random.split(dropout_key)
|
||||
forwarded_states = self.DenseReluDense(
|
||||
inputs=forwarded_states,
|
||||
enable_dropout=enable_dropout,
|
||||
dropout_key=dropout_key
|
||||
)
|
||||
dropout_key, key = jax.random.split(key)
|
||||
dropout_states = self.dropout(
|
||||
x = forwarded_states,
|
||||
key = dropout_key,
|
||||
inference=not enable_dropout
|
||||
)
|
||||
hidden = inputs + dropout_states
|
||||
return hidden
|
||||
|
||||
# # %%
|
||||
# # test for T5DenseActDense
|
||||
# # create fake config
|
||||
# config_dict = {
|
||||
# 'd_model': 768,
|
||||
# 'd_ff': 2048,
|
||||
# 'dropout_rate': 0.1,
|
||||
# 'initializer_factor': 1.0,
|
||||
# }
|
||||
# # Create a FrozenDict from the standard dictionary
|
||||
# frozen_config = FrozenConfigDict(config_dict)
|
||||
# # initialize model
|
||||
# key = jax.random.PRNGKey(0)
|
||||
# ff_layer = T5LayerFF(
|
||||
# key=key,
|
||||
# config=frozen_config,
|
||||
# dtype=jnp.float32
|
||||
# )
|
||||
# input_key, key = jax.random.split(key)
|
||||
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
||||
# dropout_key, key = jax.random.split(key)
|
||||
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
||||
# output.shape
|
||||
|
||||
# %%
|
||||
class T5Attention(eqx.Module):
|
||||
config: FrozenConfigDict
|
||||
has_relative_attention_bias: bool = False
|
||||
causal: bool = False # False for encoder, True for decoder
|
||||
dtype: jnp.dtype
|
||||
|
||||
# additional terms
|
||||
relative_attention_num_buckets: int
|
||||
relative_attention_max_distance: int
|
||||
d_model: int
|
||||
key_value_proj_dim: int
|
||||
n_heads: int
|
||||
dropout: float
|
||||
inner_dim: int
|
||||
initializer_factor: float
|
||||
|
||||
|
||||
def __init__(
|
||||
self: eqx.Module,
|
||||
key: jax.random.PRNGKey,
|
||||
config: FrozenConfigDict,
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
):
|
||||
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
||||
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
||||
self.d_model = self.config.d_model
|
||||
# size of k,v projection for each head
|
||||
self.key_value_proj_dim = self.config.d_kv
|
||||
self.n_heads = self.config.num_heads
|
||||
self.dropout = self.config.dropout_rate
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
self.initializer_factor = self.config.initializer_factor
|
||||
|
||||
|
||||
q_key, key = jax.random.split(key)
|
||||
self.q = KaimingLinear(
|
||||
key=q_key,
|
||||
input_dim=(self.inner_dim * self.key_value_proj_dim),
|
||||
output_dim=self.inner_dim,
|
||||
initializer_factor=self.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
k_key, key = jax.random.split(key)
|
||||
self.k = KaimingLinear(
|
||||
key=k_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.inner_dim,
|
||||
initializer_factor=self.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
v_key, key = jax.random.split(key)
|
||||
self.v = KaimingLinear(
|
||||
key=v_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.inner_dim,
|
||||
initializer_factor=self.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
o_key, key = jax.random.split(key)
|
||||
self.o = KaimingLinear(
|
||||
key=o_key,
|
||||
input_dim=self.inner_dim,
|
||||
output_dim=self.d_model,
|
||||
initializer_factor=self.initializer_factor,
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
# 1 bias per head, so output is n_heads
|
||||
# bias is learned during training
|
||||
if self.has_relative_attention_bias:
|
||||
input_dim = self.relative_attention_num_buckets
|
||||
output_dim = self.n_heads
|
||||
initializer_factor=self.initializer_factor
|
||||
# we standardize based on the output dimension,
|
||||
# which is n_head * kv_proj_dim - during multi head attention
|
||||
weights_init_std = initializer_factor * (self.inner_dim**-0.5)
|
||||
# shapes are: (input_dim, output_dim)
|
||||
weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std
|
||||
|
||||
self.relative_attention_bias = eqx.nn.Embedding(
|
||||
weights=weights
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128
|
||||
):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
"""
|
||||
relative_buckets = 0
|
||||
|
||||
# bidirection determines if positive relative positions are valid
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
# relative position range of [0, inf]
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
|
||||
# half of buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
# boolean to assign relative buckets later
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# other half are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
# jnp.where(condition, x, y), true->x, false->y
|
||||
# in-place cumulative summation
|
||||
# yields a list where every element has the correct relative bucket position
|
||||
# whether its small or large
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
return relative_buckets.astype("i4")
|
||||
|
||||
# bias gives weight based on relative distance aside from attention score
|
||||
def compute_bias(self, query_length, key_length):
|
||||
"""
|
||||
Compute binned relative position bias
|
||||
"""
|
||||
# arange in the first dim
|
||||
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
||||
# arange in the second dim
|
||||
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
||||
|
||||
# The relative position is defined as memory_position - query_position,
|
||||
# i.e. the distance in tokens from the attending position to the
|
||||
# attended-to position.
|
||||
#
|
||||
# 2D array where each entry represents the distance from a query token
|
||||
# to a key token
|
||||
relative_position = memory_position - context_position
|
||||
# now we apply the earlier bucket creation function
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position=relative_position,
|
||||
bidirectional=(not self.causal), # causal during decode -> not bi
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
# retrieve the bias values
|
||||
# shape (query_length, key_length, n_heads)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
# shape (1, n_heads, query_length, key_length)
|
||||
# ready for attention
|
||||
values = values.transpose((2, 0, 1))[None, :, :, :]
|
||||
return values
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
||||
|
||||
|
||||
def _create_position_bias(
|
||||
self,
|
||||
key_states,
|
||||
query_states,
|
||||
attention_mask,
|
||||
init_cache,
|
||||
seq_length,
|
||||
causal_attention_mask_shift
|
||||
):
|
||||
# unlike the flax version, we don't even check for cache
|
||||
key_length = key_states.shape[1]
|
||||
query_length = query_states.shape[1]
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.compute_bias(query_length, key_length)
|
||||
elif attention_mask is not None:
|
||||
position_bias = jnp.zeros_like(attention_mask)
|
||||
else:
|
||||
position_bias = jnp.zeros(
|
||||
(1, self.n_heads, query_length, key_length),
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
return position_bias
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask=None,
|
||||
key_value_states=None,
|
||||
position_bias=None,
|
||||
output_attentions=False,
|
||||
enable_dropout=False
|
||||
):
|
||||
batch_size, seq_length = inputs.shape[:2]
|
||||
|
||||
# q,k,v projections
|
||||
query_states = self.q(inputs)
|
||||
key_states = (
|
||||
self.k(inputs) if key_value_states is None else self.k(key_value_states)
|
||||
)
|
||||
value_states = (
|
||||
self.v(inputs) if key_value_states is None else self.v(key_value_states)
|
||||
)
|
||||
|
||||
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# counteract scaling in dot_product_attention_weights function
|
||||
# not sure if this is a good idea in equinox
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
# %%
|
||||
import os
|
||||
|
||||
# Set this to True to run the model on CPU only.
|
||||
USE_CPU_ONLY = False
|
||||
|
||||
flags = os.environ.get("XLA_FLAGS", "")
|
||||
if USE_CPU_ONLY:
|
||||
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||
# Enforce CPU-only execution
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
else:
|
||||
# GPU flags
|
||||
flags = (
|
||||
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||
'--xla_gpu_triton_gemm_any=True '
|
||||
# '--xla_gpu_enable_async_collectives=true '
|
||||
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
|
||||
os.environ["XLA_FLAGS"] = flags
|
||||
os.environ.update({
|
||||
"TOKENIZERS_PARALLELISM" : "false",
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||
"NCCL_LL128_BUFFSIZE": "-2",
|
||||
"NCCL_LL_BUFFSIZE": "-2",
|
||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
|
||||
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||
})
|
||||
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Callable, Optional
|
||||
import math
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.experimental import mesh_utils
|
||||
|
||||
from jax.sharding import PartitionSpec
|
||||
|
||||
# set cache
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||
|
||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
|
||||
|
||||
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset
|
||||
import evaluate
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
|
||||
from ml_collections import ConfigDict
|
||||
|
||||
import time
|
||||
|
||||
from parallel.dataload import DataPrepare
|
||||
|
||||
|
||||
# %%
|
||||
# import data
|
||||
|
||||
# load training
|
||||
data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv"
|
||||
# Ensure to include 'ships_idx' in the fields list
|
||||
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||
# Load the dataset
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||
|
||||
# # load valid
|
||||
# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv"
|
||||
# # Ensure to include 'ships_idx' in the fields list
|
||||
# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||
# # Load the dataset
|
||||
# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||
def process_df(df):
|
||||
output_list = [{
|
||||
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||||
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||||
} for _, row in df.iterrows()]
|
||||
return output_list
|
||||
|
||||
|
||||
# takes 1 minute to run without batching
|
||||
train_dataset = Dataset.from_list(process_df(train_df))
|
||||
|
||||
|
||||
print("preparing data")
|
||||
data_config = ConfigDict(
|
||||
dict(
|
||||
max_length=128,
|
||||
pad_token_id=0,
|
||||
decoder_start_token_id=0
|
||||
)
|
||||
)
|
||||
|
||||
dataprep = DataPrepare(train_dataset, data_config)
|
||||
|
||||
# %%
|
||||
# load model
|
||||
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
|
||||
from transformers import FlaxT5ForConditionalGeneration
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path
|
||||
)
|
||||
params = model.params
|
||||
|
||||
# %%
|
||||
# load data
|
||||
SEED = 117
|
||||
batch_size = 256 # per device batch_size
|
||||
# test_batch_size multiplies by 4 because we shard by 4 later
|
||||
train_batch_size = batch_size * jax.device_count()
|
||||
rng = jax.random.PRNGKey(SEED)
|
||||
|
||||
|
||||
# %%
|
||||
# setup sharding
|
||||
print("creating mesh")
|
||||
device_mesh = mesh_utils.create_device_mesh((4,1))
|
||||
print(device_mesh)
|
||||
|
||||
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||
print(mesh)
|
||||
|
||||
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||
|
||||
data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis
|
||||
# model_sharding=mesh_sharding(PartitionSpec('model'))
|
||||
replicate_sharding=mesh_sharding(PartitionSpec())
|
||||
|
||||
|
||||
# %%
|
||||
# define function to get encodings
|
||||
|
||||
def get_encodings(batch, params):
|
||||
input_ids=batch['input_ids']
|
||||
attention_mask=batch['attention_mask']
|
||||
input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1]))
|
||||
attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1]))
|
||||
encoder_outputs = model.encode(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
train=False,
|
||||
params=params,
|
||||
dropout_rng=None
|
||||
)
|
||||
# encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed)
|
||||
# the embedding is not the full embedding size, but the self-attention embed
|
||||
# size
|
||||
|
||||
# parallelize by multiply encoder outputs with attention mask
|
||||
# shape (batch, embed) -> (batch, embed, 1)
|
||||
# this helps it to have the same shape as encoder_outputs
|
||||
expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1)
|
||||
# here is an element-wise multiply
|
||||
embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768)
|
||||
# summing embeddings in axis 1 will sum column-wise into a (batch, 768)
|
||||
# summing attention_mask in axis 1 will sum column-wise to get the total
|
||||
# unmasked token count for data entry
|
||||
mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1)
|
||||
# the shape of mean_embeddings is (batch, embed), we are ready to return
|
||||
return mean_embeddings
|
||||
|
||||
get_encodings_jit = jax.jit(
|
||||
functools.partial(get_encodings, params=params),
|
||||
# rng, batch
|
||||
in_shardings=(data_sharding),
|
||||
out_shardings=replicate_sharding,
|
||||
)
|
||||
# # %%
|
||||
# # test the get_encodings function
|
||||
#
|
||||
# rng, input_rng = jax.random.split(rng)
|
||||
# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
||||
# batch = next(train_loader)
|
||||
# encodings = get_encodings(batch, params)
|
||||
# # function works!
|
||||
|
||||
|
||||
# %%
|
||||
# perform actual prediction
|
||||
encoding_list = []
|
||||
# note: train_batch_size is batch_size * 4
|
||||
# we have 4 devices
|
||||
pred_steps = math.ceil(len(train_dataset) / train_batch_size)
|
||||
print("***** Running prediction *****")
|
||||
print(f" Num examples = {len(train_dataset)}")
|
||||
print(f" Num steps = {pred_steps}")
|
||||
print(f" Instantaneous batch size per device = {batch_size}")
|
||||
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
||||
|
||||
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||||
batch = next(train_loader)
|
||||
batch = jax.device_put(batch, data_sharding)
|
||||
encodings = get_encodings_jit(batch)
|
||||
encoding_list.extend(jax.device_get(encodings))
|
||||
|
||||
|
||||
# %%
|
||||
encoding_list = jnp.vstack(encoding_list)
|
||||
# slice up to the previously defined list to unpad
|
||||
encoding_list = encoding_list[:len(train_dataset)]
|
||||
print(encoding_list.shape)
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# getting top-k
|
||||
def top_k_cosine_similarity(M, a, k):
|
||||
"""
|
||||
Find the top-k rows in matrix M that are most cosine similar to array a.
|
||||
|
||||
Args:
|
||||
M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector.
|
||||
a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M.
|
||||
k (int): Number of top cosine similarities to retrieve.
|
||||
|
||||
Returns:
|
||||
values (jnp.ndarray): Top-k cosine similarity values.
|
||||
indices (jnp.ndarray): Indices of the top-k most similar rows in M.
|
||||
"""
|
||||
# Step 1: Normalize both M and a
|
||||
M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d)
|
||||
a_norm = a / jnp.linalg.norm(a) # Shape: (d,)
|
||||
|
||||
# Step 2: Compute cosine similarity via dot product
|
||||
cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,)
|
||||
|
||||
# Step 3: Get the top-k values and their indices using jax.lax.top_k
|
||||
values, indices = jax.lax.top_k(cosine_similarities, k)
|
||||
|
||||
return values, indices
|
||||
|
||||
# Example usage:
|
||||
M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions
|
||||
a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions
|
||||
|
||||
# Find top 5 most similar rows
|
||||
top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5)
|
||||
|
||||
print("Top-k cosine similarity values:", top_k_values)
|
||||
print("Indices of top-k similar rows:", top_k_indices)
|
||||
|
||||
# %%
|
|
@ -0,0 +1 @@
|
|||
MNIST
|
|
@ -0,0 +1,168 @@
|
|||
# %%
|
||||
import tensorflow_datasets as tfds # TFDS for MNIST
|
||||
import tensorflow as tf # TensorFlow operations
|
||||
|
||||
tf.random.set_seed(0) # set random seed for reproducibility
|
||||
|
||||
num_epochs = 10
|
||||
batch_size = 32
|
||||
|
||||
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
|
||||
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
|
||||
|
||||
train_ds = train_ds.map(
|
||||
lambda sample: {
|
||||
'image': tf.cast(sample['image'], tf.float32) / 255,
|
||||
'label': sample['label'],
|
||||
}
|
||||
) # normalize train set
|
||||
test_ds = test_ds.map(
|
||||
lambda sample: {
|
||||
'image': tf.cast(sample['image'], tf.float32) / 255,
|
||||
'label': sample['label'],
|
||||
}
|
||||
) # normalize test set
|
||||
|
||||
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
||||
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
|
||||
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
||||
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
||||
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
||||
test_ds = test_ds.shuffle(1024)
|
||||
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
||||
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
||||
# %%
|
||||
from flax import nnx # NNX API
|
||||
from functools import partial
|
||||
|
||||
class CNN(nnx.Module):
|
||||
"""A simple CNN model."""
|
||||
|
||||
def __init__(self, *, rngs: nnx.Rngs):
|
||||
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
|
||||
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
|
||||
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
|
||||
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
|
||||
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.avg_pool(nnx.relu(self.conv1(x)))
|
||||
x = self.avg_pool(nnx.relu(self.conv2(x)))
|
||||
x = x.reshape(x.shape[0], -1) # flatten
|
||||
x = nnx.relu(self.linear1(x))
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
model = CNN(rngs=nnx.Rngs(0))
|
||||
|
||||
# %%
|
||||
nnx.display(model)
|
||||
# %%
|
||||
# test the model by feeding an example input
|
||||
import jax.numpy as jnp # JAX NumPy
|
||||
|
||||
y = model(jnp.ones((1, 28, 28, 1)))
|
||||
nnx.display(y)
|
||||
# %%
|
||||
import optax
|
||||
|
||||
learning_rate = 0.005
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
|
||||
metrics = nnx.MultiMetric(
|
||||
accuracy=nnx.metrics.Accuracy(),
|
||||
loss=nnx.metrics.Average('loss'),
|
||||
)
|
||||
|
||||
nnx.display(optimizer)
|
||||
|
||||
# %%
|
||||
|
||||
def loss_fn(model: CNN, batch):
|
||||
logits = model(batch['image'])
|
||||
loss = optax.softmax_cross_entropy_with_integer_labels(
|
||||
logits=logits, labels=batch['label']
|
||||
).mean()
|
||||
return loss, logits
|
||||
|
||||
# %%
|
||||
@nnx.jit
|
||||
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
|
||||
"""Train for a single step."""
|
||||
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
|
||||
(loss, logits), grads = grad_fn(model, batch)
|
||||
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
||||
optimizer.update(grads)
|
||||
|
||||
# %%
|
||||
# evaluation step
|
||||
@nnx.jit
|
||||
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
|
||||
loss, logits = loss_fn(model, batch)
|
||||
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
||||
|
||||
|
||||
# %%
|
||||
# for dataset seed random generation
|
||||
tf.random.set_seed(0)
|
||||
|
||||
# %%
|
||||
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
|
||||
|
||||
metrics_history = {
|
||||
'train_loss': [],
|
||||
'train_accuracy': [],
|
||||
'test_loss': [],
|
||||
'test_accuracy': [],
|
||||
}
|
||||
|
||||
for step, batch in enumerate(train_ds.as_numpy_iterator()):
|
||||
# Run the optimization for one step and make a stateful update to the following:
|
||||
# - the train state's model parameters
|
||||
# - the optimizer state
|
||||
# - the training loss and accuracy batch metrics
|
||||
train_step(model, optimizer, metrics, batch)
|
||||
|
||||
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
|
||||
# Log training metrics
|
||||
for metric, value in metrics.compute().items(): # compute metrics
|
||||
metrics_history[f'train_{metric}'].append(value) # record metrics
|
||||
metrics.reset() # reset metrics for test set
|
||||
|
||||
# Compute metrics on the test set after each training epoch
|
||||
for test_batch in test_ds.as_numpy_iterator():
|
||||
eval_step(model, metrics, test_batch)
|
||||
|
||||
# Log test metrics
|
||||
for metric, value in metrics.compute().items():
|
||||
metrics_history[f'test_{metric}'].append(value)
|
||||
metrics.reset() # reset metrics for next training epoch
|
||||
|
||||
print(
|
||||
f"train epoch: {(step+1) // num_steps_per_epoch}, "
|
||||
f"loss: {metrics_history['train_loss'][-1]}, "
|
||||
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
|
||||
)
|
||||
print(
|
||||
f"test epoch: {(step+1) // num_steps_per_epoch}, "
|
||||
f"loss: {metrics_history['test_loss'][-1]}, "
|
||||
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
|
||||
)
|
||||
# %%
|
||||
# visualize metrics
|
||||
import matplotlib.pyplot as plt # Visualization
|
||||
|
||||
# Plot loss and accuracy in subplots
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||
ax1.set_title('Loss')
|
||||
ax2.set_title('Accuracy')
|
||||
for dataset in ('train', 'test'):
|
||||
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
|
||||
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
|
||||
ax1.legend()
|
||||
ax2.legend()
|
||||
plt.show()
|
||||
|
||||
# %%
|
Loading…
Reference in New Issue