342 lines
19 KiB
Plaintext
342 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1 completed. Loss: 5.564172267913818\n",
|
|
"Epoch 2 completed. Loss: 4.88321590423584\n",
|
|
"Epoch 3 completed. Loss: 3.5059947967529297\n",
|
|
"Epoch 4 completed. Loss: 3.18548583984375\n",
|
|
"Epoch 5 completed. Loss: 2.8037068843841553\n",
|
|
"Epoch 6 completed. Loss: 2.2223541736602783\n",
|
|
"Epoch 7 completed. Loss: 1.8634291887283325\n",
|
|
"Epoch 8 completed. Loss: 1.3251842260360718\n",
|
|
"Epoch 9 completed. Loss: 0.6083177328109741\n",
|
|
"Epoch 10 completed. Loss: 0.9423710703849792\n",
|
|
"Epoch 11 completed. Loss: 0.5799884796142578\n",
|
|
"Epoch 12 completed. Loss: 0.6948736310005188\n",
|
|
"Epoch 13 completed. Loss: 0.5177479386329651\n",
|
|
"Epoch 14 completed. Loss: 0.47343072295188904\n",
|
|
"Epoch 15 completed. Loss: 0.26853761076927185\n",
|
|
"Epoch 16 completed. Loss: 0.19693760573863983\n",
|
|
"Epoch 17 completed. Loss: 0.3199688494205475\n",
|
|
"Epoch 18 completed. Loss: 0.23672448098659515\n",
|
|
"Epoch 19 completed. Loss: 0.40235987305641174\n",
|
|
"Epoch 20 completed. Loss: 0.28102293610572815\n",
|
|
"Epoch 21 completed. Loss: 0.17575399577617645\n",
|
|
"Epoch 22 completed. Loss: 0.24652625620365143\n",
|
|
"Epoch 23 completed. Loss: 0.109055295586586\n",
|
|
"Epoch 24 completed. Loss: 0.19015412032604218\n",
|
|
"Epoch 25 completed. Loss: 0.10130400210618973\n",
|
|
"Epoch 26 completed. Loss: 0.14203257858753204\n",
|
|
"Epoch 27 completed. Loss: 0.1248723715543747\n",
|
|
"Epoch 28 completed. Loss: 0.05851107835769653\n",
|
|
"Epoch 29 completed. Loss: 0.041425254195928574\n",
|
|
"Epoch 30 completed. Loss: 0.0353962741792202\n",
|
|
"Epoch 31 completed. Loss: 0.04445452615618706\n",
|
|
"Epoch 32 completed. Loss: 0.026403019204735756\n",
|
|
"Epoch 33 completed. Loss: 0.028079884126782417\n",
|
|
"Epoch 34 completed. Loss: 0.059587348252534866\n",
|
|
"Epoch 35 completed. Loss: 0.02851276472210884\n",
|
|
"Epoch 36 completed. Loss: 0.09271513670682907\n",
|
|
"Epoch 37 completed. Loss: 0.06418397277593613\n",
|
|
"Epoch 38 completed. Loss: 0.03638231381773949\n",
|
|
"Epoch 39 completed. Loss: 0.022959664463996887\n",
|
|
"Epoch 40 completed. Loss: 0.044602662324905396\n",
|
|
"Epoch 41 completed. Loss: 0.03491249307990074\n",
|
|
"Epoch 42 completed. Loss: 0.039797600358724594\n",
|
|
"Epoch 43 completed. Loss: 0.04217083007097244\n",
|
|
"Epoch 44 completed. Loss: 0.4122176170349121\n",
|
|
"Epoch 45 completed. Loss: 0.1664775162935257\n",
|
|
"Epoch 46 completed. Loss: 0.04505300521850586\n",
|
|
"Epoch 47 completed. Loss: 0.14913827180862427\n",
|
|
"Epoch 48 completed. Loss: 0.016096509993076324\n",
|
|
"Epoch 49 completed. Loss: 0.05338064581155777\n",
|
|
"Epoch 50 completed. Loss: 0.10259533673524857\n",
|
|
"Epoch 51 completed. Loss: 0.008849691599607468\n",
|
|
"Epoch 52 completed. Loss: 0.028069255873560905\n",
|
|
"Epoch 53 completed. Loss: 0.008924427442252636\n",
|
|
"Epoch 54 completed. Loss: 0.015527592971920967\n",
|
|
"Epoch 55 completed. Loss: 0.009189464151859283\n",
|
|
"Epoch 56 completed. Loss: 0.007252057082951069\n",
|
|
"Epoch 57 completed. Loss: 0.01684846170246601\n",
|
|
"Epoch 58 completed. Loss: 0.010840333066880703\n",
|
|
"Epoch 59 completed. Loss: 0.05179211124777794\n",
|
|
"Epoch 60 completed. Loss: 0.007003726437687874\n",
|
|
"Epoch 61 completed. Loss: 0.00555015355348587\n",
|
|
"Epoch 62 completed. Loss: 0.0065276664681732655\n",
|
|
"Epoch 63 completed. Loss: 0.007942711934447289\n",
|
|
"Epoch 64 completed. Loss: 0.00675524678081274\n",
|
|
"Epoch 65 completed. Loss: 0.010359193198382854\n",
|
|
"Epoch 66 completed. Loss: 0.00662408908829093\n",
|
|
"Epoch 67 completed. Loss: 0.007672889623790979\n",
|
|
"Epoch 68 completed. Loss: 0.004661311395466328\n",
|
|
"Epoch 69 completed. Loss: 0.014480670914053917\n",
|
|
"Epoch 70 completed. Loss: 0.05042335391044617\n",
|
|
"Epoch 71 completed. Loss: 0.035947512835264206\n",
|
|
"Epoch 72 completed. Loss: 0.01213429868221283\n",
|
|
"Epoch 73 completed. Loss: 0.033572785556316376\n",
|
|
"Epoch 74 completed. Loss: 0.009208262898027897\n",
|
|
"Epoch 75 completed. Loss: 0.08961852639913559\n",
|
|
"Epoch 76 completed. Loss: 4.632999897003174\n",
|
|
"Epoch 77 completed. Loss: 5.957398891448975\n",
|
|
"Epoch 78 completed. Loss: 5.970841407775879\n",
|
|
"Epoch 79 completed. Loss: 5.905709266662598\n",
|
|
"Epoch 80 completed. Loss: 5.864459037780762\n",
|
|
"Validation Accuracy: 0.14%\n",
|
|
"Accuracy (MDM=True) for Group 4: 0.48%\n",
|
|
"Results saved to 0.class_document/albert/4/test_p_c.csv\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"from transformers import AlbertTokenizer, AlbertForSequenceClassification, AdamW\n",
|
|
"from sklearn.preprocessing import LabelEncoder\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"import numpy as np\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import os \n",
|
|
"\n",
|
|
"\n",
|
|
"group_number = 4\n",
|
|
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
|
|
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
|
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
|
"output_path = f'0.class_document/albert/{group_number}/test_p_c.csv'\n",
|
|
"\n",
|
|
"train_data = pd.read_csv(train_path)\n",
|
|
"valid_data = pd.read_csv(valid_path)\n",
|
|
"test_data = pd.read_csv(test_path)\n",
|
|
"\n",
|
|
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
|
|
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
|
|
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
|
|
"\n",
|
|
"tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')\n",
|
|
"label_encoder = LabelEncoder()\n",
|
|
"label_encoder.fit(train_data['thing_property'])\n",
|
|
"\n",
|
|
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
|
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
|
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
|
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
|
"\n",
|
|
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
|
|
"\n",
|
|
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
|
|
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
|
|
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
|
|
"\n",
|
|
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
|
|
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
|
|
"\n",
|
|
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
|
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
|
"\n",
|
|
"train_labels = torch.tensor(train_labels.values)\n",
|
|
"valid_labels = torch.tensor(valid_labels.values)\n",
|
|
"\n",
|
|
"class CustomDataset(Dataset):\n",
|
|
" def __init__(self, encodings, labels):\n",
|
|
" self.encodings = encodings\n",
|
|
" self.labels = labels\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
|
|
" item['labels'] = self.labels[idx]\n",
|
|
" return item\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.labels)\n",
|
|
"\n",
|
|
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
|
|
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
|
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
|
|
"\n",
|
|
"model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=len(train_data['thing_property'].unique()))\n",
|
|
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
|
|
"\n",
|
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
|
"model.to(device)\n",
|
|
"\n",
|
|
"epochs = 80\n",
|
|
"for epoch in range(epochs):\n",
|
|
" model.train()\n",
|
|
" for batch in train_loader:\n",
|
|
" optimizer.zero_grad()\n",
|
|
" input_ids = batch['input_ids'].to(device)\n",
|
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
|
" labels = batch['labels'].to(device)\n",
|
|
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
|
" loss = outputs.loss\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
|
|
"\n",
|
|
"model.eval()\n",
|
|
"correct, total = 0, 0\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" for batch in valid_loader:\n",
|
|
" input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)\n",
|
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
|
" correct += (predictions == labels).sum().item()\n",
|
|
" total += labels.size(0)\n",
|
|
"\n",
|
|
"valid_accuracy = correct / total\n",
|
|
"print(f'Validation Accuracy: {valid_accuracy * 100:.2f}%')\n",
|
|
"\n",
|
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) # 레이블은 사용되지 않으므로 임시로 0을 사용\n",
|
|
"\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
|
"\n",
|
|
"model.eval()\n",
|
|
"predicted_thing_properties = []\n",
|
|
"predicted_scores = []\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" for batch in test_loader:\n",
|
|
" input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)\n",
|
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
|
"\n",
|
|
"# 예측된 thing_property를 레이블 인코더로 디코딩\n",
|
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
|
"\n",
|
|
"# thing_property를 thing과 property로 나눔\n",
|
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
|
"test_data['c_score'] = predicted_scores\n",
|
|
"\n",
|
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
|
"\n",
|
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
|
"\n",
|
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
|
"\n",
|
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
|
"\n",
|
|
"test_data.to_csv(output_path, index=False)\n",
|
|
"print(f'Results saved to {output_path}')\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "AttributeError",
|
|
"evalue": "'AlbertForSequenceClassification' object has no attribute 'bert'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[29], line 20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy() \u001b[38;5;66;03m# 각 문장의 평균 임베딩 추출\u001b[39;00m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# BERT 모델로 임베딩 계산\u001b[39;00m\n\u001b[0;32m---> 20\u001b[0m bert_embeddings \u001b[38;5;241m=\u001b[39m \u001b[43mget_bert_embeddings\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfiltered_encodings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# t-SNE 차원 축소\u001b[39;00m\n\u001b[1;32m 23\u001b[0m tsne \u001b[38;5;241m=\u001b[39m TSNE(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m)\n",
|
|
"Cell \u001b[0;32mIn[29], line 16\u001b[0m, in \u001b[0;36mget_bert_embeddings\u001b[0;34m(model, encodings, device)\u001b[0m\n\u001b[1;32m 14\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m encodings[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 15\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m encodings[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 16\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m(input_ids\u001b[38;5;241m=\u001b[39minput_ids, attention_mask\u001b[38;5;241m=\u001b[39mattention_mask)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n",
|
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py:1709\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1707\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1708\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1709\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
|
"\u001b[0;31mAttributeError\u001b[0m: 'AlbertForSequenceClassification' object has no attribute 'bert'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.manifold import TSNE\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"# 'filtered_data_plot.csv' 읽기\n",
|
|
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
|
|
"\n",
|
|
"# 데이터 토큰화\n",
|
|
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
|
"\n",
|
|
"# BERT 임베딩 계산 함수\n",
|
|
"def get_bert_embeddings(model, encodings, device):\n",
|
|
" model.eval()\n",
|
|
" with torch.no_grad():\n",
|
|
" input_ids = encodings['input_ids'].to(device)\n",
|
|
" attention_mask = encodings['attention_mask'].to(device)\n",
|
|
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
|
|
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
|
|
"\n",
|
|
"# BERT 모델로 임베딩 계산\n",
|
|
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
|
|
"\n",
|
|
"# t-SNE 차원 축소\n",
|
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
|
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
|
|
"\n",
|
|
"# 시각화를 위한 준비\n",
|
|
"unique_patterns = filtered_data['pattern'].unique()\n",
|
|
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
|
|
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
|
|
"\n",
|
|
"plt.figure(figsize=(14, 7))\n",
|
|
"\n",
|
|
"# 각 패턴별로 시각화\n",
|
|
"for pattern, color_idx in pattern_to_color.items():\n",
|
|
" pattern_indices = filtered_data['pattern'] == pattern\n",
|
|
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
|
|
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
|
|
"\n",
|
|
"# 그래프 설정\n",
|
|
"plt.xticks(fontsize=24)\n",
|
|
"plt.yticks(fontsize=24)\n",
|
|
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
|
|
"plt.tight_layout()\n",
|
|
"plt.show()\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "torch",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|