# t5 training for combined concatenated outputs (thing + property) 

refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow

In [1]:
from datasets import load_from_disk
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
import evaluate
import numpy as np
import os

model_name = "facebook/bart-base"
train_epochs = 80

# Load mode configuration
with open("mode.json", "r") as json_file:
    mode_dict = json.load(json_file)

mode_dict.update({"model": model_name, "train_epochs": train_epochs})
fold_group = mode_dict.get("fold_group")

with open("mode.json", "w") as json_file:
    json.dump(mode_dict, json_file)

mode = mode_dict.get("mode", "default_value")
file_path = f'combined_data/{mode}/{fold_group}'
split_datasets = load_from_disk(file_path)

# Load tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(model_name)
additional_special_tokens = [
    "<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>",
    "<TN_START>", "<TN_END>", "<TD_START>", "<TD_END>", 
    "<MIN_START>", "<MIN_END>", "<MAX_START>", "<MAX_END>",
    "<UNIT_START>", "<UNIT_END>"
]
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})

# Preprocess function for tokenization
def preprocess_function(examples):
    inputs = [ex["input"] for ex in examples['translation']]
    targets = [ex["thing_property"] for ex in examples['translation']]
    return tokenizer(inputs, text_target=targets, max_length=64, truncation=True)

tokenized_datasets = split_datasets.map(
    preprocess_function, batched=True, remove_columns=split_datasets["train"].column_names
)

# Load model and resize token embeddings
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# Data collator for padding and batching
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Load evaluation metric
metric = evaluate.load("sacrebleu")

# Compute metrics function
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = preds[0] if isinstance(preds, tuple) else preds
    
    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)  # Replace padding tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Post-process decoding
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

args = Seq2SeqTrainingArguments(
    f"train_{fold_group}_{model_name}_{mode}_{train_epochs}",
    save_strategy="steps",
    learning_rate=1e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    auto_find_batch_size=True,
    ddp_find_unused_parameters=False,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=train_epochs,
    predict_with_generate=True,
    bf16=True,
    push_to_hub=False,
    evaluation_strategy="steps",
    eval_steps=200,
    save_steps=200,    
    logging_steps=200,  
    load_best_model_at_end=True, 
    lr_scheduler_type="linear",
    warmup_steps=100,
)

# Define the EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(
   early_stopping_patience=2
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]  
)

trainer.train()
os._exit(0)



Map:   0%|          | 0/6260 [00:00<?, ? examples/s]

Map:   0%|          | 0/12969 [00:00<?, ? examples/s]

Map:   0%|          | 0/2087 [00:00<?, ? examples/s]



Step,Training Loss,Validation Loss,Bleu
200,2.6543,0.11238,26.397731
400,0.1066,0.035335,87.137364
600,0.0446,0.022964,89.884682
800,0.0263,0.01822,86.274312
1000,0.0173,0.016252,86.389477
1200,0.0124,0.015651,94.416285
1400,0.0115,0.014833,91.596509
1600,0.0088,0.015168,91.629519
1800,0.0069,0.015042,95.375351


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams

: 