hipom_data_mapping/production/run.py

75 lines
2.0 KiB
Python

# %%
import pandas as pd
import os
import glob
from end_to_end.mapper import Mapper
from end_to_end.preprocess import Abbreviator
from end_to_end.deduplication import run_deduplication
from end_to_end.rule_based_correction import Corrector
# global config
BATCH_SIZE = 512
SHIPS_LIST = [1000,1001,1003,1004]
# SHIPS_LIST = [1000]
# %%
# START: we import the raw data csv and extract only a few ships from it to simulate incoming json
data_path = 'raw_data.csv'
full_df = pd.read_csv(data_path, skipinitialspace=True)
# subset ships only to that found in SHIPS_LIST
df = full_df[full_df['ships_idx'].isin(SHIPS_LIST)].reset_index(drop=True)
# test parameters
num_rows = len(df) - 1
df = df[:num_rows]
print(len(df))
# pre-process data
abbreviator = Abbreviator(df)
df = abbreviator.run()
# %%
##########################################
# run mapping
# checkpoint
# Use glob to find matching paths
checkpoint_path = 'models/mapping_model'
mapper = Mapper(checkpoint_path)
mapper.prepare_dataloader(df, batch_size=BATCH_SIZE, max_length=128)
thing_prediction_list, property_prediction_list = mapper.generate()
# add labels too
# thing_actual_list, property_actual_list = decode_preds(pred_labels)
# Convert the list to a Pandas DataFrame
df_out = pd.DataFrame({
'p_thing': thing_prediction_list,
'p_property': property_prediction_list
})
# df_out['p_thing_correct'] = df_out['p_thing'] == df_out['thing']
# df_out['p_property_correct'] = df_out['p_property'] == df_out['property']
df = pd.concat([df, df_out], axis=1)
# %%
###################################
# run rule-based correction
corrector = Corrector(df)
df = corrector.run_correction()
# %%
####################################
# run de_duplication with thresholding
data_path = "train_all.csv"
train_df = pd.read_csv(data_path, skipinitialspace=True)
train_df['mapping'] = train_df['thing'] + " " + train_df['property']
df = run_deduplication(
test_df=df,
train_df=train_df,
batch_size=BATCH_SIZE,
threshold=0.9,
diagnostic=True)
# %%