domain_mapping/esAppMod_train/augmentation/dynamic_train.py

389 lines
11 KiB
Python
Raw Permalink Normal View History

# %%
from torch.utils.data import Dataset, DataLoader
# from datasets import load_from_disk
import os
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import re
import random
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
Trainer,
EarlyStoppingCallback,
TrainingArguments,
TrainerCallback
)
import evaluate
import numpy as np
import pandas as pd
import math
from functools import partial
import warnings
warnings.filterwarnings("ignore", message='Was asked to gather along dimension 0')
warnings.filterwarnings("ignore", message='FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.')
# import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('high')
def set_seed(seed):
"""
Set the random seed for reproducibility.
"""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
set_seed(42)
# %%
# PARAMETERS
2025-01-18 12:14:06 +09:00
SAMPLES=50
SHUFFLES=3
AMPLIFY_FACTOR=10
# %%
###################################################
# import code
# import training file
data_path = '../../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_random_shuffles(text, n):
words = text.split() # Split the input into words
shuffled_variations = []
for _ in range(n):
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
return shuffled_variations
def shuffle_text(text, n_shuffles=SHUFFLES):
all_processed = []
# add the original text
all_processed.append(text)
# Generate random shuffles
shuffled_variations = generate_random_shuffles(text, n_shuffles)
all_processed.extend(shuffled_variations)
return all_processed
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_string(sentence, corruption_probability=0.01):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < corruption_probability else word
for word in words
]
return " ".join(corrupted_words)
# %%
def create_example(index, mention):
return {'training_id': index, 'mention': mention}
# augment whole dataset
def augment_data(df):
output_list = []
for idx,row in df.iterrows():
index = row['training_id']
parent_desc = row['mention']
parent_desc = preprocess_text(parent_desc)
# add basic example
output_list.append(create_example(index, parent_desc))
# add shuffled strings
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
for desc in processed_descs:
if (desc != parent_desc):
output_list.append(create_example(index, desc))
# add corrupted strings
desc = corrupt_string(parent_desc, corruption_probability=0.1)
if (desc != parent_desc):
output_list.append(create_example(index, desc))
# add example with stripped non-alphanumerics
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
if (desc != parent_desc):
output_list.append(create_example(index, desc))
# short sequence amplifier
# short sequences are rare, and we must compensate by including more examples
# also, short sequence don't usually get affected by shuffle
words = parent_desc.split()
word_count = len(words)
if word_count <= 2:
for _ in range(AMPLIFY_FACTOR):
output_list.append(create_example(index, desc))
new_df = pd.DataFrame(output_list)
return new_df
###############################################################
# regeneration code
# %%
# we want to sample n samples from each class
# sample_size refers to the number of samples per class
def sample_from_df(df, sample_size_per_class=5):
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
.reset_index(drop=True))
return sampled_df
# %%
class DynamicDataset(Dataset):
def __init__(self, df, sample_size_per_class, tokenizer):
"""
Args:
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
sample_size_per_class (int): Number of samples to draw per class for each epoch.
"""
self.df = df
self.sample_size_per_class = sample_size_per_class
self.tokenizer = tokenizer
self.current_data = None
self.regenerate_data() # Generate the initial dataset
def regenerate_data(self):
"""
Generate a new sampled dataset for the current epoch.
dynamic callback function to regenerate data each time we call this
method, it updates the current_data we can:
- re-sample the dataframe for a new set of n_samples
- generate fresh augmentations this effectively
This allows us to re-sample and re-augment at the start of each epoch
"""
# Sample `sample_size_per_class` rows per class
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
# perform future edits here
sampled_df = augment_data(sampled_df)
# perform tokenization here
# Batch tokenize the entire column of data
tokenized_batch = self.tokenizer(
sampled_df["mention"].to_list(), # Pass all text data at once
truncation=True,
# return_tensors="pt" # disabled because pt requires equal length tensors
)
# Store the tokenized data with labels
self.current_data = [
{
"input_ids": torch.tensor(tokenized_batch["input_ids"][i]),
"attention_mask": torch.tensor(tokenized_batch["attention_mask"][i]),
"labels": torch.tensor(sampled_df.iloc[i]["training_id"]) # Include the label
}
for i in range(len(sampled_df))
]
def __len__(self):
return len(self.current_data)
def __getitem__(self, idx):
return self.current_data[idx]
# %%
class RegenerateDatasetCallback(TrainerCallback):
def __init__(self, dataset):
self.dataset = dataset
def on_epoch_begin(self, args, state, control, **kwargs):
print(f"Epoch {int(math.ceil(state.epoch + 1))}: Regenerating dataset")
self.dataset.regenerate_data()
# %%
def custom_collate_fn(batch):
# Dynamically pad tensors to the longest sequence in the batch
input_ids = [item["input_ids"] for item in batch]
attention_masks = [item["attention_mask"] for item in batch]
labels = torch.stack([item["labels"] for item in batch])
# Pad inputs to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True)
attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True)
return {
"input_ids": input_ids,
"attention_mask": attention_masks,
"labels": labels
}
##########################################################################
# training code
# %%
def train():
save_path = f'checkpoint'
# prepare tokenizer
model_checkpoint = "distilbert/distilbert-base-uncased"
# model_checkpoint = 'google-bert/bert-base-cased'
# model_checkpoint = 'prajjwal1/bert-small'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, clean_up_tokenization_spaces=True)
# make the dataset
# Define the callback
lean_df = df.drop(columns=['entity_name'])
dynamic_dataset = DynamicDataset(df = lean_df, sample_size_per_class=SAMPLES, tokenizer=tokenizer)
# create the regeneration callback
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset)
# compute metrics
metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
preds = np.argmax(preds, axis=1)
return metric.compute(predictions=preds, references=labels)
# %%
model = AutoModelForSequenceClassification.from_pretrained(
model_checkpoint,
num_labels=len(target_id_list),
id2label=id2label,
label2id=label2id)
model.resize_token_embeddings(len(tokenizer))
# model = torch.compile(model, backend="inductor", dynamic=True)
# %%
# Trainer
training_args = TrainingArguments(
output_dir=f"{save_path}",
# eval_strategy="epoch",
eval_strategy="no",
logging_dir="tensorboard-log",
logging_strategy="epoch",
save_strategy="steps",
save_steps=500,
load_best_model_at_end=False,
learning_rate=5e-5,
per_device_train_batch_size=64,
# per_device_eval_batch_size=64,
auto_find_batch_size=False,
ddp_find_unused_parameters=False,
weight_decay=0.01,
save_total_limit=1,
num_train_epochs=120,
warmup_steps=400,
bf16=True,
push_to_hub=False,
remove_unused_columns=False,
)
trainer = Trainer(
model,
training_args,
train_dataset=dynamic_dataset,
tokenizer=tokenizer,
data_collator=custom_collate_fn,
compute_metrics=compute_metrics,
callbacks=[regeneration_callback]
# callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
# uncomment to load training from checkpoint
# checkpoint_path = 'default_40_1/checkpoint-5600'
# trainer.train(resume_from_checkpoint=checkpoint_path)
trainer.train()
# execute training
train()
# %%