# %% # package imports from equinox BERT example import functools from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple # import einops # https://github.com/arogozhnikov/einops import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax # https://github.com/deepmind/optax from datasets import load_dataset # https://github.com/huggingface/datasets from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm from transformers import AutoTokenizer # https://github.com/huggingface/transformers from ml_collections import ConfigDict, FrozenConfigDict import flax.linen as nn # %% class T5LayerNorm(eqx.Module): eps: float = 1e-6 weight: jax.Array # staticmethod forces the method to be by itself weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones) def __init__( self: eqx.Module, key: jax.random.PRNGKey, hidden_size: int, # dtype: jnp.dtype = jnp.float32, ): # self.dtype = dtype # self.params = { # 'weight': self.weight_init(key, (hidden_size,), dtype) # } # force the use of float32 # note that the key argument is ignored, so key is actually optional self.weight = self.weight_init(key, (hidden_size,), jnp.float32) # takes in argument for hidden states so that it can fall through and remain # a pure function def __call__(self, hidden_states): """ Construct a layernorm module in the T5 style; No bias and no subtraction of mean """ # always compute in float32 for layer norm variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) hidden_states = hidden_states / jnp.sqrt(variance + self.eps) return self.weight * hidden_states # # %% # # testing T5LayerNorm # key = jax.random.PRNGKey(0) # hidden_size = 128 # Example hidden size # layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size) # # Create some example input data # hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10 # # Forward pass # output = layer_norm(hidden_states) # print("Output shape:", output.shape) # %% class KaimingLinear(eqx.Module): dtype: jnp.dtype = jnp.float32 weights: jax.Array def __init__( self: eqx.Module, key: jax.random.PRNGKey, input_dim: int, output_dim: int, initializer_factor: float, dtype: jnp.dtype = jnp.float32 ): self.dtype = dtype # the initialization strategy is to standardize on output dimension # input weights_init_std = initializer_factor * (input_dim**-0.5) # shapes are: (input_dim, output_dim) self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std def __call__( self, inputs: Float[Array, " input"], ): hidden = jnp.dot(inputs, self.weights) return hidden # %% # this function fortunately supports batched operations by default due to broadcasting class T5DenseActDense(eqx.Module): config: FrozenConfigDict dtype: jnp.dtype = jnp.float32 wi: jax.Array wo: jax.Array dropout: eqx.nn.Dropout act: jax.nn.relu def __init__( self: eqx.Module, key: jax.random.PRNGKey, config: FrozenConfigDict, dtype: jnp.dtype = jnp.float32 ): self.config = config self.dtype = dtype mlp_key, output_key = jax.random.split(key) # the initialization strategy is to standardize on output dimension # input # wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) # shapes are: (config.d_model, config.d_ff) # self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std self.wi = KaimingLinear( key=mlp_key, input_dim=self.config.d_model, output_dim=self.config.d_ff, initializer_factor=self.config.initializer_factor, dtype=self.dtype ) # output # wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) # shapes are: (config.d_ff, config.d_model) # self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std self.wo = KaimingLinear( key=mlp_key, input_dim=self.config.d_ff, output_dim=self.config.d_model, initializer_factor=self.config.initializer_factor, dtype=self.dtype ) self.dropout = eqx.nn.Dropout(self.config.dropout_rate) # just set to relu for now since the smaller T5's use relu self.act = jax.nn.relu def __call__( self, inputs: Float[Array, " d_model"], enable_dropout: bool = False, dropout_key: Optional[jax.random.PRNGKey] = None, ) -> Float[Array, " d_model"]: hidden = self.wi(inputs) # hidden = jnp.dot(inputs, self.wi) hidden = self.act(hidden) hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key) hidden = self.wo(hidden) # hidden = jnp.dot(hidden, self.wo) return hidden # # %% # # test for T5DenseActDense # # create fake config # config_dict = { # 'd_model': 768, # 'd_ff': 2048, # 'dropout_rate': 0.1, # 'initializer_factor': 1.0, # } # # Create a FrozenDict from the standard dictionary # frozen_config = FrozenConfigDict(config_dict) # # initialize model # key = jax.random.PRNGKey(0) # dense = T5DenseActDense( # key=key, # config=frozen_config, # dtype=jnp.float32 # ) # input_key, key = jax.random.split(key) # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values # dropout_key, key = jax.random.split(key) # output = dense(inputs=inputs, enable_dropout=False, key=dropout_key) # output.shape # %% class T5LayerFF(eqx.Module): config: FrozenConfigDict dtype: jnp.dtype DenseReluDense: T5DenseActDense layer_norm: T5LayerNorm dropout: eqx.nn.Dropout def __init__( self: eqx.Module, key: jax.random.PRNGKey, config: FrozenConfigDict, dtype: jnp.dtype = jnp.float32 ): self.config = config self.dtype = dtype dense_key, key = jax.random.split(key) # args: key, config, dtype self.DenseReluDense = T5DenseActDense( key=dense_key, config=config, dtype=dtype ) layer_key, key = jax.random.split(key) # args: key, hidden_size self.layer_norm = T5LayerNorm( key=layer_key, hidden_size=self.config.d_model ) # args: dropout_rate self.dropout = eqx.nn.Dropout(self.config.dropout_rate) def __call__( self: eqx.Module, inputs: Float[Array, " d_model"], enable_dropout: bool =False, dropout_key: Optional[jax.random.PRNGKey] = None, ): forwarded_states = self.layer_norm(inputs) dropout_key, key = jax.random.split(dropout_key) forwarded_states = self.DenseReluDense( inputs=forwarded_states, enable_dropout=enable_dropout, dropout_key=dropout_key ) dropout_key, key = jax.random.split(key) dropout_states = self.dropout( x = forwarded_states, key = dropout_key, inference=not enable_dropout ) hidden = inputs + dropout_states return hidden # # %% # # test for T5DenseActDense # # create fake config # config_dict = { # 'd_model': 768, # 'd_ff': 2048, # 'dropout_rate': 0.1, # 'initializer_factor': 1.0, # } # # Create a FrozenDict from the standard dictionary # frozen_config = FrozenConfigDict(config_dict) # # initialize model # key = jax.random.PRNGKey(0) # ff_layer = T5LayerFF( # key=key, # config=frozen_config, # dtype=jnp.float32 # ) # input_key, key = jax.random.split(key) # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values # dropout_key, key = jax.random.split(key) # output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) # output.shape # %% class T5Attention(eqx.Module): config: FrozenConfigDict has_relative_attention_bias: bool = False causal: bool = False # False for encoder, True for decoder dtype: jnp.dtype # additional terms relative_attention_num_buckets: int relative_attention_max_distance: int d_model: int key_value_proj_dim: int n_heads: int dropout: float inner_dim: int initializer_factor: float def __init__( self: eqx.Module, key: jax.random.PRNGKey, config: FrozenConfigDict, dtype: jnp.dtype = jnp.float32 ): self.relative_attention_num_buckets = self.config.relative_attention_num_buckets self.relative_attention_max_distance = self.config.relative_attention_max_distance self.d_model = self.config.d_model # size of k,v projection for each head self.key_value_proj_dim = self.config.d_kv self.n_heads = self.config.num_heads self.dropout = self.config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim self.initializer_factor = self.config.initializer_factor q_key, key = jax.random.split(key) self.q = KaimingLinear( key=q_key, input_dim=(self.inner_dim * self.key_value_proj_dim), output_dim=self.inner_dim, initializer_factor=self.initializer_factor, dtype=self.dtype ) k_key, key = jax.random.split(key) self.k = KaimingLinear( key=k_key, input_dim=self.inner_dim, output_dim=self.inner_dim, initializer_factor=self.initializer_factor, dtype=self.dtype ) v_key, key = jax.random.split(key) self.v = KaimingLinear( key=v_key, input_dim=self.inner_dim, output_dim=self.inner_dim, initializer_factor=self.initializer_factor, dtype=self.dtype ) o_key, key = jax.random.split(key) self.o = KaimingLinear( key=o_key, input_dim=self.inner_dim, output_dim=self.d_model, initializer_factor=self.initializer_factor, dtype=self.dtype ) # 1 bias per head, so output is n_heads # bias is learned during training if self.has_relative_attention_bias: input_dim = self.relative_attention_num_buckets output_dim = self.n_heads initializer_factor=self.initializer_factor # we standardize based on the output dimension, # which is n_head * kv_proj_dim - during multi head attention weights_init_std = initializer_factor * (self.inner_dim**-0.5) # shapes are: (input_dim, output_dim) weights= jax.random.normal(key, (input_dim, output_dim), dtype=self.dtype) * weights_init_std self.relative_attention_bias = eqx.nn.Embedding( weights=weights ) @staticmethod def _relative_position_bucket( relative_position, bidirectional=True, num_buckets=32, max_distance=128 ): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on """ relative_buckets = 0 # bidirection determines if positive relative positions are valid if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: # relative position range of [0, inf] relative_position = -jnp.clip(relative_position, a_max=0) # half of buckets are for exact increments in positions max_exact = num_buckets // 2 # boolean to assign relative buckets later is_small = relative_position < max_exact # other half are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) # jnp.where(condition, x, y), true->x, false->y # in-place cumulative summation # yields a list where every element has the correct relative bucket position # whether its small or large relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) return relative_buckets.astype("i4") # bias gives weight based on relative distance aside from attention score def compute_bias(self, query_length, key_length): """ Compute binned relative position bias """ # arange in the first dim context_position = jnp.arange(query_length, dtype="i4")[:, None] # arange in the second dim memory_position = jnp.arange(key_length, dtype="i4")[None, :] # The relative position is defined as memory_position - query_position, # i.e. the distance in tokens from the attending position to the # attended-to position. # # 2D array where each entry represents the distance from a query token # to a key token relative_position = memory_position - context_position # now we apply the earlier bucket creation function relative_position_bucket = self._relative_position_bucket( relative_position=relative_position, bidirectional=(not self.causal), # causal during decode -> not bi num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) # retrieve the bias values # shape (query_length, key_length, n_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (1, n_heads, query_length, key_length) # ready for attention values = values.transpose((2, 0, 1))[None, :, :, :] return values def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) def _create_position_bias( self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift ): # unlike the flax version, we don't even check for cache key_length = key_states.shape[1] query_length = query_states.shape[1] if self.has_relative_attention_bias: position_bias = self.compute_bias(query_length, key_length) elif attention_mask is not None: position_bias = jnp.zeros_like(attention_mask) else: position_bias = jnp.zeros( (1, self.n_heads, query_length, key_length), dtype=self.dtype ) return position_bias def __call__( self, inputs, attention_mask=None, key_value_states=None, position_bias=None, output_attentions=False, enable_dropout=False ): batch_size, seq_length = inputs.shape[:2] # q,k,v projections query_states = self.q(inputs) key_states = ( self.k(inputs) if key_value_states is None else self.k(key_value_states) ) value_states = ( self.v(inputs) if key_value_states is None else self.v(key_value_states) ) # reshape to (batch_size, seq_length, n_heads, head_dim) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # counteract scaling in dot_product_attention_weights function # not sure if this is a good idea in equinox