87 lines
2.5 KiB
Plaintext
87 lines
2.5 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Group 1 Recall: 0.947941\n",
|
||
|
"Group 2 Recall: 0.902804\n",
|
||
|
"Group 3 Recall: 0.970884\n",
|
||
|
"Group 4 Recall: 0.965271\n",
|
||
|
"Group 5 Recall: 0.949611\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"\n",
|
||
|
"# mode, model_name, fold_group 설정\n",
|
||
|
"mode = 'td_unit' # 원하는 모드를 설정하세요\n",
|
||
|
"model_name = 'google/t5-efficient-tiny' # 모델 이름을 설정하세요\n",
|
||
|
"recall_by_group = {}\n",
|
||
|
"\n",
|
||
|
"# 그룹 1부터 5까지 처리\n",
|
||
|
"for group in range(1, 6):\n",
|
||
|
" # CSV 파일 경로 설정 (model_name 포함)\n",
|
||
|
" debug_output_path = f\"0.dresult/{mode}/{model_name}/{group}/test_p.csv\"\n",
|
||
|
" \n",
|
||
|
" # CSV 파일 로드\n",
|
||
|
" try:\n",
|
||
|
" df = pd.read_csv(debug_output_path)\n",
|
||
|
" except FileNotFoundError:\n",
|
||
|
" print(f\"File not found: {debug_output_path}\")\n",
|
||
|
" continue\n",
|
||
|
"\n",
|
||
|
" # 1. MDM이 True인 항목만 필터\n",
|
||
|
" filtered_df = df[df['MDM'] == True].copy()\n",
|
||
|
"\n",
|
||
|
" # 2. p_thing과 p_property가 thing과 property와 같으면 TP로 설정 (loc 사용)\n",
|
||
|
" filtered_df.loc[:, 'TP'] = (filtered_df['p_thing'] == filtered_df['thing']) & (filtered_df['p_property'] == filtered_df['property'])\n",
|
||
|
"\n",
|
||
|
" # 3. TP 갯수와 전체 MDM 갯수로 Recall 계산\n",
|
||
|
" tp_count = filtered_df['TP'].sum()\n",
|
||
|
" total_count = len(filtered_df)\n",
|
||
|
"\n",
|
||
|
" # Recall 계산\n",
|
||
|
" if total_count > 0:\n",
|
||
|
" recall = tp_count / total_count\n",
|
||
|
" else:\n",
|
||
|
" recall = 0\n",
|
||
|
"\n",
|
||
|
" # 그룹별 Recall 저장\n",
|
||
|
" recall_by_group[group] = recall\n",
|
||
|
"\n",
|
||
|
"# Recall 출력\n",
|
||
|
"for group, recall in recall_by_group.items():\n",
|
||
|
" print(f\"Group {group} Recall: {recall:.6f}\")\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "torch",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|