64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
|
|
import pandas as pd
|
|
import os
|
|
import glob
|
|
from inference import Inference
|
|
|
|
checkpoint_directory = '../'
|
|
|
|
BATCH_SIZE = 512
|
|
|
|
def infer():
|
|
print(f"Inference for data")
|
|
# import test data
|
|
data_path = '../../../esAppMod_data_import/test_seq.csv'
|
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
|
|
##########################################
|
|
# run inference
|
|
# checkpoint
|
|
# Use glob to find matching paths
|
|
directory = os.path.join(checkpoint_directory, f'checkpoint')
|
|
# Use glob to find matching paths
|
|
# path is usually checkpoint-<step number>
|
|
# we are guaranteed to save only 1 checkpoint from training
|
|
pattern = 'checkpoint-*'
|
|
checkpoint_path = glob.glob(os.path.join(directory, pattern))[0]
|
|
|
|
|
|
infer = Inference(checkpoint_path)
|
|
infer.prepare_dataloader(df, batch_size=BATCH_SIZE, max_length=128)
|
|
prediction_list = infer.generate()
|
|
|
|
# add labels too
|
|
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
|
# Convert the list to a Pandas DataFrame
|
|
df_out = pd.DataFrame({
|
|
'class_prediction': prediction_list
|
|
})
|
|
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
|
|
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
|
|
# df = pd.concat([df, df_out], axis=1)
|
|
|
|
# we can save the t5 generation output here
|
|
df_out.to_csv(f"exports/result.csv", index=False)
|
|
|
|
# here we want to evaluate mapping accuracy within the valid in mdm data only
|
|
# predictions = pd.to_numeric(df_out['class_prediction'], errors="coerce")
|
|
condition_correct = df_out['class_prediction'] == df['entity_seq']
|
|
pred_correct_proportion = sum(condition_correct)/len(df_out)
|
|
|
|
# write output to file output.txt
|
|
with open("output.txt", "a") as f:
|
|
print(f'Accuracy for fold: {pred_correct_proportion}', file=f)
|
|
|
|
###########################################
|
|
# execute
|
|
|
|
# reset file before writing to it
|
|
with open("output.txt", "w") as f:
|
|
print('', file=f)
|
|
|
|
infer()
|