battery-anomaly-detection/notebooks/simple_anomaly_detection.ipynb

1240 lines
1.3 MiB
Plaintext
Raw Normal View History

2023-08-28 18:24:41 +09:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple Univariate Time Series Anomaly Detection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## boilerplate"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import copy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from pylab import rcParams\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import rc\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torch import nn, optim\n",
"\n",
"import torch.nn.functional as F\n",
"import random\n",
"# from arff2pandas import a2p"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f3c54205290>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'\n",
"\n",
"sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n",
"\n",
"HAPPY_COLORS_PALETTE = [\"#01BEFE\", \"#FFDD00\", \"#FF7D00\", \"#FF006D\", \"#ADFF02\", \"#8F00FF\"]\n",
"\n",
"sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))\n",
"\n",
"# rcParams['figure.figsize'] = 12, 8\n",
"\n",
"RANDOM_SEED = 42\n",
"np.random.seed(RANDOM_SEED)\n",
"torch.manual_seed(RANDOM_SEED) \n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import polars as pl\n",
"from io import StringIO\n",
"import math\n",
"df = pl.read_csv('../data/battery_1.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We only need 'PACK1_CRIDATA_BATT_VOL'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## visualize fault and non-fault regions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"filter_condition = df['PACK1_CRIDATA_BATT_VOL'].cast(pl.Float32) != 0\n",
"voltage_data = (df['PACK1_CRIDATA_BATT_VOL']\n",
" .filter(filter_condition)\n",
" .cast(pl.Float32))\n",
"\n",
"def convert_values(values):\n",
" numerical_values = []\n",
" for value in values:\n",
" if value == 'False':\n",
" numerical_values.append(0)\n",
" elif value == 'True':\n",
" numerical_values.append(1)\n",
" else:\n",
" # numerical_values.append(np.nan)\n",
" numerical_values.append(-1)\n",
" return numerical_values\n",
"\n",
"\n",
"fault_data = convert_values(df['BATT_PACK_1_FAULT']\n",
" .filter(filter_condition))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'fault incidents')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABqkAAAQyCAYAAAA7jhX/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdeXhU1f3H8c9M9oVAWBK2gCAkKIpsLoiArSjUveJWrRZarSBaa4vV/qhaXFqrtVZFi7WlFbW1dUOpdQG3ilAVWasmLLIvCTvZl5n7++Mkk5nkTjIJyczcyfv1PDzMuTn3zJk75y5zv/ec47IsyxIAAAAAAAAAAAAQRu5IVwAAAAAAAAAAAAAdD0EqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAHC4HTt2KC8vz/fv8ccfj3SVAAAAAKBZBKkAAAAAAAAAAAAQdgSpAAAAAKCD8e91dccdd0S6OgAAAAA6KIJUAAAAAAAAAAAACDuCVAAAAAAAAAAAAAg7glQAAAAAAAAAAAAIO4JUAAAAAAAAAAAACDuCVAAAAAAAAAAAAAi7+EhXAAAAAAA6ourqaq1atUrbt2/XgQMHFB8fr65du2rw4ME6/vjjI129kJWVlWnDhg3avHmzDh48qIqKCnXq1Eldu3bVCSecoH79+h31e1iWpXXr1mnDhg3av3+/OnXqpF69emn06NFKT09vg09Rr7y8XCtXrtTu3bt14MABJSYmqlu3bjrhhBM0YMCANn0vAAAAoKMjSAUAAACgQ7v99tu1cOFCX/q1117TkCFDWlTGb3/7Wz399NO+9J/+9CeNGzfONm9hYaEef/xxvfnmmyopKbHNk5WVpSuvvFLf//73lZKS0qK6BHPHHXfo1VdfbbT81VdftV1ep6CgoNGyXbt26V//+pfef/99rVu3TtXV1UHX79Onj6699lpdeeWVSk5ObnG9X3rpJT3xxBPatWtXo7+lpKTo3HPP1W233abMzEy98sor+vnPf+77+4IFC3TqqaeG9D5r167Vk08+qY8//lhVVVW2eY455hj98Ic/1Le//W253QxMAgAAABwtrqoBAAAAdGjf/va3A9L+AatQeL1eLVq0yJfOysrS6aefbpv3zTff1KRJk/Tiiy8GDVBJUlFRkR577DF961vfsg0SRdo3v/lNPfzww1q5cmWTASpJ2rlzp37961/riiuu0M6dO0N+j6qqKs2YMUOzZ8+2DVBJptfTyy+/rClTprR6O1VXV+vOO+/UZZddpvfffz9ogEqStmzZov/7v//T9773PR05cqRV7wcAAACgHkEqAAAAAB3aqaeeqj59+vjSixYtksfjCXn95cuXa8+ePb70hRdeqLi4uEb5Xn31Vf3kJz9ReXl5wPKhQ4dq0qRJOuuss9S/f/+Av+3evVvf/e539eWXX4Zcn3CwLMv32uVyKScnR2eccYbOPfdcnXfeeTr99NPVtWvXgHXy8/M1bdq0JoNz/m699Va99957ActSU1M1ZswYnX/++TrjjDOUkZEhyQTCbrzxRhUXF7foc1RWVuqHP/yh/vnPfwYsT09P16mnnqrzzjtPZ599to499tiAv3/66af67ne/2+i7BAAAANAyDPcHAAAAoENzuVy66KKL9OSTT0qS9u3bp6VLl2rChAkhrd+w59Ull1zSKM+mTZv0y1/+Ul6v17fsjDPO0N13391ozqbPPvtMd955pzZv3ixJOnLkiG699VYtXLjwqIb++9nPfqabbrpJknTWWWf5lk+aNEk/+9nPWlRWXFycJk6cqMmTJ2vcuHHq1KlTozyWZWnZsmV68MEHlZ+fL0naunWrHn74Yd19991Nlv/iiy9qyZIlvnR8fLxuvPFG/eAHPwgYMrC6ulovvviiHnroIe3YscP3HYbq/vvv17Jly3zp3r17a9asWZo0aZLi4wN/Lufn52vOnDlauXKlJDMM4q9+9Svde++9LXpPAAAAAPXoSQUAAACgw2s45F9TczT5Ky0t1eLFi33pE088sVGvG0m69957VVFR4Ut/61vf0tNPP90oQCVJJ598sv7+978HlLNly5aAOa9ao2vXrurbt6/69u0bsDw1NdW33O6fnSVLluixxx7Tueeeaxugkkzwb+zYsXrhhRc0YsQI3/JXXnlFhw4dClrPiooK/fa3vw1Y9sADD2jmzJmN5rRKSEjQVVddpXnz5ikhIaHJchv66KOP9I9//MOXPv7447Vw4UKdd955jQJUkjRkyBA988wzOuOMM3zL/vnPf2r9+vUhvycAAACAQASpAAAAAHR4/fr106hRo3zpd999N6Q5h956662AId8aBrskacOGDVq+fLkvnZ2drfvvv19ud/CfY5mZmXrooYcC8rzwwgtNzpcUTr179w45b0pKiu655x5fuqKiotEwfv7+/e9/BwSbzjvvPF1wwQVNvsepp56qqVOnhlwnSXrqqacC6vjkk0+qc+fOTa6TmJioBx98MKBH27PPPtui9wUAAABQjyAVAAAAACgwwFRVVaU333yz2XX8h/pLTEzUeeed1yjPokWLAtLf//73lZaW1mzZQ4cODRiWb//+/fr444+bXS8a5ebmBvTKWrNmTdC8b7/9dkD6uuuuC+k9pk2bZtsDys7GjRv12Wef+dKXX365evXqFdK63bp107nnnutLf/DBByGtBwAAAKAxglQAAAAAIDMEn/9wcs0N+bdz586AQMc3vvENdenSpVG+VatWBaT9AxzNOf/885ssK9pUVlZq//792rlzp3bs2BHwz3/bfP3110HL8A9g9enTR8cff3xI792tWzeNHDkypLyffPJJQHrSpEkhrVdn9OjRvtdFRUXasWNHi9YHAAAAYIT2mBkAAAAAxLj09HRNnDhR//rXvySZgNDWrVvVv39/2/wLFy6UZVm+tN1Qf5L0v//9z/e6d+/eysrKCrlOw4cPD0h/8cUXIa8bDlu2bNG//vUvffLJJ1q/fn3Ic0IFG0qxqKhIBw8e9KWHDh3aovocd9xx+vTTT5vNt3LlyoB0p06dWhRocrlcAekdO3YEnb8LAAAAQHAEqQAAAACg1iWXXOILUkkmEHXLLbfY5n399dd9r7t3765x48Y1ylNVVaWysjJful+/fi2qT8+ePZWcnKyKigpJ0oEDB1q0fns5cuSIfvOb3+jll18OCNSFqqSkxHZ5wyBXdnZ2i8oNdci+PXv2BKSbm/OqOYcPHz6q9QEAAICOiuH+AAAAAKDWmDFj1LNnT1/6tddesw3CrFy5Ulu2bPGlL7jgAtv5kBoGL9LT01tcp06dOgUtLxIOHz6s733ve3rppZdaFaCSFHS94uLigHQoc3f5C3X7tvV29A9EAgAAAAgdPakAAAAAoJbb7daFF16oP/7xj5LMvFOffvqpTj311IB8CxcuDEgHG+qvoYbDxIWitYGg9vLAAw/oyy+/9KUTExM1efJkjR07Vrm5ucrOzlZqaqqSkpLkdtc/F3nNNdc0OxRfQkJCQLqmpqZFdauqqgopX3V1dYvKbU60fUcAAACAUxCkAgAAAAA/3/72t31BKskEpPyDVFVVVXrzzTd96eOPP155eXm2ZXXu3Dkg3bCnUCj8h8ZrWF647d69W6+++qov3aNHD/31r3/VoEGDml23tLS02TwNP1+wuauCCTV/ly5dfK87deqkFStWtOh9AAAAALQNhvsDAAAAAD8DBw7USSed5Eu//fbbKi8v96WXLFkSEAy5+OKLg5aVmJio1NRUX3rbtm0tqsuePXt881FJUteuXVu0flv78MMPA3o
"text/plain": [
"<Figure size 1000x600 with 2 Axes>"
]
},
"metadata": {
"image/png": {
"height": 537,
"width": 852
}
},
"output_type": "display_data"
}
],
"source": [
"\n",
"fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(10,3 * 2))\n",
"\n",
"axs[0].plot(voltage_data)\n",
"axs[0].set_title(\"voltage\")\n",
"axs[1].scatter(range(len(fault_data)), fault_data)\n",
"axs[1].set_title(\"fault incidents\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train region: 145000-end\n",
"\n",
"Test region: 45000-60000"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"train_voltage_data = voltage_data[145000:]\n",
"test_voltage_data = voltage_data[85000:100000]\n",
"val_voltage_data = voltage_data[120000:135000]\n",
"fault_voltage_data = voltage_data[45000:60000]"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'fault incidents')"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABqkAAAQyCAYAAAA7jhX/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3xUVf7/8fdkkkkhCSSRhN4JSBMEKwIqKKx1V7E3bKuIyqLousvasOFaWBtiWb923cWC4E9FUdRFUERAUCF0CAESWgLpk8z9/TEyyfSZZDKTybyejwcP5tw599wzyZ2S+5nP55gMwzAEAAAAAAAAAAAAhFFcpCcAAAAAAAAAAACA2EOQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAAAAAAAAAGFHkAoAAAAAAAAAAABhR5AKAAAAAAAAAAAAYUeQCgAAAACi3I4dO9SnTx/Hv2eeeSbSUwIAAAAAvwhSAQAAAAAAAAAAIOwIUgEAAABAjKmfdXXXXXdFejoAAAAAYhRBKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEHUEqAAAAAAAAAAAAhB1BKgAAAAAAAAAAAIQdQSoAAAAAAAAAAACEXXykJwAAAAAAschqtWrlypXKz8/X/v37FR8fr8zMTPXu3Vv9+vWL9PQCVl5erg0bNmjLli06cOCAKisrlZaWpszMTA0YMEBdunRp9DEMw9CaNWu0YcMG7du3T2lpaWrfvr2GDRum1NTUEDyKOhUVFVqxYoV27dql/fv3y2KxKCsrSwMGDFD37t1DeiwAAAAg1hGkAgAAABDT/vrXv2ru3LmO9kcffaS+ffsGNcbjjz+ul156ydF++eWXNWLECI99CwsL9cwzz+jTTz9VaWmpxz7Z2dm6+OKLdc011yg5OTmouXhz11136cMPP3Tb/uGHH3rcflheXp7btp07d+rjjz/WokWLtGbNGlmtVq/7d+zYUVdeeaUuvvhiJSUlBT3v9957T88995x27tzpdl9ycrLOOOMM3XHHHcrIyNAHH3ygv/3tb477X3/9dR133HEBHWf16tWaNWuWvvvuO1VXV3vs061bN/35z3/Wn/70J8XFUZgEAAAAaCw+VQMAAACIaX/605+c2vUDVoGw2WyaP3++o52dna0TTzzRY99PP/1UY8eO1Zw5c7wGqCSpqKhITz/9tP7whz94DBJF2qmnnqonnnhCK1as8BmgkqSCggI98sgjuuiii1RQUBDwMaqrqzVx4kRNmzbNY4BKsmc9vf/++zr//PMb/HOyWq26++67dcEFF2jRokVeA1SStHXrVv3973/XVVddpYMHDzboeAAAAADqEKQCAAAAENOOO+44dezY0dGeP3++amtrA95/6dKl2r17t6N9zjnnyGw2u/X78MMPddttt6miosJpe//+/TV27FiNHj1aXbt2dbpv165duvzyy/Xbb78FPJ9wMAzDcdtkMqlz58466aSTdMYZZ+jMM8/UiSeeqMzMTKd91q1bp6uvvtpncK6+KVOm6KuvvnLalpKSohNOOEFnnXWWTjrpJKWnp0uyB8JuuukmHTp0KKjHUVVVpT//+c/673//67Q9NTVVxx13nM4880yddtpp6tmzp9P9y5Yt0+WXX+72uwQAAAAQHMr9AQAAAIhpJpNJ5557rmbNmiVJ2rt3rxYvXqxRo0YFtL9r5tV5553n1mfTpk267777ZLPZHNtOOukk3XvvvW5rNv3444+6++67tWXLFknSwYMHNWXKFM2dO7dRpf/uvPNO3XzzzZKk0aNHO7aPHTtWd955Z1Bjmc1mjRkzRuPGjdOIESOUlpbm1scwDC1ZskT//Oc/tW7dOknStm3b9MQTT+jee+/1Of6cOXO0cOFCRzs+Pl433XSTrr32WqeSgVarVXPmzNFjjz2mHTt2OH6HgXrooYe0ZMkSR7tDhw6aOnWqxo4dq/h45z+X161bp/vvv18rVqyQZC+D+PDDD+uBBx4I6pgAAAAA6pBJBQAAACDmuZb887VGU31lZWX64osvHO2BAwe6Zd1I0gMPPKDKykpH+w9/+INeeukltwCVJB1zzDF65513nMbZunWr05pXDZGZmalOnTqpU6dOTttTUlIc2z3982ThwoV6+umndcYZZ3gMUEn24N/w4cP17rvvasiQIY7tH3zwgYqLi73Os7KyUo8//rjTthkzZmjSpElua1olJCTo0ksv1ezZs5WQkOBzXFf/+9//9J///MfR7tevn+bOnaszzzzTLUAlSX379tVrr72mk046ybHtv//9r9avXx/wMQEAAAA4I0gFAAAAIOZ16dJFQ4cOdbS//PLLgNYc+uyzz5xKvrkGuyRpw4YNWrp0qaOdk5Ojhx56SHFx3v8cy8jI0GOPPebU59133/W5XlI4dejQIeC+ycnJmj59uqNdWVnpVsavvk8++cQp2HTmmWfq7LPP9nmM4447ThMmTAh4TpL0wgsvOM1x1qxZat26tc99LBaL/vnPfzpltL3xxhtBHRcAAABAHYJUAAAAACDnAFN1dbU+/fRTv/vUL/VnsVh05plnuvWZP3++U/uaa65Rq1at/I7dv39/p7J8+/bt03fffed3v+YoNzfXKSvr559/9tp3wYIFTu3rrrsuoGNcffXVHjOgPNm4caN+/PFHR/vCCy9U+/btA9o3KytLZ5xxhqP99ddfB7QfAAAAAHcEqQAAAABA9hJ89cvJ+Sv5V1BQ4BToOOWUU9SmTRu3fitXrnRq1w9w+HPWWWf5HKu5qaqq0r59+1RQUKAdO3Y4/av/s9m8ebPXMeoHsDp27Kh+/foFdOysrCwdffTRAfX94YcfnNpjx44NaL/Dhg0b5rhdVFSkHTt2BLU/AAAAALvAvmYGAAAAAC1camqqxowZo48//liSPSC0bds2de3a1WP/uXPnyjAMR9tTqT9J+uWXXxy3O3TooOzs7IDnNHjwYKf2r7/+GvC+4bB161Z9/PHH+uGHH7R+/fqA14TyVkqxqKhIBw4ccLT79+8f1HyOPPJILVu2zG+/FStWOLXT0tKCCjSZTCan9o4dO7yu3wUAAADAO4JUAAAAAPC78847zxGkkuyBqMmTJ3vsO2/ePMftI444QiNGjHDrU11drfLycke7S5cuQc2nXbt2SkpKUmVlpSRp//79Qe3fVA4ePKhHH31U77//vlOgLlClpaUet7sGuXJycoIaN9CSfbt373Zq+1vzyp+SkpJG7Q8AAADEKsr9AQAAAMDvTjjhBLVr187R/uijjzwGYVasWKGtW7c62meffbbH9ZBcgxepqalBzyktLc3reJFQUlKiq666Su+9916DAlSSvO536NAhp3Yga3fVF+jPN9Q/x/qBSAAAAACBI5MKAAAAAH4XFxenc845Ry+++KIk+7pTy5Yt03HHHefUb+7cuU5tb6X+XLmWiQtEQwNBTWXGjBn67bffHG2LxaJx48Zp+PDhys3NVU5OjlJSUpSYmKi4uLrvRV5xxRV+S/ElJCQ4tWtqaoKaW3V1dUD9rFZrUOP609x+RwAAAEC0IEgFAAAAAPX86U9/cgSpJHtAqn6Qqrq6Wp9++qmj3a9fP/Xp08fjWK1bt3Zqu2YKBaJ+aTzX8cJt165d+vDDDx3ttm3b6tVXX1WvXr387ltWVua3j+vj87Z2lTeB9m/Tpo3jdlpampYvXx7UcQAAAACEBuX+AAAAAKCeHj166KijjnK0FyxYoIqKCkd74cKFTsGQP/7xj17HslgsSklJcbS3b98e1Fx2797tWI9KkjIzM4PaP9S++eYbp6yhO+64I6AAlSTt2bPHb5/s7GyZzWZHe8OGDUHNb+PGjQH1q/9
"text/plain": [
"<Figure size 1000x600 with 2 Axes>"
]
},
"metadata": {
"image/png": {
"height": 537,
"width": 852
}
},
"output_type": "display_data"
}
],
"source": [
"# fault region\n",
"fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(10,3 * 2))\n",
"\n",
"axs[0].plot(fault_voltage_data)\n",
"axs[0].set_title(\"voltage\")\n",
"fault_incidents = fault_data[45000:60000]\n",
"axs[1].scatter(range(len(fault_incidents)), fault_incidents)\n",
"axs[1].set_title(\"fault incidents\")"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f410d197490>]"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8wAAAVCCAYAAABgmwN+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdeZwcdZ3/8XfPPZPJHXInJIEkJCThBhERFVB3vVYWXVcU8ZafiuuFuigq3teyoiIoqyuuIuuNK6ioHHLJGcKZg5CQhFwkmVxzd9fvj8rMdFd/6+w6erpfz8eDB+me6v5+u+pb3/rW9/P9fitnWZYlAAAAAAAAAAAAAADqTEPWGQAAAAAAAAAAAAAAIAsEzAEAAAAAAAAAAAAAdYmAOQAAAAAAAAAAAACgLhEwBwAAAAAAAAAAAADUJQLmAAAAAAAAAAAAAIC6RMAcAAAAAAAAAAAAAFCXCJgDAAAAAAAAAAAAAOoSAXMAAAAAAAAAAAAAQF0iYA4AAAAAAAAAAAAAqEsEzAEAAAAAAAAAAAAAdYmAOQAAAAAAAAAAAACgLhEwBwAAAAAAAAAAAADUJQLmAAAAAAAAAAAAAIC6RMAcAAAAAAAAAAAAAFCXCJgDAAAAAAAAAAAAAOoSAXMAAAAAAAAAAAAAQF0iYA4AAAAAAAAAAAAAqEsEzAEAAAAAAAAAAAAAdakp7QQ3bdqklStXateuXRoYGNDUqVO1YMECLV++PO2shLJjxw499NBD2r59u/r7+zVt2jQdddRRWrhwYdZZAwAAAAAAAAAAAABEkFrA/MYbb9SVV16ptWvXGv8+d+5cXXDBBXrjG9+oXC5XcXrr1q3TP/3TP2lgYGD4vZNPPlk//vGPQ33P2rVr9eUvf1l333238vl82d+POuoovf/979dZZ51VcZ4BAAAAAAAAAAAAAOnJWZZlJZlAb2+vPvGJT+jGG28MtP1pp52mb33rWxozZkzkNC3L0hvf+EY9+OCDJe+HDZj/5je/0aWXXqq+vj7fbc877zxdeumlofMKAAAAAAAAAAAAAMhGojPMLcvShz/8Yf35z38efq+jo0Mve9nLtGzZMrW0tGjjxo266aabtGXLFknSnXfeqQ996EO68sor1djYGCnd66+/vixYHtZdd92lSy65RIODg5KkhoYGnXXWWTrhhBPU3Nys1atX63e/+526u7slST/5yU80adIkve9976soXQAAAAAAAAAAAABAOhKdYf6Tn/xEl1122fDrpUuX6rvf/a6mT59esl1/f7++/OUv6yc/+cnwex/72Mf0tre9LXSaO3fu1D/+4z9q3759mjhxoizLUldXl6TgM8wPHjyos88+W7t27ZIkjRs3Tt/97nd14oknlmy3fft2veMd79CaNWskSblcTv/7v/+rFStWhM43AAAAAAAAAAAAACBdDUl9cX9/v6666qrh15MmTdI111xTFiyXpJaWFl166aU688wzh9+7+uqrtX///tDpfv7zn9e+ffskSRdffHGkpd1/9KMfDQfLJekLX/hCWbBckqZNm6arrrpKHR0dkuwZ9Zdffnno9AAAAAAAAAAAAAAA6UssYH733Xdrx44dw6/f/va3a/LkyZ6f+chHPjL8766uLv3qV78KleZtt92mP/zhD5Kkk046Seecc06oz0tSoVAomel+/PHH66Uvfanr9rNmzdL5558//Pquu+7SunXrQqcLAAAAAAAAAAAAAEhXYgHze++9t+T1y172Mt/PLFiwQIsWLRp+/ac//Slwet3d3frsZz8rSWpubtanP/3pwJ8ttnLlSj333HPDr1/3utf5fubcc88tef2Xv/wlUtoAAAAAAAAAAAAAgPQkFjDfsmXL8L87Ojo0Z86cQJ8rDpg/+OCD2rt3b6DPXXHFFcNpXnDBBVq4cGGI3I647bbbSl6fdtppvp+ZM2eO5s6d6/odAAAAAAAAAAAAAIDqk1jAfOg54pI0duzYwJ8bN27c8L8LhYLWrl3r+5nHH39c1157rSR7ifT3vve9IXJaas2aNcP/njFjhqZNmxboc8cee6zxOwAAAAAAAAAAAAAA1SmxgHlra+vwv/v6+gJ/rre3t+T1U0895bl9Pp/XJz/5SeXzeUnSJZdcovb29hA5LbV+/frhfwedFS+pZIb5/v37S57fDgAAAAAAAAAAAACoPk1JffGkSZOG/713717t378/0EzzzZs3l7zetGmT5/Y//vGP9dhjj0mSzjzzTJ155pkRcmtOf+bMmYE/N2PGjJLXmzZt0tSpUyvKSxj9/f3q6uoaft3a2qrGxsbU0gcAAAAAAAAAAACApOTz+ZKJ2hMmTFBLS0vF35tYwHzp0qXD/7YsS/fcc4/OPvtsz88cPHhQjz76aMl7Bw4ccN3+2Wef1Te/+U1JUnt7uz75yU9WkGN7dvvg4ODw6+Ll4f04tz148GBFeQmrq6vLd3ABAAAAAAAAAAAAANSKOCYwJ7Yk+2mnnaZcLjf8+r//+79lWZbnZ6677jp1d3eXvOcVeL7sssuGt3/ve98baka4iTPt4mXl/Ti3dX4XAAAAAAAAAAAAAKC6JBYwnzdvns4444zh1/fff7+uuOIK1+3vv/9+fetb3yp73+355zfddJNuueUWSdLChQt1wQUXVJZhQ1rNzc2BP+uc7u98FjsAAAAAAAAAAAAAoLoktiS7JF188cW69957h2dbX3nllXriiSd0wQUX6Oijj1ZLS4s2btyoG264QT/60Y/U39+vlpYWFQqF4aXROzo6yr53//79+sIXviBJyuVy+sxnPhMquO3GOUt8YGAg8Gf7+/tLXre1tVWcnzCceW9oaCiZ4Y9o8vn88L95JjyQPs5BIHuch0C2OAeB7HEeAtniHASyx3kIZItzEBhhWZYKhcLw6zCrhXtJNGB+xBFH6Gtf+5o+9KEPDc/evuWWW4Znhjvlcjl97nOf0yc+8Ynh98aOHVu23de//nXt3LlTkvTa175WJ554Yiz5dQbn3Wa3mzi3NQX6k+SsJBcuXKjOzs5U81CLVq1apYGBATU3N2vFihVZZweoO5yDQPY4D4FscQ4C2eM8BLLFOQhkj/MQyBbnIDDiwIEDWr169fDruAaRJLYk+5CzzjpL//M//6MlS5Z4bjd58mR997vf1amnnloyMmDixIkl2z344IO6/vrrJUkTJkzQRz/60djy2tbWpqamkTEE+/btC/zZvXv3lrweM2ZMbPkCAAAAAAAAAAAAAMQv0RnmQ1asWKFf//rXuuOOO3Tbbbdp9erV6urqUnNzs2bNmqUzzjhDL3/5y9XZ2am777675LPOQPtll10my7IkSR/5yEc0adKkWPM6e/ZsbdiwQZL07LPPBv7c1q1bS17PmTMnzmwBAAAAAAAAAAAAAGKWSsBcspdbP/3003X66ad7bvfwww+XvF6+fHnJ682bNw//++qrr9b3vvc9z+/bvn17yXefffbZw6/POeccXXjhhSXbz58/fzhg/swzz3h+d7HibTs7OzV16tTAnwUAAAAAAAAAAAAApC+1gHlQt99++/C/Fy5cqClTprhuu2nTplDf3dfXVxLY3r17d9k2ixcvHn7G+rZt27R9+3ZNmzbN97tXrlw5/O9FixaFyhcAAAAAAAAAAAAAIH2JP8M8jI0bN+qBBx4Yfn3uueemnocXvvCFJa/vvPNO389s2rSpJHh/xhlnxJ4vAAAAAAAAAAAAAEC8qmqG+de//vXhf7e3t+vVr3512Tb3339/qO98yUteoi1btkiSTj75ZP34xz/23P64447T5MmTtWvXLknSz3/+c51zzjmen/nFL35R8vrMM88MlUcAAAAAAAAAAAAAQPqqZob5z372M/3pT38afn3RRRdp0qRJqeejoaFBb3zjG4dfP/jgg7r55ptdt9+yZYuuvfba4dfPe97ztHDhwkTzCAAAAAAAAAAAAACoXKIB84GBAV1xxRXatm2b6za9vb26/PLL9ZnPfGb4vRUrVugtb3lLklnzdMEFF2jixInDry+55JKSpeKHbN++XRdeeKG6u7uH3/vgBz+YSh4BAAAAAAAAAAAAAJVJdEn
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 673,
"width": 998
}
},
"output_type": "display_data"
}
],
"source": [
"# test region\n",
"plt.plot(test_voltage_data)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f410cf0fdc0>]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8wAAAU3CAYAAAA7dJ8MAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd7wcVf3/8ffefm9ueq+kN5JQpYiASBG7IlhQJIKIfFEQEMvPjti+6pevheYXEVTs7QtfGxaKNJGaBEIKgfSe3CS3t/n9Mdl7d2fPtN0pe3Nfz8eDB9m9u3POzpw5c875zDmTsSzLEgAAAAAAAAAAAAAAg0xF2hkAAAAAAAAAAAAAACANBMwBAAAAAAAAAAAAAIMSAXMAAAAAAAAAAAAAwKBEwBwAAAAAAAAAAAAAMCgRMAcAAAAAAAAAAAAADEoEzAEAAAAAAAAAAAAAgxIBcwAAAAAAAAAAAADAoETAHAAAAAAAAAAAAAAwKBEwBwAAAAAAAAAAAAAMSgTMAQAAAAAAAAAAAACDEgFzAAAAAAAAAAAAAMCgRMAcAAAAAAAAAAAAADAoETAHAAAAAAAAAAAAAAxKBMwBAAAAAAAAAAAAAIMSAXMAAAAAAAAAAAAAwKBEwBwAAAAAAAAAAAAAMCgRMAcAAAAAAAAAAAAADEoEzAEAAAAAAAAAAAAAg1JV0glu3LhRzzzzjHbv3q2uri6NGzdOM2fO1OLFi5POSig7duzQ008/re3bt6uzs1Pjx4/X/PnzNWfOnLSzJknq7OxUU1NT3+va2lpVVlamlyEAAAAAAAAAAAAAiEhPT486Ojr6Xo8YMUI1NTUlbzexgPkf//hH3XTTTVqzZo3x79OmTdPSpUt1/vnnK5PJlJze2rVr9da3vlVdXV197x133HH68Y9/HGo7a9as0de+9jU9+uij6unpKfj7/Pnz9ZGPfERnnHFGyXkuRVNTkzZu3JhqHgAAAAAAAAAAAAAgKePGjSt5G7Evyd7e3q6rrrpKV111lWuwXJI2bNig6667ThdffLFaWlpKStOyLH32s5/NC5YX4/e//73e/va366GHHjIGyyXphRde0OWXX67rrruupLQAAAAAAAAAAAAAAMmKdYa5ZVm65ppr9Le//a3vvYaGBr32ta/VokWLVFNTo/Xr1+tPf/qTNm/eLEl6+OGHdfXVV+umm24qeknxX/ziF3rqqadKyvsjjzyiT3/60+ru7pYkVVRU6IwzztAxxxyj6upqrVq1Svfcc49aW1slSXfddZdGjRqlD3/4wyWlCwAAAAAAAAAAAABIRsayLCuujd911115M68XLlyom2++WRMmTMj7XGdnp772ta/prrvu6nvvE5/4hC666KLQae7cuVOvf/3rtX//fo0cOVKWZfU92zvokuwtLS0688wztXv3bknSsGHDdPPNN+vYY4/N+9z27dv1gQ98QKtXr5YkZTIZ/fKXv9SSJUtC57tU+/bt09q1a/teT506VQ0NDYnn41Czdu1a9fT0qLKyUrNnz047O8CgwzkIpI/zEEgX5yCQPs5DIF2cg0D6OA+BdHEOAv1aW1vzHlE9e/ZsDR8+vOTtxjbDvLOzU7fcckvf61GjRum2227T6NGjCz5bU1Ojz33uc9q2bZv+/ve/S5JuvfVWnXfeeRo6dGiodK+//nrt379fkvTxj39c3/ve9/oC5kHdeeedfcFySfryl79cECyXpPHjx+uWW27RG9/4RrW2tsqyLN1www364Q9/GCq9KDhn4zc0NKixsTHxfBxqKioq1NPTo4qKCvYnkALOQSB9nIdAujgHgfRxHgLp4hwE0sd5CKSLcxBwV+xq5U6xPcP80Ucf1Y4dO/peX3zxxcZgea6Pfexjff9uamrSb3/721BpPvDAA/rzn/8sSXrFK16hc845J9T3Jam3tzdvpvvRRx+ts846y/XzkydP1vve976+14888kjeTG8AAAAAAAAAAAAAQHmKLWD++OOP571+7Wtf6/udmTNnau7cuX2v77333sDptba26otf/KIkqbq6Wp///OcDfzfXM888o127dvW9Pu+883y/c+655+a9zs6SBwAAAAAAAAAAAACUr9gC5ps3b+77d0NDg6ZOnRroe7kB86eeekr79u0L9L3vfOc7fWkuXbpUc+bMCZHbfg888EDe65NOOsn3O1OnTtW0adNctwEAAAAAAAAAAAAAKD+xBcyzzxGXFOo55MOGDev7d29vr9asWeP7neeff14/+tGPJNlLpF9++eUhcppv9erVff+eOHGixo8fH+h7Rx55pHEbAAAAAAAAAAAAAIDyFFvAvLa2tu/fHR0dgb/X3t6e9/rFF1/0/HxPT48+85nPqKenR5L06U9/WvX19SFymm/dunV9/w46K15S3gzzAwcO5D2/HQAAAAAAAAAAAABQfmILmI8aNarv3/v27dOBAwcCfW/Tpk15rzdu3Oj5+R//+Md67rnnJEmnn366Tj/99JA5dU9/0qRJgb83ceLEvNd++QYAAAAAAAAAAAAApCu2gPnChQv7/m1Zlh577DHf77S0tGjFihV57zU3N7t+fsuWLfr2t78tSaqvr9dnPvOZInNra29vV3d3d9/r3OXh/Tg/29LSUlJeAAAAAAAAAAAAAADxqoprwyeddJIymYwsy5Ik3XHHHTrjjDOUyWRcv/Ozn/1Mra2tee95BZ6vu+66vs9ffvnloWaEmzjTzl1W3o/zs85tJW3t2rWqqIjtfohBo6urq+//y5YtSzk3wODDOQikj/MQSBfnIJA+zkMgXZyDQPo4D4F0cQ4C/Xp7e2PZbmwB8+nTp+vUU0/V/fffL0l64okn9J3vfEdXXnml8fNPPPGEvvvd7xa87/b88z/96U+67777JElz5szR0qVLS86zM63q6urA362pqcl77XwWe9J6enr6nuuOaGQvSgDSwTkIpI/zEEgX5yCQPs5DIF2cg0D6OA+BdHEOAvGILWAuSR//+Mf1+OOP9822vummm7Ry5UotXbpUhx9+uGpqarR+/XrdfffduvPOO9XZ2amamhr19vb2LY3e0NBQsN0DBw7oy1/+siQpk8noC1/4QqjgthvnLPEwFU9nZ2fe67q6upLzU4rKykpmmEcgtwxEUcYAhMM5CKSP8xBIF+cgkD7OQyBdnINA+jgPgXRxDgL9ent7Y5kwHGvAfNasWfrGN76hq6++um/29n333dc3M9wpk8noS1/6kj71qU/1vTd06NCCz33zm9/Uzp07JUlve9vbdOyxx0aSX2dw3m12u4nzs6ZAf5Jmz56txsbGVPNwKFi2bJm6urpUXV2tJUuWpJ0dYNDhHATSx3kIpItzEEgf5yGQLs5BIH2ch0C6OAeBfs3NzVq1alXk2419CvIZZ5yhn/zkJ1qwYIHn50aPHq2bb75ZJ554Yt768yNHjsz73FNPPaVf/OIXkqQRI0bo2muvjSyvdXV1qqrqv4dg//79gb+7b9++vNdDhgyJLF8AAAAAAAAAAAAAgOjFOsM8a8mSJfrd736nhx56SA888IBWrVqlpqYmVVdXa/LkyTr11FN19tlnq7GxUY8++mjed52B9uuuu06WZUmSPvaxj2nUqFGR5nXKlCl6+eWXJUlbtmwJ/L2tW7fmvZ46dWqU2QIAAAAAAAAAAAAARCyRgLlkL7d+8skn6+STT/b83LPPPpv3evHixXmvN23a1PfvW2+9Vd///vc9t7d9+/a8bZ955pl9r8855xxddtlleZ+fMWNGX8B8w4YNntvOlfvZxsZGjRs3LvB3AQAAAAAAAAAAAADJSyxgHtSDDz7Y9+85c+ZozJgxrp/duHFjqG13dHTkBbb37NlT8Jl58+b1PWN927Zt2r59u8aPH++77Weeeabv33Pnzg2VLwAAAAAAAAAAAABA8mJ/hnkY69ev15NPPtn3+txzz008D6ecckre64cfftj3Oxs3bswL3p966qmR5wsAAAAAAAAAAAAAEK2ymmH+zW9+s+/f9fX1evOb31zwmSeeeCLUNl/zmtdo8+bNkqTjjjtOP/7xjz0/f9RRR2n06NHavXu3JOlXv/qVzjn
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 667,
"width": 998
}
},
"output_type": "display_data"
}
],
"source": [
"# test region\n",
"plt.plot(val_voltage_data)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f410d197670>]"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8wAAAU2CAYAAADwKEypAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzddZgk1aH38V+Prjtr7C7LwuIs7gEiJCF+k0uUCNE3TtxDEuI37kaEJMRdiAd3WxZdYVl319mRrvePmpqurj5W3T0zPcz38zw87HSfPnXKTh2vQhRFkQAAAAAAAAAAAAAAGGaaBjsBAAAAAAAAAAAAAAAMBjrMAQAAAAAAAAAAAADDEh3mAAAAAAAAAAAAAIBhiQ5zAAAAAAAAAAAAAMCwRIc5AAAAAAAAAAAAAGBYosMcAAAAAAAAAAAAADAs0WEOAAAAAAAAAAAAABiW6DAHAAAAAAAAAAAAAAxLdJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAAAAAMMSHeYAAAAAAAAAAAAAgGGJDnMAAAAAAAAAAAAAwLBEhzkAAAAAAAAAAAAAYFiiwxwAAAAAAAAAAAAAMCzRYQ4AAAAAAAAAAAAAGJboMAcAAAAAAAAAAAAADEt0mAMAAAAAAAAAAAAAhqWWwU4A6qOzs1M7duzo+7u9vV3Nzc2DlyAAAAAAAAAAAAAAqJOenh4dOHCg7+8JEyaora2t5njpMH+M2LFjh1avXj3YyQAAAAAAAAAAAACAATF16tSa42BJdgAAAAAAAAAAAADAsESHOQAAAAAAAAAAAABgWGJJ9seI9vb2sr9nz56tUaNGDVJqHjuWLVumnp4eNTc36/DDDx/s5ABAEPIuAEMReReAoYi8C8BQRN4FYKgi/wKwb9++sldUZ/tHq0WH+WNEc3Nz2d+jRo3SmDFjBik1jx1NTU3q6elRU1MTxxPAkEHeBWAoIu8CMBSRdwEYisi7AAxV5F8AsrL9o9ViSXYAAAAAAAAAAAAAwLBEhzkAAAAAAAAAAAAAYFiiwxwAAAAAAAAAAAAAMCzRYQ4AAAAAAAAAAAAAGJboMAcAAAAAAAAAAAAADEt0mAMAAAAAAAAAAAAAhiU6zAEAAAAAAAAAAAAAwxId5gAAAAAAAAAAAACAYYkOcwAAAAAAAAAAAADAsESHOQAAAAAAAAAAAABgWKLDHAAAAAAAAAAAAAAwLNFhDgAAAAAAAAAAAAAYlugwBwAAAAAAAAAAAAAMS3SYAwAAAAAAAAAAAACGJTrMAQAAAAAAAAAAAADDEh3mAAAAAAAAAAAAAIBhiQ5zAAAAAAAAAAAAAMCwRIc5AAAAAAAAAAAAAGBYosMcAAAAAAAAAAAAADAs0WEOAAAAAAAAAAAAABiW6DAHAAAAAAAAAAAAAAxLdJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAAAAAMMSHeYAAAAAAAAAAAAAgGGJDnMAAAAAAAAAAAAAwLBEhzkAAAAAAAAAAAAAYFiiwxwAAAAAAAAAAAAAMCzRYQ4AAAAAAAAAAAAAGJboMAcAAAAAAAAAAAAADEt0mAMAAAAAAAAAAAAAhiU6zAEAAAAAAAAAAAAAwxId5gAAAAAAAAAAAACAYYkOcwAAAAAAAAAAAADAsESHOQAAAAAAAAAAAABgWKLDHAAAAAAAAAAAAAAwLNFhDgAAAAAAAAAAAAAYlugwBwAAAAAAAAAAAAAMS3SYAwAAAAAAAAAAAACGJTrMAQAAAAAAAAAAAADDEh3mAAAAAAAAAAAAAIBhiQ5zAAAAAAAAAAAAAMCwRIc5AAAAAAAAAAAAAGBYosMcAAAAAAAAAAAAADAs0WEOAAAAAAAAAAAAABiW6DAHAAAAAAAAAAAAAAxLdJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAAAAAMMSHeYAAAAAAAAAAAAAgGGJDnMAAAAAAAAAAAAAwLBEhzmAPqs6Il25PtK9e6LBTgoAAAAAAAAAAADQ71oGOwEAGsOK/ZFOvlPa0S21FqSrF0R60qTCYCcLAAAAAAAAAAAA6DfMMAcgSfrYirizXJK6IunVDw9qcgAAAAAAAAAAAIB+R4c5AEnSTzaU/73qwOCkAwAAAAAAAAAAABgodJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAAAAAMMSHeYAAAAAAAAAAAAAgGGJDnMAAAAAAAAAAAAAwLBEhzkAAAAAAAAAAAAAYFiiwxwAAAAAAAAAAAAAMCzRYQ4AAAAAAAAAAAAAGJboMAcgSYoGOwEAAAAAAAAAAADAAKPDHAAAAAAAAAAAAAAwLNFhDkCSVBjsBAAAAAAAAAAAAAADjA5zAAAAAAAAAAAAAMCwRIc5AAAAAAAAAAAAAGBYosMcAAAAAAAAAAAAADAs0WEOAAAAAAAAAAAAABiW6DAHAAAAAAAAAAAAAAxLdJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAEF+vCHSobdEOvWOSPfuiQY7OQAgSfrj5kiH3xLp+Nsj3bSDvAlAPnSYAwAAAAAAAAC8dnRFeu3D0soO6e490tuXDnaKAEDqLEa65GFpeYf0wF7p9UsGO0UAhho6zAEAAAAAAAAAXldukLpSEzev3TFoSQGAPtdsl3Z2l/5+YK+0v4dZ5gDC0WEOAAAAAAAAAPDqpP8JQAPqJm8CUCM6zAEAAAAAAAAAAAAAwxId5gAAAAAAAAAAr4hZnAAaEFkTgFrRYQ4AAAAAAAAAAIDHjMJgJwDAkEKHOQAAAAAAAAAAAABgWKLDHAAAAAAAAAAAAEMSS7IDqBUd5gAAAAAAAAAwjO3qjrR4X6TOorvbiU4pAADwWESHOQAAAAAAAAAMU/ftiXT0bdLRt0nn3C3t6KJbHMDQwvvKAdSKDnMAAAAAAAAAGKY+tFxa3xn/+67d0g/W28PSKQWgETHMB0Ct6DAHAAAAAAAAgGHqz1vL//7oikFJBgAAwKChwxyAJEbhAQAAAAAAwI32IwBDRYElMQDkQIc5AAAAAAAAAAAAHjMiRvgAyIEOcwCSeAcVAAAAAAAAAGDooW8cQK3oMAcgiUIFAAAAAAAA3G1EtB8BGCpYkh1AHnSYAwAAAAAAAAAAYEiibxxAregwByCJQgUAAAAAAABoIwIw9LD6BYBatQx2AgA0BgoVAAAAjaUYRfrFJml/j3TxNGlEM83XAACg/7naiCiNAACAxyI6zAEAAACgAf2/xdL318f//ulG6ZqTBjc9AAAATLgAMFREZFgAcmBJdgCSGCEMAADQSHqiqK+zXJKu2yHdt4cWHwAAAADIonMcQK3oMAcgiRHCAAAAjaTbUDhbvn/g0wEAAAAAAPBYR4c5AAAAAAwBDHAEAAAAgEoFlk8FUCM6zAFIYkl2AACARkLZDAAANCIG8AFoRCzJDqBWdJgDAAAAAAAAAAAAAIYlOswBAAAAAAAAAADwmMGkcwB50GEOAAAAAEMADT4AAAAAAAD1R4c5AAAAAAAAAEAS7wIGMPSQbQGoFR3mAAAAAAAAGLZ+siHSRfdH+tyqSD30FAIAAADDTstgJwAAAAAAAAAYDNdtj/SKh+J//26zNLFFes3MwU0TMNgKhcFOAQDkQ7YFoFbMMAcAAACAIYA5jwBQf/9vcfnfr1tsDgcgxiIMABoRWROAWtFhDgAAAAANhgYfABgYS/YPdgoAAEB/oE4FIA86zAEAAAAAAAAAkphFDgAAhh86zAEAAAAAAAAAXvSlA2hE5E0AakWHOQBJFCoAAAAaCTO7AADAYCkUBjsFAAAAA4sOcwAAAAAAAACAJAbuAQCA4YcOcwCSJAYPAwAANDYarwEAAAAAAOqPDnMAkliSHQAAAAAAAADw2EB7N4A86DAHAAAAgAZD4w4AAGhElFEAAMBjER3mACSxJDsAAAAAAAAAAACGHzrMAUhihDAAAECjo7wGAAAAAJWoKwGoFR3mAAAAANBgaPABAACNiDIKAAB4LKLDHIAklmQHAAAAAAAAAADA8EOHOQBJjBAGAAAAAACAGxMuAAwVEQ3eAHKgwxwAAAAAhgDaewAAwEBwlTkojwAAgMc
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 667,
"width": 998
}
},
"output_type": "display_data"
}
],
"source": [
"\n",
"# train region\n",
"# plt.plot(voltage_data[145000:])\n",
"plt.plot(train_voltage_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Processing"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(55774,)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_voltage_data.to_numpy().shape"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"111.548"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"55774 / 500"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(25000,)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_voltage_data.to_numpy().shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"50.0"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"25000 / 500"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"data = train_voltage_data.to_numpy()\n",
"segment_size = 100\n",
"segments = [ torch.tensor(data[i:i + segment_size]).unsqueeze(1).float() for i in range(0, len(data), segment_size) ]\n"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"if (segments[-1].shape[0] != segment_size):\n",
" segments.pop()"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([557, 100, 1])"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.stack(segments).shape"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"mean = np.mean(train_voltage_data)\n",
"stdev = np.std(train_voltage_data)\n",
"\n",
"def create_dataset(series, segment_size, mean, stdev):\n",
" # normalize the data\n",
" data = (series.to_numpy() - mean)/stdev\n",
" segments = [ torch.tensor(data[i:i + segment_size]).unsqueeze(1).float() for i in range(0, len(data), segment_size) ]\n",
" # reject the last segment if it doesn't fit the shape\n",
" if (segments[-1].shape[0] != segment_size):\n",
" segments.pop()\n",
" n_seq, seq_len, n_features = torch.stack(segments).shape\n",
"\n",
" return segments, seq_len, n_features\n"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"segment_size = 100\n",
"train_dataset, seq_len, n_features = create_dataset(train_voltage_data, segment_size, mean, stdev)\n",
"val_dataset, _, _ = create_dataset(val_voltage_data, segment_size, mean, stdev)\n",
"test_normal_dataset, _, _ = create_dataset(test_voltage_data, segment_size, mean, stdev)\n",
"test_anomaly_dataset, _, _ = create_dataset(fault_voltage_data, segment_size, mean, stdev)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Encoder Decoder"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"class Encoder(nn.Module):\n",
"\n",
" def __init__(self, seq_len, n_features, embedding_dim=64):\n",
" super(Encoder, self).__init__()\n",
"\n",
" self.seq_len, self.n_features = seq_len, n_features\n",
" self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim\n",
"\n",
" self.rnn1 = nn.LSTM(\n",
" input_size=n_features,\n",
" hidden_size=self.hidden_dim,\n",
" num_layers=1,\n",
" batch_first=True\n",
" )\n",
" \n",
" self.rnn2 = nn.LSTM(\n",
" input_size=self.hidden_dim,\n",
" hidden_size=embedding_dim,\n",
" num_layers=1,\n",
" batch_first=True\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = x.reshape((1, self.seq_len, self.n_features))\n",
"\n",
" x, (_, _) = self.rnn1(x)\n",
" x, (hidden_n, _) = self.rnn2(x)\n",
"\n",
" return hidden_n.reshape((self.n_features, self.embedding_dim))"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"class Decoder(nn.Module):\n",
"\n",
" def __init__(self, seq_len, input_dim=64, n_features=1):\n",
" super(Decoder, self).__init__()\n",
"\n",
" self.seq_len, self.input_dim = seq_len, input_dim\n",
" self.hidden_dim, self.n_features = 2 * input_dim, n_features\n",
"\n",
" self.rnn1 = nn.LSTM(\n",
" input_size=input_dim,\n",
" hidden_size=input_dim,\n",
" num_layers=1,\n",
" batch_first=True\n",
" )\n",
"\n",
" self.rnn2 = nn.LSTM(\n",
" input_size=input_dim,\n",
" hidden_size=self.hidden_dim,\n",
" num_layers=1,\n",
" batch_first=True\n",
" )\n",
"\n",
" self.output_layer = nn.Linear(self.hidden_dim, n_features)\n",
"\n",
" def forward(self, x):\n",
" x = x.repeat(self.seq_len, self.n_features)\n",
" x = x.reshape((self.n_features, self.seq_len, self.input_dim))\n",
"\n",
" x, (hidden_n, cell_n) = self.rnn1(x)\n",
" x, (hidden_n, cell_n) = self.rnn2(x)\n",
" x = x.reshape((self.seq_len, self.hidden_dim))\n",
"\n",
" return self.output_layer(x)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"class RecurrentAutoencoder(nn.Module):\n",
"\n",
" def __init__(self, seq_len, n_features, embedding_dim=64):\n",
" super(RecurrentAutoencoder, self).__init__()\n",
"\n",
" self.encoder = Encoder(seq_len, n_features, embedding_dim).to(device)\n",
" self.decoder = Decoder(seq_len, embedding_dim, n_features).to(device)\n",
"\n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.decoder(x)\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"model = RecurrentAutoencoder(seq_len, n_features, 128)\n",
"model = model.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_dataset, val_dataset, n_epochs):\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
" criterion = nn.L1Loss(reduction='sum').to(device)\n",
" history = dict(train=[], val=[])\n",
"\n",
" best_model_wts = copy.deepcopy(model.state_dict())\n",
" best_loss = 10000.0\n",
" \n",
" for epoch in range(1, n_epochs + 1):\n",
" model = model.train()\n",
"\n",
" train_losses = []\n",
" for seq_true in train_dataset:\n",
" optimizer.zero_grad()\n",
"\n",
" seq_true = seq_true.to(device)\n",
" seq_pred = model(seq_true)\n",
"\n",
" loss = criterion(seq_pred, seq_true)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" train_losses.append(loss.item())\n",
"\n",
" val_losses = []\n",
" model = model.eval()\n",
" with torch.no_grad():\n",
" for seq_true in val_dataset:\n",
"\n",
" seq_true = seq_true.to(device)\n",
" seq_pred = model(seq_true)\n",
"\n",
" loss = criterion(seq_pred, seq_true)\n",
" val_losses.append(loss.item())\n",
"\n",
" train_loss = np.mean(train_losses)\n",
" val_loss = np.mean(val_losses)\n",
"\n",
" history['train'].append(train_loss)\n",
" history['val'].append(val_loss)\n",
"\n",
" if val_loss < best_loss:\n",
" best_loss = val_loss\n",
" best_model_wts = copy.deepcopy(model.state_dict())\n",
"\n",
" print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')\n",
"\n",
" model.load_state_dict(best_model_wts)\n",
" return model.eval(), history"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: train loss 83876.04766045781 val loss 76590.19979166667\n",
"Epoch 2: train loss 69409.47180066202 val loss 62160.938072916666\n",
"Epoch 3: train loss 55019.510905240124 val loss 47822.004479166666\n",
"Epoch 4: train loss 40701.095496241025 val loss 33518.067044270836\n",
"Epoch 5: train loss 26405.11348462747 val loss 19228.23103515625\n",
"Epoch 6: train loss 12124.039359676559 val loss 5028.084357096354\n",
"Epoch 7: train loss 1190.5766426237096 val loss 489.6685514322917\n",
"Epoch 8: train loss 445.39513634180037 val loss 488.57960042317706\n",
"Epoch 9: train loss 445.98794736468386 val loss 489.6590458170573\n",
"Epoch 10: train loss 443.86239300767227 val loss 489.81301350911457\n"
]
}
],
"source": [
"model, history = train_model(\n",
" model, \n",
" train_dataset, \n",
" val_dataset, \n",
" n_epochs=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACC0AAAWZCAYAAABTun0bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3hU1dr+8Xtn0khCSAJJgNA7SAlSFLvSFBsiduEoFkRBAcVysOKxcOyKIliOgGABkapItaBI772XEFII6X1m3j/yMmanTiDJJJnv57q8frOerFnzZEh23t/Z96xl2O12uwAAAAAAAAAAAAAAACqZh6sbAAAAAAAAAAAAAAAA7onQAgAAAAAAAAAAAAAAcAlCCwAAAAAAAAAAAAAAwCUILQAAAAAAAAAAAAAAAJcgtAAAAAAAAAAAAAAAAFyC0AIAAAAAAAAAAAAAAHAJQgsAAAAAAAAAAAAAAMAlCC0AAAAAAAAAAAAAAACXILQAAAAAAAAAAAAAAABcgtACAAAAAAAAAAAAAABwCUILAAAAAAAAAAAAAADAJQgtAAAAAAAAAAAAAAAAlyC0AAAAAAAAAAAAAAAAXILQAgAAAAAAAAAAAAAAcAlCCwAAAAAAAAAAAAAAwCUILQAAAAAAAAAAAAAAAJcgtAAAAAAAAAAAAAAAAFyC0AIAAAAAAAAAAAAAAHAJQgsAAAAAAAAAAAAAAMAlCC0AAAAAAADUUG3btnX89+yzz7q6HX300Uemnk6cOOHqloAizZ071/SzunbtWle3BAAAANRYnq5uAAAAAABQvk6cOKHevXs7xhEREVq5cqULOwIAAAAAAACKxk4LAAAAAADALQ0ZMsTxKeprrrnG1e0AAAAAAOCWCC0AAAAAAAAAAAAAAACXILQAAAAAAAAAAAAAAABcwtPVDQAAAAAAAKBi7N2719UtmIwaNUqjRo1ydRsAAAAAgCqEnRYAAAAAAAAAAAAAAIBLEFoAAAAAAAAAAAAAAAAuwfEQAAAAAIBzlpmZqQ0bNig6OlpnzpyRj4+P6tatqw4dOqhFixbntfahQ4e0Z88excXFKT09XRaLRf7+/qpfv76aNWumFi1ayDAMp9ez2Wzat2+f9u3bp4SEBKWnp8vLy0sBAQFq2LChmjdvriZNmpxXz2VRke9dZYmPj9eWLVsUFxenpKQk1a5dW6GhoerevbtCQkLK7XWsVqs2b96sqKgoxcXFSZI6d+6snj17lttrlLdz6TkhIUH79u3T0aNHlZycLJvNpsDAQIWFhalr167l+p6ei+3bt+vQoUOKiYmRr6+vwsPD1bNnTwUHB7ukn6SkJG3cuFHR0dFKS0tTcHCw2rZtq06dOpXp2lBQdna21q1bp+PHjyslJUV169ZVo0aN1K1bN3l6uv5/Sjt7bTx9+rTS09MVHByshg0bqnv37vL19S231zlz5ow2bdqkU6dOKS0tTaGhoWrXrp3at29/3mtXxvXParVq586dOnLkiBISEpSRkSF/f39FRESoXbt2ioiIOK/1MzIytH79ekVHRysxMVFBQUFq2bKlIiMjz/nnJCEhQbt27dKxY8eUmpoqq9WqWrVqOX4G27ZtKz8/v/PqGwAAAKiKXP//0wIAAAAAVDuHDx/W+++/r19//VWZmZlFzmncuLHuu+8+3XnnnU7fwMnJydGMGTP0zTff6NixYyXOrV27tnr16qU777xTl156abHz0tLS9Nlnn+mHH35QbGxsiWuGhITosssu05AhQ9S5c2enei6r8nzvnnnmGc2bN88xnj9/vtq1a1emft5++2199tlnjvHnn3+uyy+/vNj5drtdP//8s7744gvt3LlTdru90BzDMNStWzeNGTNG3bt3L7WHEydOqHfv3o7xyJEjNWrUKGVlZenjjz/W3LlzHTf+z+rdu/c5hRbmzp2r5557rlA9KipKbdu2LfZ506dP10UXXVTuPe/YsUOLFy/WH3/8of3795fYe+fOnfXwww+rb9++Jc7LL//3dMstt+jNN98s89x58+Zp6tSpOnjwYKHneHh46LrrrtO4cePUoEGDUvv56KOPNGnSJMd4xYoVatSoUZnmxsbG6q233tIvv/yirKysQs+rX7++xowZo4EDB5baT34ZGRmaNGmSvvvuO6WkpBT6emhoqO6++24NHz5cFotFzz77rH788UfH1/fu3Vum1yuLzMxMTZs2Td9//71OnDhR5BwfHx/17dtXY8aMKfY9za/g78LZn/GjR4/q/fff1/Lly5WdnV3oec2bN9fYsWPVr1+/Mn8fFfW3I7+jR49q8uTJWr58eZH/jmc1bdpU1157re666y6nfnbPSk1N1bvvvqv58+crNTW10NeDgoI0YsQIDRkyRBaLxak1169fr8mTJ2vNmjWy2WzFzrNYLOrQoYP69eunYcOGVYkQDQAAAFAeOB4CAAAAAFAm06ZN04033qglS5YUe9NJko4fP65XX31Vt9xyi6Kjo0tdNyEhQbfffrsmTpxYamBBklJSUrR06VLNnDmz2DlHjx7VjTfeqMmTJ5caWDjbw4IFC7Rw4cJS556L8n7vbrnlFtM4f4DBGTabzfS9hoWF6ZJLLil2flxcnO666y6NGTNGO3bsKDKwIOUFGzZs2KB77rlHr7/+erHzShIVFaXbbrtNU6ZMKXTzv6oqa8/r16/Xrbfeqi+//LLUwIIkbdu2TSNHjtTYsWNL/PkpL9nZ2Xr66af1zDPPFBlYkPJ+hhYvXqzbbrtNBw4cqPCe1q1bp4EDB2rBggVFBhYk6dSpU3rmmWf02muvOb1udHS0Bg4cqM8//7zYG91xcXH64IMPNGzYMCUnJ59T/+di06ZN6tu3r959991iAwuSlJWVpUWLFum6667TokWLzum11qxZo4EDB+qnn34qMrAg5QUPRo0apfHjx5d4g72givrbcZbdbtf777+vAQMG6McffywxsCDl/X2YMmWKvvzyS6dfY9++fbrllls0c+bMIgMLkpSYmKg33nhDY8aMUW5ubqlrvvfee7r33nv1559/lvp+Wq1Wbd++Xe+8847S09Od7hsAAACo6ojjAgAAAACc9sknn+iDDz4w1SwWizp16qSGDRsqIyNDu3fv1qlTpxxf37dvn+666y7NmjVLDRs2LHJdu92ukSNHateuXaZ6vXr11KZNGwUHB8swDKWmpurYsWM6duxYqTeDsrKy9NBDDykqKspUb9iwoVq2bKk6derIZrMpJSVFR48e1fHjx8/p5rqzKuK9u+iiixQREeH4HhcuXKhx48Y5/eneNWvWmF7vpptuKva5x44d03333Vfo/QwLC1O7du0UGBio1NRU7dy503TDftq0aUpLSyvTDeSsrCyNHDnS8cl1Hx8fdenSRaGhoUpLS6uUm+NldS49F7xB6eXlpRYtWqhBgwYKCAhQTk6OYmNjtXfvXtMNysWLF8tut+u9996r0O9pwoQJmj9/vqO3jh07qn79+srNzXUcY3FWXFycnnjiCf3444/y9vaukH4OHjyosWPHOm4W169fX+3bt5e/v7/i4uK0ZcsWU5Bh+vTp6tixo26++eYS101ISNDQoUMLhaVCQ0PVoUMHBQQEKCYmRlu3blVOTo7+/vtvjR8/Xv7+/uX/TRawcuVKjR492vR9GYahFi1aqGnTpvLz89Pp06e1detWx89Idna2nnrqKeXm5pZpt4lDhw7prbfecqwTFBSkTp06KTAwUHFxcdq8ebNycnIc8+fMmSOLxaIJEyaUunZF/e04y2q1asyYMfrll18Kfa158+Zq2rSpAgIClJaWpiNHjujIkSNlvt7Hx8fr6aefdvQYEhKiCy64QHXq1FFiYqI2b96stLQ0x/xffvlFU6dO1aOPPlrsmrNnz9ann35qqnl7e6t9+/aqX7++fHx8lJ6ertOnT2v//v3FBiUAAACA6o7QAgAAAADAKX///bc+/PBDU+2mm27S008/rdDQUEfNbrdrxYoVeuWVVxy7G0RHR+vpp5/W9OnT5eFReNO/3377TRs3bnSMmzZtqldeeUUXX3xxkWfTp6en66+//tLixYtltVqL7HfOnDm
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 716,
"width": 1046
}
},
"output_type": "display_data"
}
],
"source": [
"ax = plt.figure().gca()\n",
"\n",
"ax.plot(history['train'])\n",
"ax.plot(history['val'])\n",
"plt.ylabel('Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.legend(['train', 'test'])\n",
"plt.title('Loss over training epochs')\n",
"plt.show();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the model\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH = 'model.pth'\n",
"\n",
"torch.save(model, MODEL_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"# reload the model\n",
"# model.torch.load('model.pth')\n",
"# model = model.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check reconstruction error"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"def predict(model, dataset):\n",
" predictions, losses = [], []\n",
" criterion = nn.L1Loss(reduction='sum').to(device)\n",
" with torch.no_grad():\n",
" model = model.eval()\n",
" for seq_true in dataset:\n",
" seq_true = seq_true.to(device)\n",
" seq_pred = model(seq_true)\n",
"\n",
" loss = criterion(seq_pred, seq_true)\n",
"\n",
" predictions.append(seq_pred.cpu().numpy().flatten())\n",
" losses.append(loss.item())\n",
" return predictions, losses"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[478.3055419921875,\n",
" 324.8333740234375,\n",
" 323.153564453125,\n",
" 426.354248046875,\n",
" 400.283935546875,\n",
" 310.18212890625,\n",
" 423.1041259765625,\n",
" 409.283935546875,\n",
" 288.3319091796875,\n",
" 6717.36279296875,\n",
" 769.3988037109375,\n",
" 375.0321044921875,\n",
" 436.6552734375,\n",
" 294.88232421875,\n",
" 364.153564453125,\n",
" 446.8544921875,\n",
" 384.1339111328125,\n",
" 318.0609130859375,\n",
" 423.754150390625,\n",
" 378.783935546875,\n",
" 289.281982421875,\n",
" 6582.95556640625,\n",
" 1024.248779296875,\n",
" 359.783935546875,\n",
" 336.283935546875,\n",
" 332.783935546875,\n",
" 347.783935546875,\n",
" 496.783935546875,\n",
" 614.783935546875,\n",
" 614.783935546875,\n",
" 614.783935546875,\n",
" 546.283935546875,\n",
" 355.183349609375,\n",
" 420.35400390625,\n",
" 350.0330810546875,\n",
" 383.153564453125,\n",
" 415.7130126953125,\n",
" 353.574462890625,\n",
" 374.153564453125,\n",
" 469.783935546875,\n",
" 330.7822265625,\n",
" 459.854736328125,\n",
" 371.48388671875,\n",
" 331.074462890625,\n",
" 466.655517578125,\n",
" 308.3328857421875,\n",
" 414.7042236328125,\n",
" 387.2337646484375,\n",
" 332.153564453125,\n",
" 489.4053955078125,\n",
" 336.7330322265625,\n",
" 274.153564453125,\n",
" 274.153564453125,\n",
" 274.153564453125,\n",
" 447.9044189453125,\n",
" 523.783935546875,\n",
" 374.783935546875,\n",
" 319.8818359375,\n",
" 459.3052978515625,\n",
" 345.68359375,\n",
" 334.153564453125,\n",
" 462.3057861328125,\n",
" 302.182373046875,\n",
" 420.954345703125,\n",
" 355.8333740234375,\n",
" 376.553955078125,\n",
" 428.283935546875,\n",
" 307.8818359375,\n",
" 424.5042724609375,\n",
" 370.28369140625,\n",
" 330.153564453125,\n",
" 459.85498046875,\n",
" 345.3336181640625,\n",
" 336.153564453125,\n",
" 449.6048583984375,\n",
" 327.6331787109375,\n",
" 357.153564453125,\n",
" 453.45556640625,\n",
" 308.7325439453125,\n",
" 425.954345703125,\n",
" 410.283935546875,\n",
" 309.48193359375,\n",
" 427.754150390625,\n",
" 435.783935546875,\n",
" 306.8828125,\n",
" 387.153564453125,\n",
" 445.9556884765625,\n",
" 313.2822265625,\n",
" 429.8544921875,\n",
" 369.6336669921875,\n",
" 337.153564453125,\n",
" 475.6055908203125,\n",
" 310.482421875,\n",
" 430.354248046875,\n",
" 397.283935546875,\n",
" 297.5819091796875,\n",
" 434.7042236328125,\n",
" 402.783935546875,\n",
" 292.5321044921875,\n",
" 400.3038330078125]"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_, losses = predict(model, test_normal_dataset)\n",
"losses[:100]"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: ylabel='Density'>"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACDkAAAU2CAYAAACVtf+yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzde3RddZ03/s/JadI2vaQptOkVSq3U4ab+hswICPMIzEJuxREdxRlxZOSypKAyFHSNl0eXjqDUC+jAIMgaGAbXAyMiygMCDmgpDxcrFnSmnVpaSugN0qZNb7md3x+snubsc5KeJCc5Ocnr9Q/57uz9PZ99Aiestd/5fFKZTCYTAAAAAAAAAADDXFW5CwAAAAAAAAAAKIaQAwAAAAAAAABQEYQcAAAAAAAAAICKIOQAAAAAAAAAAFQEIQcAAAAAAAAAoCIIOQAAAAAAAAAAFUHIAQAAAAAAAACoCEIOAAAAAAAAAEBFEHIAAAAAAAAAACqCkAMAAAAAAAAAUBGEHAAAAAAAAACAiiDkAAAAAAAAAABUBCEHAAAAAAAAAKAiCDkAAAAAAAAAABVByAEAAAAAAAAAqAhCDgAAAAAAAABARRByAAAAAAAAAAAqwphyFwCjVVtbW2zfvj27Hjt2bKTT6fIVBAAAAAAAAFAinZ2dsW/fvux6ypQpUVNTM+B9hRygTLZv3x4bNmwodxkAAAAAAAAAQ2L69OkD3sO4CgAAAAAAAACgIgg5AAAAAAAAAAAVwbgKKJOxY8fmrOfOnRu1tbVlqmbkWrNmTXR2dkY6nY4FCxaUuxyAUcfnMED5+SwGKD+fxQDl57MYYOjt3r07NmzYkF0nn4/2l5ADlEk6nc5Z19bWxsSJE8tUzchVVVUVnZ2dUVVV5f0FKAOfwwDl57MYoPx8FgOUn89igPJLPh/tL+MqAAAAAAAAAICKIOQAAAAAAAAAAFQEIQcAAAAAAAAAoCIIOQAAAAAAAAAAFUHIAQAAAAAAAACoCEIOAAAAAAAAAEBFEHIAAAAAAAAAACqCkAMAAAAAAAAAUBGEHAAAAAAAAACAiiDkAAAAAAAAAABUBCEHAAAAAAAAAKAiCDkAAAAAAAAAABVByAEAAAAAAAAAqAhCDgAAAAAAAABARRByAAAAAAAAAAAqgpADAAAAAAAAAFARhBwAAAAAAAAAgIog5AAAAAAAAAAAVAQhBwAAAAAAAACgIgg5AAAAAAAAAAAVQcgBAAAAAAAAAKgIQg4AAAAAAAAAQEUQcgAAAAAAAAAAKoKQAwAAAAAAAABQEYQcAAAAAAAAAICKIOQAAAAAAAAAAFQEIQcAAAAAAAAAoCIIOQAAAAAAAAAAFUHIAQAAAAAAAACoCEIOAAAAAAAAAEBFEHIAAAAAAAAAACqCkAMAAAAAAAAAUBGEHAAAAAAAAACAiiDkAAAAAAAAAABUBCEHAAAAAAAAAKAiCDkAAAAAAAAAABVByAEAAAAAAAAAqAhCDgAAAAAAAABARRByAAAAAAAAAAAqgpADAAAAAAAAAFARhBwAAAAAAAAAgIog5AAAAAAAAAAAVAQhBwAAAAAAAACgIgg5AAAAAAAAAAAVQcgBAAAAAAAAAKgIQg4AAAAAAAAAQEUQcgAAAAAAAAAAKoKQAwAAAAAAAABQEYQcAAAAAAAAAICKIOQAAAAAAAAAAFQEIQcAAAAAAAAAoCIIOQAAAAAAAAAAFUHIAQAiYlt7Jj6/NhNn/i4Td2/KlLscAAAAAAAAChhT7gIAoNyebsnER/4QsX7vm+tfNEe8tTYTfzY5Vd7CAAAAAAAAyKGTAwCjVlcmE19fn4lTfnsg4BARkYmIezaXrSwAAAAAAAB6oJMDAKPSpn2ZuPC/Ih7bVvj7y1uGth4AAAAAAAAOTsgBgFHnkTcy8bH/itjS3vM5v22N2N2Zidq0kRUAAAAAAADDhXEVAIwabV2ZuGZNJs5cmR9wqIqI7nGGjkzEszuGsjoAAAAAAAAORsgBgFFh3Z5MnLIi4oYN+d+bPTbil++M+P8m5R5fZmQFAAAAAADAsCLkAMCI19aVidNeiHh2Z/73zj0k4oXGiFOmpOLEutzvLRdyAAAAAAAAGFaEHAAY8Z7ZEfHy3txjNamI77w14ifHRhxS/eagipMSIYend0R0ZTJDVCUAAAAAAAAHM6bcBQDAYHtpV+56enXE/317xDsnpXKOJ0MOLR0Rv98VcezEQS4QAAAAAACAoujkAMCIlww5nD41P+AQETF7bCrmjcs9tszICgAAAAAAgGFDyAGAEe/3rbnroyf0fG6ym8NyIQcAAAAAAIBhQ8gBgBEtk8nkdXI4ppeQw4mJkMNTQg4AAAAAAADDhpADACPapraI5o7cY33p5LBub8Rr+zKlLwwAAAAAAIA+E3IAYET7faKLQ21VxLxxPZ9/9ISIyencY7o5AAAAAAAADA9CDgCMaMlRFUdPiKhKpXo8P51K5Y2sWCbkAAAAAAAAMCwIOQAwohUKORxMMuSwXMgBAAAAAABgWBByAGBES46rKCbkcFIi5PBCa0RrR6Z0RQEAAAAAANAvQg4AjFiZTCYv5HBMESGHP5scke420aIzE/HsztLWBgAAAAAAQN8JOQAwYr2yL6K1M/fYMRMPft2EdCremThv2faSlQUAAAAAAEA/CTkAMGK91Jq7njImYlZNcdcmR1YsbylNTQAAAAAAAPTfmHIXUMm6urpixYoV8corr8Trr78ekydPjpkzZ0ZjY2PU1tYOWR1tbW3x/PPPR1NTUzQ3N8fUqVNj9uzZcfzxx0dNTZFP83qxcuXKWLt2bWzZsiUmTJgQDQ0N0djYGHV1dQe/uICOjo5Yv359rF+/PjZu3Bi7du2Krq6umDRpUsycOTOOPvroaGhoGHDdLS0t8dxzz8XmzZtj165dMX369Jg/f34cd9xxA94bqAwvFRhVkUqlCp+ccFJdxHdfPbB+ekdEZyYT6SKvBwAAAAAAoPSEHPqhs7Mzbr/99rjrrrtiy5Yted+vra2Ns88+O5YsWdLvIEAx9u7dGzfeeGP8x3/8R2zfvj3v+1OmTInzzz8/rrzyyhg3blyf97/33nvjBz/4Qaxfvz7ve9XV1XHaaafF5z73uZgxY8ZB92pra4tvfetbsWLFiviv//qvaGtr6/X84447Li688MI499xz+1z3xo0b47rrrovHH3882tvb875/+OGHx8UXXxwf/OAH+7w3UFl+nwg5HDWh+GuTnRx2dr4Zmnh7EeMuAAAAAAAAGBzGVfTRjh074m//9m9j6dKlBQMOERG7d++Oe++9NxYtWhR/+MMfBqWOpqamOP/88+P2228vGHCIiNi+fXvcfvvtcf7550dTU1PRe7e1tcXixYvj85//fMGAQ0REe3t7PPzww7Fo0aJ46qmnDrrn3r1744477ojf/e53Bw04RLzZPeLqq6+Oiy66KFpbWw96/n7Lli2LRYsWxcMPP1ww4BARsX79+vj85z8fixcvLqoWoHIlQw7H9CHkMHNsKo5I5MOeMrICAAAAAACgrHRy6IOOjo741Kc+FStWrMgemzVrVixatChmz54dzc3N8dhjj8WLL74YERGbNm2KSy+9NO67776SjF/Yr7W1NS677LJYs2ZN9thb3vKWOOuss6KhoSE2bdoUDz30UKxduzYiItasWROXXXZZ3HPPPTFx4sH/BPmLX/xiPProo9l1fX19nHfeeTF//vxoaWmJ5cuXx9NPPx0Rb46EuOKKK+Kee+6JhQsXFlX/+PHj49hjj40jjzwy5s6dG5MmTYqOjo7YsmVLPP/88/Hss89GV1dXREQ89dRTcfHFF8e//du/RTqd7nXfVatWxZVXXhm7dh14qnnSSSfFCSecEJMmTYq1a9fGAw88kA2FPProo/GlL30pvv71rxdVN1BZOjOZ+MPu3GN9CTlEvNnN4eW9B9ZPbY/45OwBlwYAAAAAAEA/CTn0wR133BHLly/Prs8555z4+te/HjU1Ndljl112Wdx5553xT//0T5HJZGLLli3xhS98IW699daS1XHDDTfE6tWrs+u///u/jyVLluTMmV+8eHF
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 667,
"width": 1052
}
},
"output_type": "display_data"
}
],
"source": [
"_, losses = predict(model, train_dataset)\n",
"sns.kdeplot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: ylabel='Density'>"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACEwAAAU2CAYAAACF1bFuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdeZTcdZkv/qd6T2dlSyedhRDAOAoI12RGURzXi6Lggsp1FhxXuHcQlxFHzzgy1zMOo4I6qAPHbRQHuf6Cg4jihijKMiMIGHRGIAaSkBUSkpCtu9P9/f3BpEh9ujqp7q7uby2v1zlzpj+f/tS3n6pqqT7n+87zFLIsywIAAAAAAAAAoIm05F0AAAAAAAAAAMBkE5gAAAAAAAAAAJqOwAQAAAAAAAAA0HQEJgAAAAAAAACApiMwAQAAAAAAAAA0HYEJAAAAAAAAAKDpCEwAAAAAAAAAAE1HYAIAAAAAAAAAaDoCEwAAAAAAAABA0xGYAAAAAAAAAACajsAEAAAAAAAAANB0BCYAAAAAAAAAgKYjMAEAAAAAAAAANB2BCQAAAAAAAACg6QhMAAAAAAAAAABNR2ACAAAAAAAAAGg6AhMAAAAAAAAAQNNpy7sAaFb9/f2xbdu24rqzszNaW1vzKwgAAAAAAACgSgYHB6Ovr6+4njVrVnR0dORY0XACE5CTbdu2xdq1a/MuAwAAAAAAAGBSzJ49O+8SShjJAQAAAAAAAAA0HYEJAAAAAAAAAKDpGMkBOens7CxZL1iwILq7u3OqBoiIWLlyZQwODkZra2scd9xxeZcDQB3xGQLAWPkMAWCsfIYAMFaT9Rmye/fuWLt2bXGd3h+tBQITkJPW1taSdXd3d0ybNi2naoCIiJaWlhgcHIyWlhb/ewRgVHyGADBWPkMAGCufIQCMVV6fIen90VpgJAcAAAAAAAAA0HQEJgAAAAAAAACApiMwAQAAAAAAAAA0HYEJAAAAAAAAAKDpCEwAAAAAAAAAAE1HYAIAAAAAAAAAaDoCEwAAAAAAAABA0xGYAAAAAAAAAACajsAEAAAAAAAAANB0BCYAAAAAAAAAgKYjMAEAAAAAAAAANB2BCQAAAAAAAACg6QhMAAAAAAAAAABNR2ACAAAAAAAAAGg6AhMAAAAAAAAAQNMRmAAAAAAAAAAAmo7ABAAAAAAAAADQdAQmAAAAAAAAAICmIzABAAAAAAAAADQdgQkAAAAAAAAAoOkITAAAAAAAAAAATUdgAgAAAAAAAABoOgITAAAAAAAAAEDTEZgAAAAAAAAAAJqOwAQAAAAAAAAA0HQEJgAAAAAAAACApiMwAQAAAAAAAAA0HYEJAAAAAAAAAKDpCEwAAAAAAAAAAE1HYAIAAAAAAAAAaDoCEwAAAAAAAABA0xGYAAAAAAAAAACajsAEAAAAAAAAANB0BCYAAAAAAAAAgKYjMAEAAAAAAAAANB2BCQAAAAAAAACg6QhMAAAAAAAAAABNR2ACAAAAAAAAAGg6AhMAAAAAAAAAQNMRmAAAAAAAAAAAmo7ABAAAAAAAAADQdAQmAAAAAAAAAICmIzABAAAAAAAAADQdgQkAAAAAAAAAoOkITAAAAAAAAAAATUdgAgAAAAAAAABoOgITAAAAAAAAAEDTEZgAAAAAAAAAAJqOwAQAAAAAAAAA0HQEJgAAAAAAAACApiMwAQAAAAAAAAA0HYEJAAAAAAAAAKDpCEwAAGPy821ZvOrXWbzlv7J4eE+WdzkAAAAAAACj0pZ3AQBA/bl9exYvuzdi4L9zEjduifj2iVk8d2Yh17oAAAAAAAAqpcMEADAqG/qyeP1vngpLREQ8OhDx4nsjvrlJpwkAAAAAAKA+CEwAABXrH8riDb+J2Ng//Ht9QxFv+s+IS1ZnkWWCEwAAAAAAQG0TmAAAKvbelRG37zj4mb9ZFfH2+yMGhoQmAAAAAACA2iUwAQBU5KsbsrhiXeneoq6It84dfvZfNkScsSJi24DQBAAAAAAAUJsEJgCAQ7prRxb/+4HSva6WiG+dEPHFJRGfPDaikDzmJ49HPO/uiIf2CE0AAAAAAAC1R2ACADioR/uzOPs3EX1DpftfWBJxyvRCFAqF+KuFhbj2hIgpyV8W/7U74jm/ivj37UITAAAAAABAbRGYAABGtG8oi3N+G7G2r3T/XfMj/mxOaU+J1x5ViJ+dEtHTUXr20YGIV6yIWNcnNAEAAAAAANQOgQkAYER/vSriZ9tK914wM+LSY8ufXzajEP/+7IhnTi3d374v4op1E1IiAAAAAADAmAhMAABlXbMpi0+vLd2b1xnxzRMi2lsK5R8UEUd3FeLW/xHxwlml+//f5ogs02UCAAAAAACoDQITAMAwv92Vxdt/V7rXUYi49pkRPR0jhyX2m9lWiI8nXShW7om4e2cViwQAAAAAABgHgQkAYJj/+1DEnqHSvc89LeKPZh46LLHf0ukRi7tK9765qQrFAQAAAAAAVIHABABQYu9gFjduKd17R2/E23srD0tERBQKhXjj7NI9YzkAAAAAAIBaITABAJT46baI3Qd0l2iJiI8dM7ZrndNTul7TF/HvO8ZaGQAAAAAAQPUITAAAJW54rHT93JkRR3aMrrvEfidNjVjSXbr3zc1jLAwAAAAAAKCKBCYAgKIsy+J7yTiOVx0x9usVCoU4JxnLsXxzxKCxHAAAAAAAQM4EJgCAol/vjFjbV7p35pHju2YamNjQH3HrtvFdEwAAAAAAYLwEJgCAohuS7hLHdEX8QXf5s5X6g6mFOGlq6d7/M5YDAAAAAADImcAEAFD03cdK16868smxGuP1xqTLxLcejdg3ZCwHAAAAAACQH4EJACAiIjb0ZXHnE6V7Zx5RnWuf01O6fmwg4uZt1bk2AAAAAADAWAhMAAAREXFjMo5jRmvEC2ZV59rHTinE0umle980lgMAAAAAAMiRwAQAEBER300CE6cfHtHRMv5xHPudk4zluO7RiH5jOQAAAAAAgJwITAAAsXcwix9vLd171ZHV/RlvTAIT2/ZF/Ghr+bMAAAAAAAATTWACAIibt0XsHnpq3RIRZxxR3Z+xoKsQp84o3fv/jOUAAAAAAAByIjABAMQNj5WuT50ZcUR79cZx7HdOT+n6249F7Bk0lgMAAAAAAJh8AhMA0OSyLIvvbSnde2WVu0vs9/qjIg6MYewcjPi+sRwAAAAAAEAOBCYAoMnduzPikb7SvTOPnJifNbezEC+cVbpnLAcAAAAAAJAHgQkAaHLfTbpLLO6K+IPuift5b5xdur7hsYid+4zlAAAAAAAAJpfABAA0ue8+Vrp+1ZERhUKh/OEqOPuoiNYDLr9naHhoAwAAAAAAYKIJTABAE9vQl8WdT5TunXnExP7MIzsK8dLDSve+aSwHAAAAAAAwyQQmAKCJfS/p7DCjNeK0WRP/c89JxnJ8f0vEdmM5AAAAAACASSQwAQBNLB2F8fIjIjpaJm4cx36vOTKi44Af059FXP/YyOcBAAAAAACqTWACAJrUnsEsfry1dO9VEzyOY79Z7YU4/fDSvW9umpyfDQAAAAAAECEwAQBN6+bHI/YMPbVuiYhXTFJgIiLinJ7S9c+2RewbMpYDAAAAAACYHAITANCkbkjGcZw6M+KI9okfx7Hf/zysdL1nKOI/d0/ajwcAAAAAAJqcwAQANKEsy+J7SWBissZx7HdkRyEWdZXu3bljcmsAAAAAAACal8AEADShe3dGrOsr3TvzyMmvY9n00vVdT0x+DQAAAAAAQHMSmACAJnTDY6XrY6dEPL178ut4tsAEAAAAAACQE4EJAGhC3y0zjqNQKEx6HctmlK5X7IzoG8omvQ4AAAAAAKD5CEwAQJPZ0JcN6+SQxziOiOEdJgayiF/vzKcWAAAAAACguQhMAECTuWNH6XpGa8RpM/OpZUZbIZYko0CM5QAAAAAAACaDwAQANJl7k0DC0ukR7S2TP45jv2VJl4m7dpQ/BwAAAAAAUE0CEwDQZO5NRl6cPL38ucmydEbp+k4dJgAAAAAAgEkgMAEATeaeJDBxyrR86thvaRLY+K9dETv3ZfkUAwAAAAAANA2
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 667,
"width": 1062
}
},
"output_type": "display_data"
}
],
"source": [
"# _, losses = predict(model, train_dataset)\n",
"_, losses = predict(model, test_normal_dataset)\n",
"sns.kdeplot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: ylabel='Density'>"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACDkAAAU2CAYAAACVtf+yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzde5zWdZ03/tc1w3AYjh5gQBARFTXFPEC/tNsy7ahFB/Peu5NtbZu2S9rJsnvb3HY7mt7tmm5taW12l3unm9rBsrSyErdUNMhMQOQgcgw5w8wwc/3+YBn5XgwwMwxcc3g+H48eXZ/3fK4P7wG+ODov3p9SuVwuBwAAAAAAAACgh6updgMAAAAAAAAAAB0h5AAAAAAAAAAA9ApCDgAAAAAAAABAryDkAAAAAAAAAAD0CkIOAAAAAAAAAECvIOQAAAAAAAAAAPQKQg4AAAAAAAAAQK8g5AAAAAAAAAAA9ApCDgAAAAAAAABAryDkAAAAAAAAAAD0CkIOAAAAAAAAAECvIOQAAAAAAAAAAPQKQg4AAAAAAAAAQK8g5AAAAAAAAAAA9ApCDgAAAAAAAABAryDkAAAAAAAAAAD0CkIOAAAAAAAAAECvMKDaDUB/1dTUlHXr1rWtBw0alNra2uo1BAAAAAAAANBNWlpa0tjY2LYeNWpUBg4cuN/nCjlAlaxbty5Lly6tdhsAAAAAAAAAB8WYMWP2+wzXVQAAAAAAAAAAvYKQAwAAAAAAAADQK7iuAqpk0KBBhXVNTU2OO+64KnUDvdOCBQvS0tKS2traHHvssdVuB3odzxB0necHus7zA13n+YH94xmCrvP8QNd5fujPtmzZkqVLl7atK78/2lVCDlAltbW1hXWpVMqwYcOq1A30TjU1NWlpaUlNTY3nB7rAMwRd5/mBrvP8QNd5fmD/eIag6zw/0HWeH3hO5fdHu8p1FQAAAAAAAABAryDkAAAAAAAAAAD0CkIOAAAAAAAAAECvIOQAAAAAAAAAAPQKQg4AAAAAAAAAQK8g5AAAAAAAAAAA9ApCDgAAAAAAAABAryDkAAAAAAAAAAD0CkIOAAAAAAAAAECvIOQAAAAAAAAAAPQKQg4AAAAAAAAAQK8g5AAAAAAAAAAA9ApCDgAAAAAAAABAryDkAAAAAAAAAAD0CgOq3UBv1tramtmzZ2fJkiVZs2ZNRowYkXHjxmX69Ompr68/aH00NTXloYceyrJly7J27doceuihGT9+fKZNm5aBAwfu9/lz5szJwoULs2rVqgwdOjQNDQ2ZPn16Ro4cud9nL126NHPnzs2qVatSLpfT0NCQqVOn5sgjj9zvs3davHhxnnjiiaxatSpbtmzJoYcemsMPPzyTJ0/OxIkTu+3HAQAAAAAAAODAEnLogpaWltx000351re+lVWrVu328fr6+lxwwQW54ooruiUIsCfbtm3Lddddl//8z//MunXrdvv4qFGjcuGFF+ayyy7L4MGDO33+rbfemq997WtZvHjxbh+rq6vLeeedl4997GMZO3Zsp89+6KGHcs011+SRRx5p9+OnnXZaPvzhD2fatGmdPjtJyuVybr311txyyy354x//uMd9o0ePzmte85pceeWVXfpxAAAAAAAAADh4XFfRSRs2bMjb3va2XHvtte0GHJJky5YtufXWWzNjxoy9foN9fyxbtiwXXnhhbrrppnYDDkmybt263HTTTbnwwguzbNmyDp/d1NSUmTNn5uMf/3i7AYckaW5uzk9+8pPMmDEj999/f6d6/+pXv5qLL754jwGHJHnkkUdy8cUX56tf/Wqnzk52TId485vfnL//+7/f58//6tWrc9ddd3X6xwAAAAAAAADg4DPJoRO2b9+eyy+/PLNnz26rHXHEEZkxY0bGjx+ftWvX5p577sncuXOTJCtWrMgll1yS2267LQ0NDd3Wx6ZNm3LppZdmwYIFbbVjjjkm559/fhoaGrJixYrcddddWbhwYZJkwYIFufTSS3PLLbdk2LBh+zz/E5/4RH72s5+1rQ855JC87nWvy+TJk7N+/frMmjUrDzzwQJJk/fr1ed/73pdbbrklxx9//D7P/t73vpdrr722bV1XV5cLLrggU6dOTWtra+bOnZsf//jHaW5uTktLS6699tqMHj06b3jDGzr0c7N06dK8/e1vz/Lly9tqDQ0NOeecc3L00UdnxIgR2bRpU5588snMnj078+fP79C5AAAAAAAAAFSfkEMnfOMb38isWbPa1q95zWvy2c9+NgMHDmyrXXrppbn55pvzmc98JuVyOatWrcrf//3fd2kiwZ5cc801mTdvXtv6r/7qr3LFFVekVCq11WbOnJmrr746X//615Mk8+bNy7XXXpurrrpqr2ffdddduf3229vWL3zhC3PDDTcUwhHvec978pOf/CRXXHFFmpqasnnz5nz4wx/OnXfemZqaPQ8HWb58eeHHHzduXG666aYcc8wxhX2XXHJJ3v3ud7cFFT7xiU/khS98YcaNG7fX3jdv3px3vvOdbe8bNGhQPvjBD+btb397amtr233PokWLcvfdd+/1XAAAAAAAAAB6BtdVdNCmTZty4403tq2f97zn5fOf/3wh4LDTxRdfnLe+9a1t6/vuuy8PP/xwt/SxdOnS3HbbbW3rl770pfnIRz5SCDgkSalUykc/+tG89KUvbavdeuutWbp06R7PbmlpyXXXXde2Hjt27G4Bh51e9apX5QMf+EDbet68efnhD3+4196vv/76NDU1JUlqa2tz3XXX7RZwSJJjjz021113XVswoampKTfccMNez06SL37xi22fX11dXa6//vr85V/+5R4DDkkyadKkXHLJJfs8GwAAAAAAAIDqE3LooDvvvDPr1q1rW19xxRUZMGDPgzDe//73Z8iQIW3rm2++uVv6uOWWW9Lc3JxkR5Dhyiuv3Ov+XT/e3NycW265ZY97f/Ob3+Spp55qW8+cOXOv11u84x3vyBFHHNG23tvnuGHDhtx5551t6/PPPz+nnHLKHvefcsopOf/889vWd9xxRzZu3LjH/QsWLMi3v/3ttvW73vWuvPjFL97jfgAAAAAAAAB6HyGHDrr33nvbXo8fPz5nnnnmXvcPHz48r3zlK9vWv/71r9umGHRXH9OnT8+kSZP2un/SpEmZPn16u++vdM8997S9rq+vL4QM2lNbW5s3vOENbes//OEPWblyZbt777vvvrZwRpJcdNFFez07Sd70pje1vW5ubs599923x73/8R//kdbW1iTJsGHDcumll+7zfAAAAAAAAAB6FyGHDti2bVt+97vfta3POuus3a6HaM9ZZ53V9nrz5s37fWXF4sWLs2jRonbP72gfixYtypIlS9rdt2uI4NRTT83QoUM7dXa5XM6vfvWrfZ49ePDgnH766fs8+4wzzsjgwYPbPWNXTU1N+cEPftC2fvWrX536+vp9ng8AAAAAAABA7yLk0AELFy4sTCE49dRTO/S+0047rbB+4okn9quPefPmFdZd7aPynCRZv359YQpDR8+eOnVq4dqO9s6urJ988smpq6vb59l1dXU56aST9nn2448/XrhKxDUVAAAAAAAAAH3TgH1v4cknnyysJ06c2KH3jR8/PrW1tWlpaUmyIyxRjT6OPPLI3c552cte1i1nDxo0KA0NDVm2bFmS9j/H1tbWwgSKjp69c+/OCRhPPfVUWltbU1NTzObMmTOnsH7+85+fJJk/f35uvfXW3H///Vm+fHnK5XIOPfTQTJ06Neeee27OP//8QkADAAAAAAAAgJ7Nd3g74Omnny6sx40b16H31dbWZsyYMVm+fHmSZOnSpd3WR01NTRoaGjr0voaGhtTU1KS1tXWPfXT1c9y5d2fIob2zV69encbGxi6fvVNjY2NWr1692+f9pz/9qe11fX19Dj/88PzLv/xL/u3f/q0tYLLTli1b8vTTT+fHP/5xbrjhhlx99dVtoQgAAAAAAAAAejbXVXTApk2bCuuRI0d2+L3Dhw9ve7158+Zu62Po0KEdnkJQV1eXIUOG7LWPA/k5Vp49YsSIDp9dubfyrCRZu3Zt2+sxY8bkU5/6VP7
"text/plain": [
"<Figure size 1200x800 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 667,
"width": 1052
}
},
"output_type": "display_data"
}
],
"source": [
"# _, losses = predict(model, train_dataset)\n",
"_, losses = predict(model, test_anomaly_dataset)\n",
"sns.kdeplot(losses)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predictions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute THRESHOLD with training set data"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"predictions, losses = predict(model, train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
"loss_array = np.array(losses)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
"stdev = np.std(loss_array)\n",
"mean = np.mean(loss_array)\n",
"THRESHOLD = mean + stdev "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check on test_normal_dataset"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of intervals exceeding 1 std dev loss: 8/150\n"
]
}
],
"source": [
"_, losses = predict(model, test_normal_dataset)\n",
"exceed_count = sum(l > THRESHOLD for l in losses)\n",
"print(f'number of intervals exceeding 1 std dev loss: {exceed_count}/{len(test_normal_dataset)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check on test_anomaly_dataset"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of intervals exceeding 1 std dev loss: 118/150\n"
]
}
],
"source": [
"_, losses = predict(model, test_anomaly_dataset)\n",
"exceed_count = sum(l > THRESHOLD for l in losses)\n",
"print(f'number of intervals exceeding 1 std dev loss: {exceed_count}/{len(test_normal_dataset)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## plot construction error vs original"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
"def plot_prediction(data, model, title, ax):\n",
" predictions, pred_losses = predict(model, [data])\n",
"\n",
" ax.plot(data, label='true')\n",
" ax.plot(predictions[0], label='reconstructed')\n",
" ax.set_title(f'{title} (loss: {np.around(pred_losses[0], 2)})')\n",
" ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEPsAAAYSCAYAAABQH1kdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1dvG8Xs3PYGQhN4RhCBdRUBRUUFQQRC7WBBRbCDSRH4qYgVUREQFO6KgrygqooKiWBBERATpHRI6BBIgbZOd94+VJZO6m2zJZr+f68pFzmTK2TBJ7jl75hmLYRiGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHid1d8dAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIIFxT4AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAH6HYBwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOAjFPsAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAfIRiHwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAICPUOwDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA8BGKfQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+QrEPAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwEco9gEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD4CMU+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB+h2AcAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADgIxT7AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHyEYh8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAj1DsAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPARin0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAPkKxDwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMBHKPYBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+AjFPgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAfodgHAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA4CMU+wAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8JNTfHQAAFO/222/Xn3/+KUmqW7eufvrpJ68fc8CAAVq6dKkkaezYsbr11lsLXW/u3LkaM2aMsz1z5kx17NjR6/2D+7KysrRt2zYlJyfr4MGDSk9Pl91uV+XKlVWjRg21bNlSderU8Xc3XZKamqpt27Zp7969Onz4sDIyMhQaGqrY2FjVrVtXrVq1UlxcnM/7ZbPZtGPHDm3dulWHDx/WyZMnFR0drfj4eDVv3lxNmzaVxWLxyLFyc3O1ZcsWbdy4UceOHVNGRoYqVaqkatWqqVWrVqpfv75b+5s8ebKmT58uSerTp49eeOEFj/QTAABPIhfDEypSLga5GAAAV5GlESgMw1BycrJ27dqlffv2KS0tTVlZWYqOjlZcXJyaNm2qZs2aKSwszN9dDRjZ2dm68sorlZycLEmaMWOGzj//fD/3CgAAzyDnwhN8MWbsr5ybnZ2tNWvWaPv27Tp27JjzddWtW1fNmjVjLNwFjBkDAAIBuRiBJCUlRZs3b9bu3buVmpoqwzBUpUoV1a5dW+3atVNsbKy/u+h1KSkpWrlypQ4ePKgTJ04oISFBNWrU0LnnnqtKlSr5u3uFIhcDAAIBuRgoXnkZL+beO6BwFPsAypHk5GR17dq1wPInn3xS/fr1c2tfl112mfbs2SNJqlatmn7//XeP9BEV3w8//OC82Khfv75uvPFGP/cIpfXbb79p4cKFWrVqlXbs2KHc3Nxi12/UqJFuvPFG3XrrrYqMjHT7eFlZWdqwYYPWrFmjf//9V2vWrNGuXbtkGIZznR9//FH16tVze9+zZ8/W8uXLtWbNGu3du7fE9c877zz169dPV111ldvHcsfBgwe1YMEC/fbbb/rrr7+Unp5e5LpxcXG67rrr1L9/f9WsWbNUxzt06JDeffddffHFFzp27FiR6zVq1Ei33nqrbrnlFpcmBN1999365JNPdOzYMc2bN0/9+vVTu3btStVHAPAEcjHKA3JxxeGrXJz39427br31Vo0dO7bE9aZOnarXXnutVMeIjo7WqlWrSrWtO/bu3eu8Hvj333+1du1anTx50vn1vn37asKECWU6BrkYAIpGlkZ5QJYOLjabTddee602b95sWj5+/Hhde+21JW6/b98+ffLJJ/r777+1bt06U3YsTHR0tHr27KkBAwaoSZMmperzrl27TOPYGzZsUGZmpvPrgwcP1pAhQ0q1b29IS0vTlVdeqcOHD5uWuzIJLDw8XEOGDNHo0aMlSc8//7y+/PJLhYSEeK2/AOAN5FyUB+TcisMXY8b+yLmn7Ny5U2+++aYWLFhQ7ByGmjVr6uKLL9bDDz+satWqlemYxfHmmHFiYmKp+/XII49o4MCBxa7DmDGA8oZcjPKAXBxcyjr+m5ubqz///NN53uzYsaPIdS0Wi9q3b6/+/fvr8ssvd6ufgTCXYtmyZZo6dar+/vtv07zqUyIiInTppZfqoYcecvuagFwMINiQi1EekIsrjuXLl+uOO+4o9fafffaZWrdu7ZG+vPnmm3r55ZdNyzp06KAPP/yw1PssD+PF3HsHlMzq7w4AKNm0adOUlZXl724gCNhsNk2cONHZfvDBB3lCXgD7+OOPNWfOHG3durXEySmSI8C/8MILuvrqq7VmzRqXj/Paa6/p2muv1bnnnqubbrpJzz33nObNm6edO3cWOiBdGi+//LIWLFjgUqEPSVqxYoWGDRumO++8U0eOHPFIH/J74YUX1KVLFz333HP69ddfi73YkKRjx47p3XffVc+ePTVv3jy3j7d48WL16tVL77//frEXG5Lj//K5557TjTfeqKSkpBL3XblyZd11112SHE8Vev75593uHwD4ArkYvkIurlh8lYuD2aZNm3Tfffepc+fOzskn77zzjpYvX17iJHZ3kYsBoHTI0vAVsnTwee+99wpM9HbHunXrNH36dP35558uZcf09HTNmTNHffr00VtvveXycZYvX66BAweqQ4cO6t69u0aOHKkPPvhAq1atMhX6KI9eeumlAoU+3NG7d281btxYkrR582Z9+umnnuoaAPgdORe+Qs6tWHwxZuyrnJuX3W7XG2+8oV69emnu3LklzmE4cOCA5syZo+Tk5FIdrzi+HDP2JsaMAQQKcjF8hVwcfMoy/rtnzx5dfPHFuvPOOzVr1qxiC31Ijry1YsUKDR48WIMGDfLa3F9fs9lsGjNmjO68806tXLmyyHnVWVlZWrBgga699lp98sknPu5l8cjFAAIFuRi+Qi6GNyQlJemNN97w2P7Ky3gx994BrqHYBxAADh48qNmzZ/u7GwgCn3/+uTOc1K5dW7169fJzj+BJkZGRatasmS699FL16tVLV1xxhc477zxVrlzZtN7u3bs1YMAAlyep/PDDD1q3bp1sNps3ul2oKlWqqG3bturatat69+6t7t27q1WrVgUukJctW6b+/fvr6NGjHu9DcnKy7HZ7geUNGjRQ586d1atXL1166aWqXbu26evHjx/XqFGj3BqM/+233zRkyJACFxr16tVT165d1atXL1144YWqUqWK6evr16/XnXfe6dJE8FtuuUUxMTGSpNWrV2vx4sUu9w8AfIVcDF8hF1ds3srFwSwpKUmLFy8u0w2IriAXA0DpkaXhK2Tp4OLpySaSZLVaVbduXV1wwQW64oordPXVV+uSSy5RvXr1TOvZbDZNmjRJL730kkv73bBhg5YsWaLU1FSP9tfb/v777zIX57Barc4JJ5JjomN2dnZZuwYA5QI5F75Czq3YfDFm7K2ce0pubq5GjhypKVOmmOZuhISEqGXLlrrsssvUq1cvXXTRRapTp47b/XeXr8aMfYExYwCBgFwMXyEXB5eyjv+ePHmy0DwYFRWldu3aqVu3brrqqqt0zjnnFJj7+8svv6h///4l3pxX3tntdj344IOaO3euaXl0dLRznnHnzp0VHR3t/FpmZqaefPLJAtv4G7kYQCAgF8N
"text/plain": [
"<Figure size 2200x800 with 12 Axes>"
]
},
"metadata": {
"image/png": {
"height": 777,
"width": 2173
}
},
"output_type": "display_data"
}
],
"source": [
"fig, axs = plt.subplots(\n",
" nrows=2,\n",
" ncols=6,\n",
" sharey=True,\n",
" sharex=True,\n",
" figsize=(22, 8)\n",
")\n",
"\n",
"sample_size = 6\n",
"sample_indices = random.sample(range(0,len(test_normal_dataset)), sample_size)\n",
"\n",
"sampled_test_normal_dataset = [test_normal_dataset[i] for i in sample_indices]\n",
"sampled_test_anomaly_dataset = [test_anomaly_dataset[i] for i in sample_indices]\n",
"\n",
"for i, data in enumerate(sampled_test_normal_dataset):\n",
" plot_prediction(data, model, title='Normal', ax=axs[0, i])\n",
"\n",
"for i, data in enumerate(sampled_test_anomaly_dataset):\n",
" plot_prediction(data, model, title='Anomaly', ax=axs[1, i])\n",
"\n",
"fig.tight_layout();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}