415 lines
556 KiB
Plaintext
415 lines
556 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||
|
" warnings.warn(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 1 completed. Loss: 5.812545299530029\n",
|
||
|
"Validation Accuracy after Epoch 1: 2.20%\n",
|
||
|
"Epoch 2 completed. Loss: 5.4337921142578125\n",
|
||
|
"Validation Accuracy after Epoch 2: 12.22%\n",
|
||
|
"Epoch 3 completed. Loss: 4.7191081047058105\n",
|
||
|
"Validation Accuracy after Epoch 3: 18.78%\n",
|
||
|
"Epoch 4 completed. Loss: 3.5866851806640625\n",
|
||
|
"Validation Accuracy after Epoch 4: 27.73%\n",
|
||
|
"Epoch 5 completed. Loss: 2.891603469848633\n",
|
||
|
"Validation Accuracy after Epoch 5: 40.61%\n",
|
||
|
"Epoch 6 completed. Loss: 3.5778417587280273\n",
|
||
|
"Validation Accuracy after Epoch 6: 50.07%\n",
|
||
|
"Epoch 7 completed. Loss: 2.8838517665863037\n",
|
||
|
"Validation Accuracy after Epoch 7: 64.50%\n",
|
||
|
"Epoch 8 completed. Loss: 1.2843190431594849\n",
|
||
|
"Validation Accuracy after Epoch 8: 69.37%\n",
|
||
|
"Epoch 9 completed. Loss: 2.803881883621216\n",
|
||
|
"Validation Accuracy after Epoch 9: 76.16%\n",
|
||
|
"Epoch 10 completed. Loss: 1.6859067678451538\n",
|
||
|
"Validation Accuracy after Epoch 10: 77.75%\n",
|
||
|
"Epoch 11 completed. Loss: 1.5909239053726196\n",
|
||
|
"Validation Accuracy after Epoch 11: 80.19%\n",
|
||
|
"Epoch 12 completed. Loss: 3.397331953048706\n",
|
||
|
"Validation Accuracy after Epoch 12: 81.31%\n",
|
||
|
"Epoch 13 completed. Loss: 2.6174156665802\n",
|
||
|
"Validation Accuracy after Epoch 13: 83.14%\n",
|
||
|
"Epoch 14 completed. Loss: 2.170588493347168\n",
|
||
|
"Validation Accuracy after Epoch 14: 85.34%\n",
|
||
|
"Epoch 15 completed. Loss: 0.20281337201595306\n",
|
||
|
"Validation Accuracy after Epoch 15: 86.04%\n",
|
||
|
"Epoch 16 completed. Loss: 0.688520610332489\n",
|
||
|
"Validation Accuracy after Epoch 16: 86.74%\n",
|
||
|
"Epoch 17 completed. Loss: 2.2658097743988037\n",
|
||
|
"Validation Accuracy after Epoch 17: 87.73%\n",
|
||
|
"Epoch 18 completed. Loss: 1.1048349142074585\n",
|
||
|
"Validation Accuracy after Epoch 18: 88.06%\n",
|
||
|
"Epoch 19 completed. Loss: 0.42507076263427734\n",
|
||
|
"Validation Accuracy after Epoch 19: 88.24%\n",
|
||
|
"Epoch 20 completed. Loss: 0.2606792747974396\n",
|
||
|
"Validation Accuracy after Epoch 20: 88.71%\n",
|
||
|
"Epoch 21 completed. Loss: 0.31851115822792053\n",
|
||
|
"Validation Accuracy after Epoch 21: 89.18%\n",
|
||
|
"Epoch 22 completed. Loss: 0.09924610704183578\n",
|
||
|
"Validation Accuracy after Epoch 22: 89.46%\n",
|
||
|
"Epoch 23 completed. Loss: 0.09280075877904892\n",
|
||
|
"Validation Accuracy after Epoch 23: 90.26%\n",
|
||
|
"Epoch 24 completed. Loss: 0.12750865519046783\n",
|
||
|
"Validation Accuracy after Epoch 24: 89.93%\n",
|
||
|
"Epoch 25 completed. Loss: 0.06864642351865768\n",
|
||
|
"Validation Accuracy after Epoch 25: 90.40%\n",
|
||
|
"Epoch 26 completed. Loss: 1.2031394243240356\n",
|
||
|
"Validation Accuracy after Epoch 26: 90.87%\n",
|
||
|
"Epoch 27 completed. Loss: 0.049047697335481644\n",
|
||
|
"Validation Accuracy after Epoch 27: 90.82%\n",
|
||
|
"Epoch 28 completed. Loss: 0.2666439712047577\n",
|
||
|
"Validation Accuracy after Epoch 28: 91.10%\n",
|
||
|
"Epoch 29 completed. Loss: 1.1274741888046265\n",
|
||
|
"Validation Accuracy after Epoch 29: 91.38%\n",
|
||
|
"Epoch 30 completed. Loss: 0.040213990956544876\n",
|
||
|
"Validation Accuracy after Epoch 30: 91.66%\n",
|
||
|
"Epoch 31 completed. Loss: 0.04501065984368324\n",
|
||
|
"Validation Accuracy after Epoch 31: 91.71%\n",
|
||
|
"Epoch 32 completed. Loss: 0.031954798847436905\n",
|
||
|
"Validation Accuracy after Epoch 32: 92.37%\n",
|
||
|
"Epoch 33 completed. Loss: 0.13622191548347473\n",
|
||
|
"Validation Accuracy after Epoch 33: 92.32%\n",
|
||
|
"Epoch 34 completed. Loss: 0.22619958221912384\n",
|
||
|
"Validation Accuracy after Epoch 34: 92.18%\n",
|
||
|
"Epoch 35 completed. Loss: 0.03372799605131149\n",
|
||
|
"Validation Accuracy after Epoch 35: 92.37%\n",
|
||
|
"Epoch 36 completed. Loss: 0.04779297485947609\n",
|
||
|
"Validation Accuracy after Epoch 36: 92.41%\n",
|
||
|
"Epoch 37 completed. Loss: 0.02235586754977703\n",
|
||
|
"Validation Accuracy after Epoch 37: 92.88%\n",
|
||
|
"Epoch 38 completed. Loss: 0.07272133976221085\n",
|
||
|
"Validation Accuracy after Epoch 38: 92.97%\n",
|
||
|
"Epoch 39 completed. Loss: 0.19073285162448883\n",
|
||
|
"Validation Accuracy after Epoch 39: 92.93%\n",
|
||
|
"Epoch 40 completed. Loss: 0.030236737802624702\n",
|
||
|
"Validation Accuracy after Epoch 40: 93.16%\n",
|
||
|
"Epoch 41 completed. Loss: 0.048501163721084595\n",
|
||
|
"Validation Accuracy after Epoch 41: 93.58%\n",
|
||
|
"Epoch 42 completed. Loss: 0.3853774964809418\n",
|
||
|
"Validation Accuracy after Epoch 42: 93.77%\n",
|
||
|
"Epoch 43 completed. Loss: 0.03214043006300926\n",
|
||
|
"Validation Accuracy after Epoch 43: 94.05%\n",
|
||
|
"Epoch 44 completed. Loss: 0.03621528670191765\n",
|
||
|
"Validation Accuracy after Epoch 44: 93.72%\n",
|
||
|
"Epoch 45 completed. Loss: 0.12950848042964935\n",
|
||
|
"Validation Accuracy after Epoch 45: 94.00%\n",
|
||
|
"Epoch 46 completed. Loss: 0.9027665257453918\n",
|
||
|
"Validation Accuracy after Epoch 46: 94.05%\n",
|
||
|
"Epoch 47 completed. Loss: 0.014634504914283752\n",
|
||
|
"Validation Accuracy after Epoch 47: 93.86%\n",
|
||
|
"Epoch 48 completed. Loss: 0.019594205543398857\n",
|
||
|
"Validation Accuracy after Epoch 48: 94.15%\n",
|
||
|
"Epoch 49 completed. Loss: 0.05953751504421234\n",
|
||
|
"Validation Accuracy after Epoch 49: 94.15%\n",
|
||
|
"Epoch 50 completed. Loss: 0.01590300165116787\n",
|
||
|
"Validation Accuracy after Epoch 50: 94.38%\n",
|
||
|
"Epoch 51 completed. Loss: 0.015979576855897903\n",
|
||
|
"Validation Accuracy after Epoch 51: 94.05%\n",
|
||
|
"Epoch 52 completed. Loss: 0.10404151678085327\n",
|
||
|
"Validation Accuracy after Epoch 52: 94.19%\n",
|
||
|
"Epoch 53 completed. Loss: 0.03786793723702431\n",
|
||
|
"Validation Accuracy after Epoch 53: 94.19%\n",
|
||
|
"Epoch 54 completed. Loss: 0.06969229131937027\n",
|
||
|
"Validation Accuracy after Epoch 54: 94.33%\n",
|
||
|
"Epoch 55 completed. Loss: 0.028161363676190376\n",
|
||
|
"Validation Accuracy after Epoch 55: 93.96%\n",
|
||
|
"Epoch 56 completed. Loss: 0.008608573116362095\n",
|
||
|
"Validation Accuracy after Epoch 56: 94.05%\n",
|
||
|
"Epoch 57 completed. Loss: 0.04131948947906494\n",
|
||
|
"Validation Accuracy after Epoch 57: 94.38%\n",
|
||
|
"Epoch 58 completed. Loss: 0.057713668793439865\n",
|
||
|
"Validation Accuracy after Epoch 58: 94.29%\n",
|
||
|
"Epoch 59 completed. Loss: 0.006651934236288071\n",
|
||
|
"Validation Accuracy after Epoch 59: 94.05%\n",
|
||
|
"Epoch 60 completed. Loss: 0.009415321983397007\n",
|
||
|
"Validation Accuracy after Epoch 60: 94.29%\n",
|
||
|
"Epoch 61 completed. Loss: 0.07291022688150406\n",
|
||
|
"Validation Accuracy after Epoch 61: 94.33%\n",
|
||
|
"Epoch 62 completed. Loss: 0.019646339118480682\n",
|
||
|
"Validation Accuracy after Epoch 62: 94.47%\n",
|
||
|
"Epoch 63 completed. Loss: 0.005233598407357931\n",
|
||
|
"Validation Accuracy after Epoch 63: 94.33%\n",
|
||
|
"Epoch 64 completed. Loss: 0.006535904016345739\n",
|
||
|
"Validation Accuracy after Epoch 64: 94.38%\n",
|
||
|
"Epoch 65 completed. Loss: 0.04618072509765625\n",
|
||
|
"Validation Accuracy after Epoch 65: 94.52%\n",
|
||
|
"Epoch 66 completed. Loss: 0.003822903148829937\n",
|
||
|
"Validation Accuracy after Epoch 66: 94.52%\n",
|
||
|
"Epoch 67 completed. Loss: 0.007317937910556793\n",
|
||
|
"Validation Accuracy after Epoch 67: 94.38%\n",
|
||
|
"Epoch 68 completed. Loss: 0.043759193271398544\n",
|
||
|
"Validation Accuracy after Epoch 68: 94.33%\n",
|
||
|
"Epoch 69 completed. Loss: 0.005311332643032074\n",
|
||
|
"Validation Accuracy after Epoch 69: 94.24%\n",
|
||
|
"Epoch 70 completed. Loss: 0.014933265745639801\n",
|
||
|
"Validation Accuracy after Epoch 70: 94.15%\n",
|
||
|
"Epoch 71 completed. Loss: 0.017947230488061905\n",
|
||
|
"Validation Accuracy after Epoch 71: 94.43%\n",
|
||
|
"Epoch 72 completed. Loss: 0.03381676599383354\n",
|
||
|
"Validation Accuracy after Epoch 72: 94.38%\n",
|
||
|
"Epoch 73 completed. Loss: 0.0038632701616734266\n",
|
||
|
"Validation Accuracy after Epoch 73: 94.52%\n",
|
||
|
"Epoch 74 completed. Loss: 0.008432925678789616\n",
|
||
|
"Validation Accuracy after Epoch 74: 94.47%\n",
|
||
|
"Epoch 75 completed. Loss: 0.009978505782783031\n",
|
||
|
"Validation Accuracy after Epoch 75: 94.33%\n",
|
||
|
"Epoch 76 completed. Loss: 0.0022784313187003136\n",
|
||
|
"Validation Accuracy after Epoch 76: 94.43%\n",
|
||
|
"Epoch 77 completed. Loss: 0.002989798551425338\n",
|
||
|
"Validation Accuracy after Epoch 77: 94.61%\n",
|
||
|
"Epoch 78 completed. Loss: 0.008676871657371521\n",
|
||
|
"Validation Accuracy after Epoch 78: 94.66%\n",
|
||
|
"Epoch 79 completed. Loss: 0.009547003544867039\n",
|
||
|
"Validation Accuracy after Epoch 79: 94.10%\n",
|
||
|
"Epoch 80 completed. Loss: 0.016957035288214684\n",
|
||
|
"Validation Accuracy after Epoch 80: 93.54%\n",
|
||
|
"Accuracy (MDM=True) for Group 3: 94.33%\n",
|
||
|
"Results saved to 0.class_document/3/test_p_c.csv\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"from transformers import BertTokenizer, BertForSequenceClassification, AdamW\n",
|
||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
||
|
"import torch\n",
|
||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"\n",
|
||
|
"group_number = 3\n",
|
||
|
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
|
||
|
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
|
||
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||
|
"output_path = f'0.class_document/bert/{group_number}/test_p_c.csv'\n",
|
||
|
"\n",
|
||
|
"train_data = pd.read_csv(train_path)\n",
|
||
|
"valid_data = pd.read_csv(valid_path)\n",
|
||
|
"test_data = pd.read_csv(test_path)\n",
|
||
|
"\n",
|
||
|
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
|
||
|
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
|
||
|
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
||
|
"label_encoder = LabelEncoder()\n",
|
||
|
"label_encoder.fit(train_data['thing_property'])\n",
|
||
|
"\n",
|
||
|
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
|
||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||
|
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
|
||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||
|
"\n",
|
||
|
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
|
||
|
"\n",
|
||
|
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
|
||
|
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
|
||
|
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
|
||
|
"\n",
|
||
|
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
|
||
|
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
|
||
|
"\n",
|
||
|
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"\n",
|
||
|
"train_labels = torch.tensor(train_labels.values)\n",
|
||
|
"valid_labels = torch.tensor(valid_labels.values)\n",
|
||
|
"\n",
|
||
|
"class CustomDataset(Dataset):\n",
|
||
|
" def __init__(self, encodings, labels):\n",
|
||
|
" self.encodings = encodings\n",
|
||
|
" self.labels = labels\n",
|
||
|
"\n",
|
||
|
" def __getitem__(self, idx):\n",
|
||
|
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
|
||
|
" item['labels'] = self.labels[idx]\n",
|
||
|
" return item\n",
|
||
|
"\n",
|
||
|
" def __len__(self):\n",
|
||
|
" return len(self.labels)\n",
|
||
|
"\n",
|
||
|
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
|
||
|
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
|
||
|
"\n",
|
||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
||
|
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
|
||
|
"\n",
|
||
|
"model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(train_data['thing_property'].unique()))\n",
|
||
|
"optimizer = AdamW(model.parameters(), lr=3e-5)\n",
|
||
|
"\n",
|
||
|
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"model.to(device)\n",
|
||
|
"\n",
|
||
|
"epochs = 80\n",
|
||
|
"for epoch in range(epochs):\n",
|
||
|
" model.train()\n",
|
||
|
" for batch in train_loader:\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||
|
" labels = batch['labels'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
||
|
" loss = outputs.loss\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
|
||
|
"\n",
|
||
|
" # 검증 루프\n",
|
||
|
" model.eval()\n",
|
||
|
" correct, total = 0, 0\n",
|
||
|
"\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for batch in valid_loader:\n",
|
||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||
|
" labels = batch['labels'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
||
|
" correct += (predictions == labels).sum().item()\n",
|
||
|
" total += labels.size(0)\n",
|
||
|
"\n",
|
||
|
" valid_accuracy = correct / total\n",
|
||
|
" print(f'Validation Accuracy after Epoch {epoch + 1}: {valid_accuracy * 100:.2f}%')\n",
|
||
|
"\n",
|
||
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
||
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) \n",
|
||
|
"\n",
|
||
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
||
|
"\n",
|
||
|
"model.eval()\n",
|
||
|
"predicted_thing_properties = []\n",
|
||
|
"predicted_scores = []\n",
|
||
|
"\n",
|
||
|
"with torch.no_grad():\n",
|
||
|
" for batch in test_loader:\n",
|
||
|
" input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)\n",
|
||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
||
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
||
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
||
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
||
|
"\n",
|
||
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
||
|
"\n",
|
||
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
||
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
||
|
"test_data['c_score'] = predicted_scores\n",
|
||
|
"\n",
|
||
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
||
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
||
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
||
|
"\n",
|
||
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
||
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||
|
"\n",
|
||
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||
|
"\n",
|
||
|
"test_data.to_csv(output_path, index=False)\n",
|
||
|
"print(f'Results saved to {output_path}')\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABW4AAAKyCAYAAABFb0fEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3Rc1dX38e+dXjSjLlm92HKRu4w7uGEwmGpTXOghzhNIe0IS3vQESAVCevIkBAjNFAMmFNNkY7n33lSsrlFvI2n63Pv+IVuxsQGDAHns/VmLtSzp3jtnjuUfd/ac2UfRNE1DCCGEEEIIIYQQQgghxFlDN9ADEEIIIYQQQgghhBBCCHEyKdwKIYQQQgghhBBCCCHEWUYKt0IIIYQQQgghhBBCCHGWkcKtEEIIIYQQQgghhBBCnGWkcCuEEEIIIYQQQgghhBBnGSncCiGEEEIIIYQQQgghxFlGCrdCCCGEEEIIIYQQQghxlpHCrRBCCCGEEEIIIYQQQpxlDAM9APHpqKqKy+XC4XCgKMpAD0cIIYQQQgghhBBCCPExNE2jq6uL1NRUdLqPXlMrhdsI5XK5yMjIGOhhCCGEEEIIIYQQQgghPqGamhrS09M/8hgp3EYoh8MB9P4lO53OAR7N2cvlcpGamjrQwxBCRDDJESFEf0mOCCH6S3JECNFfkiNnD7fbTUZGRl9t76NI4TZCHW+P4HQ6pXD7Eex2O3q9fqCHIYSIYJIjQoj+khwRQvSX5IgQor8kR84+Z9L6VDYnE+e0qqqqgR6CECLCSY4IIfpLckQI0V+SI0KI/pIciUxSuBVCCCGEEEIIIYQQQoizjBRuxTktNjZ2oIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5Tfq3CCH6S3JECNFfkiNCiP6SHBFC9JfkSGSSwq04p7W0tAz0EIQQEU5yRAjRX5IjQoj+khwRQvSX5EhkksKtEEIIIYQQQgghhBBCnGWkcCvOaenp6QM9BCFEhJMcEUL0l+SIEKK/JEeEEP0lORKZpHArzmltbW0DPQQhRISTHBFC9JfkiBCivyRHhBD9JTkSmaRwK85pHo9noIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5zWg0DvQQhBARTnJECNFfkiNCiP6SHBFC9JfkSGSSwq04p6WlpQ30EIQQEU5yRAjRX5IjQoj+khwRQvSX5EhkksKtOKdVVlYO9BCEEBFOckQI0V+SI0KI/pIcEUL0l+RIZJLCrRBCCCGEEEIIIYQQQpxlpHArzmkxMTEDPQQhRISTHBFC9JfkiBCivyRHhBD9JTkSmaRwK85pJpNpoIcghIhwkiNCiP6SHBFC9JfkiBCivyRHIpMUbsU5rampaaCHIISIcJIjQoj+khwRQvSX5IgQor8kRyKTFG6FEEIIIYQQQgghhBDiLCOFW3FOS0tLG+ghCCEinOSIEKK/JEeEEP0lOSKE6C/JkchkGOgBCPF56ujoIDk5eaCHIYSIYJIjQpy//H4/LpcLn8+HxWIhNTUVs9l80jHNzc2sWbOGtrY24uLimDNnDomJiScdIzkihOgvyREhRH9JjkQmKdyKc1pPT89AD0EIEeEkR4Q4/7hcLoqKitiwcRPdHi+aBooCDruN6dOmMnPmTPbv388jjzzC3n37MVns6HR6VDVMwOdh7JhR3HPPPcybNw+QHBFC9J/kiBCivyRHIpMUbsU5zWCQX3EhRP9IjghxfiksLOSJJ5+iutaF2epgyNB8ho8ah0Gn49D+3bz+ViH33f8Ara3txA9KZ/TE2QwbVYDVHoW3p5sjB3ZSVVnKbXfcyVVXXM6jjz4qOSKE6DfJESFEf0mORCZF0zRtoAchPjm32010dDSdnZ04nc6BHo4QQgghRMT71a9+xd///n/0eP1ExyagNxjQVJVQMMCI0eO4cemXePM/L7Jp/RqGjpzA5dfdRv6YiSi6/24boakqZUf28caLj1G8fzvz513Mo48+OoDPSgghhBBCnE0+SU1PCrcRSgq3Z6a8vJzc3NyBHoYQIoJJjghxfli6dClvvfMeSalZZA3OZ9joAqy2KLyebor376Suqoy25gZ83h4Kpl3M9bd/C4vFRmxcb4H3g7o7O3j09z/h0O6NPPboP7jiiisG4FkJIc4Vcj8ihOgvyZGzxyep6ck6aSGEEEIIcV5btmwZa4o2MmrChVx48TWkZuZisliISxiE0WRmzvzFVJQc4JWn/8LRI3tJzRhMYnIaXZ3t+Lwe7I5Tb7ijomO48sY7cdWUs2bNmjMq3J7JZmhCCCGEEOL8IYVbcU6Ljo4e6CEIISKc5IgQ57Z33nmHV197g4RBmRiMJnZuXsPOzWsAMJnNDBs1gZHjp5I1ZDgXXnINBpOJkoO7aKqvJcoRjc/nwx7lAEVh/67NrHzmb3S5O3A4Y7j2pq+Snp3H1m3baW5uJjEx8bRjOHEzNHdXD16fD1VVcditXDxnNnPnziU1NfWLnBYhxFlG7keEEP0lORKZpFVChJJWCWemp6cHu90+0MMQQkQwyREhzm0XXHABNa4mhuSPJzNnKEPyx2O22PD7PJQd3ktrUz2KojC6YCpGs5VgMMDaVS8ydNQErl78Fbyebp79x2/Y8N6rgEJ0XAJ6nYGwGqKzrQU1HCI1NY3f/vbXLFq06JTHLyws5PkXVtDu9mC0RhGbkILOYCTg91NbWUprUx1Om5kv33kHixcv/sLnRwhxdpD7ESFEf0mOnD2kVYIQxzQ2NkoPFyFEv0iOCHHueumll6iqcTFqwoVcs/Ru0rIGYzCa+n4+fsocqo8eZvUbz7Fm1Qpyh40mZ+goLFY7q1Y8Rm1lKbs2r0FRFDJyh5OZO5y8/PF9vXFLDu6ipqKYUcNyWbx48SmF28LCQp5Z/gIBzcSICTOIjkvGYrUTGxeP3mAgFAxSeuQAm99/k1/95iGqq6u59957v+hpEkKcBeR+RAjRX5IjkUkKt0IIIYQQ4rzjcrn4xz8fZUj+eK644csMSs9CUZSTjuloa8ZksXLJNbdQ+NqzNNRVUjD1YkZfcCGN9VXs2rwas9XG6ILpzL/hSwwePg6dTtd3/iXX3EzZ4T24Dm8gLjEFRVE4/mG3iooKHv3XY3T7YeSEC4lLTCM3bzj2qJNXXQxKzSB/9Hj+/bff8uTTy8nMzJSVt0IIIYQQ5wndxx8iRORKSUkZ6CEIISKc5IgQ56aioiKCqp5R46cd21xMIRwO9/28vbUJv8+D2WwlKTWTyxbejsVqp762gpSMHKJjE9DrDYwumM7Sr36frMH5JxVtARRFIS9/PAUXL2L0hOlEOWPJzc3lueee466772bfwWJaWlvZXPQeRe+8wrb1hbQ2NZwy1vjEQdx429ewOeP5xz8fxeVyfd7TI4Q4y8j9iBCivyRHIpMUbsU5raura6CHIISIcJIjQpx7/H4/GzZuIiU9G2dsAuFQCINej6apqGoYn7cHn7cHk8lCdFwiZrOFQenZJCSnUnpoF91dnTTWVRGflML8G+7E4YwjFAoSDod6H0DT+O82EhoOm5n5N3yJQelZVFRU8J8338VgS2T63GuYc+USZl12HVExCezevoGn//kQu7YUnTLm9KzB5I0YQ0NjM4WFhV/cZAkhzgpyPyKE6C/JkcgkhVtxTuvu7h7oIQghIpzkiBDnHpfLRbfHy8SpF6KpITraW9Ab9ACEgkG63R3o9Qbszhj0+t7vK4rC4OHjCAYC7N1aRFtLE9lDRjJ4+FiMJhMKCqFgAE3T+ODOv0rIS+6wsWTkDMNssZI5ZDSzL7+OMRNnkD92CpNmzGPRl77Nglu+TnRsEkXv/eeU4q2iKIwcNxmLzUnh6tX4/f4vYqqEEGcJuR8RQvSX5EhkksKtOKcdf7ElhBCfluSIEOcen8+HpoEzOpb4+Di6uzpwd3ZgMBoJh0MEAwEMRhMmk7nvHE3TsFitBPw+aipLMZnNDBt9AYqioNPp0el0hMPhE1baAmiAAooOnU7H0JEFRMcm8vbrK/AHAoCCydz7GIqikJKezYJb7iYmLpl1ha+d0jbBYrHiiI6hs6t
|
||
|
"text/plain": [
|
||
|
"<Figure size 1400x700 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.manifold import TSNE\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"# 'filtered_data_plot.csv' 읽기\n",
|
||
|
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
|
||
|
"\n",
|
||
|
"# 데이터 토큰화\n",
|
||
|
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||
|
"\n",
|
||
|
"# BERT 임베딩 계산 함수\n",
|
||
|
"def get_bert_embeddings(model, encodings, device):\n",
|
||
|
" model.eval()\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" input_ids = encodings['input_ids'].to(device)\n",
|
||
|
" attention_mask = encodings['attention_mask'].to(device)\n",
|
||
|
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
|
||
|
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
|
||
|
"\n",
|
||
|
"# BERT 모델로 임베딩 계산\n",
|
||
|
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
|
||
|
"\n",
|
||
|
"# t-SNE 차원 축소\n",
|
||
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||
|
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
|
||
|
"\n",
|
||
|
"# 시각화를 위한 준비\n",
|
||
|
"unique_patterns = filtered_data['pattern'].unique()\n",
|
||
|
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
|
||
|
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(14, 7))\n",
|
||
|
"\n",
|
||
|
"# 각 패턴별로 시각화\n",
|
||
|
"for pattern, color_idx in pattern_to_color.items():\n",
|
||
|
" pattern_indices = filtered_data['pattern'] == pattern\n",
|
||
|
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
|
||
|
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
|
||
|
"\n",
|
||
|
"# 그래프 설정\n",
|
||
|
"plt.xticks(fontsize=24)\n",
|
||
|
"plt.yticks(fontsize=24)\n",
|
||
|
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "torch",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|