72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
|
|
||
|
import pandas as pd
|
||
|
import os
|
||
|
import glob
|
||
|
from inference import Inference
|
||
|
|
||
|
checkpoint_directory = '../../train/baseline'
|
||
|
|
||
|
def infer_and_select(fold):
|
||
|
print(f"Inference for fold {fold}")
|
||
|
# import test data
|
||
|
data_path = f"../../data_preprocess/exports/dataset/group_{fold}/test_all.csv"
|
||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||
|
|
||
|
# get target data
|
||
|
data_path = f"../../data_preprocess/exports/dataset/group_{fold}/train_all.csv"
|
||
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||
|
# processing to help with selection later
|
||
|
train_df['thing_property'] = train_df['thing'] + " " + train_df['property']
|
||
|
|
||
|
|
||
|
##########################################
|
||
|
# run inference
|
||
|
# checkpoint
|
||
|
# Use glob to find matching paths
|
||
|
directory = os.path.join(checkpoint_directory, f'checkpoint_fold_{fold}')
|
||
|
# Use glob to find matching paths
|
||
|
# path is usually checkpoint_fold_1/checkpoint-<step number>
|
||
|
# we are guaranteed to save only 1 checkpoint from training
|
||
|
pattern = 'checkpoint-*'
|
||
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
||
|
|
||
|
|
||
|
infer = Inference(checkpoint_path)
|
||
|
infer.prepare_dataloader(df, batch_size=256, max_length=128)
|
||
|
thing_prediction_list, property_prediction_list = infer.generate()
|
||
|
|
||
|
# add labels too
|
||
|
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
||
|
# Convert the list to a Pandas DataFrame
|
||
|
df_out = pd.DataFrame({
|
||
|
'p_thing': thing_prediction_list,
|
||
|
'p_property': property_prediction_list
|
||
|
})
|
||
|
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
|
||
|
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
|
||
|
df = pd.concat([df, df_out], axis=1)
|
||
|
|
||
|
# we can save the t5 generation output here
|
||
|
# df.to_parquet(f"exports/fold_{fold}/t5_output.parquet")
|
||
|
|
||
|
# here we want to evaluate mapping accuracy within the valid in mdm data only
|
||
|
in_mdm = df['MDM']
|
||
|
condition_correct_thing = df['p_thing'] == df['thing']
|
||
|
condition_correct_property = df['p_property'] == df['property']
|
||
|
prediction_mdm_correct = sum(condition_correct_thing & condition_correct_property & in_mdm)
|
||
|
pred_correct_proportion = prediction_mdm_correct/sum(in_mdm)
|
||
|
|
||
|
# write output to file output.txt
|
||
|
with open("output.txt", "a") as f:
|
||
|
print(f'Accuracy for fold {fold}: {pred_correct_proportion}', file=f)
|
||
|
|
||
|
###########################################
|
||
|
# Execute for all folds
|
||
|
|
||
|
# reset file before writing to it
|
||
|
with open("output.txt", "w") as f:
|
||
|
print('', file=f)
|
||
|
|
||
|
for fold in [1,2,3,4,5]:
|
||
|
infer_and_select(fold)
|