Feat: t5_jax_simple_parallel implements a working example of fsdp

This commit is contained in:
Richard Wong 2024-09-20 23:42:51 +09:00
parent 429e1742ab
commit aca80720c8
14 changed files with 5244 additions and 189 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
*.ipynb
t5_*/
model_checkpoints/
exports/
modified_t5_model/
traces/

1
parallel/.gitignore vendored
View File

@ -1 +1,2 @@
__pycache__
gpt-neo-125m/

View File

@ -17,17 +17,6 @@ file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retri
# training_size = len(split_datasets['train'])
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
# class takes in a dataset
class DataPrepare():
@ -37,6 +26,19 @@ class DataPrepare():
self.train_dataset: Optional[Dataset] = None
self.size: int = len(raw_dataset)
self.config: ConfigDict = config
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
# Define additional special tokens
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
self.make_dataset()
@ -52,31 +54,37 @@ class DataPrepare():
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
# produce input_ids and decoder_input_ids
model_inputs = tokenizer(
model_inputs = self.tokenizer(
inputs,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
# we separate it out because we need the attention mask
labels = self.tokenizer(
text_target=targets,
max_length=self.config.max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
# for loss computation
model_inputs["labels"] = labels["input_ids"]
# make decoder input ids
decoder_input_ids = shift_tokens_right_fn(
# this is actually "model output" shifted right
decoder_input_ids = self.shift_tokens_right_fn(
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
)
# require by model
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
# decoder_attention_mask = shift_tokens_right_fn(
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
# )
# We need decoder_attention_mask so we can ignore pad tokens in loss
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
return model_inputs
@ -89,13 +97,13 @@ class DataPrepare():
remove_columns=self.raw_dataset.column_names,)
# set to numpy
train_dataset.set_format(
type='numpy',
columns=[
'input_ids', 'attention_mask',
'labels', 'decoder_input_ids',
'decoder_attention_mask']
)
# train_dataset.set_format(
# type='numpy',
# columns=[
# 'input_ids', 'attention_mask', 'labels',
# 'decoder_input_ids',
# 'decoder_attention_mask']
# )
# check that data fits
# for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
@ -140,15 +148,15 @@ class DataPrepare():
for idx in batch_idx:
batch = dataset[idx]
batch = {k: v for k, v in batch.items()}
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# testing out the class
# %%
# init object
# e.g. Config
# # testing out the class
# # %%
# # init object
# # e.g. Config
# data_config = ConfigDict(
# dict(
# max_length=86,
@ -172,3 +180,13 @@ class DataPrepare():
# batch = next(iter(train_loader))
# batch['input_ids'].shape
# # %%
#
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
# tokens = tokenizer.tokenize(sentence)
# print("Tokens:", tokens)
# # Get the IDs (integer indices) of specific tokens
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
# print("Token IDs:", token_ids)
#
#
# # %%

View File

@ -242,7 +242,8 @@ def train_step(state, x):
# with mesh: # not strictly necessary in this case
# with mesh block is useful for explicit scope for device sharding
# but mesh management is automatic via jit sharding annotations
new_state = train_step(initialized_state, x)
with mesh:
new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)

854
parallel/gpt-neo-125m.json Normal file
View File

@ -0,0 +1,854 @@
{
"transformer": {
"h": {
"0": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"1": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"10": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"11": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"2": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"3": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"4": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"5": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"6": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"7": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"8": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
},
"9": {
"attn": {
"attention": {
"k_proj": {
"kernel": [
768,
768
]
},
"out_proj": {
"bias": [
768
],
"kernel": [
768,
768
]
},
"q_proj": {
"kernel": [
768,
768
]
},
"v_proj": {
"kernel": [
768,
768
]
}
}
},
"ln_1": {
"bias": [
768
],
"scale": [
768
]
},
"ln_2": {
"bias": [
768
],
"scale": [
768
]
},
"mlp": {
"c_fc": {
"bias": [
3072
],
"kernel": [
768,
3072
]
},
"c_proj": {
"bias": [
768
],
"kernel": [
3072,
768
]
}
}
}
},
"ln_f": {
"bias": [
768
],
"scale": [
768
]
},
"wpe": {
"embedding": [
2048,
768
]
},
"wte": {
"embedding": [
50257,
768
]
}
}
}

View File

@ -0,0 +1,24 @@
# %%
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import freeze, unfreeze
from partitions import set_partitions
from transformers import FlaxAutoModelForCausalLM
# this inits the model directly
model = FlaxAutoModelForCausalLM.from_pretrained(
"gpt-neo-125m",
)
params = model.params
# %%
import json
shape_dict = jax.tree.map(jnp.shape, params)
# print(json.dumps(shape_dict, sort_keys=True, indent=4))
with open('gpt-neo-125m.json', 'w') as f:
json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
# %%
param_spec = set_partitions(unfreeze(params))
# %%

108
parallel/partitions.py Normal file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for constructing PyTrees of PartitionSpecs."""
# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
import re
from flax.core.frozen_dict import freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.sharding import PartitionSpec as P
# Sentinels
_unmatched = object()
# For specifying empty leaf dict `{}`
empty_dict = object()
def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match
qts = tuple((re.compile(x + "$") for x in qs))
for i in range(len(ks) - len(qs) + 1):
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
if matches and all(matches):
return True
return False
def _replacement_rules(rules):
def replace(key, val):
for rule, replacement in rules:
if _match(rule, key):
return replacement
return val
return replace
# PartitionSpec for GPTNeo
# replicate the hidden dim and shard feed-forward and head dim
# def _get_partition_rules():
# return [
# # embeddings
# (("transformer", "wpe", "embedding"), P("mp", None)),
# (("transformer", "wte", "embedding"), P("mp", None)),
# # atention
# (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
# (("attention", "out_proj", "kernel"), P("mp", None)),
# (("attention", "out_proj", "bias"), None),
# # mlp
# (("mlp", "c_fc", "kernel"), P(None, "mp")),
# (("mlp", "c_fc", "bias"), P("mp")),
# (("mlp", "c_proj", "kernel"), P("mp", None)),
# (("mlp", "c_proj", "bias"), None),
# # layer norms
# ((r"ln_\d+", "bias"), None),
# ((r"\d+", r"ln_\d+", "scale"), None),
# (("ln_f", "bias"), None),
# (("ln_f", "scale"), None),
# ]
def _get_partition_rules():
return [
# embedding
(("shared", "embedding"), P("model", None)),
# SelfAttention
(("SelfAttention", "(q|k|v)", "kernel"), P(None, "model")),
(("SelfAttention", "o", "kernel"), P("model", None)),
(("SelfAttention", "relative_attention_bias", "embedding"), P(None)),
# EncDecAttention
(("EncDecAttention", "(q|k|v)", "kernel"), P(None, "model")),
(("EncDecAttention", "o", "kernel"), P("model", None)),
# DenseReluDense
(("DenseReluDense", "wi", "kernel"), P(None, "model")),
(("DenseReluDense", "wo", "kernel"), P("model", None)),
# layer norms
(("final_layer_norm", "weight"), P(None)),
(("layer_norm", "weight"), P(None)),
]
def set_partitions(in_dict):
rules = _get_partition_rules()
replace = _replacement_rules(rules)
initd = {k: _unmatched for k in flatten_dict(in_dict)}
result = {k: replace(k, v) for k, v in initd.items()}
assert _unmatched not in result.values(), "Incomplete partition spec."
# for item in result.values():
# if item:
# print(item)
return freeze(unflatten_dict(result))

1826
parallel/t5.json Normal file

File diff suppressed because it is too large Load Diff

651
parallel/t5_pjit.py Normal file
View File

@ -0,0 +1,651 @@
# MARK: START
# %%
# let's make 8-device simulator
import sys
sys.dont_write_bytecode = True
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags += (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze
import flax.core
from partitions import set_partitions
from tqdm import tqdm
from dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1_bf16'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 2 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 2e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
data_config = ConfigDict(
dict(
max_length=86,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# batch
# %%
# model
from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
# from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
from t5_model.pure_t5 import make_config
config = make_config()
model = model_init(config)
# %%
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5Config
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
# useful for transformer model
# model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
# from flax import traverse_util
# flat_params = traverse_util.flatten_dict(model.params)
# mask = {
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
# }
# mask = traverse_util.unflatten_dict(mask)
# model.params = model.to_bf16(model.params, mask)
##################################################################
# set partition on model
# %%
# # let's output the model parameters to a json file for study
# import json
# shape_dict = jax.tree.map(jnp.shape, params)
# # print(json.dumps(shape_dict, sort_keys=True, indent=4))
# with open('t5.json', 'w') as f:
# json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
# MARK: setup mesh
# %%
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec)
##################################################
# optimizers
# %%
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# specify sharding
# shard data
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
# Defining the required dimensions for the self-attention layer input
# batch_size = 2
# seq_length = 768
# n_heads = 12
# head_dim = 768
# %%
# Create a large array with the shape (batch_size, seq_length, n_heads, head_dim)
# large_input = np.random.rand(2,768,768)
# batch = jax.device_put(large_input, x_sharding)
# %%
# jax.debug.visualize_array_sharding(batch['input_ids'])
# %%
# shard output
# we will shard state by tracking its output upon jax.eval_shape after init
# define an init function to return a TrainState
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=variables['params'],
# tx=optimizer)
# return state
def init_fn(rng, batch, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
variables = model.init(
rng,
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask']
)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
# %%
# alternative
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(
# rng, batch
# )
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=variables['params'],
# tx=optimizer)
# return state
# %%
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
rng, init_rng = jax.random.split(rng)
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
# %%
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
print(state_sharding)
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
out_shardings=state_sharding
)
rng, init_rng = jax.random.split(rng)
initialized_state = jit_init_fn(rng, batch, model, adamw)
# %%
# we can analyze the params structure
# for weight, partitioned in initialized_state.params['decoder'].items():
# print(f'Sharding of {weight}: {partitioned}')
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
jax.tree.map(jnp.shape, initialized_state.params['decoder'])
# %%
print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding)
print(initialized_state.step)
print(initialized_state.step.sharding)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# %%
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# in_shardings=(state_sharding, x_sharding),
out_shardings=state_sharding
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_attention_mask'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
# num_labels = jax.lax.psum(num_labels, "batch")
# true grad = total grad / total samples
# grad = jax.lax.psum(grad, "batch")
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state
# %%
# variables = model.init(
# rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
# batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
with mesh:
new_state = train_step(initialized_state, batch)
# %%
# # %%
# #############################################################
# # we cannot integrate our model pspec with train_state
# # we just shard separately
# # update: we also cannot use the method of modifying a partitionspec tree
# # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter
#
# # %%
# def get_optim_initial_state(params):
# params = params
# state = adamw.init(params)
# return tuple((state)), params
#
# # %%
# # create partitions for model
# from partitions import set_partitions
# # set_partitions freezes the params on return
# model_param_spec = set_partitions(unfreeze(params))
#
# # %%
# params_shapes = jax.tree.map(lambda x: x.shape, params)
# # actually tuple
# optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes)
#
# # %%
# # get pspec for opt_state
# def get_opt_spec(x):
# if isinstance(x, dict):
# return unfreeze(model_param_spec)
# return PartitionSpec()
#
# # this function replaces the empty model params spec with the 'model_param_spec'
# opt_state_spec, param_spec = jax.tree.map(
# get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
# )
#
# # %%
#
# model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
# # store on cpu
# model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
#
# # %%
# device_mesh = mesh_utils.create_device_mesh((2,2))
# print(device_mesh)
#
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# print(mesh)
#
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
# return NamedSharding(mesh, pspec)
#
#
# # opt_state_sharding = mesh_sharding(opt_state_spec)
# # param_sharding = mesh_sharding(param_spec)
#
#
# # %%
# opt_state_sharding = nn.get_sharding(opt_state_spec, mesh)
# param_sharding = nn.get_sharding(param_spec, mesh)
#
# # %%
# # jit the get_initial_state function to shard params and init optimizer state in
# # a sharded way
# from jax.experimental.pjit import pjit
#
# with mesh:
# p_get_initial_state = pjit(
# get_optim_initial_state,
# in_shardings=None,
# out_shardings=(opt_state_spec, param_spec),
# )
#
# # Convert your PartitionSpec to NamedSharding for model params
# param_sharding = NamedSharding(mesh, freeze(param_spec))
# # Use device_put with sharding to move params onto the mesh
# sharded_params = jax.device_put(freeze(params), param_sharding)
#
# with mesh:
# # params is already frozen
# sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params))
#
# # %%
#
# # give up this section
# #############################################################
# # create train state
#
# # %%
# # Initialize random key and input for initialization
# rng = jax.random.PRNGKey(seed)
# loader_rng, rng = jax.random.split(rng)
# train_loader = dataprep.data_loader(rng, batch_size=2)
# batch = next(iter(train_loader))
#
# # use the T5 base model to do this
# from transformers import FlaxAutoModel
# model, params = FlaxAutoModel.from_pretrained(
# 't5-base',
# _do_init=False
# )
# t5_module = model.module
#
# # %%
# init_rng, rng = jax.random.split(rng)
# variables = t5_module.init(init_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables['params']
#
# # create an init function
# # %%
# # we will shard state by tracking its output upon jax.eval_shape after init
# # define an init function to return a TrainState
# def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState:
# init_rng, rng = jax.random.split(rng)
# variables = model.init(
# init_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_attention_mask'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables.pop("params")
# state = train_state.TrainState.create(
# apply_fn=model.__call__,
# params=params,
# tx=optimizer,
# )
# return state
#
# # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
# # Create an abstract closure to wrap the function before feeding it in
# # because `jax.eval_shape` only takes pytrees as arguments.
# # eval_shape(fn, rng_key, x)
# # used to perform shape inference
# # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
# init_rng, rng = jax.random.split(rng)
# abstract_variables = jax.eval_shape(
# functools.partial(init_fn, model=t5_module, optimizer=adamw),
# init_rng,
# batch
# )
#
# # %%
# # let's make our mesh
#
# device_mesh = mesh_utils.create_device_mesh((2,2))
# print(device_mesh)
#
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# print(mesh)
#
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
# return NamedSharding(mesh, pspec)
#
# # %%
# # making jax compatible batch
#
# # %%
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
# # batch = jax.device_put(batch), x_sharding)
# # jax.debug.visualize_array_sharding(batch)
#
# # %%
# state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
#
# # %%
# # integrate model_param_specs and state_out_specs
#
# # %%
# # i want to make a Sharding object
# # model_sharding = mesh_sharding(model_param_spec)
#
# # %%
# jit_init_fn = jax.jit(
# init_fn, # rng, batch, model, optimizer
# static_argnames=('model', 'optimizer'), # skip model and optimizer
# in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data
# out_shardings=state_sharding
# )
#
# # %%
#
# init_rng, rng = jax.random.split(rng)
# initialized_state = jit_init_fn(
# init_rng,
# batch,
# t5_module,
# adamw)
#
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
#
#
# # %%
#
# # %%
#
# %%

View File

@ -46,11 +46,13 @@ from datasets import load_from_disk
import nltk # Here to have a nice missing dependency error message early on
from typing import Dict, Any, Union
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import FrozenDict, unfreeze
import flax.core
@ -60,9 +62,20 @@ import time
# %%
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=true --xla_gpu_enable_custom_fusions=true --xla_gpu_enable_address_computation_fusion=true'
# os.environ['XLA_FLAGS'] = (
# '--xla_gpu_triton_gemm_any=true '
# '--xla_gpu_enable_custom_fusions=true '
# '--xla_gpu_enable_address_computation_fusion=true'
# )
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
@ -76,8 +89,8 @@ os.environ.update({
# %%
# get platform type
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
from jax.extend.backend import get_backend
print(get_backend().platform)
# %%
@ -90,14 +103,14 @@ except (LookupError, OSError):
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1_bf16'
save_path = 't5_5e_1_pmap'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 384 # 384 is the best
batch_size = 32 # 384 is the best
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
@ -116,7 +129,7 @@ adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
val_max_target_length = 86
predict_with_generate = True
@ -126,7 +139,7 @@ predict_with_generate = True
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
@ -172,15 +185,43 @@ from flax import traverse_util
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
# useful for transformer model
model.enable_gradient_checkpointing()
# model.enable_gradient_checkpointing()
# enable bf16 except for layer_norm
flat_params = traverse_util.flatten_dict(model.params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
model.params = model.to_bf16(model.params, mask)
# flat_params = traverse_util.flatten_dict(model.params)
# mask = {
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
# }
# mask = traverse_util.unflatten_dict(mask)
# # borrowed from transformers modeling_flax_utils
# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
# """
# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
# """
#
# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
# def conditional_cast(param):
# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
# param = param.astype(dtype)
# return param
#
# if mask is None:
# return jax.tree_util.tree_map(conditional_cast, params)
#
# flat_params = traverse_util.flatten_dict(params)
# flat_mask, _ = jax.tree_util.tree_flatten(mask)
#
# for masked, key in zip(flat_mask, sorted(flat_params.keys())):
# if masked:
# flat_params[key] = conditional_cast(flat_params[key])
#
# return traverse_util.unflatten_dict(flat_params)
#
# # Cast parameters to bfloat16 if desired
# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# # instead of casting the whole thing, we cast only certain parts of the tree
# params = cast_floating_to(model.params, jnp.bfloat16, mask)
# %%
# # Function to extract shape and dtype without showing values
@ -527,10 +568,12 @@ for epoch in epochs:
# jax.profiler.stop_trace()
# %%
# output_dir = save_path
# # save checkpoint after each epoch and push checkpoint to the hub
# if jax.process_index() == 0:
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# model.save_pretrained(output_dir, params=params)
# tokenizer.save_pretrained(output_dir)
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# %%

697
t5_jax_parallel.py Normal file
View File

@ -0,0 +1,697 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint as ocp
from flax.training import orbax_utils
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 32
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# batch
# %%
# model
# working
# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
# from parallel.t5_model.pure_t5 import make_config
# config = make_config()
# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True)
# %%
# from transformers import FlaxT5ForConditionalGeneration, T5Config
# model = FlaxT5ForConditionalGeneration.from_pretrained(
# "t5-base",
# dtype=jnp.bfloat16,
# )
# # pretrained_params = model.params
# model = model.module
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.float32,
# gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# print model datatype to verify
# rng, input_rng = jax.random.split(rng)
# variables = model.init(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((1,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
print("compile")
# enable bf16 except for layer_norm
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# Cast all parameters to bfloat16 if desired
# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# %%
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
# def init_fn(rng, batch, model, optimizer):
# # do be careful with the model init
# # imported models might have complicated init methods
# variables = model.init(
# rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask']
# )
# params = variables['params']
# mask = create_mask_for_layer_norm(params)
# # override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
#
# state = train_state.TrainState.create( # Create a `TrainState`.
# apply_fn=model.apply,
# params=params,
# tx=optimizer)
# return state
# %%
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
# eval_shape(fn, rng_key, x)
# used to perform shape inference
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
# rng, init_rng = jax.random.split(rng)
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# rng, init_rng = jax.random.split(rng)
# abstract_variables = jax.eval_shape(
# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
# %%
# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
# flan.linen.get_sharding
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
##################################################
# # %%
# # replace the params tree with the new modified tree
# # create partitions for model
# from parallel.partitions import set_partitions
# # set_partitions freezes the params on return
# model_part_spec = set_partitions(unfreeze(params))
# # p is already a partition spec
# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
#
# # %%
# # get_shapes = jax.tree.map(jnp.shape, params)
# # actually tuple
# # state_shapes = jax.eval_shape(state_sharding, get_shapes)
#
# # %%
# # get pspec for opt_state
# def get_opt_spec(x):
# if isinstance(x, dict):
# return unfreeze(model_named_sharding)
# # return an empty partspec
# return mesh_sharding((PartitionSpec()))
#
# # this function replaces the empty model params spec with the 'model_named_shard'
# state_sharding = jax.tree.map(
# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
# )
# %%
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# jit_init_fn = jax.jit(
# init_fn,
# static_argnames=('model', 'optimizer'), # skip model and optimizer
# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
# out_shardings=state_sharding
# )
#
#
# rng, init_rng = jax.random.split(rng)
# initialized_state = jit_init_fn(rng, batch, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# %%
# sharded_loss_fn = jax.jit(
# loss_fn,
# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis
# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model'
# )
def gather_and_sum(
sharded_values,
in_shardings
):
with mesh:
# Gather sharded values into a single device
gathered_values = jax.jit(
lambda x: x, in_shardings=in_shardings, out_shardings=None
)(sharded_values)
# Compute the sum of gathered values
summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values)
return summed_value
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
# return state as state_sharding
# we do not shard the metrics
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
# params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
# batch = jax.lax.with_sharding_constraint(batch, x_sharding)
# labels = batch.pop("decoder_input_ids")
# no use of labels here
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
loss, num_labels = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
# true grad = total grad / total samples
# needs to be in a singleton tuple for some reason
# gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),))
# gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec()))
# summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad)
# grad = jax.lax.psum(grad, "batch")
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
# step_metrics = jax.tree.map(
# # previously needed lax.psum
# # now just write single device code, let compiler handle
# lambda x: jnp.mean(x), step_metrics
# )
# if metrics is None:
# metrics = step_metrics
# else:
# # combine all the synced metrics
# metrics = jax.tree.map(jnp.mean, metrics, step_metrics)
return new_state, step_metrics
# %%
# prep 1 step
print("1 step for jit-ting")
with mesh:
state, metrics = train_step(initialized_state, batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
# function to shard a batch by treating it as a pytree
def shard_batch(batch):
# Shard each element in the dictionary (i.e., each key-value pair)
return jax.tree_util.tree_map(
lambda x: jax.device_put(x, x_sharding),
batch
)
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
# send to device
# batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()}
# batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding)
# batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding)
# batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding)
# batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding)
sharded_batch = shard_batch(batch)
with mesh:
state, train_metric = train_step(state, sharded_batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# with mesh:
# gathered_params = jax.jit(
# lambda x: x,
# in_shardings=(unfreeze(model_named_sharding),),
# out_shardings=mesh_sharding(PartitionSpec())
# )(state.params)
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
main_model.save_pretrained(output_dir, params=params)
# # stick to defaults
# options = ocp.CheckpointManagerOptions()
# with ocp.CheckpointManager(
# ocp.test_utils.erase_and_create_empty(save_path),
# options=options,
# ) as mngr:
#
# mngr.save(0, args=ocp.args.StandardSave(state))
# mngr.wait_until_finished()
# After providing `args` during an initial `save` or `restore` call, the
# `CheckpointManager` instance records the type so that you do not need to
# specify it again. If the `CheckpointManager` instance is not provided with a
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
# occasion it cannot be restored without specifying the argument at restore
# time.
# # In many cases, you can restore exactly as saved without specifying additional
# # arguments.
# mngr.restore(0)
# # If customization of properties like sharding or dtype is desired, just provide
# # the abstract target PyTree, the properties of which will be used to set
# # the properties of the restored arrays.
# mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
# %%

View File

@ -13,9 +13,12 @@
# # prediction code
# ## import and process test data
# %%
# import libraries
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import pandas as pd
import matplotlib.pyplot as plt
@ -35,7 +38,7 @@ jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
@ -51,9 +54,13 @@ from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ml_collections import ConfigDict
import time
from parallel.dataload import DataPrepare
import orbax.checkpoint as ocp
# %%
@ -86,121 +93,22 @@ def process_df(df):
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %% [markdown]
# ## Load model for attributes
# %%
# load model
model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path
)
# %% [markdown]
# ## Tokenizer
# %%
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
inputs,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=test_dataset.column_names,
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# # model generation
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
from transformers import FlaxT5ForConditionalGeneration
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
params = model.params
# %%
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
batch_size = 128
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
num_beams = 1
val_max_target_length = 128
@ -208,33 +116,31 @@ val_max_target_length = 128
predict_with_generate = True
# Initialize our training
# Initialize our prediction
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# reload model to prevent leakage of variables
# load model
model_name_or_path = "t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(test_dataset, data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
# %%
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# ensure full size floats
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# we need to replicate model over devices
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
replicated_params = jax.device_put_replicated(params, jax.devices())
# Define generation function
@ -245,25 +151,29 @@ num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
output_ids = model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
params=params,
**gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
pred_generations = []
pred_labels = []
decoder_input_list = []
rng, input_rng = jax.random.split(rng)
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
print(f" Num steps = {num_epochs}")
# print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
@ -272,26 +182,38 @@ for _ in tqdm(range(pred_steps), desc="Predicting..."):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
decoder_input = batch["decoder_input_ids"]
# generation
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
decoder_input_list.extend(decoder_input)
# %%
# %% [markdown]
# # process predictions
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
# %%
# code to get special token ids
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
# tokens = tokenizer.tokenize(sentence)
# print("Tokens:", tokens)
# # Get the IDs (integer indices) of specific tokens
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
# print("Token IDs:", token_ids)
sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT><SIG><UNIT><DATA_TYPE>"
tokens = tokenizer.tokenize(sentence)
print("Tokens:", tokens)
# Get the IDs (integer indices) of specific tokens
token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
print("Token IDs:", token_ids)
# %%
@ -304,7 +226,13 @@ def extract_seq(tokens, start_value, end_value):
return tokens[start_id+1:end_id]
# %%
i = 2
print(pred_generations[i])
print(extract_seq(pred_generations[i], 32100, 32101))
# %%
def process_tensor_output(tokens):
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>

543
t5_jax_simple_parallel.py Normal file
View File

@ -0,0 +1,543 @@
# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import functools
from functools import partial
from pprint import pprint
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
# from jax.experimental.pjit import pjit # superseded by jax.jit
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
from ml_collections import ConfigDict
import optax
import logging
import time
from datasets import Dataset, load_from_disk
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
import flax.core
# model checkpointing and saving utilities
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
from parallel.partitions import set_partitions
from tqdm import tqdm
from parallel.dataload import DataPrepare
# for memory tracking
# from jax_smi import initialise_tracking
# initialise_tracking()
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
if USE_CPU_ONLY:
jax.config.update('jax_platform_name', 'cpu')
else:
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# %%
## get platform type
from jax.extend.backend import get_backend
print(get_backend().platform)
print(jax.devices())
# %%
# config options
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train'])
# Store some constant
seed = 117
num_epochs = 5
batch_size = 64
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = training_size // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.01
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# %%
# prepare data
# init object
# e.g. Config
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(split_datasets['train'], data_config)
# # example usage
# # %%
seed = 117
rng = jax.random.PRNGKey(seed)
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
batch = next(iter(train_loader))
# %%
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
main_model = custom_model.from_pretrained(
"t5-base",
dtype=jnp.bfloat16,
gradient_checkpointing=True,
)
params = main_model.params
# pretrained_params = model.params
model = main_model.module
# %%
# # testing config hashability
# # some explanation:
# # The PreTrainedModel class loads a T5Config model that is not hashable because
# # it is a complicated class that pretends to be a dataclass.
# # The solution is to extract a dict from it, then make a ConfigDict from
# # ml_collections library so that we can get values via the "." operator.
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
# # modify the config before passing to the next layer
# from transformers import T5Config
# from t5_model.configuration_t5 import FrozenT5Config
# from ml_collections import ConfigDict, FrozenConfigDict
#
# config = T5Config.from_pretrained("t5-base").to_dict()
# config.pop('architectures')
# config.pop('id2label')
# # test if it works
# frozen_config = FrozenConfigDict(config)
# # test hash
# hash(frozen_config)
# %%
# # print model
# rng, input_rng = jax.random.split(rng)
# model.tabulate(
# input_rng,
# input_ids=batch['input_ids'],
# attention_mask=batch['attention_mask'],
# decoder_input_ids=batch['decoder_input_ids'],
# decoder_attention_mask=batch['decoder_attention_mask'],
# console_kwargs={"force_jupyter": True}
# )
# %%
# create mesh
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((2,2))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
# %%
# optimizers
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
training_size,
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
print("compile")
# enable bf16
# enable only for dense, some transformer sections, and shared
def create_mask_for_layer_norm(params):
flat_params = traverse_util.flatten_dict(params)
mask = {
# path: not (
# (path[-2] == "layer_norm" and path[-1] == "weight") or
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
# (path[-2] == "o" and path[-1] == "kernel")
# )
# for path in flat_params
path: (
(path[-2] == "wi" and path[-1] == "weight") or
(path[-2] == "wo" and path[-1] == "weight") or
(path[-2] == "k" and path[-1] == "kernel") or
(path[-2] == "q" and path[-1] == "kernel") or
(path[-2] == "v" and path[-1] == "kernel") or
(path[-2] == "shared" and path[-1] == "embedding")
) for path in flat_params
}
mask = traverse_util.unflatten_dict(mask)
return mask
# borrowed from transformers modeling_flax_utils
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_util.tree_map(conditional_cast, params)
flat_params = traverse_util.flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
flat_params[key] = conditional_cast(flat_params[key])
return traverse_util.unflatten_dict(flat_params)
# create init_fn to produce sharded state
def init_fn(params, model, optimizer):
# do be careful with the model init
# imported models might have complicated init methods
# mask = create_mask_for_layer_norm(params)
# override params with bfloat version
# params= cast_floating_to(params, jnp.bfloat16, mask)
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=params,
tx=optimizer)
return state
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=adamw), params)
# jax.sharding: describes how a jax.Array is laid out across devices
state_sharding = nn.get_sharding(abstract_variables, mesh)
# print(state_sharding)
# %%
# replace the params tree with the new modified tree
# create partitions for model
from parallel.partitions import set_partitions
# set_partitions freezes the params on return
model_part_spec = set_partitions(unfreeze(params))
# p is already a partition spec
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
# get pspec for opt_state
def get_opt_spec(x):
if isinstance(x, dict):
return unfreeze(model_named_sharding)
# return an empty partspec
return mesh_sharding((PartitionSpec()))
# this function replaces the empty model params spec with the 'model_named_shard'
state_sharding = jax.tree.map(
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
jit_init_fn = jax.jit(
init_fn,
static_argnames=('model', 'optimizer'), # skip model and optimizer
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
out_shardings=state_sharding # but returned initialized_state is sharded
)
initialized_state = jit_init_fn(params, model, adamw)
# %%
# train step
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
logits = jnp.asarray(logits, dtype=jnp.float32)
logits = logits.astype(jnp.float32)
soft_labels = soft_labels.astype(jnp.float32)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.mean()
# num_labels = padding_mask.mean()
return loss # , num_labels
# %%
# single device code annotated with jax.jit
@functools.partial(
jax.jit,
# state is state_sharding initialized from init_fn
# x_sharding is data sharded explicitly later
in_shardings=(state_sharding, x_sharding),
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
donate_argnames=('state'),
)
def train_step(state, batch):
label_smoothing_factor=0.0
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params, batch):
# check constraints
# frozen dict not allowed as sharding object
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
logits = state.apply_fn(
{'params': params},
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
decoder_input_ids=batch['decoder_input_ids'],
decoder_attention_mask=batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
# use labels here
# loss, num_labels = loss_fn(
loss = loss_fn(
logits,
batch["labels"],
batch["decoder_attention_mask"],
label_smoothing_factor)
return loss # , num_labels
# compute gradients through computational graph
# allow values to pass through
grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
(loss), grad = grad_fn(state.params, batch)
# num_labels = jax.lax.psum(num_labels, "batch")
new_state = state.apply_gradients(grads=grad)
with jax.named_scope("sync_metrics"):
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, step_metrics
# %%
# explore data sharding
sharded_batch = next(iter(train_loader))
sharded_batch = jax.device_put(sharded_batch, x_sharding)
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
# %%
# # prep 1 step
# print("1 step for jit-ting")
# with mesh:
# state, metrics = train_step(initialized_state, sharded_batch)
# %%
# %%
# tr
print("***** Running training *****")
print(f" Num examples = {training_size}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
# jax.profiler.start_trace("./traces")
print("*" * 10)
print("training start")
rng, input_rng = jax.random.split(rng)
train_time = 0
state = initialized_state
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
# Create sampling rng
train_metrics = []
steps_per_epoch = training_size // train_batch_size
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
# Generate an epoch by shuffling sampling indices from the train dataset
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = jax.device_put(batch, x_sharding)
with mesh:
state, train_metric = train_step(state, batch)
# train_metrics.append(train_metric)
# this is for more accurate time stats, but slows down training
# train_metric['loss'].block_until_ready()
train_time = time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | "
f"Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})"
)
# jax.profiler.stop_trace()
# %%
# try out
gather_state = jax.device_get(state)
gather_batch = jax.device_get(batch)
logits = gather_state.apply_fn(
{'params': gather_state.params},
input_ids=gather_batch['input_ids'],
attention_mask=gather_batch['attention_mask'],
decoder_input_ids=gather_batch['decoder_input_ids'],
decoder_attention_mask=gather_batch['decoder_attention_mask'],
)[0] # zero because output is some structure, where first is the logit
probs = nn.softmax(logits, axis=-1)
predicted = jnp.argmax(probs, axis=-1)
predicted[1]
# %%
main_model = custom_model.from_pretrained('t5-base')
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(state.params)
main_model.save_pretrained(output_dir, params=params)
# %%

360
t5_prediction_old.py Normal file
View File

@ -0,0 +1,360 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # prediction code
# ## import and process test data
# %%
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset
import evaluate
from tqdm import tqdm
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
# 'input': f"<DESC>{row['tag_description']}<DESC>",
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
# 'answer': f"{row['thing']} {row['property']}",
# 'answer_thing': row['thing'],
# 'answer_property': row['property'],
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %% [markdown]
# ## Load model for attributes
# %%
# load model
model_name_or_path = "./t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path
)
# %% [markdown]
# ## Tokenizer
# %%
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
inputs = example['input']
targets = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
inputs,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
text_target=targets,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=test_dataset.column_names,
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# # model generation
# %%
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
num_beams = 1
val_max_target_length = 128
predict_with_generate = True
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# reload model to prevent leakage of variables
# load model
model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# ensure full size floats
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
# we need to replicate model over devices
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
pred_generations = []
pred_labels = []
rng, input_rng = jax.random.split(rng)
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
for _ in tqdm(range(pred_steps), desc="Predicting..."):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
# generation
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
# %% [markdown]
# # process predictions
# %%
# code to get special token ids
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
# tokens = tokenizer.tokenize(sentence)
# print("Tokens:", tokens)
# # Get the IDs (integer indices) of specific tokens
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
# print("Token IDs:", token_ids)
# %%
# extract sequence and decode
def extract_seq(tokens, start_value, end_value):
if start_value not in tokens or end_value not in tokens:
return None # Or handle this case according to your requirements
start_id = np.where(tokens == start_value)[0][0]
end_id = np.where(tokens == end_value)[0][0]
return tokens[start_id+1:end_id]
def process_tensor_output(tokens):
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
p_thing = None
p_property = None
if (thing_seq is not None):
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
if (property_seq is not None):
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
return p_thing, p_property
# %%
# decode prediction labels
def decode_preds(tokens_list):
thing_prediction_list = []
property_prediction_list = []
for tokens in tokens_list:
p_thing, p_property = process_tensor_output(tokens)
thing_prediction_list.append(p_thing)
property_prediction_list.append(p_property)
return thing_prediction_list, property_prediction_list
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
# %%
# add labels too
thing_actual_list, property_actual_list = decode_preds(pred_labels)
# Convert the list to a Pandas DataFrame
df = pd.DataFrame({'p_thing': thing_prediction_list,
'p_property': property_prediction_list,
'thing': thing_actual_list,
'property' : property_actual_list})
df['p_thing_correct'] = df['p_thing'] == df['thing']
df['p_property_correct'] = df['p_property'] == df['property']
# %%
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
print("property accuracy", sum(df['p_property_correct'])/len(df))
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
# %%
df[~df["p_property_correct"]]
# %%
df['p_thing']
# %%
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"