semes_gaf/self_supervised/simple_test.ipynb

389 lines
17 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"import glob, os\n",
"import numpy as np\n",
"\n",
"import utils\n",
"from moco import SelfSupervisedMethod\n",
"# from model_params import EigRegParams\n",
"from model_params import VICRegParams\n",
"\n",
"from attr import evolve\n",
"\n",
"import pandas as pd\n",
"from sklearn.decomposition import PCA\n",
"\n",
"from sklearn.cluster import KMeans\n",
"from sklearn.metrics import rand_score, normalized_mutual_info_score\n",
"\n",
"# data parameters\n",
"data_params = list()\n",
"# 0 Beef\n",
"data_params.append({'resize': 471, \n",
" 'batch_size':30,\n",
" 'num_clusters': 5,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/test\",\n",
" 'checkpoint': 'checkpoint_beef'})\n",
"# 1 dist.phal.outl.agegroup\n",
"data_params.append({'resize': 81, \n",
" 'batch_size':139,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/test\",\n",
" 'checkpoint': 'checkpoint_dist_agegroup'})\n",
"# 2 ECG200\n",
"data_params.append({'resize': 97, \n",
" 'batch_size':100,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/test\",\n",
" 'checkpoint': 'checkpoint_ecg200'})\n",
"# 3 ECGFiveDays\n",
"data_params.append({'resize': 137, \n",
" 'batch_size':23,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/test\",\n",
" 'checkpoint': 'checkpoint_ecg5days'})\n",
"# 4 Meat\n",
"data_params.append({'resize': 449, \n",
" 'batch_size':60,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/test\",\n",
" 'checkpoint': 'checkpoint_meat'})\n",
"# 5 mote strain\n",
"data_params.append({'resize': 85, \n",
" 'batch_size': 20,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/test\",\n",
" 'checkpoint': 'checkpoint_motestrain'})\n",
"# 6 osuleaf\n",
"data_params.append({'resize': 428, \n",
" 'batch_size': 64, # 200\n",
" 'num_clusters': 6,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/test\",\n",
" 'checkpoint': 'checkpoint_osuleaf'})\n",
"# 7 plane\n",
"data_params.append({'resize': 145, \n",
" 'batch_size': 105,\n",
" 'num_clusters': 7,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/test\",\n",
" 'checkpoint': 'checkpoint_plane'})\n",
"# 8 proximal_agegroup\n",
"data_params.append({'resize': 81, \n",
" 'batch_size': 205,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/test\",\n",
" 'checkpoint': 'checkpoint_prox_agegroup'})\n",
"# 9 proximal_tw\n",
"data_params.append({'resize': 81, \n",
" 'batch_size': 100, # 400\n",
" 'num_clusters': 6,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/test\",\n",
" 'checkpoint': 'checkpoint_prox_tw'})\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def calculate_mean_std(image_dir):\n",
" # Initialize lists to store the sum and squared sum of pixel values\n",
" mean_1, mean_2 = 0.0, 0.0\n",
" std_1, std_2 = 0.0, 0.0\n",
" num_pixels = 0\n",
"\n",
" # Iterate through all images in the directory\n",
" for dirpath, dirnames, filenames in os.walk(image_dir):\n",
" for filename in filenames:\n",
" # Full path of the file\n",
" file_path = os.path.join(dirpath, filename)\n",
"\n",
" # for img_name in os.listdir(image_dir):\n",
" # img_path = os.path.join(image_dir, img_name)\n",
" if os.path.isfile(file_path) and file_path.endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff')):\n",
" with Image.open(file_path) as img:\n",
" # img = img.convert('RGB') # Ensure image is in RGB format\n",
" img_np = np.array(img) / 255.0 # Normalize to range [0, 1]\n",
" \n",
" num_pixels += img_np.shape[0] * img_np.shape[1]\n",
" \n",
" mean_1 += np.sum(img_np[:, :, 0])\n",
" mean_2 += np.sum(img_np[:, :, 1])\n",
" \n",
" std_1 += np.sum(img_np[:, :, 0] ** 2)\n",
" std_2 += np.sum(img_np[:, :, 1] ** 2)\n",
"\n",
" # Calculate mean\n",
" mean_1 /= num_pixels\n",
" mean_2 /= num_pixels\n",
"\n",
" # Calculate standard deviation\n",
" std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5\n",
" std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5\n",
"\n",
" return [mean_1, mean_2], [std_1, std_2]\n",
"\n",
"def list_directories(path):\n",
" entries = os.listdir(path)\n",
" directories = [ entry for entry in entries if os.path.isdir(os.path.join(path, entry))]\n",
" return directories\n",
"\n",
"\n",
"def inference(method, classes, transform, path):\n",
" batch_size = 32\n",
" image_tensors = []\n",
" result = []\n",
" labels = []\n",
"\n",
" # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" # device = torch.device('cpu')\n",
" device = torch.device('cuda:2')\n",
" method.model.to(device)\n",
" method.projection_model.to(device)\n",
"\n",
"\n",
" for key in classes:\n",
" image_dir = path + '/' + key \n",
" for img_name in os.listdir(image_dir):\n",
" image_path = os.path.join(image_dir, img_name)\n",
" image = Image.open(image_path)\n",
" # image = image.convert('RGB')\n",
"\n",
" # Preprocess the image\n",
" input_tensor = transform(image).unsqueeze(0) # Add batch dimension\n",
" image_tensors.append(input_tensor)\n",
"\n",
" # perform batching\n",
" if len(image_tensors) == batch_size:\n",
" batch_tensor = torch.cat(image_tensors).to(device)\n",
" # Use the pre-trained model to extract features\n",
" with torch.no_grad():\n",
" emb = method.model(batch_tensor)\n",
" projection = method.projection_model(emb)\n",
" # projection = method.model(input_tensor)\n",
" result.extend(projection.cpu())\n",
" # reset back to 0\n",
" image_tensors = []\n",
"\n",
"\n",
" labels.append(int(key))\n",
"\n",
" if len(image_tensors) > 0:\n",
" batch_tensor = torch.cat(image_tensors).to(device)\n",
" # Use the pre-trained model to extract features\n",
" with torch.no_grad():\n",
" emb = method.model(batch_tensor)\n",
" projection = method.projection_model(emb)\n",
" # projection = method.model(input_tensor)\n",
" result.extend(projection.cpu())\n",
"\n",
" return result, labels\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Number of runs\n",
"num_runs = 10\n",
"# Number of results/metrics per run\n",
"num_results_per_run = 10\n",
"# Create a 2D NumPy array to store the results\n",
"ri_results = np.zeros((num_runs, num_results_per_run))\n",
"nmi_results = np.zeros((num_runs, num_results_per_run))\n",
"\n",
"\n",
"start = 0\n",
"end = 9\n",
"for run_num in range(num_runs):\n",
" for selector in range(start,end+1):\n",
"\n",
" config = evolve(VICRegParams(), \n",
" encoder_arch = \"ws_resnet18\", # resnet18, resnet34, resnet50\n",
" dataset_name=\"custom\", \n",
" train_path=data_params[selector]['train_path'],\n",
" test_path=data_params[selector]['test_path'],\n",
" kmeans_weight=0, # it doens't matter since this is not used in the model\n",
" num_clusters=data_params[selector]['num_clusters'])\n",
" method = SelfSupervisedMethod(config)\n",
" # Initialize your ResNet model\n",
" checkpoint = data_params[selector]['checkpoint']\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=49-step=50.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=99-step=100.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=149-step=150.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=199-step=200.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=299-step=300.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=399-step=400.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=499-step=500.ckpt'\n",
" path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/last-v{run_num}.ckpt'\n",
" method = method.load_from_checkpoint(path)\n",
" # Set the model to evaluation mode\n",
" method.eval()\n",
"\n",
"\n",
"\n",
" # Define transform\n",
" path = data_params[selector]['test_path']\n",
" normalize_means, normalize_stds = calculate_mean_std(path)\n",
" # image_size = data_params[selector]['resize']\n",
" # crop_size = int(0.4 * image_size)\n",
" transform = transforms.Compose([\n",
" # transforms.Resize((image_size, image_size)),\n",
"\n",
" # transforms.CenterCrop(size=(crop_size, crop_size)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=normalize_means, std=normalize_stds),\n",
" ])\n",
"\n",
"\n",
"\n",
" # get all the classes\n",
" classes = list_directories(path)\n",
"\n",
" result, labels = inference(method, classes, transform, path)\n",
"\n",
" data = np.array(result)\n",
" # pca = PCA(n_components=2)\n",
" # reduced_data = pca.fit_transform(data)\n",
"\n",
" # Choose the number of clusters, say 3\n",
" kmeans = KMeans(n_clusters=data_params[selector]['num_clusters'], random_state=42, n_init=10)\n",
" clusters = kmeans.fit_predict(data)\n",
"\n",
"\n",
" # print(data_params[selector]['checkpoint'])\n",
" # print(\"Rand Index: \", rand_score(clusters, labels))\n",
" # print(\"NMI: \", normalized_mutual_info_score(clusters, labels))\n",
" rand_index = rand_score(clusters, labels)\n",
" nmi = normalized_mutual_info_score(clusters, labels)\n",
"\n",
" ri_results[run_num,selector] = rand_index\n",
" nmi_results[run_num, selector] = nmi"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"RI mean: 0.6374712643678162\n",
"RI std: 0.03729597247277404\n",
"NMI mean: 0.2602497793608286\n",
"NMI std: 0.02985710765468422\n",
"1\n",
"RI mean: 0.6458659159628819\n",
"RI std: 0.01396857176725652\n",
"NMI mean: 0.30136574266475413\n",
"NMI std: 0.03170806441677236\n",
"2\n",
"RI mean: 0.6408888888888888\n",
"RI std: 0.021672013185081274\n",
"NMI mean: 0.20592658284184645\n",
"NMI std: 0.048248741710938625\n",
"3\n",
"RI mean: 0.5765183804661967\n",
"RI std: 0.03417303145285537\n",
"NMI mean: 0.11716054993167128\n",
"NMI std: 0.05173307476499848\n",
"4\n",
"RI mean: 0.7636723163841806\n",
"RI std: 0.0838674066635877\n",
"NMI mean: 0.6087294263576666\n",
"NMI std: 0.10910741463199608\n",
"5\n",
"RI mean: 0.6088675385570139\n",
"RI std: 0.041236238284731705\n",
"NMI mean: 0.17859344790373669\n",
"NMI std: 0.08743358257833596\n",
"6\n",
"RI mean: 0.7343060937553582\n",
"RI std: 0.020174409290336055\n",
"NMI mean: 0.22234048756150684\n",
"NMI std: 0.029705953611425088\n",
"7\n",
"RI mean: 0.9384065934065934\n",
"RI std: 0.019608200939834567\n",
"NMI mean: 0.8406540399495203\n",
"NMI std: 0.03016596675386891\n",
"8\n",
"RI mean: 0.7505643232902918\n",
"RI std: 0.022756611198669806\n",
"NMI mean: 0.48619057929963566\n",
"NMI std: 0.01338604938860034\n",
"9\n",
"RI mean: 0.832415112386418\n",
"RI std: 0.019248159640681852\n",
"NMI mean: 0.5497811436968876\n",
"NMI std: 0.023003346601586715\n"
]
}
],
"source": [
"for data_select in range(10):\n",
"\tprint(data_select)\n",
"\tprint(\"RI mean: \", np.mean(ri_results[:,data_select]))\n",
"\tprint(\"RI std: \", np.std(ri_results[:,data_select]))\n",
"\tprint(\"NMI mean: \", np.mean(nmi_results[:,data_select]))\n",
"\tprint(\"NMI std: \", np.std(nmi_results[:,data_select]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}