2024-11-28 11:02:22 +09:00
|
|
|
# %%
|
2024-10-31 15:58:20 +09:00
|
|
|
import pandas as pd
|
|
|
|
import os
|
|
|
|
import glob
|
|
|
|
|
|
|
|
# directory for checkpoints
|
2024-11-28 11:02:22 +09:00
|
|
|
checkpoint_directory = '../../train/mapping_t5_complete_desc_unit_name'
|
2024-10-31 15:58:20 +09:00
|
|
|
|
2024-10-31 16:35:28 +09:00
|
|
|
def select(fold):
|
2024-10-31 15:58:20 +09:00
|
|
|
# import test data
|
2024-11-28 11:02:22 +09:00
|
|
|
data_path = f"../../train/mapping_t5_complete_desc_unit_name/mapping_prediction/exports/result_group_{fold}.csv"
|
2024-10-31 15:58:20 +09:00
|
|
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
|
|
|
|
# get target data
|
|
|
|
data_path = f"../../data_preprocess/exports/dataset/group_{fold}/train_all.csv"
|
|
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
# processing to help with selection later
|
|
|
|
train_df['thing_property'] = train_df['thing'] + " " + train_df['property']
|
|
|
|
|
|
|
|
|
|
|
|
##########################################
|
|
|
|
# Process the dataframe for selection
|
|
|
|
|
|
|
|
# we start to cull predictions from here
|
2024-10-31 16:35:28 +09:00
|
|
|
data_master_path = "../../data_import/exports/data_model_master_export.csv"
|
2024-10-31 15:58:20 +09:00
|
|
|
df_master = pd.read_csv(data_master_path, skipinitialspace=True)
|
|
|
|
data_mapping = df
|
|
|
|
# Generate patterns
|
|
|
|
data_mapping['thing_pattern'] = data_mapping['thing'].str.replace(r'\d', '#', regex=True)
|
|
|
|
data_mapping['property_pattern'] = data_mapping['property'].str.replace(r'\d', '#', regex=True)
|
|
|
|
data_mapping['pattern'] = data_mapping['thing_pattern'] + " " + data_mapping['property_pattern']
|
|
|
|
df_master['master_pattern'] = df_master['thing'] + " " + df_master['property']
|
|
|
|
# Create a set of unique patterns from master for fast lookup
|
|
|
|
master_patterns = set(df_master['master_pattern'])
|
|
|
|
# thing_patterns = set(df_master['thing'])
|
|
|
|
# Check each pattern in data_mapping if it exists in df_master and assign the "MDM" field
|
|
|
|
data_mapping['MDM'] = data_mapping['pattern'].apply(lambda x: x in master_patterns)
|
|
|
|
|
|
|
|
# check if prediction is in MDM
|
|
|
|
data_mapping['p_thing_pattern'] = data_mapping['p_thing'].str.replace(r'\d', '#', regex=True)
|
|
|
|
data_mapping['p_property_pattern'] = data_mapping['p_property'].str.replace(r'\d', '#', regex=True)
|
|
|
|
data_mapping['p_pattern'] = data_mapping['p_thing_pattern'] + " " + data_mapping['p_property_pattern']
|
|
|
|
data_mapping['p_MDM'] = data_mapping['p_pattern'].apply(lambda x: x in master_patterns)
|
|
|
|
|
|
|
|
df = data_mapping
|
|
|
|
|
|
|
|
|
|
|
|
# selection
|
|
|
|
###########################################
|
|
|
|
# we now have to perform selection
|
|
|
|
# we restrict to predictions of a class of a ship
|
|
|
|
# then perform similarity selection with in-distribution data
|
|
|
|
# the magic is in performing per-class selection, not global
|
|
|
|
# import importlib
|
|
|
|
import selection
|
|
|
|
# importlib.reload(selection)
|
|
|
|
selector = selection.Selector(input_df=df, reference_df=train_df, fold=fold)
|
2024-10-31 16:35:28 +09:00
|
|
|
|
|
|
|
##########################################
|
|
|
|
# run inference
|
|
|
|
# checkpoint
|
|
|
|
# Use glob to find matching paths
|
|
|
|
directory = os.path.join(checkpoint_directory, f'checkpoint_fold_{fold}')
|
|
|
|
# Use glob to find matching paths
|
|
|
|
# path is usually checkpoint_fold_1/checkpoint-<step number>
|
|
|
|
# we are guaranteed to save only 1 checkpoint from training
|
|
|
|
pattern = 'checkpoint-*'
|
|
|
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
2024-10-31 15:58:20 +09:00
|
|
|
tp, tn, fp, fn = selector.run_selection(checkpoint_path=checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
# write output to file output.txt
|
|
|
|
with open("output.txt", "a") as f:
|
|
|
|
print(80 * '*', file=f)
|
|
|
|
print(f'Statistics for fold {fold}', file=f)
|
|
|
|
print(f"tp: {tp}", file=f)
|
|
|
|
print(f"tn: {tn}", file=f)
|
|
|
|
print(f"fp: {fp}", file=f)
|
|
|
|
print(f"fn: {fn}", file=f)
|
|
|
|
print(f"fold: {fold}", file=f)
|
|
|
|
print("accuracy: ", (tp+tn)/(tp+tn+fp+fn), file=f)
|
|
|
|
print("f1_score: ", (2*tp)/((2*tp) + fp + fn), file=f)
|
|
|
|
print("precision: ", (tp)/(tp+fp), file=f)
|
|
|
|
print("recall: ", (tp)/(tp+fn), file=f)
|
|
|
|
|
|
|
|
###########################################
|
|
|
|
# Execute for all folds
|
|
|
|
|
|
|
|
# reset file before writing to it
|
|
|
|
with open("output.txt", "w") as f:
|
|
|
|
print('', file=f)
|
|
|
|
|
|
|
|
for fold in [1,2,3,4,5]:
|
2024-10-31 16:35:28 +09:00
|
|
|
select(fold)
|
2024-11-28 11:02:22 +09:00
|
|
|
|
|
|
|
# %%
|