1840 lines
74 KiB
Python
1840 lines
74 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 T5 Authors and HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Flax T5 model."""
|
|
|
|
import copy
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
from flax.linen import combine_masks, make_causal_mask
|
|
from flax.linen import partitioning as nn_partitioning
|
|
from flax.linen.attention import dot_product_attention_weights
|
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
from jax.random import PRNGKey
|
|
|
|
from transformers.modeling_flax_outputs import (
|
|
FlaxBaseModelOutput,
|
|
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
|
FlaxCausalLMOutputWithCrossAttentions,
|
|
FlaxSeq2SeqLMOutput,
|
|
FlaxSeq2SeqModelOutput,
|
|
)
|
|
from transformers.modeling_flax_utils import (
|
|
ACT2FN,
|
|
FlaxPreTrainedModel,
|
|
append_call_sample_docstring,
|
|
append_replace_return_docstrings,
|
|
overwrite_call_docstring,
|
|
)
|
|
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
|
from transformers import T5Config
|
|
# from dataclasses import dataclass, replace
|
|
from ml_collections import ConfigDict, FrozenConfigDict
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CHECKPOINT_FOR_DOC = "google-t5/t5-small"
|
|
_CONFIG_FOR_DOC = "FrozenConfigDict"
|
|
|
|
remat = nn_partitioning.remat
|
|
|
|
|
|
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
|
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
|
"""
|
|
Shift input ids one token to the right.
|
|
"""
|
|
shifted_input_ids = jnp.zeros_like(input_ids)
|
|
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
|
|
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
|
|
|
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
|
return shifted_input_ids
|
|
|
|
|
|
class FlaxT5LayerNorm(nn.Module):
|
|
hidden_size: int
|
|
dtype: jnp.dtype = jnp.float32
|
|
eps: float = 1e-6
|
|
weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
|
|
|
def setup(self):
|
|
self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
|
|
|
|
def __call__(self, hidden_states):
|
|
"""
|
|
Construct a layernorm module in the T5 style; No bias and no subtraction of mean.
|
|
"""
|
|
# layer norm should always be calculated in float32
|
|
variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
|
|
hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
|
|
|
|
return self.weight * hidden_states
|
|
|
|
|
|
class FlaxT5DenseActDense(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32
|
|
|
|
def setup(self):
|
|
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
|
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
|
|
|
self.wi = nn.Dense(
|
|
self.config.d_ff,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.wo = nn.Dense(
|
|
self.config.d_model,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(wo_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
self.act = ACT2FN[self.config.dense_act_fn]
|
|
|
|
def __call__(self, hidden_states, deterministic=True):
|
|
hidden_states = self.wi(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.wo(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class FlaxT5DenseGatedActDense(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
|
|
wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
|
|
|
|
self.wi_0 = nn.Dense(
|
|
self.config.d_ff,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.wi_1 = nn.Dense(
|
|
self.config.d_ff,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.wo = nn.Dense(
|
|
self.config.d_model,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(wo_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
self.act = ACT2FN[self.config.dense_act_fn]
|
|
|
|
def __call__(self, hidden_states, deterministic):
|
|
hidden_gelu = self.act(self.wi_0(hidden_states))
|
|
hidden_linear = self.wi_1(hidden_states)
|
|
hidden_states = hidden_gelu * hidden_linear
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
hidden_states = self.wo(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class FlaxT5LayerFF(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
if self.config.is_gated_act:
|
|
self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
|
|
else:
|
|
self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
|
|
|
|
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
|
|
def __call__(self, hidden_states, deterministic=True):
|
|
forwarded_states = self.layer_norm(hidden_states)
|
|
forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
|
|
hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
|
|
return hidden_states
|
|
|
|
|
|
class FlaxT5Attention(nn.Module):
|
|
config: FrozenConfigDict
|
|
has_relative_attention_bias: bool = False
|
|
causal: bool = False
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
|
|
self.relative_attention_max_distance = self.config.relative_attention_max_distance
|
|
self.d_model = self.config.d_model
|
|
self.key_value_proj_dim = self.config.d_kv
|
|
self.n_heads = self.config.num_heads
|
|
self.dropout = self.config.dropout_rate
|
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
|
|
|
q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
|
kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
|
|
o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
|
|
|
|
self.q = nn.Dense(
|
|
self.inner_dim,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(q_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.k = nn.Dense(
|
|
self.inner_dim,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(kv_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.v = nn.Dense(
|
|
self.inner_dim,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(kv_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
self.o = nn.Dense(
|
|
self.d_model,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(o_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
if self.has_relative_attention_bias:
|
|
self.relative_attention_bias = nn.Embed(
|
|
self.relative_attention_num_buckets,
|
|
self.n_heads,
|
|
embedding_init=jax.nn.initializers.normal(kv_init_std),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
@staticmethod
|
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
|
"""
|
|
Adapted from Mesh Tensorflow:
|
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
|
"""
|
|
relative_buckets = 0
|
|
if bidirectional:
|
|
num_buckets //= 2
|
|
relative_buckets += (relative_position > 0) * num_buckets
|
|
relative_position = jnp.abs(relative_position)
|
|
else:
|
|
relative_position = -jnp.clip(relative_position, a_max=0)
|
|
# now relative_position is in the range [0, inf)
|
|
|
|
# half of the buckets are for exact increments in positions
|
|
max_exact = num_buckets // 2
|
|
is_small = relative_position < max_exact
|
|
|
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
|
relative_position_if_large = max_exact + (
|
|
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
|
)
|
|
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
|
|
|
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
|
|
|
return relative_buckets.astype("i4")
|
|
|
|
def compute_bias(self, query_length, key_length):
|
|
"""Compute binned relative position bias"""
|
|
context_position = jnp.arange(query_length, dtype="i4")[:, None]
|
|
memory_position = jnp.arange(key_length, dtype="i4")[None, :]
|
|
|
|
relative_position = memory_position - context_position
|
|
relative_position_bucket = self._relative_position_bucket(
|
|
relative_position,
|
|
bidirectional=(not self.causal),
|
|
num_buckets=self.relative_attention_num_buckets,
|
|
max_distance=self.relative_attention_max_distance,
|
|
)
|
|
|
|
values = self.relative_attention_bias(relative_position_bucket)
|
|
values = values.transpose((2, 0, 1))[None, :, :, :]
|
|
return values
|
|
|
|
def _split_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
|
|
|
|
def _merge_heads(self, hidden_states):
|
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
|
|
|
|
# i suspect we are threading state here
|
|
@nn.compact
|
|
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
|
"""
|
|
This function takes projected key, value states from a single input token and concatenates the states to cached
|
|
states from previous steps. This function is slighly adapted from the official Flax repository:
|
|
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
|
"""
|
|
# detect if we're initializing by absence of existing cache data.
|
|
is_initialized = self.has_variable("cache", "cached_key")
|
|
# Variables are identified by a collection (e.g., "batch_stats") and a name
|
|
# (e.g., "moving_mean"). The value property gives access to the variable's
|
|
# content and can be assigned to for mutation.
|
|
#
|
|
# self.variable either 1.) initializes values for the first time
|
|
# 2.) retrieves the variable and does not override
|
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
|
|
|
# only run if initialized before
|
|
if is_initialized:
|
|
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
|
# update key, value caches with our new 1d spatial slices
|
|
cur_index = cache_index.value
|
|
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
|
key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
|
|
value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
|
|
cached_key.value = key
|
|
cached_value.value = value
|
|
num_updated_cache_vectors = query.shape[1]
|
|
cache_index.value = cache_index.value + num_updated_cache_vectors
|
|
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions
|
|
# that have already been generated and cached, not the remaining zero elements.
|
|
pad_mask = jnp.broadcast_to(
|
|
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
|
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
|
)
|
|
attention_mask = combine_masks(pad_mask, attention_mask)
|
|
return key, value, attention_mask
|
|
|
|
def _create_position_bias(
|
|
self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
|
|
):
|
|
cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
|
|
key_length = key_states.shape[1]
|
|
query_length = key_length if cache_is_filled else query_states.shape[1]
|
|
|
|
if self.has_relative_attention_bias:
|
|
position_bias = self.compute_bias(query_length, key_length)
|
|
elif attention_mask is not None:
|
|
position_bias = jnp.zeros_like(attention_mask)
|
|
else:
|
|
position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
|
|
|
|
# if key and values are already calculated, only the last query position bias should be taken
|
|
if cache_is_filled:
|
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
position_bias = jax.lax.dynamic_slice(
|
|
position_bias,
|
|
(0, 0, causal_attention_mask_shift, 0),
|
|
(1, self.n_heads, seq_length, max_decoder_length),
|
|
)
|
|
return position_bias
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
key_value_states=None,
|
|
position_bias=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
deterministic=True,
|
|
init_cache=False,
|
|
):
|
|
"""
|
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
|
"""
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
# q, k, v projections
|
|
query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
|
|
key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
|
|
value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
|
|
|
|
# reshape to (batch_size, seq_length, n_heads, head_dim)
|
|
query_states = self._split_heads(query_states)
|
|
key_states = self._split_heads(key_states)
|
|
value_states = self._split_heads(value_states)
|
|
|
|
# counter-act scaling in dot_product_attention_weights function
|
|
query_states *= jnp.sqrt(query_states.shape[-1])
|
|
|
|
# for fast decoding causal attention mask should be shifted
|
|
causal_attention_mask_shift = (
|
|
self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0
|
|
)
|
|
# create causal attention_mask; attention_mask has to be defined when model is causal
|
|
if self.causal:
|
|
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
|
|
|
# fast decoding for generate requires special attention_mask
|
|
if self.has_variable("cache", "cached_key"):
|
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
causal_attention_mask = jax.lax.dynamic_slice(
|
|
causal_attention_mask,
|
|
(0, 0, causal_attention_mask_shift, 0),
|
|
(1, 1, seq_length, max_decoder_length),
|
|
)
|
|
|
|
# broadcast causal attention mask & attention mask to fit for merge
|
|
causal_attention_mask = jnp.broadcast_to(
|
|
causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
|
|
)
|
|
attention_mask = jnp.broadcast_to(
|
|
jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape
|
|
)
|
|
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
|
elif attention_mask is not None:
|
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
|
|
|
# During fast autoregressive decoding, we feed one position at a time,
|
|
# and cache the keys and values step by step.
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
|
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
|
key_states, value_states, query_states, attention_mask
|
|
)
|
|
|
|
# replace masked positions with -10_000
|
|
if attention_mask is not None:
|
|
mask_value = jnp.finfo(self.dtype).min
|
|
attention_mask = jax.lax.select(
|
|
attention_mask > 0,
|
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
|
)
|
|
|
|
if position_bias is None:
|
|
# compute position bias (only for first layer)
|
|
position_bias = self._create_position_bias(
|
|
key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
position_bias = position_bias + attention_mask
|
|
|
|
# create dropout rng
|
|
dropout_rng = None
|
|
if not deterministic and self.dropout > 0.0:
|
|
dropout_rng = self.make_rng("dropout")
|
|
|
|
# Softmax(QK^T)
|
|
attn_weights = dot_product_attention_weights(
|
|
query_states,
|
|
key_states,
|
|
bias=position_bias,
|
|
dropout_rng=dropout_rng,
|
|
dropout_rate=self.dropout,
|
|
broadcast_dropout=True,
|
|
deterministic=deterministic,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
# multiply with value states
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
|
|
|
# bring back to (batch_size, seq_length, d_model)
|
|
attn_output = self._merge_heads(attn_output)
|
|
|
|
# apply output matrix
|
|
attn_output = self.o(attn_output)
|
|
|
|
outputs = (attn_output, position_bias)
|
|
|
|
if output_attentions:
|
|
outputs = outputs + (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class FlaxT5LayerSelfAttention(nn.Module):
|
|
config: FrozenConfigDict
|
|
has_relative_attention_bias: bool = False
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.SelfAttention = FlaxT5Attention(
|
|
self.config,
|
|
has_relative_attention_bias=self.has_relative_attention_bias,
|
|
causal=self.config.causal,
|
|
dtype=self.dtype,
|
|
)
|
|
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
output_attentions=False,
|
|
deterministic=True,
|
|
init_cache=False,
|
|
):
|
|
normed_hidden_states = self.layer_norm(hidden_states)
|
|
attention_output = self.SelfAttention(
|
|
normed_hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
deterministic=deterministic,
|
|
init_cache=init_cache,
|
|
)
|
|
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
|
|
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class FlaxT5LayerCrossAttention(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.EncDecAttention = FlaxT5Attention(
|
|
self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
|
|
)
|
|
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
key_value_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
output_attentions=False,
|
|
deterministic=True,
|
|
):
|
|
normed_hidden_states = self.layer_norm(hidden_states)
|
|
attention_output = self.EncDecAttention(
|
|
normed_hidden_states,
|
|
attention_mask=attention_mask,
|
|
key_value_states=key_value_states,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
|
|
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class FlaxT5Block(nn.Module):
|
|
config: FrozenConfigDict
|
|
has_relative_attention_bias: bool = False
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.causal = self.config.causal
|
|
self.layer = (
|
|
FlaxT5LayerSelfAttention(
|
|
self.config,
|
|
has_relative_attention_bias=self.has_relative_attention_bias,
|
|
name=str(0),
|
|
dtype=self.dtype,
|
|
),
|
|
)
|
|
feed_forward_index = 1
|
|
if self.causal:
|
|
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
|
|
feed_forward_index += 1
|
|
|
|
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
encoder_decoder_position_bias=None,
|
|
output_attentions=False,
|
|
return_dict=True,
|
|
deterministic=True,
|
|
init_cache=False,
|
|
):
|
|
self_attention_outputs = self.layer[0](
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
output_attentions=output_attentions,
|
|
deterministic=deterministic,
|
|
init_cache=init_cache,
|
|
)
|
|
hidden_states = self_attention_outputs[0]
|
|
attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
|
|
|
|
do_cross_attention = self.causal and encoder_hidden_states is not None
|
|
if do_cross_attention:
|
|
cross_attention_outputs = self.layer[1](
|
|
hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
position_bias=encoder_decoder_position_bias,
|
|
output_attentions=output_attentions,
|
|
deterministic=deterministic,
|
|
)
|
|
hidden_states = cross_attention_outputs[0]
|
|
|
|
# Keep cross-attention outputs and relative position weights
|
|
attention_outputs = attention_outputs + cross_attention_outputs[1:]
|
|
|
|
# Apply Feed Forward layer
|
|
hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
outputs = outputs + attention_outputs
|
|
|
|
# returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
|
|
# (cross-attention position bias), (cross-attention weights)
|
|
return outputs
|
|
|
|
|
|
class FlaxT5LayerCollection(nn.Module):
|
|
config: FrozenConfigDict
|
|
has_relative_attention_bias: bool
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
|
|
def setup(self):
|
|
self.layer = FlaxT5Block(
|
|
self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
encoder_decoder_position_bias=None,
|
|
output_attentions=False,
|
|
deterministic=True,
|
|
init_cache=False,
|
|
):
|
|
return self.layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
|
output_attentions=output_attentions,
|
|
deterministic=deterministic,
|
|
init_cache=init_cache,
|
|
)
|
|
|
|
|
|
class FlaxT5BlockCollection(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.causal = self.config.causal
|
|
if self.gradient_checkpointing:
|
|
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8))
|
|
self.blocks = [
|
|
FlaxT5CheckpointLayer(
|
|
self.config,
|
|
has_relative_attention_bias=(i == 0),
|
|
dtype=self.dtype,
|
|
name=str(i),
|
|
)
|
|
for i in range(self.config.num_layers)
|
|
]
|
|
else:
|
|
self.blocks = [
|
|
FlaxT5LayerCollection(
|
|
self.config,
|
|
has_relative_attention_bias=(i == 0),
|
|
dtype=self.dtype,
|
|
name=str(i),
|
|
)
|
|
for i in range(self.config.num_layers)
|
|
]
|
|
|
|
def __call__(
|
|
self,
|
|
hidden_states=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
deterministic: bool = True,
|
|
init_cache: bool = False,
|
|
):
|
|
# Prepare head mask if needed
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and self.causal) else None
|
|
position_bias = None
|
|
encoder_decoder_position_bias = None
|
|
|
|
for _, layer_module in enumerate(self.blocks):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_outputs = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
position_bias,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
encoder_decoder_position_bias,
|
|
output_attentions,
|
|
deterministic,
|
|
init_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
# We share the position biases between the layers - the first layer store them
|
|
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
|
# (cross-attention position bias), (cross-attention weights)
|
|
position_bias = layer_outputs[1]
|
|
|
|
if self.causal and encoder_hidden_states is not None:
|
|
encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[2],)
|
|
if self.causal:
|
|
all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
|
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
class FlaxT5Stack(nn.Module):
|
|
config: FrozenConfigDict
|
|
embed_tokens: nn.Embed
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.causal = self.config.causal
|
|
|
|
self.block = FlaxT5BlockCollection(
|
|
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
|
)
|
|
self.final_layer_norm = FlaxT5LayerNorm(
|
|
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
|
|
)
|
|
self.dropout = nn.Dropout(self.config.dropout_rate)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
deterministic: bool = True,
|
|
init_cache: bool = False,
|
|
):
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
|
|
outputs = self.block(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
deterministic=deterministic,
|
|
init_cache=init_cache,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
|
|
# Add last layer
|
|
all_hidden_states = None
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = outputs.hidden_states
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
if output_hidden_states:
|
|
return (
|
|
hidden_states,
|
|
all_hidden_states,
|
|
) + outputs[2:]
|
|
return (hidden_states,) + outputs[1:]
|
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
T5_ENCODE_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
|
should be able to pad the inputs on both the right and the left.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for detail.
|
|
|
|
To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
|
|
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
T5_DECODE_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
|
|
Indices of decoder input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
|
|
|
For training, `decoder_input_ids` should be provided.
|
|
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
|
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
be used by default.
|
|
|
|
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
|
|
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
|
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
T5_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
|
should be able to pad the inputs on both the right and the left.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for detail.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
|
|
To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
|
|
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
Indices of decoder input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
|
|
|
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
|
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
|
|
|
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
|
Training](./t5#training).
|
|
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
|
be used by default.
|
|
encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*):
|
|
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
|
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
|
|
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
|
past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
|
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = T5Config
|
|
base_model_prefix = "transformer"
|
|
module_class: nn.Module = None # to be overriden by subclass
|
|
|
|
def __init__(
|
|
self,
|
|
config: T5Config,
|
|
input_shape: Tuple[int] = (1, 1),
|
|
seed: int = 0,
|
|
dtype: jnp.dtype = jnp.bfloat16,
|
|
_do_init: bool = True,
|
|
gradient_checkpointing: bool = False,
|
|
**kwargs,
|
|
):
|
|
# we convert the T5Config mutable object into a FrozenConfigDict
|
|
# we pop these 2 things because we know they are troublesome
|
|
# from manual testing making a ConfigDict out of the config
|
|
config_dict = copy.deepcopy(config.to_dict())
|
|
config_dict.pop('architectures')
|
|
config_dict.pop('id2label')
|
|
config_dict = FrozenConfigDict(config_dict)
|
|
# all modules take in a FrozenConfigDict
|
|
module = self.module_class(config=config_dict, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
|
# but FlaxPreTrainedModel still takes the T5Config object
|
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
self._module = self.module_class(
|
|
config=self.config,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=True,
|
|
)
|
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
|
# init input tensors
|
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
|
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
args = [input_ids, attention_mask]
|
|
if self.module_class not in [FlaxT5EncoderModule]:
|
|
decoder_input_ids = jnp.ones_like(input_ids)
|
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
|
args.extend([decoder_input_ids, decoder_attention_mask])
|
|
|
|
params_rng, dropout_rng = jax.random.split(rng)
|
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
|
|
|
random_params = self.module.init(
|
|
rngs,
|
|
*args,
|
|
)["params"]
|
|
|
|
if params is not None:
|
|
random_params = flatten_dict(unfreeze(random_params))
|
|
params = flatten_dict(unfreeze(params))
|
|
for missing_key in self._missing_keys:
|
|
params[missing_key] = random_params[missing_key]
|
|
self._missing_keys = set()
|
|
return freeze(unflatten_dict(params))
|
|
else:
|
|
return random_params
|
|
|
|
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
|
def __call__(
|
|
self,
|
|
input_ids: jnp.ndarray,
|
|
attention_mask: Optional[jnp.ndarray] = None,
|
|
decoder_input_ids: jnp.ndarray = None,
|
|
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
train: bool = False,
|
|
params: dict = None,
|
|
dropout_rng: PRNGKey = None,
|
|
):
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
if decoder_input_ids is None:
|
|
raise ValueError(
|
|
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
|
|
" here."
|
|
)
|
|
|
|
# prepare encoder inputs
|
|
if attention_mask is None:
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
|
|
# prepare decoder inputs
|
|
if decoder_attention_mask is None:
|
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
|
|
|
return self.module.apply(
|
|
{"params": params or self.params},
|
|
input_ids=jnp.array(input_ids, dtype="i4"),
|
|
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=not train,
|
|
rngs=rngs,
|
|
)
|
|
|
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
|
r"""
|
|
Args:
|
|
batch_size (`int`):
|
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
|
max_length (`int`):
|
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
|
cache.
|
|
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
|
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
|
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
|
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
|
cross-attention of the decoder.
|
|
"""
|
|
# init input variables to retrieve cache
|
|
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
|
|
decoder_module = module._get_decoder_module()
|
|
return decoder_module(
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
init_variables = self.module.init(
|
|
jax.random.PRNGKey(0),
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
encoder_hidden_states=encoder_outputs[0],
|
|
init_cache=True,
|
|
method=_decoder_forward, # we only need to call the decoder to init the cache
|
|
)
|
|
return unfreeze(init_variables["cache"])
|
|
|
|
@add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=FrozenConfigDict)
|
|
def encode(
|
|
self,
|
|
input_ids: jnp.ndarray,
|
|
attention_mask: Optional[jnp.ndarray] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
train: bool = False,
|
|
params: dict = None,
|
|
dropout_rng: PRNGKey = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
|
|
|
>>> text = "My friends are cool but they eat too many carbs."
|
|
>>> inputs = tokenizer(text, return_tensors="np")
|
|
>>> encoder_outputs = model.encode(**inputs)
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
if attention_mask is None:
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {}
|
|
if dropout_rng is not None:
|
|
rngs["dropout"] = dropout_rng
|
|
|
|
def _encoder_forward(module, input_ids, attention_mask, **kwargs):
|
|
encode_module = module._get_encoder_module()
|
|
return encode_module(input_ids, attention_mask, **kwargs)
|
|
|
|
return self.module.apply(
|
|
{"params": params or self.params},
|
|
input_ids=jnp.array(input_ids, dtype="i4"),
|
|
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=not train,
|
|
rngs=rngs,
|
|
method=_encoder_forward,
|
|
)
|
|
|
|
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=FrozenConfigDict)
|
|
def decode(
|
|
self,
|
|
decoder_input_ids,
|
|
encoder_outputs,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
past_key_values: dict = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
train: bool = False,
|
|
params: dict = None,
|
|
dropout_rng: PRNGKey = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
|
|
|
>>> text = "My friends are cool but they eat too many carbs."
|
|
>>> inputs = tokenizer(text, return_tensors="np")
|
|
>>> encoder_outputs = model.encode(**inputs)
|
|
|
|
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
|
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
|
|
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
|
>>> logits = outputs.logits
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
encoder_hidden_states = encoder_outputs[0]
|
|
if encoder_attention_mask is None:
|
|
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
|
|
batch_size, sequence_length = decoder_input_ids.shape
|
|
if decoder_attention_mask is None:
|
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {}
|
|
if dropout_rng is not None:
|
|
rngs["dropout"] = dropout_rng
|
|
|
|
inputs = {"params": params or self.params}
|
|
|
|
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
|
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
|
# it can be changed by FlaxT5Attention module
|
|
if past_key_values:
|
|
inputs["cache"] = past_key_values
|
|
mutable = ["cache"]
|
|
else:
|
|
mutable = False
|
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
|
|
decoder_module = module._get_decoder_module()
|
|
return decoder_module(
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
outputs = self.module.apply(
|
|
inputs,
|
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=not train,
|
|
rngs=rngs,
|
|
mutable=mutable,
|
|
method=_decoder_forward,
|
|
)
|
|
|
|
# add updated cache to model output
|
|
if past_key_values is not None and return_dict:
|
|
outputs, past = outputs
|
|
outputs["past_key_values"] = unfreeze(past["cache"])
|
|
return outputs
|
|
elif past_key_values is not None and not return_dict:
|
|
outputs, past = outputs
|
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
|
|
|
return outputs
|
|
|
|
|
|
T5_START_DOCSTRING = r"""
|
|
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
|
|
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
|
|
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
|
|
text-to-text denoising generative setting.
|
|
|
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a Flax Linen
|
|
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
|
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
|
|
|
Finally, this model supports inherent JAX features such as:
|
|
|
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
|
|
|
Parameters:
|
|
config ([`FrozenConfigDict`]): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the
|
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
|
`jax.numpy.bfloat16` (on TPUs).
|
|
|
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
|
specified all the computation will be performed with the given `dtype`.
|
|
|
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
|
parameters.**
|
|
|
|
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
|
[`~FlaxPreTrainedModel.to_bf16`].
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
|
|
T5_START_DOCSTRING,
|
|
)
|
|
class FlaxT5Module(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def _get_encoder_module(self):
|
|
return self.encoder
|
|
|
|
def _get_decoder_module(self):
|
|
return self.decoder
|
|
|
|
def setup(self):
|
|
self.shared = nn.Embed(
|
|
self.config.vocab_size,
|
|
self.config.d_model,
|
|
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
encoder_config = copy.deepcopy(self.config)
|
|
# unfreeze
|
|
encoder_config = ConfigDict(encoder_config)
|
|
encoder_config.causal = False
|
|
# freeze
|
|
encoder_config = FrozenConfigDict(encoder_config)
|
|
self.encoder = FlaxT5Stack(
|
|
encoder_config,
|
|
embed_tokens=self.shared,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
|
|
decoder_config = copy.deepcopy(self.config)
|
|
# unfreeze
|
|
decoder_config = ConfigDict(encoder_config)
|
|
decoder_config.causal = True
|
|
decoder_config.num_layers = self.config.num_decoder_layers
|
|
# freeze
|
|
decoder_config = FrozenConfigDict(encoder_config)
|
|
self.decoder = FlaxT5Stack(
|
|
decoder_config,
|
|
embed_tokens=self.shared,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
decoder_input_ids=None,
|
|
decoder_attention_mask=None,
|
|
encoder_outputs=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
deterministic: bool = True,
|
|
):
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
# Encode if needed (training, first prediction pass)
|
|
encoder_outputs = self.encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
# Decode
|
|
decoder_outputs = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
attention_mask=decoder_attention_mask,
|
|
encoder_hidden_states=encoder_outputs[0],
|
|
encoder_attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
if not return_dict:
|
|
return decoder_outputs + encoder_outputs
|
|
|
|
return FlaxSeq2SeqModelOutput(
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
class FlaxT5Model(FlaxT5PreTrainedModel):
|
|
module_class = FlaxT5Module
|
|
|
|
|
|
append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
|
|
|
|
FLAX_T5_MODEL_DOCSTRING = """
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, FlaxT5Model
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
|
>>> model = FlaxT5Model.from_pretrained("google-t5/t5-small")
|
|
|
|
>>> input_ids = tokenizer(
|
|
... "Studies have been shown that owning a dog is good for you", return_tensors="np"
|
|
... ).input_ids
|
|
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
|
|
|
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
|
|
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
|
|
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
|
|
|
>>> # forward pass
|
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
```
|
|
"""
|
|
|
|
|
|
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
|
|
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
|
|
T5_START_DOCSTRING,
|
|
)
|
|
class FlaxT5EncoderModule(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def setup(self):
|
|
self.shared = nn.Embed(
|
|
self.config.vocab_size,
|
|
self.config.d_model,
|
|
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
encoder_config = copy.deepcopy(self.config)
|
|
# unfreeze
|
|
encoder_config = ConfigDict(encoder_config)
|
|
encoder_config.is_decoder = False
|
|
encoder_config.is_encoder_decoder = False
|
|
encoder_config.causal = False
|
|
# freeze
|
|
encoder_config = FrozenConfigDict(encoder_config)
|
|
|
|
self.encoder = FlaxT5Stack(
|
|
encoder_config,
|
|
embed_tokens=self.shared,
|
|
dtype=self.dtype,
|
|
gradient_checkpointing=self.gradient_checkpointing,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict: bool = True,
|
|
deterministic: bool = True,
|
|
):
|
|
# Encode if needed (training, first prediction pass)
|
|
encoder_outputs = self.encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
return encoder_outputs
|
|
|
|
|
|
class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
|
|
module_class = FlaxT5EncoderModule
|
|
|
|
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
|
|
def __call__(
|
|
self,
|
|
input_ids: jnp.ndarray,
|
|
attention_mask: Optional[jnp.ndarray] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
train: bool = False,
|
|
params: dict = None,
|
|
dropout_rng: PRNGKey = None,
|
|
):
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
# prepare encoder inputs
|
|
if attention_mask is None:
|
|
attention_mask = jnp.ones_like(input_ids)
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
|
|
|
return self.module.apply(
|
|
{"params": params or self.params},
|
|
input_ids=jnp.array(input_ids, dtype="i4"),
|
|
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=not train,
|
|
rngs=rngs,
|
|
)
|
|
|
|
|
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
|
|
class FlaxT5ForConditionalGenerationModule(nn.Module):
|
|
config: FrozenConfigDict
|
|
dtype: jnp.dtype = jnp.bfloat16 # the dtype of the computation
|
|
gradient_checkpointing: bool = False
|
|
|
|
def _get_encoder_module(self):
|
|
return self.encoder
|
|
|
|
def _get_decoder_module(self):
|
|
return self.decoder
|
|
|
|
def setup(self):
|
|
self.model_dim = self.config.d_model
|
|
|
|
self.shared = nn.Embed(
|
|
self.config.vocab_size,
|
|
self.config.d_model,
|
|
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
encoder_config = copy.deepcopy(self.config)
|
|
# unfreeze
|
|
encoder_config = ConfigDict(encoder_config)
|
|
encoder_config.causal = False
|
|
encoder_config.use_cache = False
|
|
encoder_config.is_encoder_decoder = False
|
|
# freeze
|
|
encoder_config = FrozenConfigDict(encoder_config)
|
|
|
|
self.encoder = FlaxT5Stack(
|
|
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
|
)
|
|
|
|
decoder_config = copy.deepcopy(self.config)
|
|
# unfreeze
|
|
decoder_config = ConfigDict(decoder_config)
|
|
decoder_config.causal = True
|
|
decoder_config.is_encoder_decoder = False
|
|
decoder_config.num_layers = self.config.num_decoder_layers
|
|
# freeze
|
|
decoder_config = FrozenConfigDict(decoder_config)
|
|
|
|
self.decoder = FlaxT5Stack(
|
|
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
|
)
|
|
|
|
self.lm_head = nn.Dense(
|
|
self.config.vocab_size,
|
|
use_bias=False,
|
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
decoder_input_ids=None,
|
|
decoder_attention_mask=None,
|
|
encoder_outputs=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
deterministic: bool = True,
|
|
):
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
# Encode
|
|
encoder_outputs = self.encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
# Decode
|
|
decoder_outputs = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
attention_mask=decoder_attention_mask,
|
|
encoder_hidden_states=hidden_states,
|
|
encoder_attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
sequence_output = decoder_outputs[0]
|
|
|
|
if self.config.tie_word_embeddings:
|
|
# Rescale output before projecting on vocab
|
|
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
|
|
|
if self.config.tie_word_embeddings:
|
|
shared_embedding = self.shared.variables["params"]["embedding"]
|
|
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
|
|
else:
|
|
lm_logits = self.lm_head(sequence_output)
|
|
|
|
if not return_dict:
|
|
return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
|
|
|
return FlaxSeq2SeqLMOutput(
|
|
logits=lm_logits,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
|
module_class = FlaxT5ForConditionalGenerationModule
|
|
|
|
@add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=FrozenConfigDict)
|
|
def decode(
|
|
self,
|
|
decoder_input_ids,
|
|
encoder_outputs,
|
|
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
past_key_values: dict = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
train: bool = False,
|
|
params: dict = None,
|
|
dropout_rng: PRNGKey = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
|
|
>>> import jax.numpy as jnp
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
|
|
|
>>> text = "summarize: My friends are cool but they eat too many carbs."
|
|
>>> inputs = tokenizer(text, return_tensors="np")
|
|
>>> encoder_outputs = model.encode(**inputs)
|
|
|
|
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
|
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
|
|
|
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
|
>>> logits = outputs.logits
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
|
|
encoder_hidden_states = encoder_outputs[0]
|
|
if encoder_attention_mask is None:
|
|
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
|
|
batch_size, sequence_length = decoder_input_ids.shape
|
|
if decoder_attention_mask is None:
|
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
|
|
# Handle any PRNG if needed
|
|
rngs = {}
|
|
if dropout_rng is not None:
|
|
rngs["dropout"] = dropout_rng
|
|
|
|
inputs = {"params": params or self.params}
|
|
|
|
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
|
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
|
# it can be changed by FlaxT5Attention module
|
|
if past_key_values:
|
|
inputs["cache"] = past_key_values
|
|
mutable = ["cache"]
|
|
else:
|
|
mutable = False
|
|
|
|
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
|
|
decoder_module = module._get_decoder_module()
|
|
decoder_outputs = decoder_module(
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = decoder_outputs[0]
|
|
|
|
if self.config.tie_word_embeddings:
|
|
# Rescale output before projecting on vocab
|
|
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
|
|
|
if self.config.tie_word_embeddings:
|
|
shared_embedding = module.shared.variables["params"]["embedding"]
|
|
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
|
|
else:
|
|
lm_logits = module.lm_head(sequence_output)
|
|
|
|
return lm_logits, decoder_outputs
|
|
|
|
outputs = self.module.apply(
|
|
inputs,
|
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
deterministic=not train,
|
|
rngs=rngs,
|
|
mutable=mutable,
|
|
method=_decoder_forward,
|
|
)
|
|
|
|
if past_key_values is None:
|
|
lm_logits, decoder_outputs = outputs
|
|
else:
|
|
(lm_logits, decoder_outputs), past = outputs
|
|
|
|
if return_dict:
|
|
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
|
logits=lm_logits,
|
|
hidden_states=decoder_outputs.hidden_states,
|
|
attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
)
|
|
else:
|
|
outputs = (lm_logits,) + decoder_outputs[1:]
|
|
|
|
# add updated cache to model output
|
|
if past_key_values is not None and return_dict:
|
|
outputs["past_key_values"] = unfreeze(past["cache"])
|
|
return outputs
|
|
elif past_key_values is not None and not return_dict:
|
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
|
|
|
return outputs
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
decoder_input_ids,
|
|
max_length,
|
|
attention_mask: Optional[jax.Array] = None,
|
|
decoder_attention_mask: Optional[jax.Array] = None,
|
|
encoder_outputs=None,
|
|
**kwargs,
|
|
):
|
|
# initializing the cache
|
|
batch_size, seq_length = decoder_input_ids.shape
|
|
|
|
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
|
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
|
# But since the decoder uses a causal mask, those positions are masked anyways.
|
|
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
|
if decoder_attention_mask is not None:
|
|
extended_attention_mask = jax.lax.dynamic_update_slice(
|
|
extended_attention_mask, decoder_attention_mask, (0, 0)
|
|
)
|
|
|
|
return {
|
|
"past_key_values": past_key_values,
|
|
"encoder_outputs": encoder_outputs,
|
|
"encoder_attention_mask": attention_mask,
|
|
"decoder_attention_mask": extended_attention_mask,
|
|
}
|
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
|
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
|
return model_kwargs
|
|
|
|
|
|
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
|
|
|
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
|
|
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
|
|
|
|
>>> # Generate Summary
|
|
>>> summary_ids = model.generate(inputs["input_ids"]).sequences
|
|
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
|
```
|
|
"""
|
|
|
|
|
|
overwrite_call_docstring(
|
|
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
|
|
)
|
|
append_replace_return_docstrings(
|
|
FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
|
) |