281 lines
8.3 KiB
Python
281 lines
8.3 KiB
Python
# %%
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
# from datasets import load_from_disk
|
|
import os
|
|
|
|
os.environ['NCCL_P2P_DISABLE'] = '1'
|
|
os.environ['NCCL_IB_DISABLE'] = '1'
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
|
|
|
import re
|
|
import random
|
|
|
|
import torch
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSequenceClassification,
|
|
DataCollatorWithPadding,
|
|
Trainer,
|
|
EarlyStoppingCallback,
|
|
TrainingArguments,
|
|
TrainerCallback
|
|
)
|
|
import evaluate
|
|
import numpy as np
|
|
import pandas as pd
|
|
from functools import partial
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore", message='Was asked to gather along dimension 0')
|
|
warnings.filterwarnings("ignore", message='FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.')
|
|
|
|
# import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
def set_seed(seed):
|
|
"""
|
|
Set the random seed for reproducibility.
|
|
"""
|
|
random.seed(seed) # Python random module
|
|
np.random.seed(seed) # NumPy random
|
|
torch.manual_seed(seed) # PyTorch CPU
|
|
torch.cuda.manual_seed(seed) # PyTorch GPU
|
|
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
|
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
|
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
|
|
|
set_seed(42)
|
|
|
|
# %%
|
|
# PARAMETERS
|
|
SAMPLES=20
|
|
|
|
# %%
|
|
###################################################
|
|
# import code
|
|
# import training file
|
|
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
|
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
|
# rather than use pattern, we use the real thing and property
|
|
entity_ids = df['entity_id'].to_list()
|
|
target_id_list = sorted(list(set(entity_ids)))
|
|
|
|
id2label = {}
|
|
label2id = {}
|
|
for idx, val in enumerate(target_id_list):
|
|
id2label[idx] = val
|
|
label2id[val] = idx
|
|
|
|
df["training_id"] = df["entity_id"].map(label2id)
|
|
|
|
###############################################################
|
|
# regeneration code
|
|
# %%
|
|
# we want to sample n samples from each class
|
|
# sample_size refers to the number of samples per class
|
|
def sample_from_df(df, sample_size_per_class=5):
|
|
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
|
|
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
|
.reset_index(drop=True))
|
|
|
|
return sampled_df
|
|
|
|
|
|
# %%
|
|
# augment whole dataset
|
|
# for now, we just return the same df
|
|
def augment_data(df):
|
|
return df
|
|
|
|
# %%
|
|
class DynamicDataset(Dataset):
|
|
def __init__(self, df, sample_size_per_class, tokenizer):
|
|
"""
|
|
Args:
|
|
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
|
|
sample_size_per_class (int): Number of samples to draw per class for each epoch.
|
|
"""
|
|
self.df = df
|
|
self.sample_size_per_class = sample_size_per_class
|
|
self.tokenizer = tokenizer
|
|
self.current_data = None
|
|
self.regenerate_data() # Generate the initial dataset
|
|
|
|
def regenerate_data(self):
|
|
"""
|
|
Generate a new sampled dataset for the current epoch.
|
|
|
|
dynamic callback function to regenerate data each time we call this
|
|
method, it updates the current_data we can:
|
|
|
|
- re-sample the dataframe for a new set of n_samples
|
|
- generate fresh augmentations this effectively
|
|
|
|
This allows us to re-sample and re-augment at the start of each epoch
|
|
"""
|
|
# Sample `sample_size_per_class` rows per class
|
|
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
|
|
|
|
# perform future edits here
|
|
sampled_df = augment_data(sampled_df)
|
|
|
|
# perform tokenization here
|
|
# Batch tokenize the entire column of data
|
|
tokenized_batch = self.tokenizer(
|
|
sampled_df["mention"].to_list(), # Pass all text data at once
|
|
truncation=True,
|
|
# return_tensors="pt" # disabled because pt requires equal length tensors
|
|
)
|
|
|
|
# Store the tokenized data with labels
|
|
self.current_data = [
|
|
{
|
|
"input_ids": torch.tensor(tokenized_batch["input_ids"][i]),
|
|
"attention_mask": torch.tensor(tokenized_batch["attention_mask"][i]),
|
|
"labels": torch.tensor(sampled_df.iloc[i]["training_id"]) # Include the label
|
|
}
|
|
for i in range(len(sampled_df))
|
|
]
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.current_data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.current_data[idx]
|
|
|
|
# %%
|
|
class RegenerateDatasetCallback(TrainerCallback):
|
|
def __init__(self, dataset, every_n_epochs=2):
|
|
"""
|
|
Args:
|
|
dataset: The dataset instance that supports regeneration.
|
|
every_n_epochs (int): Number of epochs to wait before regenerating the dataset.
|
|
"""
|
|
self.dataset = dataset
|
|
self.every_n_epochs = every_n_epochs
|
|
|
|
def on_epoch_begin(self, args, state, control, **kwargs):
|
|
# Check if the current epoch is a multiple of `every_n_epochs`
|
|
if (state.epoch + 1) % self.every_n_epochs == 0:
|
|
print(f"Epoch {int(state.epoch + 1)}: Regenerating dataset...")
|
|
self.dataset.regenerate_data()
|
|
|
|
|
|
# %%
|
|
def custom_collate_fn(batch):
|
|
# Dynamically pad tensors to the longest sequence in the batch
|
|
input_ids = [item["input_ids"] for item in batch]
|
|
attention_masks = [item["attention_mask"] for item in batch]
|
|
labels = torch.stack([item["labels"] for item in batch])
|
|
|
|
# Pad inputs to the same length
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True)
|
|
attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_masks,
|
|
"labels": labels
|
|
}
|
|
|
|
|
|
##########################################################################
|
|
# training code
|
|
# %%
|
|
def train():
|
|
|
|
save_path = f'checkpoint'
|
|
|
|
# prepare tokenizer
|
|
|
|
model_checkpoint = "distilbert/distilbert-base-uncased"
|
|
# model_checkpoint = 'google-bert/bert-base-cased'
|
|
# model_checkpoint = 'prajjwal1/bert-small'
|
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, clean_up_tokenization_spaces=True)
|
|
|
|
# make the dataset
|
|
|
|
|
|
# Define the callback
|
|
# lean_df = df.drop(columns=['entity_name'])
|
|
dynamic_dataset = DynamicDataset(df = df, sample_size_per_class=SAMPLES, tokenizer=tokenizer)
|
|
|
|
# create the regeneration callback
|
|
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset, every_n_epochs=2)
|
|
|
|
# compute metrics
|
|
metric = evaluate.load("accuracy")
|
|
|
|
def compute_metrics(eval_preds):
|
|
preds, labels = eval_preds
|
|
preds = np.argmax(preds, axis=1)
|
|
return metric.compute(predictions=preds, references=labels)
|
|
|
|
|
|
# %%
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
model_checkpoint,
|
|
num_labels=len(target_id_list),
|
|
id2label=id2label,
|
|
label2id=label2id)
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
# model = torch.compile(model, backend="inductor", dynamic=True)
|
|
|
|
|
|
# %%
|
|
# Trainer
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=f"{save_path}",
|
|
# eval_strategy="epoch",
|
|
eval_strategy="no",
|
|
logging_dir="tensorboard-log",
|
|
logging_strategy="epoch",
|
|
# save_strategy="epoch",
|
|
load_best_model_at_end=False,
|
|
learning_rate=1e-4,
|
|
per_device_train_batch_size=256,
|
|
# per_device_eval_batch_size=256,
|
|
auto_find_batch_size=False,
|
|
ddp_find_unused_parameters=False,
|
|
weight_decay=0.01,
|
|
save_total_limit=1,
|
|
num_train_epochs=40,
|
|
warmup_steps=200,
|
|
bf16=True,
|
|
push_to_hub=False,
|
|
remove_unused_columns=False,
|
|
)
|
|
|
|
|
|
trainer = Trainer(
|
|
model,
|
|
training_args,
|
|
train_dataset=dynamic_dataset,
|
|
tokenizer=tokenizer,
|
|
data_collator=custom_collate_fn,
|
|
compute_metrics=compute_metrics,
|
|
callbacks=[regeneration_callback]
|
|
# callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
|
)
|
|
|
|
# uncomment to load training from checkpoint
|
|
# checkpoint_path = 'default_40_1/checkpoint-5600'
|
|
# trainer.train(resume_from_checkpoint=checkpoint_path)
|
|
|
|
trainer.train()
|
|
|
|
# execute training
|
|
train()
|
|
|
|
|
|
# %%
|