479 lines
559 KiB
Plaintext
479 lines
559 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
|
||
|
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||
|
" warnings.warn(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 1 completed. Loss: 6.255228042602539\n",
|
||
|
"Validation Accuracy after Epoch 1: 1.14%\n",
|
||
|
"Epoch 2 completed. Loss: 5.604339599609375\n",
|
||
|
"Validation Accuracy after Epoch 2: 4.41%\n",
|
||
|
"Epoch 3 completed. Loss: 5.204141616821289\n",
|
||
|
"Validation Accuracy after Epoch 3: 8.45%\n",
|
||
|
"Epoch 4 completed. Loss: 5.063010215759277\n",
|
||
|
"Validation Accuracy after Epoch 4: 11.15%\n",
|
||
|
"Epoch 5 completed. Loss: 4.573624134063721\n",
|
||
|
"Validation Accuracy after Epoch 5: 14.19%\n",
|
||
|
"Epoch 6 completed. Loss: 4.300177574157715\n",
|
||
|
"Validation Accuracy after Epoch 6: 17.04%\n",
|
||
|
"Epoch 7 completed. Loss: 3.6726627349853516\n",
|
||
|
"Validation Accuracy after Epoch 7: 20.74%\n",
|
||
|
"Epoch 8 completed. Loss: 3.872858762741089\n",
|
||
|
"Validation Accuracy after Epoch 8: 24.35%\n",
|
||
|
"Epoch 9 completed. Loss: 3.4192371368408203\n",
|
||
|
"Validation Accuracy after Epoch 9: 26.34%\n",
|
||
|
"Epoch 10 completed. Loss: 3.1378841400146484\n",
|
||
|
"Validation Accuracy after Epoch 10: 28.81%\n",
|
||
|
"Epoch 11 completed. Loss: 2.6661603450775146\n",
|
||
|
"Validation Accuracy after Epoch 11: 30.66%\n",
|
||
|
"Epoch 12 completed. Loss: 2.873978614807129\n",
|
||
|
"Validation Accuracy after Epoch 12: 31.89%\n",
|
||
|
"Epoch 13 completed. Loss: 2.7727530002593994\n",
|
||
|
"Validation Accuracy after Epoch 13: 35.22%\n",
|
||
|
"Epoch 14 completed. Loss: 2.4830515384674072\n",
|
||
|
"Validation Accuracy after Epoch 14: 37.97%\n",
|
||
|
"Epoch 15 completed. Loss: 2.2535340785980225\n",
|
||
|
"Validation Accuracy after Epoch 15: 40.15%\n",
|
||
|
"Epoch 16 completed. Loss: 2.2358787059783936\n",
|
||
|
"Validation Accuracy after Epoch 16: 43.71%\n",
|
||
|
"Epoch 17 completed. Loss: 2.270059585571289\n",
|
||
|
"Validation Accuracy after Epoch 17: 46.13%\n",
|
||
|
"Epoch 18 completed. Loss: 1.9263427257537842\n",
|
||
|
"Validation Accuracy after Epoch 18: 49.83%\n",
|
||
|
"Epoch 19 completed. Loss: 1.8839179277420044\n",
|
||
|
"Validation Accuracy after Epoch 19: 53.77%\n",
|
||
|
"Epoch 20 completed. Loss: 1.7309540510177612\n",
|
||
|
"Validation Accuracy after Epoch 20: 59.23%\n",
|
||
|
"Epoch 21 completed. Loss: 1.6431639194488525\n",
|
||
|
"Validation Accuracy after Epoch 21: 65.50%\n",
|
||
|
"Epoch 22 completed. Loss: 1.4129509925842285\n",
|
||
|
"Validation Accuracy after Epoch 22: 68.77%\n",
|
||
|
"Epoch 23 completed. Loss: 1.6112335920333862\n",
|
||
|
"Validation Accuracy after Epoch 23: 72.05%\n",
|
||
|
"Epoch 24 completed. Loss: 1.3653665781021118\n",
|
||
|
"Validation Accuracy after Epoch 24: 74.23%\n",
|
||
|
"Epoch 25 completed. Loss: 1.2029541730880737\n",
|
||
|
"Validation Accuracy after Epoch 25: 76.46%\n",
|
||
|
"Epoch 26 completed. Loss: 1.1179462671279907\n",
|
||
|
"Validation Accuracy after Epoch 26: 78.64%\n",
|
||
|
"Epoch 27 completed. Loss: 1.1831905841827393\n",
|
||
|
"Validation Accuracy after Epoch 27: 81.02%\n",
|
||
|
"Epoch 28 completed. Loss: 0.8559420704841614\n",
|
||
|
"Validation Accuracy after Epoch 28: 82.72%\n",
|
||
|
"Epoch 29 completed. Loss: 0.8667900562286377\n",
|
||
|
"Validation Accuracy after Epoch 29: 82.58%\n",
|
||
|
"Epoch 30 completed. Loss: 1.1078470945358276\n",
|
||
|
"Validation Accuracy after Epoch 30: 83.72%\n",
|
||
|
"Epoch 31 completed. Loss: 0.8486237525939941\n",
|
||
|
"Validation Accuracy after Epoch 31: 84.34%\n",
|
||
|
"Epoch 32 completed. Loss: 0.804058313369751\n",
|
||
|
"Validation Accuracy after Epoch 32: 85.00%\n",
|
||
|
"Epoch 33 completed. Loss: 0.6297520399093628\n",
|
||
|
"Validation Accuracy after Epoch 33: 85.67%\n",
|
||
|
"Epoch 34 completed. Loss: 0.6711896657943726\n",
|
||
|
"Validation Accuracy after Epoch 34: 85.57%\n",
|
||
|
"Epoch 35 completed. Loss: 0.7203101515769958\n",
|
||
|
"Validation Accuracy after Epoch 35: 86.43%\n",
|
||
|
"Epoch 36 completed. Loss: 0.7537139654159546\n",
|
||
|
"Validation Accuracy after Epoch 36: 86.57%\n",
|
||
|
"Epoch 37 completed. Loss: 0.49183693528175354\n",
|
||
|
"Validation Accuracy after Epoch 37: 86.66%\n",
|
||
|
"Epoch 38 completed. Loss: 0.5906791090965271\n",
|
||
|
"Validation Accuracy after Epoch 38: 87.00%\n",
|
||
|
"Epoch 39 completed. Loss: 0.4300324320793152\n",
|
||
|
"Validation Accuracy after Epoch 39: 87.52%\n",
|
||
|
"Epoch 40 completed. Loss: 0.4216059744358063\n",
|
||
|
"Validation Accuracy after Epoch 40: 87.38%\n",
|
||
|
"Epoch 41 completed. Loss: 0.5085476636886597\n",
|
||
|
"Validation Accuracy after Epoch 41: 88.09%\n",
|
||
|
"Epoch 42 completed. Loss: 0.5296332836151123\n",
|
||
|
"Validation Accuracy after Epoch 42: 87.90%\n",
|
||
|
"Epoch 43 completed. Loss: 0.37904512882232666\n",
|
||
|
"Validation Accuracy after Epoch 43: 88.70%\n",
|
||
|
"Epoch 44 completed. Loss: 0.41481274366378784\n",
|
||
|
"Validation Accuracy after Epoch 44: 88.18%\n",
|
||
|
"Epoch 45 completed. Loss: 0.4976593255996704\n",
|
||
|
"Validation Accuracy after Epoch 45: 88.61%\n",
|
||
|
"Epoch 46 completed. Loss: 0.5229529142379761\n",
|
||
|
"Validation Accuracy after Epoch 46: 88.61%\n",
|
||
|
"Epoch 47 completed. Loss: 0.36945009231567383\n",
|
||
|
"Validation Accuracy after Epoch 47: 89.23%\n",
|
||
|
"Epoch 48 completed. Loss: 0.23448766767978668\n",
|
||
|
"Validation Accuracy after Epoch 48: 89.32%\n",
|
||
|
"Epoch 49 completed. Loss: 0.1870148777961731\n",
|
||
|
"Validation Accuracy after Epoch 49: 89.51%\n",
|
||
|
"Epoch 50 completed. Loss: 0.3627645969390869\n",
|
||
|
"Validation Accuracy after Epoch 50: 90.08%\n",
|
||
|
"Epoch 51 completed. Loss: 0.2712886929512024\n",
|
||
|
"Validation Accuracy after Epoch 51: 90.22%\n",
|
||
|
"Epoch 52 completed. Loss: 0.30932098627090454\n",
|
||
|
"Validation Accuracy after Epoch 52: 90.75%\n",
|
||
|
"Epoch 53 completed. Loss: 0.4048871099948883\n",
|
||
|
"Validation Accuracy after Epoch 53: 91.12%\n",
|
||
|
"Epoch 54 completed. Loss: 0.28516653180122375\n",
|
||
|
"Validation Accuracy after Epoch 54: 90.93%\n",
|
||
|
"Epoch 55 completed. Loss: 0.14647549390792847\n",
|
||
|
"Validation Accuracy after Epoch 55: 91.03%\n",
|
||
|
"Epoch 56 completed. Loss: 0.17482930421829224\n",
|
||
|
"Validation Accuracy after Epoch 56: 90.84%\n",
|
||
|
"Epoch 57 completed. Loss: 0.2837833762168884\n",
|
||
|
"Validation Accuracy after Epoch 57: 91.27%\n",
|
||
|
"Epoch 58 completed. Loss: 0.2879948914051056\n",
|
||
|
"Validation Accuracy after Epoch 58: 91.50%\n",
|
||
|
"Epoch 59 completed. Loss: 0.2823488712310791\n",
|
||
|
"Validation Accuracy after Epoch 59: 91.50%\n",
|
||
|
"Epoch 60 completed. Loss: 0.25875282287597656\n",
|
||
|
"Validation Accuracy after Epoch 60: 91.65%\n",
|
||
|
"Epoch 61 completed. Loss: 0.3561888337135315\n",
|
||
|
"Validation Accuracy after Epoch 61: 91.69%\n",
|
||
|
"Epoch 62 completed. Loss: 0.14592915773391724\n",
|
||
|
"Validation Accuracy after Epoch 62: 91.98%\n",
|
||
|
"Epoch 63 completed. Loss: 0.20252785086631775\n",
|
||
|
"Validation Accuracy after Epoch 63: 91.74%\n",
|
||
|
"Epoch 64 completed. Loss: 0.13009151816368103\n",
|
||
|
"Validation Accuracy after Epoch 64: 91.93%\n",
|
||
|
"Epoch 65 completed. Loss: 0.2165553867816925\n",
|
||
|
"Validation Accuracy after Epoch 65: 91.74%\n",
|
||
|
"Epoch 66 completed. Loss: 0.21152348816394806\n",
|
||
|
"Validation Accuracy after Epoch 66: 91.84%\n",
|
||
|
"Epoch 67 completed. Loss: 0.12813371419906616\n",
|
||
|
"Validation Accuracy after Epoch 67: 92.31%\n",
|
||
|
"Epoch 68 completed. Loss: 0.15637439489364624\n",
|
||
|
"Validation Accuracy after Epoch 68: 92.55%\n",
|
||
|
"Epoch 69 completed. Loss: 0.23416577279567719\n",
|
||
|
"Validation Accuracy after Epoch 69: 92.26%\n",
|
||
|
"Epoch 70 completed. Loss: 0.1982598602771759\n",
|
||
|
"Validation Accuracy after Epoch 70: 92.55%\n",
|
||
|
"Epoch 71 completed. Loss: 0.07098475098609924\n",
|
||
|
"Validation Accuracy after Epoch 71: 92.36%\n",
|
||
|
"Epoch 72 completed. Loss: 0.1463148593902588\n",
|
||
|
"Validation Accuracy after Epoch 72: 92.88%\n",
|
||
|
"Epoch 73 completed. Loss: 0.13348183035850525\n",
|
||
|
"Validation Accuracy after Epoch 73: 92.93%\n",
|
||
|
"Epoch 74 completed. Loss: 0.11992514878511429\n",
|
||
|
"Validation Accuracy after Epoch 74: 92.83%\n",
|
||
|
"Epoch 75 completed. Loss: 0.17647001147270203\n",
|
||
|
"Validation Accuracy after Epoch 75: 92.74%\n",
|
||
|
"Epoch 76 completed. Loss: 0.1956612467765808\n",
|
||
|
"Validation Accuracy after Epoch 76: 92.93%\n",
|
||
|
"Epoch 77 completed. Loss: 0.1292801946401596\n",
|
||
|
"Validation Accuracy after Epoch 77: 92.83%\n",
|
||
|
"Epoch 78 completed. Loss: 0.07164446264505386\n",
|
||
|
"Validation Accuracy after Epoch 78: 93.12%\n",
|
||
|
"Epoch 79 completed. Loss: 0.20230534672737122\n",
|
||
|
"Validation Accuracy after Epoch 79: 93.40%\n",
|
||
|
"Epoch 80 completed. Loss: 0.11642713099718094\n",
|
||
|
"Validation Accuracy after Epoch 80: 92.74%\n",
|
||
|
"Epoch 81 completed. Loss: 0.06227307766675949\n",
|
||
|
"Validation Accuracy after Epoch 81: 93.59%\n",
|
||
|
"Epoch 82 completed. Loss: 0.07498838752508163\n",
|
||
|
"Validation Accuracy after Epoch 82: 93.45%\n",
|
||
|
"Epoch 83 completed. Loss: 0.11042595654726028\n",
|
||
|
"Validation Accuracy after Epoch 83: 93.31%\n",
|
||
|
"Epoch 84 completed. Loss: 0.05229116976261139\n",
|
||
|
"Validation Accuracy after Epoch 84: 93.45%\n",
|
||
|
"Epoch 85 completed. Loss: 0.14967505633831024\n",
|
||
|
"Validation Accuracy after Epoch 85: 93.36%\n",
|
||
|
"Epoch 86 completed. Loss: 0.09601894021034241\n",
|
||
|
"Validation Accuracy after Epoch 86: 93.45%\n",
|
||
|
"Epoch 87 completed. Loss: 0.1715390384197235\n",
|
||
|
"Validation Accuracy after Epoch 87: 93.64%\n",
|
||
|
"Epoch 88 completed. Loss: 0.05024575814604759\n",
|
||
|
"Validation Accuracy after Epoch 88: 93.74%\n",
|
||
|
"Epoch 89 completed. Loss: 0.09373823553323746\n",
|
||
|
"Validation Accuracy after Epoch 89: 93.64%\n",
|
||
|
"Epoch 90 completed. Loss: 0.07261866331100464\n",
|
||
|
"Validation Accuracy after Epoch 90: 93.31%\n",
|
||
|
"Epoch 91 completed. Loss: 0.07679086923599243\n",
|
||
|
"Validation Accuracy after Epoch 91: 93.55%\n",
|
||
|
"Epoch 92 completed. Loss: 0.11550895869731903\n",
|
||
|
"Validation Accuracy after Epoch 92: 93.40%\n",
|
||
|
"Epoch 93 completed. Loss: 0.053604159504175186\n",
|
||
|
"Validation Accuracy after Epoch 93: 93.55%\n",
|
||
|
"Epoch 94 completed. Loss: 0.151311457157135\n",
|
||
|
"Validation Accuracy after Epoch 94: 93.36%\n",
|
||
|
"Epoch 95 completed. Loss: 0.12411662191152573\n",
|
||
|
"Validation Accuracy after Epoch 95: 93.07%\n",
|
||
|
"Epoch 96 completed. Loss: 0.0956164076924324\n",
|
||
|
"Validation Accuracy after Epoch 96: 93.97%\n",
|
||
|
"Epoch 97 completed. Loss: 0.1610800176858902\n",
|
||
|
"Validation Accuracy after Epoch 97: 93.74%\n",
|
||
|
"Epoch 98 completed. Loss: 0.0801430493593216\n",
|
||
|
"Validation Accuracy after Epoch 98: 93.83%\n",
|
||
|
"Epoch 99 completed. Loss: 0.09204142540693283\n",
|
||
|
"Validation Accuracy after Epoch 99: 93.78%\n",
|
||
|
"Epoch 100 completed. Loss: 0.02089226432144642\n",
|
||
|
"Validation Accuracy after Epoch 100: 93.74%\n",
|
||
|
"Accuracy (MDM=True) for Group 4: 93.48%\n",
|
||
|
"Results saved to 0.class_document/bert-tiny/4/test_p_c.csv\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"from transformers import BertTokenizer, BertForSequenceClassification, AdamW\n",
|
||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
||
|
"import torch\n",
|
||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import os \n",
|
||
|
"\n",
|
||
|
"group_number = 4\n",
|
||
|
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
|
||
|
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
|
||
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||
|
"output_path = f'0.class_document/bert-tiny/{group_number}/test_p_c.csv' # 디렉토리 변경\n",
|
||
|
"\n",
|
||
|
"# 데이터 읽기\n",
|
||
|
"train_data = pd.read_csv(train_path)\n",
|
||
|
"valid_data = pd.read_csv(valid_path)\n",
|
||
|
"test_data = pd.read_csv(test_path)\n",
|
||
|
"\n",
|
||
|
"# thing_property 필드 추가\n",
|
||
|
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
|
||
|
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
|
||
|
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
|
||
|
"\n",
|
||
|
"# 토크나이저 및 레이블 인코더 설정\n",
|
||
|
"tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny') # 모델 변경\n",
|
||
|
"label_encoder = LabelEncoder()\n",
|
||
|
"label_encoder.fit(train_data['thing_property'])\n",
|
||
|
"\n",
|
||
|
"# 새로운 레이블 처리\n",
|
||
|
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
|
||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||
|
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
|
||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||
|
"\n",
|
||
|
"# 'unknown_label' 추가\n",
|
||
|
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
|
||
|
"\n",
|
||
|
"# 레이블 인코딩\n",
|
||
|
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
|
||
|
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
|
||
|
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
|
||
|
"\n",
|
||
|
"# 텍스트 및 레이블 준비\n",
|
||
|
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
|
||
|
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
|
||
|
"\n",
|
||
|
"# 텍스트 인코딩\n",
|
||
|
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"\n",
|
||
|
"# 레이블을 텐서로 변환\n",
|
||
|
"train_labels = torch.tensor(train_labels.values)\n",
|
||
|
"valid_labels = torch.tensor(valid_labels.values)\n",
|
||
|
"\n",
|
||
|
"# 커스텀 데이터셋 정의\n",
|
||
|
"class CustomDataset(Dataset):\n",
|
||
|
" def __init__(self, encodings, labels):\n",
|
||
|
" self.encodings = encodings\n",
|
||
|
" self.labels = labels\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, idx):\n",
|
||
|
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
|
||
|
" item['labels'] = self.labels[idx]\n",
|
||
|
" return item\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.labels)\n",
|
||
|
"\n",
|
||
|
"# 데이터셋 생성\n",
|
||
|
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
|
||
|
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
|
||
|
"\n",
|
||
|
"# 데이터로더 생성\n",
|
||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
||
|
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
|
||
|
"\n",
|
||
|
"# 모델 및 옵티마이저 설정\n",
|
||
|
"model = BertForSequenceClassification.from_pretrained(\n",
|
||
|
" 'prajjwal1/bert-tiny', # 모델 변경\n",
|
||
|
" num_labels=len(train_data['thing_property'].unique())\n",
|
||
|
")\n",
|
||
|
"optimizer = AdamW(model.parameters(), lr=2e-4)\n",
|
||
|
"\n",
|
||
|
"# 장치 설정 (GPU 1 사용)\n",
|
||
|
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"model.to(device)\n",
|
||
|
"\n",
|
||
|
"epochs = 100\n",
|
||
|
"for epoch in range(epochs):\n",
|
||
|
" model.train()\n",
|
||
|
" for batch in train_loader:\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||
|
" labels = batch['labels'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
||
|
" loss = outputs.loss\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
|
||
|
"\n",
|
||
|
" # 검증 루프\n",
|
||
|
" model.eval()\n",
|
||
|
" correct, total = 0, 0\n",
|
||
|
"\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for batch in valid_loader:\n",
|
||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||
|
" labels = batch['labels'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
||
|
" correct += (predictions == labels).sum().item()\n",
|
||
|
" total += labels.size(0)\n",
|
||
|
"\n",
|
||
|
" valid_accuracy = correct / total\n",
|
||
|
" print(f'Validation Accuracy after Epoch {epoch + 1}: {valid_accuracy * 100:.2f}%')\n",
|
||
|
"\n",
|
||
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
||
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) # 레이블은 사용되지 않으므로 임시로 0을 사용\n",
|
||
|
"\n",
|
||
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
||
|
"\n",
|
||
|
"model.eval()\n",
|
||
|
"predicted_thing_properties = []\n",
|
||
|
"predicted_scores = []\n",
|
||
|
"\n",
|
||
|
"with torch.no_grad():\n",
|
||
|
" for batch in test_loader:\n",
|
||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
||
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
||
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
||
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
||
|
"\n",
|
||
|
"# 예측된 thing_property를 레이블 인코더로 디코딩\n",
|
||
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
||
|
"\n",
|
||
|
"# thing_property를 thing과 property로 나눔\n",
|
||
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
||
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
||
|
"test_data['c_score'] = predicted_scores\n",
|
||
|
"\n",
|
||
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
||
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
||
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
||
|
"\n",
|
||
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
||
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||
|
"\n",
|
||
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||
|
"\n",
|
||
|
"# 결과를 저장하기 전에 폴더가 존재하는지 확인하고, 없으면 생성\n",
|
||
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||
|
"\n",
|
||
|
"test_data.to_csv(output_path, index=False)\n",
|
||
|
"print(f'Results saved to {output_path}')\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABW4AAAKyCAYAAABFb0fEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5wV9b3/8dfM6X1773QQWECQooKIokaNokbRWEiuuSkmudckpueanmsS701uyi8xEbGAqImaxL4iS1cpS+/b92xvZ/f0MzO/PxY2oihgOC4sn+fj4cMtc2a/Z/bw3pnP+c7nqxiGYSCEEEIIIYQQQgghhBDijKEO9QCEEEIIIYQQQgghhBBCHEsKt0IIIYQQQgghhBBCCHGGkcKtEEIIIYQQQgghhBBCnGGkcCuEEEIIIYQQQgghhBBnGCncCiGEEEIIIYQQQgghxBlGCrdCCCGEEEIIIYQQQghxhpHCrRBCCCGEEEIIIYQQQpxhpHArhBBCCCGEEEIIIYQQZxjzUA9AfDi6ruP3+/F4PCiKMtTDEUIIIYQQQgghhBBCnIBhGPT19ZGXl4eqfvCcWincnqX8fj+FhYVDPQwhhBBCCCGEEEIIIcQpamhooKCg4AO3kcLtWcrj8QADv2Sv1zvEozmz+f1+8vLyhnoYQohhRrJFCJEMki1CiGSQbBFCJINky4cTCAQoLCwcrO19ECncnqWOtkfwer1SuD0Bl8uFyWQa6mEIIYYZyRYhRDJItgghkkGyRQiRDJIt/5qTaX0qi5OJYa+urm6ohyCEGIYkW4QQySDZIoRIBskWIUQySLYknxRuhRBCCCGEEEIIIYQQ4gwjhVsx7KWmpg71EIQQw5BkixAiGSRbhBDJINkihEgGyZbkk8KtGPak34oQIhkkW4QQySDZIoRIBskWIUQySLYknxRuxbDX0dEx1EMQQgxDki1CiGSQbBFCJINkixAiGSRbkk8Kt0IIIYQQQgghhBBCCHGGkcKtGPYKCgqGeghCiGFIskUIkQySLUKIZJBsEUIkg2RL8knhVgx7XV1dQz0EIcQwJNkihEgGyRYhRDJItgghkkGyJfmkcCuGvVAoNNRDEEIMQ5ItQohkkGwRQiSDZIsQIhkkW5JPCrdi2LNYLEM9BCHEMCTZIoRIBskWIUQySLYIIZJBsiX5pHArhr38/PyhHoIQYhiSbBFCJINkixAiGSRbhBDJINmSfFK4FcNebW3tUA9BCDEMSbYIIZJBskUIkQySLUKIZJBsST4p3AohhBBCCCGEEEIIIcQZRgq3YthLSUkZ6iEIIYYhyRYhRDJItgghkkGyRQiRDJItySeFWzHsWa3WoR6CEGIYkmwRQiSDZIsQIhkkW4QQySDZknxSuBXDXltb21APQQgxDEm2CCGSQbJFCJEMki1CiGSQbEk+KdwKIYQQQgghhBBCCCHEGUYKt2LYy8/PH+ohCCGGIckWIUQySLYIIZJBskUIkQySLcknhVsx7PX09Az1EIQQw5BkixAiGSRbzmzRaJSamhr27t1LTU0N0Wj0tG4vRLJItgghkkGyJfnMQz0AIZItGAwO9RCEEMOQZIsQIhkkW85Mfr+fyspKNqxdT7gnCLoBqoIz1c2sC2czd+5c8vLyPvT2QiSbZIsQIhkkW5JPCrdi2DOb5WUuhDj9JFuEEMkg2XLmqaio4KknVpLojpDrzOCisjm4bE6C0RBVh3fx2soXeePVVXzitptZsGDBKW8vxEdBskUIkQySLcmnGIZhDPUgxKkLBAL4fD56e3vxer1DPRwhhBBCCCGGnYqKCp58ZDnpCTc3X/xxSrILURRl8PuGYVDb2sDKNc/Tae5n9KSxHNix76S3v+WuW6V4K4QQQpxjTqWmJz1uxbBXXV091EMQQgxDki1CiGSQbEm+k+076/f7eeqJlaQn3Hzx2k9TmlN0TBEWQFEUSnOK+OK1n8baB8sffgxvyHpS26cn3Dz1xEr8fn/SnqsQR0m2CCGSQbIl+WROsxBCCCGEEGLYO9W+s5WVlSS6I9x8+S247M733W99axPPrn+BnYd209vRw5jMsg/cHsBld3LzxR/nN68+QmVlJYsXLz5tz1MIIYQQw4cUbsWw5/P5hnoIQohhSLJFHE8gEKCqqopAIIDX66W8vFxaGolTItmSHCfqO/vCY8/xxKNPcMHsC5g7dy7jxo1jw9r15DozKMkuPO4+H694ml89+0faujtJdXgxKSqarvG1R37Az5//HV9e9Bk+ueCm9x1TSXYhOY50Nq7bwKJFi7DZbMl6+kJItgghkkKyJfnOqsKtpmns3r2bt99+m82bN/P222+zY8cO4vE4AHPnzmX16tUfat+vv/46y5YtY9OmTTQ1NWGz2SgoKGDhwoV8+tOfZuzYsae8z7179/Lwww/zyiuv0NjYSDQaJT8/n1mzZnHHHXdw6aWXfqixilNjt9uHeghCiGFIskW809atW1m6dCmvV64lENUxUFAw8NlMzJ97IUuWLGHq1KlDPUxxFpBsOf2O6VN7+S3H9J3dXr2b9p4Otu7aRiKW4OC2vTz15+UYdgWXxcHdl9z2nnYHANd+55PsrTtAoS+X80dPoDx3HC6rg2AszDb/Hg501vL9ZT/nqdXP87cfPX7ccSmKQvmI83i1ej1+v5/S0tKkHgdxbpNsEUIkg2RL8p01hdvnnnuO2267jVAodFr3GwgE+MxnPsPKlSuP+XooFKK7u5udO3fyq1/9iu9///t885vfPOn9/vjHP+b73//+YFH5qIMHD3Lw4EEeffRRFi9ezB/+8Ac8Hs9peS7i+FpbWykrKxvqYQghhhnJFnHUAw88wB8efpSAYUNNKcI3djwmuxMtEqK9bjdPvrqBF197g3//1B3cd999xzy2vb2dVatW0dXVRVpaGvPnzyczM/M9PyMajeL3+4lEItjtdvLy8mR23jAl2XJ6vbtPrcvuJBqP0dLVykMvPc6rb7+BW3UwylfIeXljSbV5ieox3mzczsH2Wvbv2ktVahHlkycP7vPa73yS6qY65pbO4K6pi5icMxYU0HQdVVG4ZdLH2Nl6gEe2/pVNDVVc+51PMn/SHH751/9HNBHHZrbwlUWf5T8+8TmcNgcYBpFIZAiPkjgXSLYIIZJBsiX5zprCbU9Pz2kv2sbjca6//npWrVo1+LXzzjuPqVOnEolEWLt2Lc3NzcTjcb71rW8Rj8f53ve+d8L9fu973+OHP/zh4Oe5ublcdNFF2O12tmzZwu7duwFYsWIFnZ2dvPDCC5jNZ82vQgghhBBHPPDAA/zmz48R8RZRdNEifHkjUJR/rv1aMGMhvf7DNKz9K7/582MA3Hfffbzyyis8+OCDvL1tO1HFBooKho7NiDF9yiTuvfdeFi5cONiTc/Xa9bT39KPrBqqqkJXqYe5xenIKIY71zj61gVAfL7+9irf2bWXb/u00dDSjGAp2l5W+aJC9zQdx2ZyMySzllmlXs3LrC1gxc3jfQQDKJ0/m8Yqn2Vt3gLmlM7j/0i+S6hi4RVTHGPyZqqJSnjeO+1O/yBee/z4b921hy8HtZLjSMKsmErrGz//6e36w8n/w2T3ceP0NMmNJCCGEEMelGIZhnHizoffII4+wZMkSsrOzmT59+uB/r7zyCr/61a+AU2+V8M4Cq91uZ+nSpdxyyy2D34/FYnznO9/h5z//OTBwO9Mbb7zB3Llz33efr7/+OgsWLBj8/Gtf+xo/+tGPsFqtg19bsWIFn/rUpwbfWf/+979/UgXhdwoEAvh8Pnp7e6V33gmEw2EcDsdQD0MIMcxItoitW7dy06130OcpYuzHP4/F8f530MTDfex7/nd4+uqZNG4Ur6/dSMKRijm9CGf+WFSbEz0aItS4l0RXA+ZwN9MnTyCmK7T3RTD5shk1ZQ4uXyrxaJjWw3swhbvIT3Vy5223HHPuIc5uki2nTzQa5av/+RW8PRamjZrM39a/jB5N0NXTjb+zmTRnCmVphRR4c0BRiCaitAe76An3EdNjBLUIk3LGMKugnO54H9NnzmDRzz6Fz+TmBwu+zJS88YM/ywA0LQGAqqqoisq9L/y
|
||
|
"text/plain": [
|
||
|
"<Figure size 1400x700 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.manifold import TSNE\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"# 'filtered_data_plot.csv' 읽기\n",
|
||
|
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
|
||
|
"\n",
|
||
|
"# 데이터 토큰화\n",
|
||
|
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"\n",
|
||
|
"# BERT 임베딩 계산 함수\n",
|
||
|
"def get_bert_embeddings(model, encodings, device):\n",
|
||
|
" model.eval()\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" input_ids = encodings['input_ids'].to(device)\n",
|
||
|
" attention_mask = encodings['attention_mask'].to(device)\n",
|
||
|
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
|
||
|
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
|
||
|
"\n",
|
||
|
"# BERT 모델로 임베딩 계산\n",
|
||
|
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
|
||
|
"\n",
|
||
|
"# t-SNE 차원 축소\n",
|
||
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||
|
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
|
||
|
"\n",
|
||
|
"# 시각화를 위한 준비\n",
|
||
|
"unique_patterns = filtered_data['pattern'].unique()\n",
|
||
|
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
|
||
|
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(14, 7))\n",
|
||
|
"\n",
|
||
|
"# 각 패턴별로 시각화\n",
|
||
|
"for pattern, color_idx in pattern_to_color.items():\n",
|
||
|
" pattern_indices = filtered_data['pattern'] == pattern\n",
|
||
|
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
|
||
|
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
|
||
|
"\n",
|
||
|
"# 그래프 설정\n",
|
||
|
"plt.xticks(fontsize=24)\n",
|
||
|
"plt.yticks(fontsize=24)\n",
|
||
|
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "torch",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|