# coding=utf-8 # Copyright 2021 T5 Authors and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Callable, Optional, Tuple, Dict from collections import OrderedDict, UserDict from dataclasses import fields, is_dataclass import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from ml_collections import ConfigDict from transformers.modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) import flax from transformers.modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, ) # from transformers import T5Config from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from ml_collections import FrozenConfigDict import functools from typing import Any, Dict, Tuple, Callable, Sequence PyTree = Any Metrics = Dict[str, Tuple[jax.Array, ...]] remat = nn_partitioning.remat # MARK: PARAMETER SHARDING # %% [markdown] # # parameter sharding # Basic strategy: init full parameters on each device, then use # jax.lax.axis_index to split parameters across devices, and keep a shard on # each device # # use nn.Partitioned to annotate sharding spec on parameters # quite similar to PartitionSpec # # parameters are either jax.Array or a flax.linen.Partitioned # t5-base # https://huggingface.co/google-t5/t5-base/blob/main/config.json def make_config(): model_type="t5" # keys_to_ignore_at_inference = ["past_key_values"] # attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} d_ff=3072 d_kv=64 d_model=768 vocab_size=32128 num_layers=12 num_heads=12 num_decoder_layers=None num_decoder_layers = ( num_decoder_layers if num_decoder_layers is not None else num_layers ) # default = symmetry relative_attention_num_buckets=32 relative_attention_max_distance=128 dropout_rate=0.1 layer_norm_epsilon=1e-6 initializer_factor=1.0 feed_forward_proj="relu" is_encoder_decoder=True use_cache=True pad_token_id=0 eos_token_id=1 decoder_start_token_id=0 classifier_dropout=0.0 act_info = feed_forward_proj.split("-") dense_act_fn = act_info[-1] is_gated_act = act_info[0] == "gated" causal=False use_return_dict=False if feed_forward_proj == "gated-gelu": dense_act_fn = "gelu_new" tie_word_embeddings=False return FrozenConfigDict( dict( model_type=model_type, vocab_size=vocab_size, d_model=d_model, d_kv=d_kv, d_ff=d_ff, num_layers=num_layers, num_decoder_layers = num_decoder_layers, num_heads=num_heads, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, initializer_factor=initializer_factor, feed_forward_proj=feed_forward_proj, is_encoder_decoder=is_encoder_decoder, use_cache=use_cache, pad_token_id=pad_token_id, eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, classifier_dropout=classifier_dropout, is_gated_act = is_gated_act, dense_act_fn = dense_act_fn, causal=causal, use_return_dict=use_return_dict, tie_word_embeddings=tie_word_embeddings ) ) T5config = make_config() # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ shifted_input_ids = jnp.zeros_like(input_ids) shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids class FlaxT5LayerNorm(nn.Module): hidden_size: int dtype: jnp.dtype = jnp.float32 eps: float = 1e-6 weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones def setup(self): self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) def __call__(self, hidden_states): """ Construct a layernorm module in the T5 style; No bias and no subtraction of mean. """ # layer norm should always be calculated in float32 variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) hidden_states = hidden_states / jnp.sqrt(variance + self.eps) return self.weight * hidden_states class FlaxT5DenseActDense(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 def setup(self): wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) self.wi = nn.Dense( self.config.d_ff, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(wi_init_std), (None, 'model') ), dtype=self.dtype, ) self.wo = nn.Dense( self.config.d_model, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(wo_init_std), ("model", None) ), dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) self.act = ACT2FN[self.config.dense_act_fn] def __call__(self, hidden_states, deterministic=True): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.wo(hidden_states) return hidden_states class FlaxT5DenseGatedActDense(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) self.wi_0 = nn.Dense( self.config.d_ff, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(wi_init_std), (None, "model") ), dtype=self.dtype, ) self.wi_1 = nn.Dense( self.config.d_ff, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(wi_init_std), (None, "model") ), dtype=self.dtype, ) self.wo = nn.Dense( self.config.d_model, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(wo_init_std), ("model", None) ), dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) self.act = ACT2FN[self.config.dense_act_fn] def __call__(self, hidden_states, deterministic): hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.wo(hidden_states) return hidden_states class FlaxT5LayerFF(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.is_gated_act: self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) else: self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.dropout = nn.Dropout(self.config.dropout_rate) def __call__(self, hidden_states, deterministic=True): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) return hidden_states class FlaxT5Attention(nn.Module): config: ConfigDict has_relative_attention_bias: bool = False causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.relative_attention_num_buckets = self.config.relative_attention_num_buckets self.relative_attention_max_distance = self.config.relative_attention_max_distance self.d_model = self.config.d_model self.key_value_proj_dim = self.config.d_kv self.n_heads = self.config.num_heads self.dropout = self.config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) self.q = nn.Dense( self.inner_dim, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(q_init_std), (None, "model") ), dtype=self.dtype, ) self.k = nn.Dense( self.inner_dim, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(kv_init_std), (None, "model") ), dtype=self.dtype, ) self.v = nn.Dense( self.inner_dim, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(kv_init_std), (None, "model") ), dtype=self.dtype, ) self.o = nn.Dense( self.d_model, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(o_init_std), ("model", None) ), dtype=self.dtype, ) if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embed( self.relative_attention_num_buckets, self.n_heads, embedding_init=jax.nn.initializers.normal(kv_init_std), dtype=self.dtype, ) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: relative_position = -jnp.clip(relative_position, a_max=0) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) return relative_buckets.astype("i4") def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" context_position = jnp.arange(query_length, dtype="i4")[:, None] memory_position = jnp.arange(key_length, dtype="i4")[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_position_bucket( relative_position, bidirectional=(not self.causal), num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) values = self.relative_attention_bias(relative_position_bucket) values = values.transpose((2, 0, 1))[None, :, :, :] return values def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) @nn.compact def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions # that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask def _create_position_bias( self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift ): cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) key_length = key_states.shape[1] query_length = key_length if cache_is_filled else query_states.shape[1] if self.has_relative_attention_bias: position_bias = self.compute_bias(query_length, key_length) elif attention_mask is not None: position_bias = jnp.zeros_like(attention_mask) else: position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) # if key and values are already calculated, only the last query position bias should be taken if cache_is_filled: max_decoder_length = self.variables["cache"]["cached_key"].shape[1] position_bias = jax.lax.dynamic_slice( position_bias, (0, 0, causal_attention_mask_shift, 0), (1, self.n_heads, seq_length, max_decoder_length), ) return position_bias def __call__( self, hidden_states, attention_mask=None, key_value_states=None, position_bias=None, use_cache=False, output_attentions=False, deterministic=True, init_cache=False, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ batch_size, seq_length = hidden_states.shape[:2] # q, k, v projections query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) # reshape to (batch_size, seq_length, n_heads, head_dim) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # counter-act scaling in dot_product_attention_weights function query_states *= jnp.sqrt(query_states.shape[-1]) # for fast decoding causal attention mask should be shifted causal_attention_mask_shift = ( self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 ) # create causal attention_mask; attention_mask has to be defined when model is causal if self.causal: causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") # fast decoding for generate requires special attention_mask if self.has_variable("cache", "cached_key"): max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_attention_mask = jax.lax.dynamic_slice( causal_attention_mask, (0, 0, causal_attention_mask_shift, 0), (1, 1, seq_length, max_decoder_length), ) # broadcast causal attention mask & attention mask to fit for merge causal_attention_mask = jnp.broadcast_to( causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] ) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape ) attention_mask = combine_masks(attention_mask, causal_attention_mask) elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) # replace masked positions with -10_000 if attention_mask is not None: mask_value = jnp.finfo(self.dtype).min attention_mask = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) if position_bias is None: # compute position bias (only for first layer) position_bias = self._create_position_bias( key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift ) if attention_mask is not None: position_bias = position_bias + attention_mask # create dropout rng dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") # Softmax(QK^T) attn_weights = dot_product_attention_weights( query_states, key_states, bias=position_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, ) # multiply with value states attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) # bring back to (batch_size, seq_length, d_model) attn_output = self._merge_heads(attn_output) # apply output matrix attn_output = self.o(attn_output) outputs = (attn_output, position_bias) if output_attentions: outputs = outputs + (attn_weights,) return outputs class FlaxT5LayerSelfAttention(nn.Module): config: ConfigDict has_relative_attention_bias: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.SelfAttention = FlaxT5Attention( config=self.config, has_relative_attention_bias=self.has_relative_attention_bias, causal=self.config.causal, dtype=self.dtype, ) self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.dropout = nn.Dropout(self.config.dropout_rate) def __call__( self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, deterministic=True, init_cache=False, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( normed_hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, deterministic=deterministic, init_cache=init_cache, ) hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs class FlaxT5LayerCrossAttention(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.EncDecAttention = FlaxT5Attention( config=self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype ) self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.dropout = nn.Dropout(self.config.dropout_rate) def __call__( self, hidden_states, key_value_states, attention_mask=None, position_bias=None, output_attentions=False, deterministic=True, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( normed_hidden_states, attention_mask=attention_mask, key_value_states=key_value_states, position_bias=position_bias, output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs class FlaxT5Block(nn.Module): config: ConfigDict has_relative_attention_bias: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.causal = self.config.causal self.layer = ( FlaxT5LayerSelfAttention( config=self.config, has_relative_attention_bias=self.has_relative_attention_bias, name=str(0), dtype=self.dtype, ), ) feed_forward_index = 1 if self.causal: self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) feed_forward_index += 1 self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) def __call__( self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, output_attentions=False, return_dict=True, deterministic=True, init_cache=False, ): self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, deterministic=deterministic, init_cache=init_cache, ) hidden_states = self_attention_outputs[0] attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights do_cross_attention = self.causal and encoder_hidden_states is not None if do_cross_attention: cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, output_attentions=output_attentions, deterministic=deterministic, ) hidden_states = cross_attention_outputs[0] # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[1:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) outputs = (hidden_states,) outputs = outputs + attention_outputs # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) return outputs class FlaxT5LayerCollection(nn.Module): config: ConfigDict has_relative_attention_bias: bool dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layer = FlaxT5Block( self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype ) def __call__( self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, output_attentions=False, deterministic=True, init_cache=False, ): return self.layer( hidden_states, attention_mask=attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, output_attentions=output_attentions, deterministic=deterministic, init_cache=init_cache, ) class FlaxT5BlockCollection(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal if self.gradient_checkpointing: FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) self.blocks = [ FlaxT5CheckpointLayer( self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i), ) for i in range(self.config.num_layers) ] else: self.blocks = [ FlaxT5LayerCollection( self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i), ) for i in range(self.config.num_layers) ] def __call__( self, hidden_states=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions: bool = False, output_hidden_states: bool = False, deterministic: bool = True, init_cache: bool = False, ): # Prepare head mask if needed all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.causal) else None position_bias = None encoder_decoder_position_bias = None for i, layer_module in enumerate(self.blocks): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, output_attentions, deterministic, init_cache, ) hidden_states = layer_outputs[0] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[1] if self.causal and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) if self.causal: all_cross_attentions = all_cross_attentions + (layer_outputs[4],) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) class FlaxT5Stack(nn.Module): config: ConfigDict embed_tokens: nn.Embed dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal self.block = FlaxT5BlockCollection( config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing ) self.final_layer_norm = FlaxT5LayerNorm( self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype ) self.dropout = nn.Dropout(self.config.dropout_rate) def __call__( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, init_cache: bool = False, ): hidden_states = self.embed_tokens(input_ids) hidden_states = self.dropout(hidden_states, deterministic=deterministic) outputs = self.block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, deterministic=deterministic, init_cache=init_cache, ) hidden_states = outputs[0] hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) # Add last layer all_hidden_states = None if output_hidden_states: all_hidden_states = outputs.hidden_states all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: if output_hidden_states: return ( hidden_states, all_hidden_states, ) + outputs[2:] return (hidden_states,) + outputs[1:] return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) class FlaxT5PreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = ConfigDict base_model_prefix = "transformer" module_class: nn.Module = None def __init__( self, config: ConfigDict, input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.bfloat16, _do_init: bool = True, gradient_checkpointing: bool = False, **kwargs, ): module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def enable_gradient_checkpointing(self): self._module = self.module_class( config=self.config, dtype=self.dtype, gradient_checkpointing=True, ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) args = [input_ids, attention_mask] if self.module_class not in [FlaxT5EncoderModule]: decoder_input_ids = jnp.ones_like(input_ids) decoder_attention_mask = jnp.ones_like(input_ids) args.extend([decoder_input_ids, decoder_attention_mask]) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, *args, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: jnp.ndarray = None, decoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if decoder_input_ids is None: raise ValueError( "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" " here." ) # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # prepare decoder inputs if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example: ```python >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") >>> encoder_outputs = model.encode(**inputs) ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, input_ids, attention_mask, **kwargs): encode_module = module._get_encoder_module() return encode_module(input_ids, attention_mask, **kwargs) return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example: ```python >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration >>> import jax.numpy as jnp >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") >>> encoder_outputs = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxT5Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs class FlaxT5Module(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def setup(self): # sharded module # with manual shard_map # sharded_embed = shard_module_params( # nn.Embed, # axis_name="model", # ) # self.shared = sharded_embed( # self.config.vocab_size, # self.config.d_model, # embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), # dtype=self.dtype, # ) self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) # unfreeze encoder_config = ConfigDict(encoder_config) encoder_config.causal = False encoder_config = FrozenConfigDict(encoder_config) # freeze self.encoder = FlaxT5Stack( encoder_config, embed_tokens=self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, ) decoder_config = copy.deepcopy(self.config) # unfreeze decoder_config = ConfigDict(encoder_config) decoder_config.causal = True decoder_config.num_layers = self.config.num_decoder_layers # freeze decoder_config = FrozenConfigDict(encoder_config) self.decoder = FlaxT5Stack( decoder_config, embed_tokens=self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, ) def __call__( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, output_attentions=None, output_hidden_states=None, return_dict=None, deterministic: bool = True, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) # if not return_dict: return decoder_outputs + encoder_outputs # return FlaxSeq2SeqModelOutput( # last_hidden_state=decoder_outputs.last_hidden_state, # past_key_values=decoder_outputs.past_key_values, # decoder_hidden_states=decoder_outputs.hidden_states, # decoder_attentions=decoder_outputs.attentions, # cross_attentions=decoder_outputs.cross_attentions, # encoder_last_hidden_state=encoder_outputs.last_hidden_state, # encoder_hidden_states=encoder_outputs.hidden_states, # encoder_attentions=encoder_outputs.attentions, # ) class FlaxT5Model(FlaxT5PreTrainedModel): module_class = FlaxT5Module class FlaxT5EncoderModule(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.float32 # the dtype of the computation gradient_checkpointing: bool = False def setup(self): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) encoder_config.is_decoder = False encoder_config.is_encoder_decoder = False encoder_config.causal = False self.encoder = FlaxT5Stack( encoder_config, embed_tokens=self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, ) def __call__( self, input_ids=None, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict: bool = True, deterministic: bool = True, ): # Encode if needed (training, first prediction pass) encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) return encoder_outputs class FlaxT5EncoderModel(FlaxT5PreTrainedModel): module_class = FlaxT5EncoderModule def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) class FlaxT5ForConditionalGenerationModule(nn.Module): config: ConfigDict dtype: jnp.dtype = jnp.bfloat16 # the dtype of the computation gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def setup(self): self.model_dim = self.config.d_model self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) # unfreeze encoder_config = ConfigDict(encoder_config) encoder_config.causal = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False # freeze encoder_config = FrozenConfigDict(encoder_config) self.encoder = FlaxT5Stack( config=encoder_config, embed_tokens=self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing ) decoder_config = copy.deepcopy(self.config) # unfreeze decoder_config = ConfigDict(encoder_config) decoder_config.causal = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = self.config.num_decoder_layers # freeze encoder_config = FrozenConfigDict(decoder_config) self.decoder = FlaxT5Stack( config=decoder_config, embed_tokens=self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing ) # handles outgoing predictions # think of it like a wo of a dense self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, kernel_init=nn.with_partitioning( jax.nn.initializers.normal(self.config.initializer_factor), ("model", None) ), dtype=self.dtype, ) def __call__( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, output_attentions=None, output_hidden_states=None, return_dict=None, deterministic: bool = True, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) hidden_states = encoder_outputs[0] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) if self.config.tie_word_embeddings: # shared_embedding is a partitioned tensor shared_embedding = self.shared.variables["params"]["embedding"] lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) else: lm_logits = self.lm_head(sequence_output) if not return_dict: return (lm_logits,) + decoder_outputs[1:] + encoder_outputs # return lm_logits return FlaxSeq2SeqLMOutput( logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): module_class = FlaxT5ForConditionalGenerationModule def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example: ```python >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration >>> import jax.numpy as jnp >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") >>> text = "summarize: My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") >>> encoder_outputs = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxT5Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): decoder_module = module._get_decoder_module() decoder_outputs = decoder_module( decoder_input_ids, decoder_attention_mask, **kwargs, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.d_model**-0.5) if self.config.tie_word_embeddings: shared_embedding = module.shared.variables["params"]["embedding"] lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) else: lm_logits = module.lm_head(sequence_output) return lm_logits, decoder_outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) else: outputs = (lm_logits,) + decoder_outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jax.Array] = None, decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: extended_attention_mask = jax.lax.dynamic_update_slice( extended_attention_mask, decoder_attention_mask, (0, 0) ) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values return model_kwargs