228 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			228 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| # this code performs dataloading with text augmentation
 | |
| 
 | |
| # %%
 | |
| from torch.utils.data import Dataset, DataLoader
 | |
| import pandas as pd
 | |
| import torch
 | |
| from transformers import (
 | |
|     AutoTokenizer,
 | |
| )
 | |
| from functools import partial
 | |
| import re
 | |
| import random
 | |
| 
 | |
| # %%
 | |
| # PARAMETERS
 | |
| SAMPLES=5
 | |
| 
 | |
| # %%
 | |
| ###################################################
 | |
| # import code
 | |
| # import training file
 | |
| data_path = '../esAppMod_data_import/train.csv'
 | |
| df = pd.read_csv(data_path, skipinitialspace=True)
 | |
| # rather than use pattern, we use the real thing and property
 | |
| entity_ids = df['entity_id'].to_list()
 | |
| target_id_list = sorted(list(set(entity_ids)))
 | |
| 
 | |
| id2label = {}
 | |
| label2id = {}
 | |
| for idx, val in enumerate(target_id_list):
 | |
|     id2label[idx] = val
 | |
|     label2id[val] = idx
 | |
| 
 | |
| df["training_id"] = df["entity_id"].map(label2id)
 | |
| 
 | |
| # %%
 | |
| ##############################################################
 | |
| # augmentation code
 | |
| 
 | |
| # basic preprocessing
 | |
| def preprocess_text(text):
 | |
|     # 1. Make all uppercase
 | |
|     text = text.lower()
 | |
| 
 | |
|     # standardize spacing
 | |
|     text = re.sub(r'\s+', ' ', text).strip()
 | |
| 
 | |
|     return text
 | |
| 
 | |
| 
 | |
| def shuffle_text(text, prob=0.2):
 | |
|     if random.random() < prob:
 | |
|         words = text.split()  # Split the input into words
 | |
|         shuffled = words[:]  # Copy the word list to avoid in-place modification
 | |
|         random.shuffle(shuffled)  # Randomly shuffle the words
 | |
|         shuffled_text = " ".join(shuffled)  # Join the words back into a string
 | |
|     else:
 | |
|         shuffled_text = text
 | |
|     
 | |
|     return shuffled_text
 | |
| 
 | |
| 
 | |
| def corrupt_word(word):
 | |
|     """Corrupt a single word using random corruption techniques."""
 | |
|     if len(word) <= 1:  # Skip corruption for single-character words
 | |
|         return word
 | |
|     
 | |
|     corruption_type = random.choice(["delete", "swap"])
 | |
|     
 | |
|     if corruption_type == "delete":
 | |
|         # Randomly delete a character
 | |
|         idx = random.randint(0, len(word) - 1)
 | |
|         word = word[:idx] + word[idx + 1:]
 | |
|     
 | |
|     elif corruption_type == "swap":
 | |
|         # Swap two adjacent characters
 | |
|         if len(word) > 1:
 | |
|             idx = random.randint(0, len(word) - 2)
 | |
|             word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
 | |
|     
 | |
|     
 | |
|     return word
 | |
| 
 | |
| def corrupt_text(sentence, prob=0.05):
 | |
|     """Corrupt each word in the string with a given probability."""
 | |
|     words = sentence.split()
 | |
|     corrupted_words = [
 | |
|         corrupt_word(word) if random.random() < prob else word
 | |
|         for word in words
 | |
|     ]
 | |
|     return " ".join(corrupted_words)
 | |
| 
 | |
| def strip_nonalphanumerics(desc, prob=0.5):
 | |
|     desc = re.sub(r'[^\w\s]', ' ', desc)  # Retains only alphanumeric and spaces
 | |
|     return desc
 | |
| 
 | |
| 
 | |
| # %%
 | |
| def augment(row):
 | |
|     """
 | |
|     function to augment "mention" string input
 | |
|     returns the string input with slight variation
 | |
|     """
 | |
|     desc = row['mention']
 | |
|     # we always apply preprocess
 | |
|     desc = preprocess_text(desc)
 | |
| 
 | |
|     desc = shuffle_text(desc, prob=1.0)
 | |
|     desc = corrupt_text(desc, prob=1.0)
 | |
|     desc = strip_nonalphanumerics(desc, prob=0.5)
 | |
| 
 | |
|     return desc
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| # %%
 | |
| # custom dataset
 | |
| # we want to sample n samples from each class
 | |
| # sample_size refers to the number of samples per class
 | |
| def sample_from_df(df, sample_size_per_class=5):
 | |
|     sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
 | |
|     .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
 | |
|     .reset_index(drop=True))
 | |
| 
 | |
|     return sampled_df
 | |
| 
 | |
| 
 | |
| # %%
 | |
| class DynamicDataset(Dataset):
 | |
|     def __init__(self, df, sample_size_per_class):
 | |
|         """
 | |
|         Args:
 | |
|             df (pd.DataFrame): Original DataFrame with class (id) and data columns.
 | |
|             sample_size_per_class (int): Number of samples to draw per class for each epoch.
 | |
|         """
 | |
|         self.df = df
 | |
|         self.sample_size_per_class = sample_size_per_class
 | |
|         self.current_data = None
 | |
|         self.regenerate_data()  # Generate the initial dataset
 | |
| 
 | |
|     def regenerate_data(self):
 | |
|         """
 | |
|         Generate a new sampled dataset for the current epoch.
 | |
| 
 | |
|         dynamic callback function to regenerate data each time we call this
 | |
|         method, it updates the current_data we can: 
 | |
|             
 | |
|         - re-sample the dataframe for a new set of n_samples 
 | |
|         - generate fresh augmentations this effectively
 | |
| 
 | |
|         This allows us to re-sample and re-augment at the start of each epoch
 | |
|         """
 | |
|         # Sample `sample_size_per_class` rows per class
 | |
|         sampled_df = sample_from_df(self.df, self.sample_size_per_class)
 | |
| 
 | |
|         # Store the tokenized data with labels
 | |
|         self.current_data = sampled_df
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.current_data)
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         # do the transform here
 | |
|         row = self.current_data.iloc[idx].to_dict()
 | |
| 
 | |
|         # perform text augmentation here
 | |
|         # independent function calls might introduce changes
 | |
|         mention_0 = augment(row)
 | |
|         mention_1 = augment(row)
 | |
|         return {
 | |
|             'training_id': row['training_id'],
 | |
|             'mention_0': mention_0,
 | |
|             'mention_1': mention_1,
 | |
|         }
 | |
| 
 | |
| 
 | |
| # %%
 | |
| dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
 | |
| dataset[0]
 | |
| 
 | |
| 
 | |
| # %%
 | |
| def custom_collate_fn(batch, tokenizer):
 | |
|     # batch is just a list of dictionaries
 | |
|     label_list = [item['training_id'] for item in batch]
 | |
|     mention_0_list = [item['mention_0'] for item in batch]
 | |
|     mention_1_list = [item['mention_1'] for item in batch]
 | |
| 
 | |
|     # we can do the tokenization here
 | |
|     tokenized_batch_0 = tokenizer(
 | |
|         mention_0_list,
 | |
|         truncation=True,
 | |
|         padding=True,
 | |
|         return_tensors='pt'
 | |
|     )
 | |
| 
 | |
|     tokenized_batch_1 = tokenizer(
 | |
|         mention_1_list,
 | |
|         truncation=True,
 | |
|         padding=True,
 | |
|         return_tensors='pt'
 | |
|     )
 | |
| 
 | |
| 
 | |
|     label_list = torch.tensor(label_list)
 | |
| 
 | |
|     return {
 | |
|         'input_ids_0': tokenized_batch_0['input_ids'],
 | |
|         'attention_mask_0': tokenized_batch_0['attention_mask'],
 | |
|         'input_ids_1': tokenized_batch_1['input_ids'],
 | |
|         'attention_mask_1': tokenized_batch_1['attention_mask'],
 | |
|         'labels': label_list,
 | |
|     }
 | |
| 
 | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
 | |
| custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
 | |
| dataloader = DataLoader(
 | |
|     dataset,
 | |
|     batch_size=8,
 | |
|     collate_fn=custom_collate_fn_with_tokenizer
 | |
| )
 | |
| 
 | |
| 
 | |
| # %%
 | |
| next(iter(dataloader))
 | |
| # %%
 |