domain_mapping/esAppMod_train/seq2seq_t5_simple/prediction/predict.py

64 lines
2.0 KiB
Python

import pandas as pd
import os
import glob
from inference import Inference
checkpoint_directory = '../'
BATCH_SIZE = 512
def infer():
print(f"Inference for data")
# import test data
data_path = '../../../esAppMod_data_import/test_seq.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
##########################################
# run inference
# checkpoint
# Use glob to find matching paths
directory = os.path.join(checkpoint_directory, f'checkpoint')
# Use glob to find matching paths
# path is usually checkpoint-<step number>
# we are guaranteed to save only 1 checkpoint from training
pattern = 'checkpoint-*'
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
infer = Inference(checkpoint_path)
infer.prepare_dataloader(df, batch_size=BATCH_SIZE, max_length=128)
prediction_list = infer.generate()
# add labels too
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
# Convert the list to a Pandas DataFrame
df_out = pd.DataFrame({
'class_prediction': prediction_list
})
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
# df = pd.concat([df, df_out], axis=1)
# we can save the t5 generation output here
df_out.to_csv(f"exports/result.csv", index=False)
# here we want to evaluate mapping accuracy within the valid in mdm data only
# predictions = pd.to_numeric(df_out['class_prediction'], errors="coerce")
condition_correct = df_out['class_prediction'] == df['entity_seq']
pred_correct_proportion = sum(condition_correct)/len(df_out)
# write output to file output.txt
with open("output.txt", "a") as f:
print(f'Accuracy for fold: {pred_correct_proportion}', file=f)
###########################################
# execute
# reset file before writing to it
with open("output.txt", "w") as f:
print('', file=f)
infer()