# %% # package imports from equinox BERT example import functools from typing import Dict, List, Mapping, Optional, Callable, Optional, Tuple # import einops # https://github.com/arogozhnikov/einops import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax # https://github.com/deepmind/optax from datasets import load_dataset # https://github.com/huggingface/datasets from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm from transformers import AutoTokenizer # https://github.com/huggingface/transformers from ml_collections import ConfigDict, FrozenConfigDict # helper functions for attention computation # they are implemented with jax w/o flax from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights import flax.linen as nn # %% class T5LayerNorm(eqx.Module): eps: float = 1e-6 weight: jax.Array # staticmethod forces the method to be by itself weight_init: Callable[..., np.ndarray] = staticmethod(jax.nn.initializers.ones) def __init__( self: eqx.Module, hidden_size: int, key: jax.random.PRNGKey, # dtype: jnp.dtype = jnp.float32, ): # self.dtype = dtype # self.params = { # 'weight': self.weight_init(key, (hidden_size,), dtype) # } # force the use of float32 # note that the key argument is ignored, so key is actually optional self.weight = self.weight_init(key, (hidden_size,), jnp.float32) # takes in argument for hidden states so that it can fall through and remain # a pure function def __call__(self, hidden_states): """ Construct a layernorm module in the T5 style; No bias and no subtraction of mean """ # always compute in float32 for layer norm variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) hidden_states = hidden_states / jnp.sqrt(variance + self.eps) return self.weight * hidden_states # # %% # # testing T5LayerNorm # key = jax.random.PRNGKey(0) # hidden_size = 128 # Example hidden size # layer_norm = T5LayerNorm(key=key, hidden_size=hidden_size) # # Create some example input data # hidden_states = jnp.ones((1, 10, hidden_size)) # Batch size of 1, sequence length of 10 # # Forward pass # output = layer_norm(hidden_states) # print("Output shape:", output.shape) # %% class KaimingLinear(eqx.Module): dtype: jnp.dtype = jnp.float32 weights: jax.Array def __init__( self: eqx.Module, key: jax.random.PRNGKey, input_dim: int, output_dim: int, weights_init_std: float, dtype: jnp.dtype = jnp.float32 ): self.dtype = dtype # the initialization strategy is to standardize on output dimension # shapes are: (input_dim, output_dim) self.weights= jax.random.normal(key, (input_dim, output_dim)) * weights_init_std def __call__( self, inputs: Float[Array, " input"], ): hidden = jnp.dot(inputs, self.weights) return hidden # %% # this function fortunately supports batched operations by default due to broadcasting class T5DenseActDense(eqx.Module): config: FrozenConfigDict dtype: jnp.dtype = jnp.float32 wi: jax.Array wo: jax.Array dropout: eqx.nn.Dropout act: jax.nn.relu def __init__( self: eqx.Module, config: FrozenConfigDict, dtype: jnp.dtype, key: jax.random.PRNGKey ): self.config = config self.dtype = dtype mlp_key, output_key = jax.random.split(key) # the initialization strategy is to standardize on output dimension # input wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) # shapes are: (config.d_model, config.d_ff) # self.wi = jax.random.normal(mlp_key, (self.config.d_model, self.config.d_ff)) * wi_init_std self.wi = KaimingLinear( key=mlp_key, input_dim=self.config.d_model, output_dim=self.config.d_ff, weights_init_std=wi_init_std, dtype=self.dtype ) # output wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) # shapes are: (config.d_ff, config.d_model) # self.wo = jax.random.normal(output_key, (self.config.d_ff, self.config.d_model)) * wo_init_std self.wo = KaimingLinear( key=mlp_key, input_dim=self.config.d_ff, output_dim=self.config.d_model, weights_init_std=wo_init_std, dtype=self.dtype ) self.dropout = eqx.nn.Dropout(self.config.dropout_rate) # just set to relu for now since the smaller T5's use relu self.act = jax.nn.relu def __call__( self, inputs: Float[Array, " d_model"], enable_dropout: bool = False, dropout_key: Optional[jax.random.PRNGKey] = None, ) -> Float[Array, " d_model"]: hidden = self.wi(inputs) # hidden = jnp.dot(inputs, self.wi) hidden = self.act(hidden) hidden = self.dropout(hidden, inference=not enable_dropout, key=dropout_key) hidden = self.wo(hidden) # hidden = jnp.dot(hidden, self.wo) return hidden # # %% # # test for T5DenseActDense # # create fake config # config_dict = { # 'd_model': 768, # 'd_ff': 2048, # 'dropout_rate': 0.1, # 'initializer_factor': 1.0, # } # # Create a FrozenDict from the standard dictionary # frozen_config = FrozenConfigDict(config_dict) # # initialize model # key = jax.random.PRNGKey(0) # dense = T5DenseActDense( # key=key, # config=frozen_config, # dtype=jnp.float32 # ) # input_key, key = jax.random.split(key) # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values # dropout_key, key = jax.random.split(key) # output = dense(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) # output.shape # %% class T5LayerFF(eqx.Module): config: FrozenConfigDict dtype: jnp.dtype DenseReluDense: T5DenseActDense layer_norm: T5LayerNorm dropout: eqx.nn.Dropout def __init__( self: eqx.Module, key: jax.random.PRNGKey, config: FrozenConfigDict, dtype: jnp.dtype = jnp.float32 ): self.config = config self.dtype = dtype dense_key, key = jax.random.split(key) # args: key, config, dtype self.DenseReluDense = T5DenseActDense( key=dense_key, config=config, dtype=dtype ) layer_key, key = jax.random.split(key) # args: key, hidden_size self.layer_norm = T5LayerNorm( key=layer_key, hidden_size=self.config.d_model ) # args: dropout_rate self.dropout = eqx.nn.Dropout(self.config.dropout_rate) def __call__( self: eqx.Module, inputs: Float[Array, " d_model"], enable_dropout: bool =False, dropout_key: Optional[jax.random.PRNGKey] = None, ): forwarded_states = self.layer_norm(inputs) dropout_key, key = jax.random.split(dropout_key) forwarded_states = self.DenseReluDense( inputs=forwarded_states, enable_dropout=enable_dropout, dropout_key=dropout_key ) dropout_key, key = jax.random.split(key) dropout_states = self.dropout( x = forwarded_states, inference=not enable_dropout, key = dropout_key, ) hidden = inputs + dropout_states return hidden # # %% # # test for T5DenseActDense # # create fake config # config_dict = { # 'd_model': 768, # 'd_ff': 2048, # 'dropout_rate': 0.1, # 'initializer_factor': 1.0, # } # # Create a FrozenDict from the standard dictionary # frozen_config = FrozenConfigDict(config_dict) # # initialize model # key = jax.random.PRNGKey(0) # ff_layer = T5LayerFF( # key=key, # config=frozen_config, # dtype=jnp.float32 # ) # input_key, key = jax.random.split(key) # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values # dropout_key, key = jax.random.split(key) # output = ff_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) # output.shape # %% class T5Attention(eqx.Module): config: FrozenConfigDict has_relative_attention_bias: bool = False causal: bool = False # False for encoder, True for decoder dtype: jnp.dtype # parameters q: jax.Array k: jax.Array v: jax.Array o: jax.Array # additional terms relative_attention_num_buckets: int relative_attention_max_distance: int d_model: int key_value_proj_dim: int n_heads: int dropout: float inner_dim: int initializer_factor: float def __init__( self: eqx.Module, config: FrozenConfigDict, dtype: jnp.dtype, key: jax.random.PRNGKey, ): self.config = config self.dtype = dtype self.relative_attention_num_buckets = self.config.relative_attention_num_buckets self.relative_attention_max_distance = self.config.relative_attention_max_distance self.d_model = self.config.d_model # size of k,v projection for each head self.key_value_proj_dim = self.config.d_kv self.n_heads = self.config.num_heads self.dropout = self.config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim self.initializer_factor = self.config.initializer_factor q_init_std = self.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) kv_init_std = self.initializer_factor * (self.inner_dim**-0.5) o_init_std = self.initializer_factor * (self.inner_dim**-0.5) q_key, key = jax.random.split(key) self.q = KaimingLinear( key=q_key, input_dim=(self.inner_dim), output_dim=self.inner_dim, weights_init_std=q_init_std, dtype=self.dtype ) k_key, key = jax.random.split(key) self.k = KaimingLinear( key=k_key, input_dim=self.inner_dim, output_dim=self.inner_dim, weights_init_std=kv_init_std, dtype=self.dtype ) v_key, key = jax.random.split(key) self.v = KaimingLinear( key=v_key, input_dim=self.inner_dim, output_dim=self.inner_dim, weights_init_std=kv_init_std, dtype=self.dtype ) o_key, key = jax.random.split(key) self.o = KaimingLinear( key=o_key, input_dim=self.inner_dim, output_dim=self.d_model, weights_init_std=o_init_std, dtype=self.dtype ) @staticmethod def _relative_position_bucket( relative_position, bidirectional=True, num_buckets=32, max_distance=128 ): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on """ relative_buckets = 0 # bidirection determines if positive relative positions are valid if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: # relative position range of [0, inf] relative_position = -jnp.clip(relative_position, a_max=0) # half of buckets are for exact increments in positions max_exact = num_buckets // 2 # boolean to assign relative buckets later is_small = relative_position < max_exact # other half are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) # jnp.where(condition, x, y), true->x, false->y # in-place cumulative summation # yields a list where every element has the correct relative bucket position # whether its small or large relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) return relative_buckets.astype("i4") # bias gives weight based on relative distance aside from attention score def compute_bias(self, query_length, key_length): """ Compute binned relative position bias """ # arange in the first dim context_position = jnp.arange(query_length, dtype="i4")[:, None] # arange in the second dim memory_position = jnp.arange(key_length, dtype="i4")[None, :] # The relative position is defined as memory_position - query_position, # i.e. the distance in tokens from the attending position to the # attended-to position. # # 2D array where each entry represents the distance from a query token # to a key token relative_position = memory_position - context_position # now we apply the earlier bucket creation function relative_position_bucket = self._relative_position_bucket( relative_position=relative_position, bidirectional=(not self.causal), # causal during decode -> not bi num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) # retrieve the bias values # shape (query_length, key_length, n_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (1, n_heads, query_length, key_length) # ready for attention values = values.transpose((2, 0, 1))[None, :, :, :] return values # from (batch_size, seq_length, d_model) to # (batch_size, seq_length, n_heads, head_dim) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) # from (batch_size, seq_length, n_heads, head_dim) to # (batch_size, seq_length, d_model) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) def _create_position_bias( self, key_states, query_states, attention_mask, ): # unlike the flax version, we don't even check for cache key_length = key_states.shape[1] query_length = query_states.shape[1] if self.has_relative_attention_bias: position_bias = self.compute_bias(query_length, key_length) elif attention_mask is not None: position_bias = jnp.zeros_like(attention_mask) else: position_bias = jnp.zeros( (1, self.n_heads, query_length, key_length), dtype=self.dtype ) return position_bias def __call__( self, inputs, attention_mask=None, key_value_states=None, position_bias=None, output_attentions=False, enable_dropout=False, dropout_key: Optional[jax.random.PRNGKey] = None, ): # expected input shape: (batch_size, seq_len, d_model) # expected output: tuple of 2 arrays same shape as input # (attn, position_bias) batch_size, seq_length = inputs.shape[:2] # q,k,v projections # (batch_size, n_heads, seq_length, dim_per_head) query_states = self.q(inputs) key_states = ( self.k(inputs) if key_value_states is None else self.k(key_value_states) ) value_states = ( self.v(inputs) if key_value_states is None else self.v(key_value_states) ) # reshape to (batch_size, seq_length, n_heads, head_dim) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # counteract scaling in dot_product_attention_weights function query_states *= jnp.sqrt(query_states.shape[-1]) # create causal attention_mask if self.causal: causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") # broadcast causal attention mask & attention mask to fit for merge causal_attention_mask = jnp.broadcast_to( causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] ) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape ) attention_mask = combine_masks(attention_mask, causal_attention_mask) elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # replace masked positions with -10_000 if attention_mask is not None: mask_value = jnp.finfo(self.dtype).min attention_mask = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) if position_bias is None: # compute position bias (only for first layer) position_bias = self._create_position_bias( key_states, query_states, attention_mask ) if attention_mask is not None: position_bias = position_bias + attention_mask # Softmax(QK^T) attn_weights = dot_product_attention_weights( query_states, key_states, bias=position_bias, dropout_rng=dropout_key, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=not enable_dropout, dtype=self.dtype, ) # multiply with value states attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) # bring back to (batch_size, seq_length, d_model) attn_output = self._merge_heads(attn_output) # apply output matrix attn_output = self.o(attn_output) outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) return outputs # # %% # # test for T5Attention # # create fake config # config_dict = { # 'relative_attention_num_buckets': 32, # 'relative_attention_max_distance': 128, # 'd_model': 768, # 64 * 12 # 'd_kv': 64, # 'num_heads': 12, # 'dropout_rate': 0.1, # 'initializer_factor': 1.0, # } # # Create a FrozenDict from the standard dictionary # frozen_config = FrozenConfigDict(config_dict) # # initialize model # key = jax.random.PRNGKey(0) # attn_layer = T5Attention( # key=key, # config=frozen_config, # dtype=jnp.float32 # ) # input_key, key = jax.random.split(key) # # inputs = jax.random.normal(input_key, (10, frozen_config.d_model)) # Generate random normal values # batch_size = 1 # seq_length = 10 # inputs = jnp.ones((batch_size, seq_length, frozen_config.d_model)) # dropout_key, key = jax.random.split(key) # output = attn_layer(inputs=inputs, enable_dropout=False, dropout_key=dropout_key) # print(len(output)) # print(output[0].shape) # %%