hipom_data_mapping/train/classification_all_with_con.../utils.py

76 lines
3.2 KiB
Python
Raw Normal View History

import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
import torch.nn.functional as F
class Retriever:
def __init__(self, input_texts, model_checkpoint):
# we need to generate the embedding from list of input strings
self.embeddings = []
self.inputs = input_texts
model_checkpoint = model_checkpoint
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", return_tensors="pt", clean_up_tokenization_spaces=True)
# define additional special tokens
additional_special_tokens = ["<thing_start>", "<thing_end>", "<property_start>", "<property_end>", "<name>", "<desc>", "<sig>", "<unit>", "<data_type>"]
# add the additional special tokens to the tokenizer
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# device = "cpu"
model.to(self.device)
self.model = model.eval()
def make_mean_embedding(self, batch_size=32):
all_embeddings = self.embeddings
input_texts = self.inputs
for i in range(0, len(input_texts), batch_size):
batch_texts = input_texts[i:i+batch_size]
# Tokenize the input text
inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
input_ids = inputs.input_ids.to(self.device)
attention_mask = inputs.attention_mask.to(self.device)
# Pass the input through the encoder and retrieve the embeddings
with torch.no_grad():
encoder_outputs = self.model.encoder(input_ids, attention_mask=attention_mask)
embeddings = encoder_outputs.last_hidden_state
# Compute the mean pooling of the token embeddings
# mean_embedding = embeddings.mean(dim=1)
mean_embedding = (embeddings * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
all_embeddings.append(mean_embedding)
# remove the batch list and makes a single large tensor, dim=0 increases row-wise
all_embeddings = torch.cat(all_embeddings, dim=0)
self.embeddings = all_embeddings
def cosine_similarity_chunked(batch1, batch2, chunk_size=16):
batch1_size = batch1.size(0)
batch2_size = batch2.size(0)
# Prepare an empty tensor to store results
cos_sim = torch.empty(batch1_size, batch2_size, device=batch1.device)
# Process batch1 in chunks
for i in range(0, batch1_size, chunk_size):
batch1_chunk = batch1[i:i + chunk_size] # Get chunk of batch1
# Expand batch1 chunk and entire batch2 for comparison
batch1_chunk_exp = batch1_chunk.unsqueeze(1) # Shape: (chunk_size, 1, seq_len)
batch2_exp = batch2.unsqueeze(0) # Shape: (1, batch2_size, seq_len)
# Compute cosine similarity for the chunk and store it in the final tensor
cos_sim[i:i + chunk_size] = F.cosine_similarity(batch1_chunk_exp, batch2_exp, dim=-1)
return cos_sim