diff --git a/.gitignore b/.gitignore index 3ca7f8b..d39181d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ *.ipynb -t5_*/ model_checkpoints/ exports/ -modified_t5_model/ traces/ ruff.toml settings.json diff --git a/parallel/t5_model/.gitignore b/parallel/t5_model/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/parallel/t5_model/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/parallel/t5_model/configuration_t5.py b/parallel/t5_model/configuration_t5.py new file mode 100644 index 0000000..f6d904b --- /dev/null +++ b/parallel/t5_model/configuration_t5.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T5 model configuration""" + +from typing import Mapping + +from transformers import PretrainedConfig +from transformers import logging + + +logger = logging.get_logger(__name__) + + +class T5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to + instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the T5 + [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the + `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, # vocab size here + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + self.use_bfloat16 = True + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): +# @property +# def inputs(self) -> Mapping[str, Mapping[int, str]]: +# common_inputs = { +# "input_ids": {0: "batch", 1: "encoder_sequence"}, +# "attention_mask": {0: "batch", 1: "encoder_sequence"}, +# } +# if self.use_past: +# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" +# common_inputs["decoder_input_ids"] = {0: "batch"} +# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} +# else: +# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} +# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} +# +# if self.use_past: +# self.fill_with_past_key_values_(common_inputs, direction="inputs") +# +# return common_inputs +# +# @property +# def default_onnx_opset(self) -> int: +# return 13 \ No newline at end of file diff --git a/parallel/t5_model/modeling_t5_flax.py b/parallel/t5_model/modeling_t5_flax.py new file mode 100644 index 0000000..ac1e72d --- /dev/null +++ b/parallel/t5_model/modeling_t5_flax.py @@ -0,0 +1,2056 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax T5 model.""" + +import copy +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers import T5Config + +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from ml_collections import ConfigDict + +import functools + +from typing import Any, Dict, Tuple, Callable, Sequence + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-t5/t5-small" +_CONFIG_FOR_DOC = "T5Config" + +remat = nn_partitioning.remat + + +# MARK: PARAMETER SHARDING +# %% [markdown] +# # parameter sharding +# Basic strategy: init full parameters on each device, then use +# jax.lax.axis_index to split parameters across devices, and keep a shard on +# each device +# +# use nn.Partitioned to annotate sharding spec on parameters +# quite similar to PartitionSpec +# +# parameters are either jax.Array or a flax.linen.Partitioned + +# %% +# type annotation +Parameter = jax.Array | nn.Partitioned + +# %% +# function to shard parameters across devices +# look for an axis to equally split across the number of devices +# we can specify which parameters to shard, since they vary in size +# we set a floor on the size for sharding +@jax.named_scope("shard_params") +def shard_params(params: PyTree, axis_name: str, min_weight_size: int = 2**18) -> PyTree: + """Shard parameters across the given mesh axis. + + Args: + params: The parameters to shard. + axis_name: The axis to shard parameters across. + min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. + + Returns: + PyTree of same structure as params, but with leaves sharded over new axis if possible. + """ + # axis_index + axis_idx = jax.lax.axis_index(axis_name) + # number of units in the axis + axis_size = jax.lax.psum(1, axis_name) + + # split function + # check each parameter if it had been sharded + def _split(x: Parameter) -> Parameter: + + # already sharded + if isinstance(x, nn.Partitioned): + value, names = x.value, x.names + # not sharded + else: + value = x + names = (None,) * value.ndim + + # logging only runs on first jit + # this section checks for why a parameter is not already sharded on the axis + # check for sharded parameters despite being sharded + # (that means its on a different axis) + if axis_name in names: + logging.warning( + f"Parameter {value.shape} with names {names} already sharded on axis {axis_name}." + ) + return x + # check if parameter is to small + elif value.size <= min_weight_size: + logging.info( + f"Parameter {value.shape} with names {names} too small to shard, size {value.size} < {min_weight_size}." + ) + return x + # let's start sharding! + else: + shape = value.shape + idx = np.argsort(shape)[::-1] # Shard along largest possible axis. + for i in idx: + # this technically runs once because of return + # we only shard if we can split evenly across devices + # and if it ain't alreayd sharded + if shape[i] % axis_size == 0 and names[i] is None: + split_size = shape[i] // axis_size + p_sharded = nn.Partitioned( + value=jax.lax.dynamic_slice_in_dim( # Shard to keep on present device. + value, + axis_idx * split_size, + split_size, + axis=i + ), + names=names[:i] + (axis_name,) + names[i + 1 :], + ) + return p_sharded + + logging.warning( + f"Could not shard {value.shape} with names {names} on axis {axis_name}, no suitable axis found." + ) + return x + + # we apply the _split function across the parameter pytree + return jax.tree.map( + _split, + params, + is_leaf=lambda x: isinstance( + x, nn.Partitioned + ), # Consider a nn.Partitioned object as a leaf. + ) + +# %% +# function to gather parameters back to a single device + +# but first we need create a custom function for mean gradient computation +# jax.lax.all_gather -> retrieve shards and assemble full array in each device +# jax.lax.psum_scatter -> scatter gradients back to respective devices +def gather_array_with_mean_grads(x: jax.Array, axis: int, axis_name: str): + """Gathering with averaging gradients across replicas.""" + axis_size = jax.lax.psum(1, axis_name) + + # Define a custom gradient for the gather operation. + @jax.custom_gradient + def f(x): + # adjust backward to turn sum into mean of axis + def grad_fn(g): + # pmean_scatter from psum_scatter + # after computing from full gradient array, our shard only has a + # portion of the parameters, we only get the gradients associated + # with parameters of our shard + return ( + jax.lax.psum_scatter(g, axis_name, scatter_dimension=axis, tiled=True) / axis_size + ) + + # assemble shards to form full gradient array + return jax.lax.all_gather(x, axis_name, axis=axis, tiled=True), grad_fn + + return f(x) + +# gather params back - e.g. when computing a module forward call +# reverse operation of "shard_params" +# depends on: gather_array_with_mean_grads +@jax.named_scope("gather_params") +def gather_params(params: PyTree, axis_name: str) -> PyTree: + """Gather parameters from all replicas across the given axis. + + Args: + params: The parameters to gather. + axis_name: The axis to gather parameters across. + + Returns: + PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object. + """ + + def _gather(p: Parameter) -> Parameter: + if isinstance(p, nn.Partitioned) and axis_name in p.names: + param_shard = p.names + shard_axis = param_shard.index(axis_name) + value = gather_array_with_mean_grads(p.value, axis=shard_axis, axis_name=axis_name) + + # If there are any other axes that are sharded, we need to keep the partitioned structure. + # Otherwise, we can return the value directly. + param_shard = param_shard[:shard_axis] + (None,) + param_shard[shard_axis + 1 :] + if any([name is not None for name in param_shard]): + # we return the still-sharded axes shard + return nn.Partitioned(value, param_shard) + else: + return value + else: + return p + + # we find all the sharded params and gather them, returning a complete parameter + return jax.tree.map( + _gather, + params, + is_leaf=lambda x: isinstance(x, nn.Partitioned)) + +# %% +# when we call a module, we gather the parameters back to a single device +# wrap a module into a nn.map_variables transform +# allows for transforms on the parameter before and after a module call +# depends on: gather_params, shard_params +def shard_module_params( + target: nn.Module | Callable, + axis_name: str, + min_weight_size: int = 2**18 # 262,144 +) -> nn.Module | Callable: + """Shard parameters of a module across replicas. + + Args: + target: The module to shard. + axis_name: The axis name to shard parameters across. + min_weight_size: The minimum size of a parameter to shard. Parameters with fewer values will not be sharded. + + Returns: + The module with sharded parameters. + """ + return nn.map_variables( + target, + trans_in_fn=functools.partial( + gather_params, axis_name=axis_name), + trans_out_fn=functools.partial( + shard_params, axis_name=axis_name, min_weight_size=min_weight_size + ), + mapped_collections="params", + mutable=True, + ) + + + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +class FlaxT5DenseActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, 'model') + ), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wo_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5DenseGatedActDense(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wo_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5LayerFF(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +class FlaxT5Attention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(q_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(kv_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(kv_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(o_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxT5LayerSelfAttention(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxT5Attention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5LayerCrossAttention(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5Block(nn.Module): + config: T5Config + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxT5LayerSelfAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +class FlaxT5LayerCollection(nn.Module): + config: T5Config + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +class FlaxT5BlockCollection(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxT5Stack(nn.Module): + config: T5Config + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +T5_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: T5Config, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.bfloat16, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=T5Config) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +T5_START_DOCSTRING = r""" + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5Module(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + # sharded module + # with manual shard_map + # sharded_embed = shard_module_params( + # nn.Embed, + # axis_name="model", + # ) + # self.shared = sharded_embed( + # self.config.vocab_size, + # self.config.d_model, + # embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + # dtype=self.dtype, + # ) + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5Model(FlaxT5PreTrainedModel): + module_class = FlaxT5Module + + +append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + +FLAX_T5_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5EncoderModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class FlaxT5ForConditionalGenerationModule(nn.Module): + config: T5Config + dtype: jnp.dtype = jnp.bfloat16 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = FlaxT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + # handles outgoing predictions + # think of it like a wo of a dense + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(self.config.initializer_factor), + ("model", None) + ), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + # shared_embedding is a partitioned tensor + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): + module_class = FlaxT5ForConditionalGenerationModule + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=T5Config) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) \ No newline at end of file diff --git a/parallel/t5_model/pure_t5.py b/parallel/t5_model/pure_t5.py new file mode 100644 index 0000000..7cc1510 --- /dev/null +++ b/parallel/t5_model/pure_t5.py @@ -0,0 +1,1717 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, Optional, Tuple, Dict +from collections import OrderedDict, UserDict +from dataclasses import fields, is_dataclass + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey +from ml_collections import ConfigDict + +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +import flax + + +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, +) +# from transformers import T5Config + +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from ml_collections import FrozenConfigDict + +import functools + +from typing import Any, Dict, Tuple, Callable, Sequence + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + + + + + +remat = nn_partitioning.remat + + +# MARK: PARAMETER SHARDING +# %% [markdown] +# # parameter sharding +# Basic strategy: init full parameters on each device, then use +# jax.lax.axis_index to split parameters across devices, and keep a shard on +# each device +# +# use nn.Partitioned to annotate sharding spec on parameters +# quite similar to PartitionSpec +# +# parameters are either jax.Array or a flax.linen.Partitioned + +# t5-base +# https://huggingface.co/google-t5/t5-base/blob/main/config.json +def make_config(): + model_type="t5" + # keys_to_ignore_at_inference = ["past_key_values"] + # attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + d_ff=3072 + d_kv=64 + d_model=768 + vocab_size=32128 + num_layers=12 + num_heads=12 + num_decoder_layers=None + num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else num_layers + ) # default = symmetry + relative_attention_num_buckets=32 + relative_attention_max_distance=128 + dropout_rate=0.1 + layer_norm_epsilon=1e-6 + initializer_factor=1.0 + feed_forward_proj="relu" + is_encoder_decoder=True + use_cache=True + pad_token_id=0 + eos_token_id=1 + decoder_start_token_id=0 + classifier_dropout=0.0 + act_info = feed_forward_proj.split("-") + dense_act_fn = act_info[-1] + is_gated_act = act_info[0] == "gated" + causal=False + use_return_dict=False + if feed_forward_proj == "gated-gelu": + dense_act_fn = "gelu_new" + tie_word_embeddings=False + return FrozenConfigDict( + dict( + model_type=model_type, + vocab_size=vocab_size, + d_model=d_model, + d_kv=d_kv, + d_ff=d_ff, + num_layers=num_layers, + num_decoder_layers = num_decoder_layers, + num_heads=num_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + initializer_factor=initializer_factor, + feed_forward_proj=feed_forward_proj, + is_encoder_decoder=is_encoder_decoder, + use_cache=use_cache, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + classifier_dropout=classifier_dropout, + is_gated_act = is_gated_act, + dense_act_fn = dense_act_fn, + causal=causal, + use_return_dict=use_return_dict, + tie_word_embeddings=tie_word_embeddings + ) + ) +T5config = make_config() + + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +class FlaxT5DenseActDense(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, 'model') + ), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wo_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5DenseGatedActDense(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wi_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(wo_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5LayerFF(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +class FlaxT5Attention(nn.Module): + config: ConfigDict + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(q_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(kv_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(kv_init_std), + (None, "model") + ), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(o_init_std), + ("model", None) + ), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxT5LayerSelfAttention(nn.Module): + config: ConfigDict + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxT5Attention( + config=self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5LayerCrossAttention(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxT5Attention( + config=self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5Block(nn.Module): + config: ConfigDict + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxT5LayerSelfAttention( + config=self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +class FlaxT5LayerCollection(nn.Module): + config: ConfigDict + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +class FlaxT5BlockCollection(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxT5Stack(nn.Module): + config: ConfigDict + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxT5BlockCollection( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + + +class FlaxT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ConfigDict + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: ConfigDict, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.bfloat16, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + + +class FlaxT5Module(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + # sharded module + # with manual shard_map + # sharded_embed = shard_module_params( + # nn.Embed, + # axis_name="model", + # ) + # self.shared = sharded_embed( + # self.config.vocab_size, + # self.config.d_model, + # embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + # dtype=self.dtype, + # ) + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + # unfreeze + encoder_config = ConfigDict(encoder_config) + encoder_config.causal = False + encoder_config = FrozenConfigDict(encoder_config) + # freeze + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + # unfreeze + decoder_config = ConfigDict(encoder_config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + # freeze + decoder_config = FrozenConfigDict(encoder_config) + + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # if not return_dict: + return decoder_outputs + encoder_outputs + + # return FlaxSeq2SeqModelOutput( + # last_hidden_state=decoder_outputs.last_hidden_state, + # past_key_values=decoder_outputs.past_key_values, + # decoder_hidden_states=decoder_outputs.hidden_states, + # decoder_attentions=decoder_outputs.attentions, + # cross_attentions=decoder_outputs.cross_attentions, + # encoder_last_hidden_state=encoder_outputs.last_hidden_state, + # encoder_hidden_states=encoder_outputs.hidden_states, + # encoder_attentions=encoder_outputs.attentions, + # ) + + +class FlaxT5Model(FlaxT5PreTrainedModel): + module_class = FlaxT5Module + + +class FlaxT5EncoderModule(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +class FlaxT5ForConditionalGenerationModule(nn.Module): + config: ConfigDict + dtype: jnp.dtype = jnp.bfloat16 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + # unfreeze + encoder_config = ConfigDict(encoder_config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + # freeze + encoder_config = FrozenConfigDict(encoder_config) + self.encoder = FlaxT5Stack( + config=encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing + ) + decoder_config = copy.deepcopy(self.config) + # unfreeze + decoder_config = ConfigDict(encoder_config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + # freeze + encoder_config = FrozenConfigDict(decoder_config) + self.decoder = FlaxT5Stack( + config=decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing + ) + + # handles outgoing predictions + # think of it like a wo of a dense + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=nn.with_partitioning( + jax.nn.initializers.normal(self.config.initializer_factor), + ("model", None) + ), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + # shared_embedding is a partitioned tensor + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + # return lm_logits + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): + module_class = FlaxT5ForConditionalGenerationModule + + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + diff --git a/t5_jax.py b/t5_jax.py index def53cb..a35b4d5 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -25,6 +25,7 @@ import numpy as np from functools import partial from typing import Callable, Optional import math +import flax.linen as nn # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "bfloat16") @@ -83,7 +84,7 @@ os.environ.update({ "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) @@ -103,14 +104,14 @@ except (LookupError, OSError): # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' -save_path = 't5_5e_1_pmap' +save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 -num_epochs = 5 -batch_size = 32 # 384 is the best +num_epochs = 40 +batch_size = 64 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() @@ -120,7 +121,7 @@ steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 -learning_rate = 2e-5 +learning_rate = 2e-4 weight_decay = 0.01 adam_beta1 = 0.9 @@ -129,7 +130,7 @@ adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 -val_max_target_length = 86 +val_max_target_length = 128 predict_with_generate = True @@ -143,7 +144,7 @@ additional_special_tokens = ["", "", "", # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) -max_length = 86 +max_length = 128 # %% len(tokenizer) @@ -176,51 +177,64 @@ from transformers import FlaxT5ForConditionalGeneration from transformers import T5Config config = T5Config() - - -# %% -# If you want don't want to cast certain parameters (for example layer norm bias and scale) -# then pass the mask as follows -from flax import traverse_util - -model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") -# useful for transformer model -# model.enable_gradient_checkpointing() +model = FlaxT5ForConditionalGeneration.from_pretrained( + "t5-base", + dtype=jnp.bfloat16, + gradient_checkpointing=True +) +params = model.params # enable bf16 except for layer_norm -# flat_params = traverse_util.flatten_dict(model.params) -# mask = { -# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params -# } -# mask = traverse_util.unflatten_dict(mask) -# # borrowed from transformers modeling_flax_utils -# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: -# """ -# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. -# """ -# -# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 -# def conditional_cast(param): -# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): -# param = param.astype(dtype) -# return param -# -# if mask is None: -# return jax.tree_util.tree_map(conditional_cast, params) -# -# flat_params = traverse_util.flatten_dict(params) -# flat_mask, _ = jax.tree_util.tree_flatten(mask) -# -# for masked, key in zip(flat_mask, sorted(flat_params.keys())): -# if masked: -# flat_params[key] = conditional_cast(flat_params[key]) -# -# return traverse_util.unflatten_dict(flat_params) -# -# # Cast parameters to bfloat16 if desired -# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params) -# # instead of casting the whole thing, we cast only certain parts of the tree -# params = cast_floating_to(model.params, jnp.bfloat16, mask) +# enable bf16 +# enable only for dense, some transformer sections, and shared +def create_mask_for_layer_norm(params): + flat_params = traverse_util.flatten_dict(params) + mask = { + # path: not ( + # (path[-2] == "layer_norm" and path[-1] == "weight") or + # (path[-2] == "final_layer_norm" and path[-1] == "weight") or + # (path[-2] == "o" and path[-1] == "kernel") + # ) + # for path in flat_params + path: ( + (path[-2] == "wi" and path[-1] == "weight") or + (path[-2] == "wo" and path[-1] == "weight") or + (path[-2] == "k" and path[-1] == "kernel") or + (path[-2] == "q" and path[-1] == "kernel") or + (path[-2] == "v" and path[-1] == "kernel") or + (path[-2] == "shared" and path[-1] == "embedding") + ) for path in flat_params + } + mask = traverse_util.unflatten_dict(mask) + return mask + +# borrowed from transformers modeling_flax_utils +def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = traverse_util.flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return traverse_util.unflatten_dict(flat_params) + +mask = create_mask_for_layer_norm(params) +# override params with bfloat version +params= cast_floating_to(params, jnp.bfloat16, mask) # %% @@ -307,31 +321,6 @@ token_datasets.set_format( 'labels', 'decoder_input_ids', 'decoder_attention_mask'] ) -# %% -# check values -for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: - int_array = train_dataset[name] - if np.all((int_array >= 0) & (int_array <= 65535)): - uint16_array = int_array.astype(np.uint16) - else: - raise ValueError("Values are out of range for uint16") - -# %% - -from datasets import ClassLabel, Value, Sequence -features = train_dataset.features.copy() -features['input_ids'] = Sequence(Value('uint16')) -features['attention_mask'] = Sequence(Value('bool')) -features['labels'] = Sequence(Value('uint16')) -features['decoder_input_ids'] = Sequence(Value('uint16')) -features['decoder_attention_mask'] = Sequence(Value('bool')) -train_dataset = train_dataset.cast(features) - - - -# %% -# temp -print('data type check: ', train_dataset['decoder_attention_mask'].dtype) # %% def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): @@ -355,17 +344,11 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf for idx in batch_idx: batch = dataset[idx] - batch = {k: jnp.array(v) for k, v in batch.items()} + batch = {k: v for k, v in batch.items()} yield batch -# %% [markdown] -# # Model -# -# -# - # %% # Initialize our training @@ -406,7 +389,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn( def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters - layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_candidates = ["final_layer_norm", "layer_norm"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates @@ -437,13 +420,10 @@ class TrainState(train_state.TrainState): def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) -# set bf16 for model params -# model.params = model.to_bf16(model.params) -# Cast parameters to bfloat16 if desired -# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) # Setup train state -state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) +# input all the state here +state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): @@ -485,17 +465,17 @@ def train_step(state, batch, label_smoothing_factor=0.0): num_labels = jax.lax.psum(num_labels, "batch") # true loss = total loss / total samples - # loss = jax.lax.psum(loss, "batch") - # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # true grad = total grad / total samples grad = jax.lax.psum(grad, "batch") grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) - # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - # return new_state, metrics - return new_state + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + # return new_state # Define generation function max_length = ( @@ -549,25 +529,24 @@ for epoch in epochs: for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) - state = p_train_step(state, batch) + state, train_metric = p_train_step(state, batch) # train_metrics.append(train_metric) train_time = time.time() - train_start - # train_metric = unreplicate(train_metric) - # train_metric['loss'].block_until_ready() + train_metric = unreplicate(train_metric) + train_metric['loss'].block_until_ready() epochs.write( # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " f"Epoch... ({epoch + 1}/{num_epochs} | " - # f"Learning Rate:{train_metric['learning_rate']}, " + f"Learning Rate:{train_metric['learning_rate']}, " f"Last train time: {train_time})" ) # jax.profiler.stop_trace() # %% - output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index 89b1193..f8dc1bd 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -66,7 +66,8 @@ import orbax.checkpoint as ocp # data_path = f"../make_data/select_db/data_mapping_filtered.csv" # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" -data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' +model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name +data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' # Ensure to include 'ships_idx' in the fields list @@ -97,7 +98,6 @@ test_dataset = Dataset.from_list(process_df(df)) # from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration from transformers import FlaxT5ForConditionalGeneration # model_name_or_path = "./t5_80_1" # Replace with your specific model name -model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path) params = model.params @@ -275,11 +275,12 @@ df['p_property_correct'] = df['p_property'] == df['property'] print("thing accuracy", sum(df['p_thing_correct'])/len(df)) print("property accuracy", sum(df['p_property_correct'])/len(df)) print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) -# %% -df[~df["p_property_correct"]] # %% -df['p_thing'] +# df[~df["p_property_correct"]] + +# %% +# df['p_thing'] # %% # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) # df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" diff --git a/t5_jax_sfp_grad_accumulate.py b/t5_jax_sfp_grad_accumulate.py new file mode 100644 index 0000000..52abe0a --- /dev/null +++ b/t5_jax_sfp_grad_accumulate.py @@ -0,0 +1,610 @@ +# %% +import os +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = False + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +os.environ["XLA_FLAGS"] = flags +os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" + }) + + + + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import freeze, unfreeze, FrozenDict +import flax.core + +# model checkpointing and saving utilities +from flax import linen as nn +from flax.training import checkpoints, train_state +from flax import struct, serialization + +from parallel.partitions import set_partitions + +from tqdm import tqdm + +from parallel.dataload import DataPrepare + +# for memory tracking +# from jax_smi import initialise_tracking +# initialise_tracking() + + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + + + + +# %% +## get platform type +from jax.extend.backend import get_backend +print(get_backend().platform) +print(jax.devices()) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' +save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 64 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 +num_beams = 1 +val_max_target_length = 128 +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=batch_size) +batch = next(iter(train_loader)) + +# %% +# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom +from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model +main_model = custom_model.from_pretrained( + "t5-base", + dtype=jnp.bfloat16, + gradient_checkpointing=True, +) +params = main_model.params +# pretrained_params = model.params +model = main_model.module + +# %% +# # testing config hashability +# # some explanation: +# # The PreTrainedModel class loads a T5Config model that is not hashable because +# # it is a complicated class that pretends to be a dataclass. +# # The solution is to extract a dict from it, then make a ConfigDict from +# # ml_collections library so that we can get values via the "." operator. +# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to +# # modify the config before passing to the next layer +# from transformers import T5Config +# from t5_model.configuration_t5 import FrozenT5Config +# from ml_collections import ConfigDict, FrozenConfigDict +# +# config = T5Config.from_pretrained("t5-base").to_dict() +# config.pop('architectures') +# config.pop('id2label') +# # test if it works +# frozen_config = FrozenConfigDict(config) +# # test hash +# hash(frozen_config) + +# %% +# # print model +# rng, input_rng = jax.random.split(rng) +# model.tabulate( +# input_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'], +# console_kwargs={"force_jupyter": True} +# ) + + + +# %% +# create mesh +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((2,2)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec, memory_kind="device") + +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +model_sharding=mesh_sharding(PartitionSpec(None, 'model')) + + +# %% +# optimizers + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + +# %% + +print("compile") + +# enable bf16 +# enable only for dense, some transformer sections, and shared +def create_mask_for_layer_norm(params): + flat_params = traverse_util.flatten_dict(params) + mask = { + # path: not ( + # (path[-2] == "layer_norm" and path[-1] == "weight") or + # (path[-2] == "final_layer_norm" and path[-1] == "weight") or + # (path[-2] == "o" and path[-1] == "kernel") + # ) + # for path in flat_params + path: ( + (path[-2] == "wi" and path[-1] == "weight") or + (path[-2] == "wo" and path[-1] == "weight") or + (path[-2] == "k" and path[-1] == "kernel") or + (path[-2] == "q" and path[-1] == "kernel") or + (path[-2] == "v" and path[-1] == "kernel") or + (path[-2] == "shared" and path[-1] == "embedding") + ) for path in flat_params + } + mask = traverse_util.unflatten_dict(mask) + return mask + +# borrowed from transformers modeling_flax_utils +def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = traverse_util.flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return traverse_util.unflatten_dict(flat_params) + +# create init_fn to produce sharded state +def init_fn(params, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + + # mask = create_mask_for_layer_norm(params) + # override params with bfloat version + # params= cast_floating_to(params, jnp.bfloat16, mask) + + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=params, + tx=optimizer) + return state + + +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=adamw), params) + +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) +# print(state_sharding) + +# %% + +# replace the params tree with the new modified tree +# create partitions for model +from parallel.partitions import set_partitions +# set_partitions freezes the params on return +model_part_spec = set_partitions(unfreeze(params)) +# p is already a partition spec +model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) + +# get pspec for opt_state +def get_opt_spec(x): + if isinstance(x, dict): + return unfreeze(model_named_sharding) + # return an empty partspec + return mesh_sharding((PartitionSpec())) + +# this function replaces the empty model params spec with the 'model_named_shard' +state_sharding = jax.tree.map( + get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) +) + + + +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly + out_shardings=state_sharding # but returned initialized_state is sharded +) +initialized_state = jit_init_fn(params, model, adamw) + +# %% +# train step +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + logits = jnp.asarray(logits, dtype=jnp.float32) + logits = logits.astype(jnp.float32) + soft_labels = soft_labels.astype(jnp.float32) + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.mean() + # num_labels = padding_mask.mean() + return loss # , num_labels + +# %% +# gradient accumulation +def accumulate_gradients_loop( + state, + batch, + minibatch_size: int, + loss_fn: Callable, +) -> Tuple[PyTree, Metrics]: + """Calculate gradients and metrics for a batch using gradient accumulation. + + Args: + state: Current training state. + batch: Full training batch. + rng: Random number generator to use. + num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. + loss_fn: Loss function to calculate gradients and metrics. + + Returns: + Tuple with accumulated gradients and metrics over the minibatches. + """ + batch_size = batch['input_ids'].shape[0] + # minibatch_size = batch_size // num_minibatches + num_minibatches = batch_size // minibatch_size + # Define gradient function for single minibatch. + # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. + # otherwise it returns (value, gradient), where value is the actual output + # of the function, hence the "value" of the namesake + grad_fn = jax.value_and_grad(loss_fn, has_aux=False) + # Prepare loop variables. + grads = None + metrics = None + for minibatch_idx in range(num_minibatches): + with jax.named_scope(f"minibatch_{minibatch_idx}"): + # Split the batch into minibatches. + start = minibatch_idx * minibatch_size + end = start + minibatch_size + minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023 + # Calculate gradients and metrics for the minibatch. + # missing value is mean loss of batch + loss, step_grads = grad_fn( + state.params, minibatch + ) + with jax.named_scope("sync_metrics"): + step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + + # Accumulate gradients and metrics across minibatches. + if grads is None: + grads = step_grads + metrics = step_metrics + else: + # accumulation adder + grads = jax.tree.map(jnp.add, grads, step_grads) + metrics = jax.tree.map(jnp.add, metrics, step_metrics) + # Average gradients over minibatches. + grads = jax.tree.map(lambda g: g / num_minibatches, grads) + return grads, metrics + + + +# single device code annotated with jax.jit +@functools.partial( + jax.jit, + # state is state_sharding initialized from init_fn + # x_sharding is data sharded explicitly later + in_shardings=(state_sharding, x_sharding), + out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), + donate_argnames=('state'), +) +def train_step(state, batch): + + label_smoothing_factor=0.0 + # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + def compute_loss(params, batch): + # check constraints + # frozen dict not allowed as sharding object + params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) + batch = jax.lax.with_sharding_constraint(batch, x_sharding) + logits = state.apply_fn( + {'params': params}, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_input_ids'], + decoder_attention_mask=batch['decoder_attention_mask'], + )[0] # zero because output is some structure, where first is the logit + # use labels here + # loss, num_labels = loss_fn( + loss = loss_fn( + logits, + batch["labels"], + batch["decoder_attention_mask"], + label_smoothing_factor) + return loss # , num_labels + + # # compute gradients through computational graph + # # allow values to pass through + # grad_fn = jax.value_and_grad(compute_loss, has_aux=False) + # (loss), grad = grad_fn(state.params, batch) + # # num_labels = jax.lax.psum(num_labels, "batch") + + + # new_state = state.apply_gradients(grads=grad) + # with jax.named_scope("sync_metrics"): + # step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + + # use gradient accumulation + grads, step_metrics = accumulate_gradients_loop( + state=state, + batch=batch, + minibatch_size=32, + loss_fn=compute_loss + ) + new_state = state.apply_gradients(grads=grads) + + return new_state, step_metrics + +# %% +# explore data sharding +sharded_batch = next(iter(train_loader)) +sharded_batch = jax.device_put(sharded_batch, x_sharding) +# jax.debug.visualize_array_sharding(sharded_batch['input_ids']) +# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) + + +# %% +# # prep 1 step +# print("1 step for jit-ting") +# with mesh: +# state, metrics = train_step(initialized_state, sharded_batch) + +# %% + +# %% +# tr +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + + +print("*" * 10) +print("training start") +rng, input_rng = jax.random.split(rng) +train_time = 0 +state = initialized_state +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + steps_per_epoch = training_size // train_batch_size + train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = jax.device_put(batch, x_sharding) + with mesh: + state, train_metric = train_step(state, batch) + + # train_metrics.append(train_metric) + + + # this is for more accurate time stats, but slows down training + # train_metric['loss'].block_until_ready() + train_time = time.time() - train_start + + + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | " + f"Loss: {train_metric['loss']}, " + f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() + +# %% +# try out +gather_state = jax.device_get(state) +gather_batch = jax.device_get(batch) +logits = gather_state.apply_fn( + {'params': gather_state.params}, + input_ids=gather_batch['input_ids'], + attention_mask=gather_batch['attention_mask'], + decoder_input_ids=gather_batch['decoder_input_ids'], + decoder_attention_mask=gather_batch['decoder_attention_mask'], +)[0] # zero because output is some structure, where first is the logit + +probs = nn.softmax(logits, axis=-1) +predicted = jnp.argmax(probs, axis=-1) +print("sample output") +print(predicted[1]) + +# %% +main_model = custom_model.from_pretrained('t5-base') +output_dir = save_path + +# save checkpoint after each epoch and push checkpoint to the hub +if jax.process_index() == 0: + params = jax.device_get(state.params) + main_model.save_pretrained(output_dir, params=params) + + +# %% diff --git a/t5_jax_shmap.py b/t5_jax_shmap.py new file mode 100644 index 0000000..9d47b77 --- /dev/null +++ b/t5_jax_shmap.py @@ -0,0 +1,550 @@ +# %% +import os +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = False + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +os.environ["XLA_FLAGS"] = flags +os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" + }) + + + + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import freeze, unfreeze, FrozenDict +import flax.core + +# model checkpointing and saving utilities +from flax import linen as nn +from flax.training import checkpoints, train_state +from flax import struct, serialization + +from parallel.partitions import set_partitions + +from tqdm import tqdm + +from parallel.dataload import DataPrepare + +# for memory tracking +# from jax_smi import initialise_tracking +# initialise_tracking() + + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + + + + +# %% +## get platform type +from jax.extend.backend import get_backend +print(get_backend().platform) +print(jax.devices()) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' +save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 40 +batch_size = 32 # do not go beyond 128, 64 is good +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 2e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 +num_beams = 1 +val_max_target_length = 128 +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=batch_size) +batch = next(iter(train_loader)) + +# %% +# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom +from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model +main_model = custom_model.from_pretrained( + "t5-base", + dtype=jnp.bfloat16, + gradient_checkpointing=True, +) +params = main_model.params +# pretrained_params = model.params +model = main_model.module + +# %% +# # testing config hashability +# # some explanation: +# # The PreTrainedModel class loads a T5Config model that is not hashable because +# # it is a complicated class that pretends to be a dataclass. +# # The solution is to extract a dict from it, then make a ConfigDict from +# # ml_collections library so that we can get values via the "." operator. +# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to +# # modify the config before passing to the next layer +# from transformers import T5Config +# from t5_model.configuration_t5 import FrozenT5Config +# from ml_collections import ConfigDict, FrozenConfigDict +# +# config = T5Config.from_pretrained("t5-base").to_dict() +# config.pop('architectures') +# config.pop('id2label') +# # test if it works +# frozen_config = FrozenConfigDict(config) +# # test hash +# hash(frozen_config) + +# %% +# # print model +# rng, input_rng = jax.random.split(rng) +# model.tabulate( +# input_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'], +# console_kwargs={"force_jupyter": True} +# ) + + + +# %% +# create mesh +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((2,2)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + if USE_CPU_ONLY: + return NamedSharding(mesh, pspec, memory_kind="unpinned_host") + else: + # if gpu + return NamedSharding(mesh, pspec, memory_kind="device") + +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +model_sharding=mesh_sharding(PartitionSpec(None, 'model')) + + +# %% +# optimizers + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["final_layer_norm", "layer_norm"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + +# %% +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + logits = jnp.asarray(logits, dtype=jnp.float32) + logits = logits.astype(jnp.float32) + soft_labels = soft_labels.astype(jnp.float32) + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + mean_loss = loss.mean() + # num_labels = padding_mask.mean() + return mean_loss # , num_labels + + +# %% +################################################################ +# old jit in_shardings method + +# create init_fn to produce sharded state +def init_fn(params, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + + # mask = create_mask_for_layer_norm(params) + # override params with bfloat version + # params= cast_floating_to(params, jnp.bfloat16, mask) + + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=params, + tx=optimizer) + return state + + +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=adamw), params) + +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) + +from parallel.partitions import set_partitions +# set_partitions freezes the params on return +model_part_spec = set_partitions(unfreeze(params)) +# p is already a partition spec +model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) + +# get pspec for opt_state +def get_opt_spec(x): + if isinstance(x, dict): + return unfreeze(model_named_sharding) + # return an empty partspec + return mesh_sharding((PartitionSpec())) + +# this function replaces the empty model params spec with the 'model_named_shard' +state_sharding = jax.tree.map( + get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) +) + + + +# %% +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly + out_shardings=state_sharding # but returned initialized_state is sharded +) +initialized_state = jit_init_fn(params, model, adamw) + +# %% +# train step +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + logits = jnp.asarray(logits, dtype=jnp.float32) + logits = logits.astype(jnp.float32) + soft_labels = soft_labels.astype(jnp.float32) + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.mean() + # num_labels = padding_mask.mean() + return loss # , num_labels + +# %% + +# single device code annotated with jax.jit +@functools.partial( + jax.jit, + # state is state_sharding initialized from init_fn + # x_sharding is data sharded explicitly later + in_shardings=(state_sharding, x_sharding), + out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), + donate_argnames=('state'), +) +def train_step(state, batch): + + label_smoothing_factor=0.0 + # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + # computes loss per shard + def compute_loss(params, batch): + # check constraints + # frozen dict not allowed as sharding object + params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) + batch = jax.lax.with_sharding_constraint(batch, x_sharding) + logits = state.apply_fn( + {'params': params}, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_input_ids'], + decoder_attention_mask=batch['decoder_attention_mask'], + )[0] # zero because output is some structure, where first is the logit + + # logits sharding + # data, None, model + # + print("logits") + jax.debug.inspect_array_sharding(logits, callback=print) + # use labels here + # loss, num_labels = loss_fn( + loss = loss_fn( + logits, + batch["labels"], + batch["decoder_attention_mask"], + label_smoothing_factor) + # loss sharding + # it gives PartitionSpec(), which implies a reduction already happened + print("loss") + jax.debug.inspect_array_sharding(loss, callback=print) + + return loss # , num_labels + + # compute gradients through computational graph + # allow values to pass through + grad_fn = jax.value_and_grad(compute_loss, has_aux=False) + batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) + (loss), grads = grad_fn(state.params, batch) + # num_labels = jax.lax.psum(num_labels, "batch") + + # so far we have been operating from within each shard + # we need to sync gradients across devices + # we bring all gradients together onto a single device + # jax.debug.inspect_array_sharding(grads, callback=print) + grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec())) + # grads = jax.lax.with_sharding_constraint(grads, state_sharding) + # jax.debug.visualize_array_sharding(grad) + # jax.debug.inspect_array_sharding(grad, callback=print) + # check the output grad tree from mean + # print(jax.tree.map(jnp.shape, grad)) + + + + new_state = state.apply_gradients(grads=grads) + with jax.named_scope("sync_metrics"): + step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + + return new_state, step_metrics + +# %% +# explore data sharding +sharded_batch = next(iter(train_loader)) +# sharded_batch = jax.device_put(sharded_batch, x_sharding) +sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) +jax.debug.visualize_array_sharding(sharded_batch['input_ids']) +# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) + + + +# %% +# # prep 1 step +print("1 step for jit-ting") +with mesh: + state, metrics = train_step(initialized_state, sharded_batch) + +# %% + +# %% +# tr +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + + +print("*" * 20) +print("training start") +rng, input_rng = jax.random.split(rng) +train_time = 0 +state = initialized_state +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + steps_per_epoch = training_size // train_batch_size + train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = jax.device_put(batch, x_sharding) + with mesh: + state, train_metric = train_step(state, batch) + + # train_metrics.append(train_metric) + + + # this is for more accurate time stats, but slows down training + # train_metric['loss'].block_until_ready() + train_time = time.time() - train_start + + + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | " + f"Loss: {train_metric['loss']}, " + f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() + +# %% +# try out +# gather_state = jax.device_get(state) +# gather_batch = jax.device_get(batch) +# logits = gather_state.apply_fn( +# {'params': gather_state.params}, +# input_ids=gather_batch['input_ids'], +# attention_mask=gather_batch['attention_mask'], +# decoder_input_ids=gather_batch['decoder_input_ids'], +# decoder_attention_mask=gather_batch['decoder_attention_mask'], +# )[0] # zero because output is some structure, where first is the logit +# +# probs = nn.softmax(logits, axis=-1) +# predicted = jnp.argmax(probs, axis=-1) +# print(predicted[0]) + +# %% +main_model = custom_model.from_pretrained('t5-base') +output_dir = save_path + +# save checkpoint after each epoch and push checkpoint to the hub +if jax.process_index() == 0: + params = jax.device_get(state.params) + params = jax.tree.map(lambda x: x.astype(jnp.float32), params) + main_model.save_pretrained(output_dir, params=params) + + +# %% diff --git a/t5_jax_simple_parallel.py b/t5_jax_simple_parallel.py index 47b3d9b..a832222 100644 --- a/t5_jax_simple_parallel.py +++ b/t5_jax_simple_parallel.py @@ -27,7 +27,7 @@ os.environ.update({ "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", - "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) @@ -89,35 +89,50 @@ jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - - # %% ## get platform type from jax.extend.backend import get_backend print(get_backend().platform) print(jax.devices()) + +# %% +# create mesh +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((4,1)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec, memory_kind="device") + +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +model_sharding=mesh_sharding(PartitionSpec(None, 'model')) + + # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' -save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' +save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 -num_epochs = 5 -batch_size = 64 +num_epochs = 40 +batch_size = 128 num_train_epochs = num_epochs per_device_train_batch_size = batch_size -train_batch_size = per_device_train_batch_size * jax.device_count() +train_batch_size = per_device_train_batch_size * 2 per_device_eval_batch_size = batch_size -eval_batch_size = per_device_eval_batch_size * jax.device_count() +eval_batch_size = per_device_eval_batch_size * 2 steps_per_epoch = training_size // train_batch_size total_train_steps = steps_per_epoch * num_epochs warmup_steps = 0 -learning_rate = 5e-5 +learning_rate = 2e-3 weight_decay = 0.01 adam_beta1 = 0.9 @@ -197,21 +212,6 @@ model = main_model.module -# %% -# create mesh -print("creating mesh") -device_mesh = mesh_utils.create_device_mesh((2,2)) -print(device_mesh) - -mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) -print(mesh) - -def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: - return NamedSharding(mesh, pspec, memory_kind="device") - -x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis -model_sharding=mesh_sharding(PartitionSpec(None, 'model')) - # %% # optimizers @@ -246,7 +246,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn( def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters - layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_candidates = ["final_layer_norm", "layer_norm"] layer_norm_named_params = { layer[-2:] for layer_norm_name in layer_norm_candidates @@ -322,9 +322,9 @@ def init_fn(params, model, optimizer): # do be careful with the model init # imported models might have complicated init methods - # mask = create_mask_for_layer_norm(params) + mask = create_mask_for_layer_norm(params) # override params with bfloat version - # params= cast_floating_to(params, jnp.bfloat16, mask) + params= cast_floating_to(params, jnp.bfloat16, mask) state = train_state.TrainState.create( # Create a `TrainState`. apply_fn=model.apply, @@ -449,8 +449,8 @@ def train_step(state, batch): # %% # explore data sharding -sharded_batch = next(iter(train_loader)) -sharded_batch = jax.device_put(sharded_batch, x_sharding) +# sharded_batch = next(iter(train_loader)) +# sharded_batch = jax.device_put(sharded_batch, x_sharding) # jax.debug.visualize_array_sharding(sharded_batch['input_ids']) # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) @@ -477,7 +477,7 @@ print(f" Total optimization steps = {total_train_steps}") # jax.profiler.start_trace("./traces") -print("*" * 10) +print("*" * 50) print("training start") rng, input_rng = jax.random.split(rng) train_time = 0 @@ -489,7 +489,7 @@ for epoch in epochs: # Create sampling rng train_metrics = [] steps_per_epoch = training_size // train_batch_size - train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) + train_loader = dataprep.data_loader(rng, batch_size=train_batch_size, shuffle=True, drop_last=True) # Generate an epoch by shuffling sampling indices from the train dataset for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) @@ -514,21 +514,21 @@ for epoch in epochs: ) # jax.profiler.stop_trace() -# %% -# try out -gather_state = jax.device_get(state) -gather_batch = jax.device_get(batch) -logits = gather_state.apply_fn( - {'params': gather_state.params}, - input_ids=gather_batch['input_ids'], - attention_mask=gather_batch['attention_mask'], - decoder_input_ids=gather_batch['decoder_input_ids'], - decoder_attention_mask=gather_batch['decoder_attention_mask'], -)[0] # zero because output is some structure, where first is the logit - -probs = nn.softmax(logits, axis=-1) -predicted = jnp.argmax(probs, axis=-1) -predicted[1] +# # %% +# # try out +# gather_state = jax.device_get(state) +# gather_batch = jax.device_get(batch) +# logits = gather_state.apply_fn( +# {'params': gather_state.params}, +# input_ids=gather_batch['input_ids'], +# attention_mask=gather_batch['attention_mask'], +# decoder_input_ids=gather_batch['decoder_input_ids'], +# decoder_attention_mask=gather_batch['decoder_attention_mask'], +# )[0] # zero because output is some structure, where first is the logit +# +# probs = nn.softmax(logits, axis=-1) +# predicted = jnp.argmax(probs, axis=-1) +# print(predicted[0]) # %% main_model = custom_model.from_pretrained('t5-base') @@ -537,6 +537,7 @@ output_dir = save_path # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(state.params) + params = jax.tree.map(lambda x: x.astype(jnp.float32), params) main_model.save_pretrained(output_dir, params=params) diff --git a/t5_model/.gitignore b/t5_model/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/t5_model/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/t5_model/configuration_t5.py b/t5_model/configuration_t5.py new file mode 100644 index 0000000..3975716 --- /dev/null +++ b/t5_model/configuration_t5.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T5 model configuration""" + +from typing import Mapping + +from transformers import PretrainedConfig +from transformers import logging + +from etils import edc + +logger = logging.get_logger(__name__) + + +class T5Config(PretrainedConfig): + model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, # vocab size here + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + self.use_bfloat16 = True + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): +# @property +# def inputs(self) -> Mapping[str, Mapping[int, str]]: +# common_inputs = { +# "input_ids": {0: "batch", 1: "encoder_sequence"}, +# "attention_mask": {0: "batch", 1: "encoder_sequence"}, +# } +# if self.use_past: +# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" +# common_inputs["decoder_input_ids"] = {0: "batch"} +# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} +# else: +# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} +# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} +# +# if self.use_past: +# self.fill_with_past_key_values_(common_inputs, direction="inputs") +# +# return common_inputs +# +# @property +# def default_onnx_opset(self) -> int: +# return 13 \ No newline at end of file diff --git a/t5_model/modeling_t5_flax.py b/t5_model/modeling_t5_flax.py new file mode 100644 index 0000000..4884050 --- /dev/null +++ b/t5_model/modeling_t5_flax.py @@ -0,0 +1,1832 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax T5 model.""" + +import copy +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers import T5Config +# from dataclasses import dataclass, replace +from ml_collections import ConfigDict, FrozenConfigDict + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-t5/t5-small" +_CONFIG_FOR_DOC = "FrozenConfigDict" + +remat = nn_partitioning.remat + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +class FlaxT5LayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the T5 style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +class FlaxT5DenseActDense(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5DenseGatedActDense(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class FlaxT5LayerFF(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +class FlaxT5Attention(nn.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +class FlaxT5LayerSelfAttention(nn.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxT5Attention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5LayerCrossAttention(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class FlaxT5Block(nn.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxT5LayerSelfAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +class FlaxT5LayerCollection(nn.Module): + config: FrozenConfigDict + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxT5Block( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +class FlaxT5BlockCollection(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxT5Stack(nn.Module): + config: FrozenConfigDict + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxT5LayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +T5_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxT5PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: T5Config, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.bfloat16, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + # we convert the T5Config mutable object into a FrozenConfigDict + # we pop these 2 things because we know they are troublesome + # from manual testing making a ConfigDict out of the config + config_dict = copy.deepcopy(config.to_dict()) + config_dict.pop('architectures') + config_dict.pop('id2label') + config_dict = FrozenConfigDict(config_dict) + # all modules take in a FrozenConfigDict + module = self.module_class(config=config_dict, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + # but FlaxPreTrainedModel still takes the T5Config object + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxT5EncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(T5_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=FrozenConfigDict) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=FrozenConfigDict) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +T5_START_DOCSTRING = r""" + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`FrozenConfigDict`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5Module(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + # unfreeze + encoder_config = ConfigDict(encoder_config) + encoder_config.causal = False + # freeze + encoder_config = FrozenConfigDict(encoder_config) + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + # unfreeze + decoder_config = ConfigDict(encoder_config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + # freeze + decoder_config = FrozenConfigDict(encoder_config) + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5Model(FlaxT5PreTrainedModel): + module_class = FlaxT5Module + + +append_call_sample_docstring(FlaxT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + +FLAX_T5_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5Model + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5Model.from_pretrained("google-t5/t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class FlaxT5EncoderModule(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + # unfreeze + encoder_config = ConfigDict(encoder_config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + # freeze + encoder_config = FrozenConfigDict(encoder_config) + + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxT5EncoderModel(FlaxT5PreTrainedModel): + module_class = FlaxT5EncoderModule + + @add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) +class FlaxT5ForConditionalGenerationModule(nn.Module): + config: FrozenConfigDict + dtype: jnp.dtype = jnp.bfloat16 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + encoder_config = copy.deepcopy(self.config) + # unfreeze + encoder_config = ConfigDict(encoder_config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + # freeze + encoder_config = FrozenConfigDict(encoder_config) + + self.encoder = FlaxT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + decoder_config = copy.deepcopy(self.config) + # unfreeze + decoder_config = ConfigDict(decoder_config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + # freeze + decoder_config = FrozenConfigDict(decoder_config) + + self.decoder = FlaxT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): + module_class = FlaxT5ForConditionalGenerationModule + + @add_start_docstrings(T5_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=FrozenConfigDict) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxT5Attention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + >>> model = FlaxT5ForConditionalGeneration.from_pretrained("google-t5/t5-small") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) \ No newline at end of file