learn_jax/equinox/t5_train_model.py

496 lines
17 KiB
Python

# %%
# package imports from equinox BERT example
import functools
from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple
# import einops # https://github.com/arogozhnikov/einops
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax # https://github.com/deepmind/optax
from datasets import load_dataset # https://github.com/huggingface/datasets
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
from ml_collections import ConfigDict, FrozenConfigDict
import flax.linen as nn
# %%
class T5LayerNorm(eqx.Module):
eps: float = 1e-6
weight: jax.Array
# staticmethod forces the method to be by itself
weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones)
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
hidden_size: int,
# dtype: jnp.dtype = jnp.float32,
):
# self.dtype = dtype
# self.params = {
# 'weight': self.weight_init(key, (hidden_size,), dtype)
# }
# force the use of float32
# note that the key argument is ignored, so key is actually optional
self.weight = self.weight_init(key, (hidden_size,), jnp.float32)
# takes in argument for hidden states so that it can fall through and remain
# a pure function
def __call__(self, hidden_states):
"""
Construct a layernorm module in the T5 style;
No bias and no subtraction of mean
"""
# always compute in float32 for layer norm
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
return self.weight * hidden_states
# # %%
# # testing T5LayerNorm
# key = jax.random.PRNGKey(0)
# hidden_size = 128 # Example hidden size
# layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size)
# # Create some example input data
# hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10
# # Forward pass
# output = layer_norm(hidden_states)
# print("Output shape:", output.shape)
# %%
class KaimingLinear(eqx.Module):
dtype: jnp.dtype = jnp.float32
weights: jax.Array
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
input_dim: int,
output_dim: int,
initializer_factor: float,
dtype: jnp.dtype = jnp.float32
):
self.dtype = dtype
# the initialization strategy is to standardize on output dimension
# input
weights_init_std = initializer_factor * (input_dim**-0.5)
# shapes are: (input_dim, output_dim)
self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std
def __call__(
self,
inputs: Float[Array, " input"],
):
hidden = jnp.dot(inputs, self.weights)
return hidden
# %%
# this function fortunately supports batched operations by default due to broadcasting
class T5DenseActDense(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype = jnp.float32
wi: jax.Array
wo: jax.Array
dropout: eqx.nn.Dropout
act: jax.nn.relu
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.config = config
self.dtype = dtype
mlp_key, output_key = jax.random.split(key)
# the initialization strategy is to standardize on output dimension
# input
# wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
# shapes are: (config.d_model, config.d_ff)
# self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std
self.wi = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_model,
output_dim=self.config.d_ff,
initializer_factor=self.config.initializer_factor,
dtype=self.dtype
)
# output
# wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
# shapes are: (config.d_ff, config.d_model)
# self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std
self.wo = KaimingLinear(
key=mlp_key,
input_dim=self.config.d_ff,
output_dim=self.config.d_model,
initializer_factor=self.config.initializer_factor,
dtype=self.dtype
)
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
# just set to relu for now since the smaller T5's use relu
self.act = jax.nn.relu
def __call__(
self,
inputs: Float[Array, " d_model"],
enable_dropout: bool = False,
dropout_key: Optional[jax.random.PRNGKey] = None,
) -> Float[Array, " d_model"]:
hidden = self.wi(inputs)
# hidden = jnp.dot(inputs, self.wi)
hidden = self.act(hidden)
hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key)
hidden = self.wo(hidden)
# hidden = jnp.dot(hidden, self.wo)
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# dense = T5DenseActDense(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = dense(inputs=inputs, enable_dropout=False, key=dropout_key)
# output.shape
# %%
class T5LayerFF(eqx.Module):
config: FrozenConfigDict
dtype: jnp.dtype
DenseReluDense: T5DenseActDense
layer_norm: T5LayerNorm
dropout: eqx.nn.Dropout
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.config = config
self.dtype = dtype
dense_key, key = jax.random.split(key)
# args: key, config, dtype
self.DenseReluDense = T5DenseActDense(
key=dense_key,
config=config,
dtype=dtype
)
layer_key, key = jax.random.split(key)
# args: key, hidden_size
self.layer_norm = T5LayerNorm(
key=layer_key,
hidden_size=self.config.d_model
)
# args: dropout_rate
self.dropout = eqx.nn.Dropout(self.config.dropout_rate)
def __call__(
self: eqx.Module,
inputs: Float[Array, " d_model"],
enable_dropout: bool =False,
dropout_key: Optional[jax.random.PRNGKey] = None,
):
forwarded_states = self.layer_norm(inputs)
dropout_key, key = jax.random.split(dropout_key)
forwarded_states = self.DenseReluDense(
inputs=forwarded_states,
enable_dropout=enable_dropout,
dropout_key=dropout_key
)
dropout_key, key = jax.random.split(key)
dropout_states = self.dropout(
x = forwarded_states,
key = dropout_key,
inference=not enable_dropout
)
hidden = inputs + dropout_states
return hidden
# # %%
# # test for T5DenseActDense
# # create fake config
# config_dict = {
# 'd_model': 768,
# 'd_ff': 2048,
# 'dropout_rate': 0.1,
# 'initializer_factor': 1.0,
# }
# # Create a FrozenDict from the standard dictionary
# frozen_config = FrozenConfigDict(config_dict)
# # initialize model
# key = jax.random.PRNGKey(0)
# ff_layer = T5LayerFF(
# key=key,
# config=frozen_config,
# dtype=jnp.float32
# )
# input_key, key = jax.random.split(key)
# inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values
# dropout_key, key = jax.random.split(key)
# output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key)
# output.shape
# %%
class T5Attention(eqx.Module):
config: FrozenConfigDict
has_relative_attention_bias: bool = False
causal: bool = False # False for encoder, True for decoder
dtype: jnp.dtype
# additional terms
relative_attention_num_buckets: int
relative_attention_max_distance: int
d_model: int
key_value_proj_dim: int
n_heads: int
dropout: float
inner_dim: int
initializer_factor: float
def __init__(
self: eqx.Module,
key: jax.random.PRNGKey,
config: FrozenConfigDict,
dtype: jnp.dtype = jnp.float32
):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
# size of k,v projection for each head
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.initializer_factor = self.config.initializer_factor
q_key, key = jax.random.split(key)
self.q = KaimingLinear(
key=q_key,
input_dim=(self.inner_dim * self.key_value_proj_dim),
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
k_key, key = jax.random.split(key)
self.k = KaimingLinear(
key=k_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
v_key, key = jax.random.split(key)
self.v = KaimingLinear(
key=v_key,
input_dim=self.inner_dim,
output_dim=self.inner_dim,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
o_key, key = jax.random.split(key)
self.o = KaimingLinear(
key=o_key,
input_dim=self.inner_dim,
output_dim=self.d_model,
initializer_factor=self.initializer_factor,
dtype=self.dtype
)
# 1 bias per head, so output is n_heads
# bias is learned during training
if self.has_relative_attention_bias:
input_dim = self.relative_attention_num_buckets
output_dim = self.n_heads
initializer_factor=self.initializer_factor
# we standardize based on the output dimension,
# which is n_head * kv_proj_dim - during multi head attention
weights_init_std = initializer_factor * (self.inner_dim**-0.5)
# shapes are: (input_dim, output_dim)
weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std
self.relative_attention_bias = eqx.nn.Embedding(
weights=weights
)
@staticmethod
def _relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
"""
relative_buckets = 0
# bidirection determines if positive relative positions are valid
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
# relative position range of [0, inf]
relative_position = -jnp.clip(relative_position, a_max=0)
# half of buckets are for exact increments in positions
max_exact = num_buckets // 2
# boolean to assign relative buckets later
is_small = relative_position < max_exact
# other half are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
# jnp.where(condition, x, y), true->x, false->y
# in-place cumulative summation
# yields a list where every element has the correct relative bucket position
# whether its small or large
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
return relative_buckets.astype("i4")
# bias gives weight based on relative distance aside from attention score
def compute_bias(self, query_length, key_length):
"""
Compute binned relative position bias
"""
# arange in the first dim
context_position = jnp.arange(query_length, dtype="i4")[:, None]
# arange in the second dim
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
# The relative position is defined as memory_position - query_position,
# i.e. the distance in tokens from the attending position to the
# attended-to position.
#
# 2D array where each entry represents the distance from a query token
# to a key token
relative_position = memory_position - context_position
# now we apply the earlier bucket creation function
relative_position_bucket = self._relative_position_bucket(
relative_position=relative_position,
bidirectional=(not self.causal), # causal during decode -> not bi
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
# retrieve the bias values
# shape (query_length, key_length, n_heads)
values = self.relative_attention_bias(relative_position_bucket)
# shape (1, n_heads, query_length, key_length)
# ready for attention
values = values.transpose((2, 0, 1))[None, :, :, :]
return values
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
def _create_position_bias(
self,
key_states,
query_states,
attention_mask,
init_cache,
seq_length,
causal_attention_mask_shift
):
# unlike the flax version, we don't even check for cache
key_length = key_states.shape[1]
query_length = query_states.shape[1]
if self.has_relative_attention_bias:
position_bias = self.compute_bias(query_length, key_length)
elif attention_mask is not None:
position_bias = jnp.zeros_like(attention_mask)
else:
position_bias = jnp.zeros(
(1, self.n_heads, query_length, key_length),
dtype=self.dtype
)
return position_bias
def __call__(
self,
inputs,
attention_mask=None,
key_value_states=None,
position_bias=None,
output_attentions=False,
enable_dropout=False
):
batch_size, seq_length = inputs.shape[:2]
# q,k,v projections
query_states = self.q(inputs)
key_states = (
self.k(inputs) if key_value_states is None else self.k(key_value_states)
)
value_states = (
self.v(inputs) if key_value_states is None else self.v(key_value_states)
)
# reshape to (batch_size, seq_length, n_heads, head_dim)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# counteract scaling in dot_product_attention_weights function
# not sure if this is a good idea in equinox