2024-11-05 16:49:18 +09:00
|
|
|
import torch
|
2024-11-10 20:28:47 +09:00
|
|
|
from transformers import (
|
|
|
|
AutoTokenizer,
|
|
|
|
AutoModelForSequenceClassification,
|
|
|
|
DataCollatorWithPadding,
|
|
|
|
)
|
2024-11-05 16:49:18 +09:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Retriever:
|
|
|
|
def __init__(self, input_texts, model_checkpoint):
|
|
|
|
# we need to generate the embedding from list of input strings
|
|
|
|
self.embeddings = []
|
|
|
|
self.inputs = input_texts
|
|
|
|
model_checkpoint = model_checkpoint
|
2024-11-10 20:28:47 +09:00
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True)
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
2024-11-05 16:49:18 +09:00
|
|
|
# device = "cpu"
|
|
|
|
model.to(self.device)
|
|
|
|
self.model = model.eval()
|
|
|
|
|
|
|
|
|
2024-11-10 20:28:47 +09:00
|
|
|
def make_embedding(self, batch_size=64):
|
2024-11-05 16:49:18 +09:00
|
|
|
all_embeddings = self.embeddings
|
|
|
|
input_texts = self.inputs
|
|
|
|
|
|
|
|
for i in range(0, len(input_texts), batch_size):
|
|
|
|
batch_texts = input_texts[i:i+batch_size]
|
|
|
|
# Tokenize the input text
|
2024-11-10 20:28:47 +09:00
|
|
|
inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
|
2024-11-05 16:49:18 +09:00
|
|
|
input_ids = inputs.input_ids.to(self.device)
|
|
|
|
attention_mask = inputs.attention_mask.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
# Pass the input through the encoder and retrieve the embeddings
|
|
|
|
with torch.no_grad():
|
2024-11-10 20:28:47 +09:00
|
|
|
encoder_outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
|
# get last layer
|
|
|
|
embeddings = encoder_outputs.hidden_states[-1]
|
|
|
|
# get cls token embedding
|
|
|
|
cls_embeddings = embeddings[:, 0, :] # Shape: (batch_size, hidden_size)
|
|
|
|
all_embeddings.append(cls_embeddings)
|
2024-11-05 16:49:18 +09:00
|
|
|
|
|
|
|
# remove the batch list and makes a single large tensor, dim=0 increases row-wise
|
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
|
|
|
|
|
|
self.embeddings = all_embeddings
|
|
|
|
|
2024-11-11 02:18:57 +09:00
|
|
|
def cosine_similarity_chunked(batch1, batch2, chunk_size=1024):
|
|
|
|
device = 'cuda'
|
2024-11-05 16:49:18 +09:00
|
|
|
batch1_size = batch1.size(0)
|
|
|
|
batch2_size = batch2.size(0)
|
2024-11-11 02:18:57 +09:00
|
|
|
batch2.to(device)
|
2024-11-05 16:49:18 +09:00
|
|
|
|
|
|
|
# Prepare an empty tensor to store results
|
2024-11-11 02:18:57 +09:00
|
|
|
cos_sim = torch.empty(batch1_size, batch2_size, device=device)
|
2024-11-05 16:49:18 +09:00
|
|
|
|
|
|
|
# Process batch1 in chunks
|
|
|
|
for i in range(0, batch1_size, chunk_size):
|
|
|
|
batch1_chunk = batch1[i:i + chunk_size] # Get chunk of batch1
|
|
|
|
|
2024-11-11 02:18:57 +09:00
|
|
|
batch1_chunk.to(device)
|
2024-11-05 16:49:18 +09:00
|
|
|
# Expand batch1 chunk and entire batch2 for comparison
|
2024-11-11 02:18:57 +09:00
|
|
|
# batch1_chunk_exp = batch1_chunk.unsqueeze(1) # Shape: (chunk_size, 1, seq_len)
|
|
|
|
# batch2_exp = batch2.unsqueeze(0) # Shape: (1, batch2_size, seq_len)
|
|
|
|
batch2_norms = batch2.norm(dim=1, keepdim=True)
|
|
|
|
|
2024-11-05 16:49:18 +09:00
|
|
|
|
|
|
|
# Compute cosine similarity for the chunk and store it in the final tensor
|
2024-11-11 02:18:57 +09:00
|
|
|
# cos_sim[i:i + chunk_size] = F.cosine_similarity(batch1_chunk_exp, batch2_exp, dim=-1)
|
|
|
|
|
|
|
|
# Compute cosine similarity by matrix multiplication and normalizing
|
|
|
|
sim_chunk = torch.mm(batch1_chunk, batch2.T) / (batch1_chunk.norm(dim=1, keepdim=True) * batch2_norms.T + 1e-8)
|
|
|
|
|
|
|
|
# Store the results in the appropriate part of the final tensor
|
|
|
|
cos_sim[i:i + chunk_size] = sim_chunk
|
2024-11-05 16:49:18 +09:00
|
|
|
|
|
|
|
return cos_sim
|
|
|
|
|