Feat: implement working prediction
This commit is contained in:
parent
f523560141
commit
edd9c3551f
|
@ -1,3 +1,4 @@
|
||||||
*.ipynb
|
*.ipynb
|
||||||
t5_*/
|
t5_*/
|
||||||
exports/
|
exports/
|
||||||
|
modified_t5_model/
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Define the folder to check
|
||||||
|
folder = Path(".")
|
||||||
|
|
||||||
|
# Get all .py and .ipynb files in the folder
|
||||||
|
py_files = {file.stem: file for file in folder.glob("*.py")}
|
||||||
|
ipynb_files = {file.stem: file for file in folder.glob("*.ipynb")}
|
||||||
|
|
||||||
|
# Check for linked .py and .ipynb files
|
||||||
|
all_newer = True
|
||||||
|
|
||||||
|
for stem, py_file in py_files.items():
|
||||||
|
if stem in ipynb_files:
|
||||||
|
ipynb_file = ipynb_files[stem]
|
||||||
|
|
||||||
|
# Compare the modification times
|
||||||
|
if py_file.stat().st_mtime > ipynb_file.stat().st_mtime:
|
||||||
|
print(f"{py_file} is newer than {ipynb_file}.")
|
||||||
|
else:
|
||||||
|
print(f"{py_file} is not newer than {ipynb_file}.")
|
||||||
|
all_newer = False
|
||||||
|
|
||||||
|
if all_newer:
|
||||||
|
print("All linked .py files are newer than their corresponding .ipynb files.")
|
||||||
|
else:
|
||||||
|
print("Some .py files are not newer than their corresponding .ipynb files.")
|
293
t5_jax.py
293
t5_jax.py
|
@ -16,67 +16,6 @@
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # T5 implementation using jax
|
# # T5 implementation using jax
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# ## import
|
|
||||||
|
|
||||||
# %% [raw]
|
|
||||||
# import json
|
|
||||||
# import logging
|
|
||||||
# import math
|
|
||||||
# import os
|
|
||||||
# import sys
|
|
||||||
# import time
|
|
||||||
# from dataclasses import asdict, dataclass, field
|
|
||||||
# from enum import Enum
|
|
||||||
# from functools import partial
|
|
||||||
# from pathlib import Path
|
|
||||||
# from typing import Callable, Optional
|
|
||||||
#
|
|
||||||
# import datasets
|
|
||||||
# import evaluate
|
|
||||||
# import jax
|
|
||||||
# import jax.numpy as jnp
|
|
||||||
# import nltk # Here to have a nice missing dependency error message early on
|
|
||||||
# import numpy as np
|
|
||||||
# import optax
|
|
||||||
# from datasets import Dataset, load_dataset
|
|
||||||
# from filelock import FileLock
|
|
||||||
# from flax import jax_utils, traverse_util
|
|
||||||
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
||||||
# from flax.training import train_state
|
|
||||||
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
||||||
# from tqdm import tqdm
|
|
||||||
#
|
|
||||||
# import transformers
|
|
||||||
# from transformers import (
|
|
||||||
# CONFIG_MAPPING,
|
|
||||||
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
|
||||||
# AutoConfig,
|
|
||||||
# AutoTokenizer,
|
|
||||||
# FlaxAutoModelForSeq2SeqLM,
|
|
||||||
# HfArgumentParser,
|
|
||||||
# is_tensorboard_available,
|
|
||||||
# )
|
|
||||||
# from transformers.utils import is_offline_mode, send_example_telemetry
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# logger = logging.getLogger(__name__)
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# nltk.data.find("tokenizers/punkt")
|
|
||||||
# except (LookupError, OSError):
|
|
||||||
# if is_offline_mode():
|
|
||||||
# raise LookupError(
|
|
||||||
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
|
||||||
# )
|
|
||||||
# with FileLock(".lock") as lock:
|
|
||||||
# nltk.download("punkt", quiet=True)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
|
||||||
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
@ -88,8 +27,11 @@ import math
|
||||||
|
|
||||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
jax.config.update("jax_default_matmul_precision", "high")
|
jax.config.update("jax_default_matmul_precision", "high")
|
||||||
|
|
||||||
jax.config.update("jax_enable_x64", False)
|
jax.config.update("jax_enable_x64", False)
|
||||||
|
# enable cache
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
@ -108,6 +50,7 @@ from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
@ -116,14 +59,15 @@ import time
|
||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
os.environ['XLA_FLAGS'] = (
|
os.environ['XLA_FLAGS'] = (
|
||||||
'--xla_gpu_enable_triton_softmax_fusion=True '
|
'--xla_gpu_triton_gemm_any=true --xla_gpu_enable_custom_fusions=true --xla_gpu_enable_address_computation_fusion=true'
|
||||||
'--xla_gpu_triton_gemm_any=True '
|
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ.update({
|
os.environ.update({
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
"NCCL_LL128_BUFFSIZE": "-2",
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
"NCCL_LL_BUFFSIZE": "-2",
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : ".95"
|
||||||
})
|
})
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -132,17 +76,10 @@ print(xla_bridge.get_backend().platform)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# nltk.download('punkt')
|
|
||||||
try:
|
try:
|
||||||
nltk.data.find("tokenizers/punkt")
|
nltk.data.find("tokenizers/punkt")
|
||||||
except (LookupError, OSError):
|
except (LookupError, OSError):
|
||||||
if is_offline_mode():
|
print("error")
|
||||||
raise LookupError(
|
|
||||||
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
|
||||||
)
|
|
||||||
with FileLock(".lock") as lock:
|
|
||||||
nltk.download("punkt", quiet=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
|
@ -153,17 +90,8 @@ except (LookupError, OSError):
|
||||||
model_name_or_path = "t5-small" # Replace with your specific model name
|
model_name_or_path = "t5-small" # Replace with your specific model name
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
config = AutoConfig.from_pretrained(model_name_or_path,
|
||||||
|
force_download=False)
|
||||||
# Load model
|
|
||||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
||||||
model_name_or_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
||||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -173,7 +101,11 @@ save_path = 't5_80_1'
|
||||||
# file_path = 'combined_data'
|
# file_path = 'combined_data'
|
||||||
split_datasets = load_from_disk(file_path)
|
split_datasets = load_from_disk(file_path)
|
||||||
|
|
||||||
# prepare tokenizer
|
# %%
|
||||||
|
|
||||||
|
split_datasets['train'][0]
|
||||||
|
|
||||||
|
# %%
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||||
# Define additional special tokens
|
# Define additional special tokens
|
||||||
|
@ -183,6 +115,43 @@ tokenizer.add_special_tokens({"additional_special_tokens": additional_special_to
|
||||||
|
|
||||||
max_length = 86
|
max_length = 86
|
||||||
|
|
||||||
|
# %%
|
||||||
|
len(tokenizer)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# load pytorch model first
|
||||||
|
# from transformers import AutoModelForSeq2SeqLM
|
||||||
|
# model_checkpoint = "t5-base"
|
||||||
|
# model_pt = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
|
||||||
|
# # important! after extending tokens vocab
|
||||||
|
# model_pt.resize_token_embeddings(len(tokenizer))
|
||||||
|
# model_pt.save_pretrained('./modified_t5_model')
|
||||||
|
# model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
# pretrained_model_name_or_path="modified_t5_model",
|
||||||
|
# dtype=jax.numpy.bfloat16,
|
||||||
|
# from_pt=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_path = './t5_80_1'
|
||||||
|
# model_path = 't5=base'
|
||||||
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=model_path,
|
||||||
|
dtype=jax.numpy.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model.params_shape_tree['shared']
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||||
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||||
# for that dynamically import the `shift_tokens_right` function from the model file
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||||
|
@ -191,21 +160,19 @@ max_length = 86
|
||||||
# given a dataset entry, run it through the tokenizer
|
# given a dataset entry, run it through the tokenizer
|
||||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||||
def preprocess_function(example):
|
def preprocess_function(example):
|
||||||
input = example['input']
|
inputs = example['input']
|
||||||
target = example['output']
|
targets = example['output']
|
||||||
# text_target sets the corresponding label to inputs
|
# text_target sets the corresponding label to inputs
|
||||||
# there is no need to create a separate 'labels'
|
# there is no need to create a separate 'labels'
|
||||||
model_inputs = tokenizer(
|
model_inputs = tokenizer(
|
||||||
input,
|
inputs,
|
||||||
text_target=target,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
labels = tokenizer(
|
labels = tokenizer(
|
||||||
input,
|
text_target=targets,
|
||||||
text_target=target,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
@ -233,15 +200,18 @@ tokenized_datasets = split_datasets.map(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tokenized_datasets.set_format(type='numpy',
|
||||||
|
columns=['input_ids', 'attention_mask',
|
||||||
# %%
|
'labels', 'decoder_input_ids',
|
||||||
tokenized_datasets
|
'decoder_attention_mask'])
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
eval_dataset = tokenized_datasets["validation"]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
train_dataset[0]
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
@ -270,65 +240,14 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# Now we have model inputs in terms of the variable tokenized_datasets
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# metric
|
|
||||||
metric = evaluate.load("sacrebleu")
|
|
||||||
|
|
||||||
def postprocess_text(preds, labels):
|
|
||||||
preds = [pred.strip() for pred in preds]
|
|
||||||
labels = [label.strip() for label in labels]
|
|
||||||
|
|
||||||
# rougeLSum expects newline after each sentence
|
|
||||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
|
||||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
|
||||||
|
|
||||||
return preds, labels
|
|
||||||
|
|
||||||
# def compute_metrics(preds, labels):
|
|
||||||
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
#
|
|
||||||
# # Some simple post-processing
|
|
||||||
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
|
||||||
#
|
|
||||||
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
||||||
# result = {k: round(v * 100, 4) for k, v in result.items()}
|
|
||||||
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
|
||||||
# result["gen_len"] = np.mean(prediction_lens)
|
|
||||||
# return result
|
|
||||||
|
|
||||||
def compute_metrics(preds, labels):
|
|
||||||
# In case the model returns more than the prediction logits
|
|
||||||
if isinstance(preds, tuple):
|
|
||||||
preds = preds[0]
|
|
||||||
|
|
||||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
|
|
||||||
# Replace -100s in the labels as we can't decode them
|
|
||||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
||||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
|
|
||||||
# Some simple post-processing
|
|
||||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
|
||||||
decoded_labels = [[label.strip()] for label in decoded_labels]
|
|
||||||
|
|
||||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
||||||
return {"bleu": result["score"]}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Model
|
# # Model
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Store some constant
|
# Store some constant
|
||||||
seed = 117
|
seed = 117
|
||||||
num_epochs = 80
|
num_epochs = 40
|
||||||
batch_size = 96
|
batch_size = 32
|
||||||
num_train_epochs = num_epochs
|
num_train_epochs = num_epochs
|
||||||
per_device_train_batch_size = batch_size
|
per_device_train_batch_size = batch_size
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
@ -338,16 +257,16 @@ steps_per_epoch = len(train_dataset) // train_batch_size
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
warmup_steps = 0
|
warmup_steps = 0
|
||||||
learning_rate = 5e-5
|
learning_rate = 2e-5
|
||||||
|
|
||||||
weight_decay = 0.0
|
weight_decay = 0.01
|
||||||
adam_beta1 = 0.9
|
adam_beta1 = 0.9
|
||||||
adam_beta2 = 0.999
|
adam_beta2 = 0.999
|
||||||
adam_epsilon = 1e-8
|
adam_epsilon = 1e-8
|
||||||
label_smoothing_factor = 0.0
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
num_beams = 1
|
num_beams = 1
|
||||||
val_max_target_length = None
|
val_max_target_length = 128
|
||||||
|
|
||||||
predict_with_generate = True
|
predict_with_generate = True
|
||||||
|
|
||||||
|
@ -421,15 +340,14 @@ class TrainState(train_state.TrainState):
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||||
|
|
||||||
# Ensure model.params is properly initialized (this is just an example)
|
# set bf16 for model params
|
||||||
# Normally you would get this from a model initialization call with dummy input
|
# model.params = model.to_bf16(model.params)
|
||||||
params = model.params
|
params = model.params
|
||||||
# Cast parameters to bfloat16 if desired
|
# Cast parameters to bfloat16 if desired
|
||||||
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
# label smoothed cross entropy
|
# label smoothed cross entropy
|
||||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
@ -481,21 +399,6 @@ def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
return new_state, metrics
|
return new_state, metrics
|
||||||
|
|
||||||
# Define eval fn
|
|
||||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
|
||||||
labels = batch.pop("labels")
|
|
||||||
logits = model(**batch, params=params, train=False)[0]
|
|
||||||
|
|
||||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
||||||
num_labels = jax.lax.psum(num_labels, "batch")
|
|
||||||
|
|
||||||
# true loss = total loss / total samples
|
|
||||||
loss = jax.lax.psum(loss, "batch")
|
|
||||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
||||||
|
|
||||||
metrics = {"loss": loss}
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
max_length = (
|
max_length = (
|
||||||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||||
|
@ -512,7 +415,7 @@ def generate_step(params, batch):
|
||||||
p_train_step = jax.pmap(
|
p_train_step = jax.pmap(
|
||||||
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
||||||
)
|
)
|
||||||
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
# p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
||||||
p_generate_step = jax.pmap(generate_step, "batch")
|
p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
# Replicate the train state on each device
|
# Replicate the train state on each device
|
||||||
|
@ -563,50 +466,6 @@ for epoch in epochs:
|
||||||
f" {train_metric['learning_rate']})"
|
f" {train_metric['learning_rate']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
|
||||||
eval_metrics = []
|
|
||||||
eval_preds = []
|
|
||||||
eval_labels = []
|
|
||||||
|
|
||||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
|
||||||
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
|
||||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
|
||||||
# Model forward
|
|
||||||
batch = next(eval_loader)
|
|
||||||
labels = batch["labels"]
|
|
||||||
|
|
||||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
|
||||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
|
||||||
)
|
|
||||||
eval_metrics.append(metrics)
|
|
||||||
|
|
||||||
# generation
|
|
||||||
if predict_with_generate:
|
|
||||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
|
||||||
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
|
||||||
eval_labels.extend(labels)
|
|
||||||
|
|
||||||
# normalize eval metrics
|
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
|
||||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
rouge_desc = ""
|
|
||||||
if predict_with_generate:
|
|
||||||
rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
|
||||||
eval_metrics.update(rouge_metrics)
|
|
||||||
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
|
||||||
|
|
||||||
# Print metrics and update progress bar
|
|
||||||
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
|
||||||
epochs.write(desc)
|
|
||||||
epochs.desc = desc
|
|
||||||
|
|
||||||
# Save metrics
|
|
||||||
# if has_tensorboard and jax.process_index() == 0:
|
|
||||||
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
|
||||||
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
|
||||||
|
|
||||||
output_dir = save_path
|
output_dir = save_path
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
|
@ -614,7 +473,3 @@ for epoch in epochs:
|
||||||
model.save_pretrained(output_dir, params=params)
|
model.save_pretrained(output_dir, params=params)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# #
|
|
||||||
|
|
|
@ -39,10 +39,9 @@ from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset
|
||||||
import evaluate
|
import evaluate
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from datasets import load_from_disk
|
|
||||||
|
|
||||||
|
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
@ -76,9 +75,9 @@ def process_df(df):
|
||||||
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||||
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||||
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||||||
'answer': f"{row['thing']} {row['property']}",
|
# 'answer': f"{row['thing']} {row['property']}",
|
||||||
'answer_thing': row['thing'],
|
# 'answer_thing': row['thing'],
|
||||||
'answer_property': row['property'],
|
# 'answer_property': row['property'],
|
||||||
} for _, row in df.iterrows()]
|
} for _, row in df.iterrows()]
|
||||||
|
|
||||||
return output_list
|
return output_list
|
||||||
|
@ -93,14 +92,14 @@ test_dataset = Dataset.from_list(process_df(df))
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# load model
|
# load model
|
||||||
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_name_or_path
|
pretrained_model_name_or_path=model_name_or_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,21 +123,19 @@ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||||
# given a dataset entry, run it through the tokenizer
|
# given a dataset entry, run it through the tokenizer
|
||||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||||
def preprocess_function(example):
|
def preprocess_function(example):
|
||||||
input = example['input']
|
inputs = example['input']
|
||||||
target = example['output']
|
targets = example['output']
|
||||||
# text_target sets the corresponding label to inputs
|
# text_target sets the corresponding label to inputs
|
||||||
# there is no need to create a separate 'labels'
|
# there is no need to create a separate 'labels'
|
||||||
model_inputs = tokenizer(
|
model_inputs = tokenizer(
|
||||||
input,
|
inputs,
|
||||||
text_target=target,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
labels = tokenizer(
|
labels = tokenizer(
|
||||||
input,
|
text_target=targets,
|
||||||
text_target=target,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
@ -191,7 +188,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Model Training
|
# # model generation
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
seed = 117
|
seed = 117
|
||||||
|
@ -205,17 +202,8 @@ eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
steps_per_epoch = len(test_dataset) // train_batch_size
|
steps_per_epoch = len(test_dataset) // train_batch_size
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
warmup_steps = 0
|
|
||||||
learning_rate = 5e-5
|
|
||||||
|
|
||||||
weight_decay = 0.0
|
|
||||||
adam_beta1 = 0.9
|
|
||||||
adam_beta2 = 0.999
|
|
||||||
adam_epsilon = 1e-8
|
|
||||||
label_smoothing_factor = 0.0
|
|
||||||
|
|
||||||
num_beams = 1
|
num_beams = 1
|
||||||
val_max_target_length = None
|
val_max_target_length = 128
|
||||||
|
|
||||||
predict_with_generate = True
|
predict_with_generate = True
|
||||||
|
|
||||||
|
@ -224,55 +212,6 @@ predict_with_generate = True
|
||||||
rng = jax.random.PRNGKey(seed)
|
rng = jax.random.PRNGKey(seed)
|
||||||
rng, dropout_rng = jax.random.split(rng)
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
|
|
||||||
def create_learning_rate_fn(
|
|
||||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
||||||
) -> Callable[[int], jnp.ndarray]:
|
|
||||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
||||||
steps_per_epoch = train_ds_size // train_batch_size
|
|
||||||
num_train_steps = steps_per_epoch * num_train_epochs
|
|
||||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
||||||
decay_fn = optax.linear_schedule(
|
|
||||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
||||||
)
|
|
||||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
||||||
return schedule_fn
|
|
||||||
|
|
||||||
|
|
||||||
# Create learning rate schedule
|
|
||||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
||||||
len(test_dataset),
|
|
||||||
train_batch_size,
|
|
||||||
num_train_epochs,
|
|
||||||
warmup_steps,
|
|
||||||
learning_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use Optax's "masking" functionality to not apply weight decay
|
|
||||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
||||||
# mask boolean with the same structure as the parameters.
|
|
||||||
# The mask is True for parameters that should be decayed.
|
|
||||||
def decay_mask_fn(params):
|
|
||||||
flat_params = traverse_util.flatten_dict(params)
|
|
||||||
# find out all LayerNorm parameters
|
|
||||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
||||||
layer_norm_named_params = {
|
|
||||||
layer[-2:]
|
|
||||||
for layer_norm_name in layer_norm_candidates
|
|
||||||
for layer in flat_params.keys()
|
|
||||||
if layer_norm_name in "".join(layer).lower()
|
|
||||||
}
|
|
||||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
||||||
return traverse_util.unflatten_dict(flat_mask)
|
|
||||||
|
|
||||||
# create adam optimizer
|
|
||||||
adamw = optax.adamw(
|
|
||||||
learning_rate=linear_decay_lr_schedule_fn,
|
|
||||||
b1=adam_beta1,
|
|
||||||
b2=adam_beta2,
|
|
||||||
eps=adam_epsilon,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
mask=decay_mask_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
@ -288,23 +227,14 @@ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_name_or_path
|
model_name_or_path
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training functions
|
|
||||||
class TrainState(train_state.TrainState):
|
|
||||||
dropout_rng: jnp.ndarray
|
|
||||||
|
|
||||||
def replicate(self):
|
|
||||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
||||||
|
|
||||||
# Ensure model.params is properly initialized (this is just an example)
|
# Ensure model.params is properly initialized (this is just an example)
|
||||||
# Normally you would get this from a model initialization call with dummy input
|
# Normally you would get this from a model initialization call with dummy input
|
||||||
params = model.params
|
params = model.params
|
||||||
# Cast parameters to bfloat16 if desired
|
# ensure full size floats
|
||||||
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
# we need to replicate model over devices
|
||||||
|
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
|
||||||
# Setup train state
|
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
|
@ -315,18 +245,14 @@ num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
def generate_step(params, batch):
|
def generate_step(params, batch):
|
||||||
model.params = params
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
|
||||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
|
||||||
return output_ids.sequences
|
return output_ids.sequences
|
||||||
|
|
||||||
# Create parallel version of the train and eval step
|
# Create parallel version of the train and eval step
|
||||||
p_generate_step = jax.pmap(generate_step, "batch")
|
p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
# Replicate the train state on each device
|
|
||||||
state = state.replicate()
|
|
||||||
|
|
||||||
|
|
||||||
pred_metrics = []
|
|
||||||
pred_generations = []
|
pred_generations = []
|
||||||
pred_labels = []
|
pred_labels = []
|
||||||
|
|
||||||
|
@ -342,45 +268,92 @@ print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False):
|
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||||||
# Model forward
|
# Model forward
|
||||||
batch = next(pred_loader)
|
batch = next(pred_loader)
|
||||||
labels = batch["labels"]
|
labels = batch["labels"]
|
||||||
|
|
||||||
# generation
|
# generation
|
||||||
if predict_with_generate:
|
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
||||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||||
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
pred_labels.extend(labels)
|
||||||
pred_labels.extend(labels)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Print metrics
|
# %% [markdown]
|
||||||
# desc = f"Predict Loss: {pred_metrics['loss']})"
|
# # process predictions
|
||||||
# print(desc)
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# save predictions to parquet
|
# code to get special token ids
|
||||||
|
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
|
||||||
|
# tokens = tokenizer.tokenize(sentence)
|
||||||
|
# print("Tokens:", tokens)
|
||||||
|
# # Get the IDs (integer indices) of specific tokens
|
||||||
|
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
||||||
|
# print("Token IDs:", token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# extract sequence and decode
|
||||||
|
def extract_seq(tokens, start_value, end_value):
|
||||||
|
if start_value not in tokens or end_value not in tokens:
|
||||||
|
return None # Or handle this case according to your requirements
|
||||||
|
start_id = np.where(tokens == start_value)[0][0]
|
||||||
|
end_id = np.where(tokens == end_value)[0][0]
|
||||||
|
|
||||||
|
return tokens[start_id+1:end_id]
|
||||||
|
|
||||||
|
|
||||||
|
def process_tensor_output(tokens):
|
||||||
|
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
||||||
|
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
||||||
|
p_thing = None
|
||||||
|
p_property = None
|
||||||
|
if (thing_seq is not None):
|
||||||
|
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||||||
|
if (property_seq is not None):
|
||||||
|
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||||||
|
return p_thing, p_property
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# decode prediction labels
|
# decode prediction labels
|
||||||
def decode_preds(preds):
|
def decode_preds(tokens_list):
|
||||||
# In case the model returns more than the prediction logits
|
thing_prediction_list = []
|
||||||
if isinstance(preds, tuple):
|
property_prediction_list = []
|
||||||
preds = preds[0]
|
for tokens in tokens_list:
|
||||||
|
p_thing, p_property = process_tensor_output(tokens)
|
||||||
|
thing_prediction_list.append(p_thing)
|
||||||
|
property_prediction_list.append(p_property)
|
||||||
|
return thing_prediction_list, property_prediction_list
|
||||||
|
|
||||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
|
||||||
|
|
||||||
decoded_preds = [pred for pred in decoded_preds]
|
|
||||||
|
|
||||||
return decoded_preds
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# add labels too
|
||||||
|
thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
||||||
|
|
||||||
# Convert the list to a Pandas DataFrame
|
# Convert the list to a Pandas DataFrame
|
||||||
df = pd.DataFrame(decode_preds(pred_labels))
|
df = pd.DataFrame({'p_thing': thing_prediction_list,
|
||||||
|
'p_property': property_prediction_list,
|
||||||
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
'thing': thing_actual_list,
|
||||||
df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
'property' : property_actual_list})
|
||||||
|
|
||||||
|
|
||||||
|
df['p_thing_correct'] = df['p_thing'] == df['thing']
|
||||||
|
df['p_property_correct'] = df['p_property'] == df['property']
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
||||||
|
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
||||||
|
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
||||||
|
# %%
|
||||||
|
df[~df["p_property_correct"]]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df['p_thing']
|
||||||
|
# %%
|
||||||
|
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||||||
|
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue