Feat: increased learning rate for effective large batch size learning
This commit is contained in:
parent
aca80720c8
commit
a817fe16cc
|
@ -1,8 +1,6 @@
|
||||||
*.ipynb
|
*.ipynb
|
||||||
t5_*/
|
|
||||||
model_checkpoints/
|
model_checkpoints/
|
||||||
exports/
|
exports/
|
||||||
modified_t5_model/
|
|
||||||
traces/
|
traces/
|
||||||
ruff.toml
|
ruff.toml
|
||||||
settings.json
|
settings.json
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
__pycache__
|
|
@ -0,0 +1,163 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""T5 model configuration"""
|
||||||
|
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class T5Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
|
||||||
|
instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the T5
|
||||||
|
[google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32128):
|
||||||
|
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
|
||||||
|
d_model (`int`, *optional*, defaults to 512):
|
||||||
|
Size of the encoder layers and the pooler layer.
|
||||||
|
d_kv (`int`, *optional*, defaults to 64):
|
||||||
|
Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
|
||||||
|
be defined as `num_heads * d_kv`.
|
||||||
|
d_ff (`int`, *optional*, defaults to 2048):
|
||||||
|
Size of the intermediate feed forward layer in each `T5Block`.
|
||||||
|
num_layers (`int`, *optional*, defaults to 6):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_decoder_layers (`int`, *optional*):
|
||||||
|
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
||||||
|
num_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
|
||||||
|
The number of buckets to use for each attention layer.
|
||||||
|
relative_attention_max_distance (`int`, *optional*, defaults to 128):
|
||||||
|
The maximum distance of the longer sequences for the bucket separation.
|
||||||
|
dropout_rate (`float`, *optional*, defaults to 0.1):
|
||||||
|
The ratio for all dropout layers.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for classifier.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
initializer_factor (`float`, *optional*, defaults to 1):
|
||||||
|
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||||
|
testing).
|
||||||
|
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
|
||||||
|
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
|
||||||
|
`"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "t5"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32128, # vocab size here
|
||||||
|
d_model=512,
|
||||||
|
d_kv=64,
|
||||||
|
d_ff=2048,
|
||||||
|
num_layers=6,
|
||||||
|
num_decoder_layers=None,
|
||||||
|
num_heads=8,
|
||||||
|
relative_attention_num_buckets=32,
|
||||||
|
relative_attention_max_distance=128,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
layer_norm_epsilon=1e-6,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
feed_forward_proj="relu",
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_kv = d_kv
|
||||||
|
self.d_ff = d_ff
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_decoder_layers = (
|
||||||
|
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
||||||
|
) # default = symmetry
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||||
|
self.relative_attention_max_distance = relative_attention_max_distance
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_factor = initializer_factor
|
||||||
|
self.feed_forward_proj = feed_forward_proj
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.use_bfloat16 = True
|
||||||
|
|
||||||
|
act_info = self.feed_forward_proj.split("-")
|
||||||
|
self.dense_act_fn = act_info[-1]
|
||||||
|
self.is_gated_act = act_info[0] == "gated"
|
||||||
|
|
||||||
|
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
|
||||||
|
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
||||||
|
"'gated-gelu' or 'relu'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# for backwards compatibility
|
||||||
|
if feed_forward_proj == "gated-gelu":
|
||||||
|
self.dense_act_fn = "gelu_new"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
# @property
|
||||||
|
# def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
# common_inputs = {
|
||||||
|
# "input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
# "attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
# }
|
||||||
|
# if self.use_past:
|
||||||
|
# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||||
|
# common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
# else:
|
||||||
|
# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
#
|
||||||
|
# if self.use_past:
|
||||||
|
# self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
#
|
||||||
|
# return common_inputs
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def default_onnx_opset(self) -> int:
|
||||||
|
# return 13
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
175
t5_jax.py
175
t5_jax.py
|
@ -25,6 +25,7 @@ import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import math
|
import math
|
||||||
|
import flax.linen as nn
|
||||||
|
|
||||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
@ -83,7 +84,7 @@ os.environ.update({
|
||||||
"NCCL_LL128_BUFFSIZE": "-2",
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
"NCCL_LL_BUFFSIZE": "-2",
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99",
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8",
|
||||||
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -103,14 +104,14 @@ except (LookupError, OSError):
|
||||||
# %%
|
# %%
|
||||||
# config options
|
# config options
|
||||||
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
save_path = 't5_5e_1_pmap'
|
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/'
|
||||||
# file_path = 'combined_data'
|
# file_path = 'combined_data'
|
||||||
split_datasets = load_from_disk(file_path)
|
split_datasets = load_from_disk(file_path)
|
||||||
training_size = len(split_datasets['train'])
|
training_size = len(split_datasets['train'])
|
||||||
# Store some constant
|
# Store some constant
|
||||||
seed = 117
|
seed = 117
|
||||||
num_epochs = 5
|
num_epochs = 40
|
||||||
batch_size = 32 # 384 is the best
|
batch_size = 64 # 384 is the best
|
||||||
num_train_epochs = num_epochs
|
num_train_epochs = num_epochs
|
||||||
per_device_train_batch_size = batch_size
|
per_device_train_batch_size = batch_size
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
@ -120,7 +121,7 @@ steps_per_epoch = training_size // train_batch_size
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
warmup_steps = 0
|
warmup_steps = 0
|
||||||
learning_rate = 2e-5
|
learning_rate = 2e-4
|
||||||
|
|
||||||
weight_decay = 0.01
|
weight_decay = 0.01
|
||||||
adam_beta1 = 0.9
|
adam_beta1 = 0.9
|
||||||
|
@ -129,7 +130,7 @@ adam_epsilon = 1e-8
|
||||||
label_smoothing_factor = 0.0
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
num_beams = 1
|
num_beams = 1
|
||||||
val_max_target_length = 86
|
val_max_target_length = 128
|
||||||
|
|
||||||
predict_with_generate = True
|
predict_with_generate = True
|
||||||
|
|
||||||
|
@ -143,7 +144,7 @@ additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>",
|
||||||
# Add the additional special tokens to the tokenizer
|
# Add the additional special tokens to the tokenizer
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
max_length = 86
|
max_length = 128
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
len(tokenizer)
|
len(tokenizer)
|
||||||
|
@ -176,51 +177,64 @@ from transformers import FlaxT5ForConditionalGeneration
|
||||||
from transformers import T5Config
|
from transformers import T5Config
|
||||||
|
|
||||||
config = T5Config()
|
config = T5Config()
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
||||||
|
"t5-base",
|
||||||
# %%
|
dtype=jnp.bfloat16,
|
||||||
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
gradient_checkpointing=True
|
||||||
# then pass the mask as follows
|
)
|
||||||
from flax import traverse_util
|
params = model.params
|
||||||
|
|
||||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
|
||||||
# useful for transformer model
|
|
||||||
# model.enable_gradient_checkpointing()
|
|
||||||
|
|
||||||
# enable bf16 except for layer_norm
|
# enable bf16 except for layer_norm
|
||||||
# flat_params = traverse_util.flatten_dict(model.params)
|
# enable bf16
|
||||||
# mask = {
|
# enable only for dense, some transformer sections, and shared
|
||||||
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
def create_mask_for_layer_norm(params):
|
||||||
# }
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
# mask = traverse_util.unflatten_dict(mask)
|
mask = {
|
||||||
# # borrowed from transformers modeling_flax_utils
|
# path: not (
|
||||||
# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
# (path[-2] == "layer_norm" and path[-1] == "weight") or
|
||||||
# """
|
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
|
||||||
# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
# (path[-2] == "o" and path[-1] == "kernel")
|
||||||
# """
|
# )
|
||||||
#
|
# for path in flat_params
|
||||||
# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
path: (
|
||||||
# def conditional_cast(param):
|
(path[-2] == "wi" and path[-1] == "weight") or
|
||||||
# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
(path[-2] == "wo" and path[-1] == "weight") or
|
||||||
# param = param.astype(dtype)
|
(path[-2] == "k" and path[-1] == "kernel") or
|
||||||
# return param
|
(path[-2] == "q" and path[-1] == "kernel") or
|
||||||
#
|
(path[-2] == "v" and path[-1] == "kernel") or
|
||||||
# if mask is None:
|
(path[-2] == "shared" and path[-1] == "embedding")
|
||||||
# return jax.tree_util.tree_map(conditional_cast, params)
|
) for path in flat_params
|
||||||
#
|
}
|
||||||
# flat_params = traverse_util.flatten_dict(params)
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
# flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
return mask
|
||||||
#
|
|
||||||
# for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
# borrowed from transformers modeling_flax_utils
|
||||||
# if masked:
|
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||||
# flat_params[key] = conditional_cast(flat_params[key])
|
"""
|
||||||
#
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||||
# return traverse_util.unflatten_dict(flat_params)
|
"""
|
||||||
#
|
|
||||||
# # Cast parameters to bfloat16 if desired
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||||
# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
def conditional_cast(param):
|
||||||
# # instead of casting the whole thing, we cast only certain parts of the tree
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||||
# params = cast_floating_to(model.params, jnp.bfloat16, mask)
|
param = param.astype(dtype)
|
||||||
|
return param
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
|
||||||
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
||||||
|
if masked:
|
||||||
|
flat_params[key] = conditional_cast(flat_params[key])
|
||||||
|
|
||||||
|
return traverse_util.unflatten_dict(flat_params)
|
||||||
|
|
||||||
|
mask = create_mask_for_layer_norm(params)
|
||||||
|
# override params with bfloat version
|
||||||
|
params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -307,31 +321,6 @@ token_datasets.set_format(
|
||||||
'labels', 'decoder_input_ids',
|
'labels', 'decoder_input_ids',
|
||||||
'decoder_attention_mask']
|
'decoder_attention_mask']
|
||||||
)
|
)
|
||||||
# %%
|
|
||||||
# check values
|
|
||||||
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
|
||||||
int_array = train_dataset[name]
|
|
||||||
if np.all((int_array >= 0) & (int_array <= 65535)):
|
|
||||||
uint16_array = int_array.astype(np.uint16)
|
|
||||||
else:
|
|
||||||
raise ValueError("Values are out of range for uint16")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
from datasets import ClassLabel, Value, Sequence
|
|
||||||
features = train_dataset.features.copy()
|
|
||||||
features['input_ids'] = Sequence(Value('uint16'))
|
|
||||||
features['attention_mask'] = Sequence(Value('bool'))
|
|
||||||
features['labels'] = Sequence(Value('uint16'))
|
|
||||||
features['decoder_input_ids'] = Sequence(Value('uint16'))
|
|
||||||
features['decoder_attention_mask'] = Sequence(Value('bool'))
|
|
||||||
train_dataset = train_dataset.cast(features)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# temp
|
|
||||||
print('data type check: ', train_dataset['decoder_attention_mask'].dtype)
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
@ -355,17 +344,11 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||||
|
|
||||||
for idx in batch_idx:
|
for idx in batch_idx:
|
||||||
batch = dataset[idx]
|
batch = dataset[idx]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: v for k, v in batch.items()}
|
||||||
|
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# # Model
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
# Initialize our training
|
# Initialize our training
|
||||||
|
@ -406,7 +389,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
# find out all LayerNorm parameters
|
# find out all LayerNorm parameters
|
||||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
|
||||||
layer_norm_named_params = {
|
layer_norm_named_params = {
|
||||||
layer[-2:]
|
layer[-2:]
|
||||||
for layer_norm_name in layer_norm_candidates
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
@ -437,13 +420,10 @@ class TrainState(train_state.TrainState):
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||||
|
|
||||||
# set bf16 for model params
|
|
||||||
# model.params = model.to_bf16(model.params)
|
|
||||||
# Cast parameters to bfloat16 if desired
|
|
||||||
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
# input all the state here
|
||||||
|
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
# label smoothed cross entropy
|
# label smoothed cross entropy
|
||||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
@ -485,17 +465,17 @@ def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
num_labels = jax.lax.psum(num_labels, "batch")
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
# true loss = total loss / total samples
|
# true loss = total loss / total samples
|
||||||
# loss = jax.lax.psum(loss, "batch")
|
loss = jax.lax.psum(loss, "batch")
|
||||||
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
# true grad = total grad / total samples
|
# true grad = total grad / total samples
|
||||||
grad = jax.lax.psum(grad, "batch")
|
grad = jax.lax.psum(grad, "batch")
|
||||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||||
|
|
||||||
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
# return new_state, metrics
|
return new_state, metrics
|
||||||
return new_state
|
# return new_state
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
max_length = (
|
max_length = (
|
||||||
|
@ -549,25 +529,24 @@ for epoch in epochs:
|
||||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
batch = next(train_loader)
|
batch = next(train_loader)
|
||||||
batch = shard(batch)
|
batch = shard(batch)
|
||||||
state = p_train_step(state, batch)
|
state, train_metric = p_train_step(state, batch)
|
||||||
# train_metrics.append(train_metric)
|
# train_metrics.append(train_metric)
|
||||||
|
|
||||||
train_time = time.time() - train_start
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
# train_metric = unreplicate(train_metric)
|
train_metric = unreplicate(train_metric)
|
||||||
# train_metric['loss'].block_until_ready()
|
train_metric['loss'].block_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
# f"Learning Rate:{train_metric['learning_rate']}, "
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
f"Last train time: {train_time})"
|
f"Last train time: {train_time})"
|
||||||
)
|
)
|
||||||
# jax.profiler.stop_trace()
|
# jax.profiler.stop_trace()
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
output_dir = save_path
|
output_dir = save_path
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
|
|
|
@ -66,7 +66,8 @@ import orbax.checkpoint as ocp
|
||||||
|
|
||||||
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
||||||
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
||||||
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
|
||||||
|
data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
||||||
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
||||||
|
|
||||||
# Ensure to include 'ships_idx' in the fields list
|
# Ensure to include 'ships_idx' in the fields list
|
||||||
|
@ -97,7 +98,6 @@ test_dataset = Dataset.from_list(process_df(df))
|
||||||
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
||||||
from transformers import FlaxT5ForConditionalGeneration
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
||||||
model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name
|
|
||||||
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
|
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
|
||||||
params = model.params
|
params = model.params
|
||||||
|
|
||||||
|
@ -275,11 +275,12 @@ df['p_property_correct'] = df['p_property'] == df['property']
|
||||||
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
||||||
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
||||||
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
||||||
# %%
|
|
||||||
df[~df["p_property_correct"]]
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df['p_thing']
|
# df[~df["p_property_correct"]]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# df['p_thing']
|
||||||
# %%
|
# %%
|
||||||
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||||||
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||||||
|
|
|
@ -0,0 +1,610 @@
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
|
)
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
# model checkpointing and saving utilities
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.training import checkpoints, train_state
|
||||||
|
from flax import struct, serialization
|
||||||
|
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
|
||||||
|
# for memory tracking
|
||||||
|
# from jax_smi import initialise_tracking
|
||||||
|
# initialise_tracking()
|
||||||
|
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
## get platform type
|
||||||
|
from jax.extend.backend import get_backend
|
||||||
|
print(get_backend().platform)
|
||||||
|
print(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
||||||
|
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 64
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 5e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
print("preparing data")
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
|
||||||
|
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
|
||||||
|
main_model = custom_model.from_pretrained(
|
||||||
|
"t5-base",
|
||||||
|
dtype=jnp.bfloat16,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
)
|
||||||
|
params = main_model.params
|
||||||
|
# pretrained_params = model.params
|
||||||
|
model = main_model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # testing config hashability
|
||||||
|
# # some explanation:
|
||||||
|
# # The PreTrainedModel class loads a T5Config model that is not hashable because
|
||||||
|
# # it is a complicated class that pretends to be a dataclass.
|
||||||
|
# # The solution is to extract a dict from it, then make a ConfigDict from
|
||||||
|
# # ml_collections library so that we can get values via the "." operator.
|
||||||
|
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
|
||||||
|
# # modify the config before passing to the next layer
|
||||||
|
# from transformers import T5Config
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config
|
||||||
|
# from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
#
|
||||||
|
# config = T5Config.from_pretrained("t5-base").to_dict()
|
||||||
|
# config.pop('architectures')
|
||||||
|
# config.pop('id2label')
|
||||||
|
# # test if it works
|
||||||
|
# frozen_config = FrozenConfigDict(config)
|
||||||
|
# # test hash
|
||||||
|
# hash(frozen_config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # print model
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# model.tabulate(
|
||||||
|
# input_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
# console_kwargs={"force_jupyter": True}
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create mesh
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimizers
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
print("compile")
|
||||||
|
|
||||||
|
# enable bf16
|
||||||
|
# enable only for dense, some transformer sections, and shared
|
||||||
|
def create_mask_for_layer_norm(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
mask = {
|
||||||
|
# path: not (
|
||||||
|
# (path[-2] == "layer_norm" and path[-1] == "weight") or
|
||||||
|
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
|
||||||
|
# (path[-2] == "o" and path[-1] == "kernel")
|
||||||
|
# )
|
||||||
|
# for path in flat_params
|
||||||
|
path: (
|
||||||
|
(path[-2] == "wi" and path[-1] == "weight") or
|
||||||
|
(path[-2] == "wo" and path[-1] == "weight") or
|
||||||
|
(path[-2] == "k" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "q" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "v" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "shared" and path[-1] == "embedding")
|
||||||
|
) for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# borrowed from transformers modeling_flax_utils
|
||||||
|
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||||
|
def conditional_cast(param):
|
||||||
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||||
|
param = param.astype(dtype)
|
||||||
|
return param
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
|
||||||
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
||||||
|
if masked:
|
||||||
|
flat_params[key] = conditional_cast(flat_params[key])
|
||||||
|
|
||||||
|
return traverse_util.unflatten_dict(flat_params)
|
||||||
|
|
||||||
|
# create init_fn to produce sharded state
|
||||||
|
def init_fn(params, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
|
||||||
|
# mask = create_mask_for_layer_norm(params)
|
||||||
|
# override params with bfloat version
|
||||||
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=adamw), params)
|
||||||
|
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
# print(state_sharding)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# replace the params tree with the new modified tree
|
||||||
|
# create partitions for model
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
# set_partitions freezes the params on return
|
||||||
|
model_part_spec = set_partitions(unfreeze(params))
|
||||||
|
# p is already a partition spec
|
||||||
|
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
|
||||||
|
|
||||||
|
# get pspec for opt_state
|
||||||
|
def get_opt_spec(x):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
return unfreeze(model_named_sharding)
|
||||||
|
# return an empty partspec
|
||||||
|
return mesh_sharding((PartitionSpec()))
|
||||||
|
|
||||||
|
# this function replaces the empty model params spec with the 'model_named_shard'
|
||||||
|
state_sharding = jax.tree.map(
|
||||||
|
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
|
||||||
|
out_shardings=state_sharding # but returned initialized_state is sharded
|
||||||
|
)
|
||||||
|
initialized_state = jit_init_fn(params, model, adamw)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
logits = jnp.asarray(logits, dtype=jnp.float32)
|
||||||
|
logits = logits.astype(jnp.float32)
|
||||||
|
soft_labels = soft_labels.astype(jnp.float32)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.mean()
|
||||||
|
# num_labels = padding_mask.mean()
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# gradient accumulation
|
||||||
|
def accumulate_gradients_loop(
|
||||||
|
state,
|
||||||
|
batch,
|
||||||
|
minibatch_size: int,
|
||||||
|
loss_fn: Callable,
|
||||||
|
) -> Tuple[PyTree, Metrics]:
|
||||||
|
"""Calculate gradients and metrics for a batch using gradient accumulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current training state.
|
||||||
|
batch: Full training batch.
|
||||||
|
rng: Random number generator to use.
|
||||||
|
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
|
||||||
|
loss_fn: Loss function to calculate gradients and metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with accumulated gradients and metrics over the minibatches.
|
||||||
|
"""
|
||||||
|
batch_size = batch['input_ids'].shape[0]
|
||||||
|
# minibatch_size = batch_size // num_minibatches
|
||||||
|
num_minibatches = batch_size // minibatch_size
|
||||||
|
# Define gradient function for single minibatch.
|
||||||
|
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
|
||||||
|
# otherwise it returns (value, gradient), where value is the actual output
|
||||||
|
# of the function, hence the "value" of the namesake
|
||||||
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
|
||||||
|
# Prepare loop variables.
|
||||||
|
grads = None
|
||||||
|
metrics = None
|
||||||
|
for minibatch_idx in range(num_minibatches):
|
||||||
|
with jax.named_scope(f"minibatch_{minibatch_idx}"):
|
||||||
|
# Split the batch into minibatches.
|
||||||
|
start = minibatch_idx * minibatch_size
|
||||||
|
end = start + minibatch_size
|
||||||
|
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
|
||||||
|
# Calculate gradients and metrics for the minibatch.
|
||||||
|
# missing value is mean loss of batch
|
||||||
|
loss, step_grads = grad_fn(
|
||||||
|
state.params, minibatch
|
||||||
|
)
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
|
||||||
|
# Accumulate gradients and metrics across minibatches.
|
||||||
|
if grads is None:
|
||||||
|
grads = step_grads
|
||||||
|
metrics = step_metrics
|
||||||
|
else:
|
||||||
|
# accumulation adder
|
||||||
|
grads = jax.tree.map(jnp.add, grads, step_grads)
|
||||||
|
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
|
||||||
|
# Average gradients over minibatches.
|
||||||
|
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
|
||||||
|
return grads, metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# single device code annotated with jax.jit
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
# state is state_sharding initialized from init_fn
|
||||||
|
# x_sharding is data sharded explicitly later
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
|
||||||
|
donate_argnames=('state'),
|
||||||
|
)
|
||||||
|
def train_step(state, batch):
|
||||||
|
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
def compute_loss(params, batch):
|
||||||
|
# check constraints
|
||||||
|
# frozen dict not allowed as sharding object
|
||||||
|
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
|
||||||
|
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
|
||||||
|
logits = state.apply_fn(
|
||||||
|
{'params': params},
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
# use labels here
|
||||||
|
# loss, num_labels = loss_fn(
|
||||||
|
loss = loss_fn(
|
||||||
|
logits,
|
||||||
|
batch["labels"],
|
||||||
|
batch["decoder_attention_mask"],
|
||||||
|
label_smoothing_factor)
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# # compute gradients through computational graph
|
||||||
|
# # allow values to pass through
|
||||||
|
# grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
|
||||||
|
# (loss), grad = grad_fn(state.params, batch)
|
||||||
|
# # num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
# new_state = state.apply_gradients(grads=grad)
|
||||||
|
# with jax.named_scope("sync_metrics"):
|
||||||
|
# step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
|
||||||
|
# use gradient accumulation
|
||||||
|
grads, step_metrics = accumulate_gradients_loop(
|
||||||
|
state=state,
|
||||||
|
batch=batch,
|
||||||
|
minibatch_size=32,
|
||||||
|
loss_fn=compute_loss
|
||||||
|
)
|
||||||
|
new_state = state.apply_gradients(grads=grads)
|
||||||
|
|
||||||
|
return new_state, step_metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# explore data sharding
|
||||||
|
sharded_batch = next(iter(train_loader))
|
||||||
|
sharded_batch = jax.device_put(sharded_batch, x_sharding)
|
||||||
|
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
|
||||||
|
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # prep 1 step
|
||||||
|
# print("1 step for jit-ting")
|
||||||
|
# with mesh:
|
||||||
|
# state, metrics = train_step(initialized_state, sharded_batch)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# tr
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
|
||||||
|
print("*" * 10)
|
||||||
|
print("training start")
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
state = initialized_state
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
batch = jax.device_put(batch, x_sharding)
|
||||||
|
with mesh:
|
||||||
|
state, train_metric = train_step(state, batch)
|
||||||
|
|
||||||
|
# train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
|
||||||
|
# this is for more accurate time stats, but slows down training
|
||||||
|
# train_metric['loss'].block_until_ready()
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
f"Loss: {train_metric['loss']}, "
|
||||||
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# try out
|
||||||
|
gather_state = jax.device_get(state)
|
||||||
|
gather_batch = jax.device_get(batch)
|
||||||
|
logits = gather_state.apply_fn(
|
||||||
|
{'params': gather_state.params},
|
||||||
|
input_ids=gather_batch['input_ids'],
|
||||||
|
attention_mask=gather_batch['attention_mask'],
|
||||||
|
decoder_input_ids=gather_batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=gather_batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
|
||||||
|
probs = nn.softmax(logits, axis=-1)
|
||||||
|
predicted = jnp.argmax(probs, axis=-1)
|
||||||
|
print("sample output")
|
||||||
|
print(predicted[1])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
main_model = custom_model.from_pretrained('t5-base')
|
||||||
|
output_dir = save_path
|
||||||
|
|
||||||
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
params = jax.device_get(state.params)
|
||||||
|
main_model.save_pretrained(output_dir, params=params)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,550 @@
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
|
)
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
# model checkpointing and saving utilities
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.training import checkpoints, train_state
|
||||||
|
from flax import struct, serialization
|
||||||
|
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
|
||||||
|
# for memory tracking
|
||||||
|
# from jax_smi import initialise_tracking
|
||||||
|
# initialise_tracking()
|
||||||
|
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
## get platform type
|
||||||
|
from jax.extend.backend import get_backend
|
||||||
|
print(get_backend().platform)
|
||||||
|
print(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
||||||
|
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 40
|
||||||
|
batch_size = 32 # do not go beyond 128, 64 is good
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 2e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
print("preparing data")
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
|
||||||
|
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
|
||||||
|
main_model = custom_model.from_pretrained(
|
||||||
|
"t5-base",
|
||||||
|
dtype=jnp.bfloat16,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
)
|
||||||
|
params = main_model.params
|
||||||
|
# pretrained_params = model.params
|
||||||
|
model = main_model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # testing config hashability
|
||||||
|
# # some explanation:
|
||||||
|
# # The PreTrainedModel class loads a T5Config model that is not hashable because
|
||||||
|
# # it is a complicated class that pretends to be a dataclass.
|
||||||
|
# # The solution is to extract a dict from it, then make a ConfigDict from
|
||||||
|
# # ml_collections library so that we can get values via the "." operator.
|
||||||
|
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
|
||||||
|
# # modify the config before passing to the next layer
|
||||||
|
# from transformers import T5Config
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config
|
||||||
|
# from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
#
|
||||||
|
# config = T5Config.from_pretrained("t5-base").to_dict()
|
||||||
|
# config.pop('architectures')
|
||||||
|
# config.pop('id2label')
|
||||||
|
# # test if it works
|
||||||
|
# frozen_config = FrozenConfigDict(config)
|
||||||
|
# # test hash
|
||||||
|
# hash(frozen_config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # print model
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# model.tabulate(
|
||||||
|
# input_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
# console_kwargs={"force_jupyter": True}
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create mesh
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="unpinned_host")
|
||||||
|
else:
|
||||||
|
# if gpu
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimizers
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
logits = jnp.asarray(logits, dtype=jnp.float32)
|
||||||
|
logits = logits.astype(jnp.float32)
|
||||||
|
soft_labels = soft_labels.astype(jnp.float32)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
mean_loss = loss.mean()
|
||||||
|
# num_labels = padding_mask.mean()
|
||||||
|
return mean_loss # , num_labels
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
################################################################
|
||||||
|
# old jit in_shardings method
|
||||||
|
|
||||||
|
# create init_fn to produce sharded state
|
||||||
|
def init_fn(params, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
|
||||||
|
# mask = create_mask_for_layer_norm(params)
|
||||||
|
# override params with bfloat version
|
||||||
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=adamw), params)
|
||||||
|
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
# set_partitions freezes the params on return
|
||||||
|
model_part_spec = set_partitions(unfreeze(params))
|
||||||
|
# p is already a partition spec
|
||||||
|
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
|
||||||
|
|
||||||
|
# get pspec for opt_state
|
||||||
|
def get_opt_spec(x):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
return unfreeze(model_named_sharding)
|
||||||
|
# return an empty partspec
|
||||||
|
return mesh_sharding((PartitionSpec()))
|
||||||
|
|
||||||
|
# this function replaces the empty model params spec with the 'model_named_shard'
|
||||||
|
state_sharding = jax.tree.map(
|
||||||
|
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
|
||||||
|
out_shardings=state_sharding # but returned initialized_state is sharded
|
||||||
|
)
|
||||||
|
initialized_state = jit_init_fn(params, model, adamw)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
logits = jnp.asarray(logits, dtype=jnp.float32)
|
||||||
|
logits = logits.astype(jnp.float32)
|
||||||
|
soft_labels = soft_labels.astype(jnp.float32)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.mean()
|
||||||
|
# num_labels = padding_mask.mean()
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# single device code annotated with jax.jit
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
# state is state_sharding initialized from init_fn
|
||||||
|
# x_sharding is data sharded explicitly later
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
|
||||||
|
donate_argnames=('state'),
|
||||||
|
)
|
||||||
|
def train_step(state, batch):
|
||||||
|
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
# computes loss per shard
|
||||||
|
def compute_loss(params, batch):
|
||||||
|
# check constraints
|
||||||
|
# frozen dict not allowed as sharding object
|
||||||
|
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
|
||||||
|
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
|
||||||
|
logits = state.apply_fn(
|
||||||
|
{'params': params},
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
|
||||||
|
# logits sharding
|
||||||
|
# data, None, model
|
||||||
|
#
|
||||||
|
print("logits")
|
||||||
|
jax.debug.inspect_array_sharding(logits, callback=print)
|
||||||
|
# use labels here
|
||||||
|
# loss, num_labels = loss_fn(
|
||||||
|
loss = loss_fn(
|
||||||
|
logits,
|
||||||
|
batch["labels"],
|
||||||
|
batch["decoder_attention_mask"],
|
||||||
|
label_smoothing_factor)
|
||||||
|
# loss sharding
|
||||||
|
# it gives PartitionSpec(), which implies a reduction already happened
|
||||||
|
print("loss")
|
||||||
|
jax.debug.inspect_array_sharding(loss, callback=print)
|
||||||
|
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
# allow values to pass through
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
|
||||||
|
batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
|
||||||
|
(loss), grads = grad_fn(state.params, batch)
|
||||||
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# so far we have been operating from within each shard
|
||||||
|
# we need to sync gradients across devices
|
||||||
|
# we bring all gradients together onto a single device
|
||||||
|
# jax.debug.inspect_array_sharding(grads, callback=print)
|
||||||
|
grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec()))
|
||||||
|
# grads = jax.lax.with_sharding_constraint(grads, state_sharding)
|
||||||
|
# jax.debug.visualize_array_sharding(grad)
|
||||||
|
# jax.debug.inspect_array_sharding(grad, callback=print)
|
||||||
|
# check the output grad tree from mean
|
||||||
|
# print(jax.tree.map(jnp.shape, grad))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
new_state = state.apply_gradients(grads=grads)
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
|
||||||
|
return new_state, step_metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# explore data sharding
|
||||||
|
sharded_batch = next(iter(train_loader))
|
||||||
|
# sharded_batch = jax.device_put(sharded_batch, x_sharding)
|
||||||
|
sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
|
||||||
|
jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
|
||||||
|
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # prep 1 step
|
||||||
|
print("1 step for jit-ting")
|
||||||
|
with mesh:
|
||||||
|
state, metrics = train_step(initialized_state, sharded_batch)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# tr
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
|
||||||
|
print("*" * 20)
|
||||||
|
print("training start")
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
state = initialized_state
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
batch = jax.device_put(batch, x_sharding)
|
||||||
|
with mesh:
|
||||||
|
state, train_metric = train_step(state, batch)
|
||||||
|
|
||||||
|
# train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
|
||||||
|
# this is for more accurate time stats, but slows down training
|
||||||
|
# train_metric['loss'].block_until_ready()
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
f"Loss: {train_metric['loss']}, "
|
||||||
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# try out
|
||||||
|
# gather_state = jax.device_get(state)
|
||||||
|
# gather_batch = jax.device_get(batch)
|
||||||
|
# logits = gather_state.apply_fn(
|
||||||
|
# {'params': gather_state.params},
|
||||||
|
# input_ids=gather_batch['input_ids'],
|
||||||
|
# attention_mask=gather_batch['attention_mask'],
|
||||||
|
# decoder_input_ids=gather_batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=gather_batch['decoder_attention_mask'],
|
||||||
|
# )[0] # zero because output is some structure, where first is the logit
|
||||||
|
#
|
||||||
|
# probs = nn.softmax(logits, axis=-1)
|
||||||
|
# predicted = jnp.argmax(probs, axis=-1)
|
||||||
|
# print(predicted[0])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
main_model = custom_model.from_pretrained('t5-base')
|
||||||
|
output_dir = save_path
|
||||||
|
|
||||||
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
params = jax.device_get(state.params)
|
||||||
|
params = jax.tree.map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
main_model.save_pretrained(output_dir, params=params)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -27,7 +27,7 @@ os.environ.update({
|
||||||
"NCCL_LL128_BUFFSIZE": "-2",
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
"NCCL_LL_BUFFSIZE": "-2",
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
|
||||||
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -89,35 +89,50 @@ jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
## get platform type
|
## get platform type
|
||||||
from jax.extend.backend import get_backend
|
from jax.extend.backend import get_backend
|
||||||
print(get_backend().platform)
|
print(get_backend().platform)
|
||||||
print(jax.devices())
|
print(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create mesh
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((4,1))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# config options
|
# config options
|
||||||
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
||||||
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
|
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/'
|
||||||
# file_path = 'combined_data'
|
# file_path = 'combined_data'
|
||||||
split_datasets = load_from_disk(file_path)
|
split_datasets = load_from_disk(file_path)
|
||||||
training_size = len(split_datasets['train'])
|
training_size = len(split_datasets['train'])
|
||||||
# Store some constant
|
# Store some constant
|
||||||
seed = 117
|
seed = 117
|
||||||
num_epochs = 5
|
num_epochs = 40
|
||||||
batch_size = 64
|
batch_size = 128
|
||||||
num_train_epochs = num_epochs
|
num_train_epochs = num_epochs
|
||||||
per_device_train_batch_size = batch_size
|
per_device_train_batch_size = batch_size
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
train_batch_size = per_device_train_batch_size * 2
|
||||||
per_device_eval_batch_size = batch_size
|
per_device_eval_batch_size = batch_size
|
||||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
eval_batch_size = per_device_eval_batch_size * 2
|
||||||
steps_per_epoch = training_size // train_batch_size
|
steps_per_epoch = training_size // train_batch_size
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
warmup_steps = 0
|
warmup_steps = 0
|
||||||
learning_rate = 5e-5
|
learning_rate = 2e-3
|
||||||
|
|
||||||
weight_decay = 0.01
|
weight_decay = 0.01
|
||||||
adam_beta1 = 0.9
|
adam_beta1 = 0.9
|
||||||
|
@ -197,21 +212,6 @@ model = main_model.module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# create mesh
|
|
||||||
print("creating mesh")
|
|
||||||
device_mesh = mesh_utils.create_device_mesh((2,2))
|
|
||||||
print(device_mesh)
|
|
||||||
|
|
||||||
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
||||||
print(mesh)
|
|
||||||
|
|
||||||
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
||||||
return NamedSharding(mesh, pspec, memory_kind="device")
|
|
||||||
|
|
||||||
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
|
||||||
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# optimizers
|
# optimizers
|
||||||
|
@ -246,7 +246,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
def decay_mask_fn(params):
|
def decay_mask_fn(params):
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
# find out all LayerNorm parameters
|
# find out all LayerNorm parameters
|
||||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
|
||||||
layer_norm_named_params = {
|
layer_norm_named_params = {
|
||||||
layer[-2:]
|
layer[-2:]
|
||||||
for layer_norm_name in layer_norm_candidates
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
@ -322,9 +322,9 @@ def init_fn(params, model, optimizer):
|
||||||
# do be careful with the model init
|
# do be careful with the model init
|
||||||
# imported models might have complicated init methods
|
# imported models might have complicated init methods
|
||||||
|
|
||||||
# mask = create_mask_for_layer_norm(params)
|
mask = create_mask_for_layer_norm(params)
|
||||||
# override params with bfloat version
|
# override params with bfloat version
|
||||||
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
state = train_state.TrainState.create( # Create a `TrainState`.
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
apply_fn=model.apply,
|
apply_fn=model.apply,
|
||||||
|
@ -449,8 +449,8 @@ def train_step(state, batch):
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# explore data sharding
|
# explore data sharding
|
||||||
sharded_batch = next(iter(train_loader))
|
# sharded_batch = next(iter(train_loader))
|
||||||
sharded_batch = jax.device_put(sharded_batch, x_sharding)
|
# sharded_batch = jax.device_put(sharded_batch, x_sharding)
|
||||||
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
|
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
|
||||||
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
|
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
|
||||||
|
|
||||||
|
@ -477,7 +477,7 @@ print(f" Total optimization steps = {total_train_steps}")
|
||||||
# jax.profiler.start_trace("./traces")
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
|
||||||
print("*" * 10)
|
print("*" * 50)
|
||||||
print("training start")
|
print("training start")
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
train_time = 0
|
train_time = 0
|
||||||
|
@ -489,7 +489,7 @@ for epoch in epochs:
|
||||||
# Create sampling rng
|
# Create sampling rng
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
steps_per_epoch = training_size // train_batch_size
|
steps_per_epoch = training_size // train_batch_size
|
||||||
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
train_loader = dataprep.data_loader(rng, batch_size=train_batch_size, shuffle=True, drop_last=True)
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
batch = next(train_loader)
|
batch = next(train_loader)
|
||||||
|
@ -514,21 +514,21 @@ for epoch in epochs:
|
||||||
)
|
)
|
||||||
# jax.profiler.stop_trace()
|
# jax.profiler.stop_trace()
|
||||||
|
|
||||||
# %%
|
# # %%
|
||||||
# try out
|
# # try out
|
||||||
gather_state = jax.device_get(state)
|
# gather_state = jax.device_get(state)
|
||||||
gather_batch = jax.device_get(batch)
|
# gather_batch = jax.device_get(batch)
|
||||||
logits = gather_state.apply_fn(
|
# logits = gather_state.apply_fn(
|
||||||
{'params': gather_state.params},
|
# {'params': gather_state.params},
|
||||||
input_ids=gather_batch['input_ids'],
|
# input_ids=gather_batch['input_ids'],
|
||||||
attention_mask=gather_batch['attention_mask'],
|
# attention_mask=gather_batch['attention_mask'],
|
||||||
decoder_input_ids=gather_batch['decoder_input_ids'],
|
# decoder_input_ids=gather_batch['decoder_input_ids'],
|
||||||
decoder_attention_mask=gather_batch['decoder_attention_mask'],
|
# decoder_attention_mask=gather_batch['decoder_attention_mask'],
|
||||||
)[0] # zero because output is some structure, where first is the logit
|
# )[0] # zero because output is some structure, where first is the logit
|
||||||
|
#
|
||||||
probs = nn.softmax(logits, axis=-1)
|
# probs = nn.softmax(logits, axis=-1)
|
||||||
predicted = jnp.argmax(probs, axis=-1)
|
# predicted = jnp.argmax(probs, axis=-1)
|
||||||
predicted[1]
|
# print(predicted[0])
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
main_model = custom_model.from_pretrained('t5-base')
|
main_model = custom_model.from_pretrained('t5-base')
|
||||||
|
@ -537,6 +537,7 @@ output_dir = save_path
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(state.params)
|
params = jax.device_get(state.params)
|
||||||
|
params = jax.tree.map(lambda x: x.astype(jnp.float32), params)
|
||||||
main_model.save_pretrained(output_dir, params=params)
|
main_model.save_pretrained(output_dir, params=params)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
__pycache__
|
|
@ -0,0 +1,118 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""T5 model configuration"""
|
||||||
|
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
from etils import edc
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class T5Config(PretrainedConfig):
|
||||||
|
model_type = "t5"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32128, # vocab size here
|
||||||
|
d_model=512,
|
||||||
|
d_kv=64,
|
||||||
|
d_ff=2048,
|
||||||
|
num_layers=6,
|
||||||
|
num_decoder_layers=None,
|
||||||
|
num_heads=8,
|
||||||
|
relative_attention_num_buckets=32,
|
||||||
|
relative_attention_max_distance=128,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
layer_norm_epsilon=1e-6,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
feed_forward_proj="relu",
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_kv = d_kv
|
||||||
|
self.d_ff = d_ff
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_decoder_layers = (
|
||||||
|
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
||||||
|
) # default = symmetry
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||||
|
self.relative_attention_max_distance = relative_attention_max_distance
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_factor = initializer_factor
|
||||||
|
self.feed_forward_proj = feed_forward_proj
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.use_bfloat16 = True
|
||||||
|
|
||||||
|
act_info = self.feed_forward_proj.split("-")
|
||||||
|
self.dense_act_fn = act_info[-1]
|
||||||
|
self.is_gated_act = act_info[0] == "gated"
|
||||||
|
|
||||||
|
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
|
||||||
|
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
||||||
|
"'gated-gelu' or 'relu'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# for backwards compatibility
|
||||||
|
if feed_forward_proj == "gated-gelu":
|
||||||
|
self.dense_act_fn = "gelu_new"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||||
|
# @property
|
||||||
|
# def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
# common_inputs = {
|
||||||
|
# "input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
# "attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||||
|
# }
|
||||||
|
# if self.use_past:
|
||||||
|
# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||||
|
# common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||||
|
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
# else:
|
||||||
|
# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
|
#
|
||||||
|
# if self.use_past:
|
||||||
|
# self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||||
|
#
|
||||||
|
# return common_inputs
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def default_onnx_opset(self) -> int:
|
||||||
|
# return 13
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue