{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torchvision import transforms\n", "from PIL import Image\n", "import glob, os\n", "import numpy as np\n", "\n", "import utils\n", "from moco import SelfSupervisedMethod\n", "# from model_params import EigRegParams\n", "from model_params import VICRegParams\n", "\n", "from attr import evolve\n", "\n", "import pandas as pd\n", "from sklearn.decomposition import PCA\n", "\n", "from sklearn.cluster import KMeans\n", "from sklearn.metrics import rand_score, normalized_mutual_info_score\n", "\n", "# data parameters\n", "data_params = list()\n", "# 0 Beef\n", "data_params.append({'resize': 471, \n", " 'batch_size':30,\n", " 'num_clusters': 5,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/test\",\n", " 'checkpoint': 'checkpoint_beef'})\n", "# 1 dist.phal.outl.agegroup\n", "data_params.append({'resize': 81, \n", " 'batch_size':139,\n", " 'num_clusters': 3,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/test\",\n", " 'checkpoint': 'checkpoint_dist_agegroup'})\n", "# 2 ECG200\n", "data_params.append({'resize': 97, \n", " 'batch_size':100,\n", " 'num_clusters': 2,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/test\",\n", " 'checkpoint': 'checkpoint_ecg200'})\n", "# 3 ECGFiveDays\n", "data_params.append({'resize': 137, \n", " 'batch_size':23,\n", " 'num_clusters': 2,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/test\",\n", " 'checkpoint': 'checkpoint_ecg5days'})\n", "# 4 Meat\n", "data_params.append({'resize': 449, \n", " 'batch_size':60,\n", " 'num_clusters': 3,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/test\",\n", " 'checkpoint': 'checkpoint_meat'})\n", "# 5 mote strain\n", "data_params.append({'resize': 85, \n", " 'batch_size': 20,\n", " 'num_clusters': 2,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/test\",\n", " 'checkpoint': 'checkpoint_motestrain'})\n", "# 6 osuleaf\n", "data_params.append({'resize': 428, \n", " 'batch_size': 64, # 200\n", " 'num_clusters': 6,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/test\",\n", " 'checkpoint': 'checkpoint_osuleaf'})\n", "# 7 plane\n", "data_params.append({'resize': 145, \n", " 'batch_size': 105,\n", " 'num_clusters': 7,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/test\",\n", " 'checkpoint': 'checkpoint_plane'})\n", "# 8 proximal_agegroup\n", "data_params.append({'resize': 81, \n", " 'batch_size': 205,\n", " 'num_clusters': 3,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/test\",\n", " 'checkpoint': 'checkpoint_prox_agegroup'})\n", "# 9 proximal_tw\n", "data_params.append({'resize': 81, \n", " 'batch_size': 100, # 400\n", " 'num_clusters': 6,\n", " 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/train\",\n", " 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/test\",\n", " 'checkpoint': 'checkpoint_prox_tw'})\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def calculate_mean_std(image_dir):\n", " # Initialize lists to store the sum and squared sum of pixel values\n", " mean_1, mean_2 = 0.0, 0.0\n", " std_1, std_2 = 0.0, 0.0\n", " num_pixels = 0\n", "\n", " # Iterate through all images in the directory\n", " for dirpath, dirnames, filenames in os.walk(image_dir):\n", " for filename in filenames:\n", " # Full path of the file\n", " file_path = os.path.join(dirpath, filename)\n", "\n", " # for img_name in os.listdir(image_dir):\n", " # img_path = os.path.join(image_dir, img_name)\n", " if os.path.isfile(file_path) and file_path.endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff')):\n", " with Image.open(file_path) as img:\n", " # img = img.convert('RGB') # Ensure image is in RGB format\n", " img_np = np.array(img) / 255.0 # Normalize to range [0, 1]\n", " \n", " num_pixels += img_np.shape[0] * img_np.shape[1]\n", " \n", " mean_1 += np.sum(img_np[:, :, 0])\n", " mean_2 += np.sum(img_np[:, :, 1])\n", " \n", " std_1 += np.sum(img_np[:, :, 0] ** 2)\n", " std_2 += np.sum(img_np[:, :, 1] ** 2)\n", "\n", " # Calculate mean\n", " mean_1 /= num_pixels\n", " mean_2 /= num_pixels\n", "\n", " # Calculate standard deviation\n", " std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5\n", " std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5\n", "\n", " return [mean_1, mean_2], [std_1, std_2]\n", "\n", "def list_directories(path):\n", " entries = os.listdir(path)\n", " directories = [ entry for entry in entries if os.path.isdir(os.path.join(path, entry))]\n", " return directories\n", "\n", "\n", "def inference(method, classes, transform, path):\n", " batch_size = 32\n", " image_tensors = []\n", " result = []\n", " labels = []\n", "\n", " # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " # device = torch.device('cpu')\n", " device = torch.device('cuda:2')\n", " method.model.to(device)\n", " method.projection_model.to(device)\n", "\n", "\n", " for key in classes:\n", " image_dir = path + '/' + key \n", " for img_name in os.listdir(image_dir):\n", " image_path = os.path.join(image_dir, img_name)\n", " image = Image.open(image_path)\n", " # image = image.convert('RGB')\n", "\n", " # Preprocess the image\n", " input_tensor = transform(image).unsqueeze(0) # Add batch dimension\n", " image_tensors.append(input_tensor)\n", "\n", " # perform batching\n", " if len(image_tensors) == batch_size:\n", " batch_tensor = torch.cat(image_tensors).to(device)\n", " # Use the pre-trained model to extract features\n", " with torch.no_grad():\n", " emb = method.model(batch_tensor)\n", " projection = method.projection_model(emb)\n", " # projection = method.model(input_tensor)\n", " result.extend(projection.cpu())\n", " # reset back to 0\n", " image_tensors = []\n", "\n", "\n", " labels.append(int(key))\n", "\n", " if len(image_tensors) > 0:\n", " batch_tensor = torch.cat(image_tensors).to(device)\n", " # Use the pre-trained model to extract features\n", " with torch.no_grad():\n", " emb = method.model(batch_tensor)\n", " projection = method.projection_model(emb)\n", " # projection = method.model(input_tensor)\n", " result.extend(projection.cpu())\n", "\n", " return result, labels\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Number of runs\n", "num_runs = 10\n", "# Number of results/metrics per run\n", "num_results_per_run = 10\n", "# Create a 2D NumPy array to store the results\n", "ri_results = np.zeros((num_runs, num_results_per_run))\n", "nmi_results = np.zeros((num_runs, num_results_per_run))\n", "\n", "\n", "start = 0\n", "end = 9\n", "for run_num in range(num_runs):\n", " for selector in range(start,end+1):\n", "\n", " config = evolve(VICRegParams(), \n", " encoder_arch = \"ws_resnet18\", # resnet18, resnet34, resnet50\n", " dataset_name=\"custom\", \n", " train_path=data_params[selector]['train_path'],\n", " test_path=data_params[selector]['test_path'],\n", " kmeans_weight=0, # it doens't matter since this is not used in the model\n", " num_clusters=data_params[selector]['num_clusters'])\n", " method = SelfSupervisedMethod(config)\n", " # Initialize your ResNet model\n", " checkpoint = data_params[selector]['checkpoint']\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=49-step=50.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=99-step=100.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=149-step=150.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=199-step=200.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=299-step=300.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=399-step=400.ckpt'\n", " # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=499-step=500.ckpt'\n", " path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/last-v{run_num}.ckpt'\n", " method = method.load_from_checkpoint(path)\n", " # Set the model to evaluation mode\n", " method.eval()\n", "\n", "\n", "\n", " # Define transform\n", " path = data_params[selector]['test_path']\n", " normalize_means, normalize_stds = calculate_mean_std(path)\n", " # image_size = data_params[selector]['resize']\n", " # crop_size = int(0.4 * image_size)\n", " transform = transforms.Compose([\n", " # transforms.Resize((image_size, image_size)),\n", "\n", " # transforms.CenterCrop(size=(crop_size, crop_size)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=normalize_means, std=normalize_stds),\n", " ])\n", "\n", "\n", "\n", " # get all the classes\n", " classes = list_directories(path)\n", "\n", " result, labels = inference(method, classes, transform, path)\n", "\n", " data = np.array(result)\n", " # pca = PCA(n_components=2)\n", " # reduced_data = pca.fit_transform(data)\n", "\n", " # Choose the number of clusters, say 3\n", " kmeans = KMeans(n_clusters=data_params[selector]['num_clusters'], random_state=42, n_init=10)\n", " clusters = kmeans.fit_predict(data)\n", "\n", "\n", " # print(data_params[selector]['checkpoint'])\n", " # print(\"Rand Index: \", rand_score(clusters, labels))\n", " # print(\"NMI: \", normalized_mutual_info_score(clusters, labels))\n", " rand_index = rand_score(clusters, labels)\n", " nmi = normalized_mutual_info_score(clusters, labels)\n", "\n", " ri_results[run_num,selector] = rand_index\n", " nmi_results[run_num, selector] = nmi" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "RI mean: 0.6374712643678162\n", "RI std: 0.03729597247277404\n", "NMI mean: 0.2602497793608286\n", "NMI std: 0.02985710765468422\n", "1\n", "RI mean: 0.6458659159628819\n", "RI std: 0.01396857176725652\n", "NMI mean: 0.30136574266475413\n", "NMI std: 0.03170806441677236\n", "2\n", "RI mean: 0.6408888888888888\n", "RI std: 0.021672013185081274\n", "NMI mean: 0.20592658284184645\n", "NMI std: 0.048248741710938625\n", "3\n", "RI mean: 0.5765183804661967\n", "RI std: 0.03417303145285537\n", "NMI mean: 0.11716054993167128\n", "NMI std: 0.05173307476499848\n", "4\n", "RI mean: 0.7636723163841806\n", "RI std: 0.0838674066635877\n", "NMI mean: 0.6087294263576666\n", "NMI std: 0.10910741463199608\n", "5\n", "RI mean: 0.6088675385570139\n", "RI std: 0.041236238284731705\n", "NMI mean: 0.17859344790373669\n", "NMI std: 0.08743358257833596\n", "6\n", "RI mean: 0.7343060937553582\n", "RI std: 0.020174409290336055\n", "NMI mean: 0.22234048756150684\n", "NMI std: 0.029705953611425088\n", "7\n", "RI mean: 0.9384065934065934\n", "RI std: 0.019608200939834567\n", "NMI mean: 0.8406540399495203\n", "NMI std: 0.03016596675386891\n", "8\n", "RI mean: 0.7505643232902918\n", "RI std: 0.022756611198669806\n", "NMI mean: 0.48619057929963566\n", "NMI std: 0.01338604938860034\n", "9\n", "RI mean: 0.832415112386418\n", "RI std: 0.019248159640681852\n", "NMI mean: 0.5497811436968876\n", "NMI std: 0.023003346601586715\n" ] } ], "source": [ "for data_select in range(10):\n", "\tprint(data_select)\n", "\tprint(\"RI mean: \", np.mean(ri_results[:,data_select]))\n", "\tprint(\"RI std: \", np.std(ri_results[:,data_select]))\n", "\tprint(\"NMI mean: \", np.mean(nmi_results[:,data_select]))\n", "\tprint(\"NMI std: \", np.std(nmi_results[:,data_select]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }