hipom_data_mapping/post_process/tfidf_class/2a.classifier_bert.ipynb

415 lines
556 KiB
Plaintext
Raw Normal View History

2024-09-25 08:52:30 +09:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 completed. Loss: 5.812545299530029\n",
"Validation Accuracy after Epoch 1: 2.20%\n",
"Epoch 2 completed. Loss: 5.4337921142578125\n",
"Validation Accuracy after Epoch 2: 12.22%\n",
"Epoch 3 completed. Loss: 4.7191081047058105\n",
"Validation Accuracy after Epoch 3: 18.78%\n",
"Epoch 4 completed. Loss: 3.5866851806640625\n",
"Validation Accuracy after Epoch 4: 27.73%\n",
"Epoch 5 completed. Loss: 2.891603469848633\n",
"Validation Accuracy after Epoch 5: 40.61%\n",
"Epoch 6 completed. Loss: 3.5778417587280273\n",
"Validation Accuracy after Epoch 6: 50.07%\n",
"Epoch 7 completed. Loss: 2.8838517665863037\n",
"Validation Accuracy after Epoch 7: 64.50%\n",
"Epoch 8 completed. Loss: 1.2843190431594849\n",
"Validation Accuracy after Epoch 8: 69.37%\n",
"Epoch 9 completed. Loss: 2.803881883621216\n",
"Validation Accuracy after Epoch 9: 76.16%\n",
"Epoch 10 completed. Loss: 1.6859067678451538\n",
"Validation Accuracy after Epoch 10: 77.75%\n",
"Epoch 11 completed. Loss: 1.5909239053726196\n",
"Validation Accuracy after Epoch 11: 80.19%\n",
"Epoch 12 completed. Loss: 3.397331953048706\n",
"Validation Accuracy after Epoch 12: 81.31%\n",
"Epoch 13 completed. Loss: 2.6174156665802\n",
"Validation Accuracy after Epoch 13: 83.14%\n",
"Epoch 14 completed. Loss: 2.170588493347168\n",
"Validation Accuracy after Epoch 14: 85.34%\n",
"Epoch 15 completed. Loss: 0.20281337201595306\n",
"Validation Accuracy after Epoch 15: 86.04%\n",
"Epoch 16 completed. Loss: 0.688520610332489\n",
"Validation Accuracy after Epoch 16: 86.74%\n",
"Epoch 17 completed. Loss: 2.2658097743988037\n",
"Validation Accuracy after Epoch 17: 87.73%\n",
"Epoch 18 completed. Loss: 1.1048349142074585\n",
"Validation Accuracy after Epoch 18: 88.06%\n",
"Epoch 19 completed. Loss: 0.42507076263427734\n",
"Validation Accuracy after Epoch 19: 88.24%\n",
"Epoch 20 completed. Loss: 0.2606792747974396\n",
"Validation Accuracy after Epoch 20: 88.71%\n",
"Epoch 21 completed. Loss: 0.31851115822792053\n",
"Validation Accuracy after Epoch 21: 89.18%\n",
"Epoch 22 completed. Loss: 0.09924610704183578\n",
"Validation Accuracy after Epoch 22: 89.46%\n",
"Epoch 23 completed. Loss: 0.09280075877904892\n",
"Validation Accuracy after Epoch 23: 90.26%\n",
"Epoch 24 completed. Loss: 0.12750865519046783\n",
"Validation Accuracy after Epoch 24: 89.93%\n",
"Epoch 25 completed. Loss: 0.06864642351865768\n",
"Validation Accuracy after Epoch 25: 90.40%\n",
"Epoch 26 completed. Loss: 1.2031394243240356\n",
"Validation Accuracy after Epoch 26: 90.87%\n",
"Epoch 27 completed. Loss: 0.049047697335481644\n",
"Validation Accuracy after Epoch 27: 90.82%\n",
"Epoch 28 completed. Loss: 0.2666439712047577\n",
"Validation Accuracy after Epoch 28: 91.10%\n",
"Epoch 29 completed. Loss: 1.1274741888046265\n",
"Validation Accuracy after Epoch 29: 91.38%\n",
"Epoch 30 completed. Loss: 0.040213990956544876\n",
"Validation Accuracy after Epoch 30: 91.66%\n",
"Epoch 31 completed. Loss: 0.04501065984368324\n",
"Validation Accuracy after Epoch 31: 91.71%\n",
"Epoch 32 completed. Loss: 0.031954798847436905\n",
"Validation Accuracy after Epoch 32: 92.37%\n",
"Epoch 33 completed. Loss: 0.13622191548347473\n",
"Validation Accuracy after Epoch 33: 92.32%\n",
"Epoch 34 completed. Loss: 0.22619958221912384\n",
"Validation Accuracy after Epoch 34: 92.18%\n",
"Epoch 35 completed. Loss: 0.03372799605131149\n",
"Validation Accuracy after Epoch 35: 92.37%\n",
"Epoch 36 completed. Loss: 0.04779297485947609\n",
"Validation Accuracy after Epoch 36: 92.41%\n",
"Epoch 37 completed. Loss: 0.02235586754977703\n",
"Validation Accuracy after Epoch 37: 92.88%\n",
"Epoch 38 completed. Loss: 0.07272133976221085\n",
"Validation Accuracy after Epoch 38: 92.97%\n",
"Epoch 39 completed. Loss: 0.19073285162448883\n",
"Validation Accuracy after Epoch 39: 92.93%\n",
"Epoch 40 completed. Loss: 0.030236737802624702\n",
"Validation Accuracy after Epoch 40: 93.16%\n",
"Epoch 41 completed. Loss: 0.048501163721084595\n",
"Validation Accuracy after Epoch 41: 93.58%\n",
"Epoch 42 completed. Loss: 0.3853774964809418\n",
"Validation Accuracy after Epoch 42: 93.77%\n",
"Epoch 43 completed. Loss: 0.03214043006300926\n",
"Validation Accuracy after Epoch 43: 94.05%\n",
"Epoch 44 completed. Loss: 0.03621528670191765\n",
"Validation Accuracy after Epoch 44: 93.72%\n",
"Epoch 45 completed. Loss: 0.12950848042964935\n",
"Validation Accuracy after Epoch 45: 94.00%\n",
"Epoch 46 completed. Loss: 0.9027665257453918\n",
"Validation Accuracy after Epoch 46: 94.05%\n",
"Epoch 47 completed. Loss: 0.014634504914283752\n",
"Validation Accuracy after Epoch 47: 93.86%\n",
"Epoch 48 completed. Loss: 0.019594205543398857\n",
"Validation Accuracy after Epoch 48: 94.15%\n",
"Epoch 49 completed. Loss: 0.05953751504421234\n",
"Validation Accuracy after Epoch 49: 94.15%\n",
"Epoch 50 completed. Loss: 0.01590300165116787\n",
"Validation Accuracy after Epoch 50: 94.38%\n",
"Epoch 51 completed. Loss: 0.015979576855897903\n",
"Validation Accuracy after Epoch 51: 94.05%\n",
"Epoch 52 completed. Loss: 0.10404151678085327\n",
"Validation Accuracy after Epoch 52: 94.19%\n",
"Epoch 53 completed. Loss: 0.03786793723702431\n",
"Validation Accuracy after Epoch 53: 94.19%\n",
"Epoch 54 completed. Loss: 0.06969229131937027\n",
"Validation Accuracy after Epoch 54: 94.33%\n",
"Epoch 55 completed. Loss: 0.028161363676190376\n",
"Validation Accuracy after Epoch 55: 93.96%\n",
"Epoch 56 completed. Loss: 0.008608573116362095\n",
"Validation Accuracy after Epoch 56: 94.05%\n",
"Epoch 57 completed. Loss: 0.04131948947906494\n",
"Validation Accuracy after Epoch 57: 94.38%\n",
"Epoch 58 completed. Loss: 0.057713668793439865\n",
"Validation Accuracy after Epoch 58: 94.29%\n",
"Epoch 59 completed. Loss: 0.006651934236288071\n",
"Validation Accuracy after Epoch 59: 94.05%\n",
"Epoch 60 completed. Loss: 0.009415321983397007\n",
"Validation Accuracy after Epoch 60: 94.29%\n",
"Epoch 61 completed. Loss: 0.07291022688150406\n",
"Validation Accuracy after Epoch 61: 94.33%\n",
"Epoch 62 completed. Loss: 0.019646339118480682\n",
"Validation Accuracy after Epoch 62: 94.47%\n",
"Epoch 63 completed. Loss: 0.005233598407357931\n",
"Validation Accuracy after Epoch 63: 94.33%\n",
"Epoch 64 completed. Loss: 0.006535904016345739\n",
"Validation Accuracy after Epoch 64: 94.38%\n",
"Epoch 65 completed. Loss: 0.04618072509765625\n",
"Validation Accuracy after Epoch 65: 94.52%\n",
"Epoch 66 completed. Loss: 0.003822903148829937\n",
"Validation Accuracy after Epoch 66: 94.52%\n",
"Epoch 67 completed. Loss: 0.007317937910556793\n",
"Validation Accuracy after Epoch 67: 94.38%\n",
"Epoch 68 completed. Loss: 0.043759193271398544\n",
"Validation Accuracy after Epoch 68: 94.33%\n",
"Epoch 69 completed. Loss: 0.005311332643032074\n",
"Validation Accuracy after Epoch 69: 94.24%\n",
"Epoch 70 completed. Loss: 0.014933265745639801\n",
"Validation Accuracy after Epoch 70: 94.15%\n",
"Epoch 71 completed. Loss: 0.017947230488061905\n",
"Validation Accuracy after Epoch 71: 94.43%\n",
"Epoch 72 completed. Loss: 0.03381676599383354\n",
"Validation Accuracy after Epoch 72: 94.38%\n",
"Epoch 73 completed. Loss: 0.0038632701616734266\n",
"Validation Accuracy after Epoch 73: 94.52%\n",
"Epoch 74 completed. Loss: 0.008432925678789616\n",
"Validation Accuracy after Epoch 74: 94.47%\n",
"Epoch 75 completed. Loss: 0.009978505782783031\n",
"Validation Accuracy after Epoch 75: 94.33%\n",
"Epoch 76 completed. Loss: 0.0022784313187003136\n",
"Validation Accuracy after Epoch 76: 94.43%\n",
"Epoch 77 completed. Loss: 0.002989798551425338\n",
"Validation Accuracy after Epoch 77: 94.61%\n",
"Epoch 78 completed. Loss: 0.008676871657371521\n",
"Validation Accuracy after Epoch 78: 94.66%\n",
"Epoch 79 completed. Loss: 0.009547003544867039\n",
"Validation Accuracy after Epoch 79: 94.10%\n",
"Epoch 80 completed. Loss: 0.016957035288214684\n",
"Validation Accuracy after Epoch 80: 93.54%\n",
"Accuracy (MDM=True) for Group 3: 94.33%\n",
"Results saved to 0.class_document/3/test_p_c.csv\n"
]
}
],
"source": [
"import pandas as pd\n",
"from transformers import BertTokenizer, BertForSequenceClassification, AdamW\n",
"from sklearn.preprocessing import LabelEncoder\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"\n",
"group_number = 3\n",
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
"output_path = f'0.class_document/bert/{group_number}/test_p_c.csv'\n",
"\n",
"train_data = pd.read_csv(train_path)\n",
"valid_data = pd.read_csv(valid_path)\n",
"test_data = pd.read_csv(test_path)\n",
"\n",
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
"\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
"label_encoder = LabelEncoder()\n",
"label_encoder.fit(train_data['thing_property'])\n",
"\n",
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
"\n",
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
"\n",
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
"\n",
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
"\n",
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
"\n",
"train_labels = torch.tensor(train_labels.values)\n",
"valid_labels = torch.tensor(valid_labels.values)\n",
"\n",
"class CustomDataset(Dataset):\n",
" def __init__(self, encodings, labels):\n",
" self.encodings = encodings\n",
" self.labels = labels\n",
"\n",
" def __getitem__(self, idx):\n",
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
" item['labels'] = self.labels[idx]\n",
" return item\n",
"\n",
" def __len__(self):\n",
" return len(self.labels)\n",
"\n",
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
"\n",
"model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(train_data['thing_property'].unique()))\n",
"optimizer = AdamW(model.parameters(), lr=3e-5)\n",
"\n",
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
"model.to(device)\n",
"\n",
"epochs = 80\n",
"for epoch in range(epochs):\n",
" model.train()\n",
" for batch in train_loader:\n",
" optimizer.zero_grad()\n",
" input_ids = batch['input_ids'].to(device)\n",
" attention_mask = batch['attention_mask'].to(device)\n",
" labels = batch['labels'].to(device)\n",
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
" loss = outputs.loss\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
"\n",
" # 검증 루프\n",
" model.eval()\n",
" correct, total = 0, 0\n",
"\n",
" with torch.no_grad():\n",
" for batch in valid_loader:\n",
" input_ids = batch['input_ids'].to(device)\n",
" attention_mask = batch['attention_mask'].to(device)\n",
" labels = batch['labels'].to(device)\n",
" outputs = model(input_ids, attention_mask=attention_mask)\n",
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
" correct += (predictions == labels).sum().item()\n",
" total += labels.size(0)\n",
"\n",
" valid_accuracy = correct / total\n",
" print(f'Validation Accuracy after Epoch {epoch + 1}: {valid_accuracy * 100:.2f}%')\n",
"\n",
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) \n",
"\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
"\n",
"model.eval()\n",
"predicted_thing_properties = []\n",
"predicted_scores = []\n",
"\n",
"with torch.no_grad():\n",
" for batch in test_loader:\n",
" input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)\n",
" outputs = model(input_ids, attention_mask=attention_mask)\n",
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
"\n",
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
"\n",
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
"test_data['c_score'] = predicted_scores\n",
"\n",
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
"\n",
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
"\n",
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
"\n",
"test_data.to_csv(output_path, index=False)\n",
"print(f'Results saved to {output_path}')\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABW4AAAKyCAYAAABFb0fEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3Rc1dX38e+dXjSjLlm92HKRu4w7uGEwmGpTXOghzhNIe0IS3vQESAVCevIkBAjNFAMmFNNkY7n33lSsrlFvI2n63Pv+IVuxsQGDAHns/VmLtSzp3jtnjuUfd/ac2UfRNE1DCCGEEEIIIYQQQgghxFlDN9ADEEIIIYQQQgghhBBCCHEyKdwKIYQQQgghhBBCCCHEWUYKt0IIIYQQQgghhBBCCHGWkcKtEEIIIYQQQgghhBBCnGWkcCuEEEIIIYQQQgghhBBnGSncCiGEEEIIIYQQQgghxFlGCrdCCCGEEEIIIYQQQghxlpHCrRBCCCGEEEIIIYQQQpxlDAM9APHpqKqKy+XC4XCgKMpAD0cIIYQQQgghhBBCCPExNE2jq6uL1NRUdLqPXlMrhdsI5XK5yMjIGOhhCCGEEEIIIYQQQgghPqGamhrS09M/8hgp3EYoh8MB9P4lO53OAR7N2cvlcpGamjrQwxBCRDDJESFEf0mOCCH6S3JECNFfkiNnD7fbTUZGRl9t76NI4TZCHW+P4HQ6pXD7Eex2O3q9fqCHIYSIYJIjQoj+khwRQvSX5IgQor8kR84+Z9L6VDYnE+e0qqqqgR6CECLCSY4IIfpLckQI0V+SI0KI/pIciUxSuBVCCCGEEEIIIYQQQoizjBRuxTktNjZ2oIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5Tfq3CCH6S3JECNFfkiNCiP6SHBFC9JfkSGSSwq04p7W0tAz0EIQQEU5yRAjRX5IjQoj+khwRQvSX5EhkksKtEEIIIYQQQgghhBBCnGWkcCvOaenp6QM9BCFEhJMcEUL0l+SIEKK/JEeEEP0lORKZpHArzmltbW0DPQQhRISTHBFC9JfkiBCivyRHhBD9JTkSmaRwK85pHo9noIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5zWg0DvQQhBARTnJECNFfkiNCiP6SHBFC9JfkSGSSwq04p6WlpQ30EIQQEU5yRAjRX5IjQoj+khwRQvSX5EhkksKtOKdVVlYO9BCEEBFOckQI0V+SI0KI/pIcEUL0l+RIZJLCrRBCCCGEEEIIIYQQQpxlpHArzmkxMTEDPQQhRISTHBFC9JfkiBCivyRHhBD9JTkSmaRwK85pJpNpoIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5rampaaCHIISIcJIjQoj+khwRQvSX5IgQor8kRyKTFG6FEEIIIYQQQgghhBDiLCOFW3FOS0tLG+ghCCEinOSIEKK/JEeEEP0lOSKE6C/JkchkGOgBCPF56ujoIDk5eaCHIYSIYJIjQpy//H4/LpcLn8+HxWIhNTUVs9l80jHNzc2sWbOGtrY24uLimDNnDomJiScdIzkihOgvyREhRH9JjkQmKdyKc1pPT89AD0EIEeEkR4Q4/7hcLoqKitiwcRPdHi+aBooCDruN6dOmMnPmTPbv388jjzzC3n37MVns6HR6VDVMwOdh7JhR3HPPPcybNw+QHBFC9J/kiBCivyRHIpMUbsU5zWCQX3EhRP9IjghxfiksLOSJJ5+iutaF2epgyNB8ho8ah0Gn49D+3bz+ViH33f8Ara3txA9KZ/TE2QwbVYDVHoW3p5sjB3ZSVVnKbXfcyVVXXM6jjz4qOSKE6DfJESFEf0mORCZF0zRtoAchPjm32010dDSdnZ04nc6BHo4QQgghRMT71a9+xd///n/0eP1ExyagNxjQVJVQMMCI0eO4cemXePM/L7Jp/RqGjpzA5dfdRv6YiSi6/24boakqZUf28caLj1G8fzvz513Mo48+OoDPSgghhBBCnE0+SU1PCrcRSgq3Z6a8vJzc3NyBHoYQIoJJjghxfli6dClvvfMeSalZZA3OZ9joAqy2KLyebor376Suqoy25gZ83h4Kpl3M9bd/C4vFRmxcb4H3g7o7O3j09z/h0O6NPPboP7jiiisG4FkJIc4Vcj8ihOgvyZGzxyep6ck6aSGEEEIIcV5btmwZa4o2MmrChVx48TWkZuZisliISxiE0WRmzvzFVJQc4JWn/8LRI3tJzRhMYnIaXZ3t+Lwe7I5Tb7ijomO48sY7cdWUs2bNmjMq3J7JZmhCCCGEEOL8IYVbcU6Ljo4e6CEIISKc5IgQ57Z33nmHV197g4RBmRiMJnZuXsPOzWsAMJnNDBs1gZHjp5I1ZDgXXnINBpOJkoO7aKqvJcoRjc/nwx7lAEVh/67NrHzmb3S5O3A4Y7j2pq+Snp3H1m3baW5uJjEx8bRjOHEzNHdXD16fD1VVcditXDxnNnPnziU1NfWLnBYhxFlG7keEEP0lORKZpFVChJJWCWemp6cHu90+0MMQQkQwyREhzm0XXHABNa4mhuSPJzNnKEPyx2O22PD7PJQd3ktrUz2KojC6YCpGs5VgMMDaVS8ydNQErl78Fbyebp79x2/Y8N6rgEJ0XAJ6nYGwGqKzrQU1HCI1NY3f/vbXLFq06JTHLyws5PkXVtDu9mC0RhGbkILOYCTg91NbWUprUx1Om5kv33kHixcv/sLnRwhxdpD7ESFEf0mOnD2kVYIQxzQ2NkoPFyFEv0iOCHHueumll6iqcTFqwoVcs/Ru0rIGYzCa+n4+fsocqo8eZvUbz7Fm1Qpyh40mZ+goLFY7q1Y8Rm1lKbs2r0FRFDJyh5OZO5y8/PF9vXFLDu6ipqKYUcNyWbx48SmF28LCQp5Z/gIBzcSICTOIjkvGYrUTGxeP3mAgFAxSeuQAm99/k1/95iGqq6u59957v+hpEkKcBeR+RAjRX5IjkUkKt0IIIYQQ4rzjcrn4xz8fZUj+eK644csMSs9CUZSTjuloa8ZksXLJNbdQ+NqzNNRVUjD1YkZfcCGN9VXs2rwas9XG6ILpzL/hSwwePg6dTtd3/iXX3EzZ4T24Dm8gLjEFRVE4/mG3iooKHv3XY3T7YeSEC4lLTCM3bzj2qJNXXQxKzSB/9Hj+/bff8uTTy8nMzJSVt0IIIYQQ5wndxx8iRORKSUkZ6CEIISKc5IgQ56aioiKCqp5R46cd21xMIRwO9/28vbUJv8+D2WwlKTWTyxbejsVqp762gpSMHKJjE9DrDYwumM7Sr36frMH5JxVtARRFIS9/PAUXL2L0hOlEOWPJzc3lueee466772bfwWJaWlvZXPQeRe+8wrb1hbQ2NZwy1vjEQdx429ewOeP5xz8fxeVyfd7TI4Q4y8j9iBCivyRHIpMUbsU5raura6CHIISIcJIjQpx7/H4/GzZuIiU9G2dsAuFQCINej6apqGoYn7cHn7cHk8lCdFwiZrOFQenZJCSnUnpoF91dnTTWVRGflML8G+7E4YwjFAoSDod6H0DT+O82EhoOm5n5N3yJQelZVFRU8J8338VgS2T63GuYc+USZl12HVExCezevoGn//kQu7YUnTLm9KzB5I0YQ0NjM4WFhV/cZAkhzgpyPyKE6C/JkcgkhVtxTuvu7h7oIQghIpzkiBDnHpfLRbfHy8SpF6KpITraW9Ab9ACEgkG63R3o9Qbszhj0+t7vK4rC4OHjCAYC7N1aRFtLE9lDRjJ4+FiMJhMKCqFgAE3T+ODOv0rIS+6wsWTkDMNssZI5ZDSzL7+OMRNnkD92CpNmzGPRl77Nglu+TnRsEkXv/eeU4q2iKIwcNxmLzUnh6tX4/f4vYqqEEGcJuR8RQvSX5EhkksKtOKcdf7ElhBCfluSIEOcen8+HpoEzOpb4+Di6uzpwd3ZgMBoJh0MEAwEMRhMmk7nvHE3TsFitBPw+aipLMZnNDBt9AYqioNPp0el0hMPhE1baAmiAAooOnU7H0JEFRMcm8vbrK/AHAoCCydz7GIqikJKezYJb7iYmLpl1ha+d0jbBYrHiiI6hs6t
"text/plain": [
"<Figure size 1400x700 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# 'filtered_data_plot.csv' 읽기\n",
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
"\n",
"# 데이터 토큰화\n",
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
"\n",
"# BERT 임베딩 계산 함수\n",
"def get_bert_embeddings(model, encodings, device):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" input_ids = encodings['input_ids'].to(device)\n",
" attention_mask = encodings['attention_mask'].to(device)\n",
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
"\n",
"# BERT 모델로 임베딩 계산\n",
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
"\n",
"# t-SNE 차원 축소\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
"\n",
"# 시각화를 위한 준비\n",
"unique_patterns = filtered_data['pattern'].unique()\n",
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
"\n",
"plt.figure(figsize=(14, 7))\n",
"\n",
"# 각 패턴별로 시각화\n",
"for pattern, color_idx in pattern_to_color.items():\n",
" pattern_indices = filtered_data['pattern'] == pattern\n",
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
"\n",
"# 그래프 설정\n",
"plt.xticks(fontsize=24)\n",
"plt.yticks(fontsize=24)\n",
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}