625 lines
20 KiB
Python
625 lines
20 KiB
Python
# ---
|
|
# jupyter:
|
|
# jupytext:
|
|
# formats: ipynb,py:percent
|
|
# text_representation:
|
|
# extension: .py
|
|
# format_name: percent
|
|
# format_version: '1.3'
|
|
# jupytext_version: 1.16.4
|
|
# kernelspec:
|
|
# display_name: jax
|
|
# language: python
|
|
# name: python3
|
|
# ---
|
|
|
|
# %% [markdown]
|
|
# # T5 implementation using jax
|
|
|
|
# %% [markdown]
|
|
# ## import
|
|
|
|
# %% [raw]
|
|
# import json
|
|
# import logging
|
|
# import math
|
|
# import os
|
|
# import sys
|
|
# import time
|
|
# from dataclasses import asdict, dataclass, field
|
|
# from enum import Enum
|
|
# from functools import partial
|
|
# from pathlib import Path
|
|
# from typing import Callable, Optional
|
|
#
|
|
# import datasets
|
|
# import evaluate
|
|
# import jax
|
|
# import jax.numpy as jnp
|
|
# import nltk # Here to have a nice missing dependency error message early on
|
|
# import numpy as np
|
|
# import optax
|
|
# from datasets import Dataset, load_dataset
|
|
# from filelock import FileLock
|
|
# from flax import jax_utils, traverse_util
|
|
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
# from flax.training import train_state
|
|
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
# from tqdm import tqdm
|
|
#
|
|
# import transformers
|
|
# from transformers import (
|
|
# CONFIG_MAPPING,
|
|
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
|
# AutoConfig,
|
|
# AutoTokenizer,
|
|
# FlaxAutoModelForSeq2SeqLM,
|
|
# HfArgumentParser,
|
|
# is_tensorboard_available,
|
|
# )
|
|
# from transformers.utils import is_offline_mode, send_example_telemetry
|
|
#
|
|
#
|
|
# logger = logging.getLogger(__name__)
|
|
#
|
|
# try:
|
|
# nltk.data.find("tokenizers/punkt")
|
|
# except (LookupError, OSError):
|
|
# if is_offline_mode():
|
|
# raise LookupError(
|
|
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
|
# )
|
|
# with FileLock(".lock") as lock:
|
|
# nltk.download("punkt", quiet=True)
|
|
#
|
|
#
|
|
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
|
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
|
|
|
# %%
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
import numpy as np
|
|
from functools import partial
|
|
from typing import Callable, Optional
|
|
import math
|
|
|
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
jax.config.update("jax_default_matmul_precision", "high")
|
|
|
|
jax.config.update("jax_enable_x64", False)
|
|
|
|
|
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
|
|
|
|
|
import datasets
|
|
from datasets import Dataset, load_dataset
|
|
import evaluate
|
|
|
|
|
|
import nltk # Here to have a nice missing dependency error message early on
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
|
|
|
|
import time
|
|
|
|
|
|
# %%
|
|
import os
|
|
os.environ['XLA_FLAGS'] = (
|
|
'--xla_gpu_enable_triton_softmax_fusion=True '
|
|
'--xla_gpu_triton_gemm_any=True '
|
|
)
|
|
|
|
os.environ.update({
|
|
"CUDA_VISIBLE_DEVICES": "0, 1, 2, 3",
|
|
"NCCL_LL128_BUFFSIZE": "-2",
|
|
"NCCL_LL_BUFFSIZE": "-2",
|
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
|
})
|
|
|
|
# %%
|
|
from jax.lib import xla_bridge
|
|
print(xla_bridge.get_backend().platform)
|
|
|
|
|
|
# %%
|
|
# nltk.download('punkt')
|
|
try:
|
|
nltk.data.find("tokenizers/punkt")
|
|
except (LookupError, OSError):
|
|
if is_offline_mode():
|
|
raise LookupError(
|
|
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
|
)
|
|
with FileLock(".lock") as lock:
|
|
nltk.download("punkt", quiet=True)
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
# ## Prepare datasets
|
|
|
|
# %%
|
|
# load model
|
|
model_name_or_path = "t5-small" # Replace with your specific model name
|
|
|
|
# Load configuration
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
|
|
# Load model
|
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
model_name_or_path
|
|
)
|
|
|
|
|
|
# %%
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
|
|
|
|
|
# %%
|
|
from tqdm import tqdm
|
|
from datasets import load_from_disk
|
|
# Path to saved combined_dataset
|
|
file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5'
|
|
save_path = 't5_80_1_retrieval'
|
|
# file_path = 'combined_data'
|
|
split_datasets = load_from_disk(file_path)
|
|
|
|
# prepare tokenizer
|
|
from transformers import T5TokenizerFast
|
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
# Define additional special tokens
|
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
# Define additional special tokens
|
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>",
|
|
"<CONTEXT>", "<EXAMPLE>", "<INPUT>", "<OUTPUT>"]
|
|
# Add the additional special tokens to the tokenizer
|
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
max_length = 300
|
|
|
|
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
|
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
|
# for that dynamically import the `shift_tokens_right` function from the model file
|
|
|
|
|
|
# given a dataset entry, run it through the tokenizer
|
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
def preprocess_function(example):
|
|
input = example['input']
|
|
target = example['output']
|
|
# text_target sets the corresponding label to inputs
|
|
# there is no need to create a separate 'labels'
|
|
model_inputs = tokenizer(
|
|
input,
|
|
text_target=target,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
labels = tokenizer(
|
|
input,
|
|
text_target=target,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="np"
|
|
)
|
|
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
decoder_input_ids = shift_tokens_right_fn(
|
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
|
)
|
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
|
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
|
|
|
return model_inputs
|
|
|
|
# map maps function to each "row" in the dataset
|
|
# aka the data in the immediate nesting
|
|
tokenized_datasets = split_datasets.map(
|
|
preprocess_function,
|
|
batched=True,
|
|
num_proc=1,
|
|
remove_columns=split_datasets["train"].column_names,
|
|
)
|
|
|
|
|
|
|
|
|
|
# %%
|
|
tokenized_datasets
|
|
|
|
# %%
|
|
train_dataset = tokenized_datasets["train"]
|
|
eval_dataset = tokenized_datasets["validation"]
|
|
|
|
|
|
# %%
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
"""
|
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
"""
|
|
if shuffle:
|
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
batch_idx = np.asarray(batch_idx)
|
|
else:
|
|
batch_idx = np.arange(len(dataset))
|
|
|
|
if drop_last:
|
|
steps_per_epoch = len(dataset) // batch_size
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
else:
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
|
|
|
for idx in batch_idx:
|
|
batch = dataset[idx]
|
|
batch = {k: np.array(v) for k, v in batch.items()}
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
# Now we have model inputs in terms of the variable tokenized_datasets
|
|
|
|
# %%
|
|
# metric
|
|
metric = evaluate.load("sacrebleu")
|
|
|
|
def postprocess_text(preds, labels):
|
|
preds = [pred.strip() for pred in preds]
|
|
labels = [label.strip() for label in labels]
|
|
|
|
# rougeLSum expects newline after each sentence
|
|
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
|
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
|
|
|
return preds, labels
|
|
|
|
# def compute_metrics(preds, labels):
|
|
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
#
|
|
# # Some simple post-processing
|
|
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
|
#
|
|
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
# result = {k: round(v * 100, 4) for k, v in result.items()}
|
|
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
|
# result["gen_len"] = np.mean(prediction_lens)
|
|
# return result
|
|
|
|
def compute_metrics(preds, labels):
|
|
# In case the model returns more than the prediction logits
|
|
if isinstance(preds, tuple):
|
|
preds = preds[0]
|
|
|
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
|
|
# Replace -100s in the labels as we can't decode them
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
|
|
# Some simple post-processing
|
|
decoded_preds = [pred.strip() for pred in decoded_preds]
|
|
decoded_labels = [[label.strip()] for label in decoded_labels]
|
|
|
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
return {"bleu": result["score"]}
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
# # Model
|
|
|
|
# %%
|
|
# Store some constant
|
|
seed = 117
|
|
num_epochs = 80
|
|
batch_size = 36
|
|
num_train_epochs = num_epochs
|
|
per_device_train_batch_size = batch_size
|
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
|
per_device_eval_batch_size = batch_size
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
warmup_steps = 0
|
|
learning_rate = 5e-5
|
|
|
|
weight_decay = 0.0
|
|
adam_beta1 = 0.9
|
|
adam_beta2 = 0.999
|
|
adam_epsilon = 1e-8
|
|
label_smoothing_factor = 0.0
|
|
|
|
num_beams = 1
|
|
val_max_target_length = None
|
|
|
|
predict_with_generate = True
|
|
|
|
|
|
# %%
|
|
|
|
# Initialize our training
|
|
rng = jax.random.PRNGKey(seed)
|
|
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
|
|
# %%
|
|
# optimization functions
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
len(train_dataset),
|
|
train_batch_size,
|
|
num_train_epochs,
|
|
warmup_steps,
|
|
learning_rate,
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=adam_beta1,
|
|
b2=adam_beta2,
|
|
eps=adam_epsilon,
|
|
weight_decay=weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
|
|
# %%
|
|
# Training functions
|
|
class TrainState(train_state.TrainState):
|
|
dropout_rng: jnp.ndarray
|
|
|
|
def replicate(self):
|
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
|
# Ensure model.params is properly initialized (this is just an example)
|
|
# Normally you would get this from a model initialization call with dummy input
|
|
params = model.params
|
|
# Cast parameters to bfloat16 if desired
|
|
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
|
|
|
|
|
# Setup train state
|
|
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
|
|
|
# label smoothed cross entropy
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
"""
|
|
The label smoothing implementation is adapted from Flax's official example:
|
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
confidence = 1.0 - label_smoothing_factor
|
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
|
normalizing_constant = -(
|
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
|
)
|
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
|
loss = loss - normalizing_constant
|
|
|
|
# ignore padded tokens from loss
|
|
loss = loss * padding_mask
|
|
loss = loss.sum()
|
|
num_labels = padding_mask.sum()
|
|
return loss, num_labels
|
|
|
|
# Define gradient update step fn
|
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
def compute_loss(params):
|
|
labels = batch.pop("labels")
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
return loss, num_labels
|
|
|
|
# compute gradients through computational graph
|
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params)
|
|
num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
# true loss = total loss / total samples
|
|
loss = jax.lax.psum(loss, "batch")
|
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
|
|
# true grad = total grad / total samples
|
|
grad = jax.lax.psum(grad, "batch")
|
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
return new_state, metrics
|
|
|
|
# Define eval fn
|
|
def eval_step(params, batch, label_smoothing_factor=0.0):
|
|
labels = batch.pop("labels")
|
|
logits = model(**batch, params=params, train=False)[0]
|
|
|
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
# true loss = total loss / total samples
|
|
loss = jax.lax.psum(loss, "batch")
|
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
|
|
metrics = {"loss": loss}
|
|
return metrics
|
|
|
|
# Define generation function
|
|
max_length = (
|
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
|
)
|
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
|
|
def generate_step(params, batch):
|
|
model.params = params
|
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
|
return output_ids.sequences
|
|
|
|
# Create parallel version of the train and eval step
|
|
p_train_step = jax.pmap(
|
|
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
|
)
|
|
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
|
p_generate_step = jax.pmap(generate_step, "batch")
|
|
|
|
# Replicate the train state on each device
|
|
state = state.replicate()
|
|
|
|
|
|
|
|
# %%
|
|
|
|
|
|
print("***** Running training *****")
|
|
print(f" Num examples = {len(train_dataset)}")
|
|
print(f" Num Epochs = {num_epochs}")
|
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
print(f" Total optimization steps = {total_train_steps}")
|
|
|
|
|
|
# %%
|
|
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
# epochs = range(num_epochs)
|
|
for epoch in epochs:
|
|
# ======================== Training ================================
|
|
train_start = time.time()
|
|
|
|
# Create sampling rng
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_metrics = []
|
|
|
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
# train
|
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
|
batch = next(train_loader)
|
|
batch = shard(batch)
|
|
state, train_metric = p_train_step(state, batch)
|
|
train_metrics.append(train_metric)
|
|
|
|
train_time += time.time() - train_start
|
|
|
|
train_metric = unreplicate(train_metric)
|
|
|
|
epochs.write(
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
|
f" {train_metric['learning_rate']})"
|
|
)
|
|
|
|
# ======================== Evaluating ==============================
|
|
# eval_metrics = []
|
|
# eval_preds = []
|
|
# eval_labels = []
|
|
|
|
# eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
|
# eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
|
# for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
|
# # Model forward
|
|
# batch = next(eval_loader)
|
|
# labels = batch["labels"]
|
|
|
|
# metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
|
# state.params, batch, min_device_batch=per_device_eval_batch_size
|
|
# )
|
|
# eval_metrics.append(metrics)
|
|
|
|
# # generation
|
|
# if predict_with_generate:
|
|
# generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
|
# eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
|
# eval_labels.extend(labels)
|
|
|
|
# # normalize eval metrics
|
|
# eval_metrics = get_metrics(eval_metrics)
|
|
# eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
|
|
# compute metrics
|
|
# rouge_desc = ""
|
|
# if predict_with_generate:
|
|
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
|
# eval_metrics.update(rouge_metrics)
|
|
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
|
|
|
# # Print metrics and update progress bar
|
|
# desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
|
# epochs.write(desc)
|
|
# epochs.desc = desc
|
|
|
|
# Save metrics
|
|
# if has_tensorboard and jax.process_index() == 0:
|
|
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
|
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
|
|
|
output_dir = save_path
|
|
# save checkpoint after each epoch and push checkpoint to the hub
|
|
if jax.process_index() == 0:
|
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
|
model.save_pretrained(output_dir, params=params)
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
|
|
|
|
# %% [markdown]
|
|
# #
|