451 lines
14 KiB
Python
451 lines
14 KiB
Python
# %%
|
|
import pandas as pd
|
|
import numpy as np
|
|
from typing import List
|
|
from tqdm import tqdm
|
|
from utils import Retriever, cosine_similarity_chunked
|
|
import glob
|
|
import os
|
|
|
|
# import re
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm import tqdm
|
|
import random
|
|
import math
|
|
|
|
|
|
# %%
|
|
class Embedder():
|
|
input_df: pd.DataFrame
|
|
fold: int
|
|
|
|
def __init__(self, input_df):
|
|
self.input_df = input_df
|
|
|
|
|
|
def make_embedding(self, checkpoint_path):
|
|
|
|
def generate_input_list(df):
|
|
input_list = []
|
|
for _, row in df.iterrows():
|
|
# name = f"<NAME>{row['tag_name']}<NAME>"
|
|
desc = f"<DESC>{row['tag_description']}<DESC>"
|
|
# element = f"{name}{desc}"
|
|
element = f"{desc}"
|
|
input_list.append(element)
|
|
return input_list
|
|
|
|
# prepare reference embed
|
|
train_data = list(generate_input_list(self.input_df))
|
|
# Define the directory and the pattern
|
|
retriever_train = Retriever(train_data, checkpoint_path)
|
|
retriever_train.make_mean_embedding(batch_size=64)
|
|
return retriever_train.embeddings.to('cpu')
|
|
|
|
# %%
|
|
# input data
|
|
fold = 1
|
|
data_path = f"../../data_preprocess/exports/dataset/group_{fold}/train.csv"
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
# %%
|
|
checkpoint_directory = "../../train/baseline"
|
|
directory = os.path.join(checkpoint_directory, f'checkpoint_fold_{fold}')
|
|
# Use glob to find matching paths
|
|
# path is usually checkpoint_fold_1/checkpoint-<step number>
|
|
# we are guaranteed to save only 1 checkpoint from training
|
|
pattern = 'checkpoint-*'
|
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
|
|
|
train_embedder = Embedder(input_df=train_df)
|
|
train_embeds = train_embedder.make_embedding(checkpoint_path)
|
|
|
|
|
|
# %%
|
|
train_embeds.shape
|
|
|
|
# %%
|
|
# now we need to generate the class labels
|
|
data_path = '../../data_import/exports/data_mapping_mdm.csv'
|
|
full_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
mdm_list = sorted(list((set(full_df['pattern']))))
|
|
|
|
# %%
|
|
# based on the mdm_labels, we assign a value to the dataframe
|
|
def generate_labels(df, mdm_list):
|
|
label_list = []
|
|
for _, row in df.iterrows():
|
|
pattern = row['pattern']
|
|
try:
|
|
index = mdm_list.index(pattern)
|
|
label_list.append(index)
|
|
except ValueError:
|
|
label_list.append(-1)
|
|
|
|
return label_list
|
|
|
|
# %%
|
|
label_list = generate_labels(train_df, mdm_list)
|
|
|
|
# # %%
|
|
# from collections import Counter
|
|
#
|
|
# frequency = Counter(label_list)
|
|
# frequency
|
|
|
|
####################################################
|
|
# %%
|
|
# we can start contrastive learning on a re-projection layer for the embedding
|
|
#################################################
|
|
# MARK: start collaborative filtering
|
|
|
|
# we need to create a batch where half are positive examples and the other half
|
|
# is negative examples
|
|
|
|
# we first need to test out how we can get the embeddings of each ship
|
|
|
|
# %%
|
|
label_tensor = torch.asarray(label_list)
|
|
|
|
def create_pairs(all_embeddings, labels, batch_size):
|
|
positive_pairs = []
|
|
negative_pairs = []
|
|
|
|
# find unique ships labels
|
|
unique_labels = torch.unique(labels)
|
|
|
|
embeddings_by_label = {}
|
|
for label in unique_labels:
|
|
embeddings_by_label[label.item()] = all_embeddings[labels == label]
|
|
|
|
# create positive pairs from the same ship
|
|
for _ in range(batch_size // 2):
|
|
label = random.choice(unique_labels)
|
|
label_embeddings = embeddings_by_label[label.item()]
|
|
|
|
# randomly select 2 embeddings from the same ship
|
|
if len(label_embeddings) >= 2: # ensure that we can choose
|
|
emb1, emb2 = random.sample(list(label_embeddings), 2)
|
|
positive_pairs.append((emb1, emb2, torch.tensor(1.0)))
|
|
|
|
# create negative pairs (from different ships)
|
|
for _ in range(batch_size // 2):
|
|
label1, label2 = random.sample(list(unique_labels), 2)
|
|
|
|
# select one embedding from each ship
|
|
emb1 = random.choice(embeddings_by_label[label1.item()])
|
|
emb2 = random.choice(embeddings_by_label[label2.item()])
|
|
|
|
negative_pairs.append((emb1, emb2, torch.tensor(0.0)))
|
|
|
|
pairs = positive_pairs + negative_pairs
|
|
|
|
# separate embeddings and labels for the batch
|
|
emb1_batch = torch.stack([pair[0] for pair in pairs])
|
|
emb2_batch = torch.stack([pair[1] for pair in pairs])
|
|
labels_batch = torch.stack([pair[2] for pair in pairs])
|
|
|
|
return emb1_batch, emb2_batch, labels_batch
|
|
|
|
|
|
# %%
|
|
# create model
|
|
|
|
class linear_map(nn.Module):
|
|
def __init__(self, input_dim, output_dim):
|
|
super(linear_map, self).__init__()
|
|
self.linear_1 = nn.Linear(input_dim, output_dim)
|
|
# self.linear_2 = nn.Linear(512, output_dim)
|
|
# self.relu = nn.ReLU() # Non-linearity
|
|
|
|
def forward(self, x):
|
|
x = self.linear_1(x)
|
|
# x = self.relu(x)
|
|
# x = self.linear_2(x)
|
|
return x
|
|
|
|
|
|
# %%
|
|
# the contrastive loss
|
|
# def contrastive_loss(embedding1, embedding2, label, margin=1.0):
|
|
# # calculate euclidean distance
|
|
# distance = F.pairwise_distance(embedding1, embedding2)
|
|
#
|
|
# # loss for positive pairs
|
|
# # label will select on positive examples
|
|
# positive_loss = label * torch.pow(distance, 2)
|
|
#
|
|
# # loss for negative pairs
|
|
# negative_loss = (1 - label) * torch.pow(torch.clamp(margin - distance, min=0), 2)
|
|
#
|
|
# loss = torch.mean(positive_loss + negative_loss)
|
|
# return loss
|
|
|
|
|
|
def contrastive_loss_cosine(embeddings1, embeddings2, label, margin=0.5):
|
|
"""
|
|
Compute the contrastive loss using cosine similarity.
|
|
|
|
Args:
|
|
- embeddings1: Tensor of embeddings for one set of pairs, shape (batch_size, embedding_size)
|
|
- embeddings2: Tensor of embeddings for the other set of pairs, shape (batch_size, embedding_size)
|
|
- label: Tensor of labels, 1 for positive pairs (same class), 0 for negative pairs (different class)
|
|
- margin: Margin for negative pairs (default 0.5)
|
|
|
|
Returns:
|
|
- loss: Contrastive loss based on cosine similarity
|
|
"""
|
|
# Cosine similarity between the two sets of embeddings
|
|
cosine_sim = F.cosine_similarity(embeddings1, embeddings2)
|
|
|
|
# For positive pairs, we want the cosine similarity to be close to 1
|
|
positive_loss = label * (1 - cosine_sim)
|
|
|
|
# For negative pairs, we want the cosine similarity to be lower than the margin
|
|
negative_loss = (1 - label) * F.relu(cosine_sim - margin)
|
|
|
|
# Combine the two losses
|
|
loss = positive_loss + negative_loss
|
|
|
|
# Return the average loss across the batch
|
|
return loss.mean()
|
|
|
|
|
|
# %%
|
|
# training loop
|
|
num_epochs = 50
|
|
batch_size = 256
|
|
train_data_size = train_embeds.shape[0]
|
|
output_dim = 512
|
|
learning_rate = 2e-6
|
|
steps_per_epoch = math.ceil(train_data_size / batch_size)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
linear_model = linear_map(
|
|
input_dim=train_embeds.shape[-1],
|
|
output_dim=output_dim)
|
|
|
|
linear_model = torch.compile(linear_model)
|
|
linear_model.to(device)
|
|
|
|
optimizer = torch.optim.Adam(linear_model.parameters(), lr=learning_rate)
|
|
# Define the lambda function for linear decay
|
|
# It should return the multiplier for the learning rate (starts at 1.0 and goes to 0)
|
|
def linear_decay(epoch):
|
|
return 1 - epoch / num_epochs
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_decay)
|
|
# %%
|
|
|
|
for epoch in tqdm(range(num_epochs)):
|
|
with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
|
|
for _ in range(steps_per_epoch):
|
|
emb1_batch, emb2_batch, labels_batch = create_pairs(
|
|
train_embeds,
|
|
label_tensor,
|
|
batch_size
|
|
)
|
|
output1 = linear_model(emb1_batch.to(device))
|
|
output2 = linear_model(emb2_batch.to(device))
|
|
|
|
loss = contrastive_loss_cosine(output1, output2, labels_batch.to(device), margin=0.7)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
scheduler.step()
|
|
|
|
# if epoch % 10 == 0:
|
|
# print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")
|
|
pbar.set_postfix({'loss': loss.item()})
|
|
pbar.update(1)
|
|
|
|
|
|
# %%
|
|
# apply the re-projection layer to achieve better classification
|
|
# new_embeds = for loop of model on old embeds
|
|
|
|
# we have to transform our previous embeddings into mapped embeddings
|
|
def predict_batch(embeds, model, batch_size):
|
|
output_list = []
|
|
with torch.no_grad():
|
|
for i in range(0, len(embeds), batch_size):
|
|
batch_embed = embeds[i:i+batch_size]
|
|
output = model(batch_embed.to(device))
|
|
output_list.append(output)
|
|
|
|
all_embeddings = torch.cat(output_list, dim=0)
|
|
return all_embeddings
|
|
|
|
train_remap_embeds = predict_batch(train_embeds, linear_model, 32)
|
|
|
|
|
|
####################################################
|
|
# %%
|
|
# we can start classifying
|
|
|
|
# %%
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Define the neural network with non-linearity
|
|
class NeuralNet(nn.Module):
|
|
def __init__(self, input_dim, output_dim):
|
|
super(NeuralNet, self).__init__()
|
|
self.fc1 = nn.Linear(input_dim, 512) # First layer (input to hidden)
|
|
self.relu = nn.ReLU() # Non-linearity
|
|
self.fc2 = nn.Linear(512, 256) # Output layer
|
|
self.fc3 = nn.Linear(256, output_dim)
|
|
|
|
def forward(self, x):
|
|
out = self.fc1(x) # Input to hidden
|
|
out = self.relu(out) # Apply non-linearity
|
|
out = self.fc2(out) # Hidden to output
|
|
out = self.relu(out)
|
|
out = self.fc3(out)
|
|
return out
|
|
|
|
# Example usage
|
|
input_dim = 512 # Example input dimension (adjust based on your mean embedding size)
|
|
output_dim = 202 # 202 classes
|
|
|
|
model = NeuralNet(input_dim, output_dim)
|
|
model = torch.compile(model)
|
|
model = model.to(device)
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
# %%
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
# we use the re-projected embeds
|
|
mean_embeddings = train_remap_embeds
|
|
# mean_embeddings = train_embeds
|
|
# labels = torch.randint(0, 2, (1000,)) # Random binary labels (0 for OOD, 1 for ID)
|
|
|
|
train_labels = generate_labels(train_df, mdm_list)
|
|
labels = torch.tensor(train_labels)
|
|
|
|
# Create a dataset and DataLoader
|
|
dataset = TensorDataset(mean_embeddings, labels)
|
|
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
|
|
# %%
|
|
# Define loss function and optimizer
|
|
# criterion = nn.BCELoss() # Binary cross entropy loss
|
|
# criterion = nn.BCEWithLogitsLoss()
|
|
criterion = nn.CrossEntropyLoss()
|
|
learning_rate = 1e-3
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
# Define the scheduler
|
|
|
|
|
|
# Training loop
|
|
num_epochs = 800 # Adjust as needed
|
|
|
|
|
|
# Define the lambda function for linear decay
|
|
# It should return the multiplier for the learning rate (starts at 1.0 and goes to 0)
|
|
def linear_decay(epoch):
|
|
return 1 - epoch / num_epochs
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_decay)
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
running_loss = 0.0
|
|
for inputs, targets in dataloader:
|
|
# Forward pass
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
outputs = model(inputs)
|
|
# loss = criterion(outputs.squeeze(), targets.float().squeeze()) # Ensure the target is float
|
|
loss = criterion(outputs, targets)
|
|
|
|
# Backward pass and optimization
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
scheduler.step()
|
|
|
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(dataloader)}")
|
|
|
|
|
|
|
|
# %%
|
|
data_path = f"../../data_preprocess/exports/dataset/group_{fold}/test_all.csv"
|
|
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
test_df = test_df[test_df['MDM']].reset_index(drop=True)
|
|
|
|
checkpoint_directory = "../../train/baseline"
|
|
directory = os.path.join(checkpoint_directory, f'checkpoint_fold_{fold}')
|
|
# Use glob to find matching paths
|
|
# path is usually checkpoint_fold_1/checkpoint-<step number>
|
|
# we are guaranteed to save only 1 checkpoint from training
|
|
pattern = 'checkpoint-*'
|
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
|
|
|
test_embedder = Embedder(input_df=test_df)
|
|
test_embeds = test_embedder.make_embedding(checkpoint_path)
|
|
|
|
test_remap_embeds = predict_batch(test_embeds, linear_model, 32)
|
|
|
|
|
|
test_labels = generate_labels(test_df, mdm_list)
|
|
# %%
|
|
# mean_embeddings = test_embeds
|
|
mean_embeddings = test_remap_embeds
|
|
|
|
labels = torch.tensor(test_labels)
|
|
dataset = TensorDataset(mean_embeddings, labels)
|
|
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
|
|
|
|
model.eval()
|
|
output_classes = []
|
|
output_probs = []
|
|
for inputs, _ in dataloader:
|
|
with torch.no_grad():
|
|
inputs = inputs.to(device)
|
|
logits = model(inputs)
|
|
probabilities = torch.softmax(logits, dim=1)
|
|
# predicted_classes = torch.argmax(probabilities, dim=1)
|
|
max_probabilities, predicted_classes = torch.max(probabilities, dim=1)
|
|
output_classes.extend(predicted_classes.to('cpu').numpy())
|
|
output_probs.extend(max_probabilities.to('cpu').numpy())
|
|
|
|
|
|
# %%
|
|
# evaluation
|
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
|
y_true = test_labels
|
|
y_pred = output_classes
|
|
|
|
# Compute metrics
|
|
accuracy = accuracy_score(y_true, y_pred)
|
|
f1 = f1_score(y_true, y_pred, average='macro')
|
|
precision = precision_score(y_true, y_pred, average='macro')
|
|
recall = recall_score(y_true, y_pred, average='macro')
|
|
|
|
# Print the results
|
|
print(f'Accuracy: {accuracy:.2f}')
|
|
print(f'F1 Score: {f1:.2f}')
|
|
print(f'Precision: {precision:.2f}')
|
|
print(f'Recall: {recall:.2f}')
|
|
|
|
|
|
# %%
|