2025-01-13 19:05:13 +09:00
|
|
|
import torch
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers import (
|
|
|
|
T5TokenizerFast,
|
|
|
|
AutoModelForSeq2SeqLM,
|
|
|
|
)
|
|
|
|
import os
|
|
|
|
from tqdm import tqdm
|
|
|
|
from datasets import Dataset
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
|
|
|
|
|
|
|
class Inference():
|
|
|
|
tokenizer: T5TokenizerFast
|
|
|
|
model: torch.nn.Module
|
|
|
|
dataloader: DataLoader
|
|
|
|
|
|
|
|
def __init__(self, checkpoint_path):
|
|
|
|
self._create_tokenizer()
|
|
|
|
self._load_model(checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
def _create_tokenizer(self):
|
|
|
|
# %%
|
|
|
|
# load tokenizer
|
|
|
|
self.tokenizer = T5TokenizerFast.from_pretrained("t5-small", return_tensors="pt", clean_up_tokenization_spaces=True)
|
|
|
|
# Define additional special tokens
|
|
|
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
|
|
# Add the additional special tokens to the tokenizer
|
|
|
|
# self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
|
|
|
|
|
|
def _load_model(self, checkpoint_path: str):
|
|
|
|
# load model
|
|
|
|
# Define the directory and the pattern
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)
|
|
|
|
model = torch.compile(model)
|
|
|
|
# set model to eval
|
|
|
|
self.model = model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataloader(self, input_df, batch_size, max_length):
|
|
|
|
"""
|
|
|
|
*arguments*
|
|
|
|
- input_df: input dataframe containing fields 'tag_description', 'thing', 'property'
|
|
|
|
- batch_size: the batch size of dataloader output
|
|
|
|
- max_length: length of tokenizer output
|
|
|
|
"""
|
|
|
|
print("preparing dataloader")
|
|
|
|
# convert each dataframe row into a dictionary
|
|
|
|
# outputs a list of dictionaries
|
|
|
|
|
|
|
|
def _process_df(df):
|
|
|
|
output_list = []
|
|
|
|
for _, row in df.iterrows():
|
|
|
|
desc = row['mention']
|
2025-01-15 20:09:15 +09:00
|
|
|
label = row['entity_seq']
|
2025-01-13 19:05:13 +09:00
|
|
|
element = {
|
|
|
|
'input' : desc,
|
2025-01-15 20:09:15 +09:00
|
|
|
'output': f'{label}'
|
2025-01-13 19:05:13 +09:00
|
|
|
}
|
|
|
|
|
|
|
|
output_list.append(element)
|
|
|
|
|
|
|
|
return output_list
|
|
|
|
|
|
|
|
def _preprocess_function(example):
|
|
|
|
input = example['input']
|
|
|
|
target = example['output']
|
|
|
|
# text_target sets the corresponding label to inputs
|
|
|
|
# there is no need to create a separate 'labels'
|
|
|
|
model_inputs = self.tokenizer(
|
|
|
|
input,
|
|
|
|
text_target=target,
|
|
|
|
max_length=max_length,
|
|
|
|
return_tensors="pt",
|
|
|
|
padding='max_length',
|
|
|
|
truncation=True,
|
|
|
|
)
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
test_dataset = Dataset.from_list(_process_df(input_df))
|
|
|
|
|
|
|
|
|
|
|
|
# map maps function to each "row" in the dataset
|
|
|
|
# aka the data in the immediate nesting
|
|
|
|
datasets = test_dataset.map(
|
|
|
|
_preprocess_function,
|
|
|
|
batched=True,
|
|
|
|
num_proc=1,
|
|
|
|
remove_columns=test_dataset.column_names,
|
|
|
|
)
|
|
|
|
# datasets = _preprocess_function(test_dataset)
|
|
|
|
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
|
|
|
|
|
|
|
# create dataloader
|
|
|
|
self.dataloader = DataLoader(datasets, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
def generate(self):
|
2025-01-15 20:09:15 +09:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2025-01-13 19:05:13 +09:00
|
|
|
MAX_GENERATE_LENGTH = 128
|
|
|
|
|
|
|
|
pred_generations = []
|
|
|
|
pred_labels = []
|
|
|
|
|
|
|
|
print("start generation")
|
|
|
|
for batch in tqdm(self.dataloader):
|
|
|
|
# Inference in batches
|
|
|
|
input_ids = batch['input_ids']
|
|
|
|
attention_mask = batch['attention_mask']
|
|
|
|
# save labels too
|
|
|
|
pred_labels.extend(batch['labels'])
|
|
|
|
|
|
|
|
|
|
|
|
# Move to GPU if available
|
|
|
|
input_ids = input_ids.to(device)
|
|
|
|
attention_mask = attention_mask.to(device)
|
|
|
|
self.model.to(device)
|
|
|
|
|
|
|
|
# Perform inference
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = self.model.generate(input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
max_length=MAX_GENERATE_LENGTH)
|
|
|
|
|
|
|
|
# Decode the output and print the results
|
|
|
|
pred_generations.extend(outputs.to("cpu"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
def process_tensor_output(tokens):
|
|
|
|
predictions = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
# decode prediction labels
|
|
|
|
def decode_preds(tokens_list):
|
|
|
|
prediction_list = []
|
|
|
|
for tokens in tokens_list:
|
|
|
|
predicted_seq = process_tensor_output(tokens)
|
|
|
|
prediction_list.append(predicted_seq)
|
|
|
|
return prediction_list
|
|
|
|
|
|
|
|
prediction_list = decode_preds(pred_generations)
|
|
|
|
return prediction_list
|
|
|
|
|