# t5 training for combined concatenated outputs (thing + property) 

refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow

In [1]:
from datasets import load_from_disk
import json
from transformers import AutoTokenizer
import os
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
import evaluate
import numpy as np

model_name = "google/t5-efficient-tiny"
# google/t5-efficient-tiny
# google/t5-efficient-mini
# t5-small
# t5-base

train_epochs = 80

with open("mode.json", "r") as json_file:
    mode_dict = json.load(json_file)

mode_dict.update({"model": model_name, "train_epochs": train_epochs})
fold_group = mode_dict.get("fold_group")

with open("mode.json", "w") as json_file:
    json.dump(mode_dict, json_file)

mode = mode_dict.get("mode", "default_value")
file_path = f'combined_data/{mode}/{fold_group}'
split_datasets = load_from_disk(file_path)

tokenizer = AutoTokenizer.from_pretrained(model_name)
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", 
                             "<TN_START>", "<TN_END>", "<TD_START>", "<TD_END>", 
                             "<MIN_START>", "<MIN_END>", "<MAX_START>", "<MAX_END>", 
                             "<UNIT_START>", "<UNIT_END>"]
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})

max_length = 64

def preprocess_function(examples):
    inputs = [ex["input"] for ex in examples['translation']]
    targets = [ex["thing_property"] for ex in examples['translation']]
    return tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)

tokenized_datasets = split_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=split_datasets["train"].column_names,
)


model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'

args = Seq2SeqTrainingArguments(
    f"train_{fold_group}_{model_name}_{mode}_{train_epochs}",
    save_strategy="steps",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    auto_find_batch_size=True,
    ddp_find_unused_parameters=False,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=train_epochs,
    predict_with_generate=True,
    bf16=True,
    push_to_hub=False,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,    
    logging_steps=100,  
    load_best_model_at_end=True, 
    metric_for_best_model="bleu",
    lr_scheduler_type="linear",
    warmup_steps=100,
)

# Define the EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(
   early_stopping_patience=5,

)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]  
)

trainer.train()
os._exit(0)




Step,Training Loss,Validation Loss,Bleu
100,9.0681,1.485702,0.0
200,0.8864,0.219002,20.99997
300,0.3025,0.1001,50.318311
400,0.1684,0.053922,52.052581
500,0.1138,0.046394,53.469249
600,0.0845,0.040225,53.980484
700,0.0669,0.026786,58.959618
800,0.0533,0.025612,52.672595
900,0.0426,0.019917,58.47523
1000,0.0382,0.021234,52.335545




KeyboardInterrupt: 