389 lines
17 KiB
Plaintext
389 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torchvision import transforms\n",
|
|
"from PIL import Image\n",
|
|
"import glob, os\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import utils\n",
|
|
"from moco import SelfSupervisedMethod\n",
|
|
"# from model_params import EigRegParams\n",
|
|
"from model_params import VICRegParams\n",
|
|
"\n",
|
|
"from attr import evolve\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"from sklearn.decomposition import PCA\n",
|
|
"\n",
|
|
"from sklearn.cluster import KMeans\n",
|
|
"from sklearn.metrics import rand_score, normalized_mutual_info_score\n",
|
|
"\n",
|
|
"# data parameters\n",
|
|
"data_params = list()\n",
|
|
"# 0 Beef\n",
|
|
"data_params.append({'resize': 471, \n",
|
|
" 'batch_size':30,\n",
|
|
" 'num_clusters': 5,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/test\",\n",
|
|
" 'checkpoint': 'checkpoint_beef'})\n",
|
|
"# 1 dist.phal.outl.agegroup\n",
|
|
"data_params.append({'resize': 81, \n",
|
|
" 'batch_size':139,\n",
|
|
" 'num_clusters': 3,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/test\",\n",
|
|
" 'checkpoint': 'checkpoint_dist_agegroup'})\n",
|
|
"# 2 ECG200\n",
|
|
"data_params.append({'resize': 97, \n",
|
|
" 'batch_size':100,\n",
|
|
" 'num_clusters': 2,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/test\",\n",
|
|
" 'checkpoint': 'checkpoint_ecg200'})\n",
|
|
"# 3 ECGFiveDays\n",
|
|
"data_params.append({'resize': 137, \n",
|
|
" 'batch_size':23,\n",
|
|
" 'num_clusters': 2,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/test\",\n",
|
|
" 'checkpoint': 'checkpoint_ecg5days'})\n",
|
|
"# 4 Meat\n",
|
|
"data_params.append({'resize': 449, \n",
|
|
" 'batch_size':60,\n",
|
|
" 'num_clusters': 3,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/test\",\n",
|
|
" 'checkpoint': 'checkpoint_meat'})\n",
|
|
"# 5 mote strain\n",
|
|
"data_params.append({'resize': 85, \n",
|
|
" 'batch_size': 20,\n",
|
|
" 'num_clusters': 2,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/test\",\n",
|
|
" 'checkpoint': 'checkpoint_motestrain'})\n",
|
|
"# 6 osuleaf\n",
|
|
"data_params.append({'resize': 428, \n",
|
|
" 'batch_size': 64, # 200\n",
|
|
" 'num_clusters': 6,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/test\",\n",
|
|
" 'checkpoint': 'checkpoint_osuleaf'})\n",
|
|
"# 7 plane\n",
|
|
"data_params.append({'resize': 145, \n",
|
|
" 'batch_size': 105,\n",
|
|
" 'num_clusters': 7,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/test\",\n",
|
|
" 'checkpoint': 'checkpoint_plane'})\n",
|
|
"# 8 proximal_agegroup\n",
|
|
"data_params.append({'resize': 81, \n",
|
|
" 'batch_size': 205,\n",
|
|
" 'num_clusters': 3,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/test\",\n",
|
|
" 'checkpoint': 'checkpoint_prox_agegroup'})\n",
|
|
"# 9 proximal_tw\n",
|
|
"data_params.append({'resize': 81, \n",
|
|
" 'batch_size': 100, # 400\n",
|
|
" 'num_clusters': 6,\n",
|
|
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/train\",\n",
|
|
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/test\",\n",
|
|
" 'checkpoint': 'checkpoint_prox_tw'})\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def calculate_mean_std(image_dir):\n",
|
|
" # Initialize lists to store the sum and squared sum of pixel values\n",
|
|
" mean_1, mean_2 = 0.0, 0.0\n",
|
|
" std_1, std_2 = 0.0, 0.0\n",
|
|
" num_pixels = 0\n",
|
|
"\n",
|
|
" # Iterate through all images in the directory\n",
|
|
" for dirpath, dirnames, filenames in os.walk(image_dir):\n",
|
|
" for filename in filenames:\n",
|
|
" # Full path of the file\n",
|
|
" file_path = os.path.join(dirpath, filename)\n",
|
|
"\n",
|
|
" # for img_name in os.listdir(image_dir):\n",
|
|
" # img_path = os.path.join(image_dir, img_name)\n",
|
|
" if os.path.isfile(file_path) and file_path.endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff')):\n",
|
|
" with Image.open(file_path) as img:\n",
|
|
" # img = img.convert('RGB') # Ensure image is in RGB format\n",
|
|
" img_np = np.array(img) / 255.0 # Normalize to range [0, 1]\n",
|
|
" \n",
|
|
" num_pixels += img_np.shape[0] * img_np.shape[1]\n",
|
|
" \n",
|
|
" mean_1 += np.sum(img_np[:, :, 0])\n",
|
|
" mean_2 += np.sum(img_np[:, :, 1])\n",
|
|
" \n",
|
|
" std_1 += np.sum(img_np[:, :, 0] ** 2)\n",
|
|
" std_2 += np.sum(img_np[:, :, 1] ** 2)\n",
|
|
"\n",
|
|
" # Calculate mean\n",
|
|
" mean_1 /= num_pixels\n",
|
|
" mean_2 /= num_pixels\n",
|
|
"\n",
|
|
" # Calculate standard deviation\n",
|
|
" std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5\n",
|
|
" std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5\n",
|
|
"\n",
|
|
" return [mean_1, mean_2], [std_1, std_2]\n",
|
|
"\n",
|
|
"def list_directories(path):\n",
|
|
" entries = os.listdir(path)\n",
|
|
" directories = [ entry for entry in entries if os.path.isdir(os.path.join(path, entry))]\n",
|
|
" return directories\n",
|
|
"\n",
|
|
"\n",
|
|
"def inference(method, classes, transform, path):\n",
|
|
" batch_size = 32\n",
|
|
" image_tensors = []\n",
|
|
" result = []\n",
|
|
" labels = []\n",
|
|
"\n",
|
|
" # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
" # device = torch.device('cpu')\n",
|
|
" device = torch.device('cuda:2')\n",
|
|
" method.model.to(device)\n",
|
|
" method.projection_model.to(device)\n",
|
|
"\n",
|
|
"\n",
|
|
" for key in classes:\n",
|
|
" image_dir = path + '/' + key \n",
|
|
" for img_name in os.listdir(image_dir):\n",
|
|
" image_path = os.path.join(image_dir, img_name)\n",
|
|
" image = Image.open(image_path)\n",
|
|
" # image = image.convert('RGB')\n",
|
|
"\n",
|
|
" # Preprocess the image\n",
|
|
" input_tensor = transform(image).unsqueeze(0) # Add batch dimension\n",
|
|
" image_tensors.append(input_tensor)\n",
|
|
"\n",
|
|
" # perform batching\n",
|
|
" if len(image_tensors) == batch_size:\n",
|
|
" batch_tensor = torch.cat(image_tensors).to(device)\n",
|
|
" # Use the pre-trained model to extract features\n",
|
|
" with torch.no_grad():\n",
|
|
" emb = method.model(batch_tensor)\n",
|
|
" projection = method.projection_model(emb)\n",
|
|
" # projection = method.model(input_tensor)\n",
|
|
" result.extend(projection.cpu())\n",
|
|
" # reset back to 0\n",
|
|
" image_tensors = []\n",
|
|
"\n",
|
|
"\n",
|
|
" labels.append(int(key))\n",
|
|
"\n",
|
|
" if len(image_tensors) > 0:\n",
|
|
" batch_tensor = torch.cat(image_tensors).to(device)\n",
|
|
" # Use the pre-trained model to extract features\n",
|
|
" with torch.no_grad():\n",
|
|
" emb = method.model(batch_tensor)\n",
|
|
" projection = method.projection_model(emb)\n",
|
|
" # projection = method.model(input_tensor)\n",
|
|
" result.extend(projection.cpu())\n",
|
|
"\n",
|
|
" return result, labels\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Number of runs\n",
|
|
"num_runs = 10\n",
|
|
"# Number of results/metrics per run\n",
|
|
"num_results_per_run = 10\n",
|
|
"# Create a 2D NumPy array to store the results\n",
|
|
"ri_results = np.zeros((num_runs, num_results_per_run))\n",
|
|
"nmi_results = np.zeros((num_runs, num_results_per_run))\n",
|
|
"\n",
|
|
"\n",
|
|
"start = 0\n",
|
|
"end = 9\n",
|
|
"for run_num in range(num_runs):\n",
|
|
" for selector in range(start,end+1):\n",
|
|
"\n",
|
|
" config = evolve(VICRegParams(), \n",
|
|
" encoder_arch = \"ws_resnet18\", # resnet18, resnet34, resnet50\n",
|
|
" dataset_name=\"custom\", \n",
|
|
" train_path=data_params[selector]['train_path'],\n",
|
|
" test_path=data_params[selector]['test_path'],\n",
|
|
" kmeans_weight=0, # it doens't matter since this is not used in the model\n",
|
|
" num_clusters=data_params[selector]['num_clusters'])\n",
|
|
" method = SelfSupervisedMethod(config)\n",
|
|
" # Initialize your ResNet model\n",
|
|
" checkpoint = data_params[selector]['checkpoint']\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=49-step=50.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=99-step=100.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=149-step=150.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=199-step=200.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=299-step=300.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=399-step=400.ckpt'\n",
|
|
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=499-step=500.ckpt'\n",
|
|
" path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/last-v{run_num}.ckpt'\n",
|
|
" method = method.load_from_checkpoint(path)\n",
|
|
" # Set the model to evaluation mode\n",
|
|
" method.eval()\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
" # Define transform\n",
|
|
" path = data_params[selector]['test_path']\n",
|
|
" normalize_means, normalize_stds = calculate_mean_std(path)\n",
|
|
" # image_size = data_params[selector]['resize']\n",
|
|
" # crop_size = int(0.4 * image_size)\n",
|
|
" transform = transforms.Compose([\n",
|
|
" # transforms.Resize((image_size, image_size)),\n",
|
|
"\n",
|
|
" # transforms.CenterCrop(size=(crop_size, crop_size)),\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize(mean=normalize_means, std=normalize_stds),\n",
|
|
" ])\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
" # get all the classes\n",
|
|
" classes = list_directories(path)\n",
|
|
"\n",
|
|
" result, labels = inference(method, classes, transform, path)\n",
|
|
"\n",
|
|
" data = np.array(result)\n",
|
|
" # pca = PCA(n_components=2)\n",
|
|
" # reduced_data = pca.fit_transform(data)\n",
|
|
"\n",
|
|
" # Choose the number of clusters, say 3\n",
|
|
" kmeans = KMeans(n_clusters=data_params[selector]['num_clusters'], random_state=42, n_init=10)\n",
|
|
" clusters = kmeans.fit_predict(data)\n",
|
|
"\n",
|
|
"\n",
|
|
" # print(data_params[selector]['checkpoint'])\n",
|
|
" # print(\"Rand Index: \", rand_score(clusters, labels))\n",
|
|
" # print(\"NMI: \", normalized_mutual_info_score(clusters, labels))\n",
|
|
" rand_index = rand_score(clusters, labels)\n",
|
|
" nmi = normalized_mutual_info_score(clusters, labels)\n",
|
|
"\n",
|
|
" ri_results[run_num,selector] = rand_index\n",
|
|
" nmi_results[run_num, selector] = nmi"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0\n",
|
|
"RI mean: 0.6374712643678162\n",
|
|
"RI std: 0.03729597247277404\n",
|
|
"NMI mean: 0.2602497793608286\n",
|
|
"NMI std: 0.02985710765468422\n",
|
|
"1\n",
|
|
"RI mean: 0.6458659159628819\n",
|
|
"RI std: 0.01396857176725652\n",
|
|
"NMI mean: 0.30136574266475413\n",
|
|
"NMI std: 0.03170806441677236\n",
|
|
"2\n",
|
|
"RI mean: 0.6408888888888888\n",
|
|
"RI std: 0.021672013185081274\n",
|
|
"NMI mean: 0.20592658284184645\n",
|
|
"NMI std: 0.048248741710938625\n",
|
|
"3\n",
|
|
"RI mean: 0.5765183804661967\n",
|
|
"RI std: 0.03417303145285537\n",
|
|
"NMI mean: 0.11716054993167128\n",
|
|
"NMI std: 0.05173307476499848\n",
|
|
"4\n",
|
|
"RI mean: 0.7636723163841806\n",
|
|
"RI std: 0.0838674066635877\n",
|
|
"NMI mean: 0.6087294263576666\n",
|
|
"NMI std: 0.10910741463199608\n",
|
|
"5\n",
|
|
"RI mean: 0.6088675385570139\n",
|
|
"RI std: 0.041236238284731705\n",
|
|
"NMI mean: 0.17859344790373669\n",
|
|
"NMI std: 0.08743358257833596\n",
|
|
"6\n",
|
|
"RI mean: 0.7343060937553582\n",
|
|
"RI std: 0.020174409290336055\n",
|
|
"NMI mean: 0.22234048756150684\n",
|
|
"NMI std: 0.029705953611425088\n",
|
|
"7\n",
|
|
"RI mean: 0.9384065934065934\n",
|
|
"RI std: 0.019608200939834567\n",
|
|
"NMI mean: 0.8406540399495203\n",
|
|
"NMI std: 0.03016596675386891\n",
|
|
"8\n",
|
|
"RI mean: 0.7505643232902918\n",
|
|
"RI std: 0.022756611198669806\n",
|
|
"NMI mean: 0.48619057929963566\n",
|
|
"NMI std: 0.01338604938860034\n",
|
|
"9\n",
|
|
"RI mean: 0.832415112386418\n",
|
|
"RI std: 0.019248159640681852\n",
|
|
"NMI mean: 0.5497811436968876\n",
|
|
"NMI std: 0.023003346601586715\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for data_select in range(10):\n",
|
|
"\tprint(data_select)\n",
|
|
"\tprint(\"RI mean: \", np.mean(ri_results[:,data_select]))\n",
|
|
"\tprint(\"RI std: \", np.std(ri_results[:,data_select]))\n",
|
|
"\tprint(\"NMI mean: \", np.mean(nmi_results[:,data_select]))\n",
|
|
"\tprint(\"NMI std: \", np.std(nmi_results[:,data_select]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|