2024-10-31 15:58:20 +09:00
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
import os
|
|
|
|
import glob
|
|
|
|
from inference import Inference
|
|
|
|
|
2024-11-05 16:49:18 +09:00
|
|
|
checkpoint_directory = '../'
|
2024-10-31 15:58:20 +09:00
|
|
|
|
|
|
|
def infer_and_select(fold):
|
|
|
|
print(f"Inference for fold {fold}")
|
|
|
|
# import test data
|
2024-11-05 16:49:18 +09:00
|
|
|
data_path = f"../../../data_preprocess/exports/dataset/group_{fold}/test_all.csv"
|
2024-10-31 15:58:20 +09:00
|
|
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
|
|
|
|
# get target data
|
2024-11-05 16:49:18 +09:00
|
|
|
data_path = f"../../../data_preprocess/exports/dataset/group_{fold}/train_all.csv"
|
2024-10-31 15:58:20 +09:00
|
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
# processing to help with selection later
|
|
|
|
train_df['thing_property'] = train_df['thing'] + " " + train_df['property']
|
|
|
|
|
|
|
|
|
|
|
|
##########################################
|
|
|
|
# run inference
|
|
|
|
# checkpoint
|
|
|
|
# Use glob to find matching paths
|
|
|
|
directory = os.path.join(checkpoint_directory, f'checkpoint_fold_{fold}')
|
|
|
|
# Use glob to find matching paths
|
|
|
|
# path is usually checkpoint_fold_1/checkpoint-<step number>
|
|
|
|
# we are guaranteed to save only 1 checkpoint from training
|
|
|
|
pattern = 'checkpoint-*'
|
|
|
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
|
|
|
|
|
|
|
|
|
|
|
infer = Inference(checkpoint_path)
|
|
|
|
infer.prepare_dataloader(df, batch_size=256, max_length=128)
|
|
|
|
thing_prediction_list, property_prediction_list = infer.generate()
|
|
|
|
|
|
|
|
# add labels too
|
|
|
|
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
|
|
|
# Convert the list to a Pandas DataFrame
|
|
|
|
df_out = pd.DataFrame({
|
|
|
|
'p_thing': thing_prediction_list,
|
|
|
|
'p_property': property_prediction_list
|
|
|
|
})
|
|
|
|
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
|
|
|
|
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
|
|
|
|
df = pd.concat([df, df_out], axis=1)
|
|
|
|
|
|
|
|
# we can save the t5 generation output here
|
2024-10-31 16:35:28 +09:00
|
|
|
df.to_csv(f"exports/result_group_{fold}.csv")
|
2024-10-31 15:58:20 +09:00
|
|
|
|
|
|
|
# here we want to evaluate mapping accuracy within the valid in mdm data only
|
|
|
|
in_mdm = df['MDM']
|
|
|
|
condition_correct_thing = df['p_thing'] == df['thing']
|
|
|
|
condition_correct_property = df['p_property'] == df['property']
|
|
|
|
prediction_mdm_correct = sum(condition_correct_thing & condition_correct_property & in_mdm)
|
|
|
|
pred_correct_proportion = prediction_mdm_correct/sum(in_mdm)
|
|
|
|
|
|
|
|
# write output to file output.txt
|
|
|
|
with open("output.txt", "a") as f:
|
|
|
|
print(f'Accuracy for fold {fold}: {pred_correct_proportion}', file=f)
|
|
|
|
|
|
|
|
###########################################
|
|
|
|
# Execute for all folds
|
|
|
|
|
|
|
|
# reset file before writing to it
|
|
|
|
with open("output.txt", "w") as f:
|
|
|
|
print('', file=f)
|
|
|
|
|
|
|
|
for fold in [1,2,3,4,5]:
|
|
|
|
infer_and_select(fold)
|