diff --git a/.gitignore b/.gitignore index f2e3f61..3ca7f8b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.ipynb t5_*/ +model_checkpoints/ exports/ modified_t5_model/ traces/ diff --git a/parallel/.gitignore b/parallel/.gitignore index bee8a64..a077105 100644 --- a/parallel/.gitignore +++ b/parallel/.gitignore @@ -1 +1,2 @@ __pycache__ +gpt-neo-125m/ diff --git a/parallel/dataload.py b/parallel/dataload.py index fff0b70..9de9632 100644 --- a/parallel/dataload.py +++ b/parallel/dataload.py @@ -17,17 +17,6 @@ file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retri # training_size = len(split_datasets['train']) from transformers import T5TokenizerFast -tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) -# Define additional special tokens -additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] -# Add the additional special tokens to the tokenizer -tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) - -model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") - -model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) -shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 - # class takes in a dataset class DataPrepare(): @@ -37,6 +26,19 @@ class DataPrepare(): self.train_dataset: Optional[Dataset] = None self.size: int = len(raw_dataset) self.config: ConfigDict = config + self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False) + # Define additional special tokens + # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] + additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] + # Add the additional special tokens to the tokenizer + self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + + model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") + + model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) + self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009 + + self.make_dataset() @@ -52,31 +54,37 @@ class DataPrepare(): # text_target sets the corresponding label to inputs # there is no need to create a separate 'labels' # produce input_ids and decoder_input_ids - model_inputs = tokenizer( + model_inputs = self.tokenizer( inputs, max_length=self.config.max_length, padding="max_length", truncation=True, return_tensors="np" ) - labels = tokenizer( + # we separate it out because we need the attention mask + labels = self.tokenizer( text_target=targets, max_length=self.config.max_length, padding="max_length", truncation=True, return_tensors="np" ) - + model_inputs['input_ids'] = np.asarray(model_inputs['input_ids']) + model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask']) # for loss computation model_inputs["labels"] = labels["input_ids"] # make decoder input ids - decoder_input_ids = shift_tokens_right_fn( + # this is actually "model output" shifted right + decoder_input_ids = self.shift_tokens_right_fn( labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id ) # require by model model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) - # We need decoder_attention_mask so we can ignore pad tokens from loss - model_inputs["decoder_attention_mask"] = labels["attention_mask"] + # decoder_attention_mask = shift_tokens_right_fn( + # labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id + # ) + # We need decoder_attention_mask so we can ignore pad tokens in loss + model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"]) return model_inputs @@ -89,13 +97,13 @@ class DataPrepare(): remove_columns=self.raw_dataset.column_names,) # set to numpy - train_dataset.set_format( - type='numpy', - columns=[ - 'input_ids', 'attention_mask', - 'labels', 'decoder_input_ids', - 'decoder_attention_mask'] - ) + # train_dataset.set_format( + # type='numpy', + # columns=[ + # 'input_ids', 'attention_mask', 'labels', + # 'decoder_input_ids', + # 'decoder_attention_mask'] + # ) # check that data fits # for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: @@ -140,15 +148,15 @@ class DataPrepare(): for idx in batch_idx: batch = dataset[idx] - batch = {k: v for k, v in batch.items()} + batch = {k: np.array(v) for k, v in batch.items()} yield batch -# testing out the class -# %% -# init object -# e.g. Config +# # testing out the class +# # %% +# # init object +# # e.g. Config # data_config = ConfigDict( # dict( # max_length=86, @@ -172,3 +180,13 @@ class DataPrepare(): # batch = next(iter(train_loader)) # batch['input_ids'].shape # # %% +# +# sentence = "" +# tokens = tokenizer.tokenize(sentence) +# print("Tokens:", tokens) +# # Get the IDs (integer indices) of specific tokens +# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] +# print("Token IDs:", token_ids) +# +# +# # %% diff --git a/parallel/flax_pjit_tutorial.py b/parallel/flax_pjit_tutorial.py index 1ddf451..3d24701 100644 --- a/parallel/flax_pjit_tutorial.py +++ b/parallel/flax_pjit_tutorial.py @@ -242,7 +242,8 @@ def train_step(state, x): # with mesh: # not strictly necessary in this case # with mesh block is useful for explicit scope for device sharding # but mesh management is automatic via jit sharding annotations -new_state = train_step(initialized_state, x) +with mesh: + new_state = train_step(initialized_state, x) print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) diff --git a/parallel/gpt-neo-125m.json b/parallel/gpt-neo-125m.json new file mode 100644 index 0000000..79b035e --- /dev/null +++ b/parallel/gpt-neo-125m.json @@ -0,0 +1,854 @@ +{ + "transformer": { + "h": { + "0": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "1": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "10": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "11": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "2": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "3": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "4": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "5": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "6": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "7": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "8": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + }, + "9": { + "attn": { + "attention": { + "k_proj": { + "kernel": [ + 768, + 768 + ] + }, + "out_proj": { + "bias": [ + 768 + ], + "kernel": [ + 768, + 768 + ] + }, + "q_proj": { + "kernel": [ + 768, + 768 + ] + }, + "v_proj": { + "kernel": [ + 768, + 768 + ] + } + } + }, + "ln_1": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "ln_2": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "mlp": { + "c_fc": { + "bias": [ + 3072 + ], + "kernel": [ + 768, + 3072 + ] + }, + "c_proj": { + "bias": [ + 768 + ], + "kernel": [ + 3072, + 768 + ] + } + } + } + }, + "ln_f": { + "bias": [ + 768 + ], + "scale": [ + 768 + ] + }, + "wpe": { + "embedding": [ + 2048, + 768 + ] + }, + "wte": { + "embedding": [ + 50257, + 768 + ] + } + } +} \ No newline at end of file diff --git a/parallel/gptneo_partition_test.py b/parallel/gptneo_partition_test.py new file mode 100644 index 0000000..a3f0090 --- /dev/null +++ b/parallel/gptneo_partition_test.py @@ -0,0 +1,24 @@ +# %% +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import freeze, unfreeze +from partitions import set_partitions + +from transformers import FlaxAutoModelForCausalLM +# this inits the model directly +model = FlaxAutoModelForCausalLM.from_pretrained( + "gpt-neo-125m", +) +params = model.params + +# %% +import json +shape_dict = jax.tree.map(jnp.shape, params) +# print(json.dumps(shape_dict, sort_keys=True, indent=4)) +with open('gpt-neo-125m.json', 'w') as f: + json.dump(shape_dict, fp=f, sort_keys=True, indent=2) + +# %% +param_spec = set_partitions(unfreeze(params)) + +# %% diff --git a/parallel/partitions.py b/parallel/partitions.py new file mode 100644 index 0000000..40f9371 --- /dev/null +++ b/parallel/partitions.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for constructing PyTrees of PartitionSpecs.""" + +# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py + +import re + +from flax.core.frozen_dict import freeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.sharding import PartitionSpec as P + + +# Sentinels +_unmatched = object() + +# For specifying empty leaf dict `{}` +empty_dict = object() + + +def _match(qs, ks): + """Return True if regexes in qs match any window of strings in tuple ks.""" + # compile regexes and force complete match + qts = tuple((re.compile(x + "$") for x in qs)) + for i in range(len(ks) - len(qs) + 1): + matches = [x.match(y) for x, y in zip(qts, ks[i:])] + if matches and all(matches): + return True + return False + + +def _replacement_rules(rules): + def replace(key, val): + for rule, replacement in rules: + if _match(rule, key): + return replacement + return val + + return replace + + +# PartitionSpec for GPTNeo +# replicate the hidden dim and shard feed-forward and head dim +# def _get_partition_rules(): +# return [ +# # embeddings +# (("transformer", "wpe", "embedding"), P("mp", None)), +# (("transformer", "wte", "embedding"), P("mp", None)), +# # atention +# (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), +# (("attention", "out_proj", "kernel"), P("mp", None)), +# (("attention", "out_proj", "bias"), None), +# # mlp +# (("mlp", "c_fc", "kernel"), P(None, "mp")), +# (("mlp", "c_fc", "bias"), P("mp")), +# (("mlp", "c_proj", "kernel"), P("mp", None)), +# (("mlp", "c_proj", "bias"), None), +# # layer norms +# ((r"ln_\d+", "bias"), None), +# ((r"\d+", r"ln_\d+", "scale"), None), +# (("ln_f", "bias"), None), +# (("ln_f", "scale"), None), +# ] + +def _get_partition_rules(): + return [ + # embedding + (("shared", "embedding"), P("model", None)), + # SelfAttention + (("SelfAttention", "(q|k|v)", "kernel"), P(None, "model")), + (("SelfAttention", "o", "kernel"), P("model", None)), + (("SelfAttention", "relative_attention_bias", "embedding"), P(None)), + # EncDecAttention + (("EncDecAttention", "(q|k|v)", "kernel"), P(None, "model")), + (("EncDecAttention", "o", "kernel"), P("model", None)), + # DenseReluDense + (("DenseReluDense", "wi", "kernel"), P(None, "model")), + (("DenseReluDense", "wo", "kernel"), P("model", None)), + # layer norms + (("final_layer_norm", "weight"), P(None)), + (("layer_norm", "weight"), P(None)), + ] + + +def set_partitions(in_dict): + rules = _get_partition_rules() + replace = _replacement_rules(rules) + initd = {k: _unmatched for k in flatten_dict(in_dict)} + result = {k: replace(k, v) for k, v in initd.items()} + assert _unmatched not in result.values(), "Incomplete partition spec." + # for item in result.values(): + # if item: + # print(item) + + return freeze(unflatten_dict(result)) \ No newline at end of file diff --git a/parallel/t5.json b/parallel/t5.json new file mode 100644 index 0000000..362d1af --- /dev/null +++ b/parallel/t5.json @@ -0,0 +1,1826 @@ +{ + "decoder": { + "block": { + "0": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "relative_attention_bias": { + "embedding": [ + 32, + 12 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "1": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "10": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "11": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "2": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "3": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "4": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "5": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "6": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "7": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "8": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "9": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "EncDecAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "2": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + } + }, + "final_layer_norm": { + "weight": [ + 768 + ] + } + }, + "encoder": { + "block": { + "0": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "relative_attention_bias": { + "embedding": [ + 32, + 12 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "1": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "10": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "11": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "2": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "3": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "4": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "5": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "6": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "7": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "8": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + }, + "9": { + "layer": { + "0": { + "SelfAttention": { + "k": { + "kernel": [ + 768, + 768 + ] + }, + "o": { + "kernel": [ + 768, + 768 + ] + }, + "q": { + "kernel": [ + 768, + 768 + ] + }, + "v": { + "kernel": [ + 768, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + }, + "1": { + "DenseReluDense": { + "wi": { + "kernel": [ + 768, + 3072 + ] + }, + "wo": { + "kernel": [ + 3072, + 768 + ] + } + }, + "layer_norm": { + "weight": [ + 768 + ] + } + } + } + } + }, + "final_layer_norm": { + "weight": [ + 768 + ] + } + }, + "shared": { + "embedding": [ + 32128, + 768 + ] + } +} \ No newline at end of file diff --git a/parallel/t5_pjit.py b/parallel/t5_pjit.py new file mode 100644 index 0000000..275bdf2 --- /dev/null +++ b/parallel/t5_pjit.py @@ -0,0 +1,651 @@ +# MARK: START +# %% +# let's make 8-device simulator +import sys +sys.dont_write_bytecode = True +import os + +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = True + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags += ( + "--xla_gpu_enable_triton_softmax_fusion=true " + "--xla_gpu_triton_gemm_any=false " + "--xla_gpu_enable_async_collectives=true " + "--xla_gpu_enable_latency_hiding_scheduler=true " + "--xla_gpu_enable_highest_priority_async_stream=true " + ) +os.environ["XLA_FLAGS"] = flags + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import freeze, unfreeze +import flax.core + +from partitions import set_partitions + +from tqdm import tqdm + +from dataload import DataPrepare + + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + + +# %% +# get platform type +from jax.lib import xla_bridge +print(xla_bridge.get_backend().platform) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' +save_path = 't5_80_1_bf16' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 2 # 384 is the best +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 2e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = 128 + +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +data_config = ConfigDict( + dict( + max_length=86, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=batch_size) +batch = next(iter(train_loader)) +# batch + +# %% +# model + + +from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init +# from t5_model.pure_t5 import FlaxT5DenseActDense as model_init +from t5_model.pure_t5 import make_config +config = make_config() +model = model_init(config) + +# %% +from transformers import FlaxT5ForConditionalGeneration +from transformers import T5Config +model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False) + + + +# useful for transformer model +# model.enable_gradient_checkpointing() + +# enable bf16 except for layer_norm +# from flax import traverse_util +# flat_params = traverse_util.flatten_dict(model.params) +# mask = { +# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params +# } +# mask = traverse_util.unflatten_dict(mask) +# model.params = model.to_bf16(model.params, mask) + + +################################################################## +# set partition on model + +# %% +# # let's output the model parameters to a json file for study +# import json +# shape_dict = jax.tree.map(jnp.shape, params) +# # print(json.dumps(shape_dict, sort_keys=True, indent=4)) +# with open('t5.json', 'w') as f: +# json.dump(shape_dict, fp=f, sort_keys=True, indent=2) + +# MARK: setup mesh +# %% +device_mesh = mesh_utils.create_device_mesh((2,2)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec) + + +################################################## +# optimizers +# %% + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +# %% +# specify sharding + +# shard data +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()} +# Defining the required dimensions for the self-attention layer input +# batch_size = 2 +# seq_length = 768 +# n_heads = 12 +# head_dim = 768 +# %% + +# Create a large array with the shape (batch_size, seq_length, n_heads, head_dim) +# large_input = np.random.rand(2,768,768) +# batch = jax.device_put(large_input, x_sharding) + +# %% +# jax.debug.visualize_array_sharding(batch['input_ids']) + +# %% +# shard output +# we will shard state by tracking its output upon jax.eval_shape after init +# define an init function to return a TrainState +# def init_fn(rng, batch, model, optimizer): +# # do be careful with the model init +# # imported models might have complicated init methods +# variables = model.init(rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_attention_mask'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) +# state = train_state.TrainState.create( # Create a `TrainState`. +# apply_fn=model.apply, +# params=variables['params'], +# tx=optimizer) +# return state + + +def init_fn(rng, batch, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + variables = model.init( + rng, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_attention_mask'], + decoder_attention_mask=batch['decoder_attention_mask'] + ) + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=variables['params'], + tx=optimizer) + return state + +# %% +# alternative +# def init_fn(rng, batch, model, optimizer): +# # do be careful with the model init +# # imported models might have complicated init methods +# variables = model.init( +# rng, batch +# ) +# state = train_state.TrainState.create( # Create a `TrainState`. +# apply_fn=model.apply, +# params=variables['params'], +# tx=optimizer) +# return state + + +# %% +# Create an abstract closure to wrap the function before feeding it in +# because `jax.eval_shape` only takes pytrees as arguments. +# eval_shape(fn, rng_key, x) +# used to perform shape inference +# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves +rng, init_rng = jax.random.split(rng) +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch) + + +# %% +# This `state_sharding` has the same pytree structure as `state`, the output +# of the `init_fn`. +# flan.linen.get_sharding +# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) +print(state_sharding) + +# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity + + +# %% +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data + out_shardings=state_sharding +) + + +rng, init_rng = jax.random.split(rng) +initialized_state = jit_init_fn(rng, batch, model, adamw) + +# %% +# we can analyze the params structure +# for weight, partitioned in initialized_state.params['decoder'].items(): +# print(f'Sharding of {weight}: {partitioned}') +# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) +# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) +jax.tree.map(jnp.shape, initialized_state.params['decoder']) + + +# %% +print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding) +print(initialized_state.step) +print(initialized_state.step.sharding) + + +# %% +# train step +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# %% + +# single device code annotated with jax.jit +@functools.partial( + jax.jit, + # in_shardings=(state_sharding, x_sharding), + out_shardings=state_sharding +) +def train_step(state, batch): + label_smoothing_factor=0.0 + # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn( + {'params': params}, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_attention_mask'], + decoder_attention_mask=batch['decoder_attention_mask'], + )[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + # allow values to pass through + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + # num_labels = jax.lax.psum(num_labels, "batch") + + + # true grad = total grad / total samples + # grad = jax.lax.psum(grad, "batch") + # grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad) + + # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state + + + +# %% + +# variables = model.init( +# rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_attention_mask'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) +# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +# batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()} + +with mesh: + new_state = train_step(initialized_state, batch) + + + +# %% + + + + + +# # %% +# ############################################################# +# # we cannot integrate our model pspec with train_state +# # we just shard separately +# # update: we also cannot use the method of modifying a partitionspec tree +# # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter +# +# # %% +# def get_optim_initial_state(params): +# params = params +# state = adamw.init(params) +# return tuple((state)), params +# +# # %% +# # create partitions for model +# from partitions import set_partitions +# # set_partitions freezes the params on return +# model_param_spec = set_partitions(unfreeze(params)) +# +# # %% +# params_shapes = jax.tree.map(lambda x: x.shape, params) +# # actually tuple +# optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes) +# +# # %% +# # get pspec for opt_state +# def get_opt_spec(x): +# if isinstance(x, dict): +# return unfreeze(model_param_spec) +# return PartitionSpec() +# +# # this function replaces the empty model params spec with the 'model_param_spec' +# opt_state_spec, param_spec = jax.tree.map( +# get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) +# ) +# +# # %% +# +# model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True) +# # store on cpu +# model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params) +# +# # %% +# device_mesh = mesh_utils.create_device_mesh((2,2)) +# print(device_mesh) +# +# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +# print(mesh) +# +# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: +# return NamedSharding(mesh, pspec) +# +# +# # opt_state_sharding = mesh_sharding(opt_state_spec) +# # param_sharding = mesh_sharding(param_spec) +# +# +# # %% +# opt_state_sharding = nn.get_sharding(opt_state_spec, mesh) +# param_sharding = nn.get_sharding(param_spec, mesh) +# +# # %% +# # jit the get_initial_state function to shard params and init optimizer state in +# # a sharded way +# from jax.experimental.pjit import pjit +# +# with mesh: +# p_get_initial_state = pjit( +# get_optim_initial_state, +# in_shardings=None, +# out_shardings=(opt_state_spec, param_spec), +# ) +# +# # Convert your PartitionSpec to NamedSharding for model params +# param_sharding = NamedSharding(mesh, freeze(param_spec)) +# # Use device_put with sharding to move params onto the mesh +# sharded_params = jax.device_put(freeze(params), param_sharding) +# +# with mesh: +# # params is already frozen +# sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params)) +# +# # %% +# +# # give up this section +# ############################################################# +# # create train state +# +# # %% +# # Initialize random key and input for initialization +# rng = jax.random.PRNGKey(seed) +# loader_rng, rng = jax.random.split(rng) +# train_loader = dataprep.data_loader(rng, batch_size=2) +# batch = next(iter(train_loader)) +# +# # use the T5 base model to do this +# from transformers import FlaxAutoModel +# model, params = FlaxAutoModel.from_pretrained( +# 't5-base', +# _do_init=False +# ) +# t5_module = model.module +# +# # %% +# init_rng, rng = jax.random.split(rng) +# variables = t5_module.init(init_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_attention_mask'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) +# params = variables['params'] +# +# # create an init function +# # %% +# # we will shard state by tracking its output upon jax.eval_shape after init +# # define an init function to return a TrainState +# def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState: +# init_rng, rng = jax.random.split(rng) +# variables = model.init( +# init_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_attention_mask'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) +# params = variables.pop("params") +# state = train_state.TrainState.create( +# apply_fn=model.__call__, +# params=params, +# tx=optimizer, +# ) +# return state +# +# # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True) +# # Create an abstract closure to wrap the function before feeding it in +# # because `jax.eval_shape` only takes pytrees as arguments. +# # eval_shape(fn, rng_key, x) +# # used to perform shape inference +# # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves +# init_rng, rng = jax.random.split(rng) +# abstract_variables = jax.eval_shape( +# functools.partial(init_fn, model=t5_module, optimizer=adamw), +# init_rng, +# batch +# ) +# +# # %% +# # let's make our mesh +# +# device_mesh = mesh_utils.create_device_mesh((2,2)) +# print(device_mesh) +# +# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +# print(mesh) +# +# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: +# return NamedSharding(mesh, pspec) +# +# # %% +# # making jax compatible batch +# +# # %% +# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +# # batch = jax.device_put(batch), x_sharding) +# # jax.debug.visualize_array_sharding(batch) +# +# # %% +# state_sharding = nn.get_sharding(abstract_variables, mesh) +# print(state_sharding) +# +# # %% +# # integrate model_param_specs and state_out_specs +# +# # %% +# # i want to make a Sharding object +# # model_sharding = mesh_sharding(model_param_spec) +# +# # %% +# jit_init_fn = jax.jit( +# init_fn, # rng, batch, model, optimizer +# static_argnames=('model', 'optimizer'), # skip model and optimizer +# in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data +# out_shardings=state_sharding +# ) +# +# # %% +# +# init_rng, rng = jax.random.split(rng) +# initialized_state = jit_init_fn( +# init_rng, +# batch, +# t5_module, +# adamw) +# +# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) +# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) +# +# +# # %% +# +# # %% +# +# %% diff --git a/t5_jax.py b/t5_jax.py index 765e1cf..def53cb 100644 --- a/t5_jax.py +++ b/t5_jax.py @@ -46,11 +46,13 @@ from datasets import load_from_disk import nltk # Here to have a nice missing dependency error message early on +from typing import Dict, Any, Union from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import FrozenDict, unfreeze import flax.core @@ -60,9 +62,20 @@ import time # %% import os -os.environ['XLA_FLAGS'] = ( - '--xla_gpu_triton_gemm_any=true --xla_gpu_enable_custom_fusions=true --xla_gpu_enable_address_computation_fusion=true' +# os.environ['XLA_FLAGS'] = ( + # '--xla_gpu_triton_gemm_any=true ' + # '--xla_gpu_enable_custom_fusions=true ' + # '--xla_gpu_enable_address_computation_fusion=true' +# ) +flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' ) +os.environ["XLA_FLAGS"] = flags + os.environ.update({ "TOKENIZERS_PARALLELISM" : "false", @@ -76,8 +89,8 @@ os.environ.update({ # %% # get platform type -from jax.lib import xla_bridge -print(xla_bridge.get_backend().platform) +from jax.extend.backend import get_backend +print(get_backend().platform) # %% @@ -90,14 +103,14 @@ except (LookupError, OSError): # %% # config options file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' -save_path = 't5_80_1_bf16' +save_path = 't5_5e_1_pmap' # file_path = 'combined_data' split_datasets = load_from_disk(file_path) training_size = len(split_datasets['train']) # Store some constant seed = 117 num_epochs = 5 -batch_size = 384 # 384 is the best +batch_size = 32 # 384 is the best num_train_epochs = num_epochs per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() @@ -116,7 +129,7 @@ adam_epsilon = 1e-8 label_smoothing_factor = 0.0 num_beams = 1 -val_max_target_length = 128 +val_max_target_length = 86 predict_with_generate = True @@ -126,7 +139,7 @@ predict_with_generate = True from transformers import T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) # Define additional special tokens -additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +additional_special_tokens = ["", "", "", "", "", "", "", "", ""] # Add the additional special tokens to the tokenizer tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) @@ -172,15 +185,43 @@ from flax import traverse_util model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") # useful for transformer model -model.enable_gradient_checkpointing() +# model.enable_gradient_checkpointing() # enable bf16 except for layer_norm -flat_params = traverse_util.flatten_dict(model.params) -mask = { - path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params -} -mask = traverse_util.unflatten_dict(mask) -model.params = model.to_bf16(model.params, mask) +# flat_params = traverse_util.flatten_dict(model.params) +# mask = { +# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params +# } +# mask = traverse_util.unflatten_dict(mask) +# # borrowed from transformers modeling_flax_utils +# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: +# """ +# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. +# """ +# +# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 +# def conditional_cast(param): +# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): +# param = param.astype(dtype) +# return param +# +# if mask is None: +# return jax.tree_util.tree_map(conditional_cast, params) +# +# flat_params = traverse_util.flatten_dict(params) +# flat_mask, _ = jax.tree_util.tree_flatten(mask) +# +# for masked, key in zip(flat_mask, sorted(flat_params.keys())): +# if masked: +# flat_params[key] = conditional_cast(flat_params[key]) +# +# return traverse_util.unflatten_dict(flat_params) +# +# # Cast parameters to bfloat16 if desired +# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params) +# # instead of casting the whole thing, we cast only certain parts of the tree +# params = cast_floating_to(model.params, jnp.bfloat16, mask) + # %% # # Function to extract shape and dtype without showing values @@ -527,10 +568,12 @@ for epoch in epochs: # jax.profiler.stop_trace() # %% -# output_dir = save_path -# # save checkpoint after each epoch and push checkpoint to the hub -# if jax.process_index() == 0: -# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) -# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) -# model.save_pretrained(output_dir, params=params) -# tokenizer.save_pretrained(output_dir) +output_dir = save_path +# save checkpoint after each epoch and push checkpoint to the hub +if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) + model.save_pretrained(output_dir, params=params) + tokenizer.save_pretrained(output_dir) + +# %% diff --git a/t5_jax_parallel.py b/t5_jax_parallel.py new file mode 100644 index 0000000..fb92073 --- /dev/null +++ b/t5_jax_parallel.py @@ -0,0 +1,697 @@ +# %% +import os +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = False + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +os.environ["XLA_FLAGS"] = flags +os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" + }) + + + + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import freeze, unfreeze, FrozenDict +import flax.core + +# model checkpointing and saving utilities +from flax import linen as nn +from flax.training import checkpoints, train_state +from flax import struct, serialization +import orbax.checkpoint as ocp +from flax.training import orbax_utils + +from parallel.partitions import set_partitions + +from tqdm import tqdm + +from parallel.dataload import DataPrepare + + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + + + + +# %% +## get platform type +from jax.extend.backend import get_backend +print(get_backend().platform) +print(jax.devices()) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' +save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 32 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 +num_beams = 1 +val_max_target_length = 128 +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=batch_size) +batch = next(iter(train_loader)) +# batch + +# %% +# model + +# working +# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init +# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init +# from parallel.t5_model.pure_t5 import make_config +# config = make_config() +# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True) + + +# %% +# from transformers import FlaxT5ForConditionalGeneration, T5Config +# model = FlaxT5ForConditionalGeneration.from_pretrained( +# "t5-base", +# dtype=jnp.bfloat16, +# ) +# # pretrained_params = model.params +# model = model.module + +# %% +# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom +from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model +main_model = custom_model.from_pretrained( + "t5-base", + dtype=jnp.float32, + # gradient_checkpointing=True, +) +params = main_model.params +# pretrained_params = model.params +model = main_model.module + +# %% +# # testing config hashability +# # some explanation: +# # The PreTrainedModel class loads a T5Config model that is not hashable because +# # it is a complicated class that pretends to be a dataclass. +# # The solution is to extract a dict from it, then make a ConfigDict from +# # ml_collections library so that we can get values via the "." operator. +# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to +# # modify the config before passing to the next layer +# from transformers import T5Config +# from t5_model.configuration_t5 import FrozenT5Config +# from ml_collections import ConfigDict, FrozenConfigDict +# +# config = T5Config.from_pretrained("t5-base").to_dict() +# config.pop('architectures') +# config.pop('id2label') +# # test if it works +# frozen_config = FrozenConfigDict(config) +# # test hash +# hash(frozen_config) + +# %% + +# %% +# # print model +# rng, input_rng = jax.random.split(rng) +# model.tabulate( +# input_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'], +# console_kwargs={"force_jupyter": True} +# ) + +# %% +# print model datatype to verify +# rng, input_rng = jax.random.split(rng) +# variables = model.init( +# input_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) + + + + + + +# %% +# create mesh +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((1,1)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec, memory_kind="device") + +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +model_sharding=mesh_sharding(PartitionSpec(None, 'model')) + + +# %% +# optimizers + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +print("compile") + + +# enable bf16 except for layer_norm +def create_mask_for_layer_norm(params): + flat_params = traverse_util.flatten_dict(params) + mask = { + path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params + } + mask = traverse_util.unflatten_dict(mask) + return mask + +# borrowed from transformers modeling_flax_utils +def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = traverse_util.flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return traverse_util.unflatten_dict(flat_params) + +# Cast all parameters to bfloat16 if desired +# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params) + +# %% +def init_fn(params, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + # mask = create_mask_for_layer_norm(params) + # override params with bfloat version + # params= cast_floating_to(params, jnp.bfloat16, mask) + + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=params, + tx=optimizer) + return state + + +# def init_fn(rng, batch, model, optimizer): +# # do be careful with the model init +# # imported models might have complicated init methods +# variables = model.init( +# rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'] +# ) +# params = variables['params'] +# mask = create_mask_for_layer_norm(params) +# # override params with bfloat version +# params= cast_floating_to(params, jnp.bfloat16, mask) +# +# state = train_state.TrainState.create( # Create a `TrainState`. +# apply_fn=model.apply, +# params=params, +# tx=optimizer) +# return state + + + +# %% +# Create an abstract closure to wrap the function before feeding it in +# because `jax.eval_shape` only takes pytrees as arguments. +# eval_shape(fn, rng_key, x) +# used to perform shape inference +# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves +# rng, init_rng = jax.random.split(rng) +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=adamw), params) + +# rng, init_rng = jax.random.split(rng) +# abstract_variables = jax.eval_shape( +# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch) + + +# %% +# This `state_sharding` has the same pytree structure as `state`, the output +# of the `init_fn`. +# flan.linen.get_sharding +# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) +# print(state_sharding) + +# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity + +################################################## +# # %% +# # replace the params tree with the new modified tree +# # create partitions for model +# from parallel.partitions import set_partitions +# # set_partitions freezes the params on return +# model_part_spec = set_partitions(unfreeze(params)) +# # p is already a partition spec +# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) +# +# # %% +# # get_shapes = jax.tree.map(jnp.shape, params) +# # actually tuple +# # state_shapes = jax.eval_shape(state_sharding, get_shapes) +# +# # %% +# # get pspec for opt_state +# def get_opt_spec(x): +# if isinstance(x, dict): +# return unfreeze(model_named_sharding) +# # return an empty partspec +# return mesh_sharding((PartitionSpec())) +# +# # this function replaces the empty model params spec with the 'model_named_shard' +# state_sharding = jax.tree.map( +# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) +# ) + + + +# %% +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly + out_shardings=state_sharding # but returned initialized_state is sharded +) +initialized_state = jit_init_fn(params, model, adamw) + + +# jit_init_fn = jax.jit( +# init_fn, +# static_argnames=('model', 'optimizer'), # skip model and optimizer +# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data +# out_shardings=state_sharding +# ) +# +# +# rng, init_rng = jax.random.split(rng) +# initialized_state = jit_init_fn(rng, batch, model, adamw) + + +# %% +# train step +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# %% + +# sharded_loss_fn = jax.jit( +# loss_fn, +# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis +# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model' +# ) + +def gather_and_sum( + sharded_values, + in_shardings +): + with mesh: + # Gather sharded values into a single device + gathered_values = jax.jit( + lambda x: x, in_shardings=in_shardings, out_shardings=None + )(sharded_values) + + # Compute the sum of gathered values + summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values) + return summed_value + + +# single device code annotated with jax.jit +@functools.partial( + jax.jit, + # state is state_sharding initialized from init_fn + # x_sharding is data sharded explicitly later + in_shardings=(state_sharding, x_sharding), + # return state as state_sharding + # we do not shard the metrics + out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), + donate_argnames=('state'), +) +def train_step(state, batch): + + label_smoothing_factor=0.0 + # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + def compute_loss(params, batch): + # check constraints + # frozen dict not allowed as sharding object + # params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) + # batch = jax.lax.with_sharding_constraint(batch, x_sharding) + # labels = batch.pop("decoder_input_ids") + # no use of labels here + logits = state.apply_fn( + {'params': params}, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_input_ids'], + decoder_attention_mask=batch['decoder_attention_mask'], + )[0] # zero because output is some structure, where first is the logit + # use labels here + loss, num_labels = loss_fn( + logits, + batch["labels"], + batch["decoder_attention_mask"], + label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + # allow values to pass through + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params, batch) + # num_labels = jax.lax.psum(num_labels, "batch") + + + # true grad = total grad / total samples + # needs to be in a singleton tuple for some reason + # gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),)) + + # gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec())) + + # summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad) + # grad = jax.lax.psum(grad, "batch") + # grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + + new_state = state.apply_gradients(grads=grad) + with jax.named_scope("sync_metrics"): + step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + # step_metrics = jax.tree.map( + # # previously needed lax.psum + # # now just write single device code, let compiler handle + # lambda x: jnp.mean(x), step_metrics + # ) + + # if metrics is None: + # metrics = step_metrics + # else: + # # combine all the synced metrics + # metrics = jax.tree.map(jnp.mean, metrics, step_metrics) + + + return new_state, step_metrics + + + + +# %% +# prep 1 step +print("1 step for jit-ting") + + +with mesh: + state, metrics = train_step(initialized_state, batch) + + +# %% + +# %% +# tr +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + +# function to shard a batch by treating it as a pytree +def shard_batch(batch): + # Shard each element in the dictionary (i.e., each key-value pair) + return jax.tree_util.tree_map( + lambda x: jax.device_put(x, x_sharding), + batch + ) + + +print("*" * 10) +print("training start") +rng, input_rng = jax.random.split(rng) +train_time = 0 +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + steps_per_epoch = training_size // train_batch_size + train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + # send to device + # batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()} + # batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding) + # batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding) + # batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding) + # batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding) + sharded_batch = shard_batch(batch) + with mesh: + state, train_metric = train_step(state, sharded_batch) + + # train_metrics.append(train_metric) + + + # this is for more accurate time stats, but slows down training + # train_metric['loss'].block_until_ready() + train_time = time.time() - train_start + + + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | " + f"Loss: {train_metric['loss']}, " + f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() +# %% +# with mesh: +# gathered_params = jax.jit( +# lambda x: x, +# in_shardings=(unfreeze(model_named_sharding),), +# out_shardings=mesh_sharding(PartitionSpec()) +# )(state.params) + +main_model = custom_model.from_pretrained('t5-base') +output_dir = save_path +# save checkpoint after each epoch and push checkpoint to the hub +if jax.process_index() == 0: + params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) + main_model.save_pretrained(output_dir, params=params) + +# # stick to defaults +# options = ocp.CheckpointManagerOptions() +# with ocp.CheckpointManager( +# ocp.test_utils.erase_and_create_empty(save_path), +# options=options, +# ) as mngr: +# +# mngr.save(0, args=ocp.args.StandardSave(state)) +# mngr.wait_until_finished() + + # After providing `args` during an initial `save` or `restore` call, the + # `CheckpointManager` instance records the type so that you do not need to + # specify it again. If the `CheckpointManager` instance is not provided with a + # `ocp.args.CheckpointArgs` instance for a particular item on a previous + # occasion it cannot be restored without specifying the argument at restore + # time. + + # # In many cases, you can restore exactly as saved without specifying additional + # # arguments. + # mngr.restore(0) + # # If customization of properties like sharding or dtype is desired, just provide + # # the abstract target PyTree, the properties of which will be used to set + # # the properties of the restored arrays. + # mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree)) + +# %% diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py index 1183864..89b1193 100644 --- a/t5_jax_prediction.py +++ b/t5_jax_prediction.py @@ -13,9 +13,12 @@ # # prediction code # ## import and process test data - # %% # import libraries +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + import pandas as pd import matplotlib.pyplot as plt @@ -35,7 +38,7 @@ jax.config.update("jax_default_matmul_precision", "bfloat16") jax.config.update("jax_enable_x64", False) -from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig +# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets @@ -51,9 +54,13 @@ from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from ml_collections import ConfigDict import time +from parallel.dataload import DataPrepare +import orbax.checkpoint as ocp + # %% @@ -86,121 +93,22 @@ def process_df(df): # takes 1 minute to run without batching test_dataset = Dataset.from_list(process_df(df)) - -# %% [markdown] -# ## Load model for attributes - # %% -# load model -model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name - -# Load configuration -config = AutoConfig.from_pretrained(model_name_or_path) - -# Load model -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - pretrained_model_name_or_path=model_name_or_path -) - - -# %% [markdown] -# ## Tokenizer - -# %% -# prepare tokenizer -from transformers import T5TokenizerFast -tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) -# Define additional special tokens -additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] -# Add the additional special tokens to the tokenizer -tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) - -max_length = 86 - -model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) -shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") - -# given a dataset entry, run it through the tokenizer -# Setting padding="max_length" as we need fixed length inputs for jitted functions -def preprocess_function(example): - inputs = example['input'] - targets = example['output'] - # text_target sets the corresponding label to inputs - # there is no need to create a separate 'labels' - model_inputs = tokenizer( - inputs, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="np" - ) - labels = tokenizer( - text_target=targets, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="np" - ) - - model_inputs["labels"] = labels["input_ids"] - decoder_input_ids = shift_tokens_right_fn( - labels["input_ids"], config.pad_token_id, config.decoder_start_token_id - ) - model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) - - # We need decoder_attention_mask so we can ignore pad tokens from loss - model_inputs["decoder_attention_mask"] = labels["attention_mask"] - - return model_inputs - -# map maps function to each "row" in the dataset -# aka the data in the immediate nesting -test_dataset = test_dataset.map( - preprocess_function, - batched=True, - num_proc=1, - remove_columns=test_dataset.column_names, -) - -def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): - """ - Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, - and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. - """ - if shuffle: - batch_idx = jax.random.permutation(rng, len(dataset)) - batch_idx = np.asarray(batch_idx) - else: - batch_idx = np.arange(len(dataset)) - - if drop_last: - steps_per_epoch = len(dataset) // batch_size - batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. - batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) - else: - steps_per_epoch = math.ceil(len(dataset) / batch_size) - batch_idx = np.array_split(batch_idx, steps_per_epoch) - - for idx in batch_idx: - batch = dataset[idx] - batch = {k: np.array(v) for k, v in batch.items()} - - yield batch - -# %% [markdown] -# # model generation +# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration +from transformers import FlaxT5ForConditionalGeneration +# model_name_or_path = "./t5_80_1" # Replace with your specific model name +model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name +model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path) +params = model.params # %% seed = 117 -num_epochs = 80 -batch_size = 96 -num_train_epochs = num_epochs +batch_size = 128 per_device_train_batch_size = batch_size train_batch_size = per_device_train_batch_size * jax.device_count() per_device_eval_batch_size = batch_size eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(test_dataset) // train_batch_size -total_train_steps = steps_per_epoch * num_epochs num_beams = 1 val_max_target_length = 128 @@ -208,33 +116,31 @@ val_max_target_length = 128 predict_with_generate = True -# Initialize our training +# Initialize our prediction rng = jax.random.PRNGKey(seed) rng, dropout_rng = jax.random.split(rng) - -# %% - -# reload model to prevent leakage of variables -# load model -model_name_or_path = "t5_80_1" # Replace with your specific model name - -# Load configuration -config = AutoConfig.from_pretrained(model_name_or_path) - -# Load model -model = FlaxAutoModelForSeq2SeqLM.from_pretrained( - model_name_or_path +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) ) +dataprep = DataPrepare(test_dataset, data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) + +# %% # Ensure model.params is properly initialized (this is just an example) # Normally you would get this from a model initialization call with dummy input -params = model.params -# ensure full size floats -params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) # we need to replicate model over devices -replicated_params = jax.device_put_replicated(params_f16, jax.devices()) +replicated_params = jax.device_put_replicated(params, jax.devices()) # Define generation function @@ -245,25 +151,29 @@ num_beams = num_beams if num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): - output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) + output_ids = model.generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + params=params, + **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_generate_step = jax.pmap(generate_step, "batch") - - pred_generations = [] pred_labels = [] +decoder_input_list = [] rng, input_rng = jax.random.split(rng) -pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) + +pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False) pred_steps = math.ceil(len(test_dataset) / eval_batch_size) print("***** Running training *****") print(f" Num examples = {len(test_dataset)}") -print(f" Num steps = {num_epochs}") +# print(f" Num steps = {num_epochs}") print(f" Instantaneous batch size per device = {per_device_train_batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") @@ -272,26 +182,38 @@ for _ in tqdm(range(pred_steps), desc="Predicting..."): # Model forward batch = next(pred_loader) labels = batch["labels"] + decoder_input = batch["decoder_input_ids"] + # generation generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) + decoder_input_list.extend(decoder_input) +# %% + # %% [markdown] # # process predictions +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False) +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", "", "", ""] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + # %% # code to get special token ids -# sentence = "" -# tokens = tokenizer.tokenize(sentence) -# print("Tokens:", tokens) -# # Get the IDs (integer indices) of specific tokens -# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] -# print("Token IDs:", token_ids) +sentence = "" +tokens = tokenizer.tokenize(sentence) +print("Tokens:", tokens) +# Get the IDs (integer indices) of specific tokens +token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] +print("Token IDs:", token_ids) # %% @@ -304,7 +226,13 @@ def extract_seq(tokens, start_value, end_value): return tokens[start_id+1:end_id] +# %% +i = 2 +print(pred_generations[i]) +print(extract_seq(pred_generations[i], 32100, 32101)) + +# %% def process_tensor_output(tokens): thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = diff --git a/t5_jax_simple_parallel.py b/t5_jax_simple_parallel.py new file mode 100644 index 0000000..47b3d9b --- /dev/null +++ b/t5_jax_simple_parallel.py @@ -0,0 +1,543 @@ +# %% +import os +# Set this to True to run the model on CPU only. +USE_CPU_ONLY = False + +flags = os.environ.get("XLA_FLAGS", "") +if USE_CPU_ONLY: + flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices + # Enforce CPU-only execution + os.environ["CUDA_VISIBLE_DEVICES"] = "" + os.environ["JAX_PLATFORMS"] = "cpu" +else: + # GPU flags + flags = ( + '--xla_gpu_enable_triton_softmax_fusion=true ' + '--xla_gpu_triton_gemm_any=True ' + # '--xla_gpu_enable_async_collectives=true ' + '--xla_gpu_enable_latency_hiding_scheduler=true ' + '--xla_gpu_enable_highest_priority_async_stream=true ' + ) + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +os.environ["XLA_FLAGS"] = flags +os.environ.update({ + "TOKENIZERS_PARALLELISM" : "false", + "CUDA_DEVICE_MAX_CONNECTIONS" : "1", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", + # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" + }) + + + + +import functools +from functools import partial +from pprint import pprint +from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +# from jax.experimental.pjit import pjit # superseded by jax.jit +from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec +from ml_collections import ConfigDict +import optax +import logging +import time +from datasets import Dataset, load_from_disk + +from flax import jax_utils, traverse_util +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from flax.core.frozen_dict import freeze, unfreeze, FrozenDict +import flax.core + +# model checkpointing and saving utilities +from flax import linen as nn +from flax.training import checkpoints, train_state +from flax import struct, serialization + +from parallel.partitions import set_partitions + +from tqdm import tqdm + +from parallel.dataload import DataPrepare + +# for memory tracking +# from jax_smi import initialise_tracking +# initialise_tracking() + + +PyTree = Any +Metrics = Dict[str, Tuple[jax.Array, ...]] + +if USE_CPU_ONLY: + jax.config.update('jax_platform_name', 'cpu') +else: + jax.config.update("jax_default_matmul_precision", "bfloat16") + +jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") +jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + + + + +# %% +## get platform type +from jax.extend.backend import get_backend +print(get_backend().platform) +print(jax.devices()) + +# %% +# config options +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' +save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) +training_size = len(split_datasets['train']) +# Store some constant +seed = 117 +num_epochs = 5 +batch_size = 64 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = training_size // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.01 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 +num_beams = 1 +val_max_target_length = 128 +predict_with_generate = True + + +# %% +# prepare data +# init object +# e.g. Config +print("preparing data") +data_config = ConfigDict( + dict( + max_length=128, + pad_token_id=0, + decoder_start_token_id=0 + ) +) + +dataprep = DataPrepare(split_datasets['train'], data_config) +# # example usage +# # %% +seed = 117 +rng = jax.random.PRNGKey(seed) +train_loader = dataprep.data_loader(rng, batch_size=batch_size) +batch = next(iter(train_loader)) + +# %% +# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom +from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model +main_model = custom_model.from_pretrained( + "t5-base", + dtype=jnp.bfloat16, + gradient_checkpointing=True, +) +params = main_model.params +# pretrained_params = model.params +model = main_model.module + +# %% +# # testing config hashability +# # some explanation: +# # The PreTrainedModel class loads a T5Config model that is not hashable because +# # it is a complicated class that pretends to be a dataclass. +# # The solution is to extract a dict from it, then make a ConfigDict from +# # ml_collections library so that we can get values via the "." operator. +# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to +# # modify the config before passing to the next layer +# from transformers import T5Config +# from t5_model.configuration_t5 import FrozenT5Config +# from ml_collections import ConfigDict, FrozenConfigDict +# +# config = T5Config.from_pretrained("t5-base").to_dict() +# config.pop('architectures') +# config.pop('id2label') +# # test if it works +# frozen_config = FrozenConfigDict(config) +# # test hash +# hash(frozen_config) + +# %% +# # print model +# rng, input_rng = jax.random.split(rng) +# model.tabulate( +# input_rng, +# input_ids=batch['input_ids'], +# attention_mask=batch['attention_mask'], +# decoder_input_ids=batch['decoder_input_ids'], +# decoder_attention_mask=batch['decoder_attention_mask'], +# console_kwargs={"force_jupyter": True} +# ) + + + +# %% +# create mesh +print("creating mesh") +device_mesh = mesh_utils.create_device_mesh((2,2)) +print(device_mesh) + +mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) +print(mesh) + +def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: + return NamedSharding(mesh, pspec, memory_kind="device") + +x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis +model_sharding=mesh_sharding(PartitionSpec(None, 'model')) + + +# %% +# optimizers + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + training_size, + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + +# %% + +print("compile") + +# enable bf16 +# enable only for dense, some transformer sections, and shared +def create_mask_for_layer_norm(params): + flat_params = traverse_util.flatten_dict(params) + mask = { + # path: not ( + # (path[-2] == "layer_norm" and path[-1] == "weight") or + # (path[-2] == "final_layer_norm" and path[-1] == "weight") or + # (path[-2] == "o" and path[-1] == "kernel") + # ) + # for path in flat_params + path: ( + (path[-2] == "wi" and path[-1] == "weight") or + (path[-2] == "wo" and path[-1] == "weight") or + (path[-2] == "k" and path[-1] == "kernel") or + (path[-2] == "q" and path[-1] == "kernel") or + (path[-2] == "v" and path[-1] == "kernel") or + (path[-2] == "shared" and path[-1] == "embedding") + ) for path in flat_params + } + mask = traverse_util.unflatten_dict(mask) + return mask + +# borrowed from transformers modeling_flax_utils +def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = traverse_util.flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return traverse_util.unflatten_dict(flat_params) + +# create init_fn to produce sharded state +def init_fn(params, model, optimizer): + # do be careful with the model init + # imported models might have complicated init methods + + # mask = create_mask_for_layer_norm(params) + # override params with bfloat version + # params= cast_floating_to(params, jnp.bfloat16, mask) + + state = train_state.TrainState.create( # Create a `TrainState`. + apply_fn=model.apply, + params=params, + tx=optimizer) + return state + + +abstract_variables = jax.eval_shape( + functools.partial(init_fn, model=model, optimizer=adamw), params) + +# jax.sharding: describes how a jax.Array is laid out across devices +state_sharding = nn.get_sharding(abstract_variables, mesh) +# print(state_sharding) + +# %% + +# replace the params tree with the new modified tree +# create partitions for model +from parallel.partitions import set_partitions +# set_partitions freezes the params on return +model_part_spec = set_partitions(unfreeze(params)) +# p is already a partition spec +model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) + +# get pspec for opt_state +def get_opt_spec(x): + if isinstance(x, dict): + return unfreeze(model_named_sharding) + # return an empty partspec + return mesh_sharding((PartitionSpec())) + +# this function replaces the empty model params spec with the 'model_named_shard' +state_sharding = jax.tree.map( + get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) +) + + + +jit_init_fn = jax.jit( + init_fn, + static_argnames=('model', 'optimizer'), # skip model and optimizer + in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly + out_shardings=state_sharding # but returned initialized_state is sharded +) +initialized_state = jit_init_fn(params, model, adamw) + +# %% +# train step +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + logits = jnp.asarray(logits, dtype=jnp.float32) + logits = logits.astype(jnp.float32) + soft_labels = soft_labels.astype(jnp.float32) + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.mean() + # num_labels = padding_mask.mean() + return loss # , num_labels + +# %% + +# single device code annotated with jax.jit +@functools.partial( + jax.jit, + # state is state_sharding initialized from init_fn + # x_sharding is data sharded explicitly later + in_shardings=(state_sharding, x_sharding), + out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), + donate_argnames=('state'), +) +def train_step(state, batch): + + label_smoothing_factor=0.0 + # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + def compute_loss(params, batch): + # check constraints + # frozen dict not allowed as sharding object + params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) + batch = jax.lax.with_sharding_constraint(batch, x_sharding) + logits = state.apply_fn( + {'params': params}, + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + decoder_input_ids=batch['decoder_input_ids'], + decoder_attention_mask=batch['decoder_attention_mask'], + )[0] # zero because output is some structure, where first is the logit + # use labels here + # loss, num_labels = loss_fn( + loss = loss_fn( + logits, + batch["labels"], + batch["decoder_attention_mask"], + label_smoothing_factor) + return loss # , num_labels + + # compute gradients through computational graph + # allow values to pass through + grad_fn = jax.value_and_grad(compute_loss, has_aux=False) + (loss), grad = grad_fn(state.params, batch) + # num_labels = jax.lax.psum(num_labels, "batch") + + + new_state = state.apply_gradients(grads=grad) + with jax.named_scope("sync_metrics"): + step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + + return new_state, step_metrics + +# %% +# explore data sharding +sharded_batch = next(iter(train_loader)) +sharded_batch = jax.device_put(sharded_batch, x_sharding) +# jax.debug.visualize_array_sharding(sharded_batch['input_ids']) +# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) + + +# %% +# # prep 1 step +# print("1 step for jit-ting") +# with mesh: +# state, metrics = train_step(initialized_state, sharded_batch) + +# %% + +# %% +# tr +print("***** Running training *****") +print(f" Num examples = {training_size}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% +# jax.profiler.start_trace("./traces") + + +print("*" * 10) +print("training start") +rng, input_rng = jax.random.split(rng) +train_time = 0 +state = initialized_state +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +for epoch in epochs: + train_start = time.time() + + # Create sampling rng + train_metrics = [] + steps_per_epoch = training_size // train_batch_size + train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) + # Generate an epoch by shuffling sampling indices from the train dataset + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = jax.device_put(batch, x_sharding) + with mesh: + state, train_metric = train_step(state, batch) + + # train_metrics.append(train_metric) + + + # this is for more accurate time stats, but slows down training + # train_metric['loss'].block_until_ready() + train_time = time.time() - train_start + + + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | " + f"Loss: {train_metric['loss']}, " + f"Learning Rate:{train_metric['learning_rate']}, " + f"Last train time: {train_time})" + ) +# jax.profiler.stop_trace() + +# %% +# try out +gather_state = jax.device_get(state) +gather_batch = jax.device_get(batch) +logits = gather_state.apply_fn( + {'params': gather_state.params}, + input_ids=gather_batch['input_ids'], + attention_mask=gather_batch['attention_mask'], + decoder_input_ids=gather_batch['decoder_input_ids'], + decoder_attention_mask=gather_batch['decoder_attention_mask'], +)[0] # zero because output is some structure, where first is the logit + +probs = nn.softmax(logits, axis=-1) +predicted = jnp.argmax(probs, axis=-1) +predicted[1] + +# %% +main_model = custom_model.from_pretrained('t5-base') +output_dir = save_path + +# save checkpoint after each epoch and push checkpoint to the hub +if jax.process_index() == 0: + params = jax.device_get(state.params) + main_model.save_pretrained(output_dir, params=params) + + +# %% diff --git a/t5_prediction_old.py b/t5_prediction_old.py new file mode 100644 index 0000000..87edd6c --- /dev/null +++ b/t5_prediction_old.py @@ -0,0 +1,360 @@ + +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# --- + +# %% [markdown] +# # prediction code +# ## import and process test data + + +# %% +# import libraries +import pandas as pd +import matplotlib.pyplot as plt + +from datasets import Dataset, DatasetDict + +import jax +import jax.numpy as jnp +import optax +import numpy as np +from functools import partial +from typing import Callable, Optional +import math + +# jax.config.update("jax_default_matmul_precision", "tensorfloat32") +jax.config.update("jax_default_matmul_precision", "high") + +jax.config.update("jax_enable_x64", False) + + +from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig + + +import datasets +from datasets import Dataset +import evaluate +from tqdm import tqdm + + +import nltk # Here to have a nice missing dependency error message early on + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key + + +import time + + +# %% + +# data_path = f"../make_data/select_db/data_mapping_filtered.csv" +# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" +data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' +# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' + +# Ensure to include 'ships_idx' in the fields list +fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] + +# Load the dataset +df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) + +def process_df(df): + output_list = [{ + 'input': f"{row['tag_name']}{row['tag_description']}", + # 'input': f"{row['tag_description']}", + # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", + # 'input': f"{row['tag_description']}{row['unit']}", + 'output': f"{row['thing']}{row['property']}", + # 'answer': f"{row['thing']} {row['property']}", + # 'answer_thing': row['thing'], + # 'answer_property': row['property'], + } for _, row in df.iterrows()] + + return output_list + + +# takes 1 minute to run without batching +test_dataset = Dataset.from_list(process_df(df)) + + +# %% [markdown] +# ## Load model for attributes + +# %% +# load model +model_name_or_path = "./t5_80_1" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + pretrained_model_name_or_path=model_name_or_path +) + + +# %% [markdown] +# ## Tokenizer + +# %% +# prepare tokenizer +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + +max_length = 86 + +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + +# given a dataset entry, run it through the tokenizer +# Setting padding="max_length" as we need fixed length inputs for jitted functions +def preprocess_function(example): + inputs = example['input'] + targets = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + inputs, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + labels = tokenizer( + text_target=targets, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + +# map maps function to each "row" in the dataset +# aka the data in the immediate nesting +test_dataset = test_dataset.map( + preprocess_function, + batched=True, + num_proc=1, + remove_columns=test_dataset.column_names, +) + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: np.array(v) for k, v in batch.items()} + + yield batch + +# %% [markdown] +# # model generation + +# %% +seed = 117 +num_epochs = 80 +batch_size = 96 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = len(test_dataset) // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +num_beams = 1 +val_max_target_length = 128 + +predict_with_generate = True + + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + + +# %% + +# reload model to prevent leakage of variables +# load model +model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path +) + + +# Ensure model.params is properly initialized (this is just an example) +# Normally you would get this from a model initialization call with dummy input +params = model.params +# ensure full size floats +params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params) +# we need to replicate model over devices +replicated_params = jax.device_put_replicated(params_f16, jax.devices()) + + +# Define generation function +max_length = ( + val_max_target_length if val_max_target_length is not None else model.config.max_length +) +num_beams = num_beams if num_beams is not None else model.config.num_beams +gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +def generate_step(params, batch): + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs) + return output_ids.sequences + +# Create parallel version of the train and eval step +p_generate_step = jax.pmap(generate_step, "batch") + + + +pred_generations = [] +pred_labels = [] + +rng, input_rng = jax.random.split(rng) + +pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) +pred_steps = math.ceil(len(test_dataset) / eval_batch_size) + +print("***** Running training *****") +print(f" Num examples = {len(test_dataset)}") +print(f" Num steps = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") + + +for _ in tqdm(range(pred_steps), desc="Predicting..."): + # Model forward + batch = next(pred_loader) + labels = batch["labels"] + + # generation + generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch) + pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + pred_labels.extend(labels) + + + +# %% [markdown] +# # process predictions + + +# %% +# code to get special token ids +# sentence = "" +# tokens = tokenizer.tokenize(sentence) +# print("Tokens:", tokens) +# # Get the IDs (integer indices) of specific tokens +# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens] +# print("Token IDs:", token_ids) + + +# %% +# extract sequence and decode +def extract_seq(tokens, start_value, end_value): + if start_value not in tokens or end_value not in tokens: + return None # Or handle this case according to your requirements + start_id = np.where(tokens == start_value)[0][0] + end_id = np.where(tokens == end_value)[0][0] + + return tokens[start_id+1:end_id] + + +def process_tensor_output(tokens): + thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = , 32101 = + property_seq = extract_seq(tokens, 32102, 32103) # 32102 = , 32103 = + p_thing = None + p_property = None + if (thing_seq is not None): + p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain + if (property_seq is not None): + p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain + return p_thing, p_property + + +# %% +# decode prediction labels +def decode_preds(tokens_list): + thing_prediction_list = [] + property_prediction_list = [] + for tokens in tokens_list: + p_thing, p_property = process_tensor_output(tokens) + thing_prediction_list.append(p_thing) + property_prediction_list.append(p_property) + return thing_prediction_list, property_prediction_list + +thing_prediction_list, property_prediction_list = decode_preds(pred_generations) + +# %% +# add labels too +thing_actual_list, property_actual_list = decode_preds(pred_labels) + +# Convert the list to a Pandas DataFrame +df = pd.DataFrame({'p_thing': thing_prediction_list, + 'p_property': property_prediction_list, + 'thing': thing_actual_list, + 'property' : property_actual_list}) + +df['p_thing_correct'] = df['p_thing'] == df['thing'] +df['p_property_correct'] = df['p_property'] == df['property'] + +# %% +print("thing accuracy", sum(df['p_thing_correct'])/len(df)) +print("property accuracy", sum(df['p_property_correct'])/len(df)) +print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) +# %% +df[~df["p_property_correct"]] + +# %% +df['p_thing'] +# %% +# Save the DataFrame as a Parquet file (using pyarrow or fastparquet) +# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" + +