592 lines
20 KiB
Python
592 lines
20 KiB
Python
# %%
|
|
# package imports from equinox BERT example
|
|
import functools
|
|
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
|
|
|
|
# import einops # https://github.com/arogozhnikov/einops
|
|
import equinox as eqx
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import optax # https://github.com/deepmind/optax
|
|
from datasets import load_dataset # https://github.com/huggingface/datasets
|
|
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
|
|
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
|
|
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
|
|
from ml_collections import ConfigDict, FrozenConfigDict
|
|
|
|
# helper functions for attention computation
|
|
# they are implemented with jax w/o flax
|
|
from flax.linen import combine_masks, make_causal_mask
|
|
from flax.linen.attention import dot_product_attention_weights
|
|
|
|
import flax.linen as nn
|
|
# %%
|
|
class T5LayerNorm(eqx.Module):
|
|
eps: float = 1e-6
|
|
weight: jax.Array
|
|
# staticmethod forces the method to be by itself
|
|
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
|
|
|
|
def __init__(
|
|
self: eqx.Module,
|
|
hidden_size: int,
|
|
key: jax.random.PRNGKey,
|
|
# dtype: jnp.dtype = jnp.float32,
|
|
):
|
|
# self.dtype = dtype
|
|
# self.params = {
|
|
# 'weight': self.weight_init(key, (hidden_size,), dtype)
|
|
# }
|
|
# force the use of float32
|
|
# note that the key argument is ignored, so key is actually optional
|
|
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
|
|
|
|
# takes in argument for hidden states so that it can fall through and remain
|
|
# a pure function
|
|
def __call__(self, hidden_states):
|
|
"""
|
|
Construct a layernorm module in the T5 style;
|
|
No bias and no subtraction of mean
|
|
"""
|
|
# always compute in float32 for layer norm
|
|
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
|
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
|
|
|
return self.weight * hidden_states
|
|
|
|
# # %%
|
|
# # testing T5LayerNorm
|
|
# key = jax.random.PRNGKey(0)
|
|
# hidden_size = 128 # Example hidden size
|
|
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
|
|
# # Create some example input data
|
|
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
|
|
# # Forward pass
|
|
# output = layer_norm(hidden_states)
|
|
# print("Output shape:", output.shape)
|
|
|
|
# %%
|
|
class KaimingLinear(eqx.Module):
|
|
dtype: jnp.dtype = jnp.float32
|
|
weights: jax.Array
|
|
|
|
|
|
def __init__(
|
|
self: eqx.Module,
|
|
key: jax.random.PRNGKey,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
weights_init_std: float,
|
|
dtype: jnp.dtype = jnp.float32
|
|
):
|
|
self.dtype = dtype
|
|
|
|
# the initialization strategy is to standardize on output dimension
|
|
# shapes are: (input_dim, output_dim)
|
|
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: Float[Array, " input"],
|
|
):
|
|
hidden = jnp.dot(inputs, self.weights)
|
|
return hidden
|
|
|
|
|
|
|
|
# %%
|
|
# this function fortunately supports batched operations by default due to broadcasting
|
|
class T5DenseActDense(eqx.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32
|
|
wi: jax.Array
|
|
wo: jax.Array
|
|
dropout: eqx.nn.Dropout
|
|
act: jax.nn.relu
|
|
|
|
def __init__(
|
|
self: eqx.Module,
|
|
config: FrozenConfigDict,
|
|
dtype: jnp.dtype,
|
|
key: jax.random.PRNGKey
|
|
):
|
|
self.config = config
|
|
self.dtype = dtype
|
|
|
|
mlp_key, output_key = jax.random.split(key)
|
|
# the initialization strategy is to standardize on output dimension
|
|
# input
|
|
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
|
# shapes are: (config.d_model, config.d_ff)
|
|
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
|
|
self.wi = KaimingLinear(
|
|
key=mlp_key,
|
|
input_dim=self.config.d_model,
|
|
output_dim=self.config.d_ff,
|
|
weights_init_std=wi_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
# output
|
|
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
|
# shapes are: (config.d_ff, config.d_model)
|
|
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
|
|
self.wo = KaimingLinear(
|
|
key=mlp_key,
|
|
input_dim=self.config.d_ff,
|
|
output_dim=self.config.d_model,
|
|
weights_init_std=wo_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
|
|
|
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
|
# just set to relu for now since the smaller T5's use relu
|
|
self.act = jax.nn.relu
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: Float[Array, " d_model"],
|
|
enable_dropout: bool = False,
|
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
|
) -> Float[Array, " d_model"]:
|
|
hidden = self.wi(inputs)
|
|
# hidden = jnp.dot(inputs, self.wi)
|
|
hidden = self.act(hidden)
|
|
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
|
|
hidden = self.wo(hidden)
|
|
# hidden = jnp.dot(hidden, self.wo)
|
|
return hidden
|
|
|
|
|
|
# # %%
|
|
# # test for T5DenseActDense
|
|
# # create fake config
|
|
# config_dict = {
|
|
# 'd_model': 768,
|
|
# 'd_ff': 2048,
|
|
# 'dropout_rate': 0.1,
|
|
# 'initializer_factor': 1.0,
|
|
# }
|
|
# # Create a FrozenDict from the standard dictionary
|
|
# frozen_config = FrozenConfigDict(config_dict)
|
|
# # initialize model
|
|
# key = jax.random.PRNGKey(0)
|
|
# dense = T5DenseActDense(
|
|
# key=key,
|
|
# config=frozen_config,
|
|
# dtype=jnp.float32
|
|
# )
|
|
# input_key, key = jax.random.split(key)
|
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
|
# dropout_key, key = jax.random.split(key)
|
|
# output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
|
# output.shape
|
|
|
|
# %%
|
|
class T5LayerFF(eqx.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype
|
|
DenseReluDense: T5DenseActDense
|
|
layer_norm: T5LayerNorm
|
|
dropout: eqx.nn.Dropout
|
|
|
|
def __init__(
|
|
self: eqx.Module,
|
|
key: jax.random.PRNGKey,
|
|
config: FrozenConfigDict,
|
|
dtype: jnp.dtype = jnp.float32
|
|
):
|
|
|
|
self.config = config
|
|
self.dtype = dtype
|
|
|
|
dense_key, key = jax.random.split(key)
|
|
# args: key, config, dtype
|
|
self.DenseReluDense = T5DenseActDense(
|
|
key=dense_key,
|
|
config=config,
|
|
dtype=dtype
|
|
)
|
|
layer_key, key = jax.random.split(key)
|
|
# args: key, hidden_size
|
|
self.layer_norm = T5LayerNorm(
|
|
key=layer_key,
|
|
hidden_size=self.config.d_model
|
|
)
|
|
# args: dropout_rate
|
|
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
|
|
|
|
def __call__(
|
|
self: eqx.Module,
|
|
inputs: Float[Array, " d_model"],
|
|
enable_dropout: bool =False,
|
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
|
):
|
|
forwarded_states = self.layer_norm(inputs)
|
|
dropout_key, key = jax.random.split(dropout_key)
|
|
forwarded_states = self.DenseReluDense(
|
|
inputs=forwarded_states,
|
|
enable_dropout=enable_dropout,
|
|
dropout_key=dropout_key
|
|
)
|
|
dropout_key, key = jax.random.split(key)
|
|
dropout_states = self.dropout(
|
|
x = forwarded_states,
|
|
inference=not enable_dropout,
|
|
key = dropout_key,
|
|
)
|
|
hidden = inputs + dropout_states
|
|
return hidden
|
|
|
|
# # %%
|
|
# # test for T5DenseActDense
|
|
# # create fake config
|
|
# config_dict = {
|
|
# 'd_model': 768,
|
|
# 'd_ff': 2048,
|
|
# 'dropout_rate': 0.1,
|
|
# 'initializer_factor': 1.0,
|
|
# }
|
|
# # Create a FrozenDict from the standard dictionary
|
|
# frozen_config = FrozenConfigDict(config_dict)
|
|
# # initialize model
|
|
# key = jax.random.PRNGKey(0)
|
|
# ff_layer = T5LayerFF(
|
|
# key=key,
|
|
# config=frozen_config,
|
|
# dtype=jnp.float32
|
|
# )
|
|
# input_key, key = jax.random.split(key)
|
|
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
|
# dropout_key, key = jax.random.split(key)
|
|
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
|
# output.shape
|
|
|
|
# %%
|
|
class T5Attention(eqx.Module):
|
|
config: FrozenConfigDict
|
|
has_relative_attention_bias: bool = False
|
|
causal: bool = False # False for encoder, True for decoder
|
|
dtype: jnp.dtype
|
|
|
|
# parameters
|
|
q: jax.Array
|
|
k: jax.Array
|
|
v: jax.Array
|
|
o: jax.Array
|
|
|
|
|
|
# additional terms
|
|
relative_attention_num_buckets: int
|
|
relative_attention_max_distance: int
|
|
d_model: int
|
|
key_value_proj_dim: int
|
|
n_heads: int
|
|
dropout: float
|
|
inner_dim: int
|
|
initializer_factor: float
|
|
|
|
|
|
def __init__(
|
|
self: eqx.Module,
|
|
config: FrozenConfigDict,
|
|
dtype: jnp.dtype,
|
|
key: jax.random.PRNGKey,
|
|
):
|
|
self.config = config
|
|
self.dtype = dtype
|
|
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
|
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
|
self.d_model = self.config.d_model
|
|
# size of k,v projection for each head
|
|
self.key_value_proj_dim = self.config.d_kv
|
|
self.n_heads = self.config.num_heads
|
|
self.dropout = self.config.dropout_rate
|
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
|
self.initializer_factor = self.config.initializer_factor
|
|
|
|
q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
|
kv_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
|
o_init_std = self.initializer_factor * (self.inner_dim**-0.5)
|
|
|
|
q_key, key = jax.random.split(key)
|
|
self.q = KaimingLinear(
|
|
key=q_key,
|
|
input_dim=(self.inner_dim),
|
|
output_dim=self.inner_dim,
|
|
weights_init_std=q_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
|
|
k_key, key = jax.random.split(key)
|
|
self.k = KaimingLinear(
|
|
key=k_key,
|
|
input_dim=self.inner_dim,
|
|
output_dim=self.inner_dim,
|
|
weights_init_std=kv_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
|
|
v_key, key = jax.random.split(key)
|
|
self.v = KaimingLinear(
|
|
key=v_key,
|
|
input_dim=self.inner_dim,
|
|
output_dim=self.inner_dim,
|
|
weights_init_std=kv_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
|
|
o_key, key = jax.random.split(key)
|
|
self.o = KaimingLinear(
|
|
key=o_key,
|
|
input_dim=self.inner_dim,
|
|
output_dim=self.d_model,
|
|
weights_init_std=o_init_std,
|
|
dtype=self.dtype
|
|
)
|
|
|
|
@staticmethod
|
|
def _relative_position_bucket(
|
|
relative_position,
|
|
bidirectional=True,
|
|
num_buckets=32,
|
|
max_distance=128
|
|
):
|
|
"""
|
|
Adapted from Mesh Tensorflow:
|
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
|
"""
|
|
relative_buckets = 0
|
|
|
|
# bidirection determines if positive relative positions are valid
|
|
if bidirectional:
|
|
num_buckets //= 2
|
|
relative_buckets += (relative_position > 0) * num_buckets
|
|
relative_position = jnp.abs(relative_position)
|
|
else:
|
|
# relative position range of [0, inf]
|
|
relative_position = -jnp.clip(relative_position, a_max=0)
|
|
|
|
# half of buckets are for exact increments in positions
|
|
max_exact = num_buckets // 2
|
|
# boolean to assign relative buckets later
|
|
is_small = relative_position < max_exact
|
|
|
|
# other half are for logarithmically bigger bins in positions up to max_distance
|
|
relative_position_if_large = max_exact + (
|
|
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
|
)
|
|
|
|
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
|
|
|
# jnp.where(condition, x, y), true->x, false->y
|
|
# in-place cumulative summation
|
|
# yields a list where every element has the correct relative bucket position
|
|
# whether its small or large
|
|
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
|
|
|
return relative_buckets.astype("i4")
|
|
|
|
# bias gives weight based on relative distance aside from attention score
|
|
def compute_bias(self, query_length, key_length):
|
|
"""
|
|
Compute binned relative position bias
|
|
"""
|
|
# arange in the first dim
|
|
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
|
# arange in the second dim
|
|
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
|
|
|
# The relative position is defined as memory_position - query_position,
|
|
# i.e. the distance in tokens from the attending position to the
|
|
# attended-to position.
|
|
#
|
|
# 2D array where each entry represents the distance from a query token
|
|
# to a key token
|
|
relative_position = memory_position - context_position
|
|
# now we apply the earlier bucket creation function
|
|
relative_position_bucket = self._relative_position_bucket(
|
|
relative_position=relative_position,
|
|
bidirectional=(not self.causal), # causal during decode -> not bi
|
|
num_buckets=self.relative_attention_num_buckets,
|
|
max_distance=self.relative_attention_max_distance,
|
|
)
|
|
# retrieve the bias values
|
|
# shape (query_length, key_length, n_heads)
|
|
values = self.relative_attention_bias(relative_position_bucket)
|
|
# shape (1, n_heads, query_length, key_length)
|
|
# ready for attention
|
|
values = values.transpose((2, 0, 1))[None, :, :, :]
|
|
return values
|
|
|
|
|
|
# from (batch_size, seq_length, d_model) to
|
|
# (batch_size, seq_length, n_heads, head_dim)
|
|
def _split_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
|
|
|
# from (batch_size, seq_length, n_heads, head_dim) to
|
|
# (batch_size, seq_length, d_model)
|
|
def _merge_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
|
|
|
|
|
def _create_position_bias(
|
|
self,
|
|
key_states,
|
|
query_states,
|
|
attention_mask,
|
|
):
|
|
# unlike the flax version, we don't even check for cache
|
|
key_length = key_states.shape[1]
|
|
query_length = query_states.shape[1]
|
|
|
|
if self.has_relative_attention_bias:
|
|
position_bias = self.compute_bias(query_length, key_length)
|
|
elif attention_mask is not None:
|
|
position_bias = jnp.zeros_like(attention_mask)
|
|
else:
|
|
position_bias = jnp.zeros(
|
|
(1, self.n_heads, query_length, key_length),
|
|
dtype=self.dtype
|
|
)
|
|
|
|
return position_bias
|
|
|
|
def __call__(
|
|
self,
|
|
inputs,
|
|
attention_mask=None,
|
|
key_value_states=None,
|
|
position_bias=None,
|
|
output_attentions=False,
|
|
enable_dropout=False,
|
|
dropout_key: Optional[jax.random.PRNGKey] = None,
|
|
):
|
|
# expected input shape: (batch_size, seq_len, d_model)
|
|
# expected output: tuple of 2 arrays same shape as input
|
|
# (attn, position_bias)
|
|
batch_size, seq_length = inputs.shape[:2]
|
|
|
|
# q,k,v projections
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
query_states = self.q(inputs)
|
|
key_states = (
|
|
self.k(inputs) if key_value_states is None else self.k(key_value_states)
|
|
)
|
|
value_states = (
|
|
self.v(inputs) if key_value_states is None else self.v(key_value_states)
|
|
)
|
|
|
|
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
|
query_states = self._split_heads(query_states)
|
|
key_states = self._split_heads(key_states)
|
|
value_states = self._split_heads(value_states)
|
|
|
|
# counteract scaling in dot_product_attention_weights function
|
|
query_states *= jnp.sqrt(query_states.shape[-1])
|
|
|
|
# create causal attention_mask
|
|
if self.causal:
|
|
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
|
|
|
# broadcast causal attention mask & attention mask to fit for merge
|
|
causal_attention_mask = jnp.broadcast_to(
|
|
causal_attention_mask,
|
|
(batch_size,) + causal_attention_mask.shape[1:]
|
|
)
|
|
attention_mask = jnp.broadcast_to(
|
|
jnp.expand_dims(attention_mask, axis=(-3, -2)),
|
|
causal_attention_mask.shape
|
|
)
|
|
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
|
elif attention_mask is not None:
|
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
|
|
|
# replace masked positions with -10_000
|
|
if attention_mask is not None:
|
|
mask_value = jnp.finfo(self.dtype).min
|
|
attention_mask = jax.lax.select(
|
|
attention_mask > 0,
|
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
|
)
|
|
|
|
if position_bias is None:
|
|
# compute position bias (only for first layer)
|
|
position_bias = self._create_position_bias(
|
|
key_states, query_states, attention_mask
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
position_bias = position_bias + attention_mask
|
|
|
|
# Softmax(QK^T)
|
|
attn_weights = dot_product_attention_weights(
|
|
query_states,
|
|
key_states,
|
|
bias=position_bias,
|
|
dropout_rng=dropout_key,
|
|
dropout_rate=self.dropout,
|
|
broadcast_dropout=True,
|
|
deterministic=not enable_dropout,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
# multiply with value states
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
|
|
|
# bring back to (batch_size, seq_length, d_model)
|
|
attn_output = self._merge_heads(attn_output)
|
|
|
|
# apply output matrix
|
|
attn_output = self.o(attn_output)
|
|
|
|
outputs = (attn_output, position_bias)
|
|
|
|
if output_attentions:
|
|
outputs = outputs + (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
# # %%
|
|
# # test for T5Attention
|
|
# # create fake config
|
|
# config_dict = {
|
|
# 'relative_attention_num_buckets': 32,
|
|
# 'relative_attention_max_distance': 128,
|
|
# 'd_model': 768, # 64 * 12
|
|
# 'd_kv': 64,
|
|
# 'num_heads': 12,
|
|
# 'dropout_rate': 0.1,
|
|
# 'initializer_factor': 1.0,
|
|
# }
|
|
# # Create a FrozenDict from the standard dictionary
|
|
# frozen_config = FrozenConfigDict(config_dict)
|
|
# # initialize model
|
|
# key = jax.random.PRNGKey(0)
|
|
# attn_layer = T5Attention(
|
|
# key=key,
|
|
# config=frozen_config,
|
|
# dtype=jnp.float32
|
|
# )
|
|
# input_key, key = jax.random.split(key)
|
|
# # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
|
|
# batch_size = 1
|
|
# seq_length = 10
|
|
# inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model))
|
|
# dropout_key, key = jax.random.split(key)
|
|
# output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
|
|
# print(len(output))
|
|
# print(output[0].shape)
|
|
|
|
# %%
|