75 lines
2.0 KiB
Python
75 lines
2.0 KiB
Python
# %%
|
|
import pandas as pd
|
|
import os
|
|
import glob
|
|
from end_to_end.mapper import Mapper
|
|
from end_to_end.preprocess import Abbreviator
|
|
from end_to_end.deduplication import run_deduplication
|
|
from end_to_end.rule_based_correction import Corrector
|
|
|
|
|
|
# global config
|
|
BATCH_SIZE = 512
|
|
SHIPS_LIST = [1000,1001,1003,1004]
|
|
# SHIPS_LIST = [1000]
|
|
|
|
# %%
|
|
# START: we import the raw data csv and extract only a few ships from it to simulate incoming json
|
|
data_path = 'raw_data.csv'
|
|
full_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
# subset ships only to that found in SHIPS_LIST
|
|
df = full_df[full_df['ships_idx'].isin(SHIPS_LIST)].reset_index(drop=True)
|
|
|
|
# test parameters
|
|
num_rows = len(df) - 1
|
|
df = df[:num_rows]
|
|
print(len(df))
|
|
|
|
# pre-process data
|
|
abbreviator = Abbreviator(df)
|
|
df = abbreviator.run()
|
|
|
|
# %%
|
|
##########################################
|
|
# run mapping
|
|
# checkpoint
|
|
# Use glob to find matching paths
|
|
checkpoint_path = 'models/mapping_model'
|
|
mapper = Mapper(checkpoint_path)
|
|
mapper.prepare_dataloader(df, batch_size=BATCH_SIZE, max_length=128)
|
|
thing_prediction_list, property_prediction_list = mapper.generate()
|
|
|
|
# add labels too
|
|
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
|
# Convert the list to a Pandas DataFrame
|
|
df_out = pd.DataFrame({
|
|
'p_thing': thing_prediction_list,
|
|
'p_property': property_prediction_list
|
|
})
|
|
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
|
|
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
|
|
df = pd.concat([df, df_out], axis=1)
|
|
|
|
# %%
|
|
###################################
|
|
# run rule-based correction
|
|
corrector = Corrector(df)
|
|
df = corrector.run_correction()
|
|
|
|
|
|
# %%
|
|
####################################
|
|
# run de_duplication with thresholding
|
|
data_path = "train_all.csv"
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
train_df['mapping'] = train_df['thing'] + " " + train_df['property']
|
|
|
|
df = run_deduplication(
|
|
test_df=df,
|
|
train_df=train_df,
|
|
batch_size=BATCH_SIZE,
|
|
threshold=0.9,
|
|
diagnostic=True)
|
|
|
|
# %%
|