Feat: increased learning rate for effective large batch size learning

This commit is contained in:
Richard Wong 2024-09-22 22:28:41 +09:00
parent aca80720c8
commit a817fe16cc
13 changed files with 7178 additions and 151 deletions

2
.gitignore vendored
View File

@ -1,8 +1,6 @@
*.ipynb
t5_*/
model_checkpoints/
exports/
modified_t5_model/
traces/
ruff.toml
settings.json

1
parallel/t5_model/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,163 @@
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model configuration"""
from typing import Mapping
from transformers import PretrainedConfig
from transformers import logging
logger = logging.get_logger(__name__)
class T5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the T5
[google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Arguments:
vocab_size (`int`, *optional*, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
d_model (`int`, *optional*, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (`int`, *optional*, defaults to 64):
Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
be defined as `num_heads * d_kv`.
d_ff (`int`, *optional*, defaults to 2048):
Size of the intermediate feed forward layer in each `T5Block`.
num_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (`int`, *optional*):
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
num_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
The number of buckets to use for each attention layer.
relative_attention_max_distance (`int`, *optional*, defaults to 128):
The maximum distance of the longer sequences for the bucket separation.
dropout_rate (`float`, *optional*, defaults to 0.1):
The ratio for all dropout layers.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
`"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=32128, # vocab size here
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
classifier_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.classifier_dropout = classifier_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.use_bfloat16 = True
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# @property
# def inputs(self) -> Mapping[str, Mapping[int, str]]:
# common_inputs = {
# "input_ids": {0: "batch", 1: "encoder_sequence"},
# "attention_mask": {0: "batch", 1: "encoder_sequence"},
# }
# if self.use_past:
# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
# common_inputs["decoder_input_ids"] = {0: "batch"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
# else:
# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
#
# if self.use_past:
# self.fill_with_past_key_values_(common_inputs, direction="inputs")
#
# return common_inputs
#
# @property
# def default_onnx_opset(self) -> int:
# return 13

File diff suppressed because it is too large Load Diff

1717
parallel/t5_model/pure_t5.py Normal file

File diff suppressed because it is too large Load Diff

175
t5_jax.py
View File

@ -25,6 +25,7 @@ import numpy as np
from functools import partial
from typing import Callable, Optional
import math
import flax.linen as nn
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "bfloat16")
@ -83,7 +84,7 @@ os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
@ -103,14 +104,14 @@ except (LookupError, OSError):
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_5e_1_pmap'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 32 # 384 is the best
num_epochs = 40
batch_size = 64 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
@ -120,7 +121,7 @@ steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
learning_rate = 2e-4
weight_decay = 0.01
adam_beta1 = 0.9
@ -129,7 +130,7 @@ adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 86
val_max_target_length = 128
predict_with_generate = True
@ -143,7 +144,7 @@ additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>",
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
max_length = 128
# %%
len(tokenizer)
@ -176,51 +177,64 @@ from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
config = T5Config()
# %%
# If you want don't want to cast certain parameters (for example layer norm bias and scale)
# then pass the mask as follows
from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
# useful for transformer model
# model.enable_gradient_checkpointing()
model = FlaxT5ForConditionalGeneration.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True
)
params = model.params
# enable bf16 except for layer_norm
# flat_params = traverse_util.flatten_dict(model.params)
# mask = {
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
# }
# mask = traverse_util.unflatten_dict(mask)
# # borrowed from transformers modeling_flax_utils
# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
# """
# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
# """
#
# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
# def conditional_cast(param):
# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
# param = param.astype(dtype)
# return param
#
# if mask is None:
# return jax.tree_util.tree_map(conditional_cast, params)
#
# flat_params = traverse_util.flatten_dict(params)
# flat_mask, _ = jax.tree_util.tree_flatten(mask)
#
# for masked, key in zip(flat_mask, sorted(flat_params.keys())):
# if masked:
# flat_params[key] = conditional_cast(flat_params[key])
#
# return traverse_util.unflatten_dict(flat_params)
#
# # Cast parameters to bfloat16 if desired
# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# # instead of casting the whole thing, we cast only certain parts of the tree
# params = cast_floating_to(model.params, jnp.bfloat16, mask)
# enable bf16
# enable only for dense, some transformer sections, and shared
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
# path: not (
# (path[-2] == "layer_norm" and path[-1] == "weight") or
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
# (path[-2] == "o" and path[-1] == "kernel")
# )
# for path in flat_params
path: (
(path[-2] == "wi" and path[-1] == "weight") or
(path[-2] == "wo" and path[-1] == "weight") or
(path[-2] == "k" and path[-1] == "kernel") or
(path[-2] == "q" and path[-1] == "kernel") or
(path[-2] == "v" and path[-1] == "kernel") or
(path[-2] == "shared" and path[-1] == "embedding")
) for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
mask = create_mask_for_layer_norm(params)
# override params with bfloat version
params= cast_floating_to(params, jnp.bfloat16, mask)
# %%
@ -307,31 +321,6 @@ token_datasets.set_format(
'labels', 'decoder_input_ids',
'decoder_attention_mask']
)
# %%
# check values
for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
int_array = train_dataset[name]
if np.all((int_array >= 0) & (int_array <= 65535)):
uint16_array = int_array.astype(np.uint16)
else:
raise ValueError("Values are out of range for uint16")
# %%
from datasets import ClassLabel, Value, Sequence
features = train_dataset.features.copy()
features['input_ids'] = Sequence(Value('uint16'))
features['attention_mask'] = Sequence(Value('bool'))
features['labels'] = Sequence(Value('uint16'))
features['decoder_input_ids'] = Sequence(Value('uint16'))
features['decoder_attention_mask'] = Sequence(Value('bool'))
train_dataset = train_dataset.cast(features)
# %%
# temp
print('data type check: ', train_dataset['decoder_attention_mask'].dtype)
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
@ -355,17 +344,11 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = {k: v for k, v in batch.items()}
yield batch
# %% [markdown]
# # Model
#
#
#
# %%
# Initialize our training
@ -406,7 +389,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn(
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
@ -437,13 +420,10 @@ class TrainState(train_state.TrainState):
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# set bf16 for model params
# model.params = model.to_bf16(model.params)
# Cast parameters to bfloat16 if desired
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# input all the state here
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@ -485,17 +465,17 @@ def train_step(state, batch, label_smoothing_factor=0.0):
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
# loss = jax.lax.psum(loss, "batch")
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# return new_state, metrics
return new_state
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# return new_state
# Define generation function
max_length = (
@ -549,25 +529,24 @@ for epoch in epochs:
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state = p_train_step(state, batch)
state, train_metric = p_train_step(state, batch)
# train_metrics.append(train_metric)
train_time = time.time() - train_start
# train_metric = unreplicate(train_metric)
# train_metric['loss'].block_until_ready()
train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready()
epochs.write(
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Epoch... ({epoch + 1}/{num_epochs} | "
# f"Learning Rate:{train_metric['learning_rate']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:

View File

@ -66,7 +66,8 @@ import orbax.checkpoint as ocp
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
@ -97,7 +98,6 @@ test_dataset = Dataset.from_list(process_df(df))
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
from transformers import FlaxT5ForConditionalGeneration
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
params = model.params
@ -275,11 +275,12 @@ df['p_property_correct'] = df['p_property'] == df['property']
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
print("property accuracy", sum(df['p_property_correct'])/len(df))
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
# %%
df[~df["p_property_correct"]]
# %%
df['p_thing']
# df[~df["p_property_correct"]]
# %%
# df['p_thing']
# %%
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"

View File

@ -0,0 +1,610 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
# for memory tracking
# from jax_smi import initialise_tracking
# initialise_tracking()
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 64
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
print("compile")
# enable bf16
# enable only for dense, some transformer sections, and shared
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
# path: not (
# (path[-2] == "layer_norm" and path[-1] == "weight") or
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
# (path[-2] == "o" and path[-1] == "kernel")
# )
# for path in flat_params
path: (
(path[-2] == "wi" and path[-1] == "weight") or
(path[-2] == "wo" and path[-1] == "weight") or
(path[-2] == "k" and path[-1] == "kernel") or
(path[-2] == "q" and path[-1] == "kernel") or
(path[-2] == "v" and path[-1] == "kernel") or
(path[-2] == "shared" and path[-1] == "embedding")
) for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# create init_fn to produce sharded state
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# %%
# replace the params tree with the new modified tree
# create partitions for model
from parallel.partitions import set_partitions
# set_partitions freezes the params on return
model_part_spec = set_partitions(unfreeze(params))
# p is already a partition spec
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
# get pspec for opt_state
def get_opt_spec(x):
if isinstance(x, dict):
return unfreeze(model_named_sharding)
# return an empty partspec
return mesh_sharding((PartitionSpec()))
# this function replaces the empty model params spec with the 'model_named_shard'
state_sharding = jax.tree.map(
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.mean()
# num_labels = padding_mask.mean()
return loss # , num_labels
# %%
# gradient accumulation
def accumulate_gradients_loop(
state,
batch,
minibatch_size: int,
loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
"""Calculate gradients and metrics for a batch using gradient accumulation.
Args:
state: Current training state.
batch: Full training batch.
rng: Random number generator to use.
num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn: Loss function to calculate gradients and metrics.
Returns:
Tuple with accumulated gradients and metrics over the minibatches.
"""
batch_size = batch['input_ids'].shape[0]
# minibatch_size = batch_size // num_minibatches
num_minibatches = batch_size // minibatch_size
# Define gradient function for single minibatch.
# If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.
# otherwise it returns (value, gradient), where value is the actual output
# of the function, hence the "value" of the namesake
grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
# Prepare loop variables.
grads = None
metrics = None
for minibatch_idx in range(num_minibatches):
with jax.named_scope(f"minibatch_{minibatch_idx}"):
# Split the batch into minibatches.
start = minibatch_idx * minibatch_size
end = start + minibatch_size
minibatch = jax.tree.map(lambda x: x[start:end], batch) # noqa: B023
# Calculate gradients and metrics for the minibatch.
# missing value is mean loss of batch
loss, step_grads = grad_fn(
state.params, minibatch
)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# Accumulate gradients and metrics across minibatches.
if grads is None:
grads = step_grads
metrics = step_metrics
else:
# accumulation adder
grads = jax.tree.map(jnp.add, grads, step_grads)
metrics = jax.tree.map(jnp.add, metrics, step_metrics)
# Average gradients over minibatches.
grads = jax.tree.map(lambda g: g / num_minibatches, grads)
return grads, metrics
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
# loss, num_labels = loss_fn(
loss = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss # , num_labels
# # compute gradients through computational graph
# # allow values to pass through
# grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
# (loss), grad = grad_fn(state.params, batch)
# # num_labels = jax.lax.psum(num_labels, "batch")
# new_state = state.apply_gradients(grads=grad)
# with jax.named_scope("sync_metrics"):
# step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# use gradient accumulation
grads, step_metrics = accumulate_gradients_loop(
state=state,
batch=batch,
minibatch_size=32,
loss_fn=compute_loss
)
new_state = state.apply_gradients(grads=grads)
return new_state, step_metrics
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
sharded_batch = jax.device_put(sharded_batch, x_sharding)
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step
# print("1 step for jit-ting")
# with mesh:
# state, metrics = train_step(initialized_state, sharded_batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
state = initialized_state
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = jax.device_put(batch, x_sharding)
with mesh:
state, train_metric = train_step(state, batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# try out
gather_state = jax.device_get(state)
gather_batch = jax.device_get(batch)
logits = gather_state.apply_fn(
{'params': gather_state.params},
input_ids=gather_batch['input_ids'],
attention_mask=gather_batch['attention_mask'],
decoder_input_ids=gather_batch['decoder_input_ids'],
decoder_attention_mask=gather_batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
probs = nn.softmax(logits, axis=-1)
predicted = jnp.argmax(probs, axis=-1)
print("sample output")
print(predicted[1])
# %%
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
main_model.save_pretrained(output_dir, params=params)
# %%

550
t5_jax_shmap.py Normal file
View File

@ -0,0 +1,550 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
# for memory tracking
# from jax_smi import initialise_tracking
# initialise_tracking()
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 40
batch_size = 32 # do not go beyond 128, 64 is good
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
if USE_CPU_ONLY:
return NamedSharding(mesh, pspec, memory_kind="unpinned_host")
else:
# if gpu
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
mean_loss = loss.mean()
# num_labels = padding_mask.mean()
return mean_loss # , num_labels
# %%
################################################################
# old jit in_shardings method
# create init_fn to produce sharded state
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
from parallel.partitions import set_partitions
# set_partitions freezes the params on return
model_part_spec = set_partitions(unfreeze(params))
# p is already a partition spec
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
# get pspec for opt_state
def get_opt_spec(x):
if isinstance(x, dict):
return unfreeze(model_named_sharding)
# return an empty partspec
return mesh_sharding((PartitionSpec()))
# this function replaces the empty model params spec with the 'model_named_shard'
state_sharding = jax.tree.map(
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.mean()
# num_labels = padding_mask.mean()
return loss # , num_labels
# %%
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
# computes loss per shard
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# logits sharding
# data, None, model
#
print("logits")
jax.debug.inspect_array_sharding(logits, callback=print)
# use labels here
# loss, num_labels = loss_fn(
loss = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
# loss sharding
# it gives PartitionSpec(), which implies a reduction already happened
print("loss")
jax.debug.inspect_array_sharding(loss, callback=print)
return loss # , num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
(loss), grads = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
# so far we have been operating from within each shard
# we need to sync gradients across devices
# we bring all gradients together onto a single device
# jax.debug.inspect_array_sharding(grads, callback=print)
grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec()))
# grads = jax.lax.with_sharding_constraint(grads, state_sharding)
# jax.debug.visualize_array_sharding(grad)
# jax.debug.inspect_array_sharding(grad, callback=print)
# check the output grad tree from mean
# print(jax.tree.map(jnp.shape, grad))
new_state = state.apply_gradients(grads=grads)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, step_metrics
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
# sharded_batch = jax.device_put(sharded_batch, x_sharding)
sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch)
jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step
print("1 step for jit-ting")
with mesh:
state, metrics = train_step(initialized_state, sharded_batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
print("*" * 20)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
state = initialized_state
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = jax.device_put(batch, x_sharding)
with mesh:
state, train_metric = train_step(state, batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# try out
# gather_state = jax.device_get(state)
# gather_batch = jax.device_get(batch)
# logits = gather_state.apply_fn(
# {'params': gather_state.params},
# input_ids=gather_batch['input_ids'],
# attention_mask=gather_batch['attention_mask'],
# decoder_input_ids=gather_batch['decoder_input_ids'],
# decoder_attention_mask=gather_batch['decoder_attention_mask'],
# )[0] # zero because output is some structure, where first is the logit
#
# probs = nn.softmax(logits, axis=-1)
# predicted = jnp.argmax(probs, axis=-1)
# print(predicted[0])
# %%
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
params = jax.tree.map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)
# %%

View File

@ -27,7 +27,7 @@ os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
@ -89,35 +89,50 @@ jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((4,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 64
num_epochs = 40
batch_size = 128
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
train_batch_size = per_device_train_batch_size * 2
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
eval_batch_size = per_device_eval_batch_size * 2
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
learning_rate = 2e-3
weight_decay = 0.01
adam_beta1 = 0.9
@ -197,21 +212,6 @@ model = main_model.module
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
@ -246,7 +246,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn(
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_candidates = ["final_layer_norm", "layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
@ -322,9 +322,9 @@ def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
@ -449,8 +449,8 @@ def train_step(state, batch):
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
sharded_batch = jax.device_put(sharded_batch, x_sharding)
# sharded_batch = next(iter(train_loader))
# sharded_batch = jax.device_put(sharded_batch, x_sharding)
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
@ -477,7 +477,7 @@ print(f" Total optimization steps = {total_train_steps}")
# jax.profiler.start_trace("./traces")
print("*" * 10)
print("*" * 50)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
@ -489,7 +489,7 @@ for epoch in epochs:
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
train_loader = dataprep.data_loader(rng, batch_size=train_batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
@ -514,21 +514,21 @@ for epoch in epochs:
)
# jax.profiler.stop_trace()
# %%
# try out
gather_state = jax.device_get(state)
gather_batch = jax.device_get(batch)
logits = gather_state.apply_fn(
{'params': gather_state.params},
input_ids=gather_batch['input_ids'],
attention_mask=gather_batch['attention_mask'],
decoder_input_ids=gather_batch['decoder_input_ids'],
decoder_attention_mask=gather_batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
probs = nn.softmax(logits, axis=-1)
predicted = jnp.argmax(probs, axis=-1)
predicted[1]
# # %%
# # try out
# gather_state = jax.device_get(state)
# gather_batch = jax.device_get(batch)
# logits = gather_state.apply_fn(
# {'params': gather_state.params},
# input_ids=gather_batch['input_ids'],
# attention_mask=gather_batch['attention_mask'],
# decoder_input_ids=gather_batch['decoder_input_ids'],
# decoder_attention_mask=gather_batch['decoder_attention_mask'],
# )[0] # zero because output is some structure, where first is the logit
#
# probs = nn.softmax(logits, axis=-1)
# predicted = jnp.argmax(probs, axis=-1)
# print(predicted[0])
# %%
main_model = custom_model.from_pretrained('t5-base')
@ -537,6 +537,7 @@ output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
params = jax.tree.map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)

1
t5_model/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,118 @@
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model configuration"""
from typing import Mapping
from transformers import PretrainedConfig
from transformers import logging
from etils import edc
logger = logging.get_logger(__name__)
class T5Config(PretrainedConfig):
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=32128, # vocab size here
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
classifier_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.classifier_dropout = classifier_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.use_bfloat16 = True
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# @property
# def inputs(self) -> Mapping[str, Mapping[int, str]]:
# common_inputs = {
# "input_ids": {0: "batch", 1: "encoder_sequence"},
# "attention_mask": {0: "batch", 1: "encoder_sequence"},
# }
# if self.use_past:
# common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
# common_inputs["decoder_input_ids"] = {0: "batch"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
# else:
# common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
#
# if self.use_past:
# self.fill_with_past_key_values_(common_inputs, direction="inputs")
#
# return common_inputs
#
# @property
# def default_onnx_opset(self) -> int:
# return 13

1832
t5_model/modeling_t5_flax.py Normal file

File diff suppressed because it is too large Load Diff