Files
Machine-Learning-Models/ShipsSatelliteImageClassification/main.ipynb

804 lines
54 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ship Detection using Faster R-CNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Importing libraries\n",
"\n",
"The necessary libraries required for implementing the Faster R-CNN model are imported in the code block below."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checking System\n",
"\n",
"Tensorflow version : 2.11.0-dev20220812\n",
"Number of replicas: 1\n",
"2.11.0-dev20220812\n",
"\n",
"System check done in 0.0104 seconds\n",
"System Checked!\n"
]
}
],
"source": [
"import time\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import os, random, cv2, pickle, json, itertools\n",
"import imgaug.augmenters as iaa\n",
"import imgaug.imgaug\n",
"\n",
"from IPython.display import SVG\n",
"# from tensorflow.python import keras\n",
"from keras.utils import plot_model, model_to_dot\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import confusion_matrix\n",
"from collections import Counter\n",
"from sklearn.utils import class_weight\n",
"from tqdm import tqdm\n",
"from sklearn.preprocessing import LabelBinarizer\n",
"\n",
"from keras.utils import to_categorical\n",
"from keras.models import Sequential, Model\n",
"from keras.layers import (Add, Input, Conv2D, Dropout, Activation, BatchNormalization, MaxPooling2D, ZeroPadding2D, AveragePooling2D, Flatten, Dense)\n",
"from keras.optimizers import Adam, SGD\n",
"from keras.callbacks import TensorBoard, ModelCheckpoint, Callback\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.initializers import *\n",
"\n",
"\n",
"start_time = time.perf_counter()\n",
"print(\"Checking System\\n\")\n",
"\n",
"SEED = 1337\n",
"print('Tensorflow version : {}'.format(tf.__version__))\n",
"\n",
"try:\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
" tf.config.experimental_connect_to_cluster(tpu)\n",
" tf.tpu.experimental.initialize_tpu_system(tpu)\n",
" strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"except ValueError:\n",
" strategy = tf.distribute.get_strategy() # for CPU and single GPU\n",
" print('Number of replicas:', strategy.num_replicas_in_sync)\n",
" \n",
"print(tf.__version__)\n",
"\n",
"end_time = time.perf_counter()\n",
"\n",
"print(f\"\\nSystem check done in {end_time - start_time:0.4f} seconds\")\n",
"print(\"System Checked!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utility functions\n",
"\n",
"1. show_final_history - For plotting the loss and accuracy of the training and validation datasets\n",
"2. plot_confusion_matrix - For plotting the percentage of true positives per class for a better feel of how the model predicted the data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def show_final_history(history):\n",
" \n",
" plt.style.use(\"ggplot\")\n",
" fig, ax = plt.subplots(1,2,figsize=(15,5))\n",
" ax[0].set_title('Loss')\n",
" ax[1].set_title('Accuracy')\n",
" ax[0].plot(history.history['loss'],label='Train Loss')\n",
" ax[0].plot(history.history['val_loss'],label='Validation Loss')\n",
" ax[1].plot(history.history['accuracy'],label='Train Accuracy')\n",
" ax[1].plot(history.history['val_accuracy'],label='Validation Accuracy')\n",
" \n",
" ax[0].legend(loc='upper right')\n",
" ax[1].legend(loc='lower right')\n",
" plt.show();\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def plot_confusion_matrix(cm,classes,title='Confusion Matrix',cmap=plt.cm.Blues):\n",
" \n",
"# np.seterr(divide='ignore',invalid='ignore')\n",
" cm = cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]\n",
" plt.figure(figsize=(10,10))\n",
" plt.imshow(cm,interpolation='nearest',cmap=cmap)\n",
" plt.title(title)\n",
" plt.colorbar()\n",
" tick_marks = np.arange(len(classes))\n",
" plt.xticks(tick_marks, classes,rotation=45)\n",
" plt.yticks(tick_marks, classes)\n",
" \n",
" fmt = '.2f'\n",
" thresh = cm.max()/2.\n",
" for i,j in itertools.product(range(cm.shape[0]),range(cm.shape[1])):\n",
" plt.text(j,i,format(cm[i,j],fmt),\n",
" horizontalalignment=\"center\",\n",
" color=\"white\" if cm[i,j] > thresh else \"black\")\n",
" pass\n",
" \n",
" plt.ylabel('True Label')\n",
" plt.xlabel('Predicted Label')\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the images\n",
"\n",
"The images from the SatelliteImageryofShips is loaded into numpy arrays, with labels [0,1] corresponding to the classes no-ship and ship. The data was loaded into numpy arrays as data augmentation and upsampling/downsampling is easier to perform."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'no-ship': 0, 'ship': 1}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets = ['./SatelliteImageryofShips/']\n",
"\n",
"class_names = [\"no-ship\",\"ship\"]\n",
"\n",
"class_name_labels = {class_name:i for i,class_name in enumerate(class_names)}\n",
"\n",
"num_classes = len(class_names)\n",
"class_name_labels"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def load_data():\n",
" images, labels = [], []\n",
" \n",
" for dataset in datasets:\n",
" \n",
" for folder in os.listdir(dataset):\n",
" label = class_name_labels[folder]\n",
" \n",
" for file in tqdm(os.listdir(os.path.join(dataset,folder))):\n",
" \n",
" img_path = os.path.join(dataset,folder,file)\n",
" \n",
" img = cv2.imread(img_path)\n",
" img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)\n",
" img = cv2.resize(img, (48,48))\n",
" \n",
" images.append(img)\n",
" labels.append(label)\n",
" pass\n",
" pass\n",
" \n",
" images = np.array(images,dtype=np.float32)/255.0\n",
" labels = np.array(labels,dtype=np.float32)\n",
" pass\n",
" \n",
" return (images, labels)\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3000/3000 [00:05<00:00, 572.93it/s]\n",
"100%|██████████| 1000/1000 [00:01<00:00, 566.53it/s]\n"
]
},
{
"data": {
"text/plain": [
"((4000, 48, 48, 3), (4000,))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(images, labels) = load_data()\n",
"images.shape, labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EDA of dataset\n",
"\n",
"1. bar-plot - Bar plot is made to find the count of images per class\n",
"2. pie-plot - Pie plot is drawn to find the percentage of class distribution in the dataset"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Count</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Class-Label</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>no-ship</th>\n",
" <td>3000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ship</th>\n",
" <td>1000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Count\n",
"Class-Label \n",
"no-ship 3000\n",
"ship 1000"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_labels = labels.shape[0]\n",
"\n",
"_, count = np.unique(labels, return_counts=True)\n",
"\n",
"df = pd.DataFrame(data = count)\n",
"df['Class Label'] = class_names\n",
"df.columns = ['Count','Class-Label']\n",
"df.set_index('Class-Label',inplace=True)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEWCAYAAACKSkfIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAdYElEQVR4nO3dfZxWdZ3/8ddbQNDAREBCQEEdS0EdZSQz71bbpNL1ZtOwNukWdfVXtmZpj8em1VJWbrpa6moa2o2IqQvrTWqspLbegIQBIklBMoLceROksICf3x/nO3Ycrpm5hhlmGL7v5+NxPa5zfc/3fM/3zFzzvs58z7nOUURgZmZ52KGzO2BmZh3HoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGH/nZI0kRJ/5amj5K0oB3bvl/SuDT9KUmPtWPbn5D0YHu114r1vl/S85LWSjqlwvx5ko7t6H51deX3oW07und2B2zriohHgXe3VE/SZcC+EfFPLbT3ofbol6RhwCKgR0RsTG3/HPh5e7TfSt8EfhgR/1FpZkSM6OD+mG013tO3qqiwvb5f9gLmdXYntnWSvJO4Hdhe/4izIukQSbMkrZF0O9CrNO9YSfWl11+V9GKqu0DS8ZLGAF8DPpaGOJ5JdadLmiDpt8DrwN6p7HNvX72ukfSapOckHV+asVjSB0qvL5P0s/TykfT8alrn+xoPF0k6QtKM1PYMSUeU5k2X9C1Jv03b8qCk/s38jD4vaaGklyVNlbRHKv8jsDfw36kfPSss+9Z2pG24Q9LP0nrnSNpP0iWSVkhaIumDpWU/LWl+qvsnSWc3avsrkpZJWirpc5JC0r5pXk9JV0h6QdJySddL2inN6y/pHkmvpm16tKkP5dTmF9L6V0n6frmupM+kPr4i6QFJezVa9jxJzwPPN9H+kZL+N/VliaRPVajTN/V3ZVrPPZKGlOZ/KvVvjaRFkj6RyveV9Jv0HliV3t/WBg79Lk7SjsB/AT8FdgPuAP6xibrvBs4HDouIPsAJwOKI+BXwbeD2iOgdEQeXFvskMB7oA/y5QrPvBf4E9AcuBe6StFsVXT86Pe+a1vl4o77uBtwLXA30A34A3CupX6nax4FPA7sDOwJfbmK7jwO+A5wBDErbMQkgIvYBXgBOSv1YX0XfT6L4efcFfgc8QPG3NJhiqOg/S3VXACcCu6S+Xinp0NSvMcC/AB8A9gWOabSe7wL7AbVp/mDg62nehUA9MAAYSPGh3dw1VU4F6oBDgZOBz6Q+nJKWPS219ShwW6NlT6H4PR/QuFFJewL3A9ek5WuB2RXWvwPwE4r/qvYE3gB+mNp4B8Xv+UPpfXlEqY1vAQ9S/KyHpPVYGzj0u77DgR7AVRGxISJ+Ccxoou4moCdwgKQeEbE4Iv7YQvsTI2JeRGyMiA0V5q8orft2YAHwkS3clrKPAM9HxE/Tum8DnqMI3AY/iYg/RMQbwGSKwKnkE8DNETErhfolwPtUHFfYEo9GxAPpWMQdFGF3efr5TAKGSdoVICLujYg/RuE3FAF2VGrnjLQN8yLideAbDSuQJODzwJci4uWIWEPxwTw2VdlA8QG2V/rZPxrNX0jru6mdF4CrgDNT+dnAdyJiftqebwO15b39NP/l9HNu7BPAryPittSP1RExu3GlVH5nRLyetmUCb/+QexMYKWmniFgWEQ3DbRsoPij2iIh1EdFuJw7kyqHf9e0BvNjoD77SHjkRsRC4ALgMWCFpUsMwRzOWtDC/0rpbarMae7D5dvyZYm+3wUul6deB3tW0FRFrgdWN2mqN5aXpN4BVEbGp9JqGvkj6kKQn0hDMq8CHKf4rauhX+edbnh4A7Aw8nYZNXgV+lcoBvg8sBB5MwyIXt9Dnctvl39FewH+U1vEyIN7+s2nuPTAUaGnHAUk7S/pPSX+W9BeK4b1dJXWLiL8CHwPOAZZJulfSe9KiX0n9eUrFWVSfaWld1jyHfte3DBic9gwb7NlU5Yj4RUQcSfHHHhRDCND00EBLl2GttO6lafqvFMHV4F2taHdp6mPZnsCLLSzXYltpOKHfFrZVtXR84E7gCmBgROwK3EcRYlD87oaUFhlaml5F8QEyIiJ2TY93RkRvgIhYExEXRsTeFP/9/ItKx1MqKLdd/h0tAc4urWPXiNgpIv63VL+539USYJ9m5je4kOIssvdGxC78bXhPaXseiIi/p/jv5TngxlT+UkR8PiL2oPiv5NqGYx62ZRz6Xd/jwEbgC5K6SzoNGF2poqR3SzouhdE6ilBp2ENdTjEs0dr3xO5p3T0knQ7sTxFsUIzLjk3z6oCPlpZbSfEv/d5NtHsfsJ+kj6ft+hjFmPI9rewfwC+AT0uqTdv+beDJiFi8BW21xo4Uw2krgY2SPgR8sDR/curX/pJ25m/j9UTEmxTBd6Wk3QEkDZZ0Qpo+MR3kFPAXit/jJpp2UTqYOhT4ItBwQPR64BJJI1K770y/x2r9HPiApDPS76mfpNoK9fpQvN9eTcdrLm2YIWmgpH9IH8brgbUN2yLp9NIB31coPoCa205rgUO/i4uI/6M4CPcpij+KjwF3NVG9J3A5xV7kSxSB/bU07470vFrSrFZ04UmgJrU5AfhoRKxO8/6VYi/wFYrx6l+U+v16qv/bNLRweKPtWk1xAPRCiqGYrwAnRsSqVvStoa1pqS93Uuxd78Pfxsa3mjR2/QWKcH+F4sDz1NL8+ykOYD5MMVTTcDC74WDyV1P5E2lI5Nf87TsXNen12rTctRExvZnuTAGepvggvhe4KfXhbor/9ialdcwFqv4uRjpG8GGK39PLqf2DK1S9CtiJ4n3yBMVQVYMd0vJLUxvHAP+c5h0GPClpLcXP7osRsaja/tnm5JuomG0bJO1PEbo9G76w1k7tBlCTjulY5rynb9aJJJ0qaUdJfSn2uP+7PQPfrDGHvlnnOptizP+PFGPV53Zud2x75+EdM7OMeE/fzCwj2/wFlPr37x/Dhg3r7G6YmXUpTz/99KqIGNC4fJsP/WHDhjFz5szO7oaZWZciqeI38z28Y2aWEYe+mVlGHPpmZhnZ5sf0zcyasmHDBurr61m3bl1nd6XT9OrViyFDhtCjR4+q6jv0zazLqq+vp0+fPgwbNoy3X+w1DxHB6tWrqa+vZ/jw4VUt4+EdM+uy1q1bR79+/bIMfABJ9OvXr1X/6bQY+pJ6SXpK0jPpJgbfSOW7SXpI0vPpuW9pmUtU3I90QcOlYFP5KBX3FF0o6Wrl+psys3aTe4y0dvur2dNfDxyX7ptaC4xJl8G9GJgWETXAtPQaSQdQXLZ2BDCG4qYH3VJb11Hcb7UmPca0qrdmZtYmLY7pp1vhrU0ve6RHUNxc+dhUfgswneL63ycDk9K9SBdJWgiMlrQY2KXhBtiSbqW44fL97bMpZpa7YRff267tLb685ds9v/TSS1xwwQXMmDGDnj17MmzYMK666ir222+/dunD9OnT2XHHHTniiCPapb2qDuSmPfWngX2BH0XEk5IGRsQygIhY1nB3H4p7az5RWrw+lW1I043LK61vPMV/BOy5Z5N3/tumtPebLWfV/KGZbQsiglNPPZVx48YxadIkAGbPns3y5cvbNfR79+7dbqFf1YHciNgUEbUU9/McLWlkM9UrDTBFM+WV1ndDRNRFRN2AAZtdOsLMbJvw8MMP06NHD84555y3ympraznyyCO56KKLGDlyJAceeCC3317cnXL69OmceOKJb9U9//zzmThxIlBccubSSy/l0EMP5cADD+S5555j8eLFXH/99Vx55ZXU1tby6KOPtrnPrTplMyJelTSdYix+uaRBaS9/ELAiVavn7TdhHkJxG7R63n4T6IZyM7Muae7cuYwaNWqz8rvuuovZs2fzzDPPsGrVKg477DCOPvroCi28Xf/+/Zk1axbXXnstV1xxBT/+8Y8555xz6N27N1/+8pfbpc/VnL0zQNKuaXon4AMUd6ufCoxL1cZR3IOTVD5WUk9JwykO2D6VhoLWSDo8nbVzVmkZM7PtxmOPPcaZZ55Jt27dGDhwIMcccwwzZsxocbnTTjsNgFGjRrF48eKt0rdq9vQHAbekcf0dgMkRcY+kx4HJkj4LvACcDhAR8yRNBp4FNgLnRUTD3evPBSZS3CD5fnwQ18y6sBEjRvDLX/5ys/Kmbk7VvXt33nzzzbdeNz6/vmfPngB069aNjRu3zl0zW9zTj4jfR8QhEXFQRIyMiG+m8tURcXxE1KTnl0vLTIiIfSLi3RFxf6l8Zmpjn4g4P3zbLjPrwo477jjWr1/PjTfe+FbZjBkz6Nu3L7fffjubNm1i5cqVPPLII4wePZq99tqLZ599lvXr1/Paa68xbdq0FtfRp08f1qxZ02599mUYzGy70dFnfkni7rvv5oILLuDyyy+nV69eb52yuXbtWg4++GAk8b3vfY93vetdAJxxxhkcdNBB1NTUcMghh7S4jpNOOomPfvSjTJkyhWuuuYajjjqqbX3e1ne26+rqoivcRMWnbLYfn7Jp1Zo/fz77779/Z3ej01X6OUh6OiLqGtf1tXfMzDLi0Dczy4hD38y6tG19iHpra+32O/TNrMvq1asXq1evzjb4G66n36tXr6qX8dk7ZtZlDRkyhPr6elauXNnZXek0DXfOqpZD38y6rB49elR9xygreHjHzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjLYa+pKGSHpY0X9I8SV9M5ZdJelHS7PT4cGmZSyQtlLRA0gml8lGS5qR5V0vS1tksMzOrpJrbJW4ELoyIWZL6AE9LeijNuzIirihXlnQAMBYYAewB/FrSfhGxCbgOGA88AdwHjAHub59NMTOzlrS4px8RyyJiVppeA8wHBjezyMnApIhYHxGLgIXAaEmDgF0i4vEobl1/K3BKWzfAzMyq16oxfUnDgEOAJ1PR+ZJ+L+lmSX1T2WBgSWmx+lQ2OE03Lq+0nvGSZkqamfNd7s3M2lvVoS+pN3AncEFE/IViqGYfoBZYBvx7Q9UKi0cz5ZsXRtwQEXURUTdgwIBqu2hmZi2oKvQl9aAI/J9HxF0AEbE8IjZFxJvAjcDoVL0eGFpafAiwNJUPqVBuZmYdpJqzdwTcBMyPiB+UygeVqp0KzE3TU4GxknpKGg7UAE9FxDJgjaTDU5tnAVPaaTvMzKwK1Zy9837gk8AcSbNT2deAMyXVUgzRLAbOBoiIeZImA89SnPlzXjpzB+BcYCKwE8VZOz5zx8ysA7UY+hHxGJXH4+9rZpkJwIQK5TOBka3poJmZtR9/I9fMLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjLQY+pKGSnpY0nxJ8yR9MZXvJukhSc+n576lZS6RtFDSAkknlMpHSZqT5l0tSVtns8zMrJJq9vQ3AhdGxP7A4cB5kg4ALgamRUQNMC29Js0bC4wAxgDXSuqW2roOGA/UpMeYdtwWMzNrQYuhHxHLImJWml4DzAcGAycDt6RqtwCnpOmTgUkRsT4iFgELgdGSBgG7RMTjERHAraVlzMysA7RqTF/SMOAQ4ElgYEQsg+KDAdg9VRsMLCktVp/KBqfpxuWV1jNe0kxJM1euXNmaLpqZWTOqDn1JvYE7gQsi4i/NVa1QFs2Ub14YcUNE1EVE3YABA6rtopmZtaCq0JfUgyLwfx4Rd6Xi5WnIhvS8IpXXA0NLiw8BlqbyIRXKzcysg1Rz9o6Am4D5EfGD0qypwLg0PQ6YUiofK6mnpOEUB2yfSkNAayQdnto8q7SMmZl1gO5V1Hk/8ElgjqTZqexrwOXAZEmfBV4ATgeIiHmSJgPPUpz5c15EbErLnQtMBHYC7k8PMzPrIC2GfkQ8RuXxeIDjm1hmAjChQvlMYGRrOmhmZu3H38g1M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8tIi6Ev6WZJKyTNLZVdJulFSbPT48OleZdIWihpgaQTSuWjJM1J866WpPbfHDMza041e/oTgTEVyq+MiNr0uA9A0gHAWGBEWuZaSd1S/euA8UBNelRq08zMtqIWQz8iHgFerrK9k4FJEbE+IhYBC4HRkgYBu0TE4xERwK3AKVvYZzMz20JtGdM/X9Lv0/BP31Q2GFhSqlOfygan6cblFUkaL2mmpJkrV65sQxfNzKxsS0P/OmAfoBZYBvx7Kq80Th/NlFcUETdERF1E1A0YMGALu2hmZo1tUehHxPKI2BQRbwI3AqPTrHpgaKnqEGBpKh9SodzMzDrQFoV+GqNvcCrQcGbPVGCspJ6ShlMcsH0qIpYBayQdns7aOQuY0oZ+m5nZFujeUgVJtwHHAv0l1QOXAsdKqqUYolkMnA0QEfMkTQaeBTYC50XEptTUuRRnAu0E3J8eZmbWgVoM/Yg4s0LxTc3UnwBMqFA+ExjZqt6ZmVm78jdyzcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8tIi6Ev6WZJKyTNLZXtJukhSc+n576leZdIWihpgaQTSuWjJM1J866WpPbfHDMza041e/oTgTGNyi4GpkVEDTAtvUbSAcBYYERa5lpJ3dIy1wHjgZr0aNymmZltZS2GfkQ8ArzcqPhk4JY0fQtwSql8UkSsj4hFwEJgtKRBwC4R8XhEBHBraRkzM+sg3bdwuYERsQwgIpZJ2j2VDwaeKNWrT2Ub0nTj8ookjaf4r4A999xzC7toZgDDLr63s7uwXVl8+Uc6uwtt0t4HciuN00cz5RVFxA0RURcRdQMGDGi3zpmZ5W5LQ395GrIhPa9I5fXA0FK9IcDSVD6kQrmZmXWgLQ39qcC4ND0OmFIqHyupp6ThFAdsn0pDQWskHZ7O2jmrtIyZmXWQFsf0Jd0GHAv0l1QPXApcDkyW9FngBeB0gIiYJ2ky8CywETgvIjalps6lOBNoJ+D+9DAzsw7UYuhHxJlNzDq+ifoTgAkVymcCI1vVOzMza1f+Rq6ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUbaFPqSFkuaI2m2pJmpbDdJD0l6Pj33LdW/RNJCSQskndDWzpuZWeu0x57+30VEbUTUpdcXA9MiogaYll4j6QBgLDACGANcK6lbO6zfzMyqtDWGd04GbknTtwCnlMonRcT6iFgELARGb4X1m5lZE9oa+gE8KOlpSeNT2cCIWAaQnndP5YOBJaVl61PZZiSNlzRT0syVK1e2sYtmZtagexuXf39ELJW0O/CQpOeaqasKZVGpYkTcANwAUFdXV7GOmZm1Xpv29CNiaXpeAdxNMVyzXNIggPS8IlWvB4aWFh8CLG3L+s3MrHW2OPQlvUNSn4Zp4IPAXGAqMC5VGwdMSdNTgbGSekoaDtQAT23p+s3MrPXaMrwzELhbUkM7v4iIX0maAUyW9FngBeB0gIiYJ2ky8CywETgvIja1qfdmZtYqWxz6EfEn4OAK5auB45tYZgIwYUvXaWZmbeNv5JqZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWkQ4PfUljJC2QtFDSxR29fjOznHVo6EvqBvwI+BBwAHCmpAM6sg9mZjnr6D390cDCiPhTRPwfMAk4uYP7YGaWre4dvL7BwJLS63rgvY0rSRoPjE8v10pa0AF9y0F/YFVnd6Il+m5n98A6id+f7WuvSoUdHfqqUBabFUTcANyw9buTF0kzI6Kus/thVonfnx2jo4d36oGhpddDgKUd3Aczs2x1dOjPAGokDZe0IzAWmNrBfTAzy1aHDu9ExEZJ5wMPAN2AmyNiXkf2IXMeMrNtmd+fHUARmw2pm5nZdsrfyDUzy4hD38wsIw79jEmaLmmzU+Qk1Um6ujP6ZHmQtFhS/wrl/+DLs2xdHX2evnUBETETmNnZ/bD8RMRUfEbfVuU9/S5C0jBJ8yXdKGmepAcl7SSpVtITkn4v6W5JfSss203SRElzJc2R9KXS7NMlPSXpD5KOSvWPlXRPmr5M0k8l/Y+k5yV9voM22bYTkt4h6V5Jz6T34MfSrP8naVZ6T74n1f2UpB+m6YmSrpf0aHp/nthpG7Edceh3LTXAjyJiBPAq8I/ArcBXI+IgYA5waYXlaoHBETEyIg4EflKa1z0iRgMXNLEswEHAR4D3AV+XtEfbN8UyMgZYGhEHR8RI4FepfFVEHApcB3y5iWWHAcdQvP+ul9Rra3d2e+fQ71oWRcTsNP00sA+wa0T8JpXdAhxdYbk/AXtLukbSGOAvpXl3ldob1sR6p0TEGxGxCniY4sJ5ZtWaA3xA0nclHRURr6Xyat57kyPizYh4nuJ9/J6t29Xtn0O/a1lfmt4E7FqpUhrOmZ0e34yIV4CDgenAecCPK7S5iaaP8TT+Moe/3GFVi4g/AKMowv87kr6eZvm91wkc+l3ba8ArDWPxwCeB30TEpoioTY+vp7MkdoiIO4F/BQ5t5XpOltRLUj/gWIrLaZhVJQ0Hvh4RPwOuoHXvv9Ml7SBpH2BvwFfcbSOfvdP1jaMY69yZ4t/fT1eoMxj4iaSGD/lLWrmOp4B7gT2Bb0WEL5JnrXEg8H1JbwIbgHOBX1a57ALgN8BA4JyIWLd1upgPX4bBmiXpMmBtRFzR2X2xvEiaCNwTEdV+QFgVPLxjZpYR7+mbmWXEe/pmZhlx6JuZZcShb2aWEYe+bVckvUvSJEl/lPSspPsk7Sdp7lZc59pW1L1MUlOXHGhz+2Yt8Xn6tt2QJOBu4JaIGJvKainO8TYzvKdv25e/AzZExPUNBelaRUsaXqerlT6aru44S9IRqXyQpEfSpSvmSjqqhauTNkvSSZKelPQ7Sb+WVP7gObjSVUslXSRpRrpi6jfa8oMwa4r39G17MpLi4l3NWQH8fUSsk1QD3AbUAR8HHoiICZK6ATtTujopgKRdW9GXx4DDIyIkfQ74CnBhmncQcDjwDuB3ku5Nfa+huJidgKmSjo6IR1qxTrMWOfQtNz2AH6Zhn03Afql8BnCzpB7Af0XEbElvXZ2U4jIUD7ZiPUOA2yUNAnYEFpXmTYmIN4A3JDVctfRI4IPA71Kd3hQfAg59a1ce3rHtyTyKqzk250vAcoqrjtZRBDJpj/po4EXgp5LOqnR1UklDS1cwPaeZ9VwD/DDdv+BsoHwd+EpXjhTwndKF8vaNiJta3mSz1nHo2/bkf4CejcbJDwP2KtV5J7AsIt6kuCppt1RvL2BFRNwI3AQcWunqpBGxpBTM19O0d1J8gEBxUbyySlctfQD4jKTeqT+DJe2+BT8Ds2Z5eMe2G2n8/FTgKhU3114HLKa4K1iDa4E7JZ1OcUOYv6byY4GLJG0A1gJnUf3VSXeWVF96/QPgMuAOSS8CTwDDS/MrXbV0qaT9gceLk5BYC/wTxTEIs3bja++YmWXEwztmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWkf8PzecdfaIpyuEAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df.plot.bar(rot=0)\n",
"plt.title(\"distribution of images per class\");"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAZzUlEQVR4nO3deZgV1YH38e/pvt0s3ewIgkYruKAjAhoXzLglxmVShjhx0BgX4jZxxsRoJpmUmTfJNTpavG9MJu5xeRQ10XGLGMsoeTQQoxLUBBQ1osGKioDIcumGbnqr94+6bNIduFudu/w+z3Ofhqar7+9C94/Tp06dMlEUISIiyaizHUBEpJaodEVEEqTSFRFJkEpXRCRBKl0RkQSpdEVEEqTSFRFJkEpXRCRBKl0RkQSpdEVEEqTSFRFJkEpXRCRBKl0RkQSpdEVEEqTSFRFJkEpXRCRBKl0RkQSlbAcQ2cTxghHAntmHA4wBBgHNvbwdAHQCHcDG7NtNv24DPgSW9fJYHvpuZ1KvSeTjjG7XI0lyvCAF/AMwCZgMjCcu2D2Jy7TUeoAlwKLs49Xs28Wh73Yl8PxS41S6UlKOF3wCOBo4DDicuGj72czUhw7gDeAFYA4wJ/TdFVYTSVVS6UpROV4wEDgGODH72M9uooL8hWwBoxKWIlHpSsEcL9gb+GfgBOAoynMkW6gImA88BDwU+m5oN45UKpWu5MXxglHAl4EziacOas1LwIPAg6HvvmM7jFQOla7stOzUwSnAWcDxaPXLJi8DdwC/CH13ne0wUt5UurJDjhfsA3wTmE4yKwwq1Xrgf4EbQ9/9k+0wUp5UutInxwuOBS4DTkYX0uTqeeA64GEtRZOtqXRlG44XNBDP1V4GHGQ5TjX4G/DfwF26KENApStZjhfUAecAVwB7WI5TjULgKmCmRr61TaUrOF4wFbgaOMB2lhqwhLh871H51iaVbg1zvOBIwAf+0XaWGvQ28O3Qd2fZDiLJUunWIMcLPgn8DPiC7SxCAHwz9N2/2g4iyVDp1pDsZjPfAn4IDLQcR7bYCPxf4JrQd9tsh5HSUunWCMcLDgFuI95wRsrTO8Cloe8+ZjuIlI5Kt8o5XtBMfOLm60C95Tiyc+4D/j303bW2g0jxqXSrWPZE2b3Ee9VKZXkXOCf03bm2g0hxqXSLzBgzB/h2FEUvfez9hwDnRFF0SakzZNfcXk685laj28rVA1wL/J/Qdztsh5Hi0KWdCYmi6KWECnc08BTxlIIKt7LVAd8B5jlesL/tMFIcNV26xhjHGPOGMeY2Y8xrxpjZxpgBxpjJxph5xphXjDG/MsYM6+XYemPMXcaYRcaYV40xl231x9OMMfONMYuNMUdlP/5YY8zj2V+njTH3GGOeMca8ZYy5sBivx/GC44CFwOeK8fmkbBwEvOx4wRm2g0jharp0s/YBboyi6ABgLXAqcDfw3SiKJhLfQ+uHvRw3GdgtiqIJURQdCNy51Z+loig6DLi0j2MBJgIucATwA2PM2HxfgOMFxvGCK4DZwOh8P4+UtQHALx0vuCY7fSQVSv948E4URQuyv34Z2AsYGkXRphMYM4nv8fVxS4BxxpjrjTEnAVvvo/rIVp/P6eN5Z0VR1BZF0UfA78hzI3DHC/oD9wM/QP+etcADHnW8YJDtIJIffZPGC9M36QaG9vZB2emEBdnHj6IoWkN8R9s5wMXA7b18zm763uj742cwcz6j6XjBSOBp4LRcj5WK9gXgBccLxtkOIrlT6W4vA6zZNBcLnA3MjaKoO4qiydnHD4wxI4G6KIoeBr4PHJzj83zRGNPfGDMCOBZ4MZeDHS8YD8wDPp3j80p1OAB40fGCY2wHkdyodHs3Hfh/xphXiOduf9TLx+wGzDHGLADuIl6ilYv5xNfdzwOujKLog5090PGCo4k3yd4rx+eU6jIceNLxgpNtB5Gdp3W6Fhhj0kBrFEU/zvVYxwtOBX4JNBY7l1SsTuCs0HcfsB1Edkwj3QrieMGXiU+aqXBlaw3EKxu+ajuI7JhGuhXC8YKziZel6YIH6UsEXBL67g22g0jfNNKtAI4XnEk8b6zClb/HANc7XvAd20GkbxrpljnHC6YR7zqlwpVc/Gvou7fZDiHbU+mWMccLXOBR+l7rK9KXbuBfQt991HYQ2ZZKt0w5XnAQ8CzQZDuLVKx24ITQd5+1HUS2UOmWIccLdgf+COS9H4NI1lrg6NB3X7UdRGI6kVZmstfUB6hwpTiGEl9A4VjOIVkq3TKSvXHkA8Q7kIkUy1hgluMFuhlpGVDplpfrgZNsh5CqNBG41XYIUemWDccLzgUusp1DqtqZjhd83XaIWqcTaWUgeyuWlwD9+Cel1gkcE/ruC7aD1CqVrmWOFwwg3nFsgu0sUjOWAp8KfXeF7SC1SNML9l2HCleStRtwv277Y4f+0i3K7hp2ge0cUpOOJb7TsCRM0wuWZG+1sgDQva7Elg7gEF04kSyNdO25FRWu2NUI3ON4QYPtILVEpWuB4wXTgeNs5xAhvrlqrreakgJoeiFh2Tv4/gUYYTuLSFYHcHDou6/ZDlILNNJN3k9R4Up5aQTu0GqGZOgvOUGOFxwPnGU7h0gvDgfOsR2iFmh6ISGOF/QHXgPG2c4i0odlwD6h7663HaSaaaSbnG+gwpXyNgb4ru0Q1U4j3QQ4XjAUWAIMsxxFZEfagPGh775nO0i10kg3Gd9FhSuVYQBwje0Q1Uwj3RJzvGAs8DbxF3PJda56n5WPzdj8+661yxl65Fn0bGyldeFT1A0cAsCwo89hwF6Hbnd825KXWf30rdDTQ/OkExgyZRoA3W0tfDRrBl3rVpAaPJqRp3jU92+m/f3XWT37Jkx9AyOnfoeGYWPpaW9l5awZjDrtRxhjknjZUlwRMCX03fm2g1Qj3WW29H5IQoUL0DBid8aeez0AUU837980nYH7HkHrq79l0CGnMOTwL/V5bNTTzerf3syo068iNWgEy2ZexoC9D6dx5B6sm/cg/Z1JDJkyjcy8B1k370GGHXsu6178FbuccjldmQ9p+fMTDP/sBax9/n6GHHGaCrdyGeBK4ETbQaqRphdKyPGCfYHzbD1/+98W0jB0DKkho3bq4zuWLSY1dAwNQ3fF1DfQtP/RtL01D4ANb/+RpgnxRXRNE45jQ/b9pi5F1NVB1LURU5eic80yultW0X+PA0vzoiQpJzhecLDtENVII93S+h4W/47Xv/F7Bu5/9Obft/zpcda/9gyNu+7NsM9eQH3/5m0+vqtlFanBu2z+ff2gkXQsexOA7vVrSTUPByDVPJye9WsBGDJlGquevAHT0MhI9z9Y87s7GHqUliJXicuBabZDVBuNdEvE8YIxwBm2nj/q7qTt7fk07XckAIMO+jy7fe02xpx7HfXNw1nzzO07+Zn+/hRB4+hxjDnnWnY94xq6Msupzxbzylkz+OjXP6Z7/ZpCXobY9aXsT2tSRCrd0rmY+PJKK9qWvEzj6L2ob4oXTdQ3DcPU1WNMHYMmnUjHssXbHZMaNIKudSs3/7675aPNJVrfNJSu1tUAdLWupq5p6DbHRlFE5vn/Zcg/nsHa537J0CO/QtMBn2Hdy78u0SuUBNQB/2k7RLVR6ZZA9hY8Vm8yuf71uTRtNbWwqTABNix+gYaRe253TOOYfela8wGda5cTdXey/o3fM2DvwwEYuPfhrF/0dPy5Fz3NwOz7Nz/foqcZsNch1PdvJurcCKYOjIl/LZXsbMcLdrMdoppoTrc0zsHipjY9ne20hwsYcdKWG7+unXMnHSuWgDGkhoxi+Inxn3W1rGLVk9cxetoVmLp6hh9/ER8+8AOIemg+8Hgad4nLefCUf+GjWT6tr8wmNXgXRn7x8m2er3XR04w+7cr4Yw89hZW/uhpTn2LkVA2UKlwjcCm6y0TRaJ1ukTleYIDXgf1sZxEpkpXAbqHvdtoOUg00vVB8J6LCleqyCzDVdohqodItvnNtBxApgfNtB6gWml4oIscLBgPLSfAKNJGE9AB7hr77vu0glU4j3eL6EipcqU51wHTbIaqBSre4dCmWVLPzsieKpQAq3SLJ7ib2Gds5REpoHHCE7RCVTqVbPGegv0+pflrFUCCVRPF82XYAkQSodAuk1QtF4HjBLsAKdrQ7jEh12Cf03bdth6hUGukWx/GocKV2fMF2gEqm0i0O7bAvtURTDAXQ9EKBsktoPgB2tZ1FJCFdwKjQd7VZch400i3cRFS4UltSwLG2Q1QqlW7hNLUgtejTtgNUKpVu4T5nO4CIBbpIIk+a0y1Adj53DTDEdhaRhLUDQ0Lf7bAdpNJopFuY8ahwpTb1Bw6yHaISqXQLc5jtACIWaV43DyrdwhxiO4CIRZrXzYNKtzCTbAcQsehg2wEqkUq3MCpdqWWO4wUNtkNUGpVunhwv2AOdRJPaVg/sZTtEpVHp5m+c7QAiZWBf2wEqjUo3f3vaDiBSBlS6OVLp5k+lKxKvVZccqHTz59gOIFIGNNLNkUo3fxrpimjwkTOVbv4c2wFEysAI2wEqjTa8yUN2o5t2oNF2FpEy0D/03Y22Q1QKjXTz04QKV2ST4bYDVBKVbn6abQcQKSOaYsiBSjc/g2wHECkjKt0cqHTzo5GuyBYq3RyodPOjka7IFkNtB6gkKt38aKQrsoV2GsuBSjc/Kl2RLeptB6gkKt38GNsBRMqISjcHKdsBKpTugFoGfpS6c+7Z9b/VLWMs66auJ74ptuwMlW5+dPVNGUjRjTG6SMW2FD09tjNUEk0v5EcjXZEtumwHqCQq3fyodEW26LQdoJKodPOj6QWRLTTSzYFKNz8a6YpsobNoOVDp5met7QAiZeQD2wEqiUo3P8tsBxApI/p+yIFKNw+h77aj0a4IQDfwoe0QlUSlmz/97y4CK0hntE43Byrd/Kl0RfR9kDOVbv70xSaik2g5U+nmT19sIhp85Eylm7+/2Q4gUgY0+MiRSjd/i2wHECkD79sOUGlUuvlT6YrAn20HqDQq3TyFvrsKWG47h4hFHWjwkTOVbmFetR1AxKJXSWe0D0mOVLqF0f/yUsteth2gEql0C6PSlVr2ku0AlUilW5g/2Q4gYpFGunlQ6RbmFbSXqNQmnUTLk0q3AKHv9gDP2s4hYoFOouVJpVu4ObYDiFig+dw8qXQLN8d2ABELnrYdoFKpdAu3EM3rSm3pAJ60HaJSqXQLpHldqUFzSWdabIeoVCrd4phtO4BIgh6zHaCSqXSL4xFAtyyRWqHSLYBKtwhC310GPG87h0gCFpLOvGs7RCVT6RbPg7YDiCRAo9wCqXSL52Egsh1CpMRUugVS6RZJ6LtLgRds5xApoQ/QfgsFU+kW10O2A4iU0H2kM/pprkAq3eK6H+iyHUKkBCLgZtshqoFKt4iyqxgC2zlESmA26cxfbYeoBird4vu57QAiJXCT7QDVQqVbfE8Boe0QUrj3Mj18ZuZ69r+xlQNuauVn8zYCkJ7Tzm4/aWHyLa1MvqWVJ97q7PX4J9/uYvwNrex9XQv+HzZufv/qtojj71nPPte3cvw961nTFk+TPvduFxNvbuXQ21p5e3V8rc3a9ogT711PFFmdSv0b8LjNANVEpVtk2b0YNCqoAqk6uPaE/rxxcTPzzm/ixhc7eX1lNwCXTWlkwUXNLLiomc/v07Ddsd09ERc/0cZvzhzI6xc3c9+iLcf6f9jIcZ9M8dY3mjnuk6nNhXztCx08fNoArv5sf25+Md6q9sq5G/nekf0wxiT0qnt1K+mMrrgsEpVuadwObLAdQgozZlAdB4+pB2BQP8P+u9SxdN3OjTjnL+1m7+F1jBtWR2O94csHNDDrL/E51llvdjF9UlzU0yc18Oib8fsb6qGtCzZ0RjTUw19X97C0pYdjnFQJXt1O6yD+epYiUemWQOi7a4B7beeQ4gnX9vDnZd0cvntcwjfM72Diza2cN6tt8/TA1pa2RHxi8JZvr90HG5a2xIPFFa09jBkU/9mYQXV8uD5+/+VH9uNff93O//yxg68f1sh/PdPOlZ/pV+qXtiMPk858aDtENVHpls4MtHysKrR2RJz6wAb+56T+DO5n+LdDGvnrJc0suKiJMc2G/5jdvt0xvU3B7miCYPKu9cy7oInfTW9iyZoexg6qIwJOf2gDZz3SxopWKz/h32jjSauZSrdEQt9dAtxtO4cUprM7LtwzD2zgS/vHUwKjm+uorzPUGcOFn2pk/tLu7Y7bfbDhvXVbSvL9dRFjs6Pb0c11LMuOepe19DCqadtvwyiKuOr3G/n+0f24Yu5Grji2H2dNbOC6PyZ+S7Lfkc48l/STVjuVbmldhUa7FSuKIs5/rJ39R9bzrSO2/Ji/qTABfvVGJxNGbf9tdOhu9by1qod31vTQ0R1x/2udTB0fz81O3TfFzIXxioeZCzv54vht52xnLuzE3SfFsAGGDZ1QZ+LHht4XSZSSl/gz1gCrM/TVLvTddxwvmAmcbzuL5O6597q555VODhxVx+RbWgG4+rh+3LeoiwXLuzGAM7SOn5/cH4APWnq44LF2njhzIKk6ww2f78+J926gO4o4b3IjB4yK54O9Ixs57aE27vhzJ3sMMTw4beDm59zQGTFzYSezz4rf960pjZz6QBuN9XDfqQOSfPmPkM7MT/IJa4WxvP6v6jle4ACLge3XFUlBrk7dPvcrqWeOsZ2jCnUDE0hn/mI7SDXS9EKJhb4bAndZjiGSi7tUuKWj0k3GlWjdrlSGdiBtO0Q1U+kmIPTd94D/tp1DZCfcQDrzvu0Q1Uylm5wfE8/tipSrDHCN7RDVTqWbkNB3O4BLbOcQ+TuuIJ1ZbTtEtVPpJij03aeIb9cuUm6eB35mO0QtUOkm7zJ0Uk3KSxtwrnYSS4ZKN2Gh776Lzg5Lefk+6YzONyREpWvHtcDvbYcQIZ5W+KntELVEpWtBdqPzs4nPFovYomkFC1S6lmSnGS62nUNqmqYVLFDpWhT67i+A+2znkJqkaQVLVLr2/Tvwru0QUlPWAtM1rWCHStey0HfXAmcCye+WKrWoGziddOZt20FqlUq3DIS++wfg67ZzSE34T9KZ2bZD1DKVbpkIffdWdD8qKa2ZpDM/sR2i1ql0y8ulwDO2Q0hVmgd8zXYIUemWldB3u4BpwBLbWaSqLAX+mXRmo+0gotItO6HvrgamAi22s0hVaAdOIZ1ZbjuIxFS6ZSj03deIi7fNdhapaBFwPunMS7aDyBYq3TIV+u4c4FSgw3IUqVyXkM780nYI2ZZKt4yFvvsb4CvEaytFcvFt0pkbbIeQ7al0y1zouw8D5xH/qCiyM75POnOt7RDSO5VuBQh99260OY7snDTpzFW2Q0jfVLoVIvTdm4GLAF0vL33xSGeusB1C/j6VbgUJfffnxOt4td5SPu5S0pkZtkPIjql0K0zou48AJ6IN0CXWDXyNdCbvm0oaY0JjzMhe3j/VGOMVlE62o9KtQKHvzgWOAbTgvbatAf6JdObWUnzyKIoei6LIL8XnrmUq3QoV+u5C4NPAW7aziBVvAIeRzvw2l4OMMU3GmMAYs9AYs8gYc3r2j75hjPmTMeZVY8x+2Y/9qjHmhuyv7zLG3GKMedYYs9gYc3JxX07tUOlWsNB33wEOB35jO4sk6nFgSp574p4EfBBF0aQoiiYAT2bf/1EURQcDNwPf7uNYh/gnLBe4xRjTP4/nr3kq3QoX+u4a4m+CNFrZUAtmAF8knVmX5/GvAp8zxswwxhwVRdGmcwOPZN++TFyuvXkgiqKeKIreIt6Uab88M9S0lO0AUrjQdyPgCscL5gO/AIZZjiTF1wZcUOhlvVEULTbGfAr4PHCNMWbThuabVsR003cvfPwCHV2wkweNdKtI9rLhTwELLEeR4noPOLoY+ygYY8YCG6Iouhf4MXBwDodPM8bUGWP2AsYBbxaapxapdKtMdp73COAW21mkKG4HJhRxp7ADgfnGmAXAfwG5XL32JjCX+BzCRVEUtRcpU00xUaSfEKqV4wUnAXcAY21nKYWrU7fP/UrqmWNs5yiR94ALSWeesh0E4tULwONRFD1kO0ul00i3ioW++yQwAbjHdhbJyabRbVkUrhSXTqRVuezqhnMcL7ifeMrhE5YjSd/KanS7tSiKvmo7Q7XQSLdGhL77BHAA8DOg03Ic2Z5GtzVCpVtDQt9tCX33UuIph19bjiOx54CjSGcuLGDtrVQQTS/UoNB3FwNTHS84DvgJMNFypFq0CPge6Yz+86sxGunWsNB3nwYOAi4EVliOUytCYDowSYVbm1S6NS703Z7Qd28nXuz+TeKTOVJ8K4FLgfGkM3eTzuiS7Rql6QUBIPTdDcB1jhfcDJwJfBddW18MK4CbgJ+SzrTYDiP2qXRlG6HvdgJ3OV5wN3AK4AGHWg1VmZ4lLtuHSWe0WkQ2U+lKr0Lf7SHeeeoRxwumAOcDpwODrAYrby3AvcBNpDOLbIeR8qTSlR0KfXceMM/xgkuB04hvCX+k1VDl5TXiUe09mkKQHVHpyk4LfXc9cCdwp+MF44FzgVOBva0Gs+NN4DHgUdKZ522Hkcqh0pW8hL77JvF8r+d4wf7AVOALxDucVeOqmG7geeKifYx0ZrHlPFKhVLpSsNB33yC+Z9cMxwt2Id4g+2TgKGC0zWwFagGeIi7aJ0hnVlnOI1VApStFFfruSmBm9oHjBeOIR7+fzr6dCNRbC9i3dmAh8e1qXsq+fZ10pstqKqk6Kl0pqdB3lxDfT+sXAI4XNAGHAP8A7LvVwyG5r8c1wGJUsGKBSlcSlT0ZNzf72Mzxggbiq+L2AXYFRvTxGEg8Uk51U9cKLAO6so9O4hHr8uz7lwEffOztMtIZ3fFArNGdI0REElSNZ5lFRMqWSldEJEEqXRGRBKl0RUQSpNIVEUmQSldEJEEqXRGRBKl0RUQSpNIVEUmQSldEJEEqXRGRBKl0RUQSpNIVEUmQSldEJEEqXRGRBKl0RUQS9P8Be9gvbnlkJxwAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.pie(count,\n",
" explode=(0,0),\n",
" labels=class_names,\n",
" autopct=\"%1.2f%%\")\n",
"plt.axis('equal');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Due to imbalance in dataset, upsampling is done on the minority class, by randomly duplicating images until the 2 classes have comparable distribution in the dataset. After this is done, the dataset will be split into the training, testing and validation sets by randomly shuffling them and then splitting.\n",
"\n",
"Another way is to introduce class weights for each specific class. Each class is penalised with the specific class weight. Higher the class weight, greater the penalty. Classes with lower percentage have a higher penalty. This allows for the model to penalise itself heavily if class detected is incorrect."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Augmenting Images of Minority Class\n",
"\n",
"The images present in the ship class are augmented and then stored in the dataset, so that there is an equal representation of the classes. The current ratio of classes is 1:3, meaning that for every image present in the ship class there are 3 images present in the no-ship class. This will be countered by producing 2 augmented images per original image of the ship class. This will make the dataset balanced.\n",
"\n",
"If augmentation of dataset is required then set AUGMENTATION to True. This will balance the dataset via augmentation of minority classes. To train via class weights, then set AUGMENTATION to False."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"AUGMENTATION = True"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def augment_add(images, seq, labels):\n",
" \n",
" augmented_images, augmented_labels = [],[]\n",
" for idx,img in tqdm(enumerate(images)):\n",
" \n",
" if labels[idx] == 1:\n",
" image_aug_1 = seq.augment_image(image=img)\n",
" image_aug_2 = seq.augment_image(image=img)\n",
" augmented_images.append(image_aug_1)\n",
" augmented_images.append(image_aug_2)\n",
" augmented_labels.append(labels[idx])\n",
" augmented_labels.append(labels[idx])\n",
" pass\n",
" \n",
" augmented_images = np.array(augmented_images, dtype=np.float32)\n",
" augmented_labels = np.array(augmented_labels, dtype=np.float32)\n",
" \n",
" return (augmented_images, augmented_labels)\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"seq = iaa.Sequential([\n",
" iaa.Fliplr(0.5),\n",
" iaa.Crop(percent=(0,0.1)),\n",
" iaa.LinearContrast((0.75,1.5)),\n",
" iaa.Multiply((0.8,1.2), per_channel=0.2),\n",
" iaa.Affine(\n",
" scale={'x':(0.8,1.2), \"y\":(0.8,1.2)},\n",
" translate_percent={\"x\":(-0.2,0.2),\"y\":(-0.2,0.2)},\n",
" rotate=(-25,25),\n",
" shear=(-8,8)\n",
" )\n",
"], random_order=True)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"4000it [00:16, 235.65it/s] \n"
]
}
],
"source": [
"if AUGMENTATION:\n",
" (aug_images, aug_labels) = augment_add(images, seq, labels)\n",
" images = np.concatenate([images, aug_images])\n",
" labels = np.concatenate([labels, aug_labels])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((6000, 48, 48, 3), (6000,))"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"images.shape, labels.shape"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADxCAYAAABoIWSWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAYEUlEQVR4nO3deZwcZYHG8d87PZNrknQuyAFoy62EQ4gohxxyrgNBhSTKDQIrQVZ0PRpYsbgbV1xEJVwrIHKfAZqNUY4AcoREAwQQAqQFkkDuziSTmcxM1/5RnUxgJkf1UW9V1/P9fOYzQ6Yr83Toevqdt6reMq7rIiIiwaizHUBEJE5UuiIiAVLpiogESKUrIhIgla6ISIBUuiIiAVLpiqzHGPO0MWZMD38+xhhzrY1MUlvqbQcQiQLXdWcAM2znkOjTSFciyxiTMsa8aYy5yRjzujFmqjGmrzFmD2PMi8aYV40xDxljBvewbcIYc6sxZrYx5jVjzA/X+/Y4Y8x0Y8zbxpivFh9/kDHmseLXjjHmdmPMk8aYOcaYMwN6ylIDVLoSdTsAv3dddxdgOXAs8EfgZ67r7ga8Bvyih+32ALZyXXe067q7Ares971613X3Bs7bwLYAuwFNwD7ARcaYUeU/FYkDla5E3VzXdWcVv54JbAcMcl13WvHPbgMO6GG794BtjTG/NcYcCaxY73sPrvf3pTbwcye7rrvadd3FwFPA3qU/BYkTla5EXdt6X3cCg3p6UHE6YVbx4xLXdZcBuwNPA+cAN/fwd3ay4eMen160RIuYyGZR6UqtyQPL1s7FAicB01zX7XRdd4/ix0XGmGFAneu6DwA/B/b0+XOOMcb0McYMBQ4CXq7UE5DaprMXpBadAlxvjOmHN41wWg+P2Qq4xRizduBxvs+fMR3IAp8BLnVdd36pYSVejJZ2FPHHGOMAK13X/ZXtLBI9ml4QEQmQRroiIgHSSFdEJEAqXRGRAKl0RUQCpFPGJDRS6WwSGAWM/NTnEUA/vNdrPdAAJIAC0A50FD9agYXAfGDBpz4vyWWadABDrNOBNAlUKp3tg7duwV7Fj53wynUkXrFWyxrgI7wSfg/vEt+ZwN9zmaYVG9tQpJJUulI16xXsGLpKdhfC9RuWC7xDVwnPQEUsVaTSlYpJpbMGr2DHAl/HK9wwFezmcoE5wJ+BR4BpuUxTu91IUitUulKW4mj2UOBo4Ci8Odhak6ergB/PZZqWWc4jEabSFd9S6eyWeAU7FjiM6s7Fhk0H8BxeAT+SyzS9azmPRIxKVzZLcergcGAi3uLdCbuJQuMF4DrgvlymqW1TDxZR6cpGpdLZwcDpwPeA7S3HCbNFwB+A63OZppzlLBJiKl3pUSqdHYO3uPcEoK/lOFFSAB7HG/1O0bnB8mkqXVknlc4mgOOBc4EvWY5TC94DJuGNflfaDiPhoNIVAFLp7LHA5XgXK0hlLQIuwyvfNbbDiF0q3ZhLpbNfAzJoZBuEuXh3F74jl2kq2A4jdqh0YyqVzu4JXIl3RoIE61XgglymKWs7iARPpRszqXR2e7xfdccDxnKcuHsWSOcyTc/bDiLBUenGRPHKsUuA8/BW6ZLwuA/4fi7TtNB2EKk+lW4MpNLZfYBb0EGyMFuMV7z32A4i1aXSrWHF0e2lwI/QgvVR8QAwUaPe2qXSrVEa3UbaYuDcXKbpbttBpPJUujWmOLq9DPghGt1G3YPA2Rr11haVbg1JpbN7AXeg0W0tWQKclcs0PWg7iFSGSrdGpNLZE4CbgT62s0hVXA78XGs5RJ9KN+JS6Wwd3kUOP7WdRaruYeAkreMQbSrdCEulswOBO/HWt5V4mA2MzWWa5toOIqVR6UZU8cqyR4DP284igVsCHJfLND1tO4j4p6PbEZRKZw8FpqPCjauhwF9S6exE20HEP410IyaVzp4L/A+6XY54bsC7kq3DdhDZPCrdCEmls5cAP7edQ0LnQeA7Wqs3GlS6EZFKZ38J/MR2DgmtLHCsbo4ZfirdkCvehfc3eLfQEdmYvwLH5DJNLbaDyIbpQFr4/Q4VrmyeQ4FsKp3VjURDTKUbYql09teAjlCLHwcBD6fS2d62g0jPVLohlUpnr8BbtEbEr8OB+1LprBarDyGVbgil0tkLgfNt55BIOxq4s3iZuISI/oeETCqd/Tbe0owi5ToOb10OCRGdvRAixTv0PgfoQIhU0om5TNMdtkOIR6UbEql0djgwA9jadhapOa3AAblM08u2g4hKNxRS6Wwv4ClgX9tZgvDhpNOp69UX6uowdQlGnnINnaubWTz5KjpWfEz9wOEM+0aaRJ/+3bZd/d5Mlj5xIxQK9N/9cJJfGQewwe1bP3yDpVOvwyQaGDb2JzQMHkWhdSWLJl/FluMvwZjY3IV+PjAml2laYDtI3GlONxyuJyaFu9bw71zBqNN+y8hTrgFgxYv30Se1O1uddRN9Uruz4sX7um3jFjpZ+pdJbDnuYkadcR2r3pjGmsXvb3T7FS8/xBbfOJ9BB5xM8z8eB2D583eT3Gd8nAoXYBTeqWRa5N4yla5lqXT2POA02zlsa3nnJRpHHwJA4+hDaJnzYrfHrFnwNvWDRtIwaAQm0UDj5w9gdfFxG9re1NXjdqzB7WjD1NXTvmwBnc1L6POZXQN6ZqGyN3Cj7RBxp9K1KJXOHgb8ynaOwBnDwnsvYsGtP6B51hQAOlctp77/EADq+w+hsGp5t806mpdQP3CLdf+dGDCMzpVLNrp98ivjWDLld6yYMZkBex7F8mf+yKCvnljFJxd6J6XS2R/bDhFn9bYDxFUqnR0F3E0Ml2gcccIvqR8wlM5Vy/n4nv+iYWg5xw43PkXQa/i2jDz5agBaP5hNoljMiyZfhalLMPhr3yXROLiMnx9JV6XS2em5TNMztoPEkUa69twADLEdwob6AUMBSDQOot+O+9A2/20SjYPoWLkUgI6VS6lrHNTjdh0rFq37787mxetKdFPbu65L/vl7SO73HZb/7U4G7X88jbsczIqZj1bhGYZeHfCHVDrbz3aQOFLpWpBKZ08GjrKdw4bCmlYKbS3rvm6d+w96bfFZ+m3/ZVbNfgKAVbOfoN/2X+62ba+RO9KxbD7tyz/C7Wxn1ZvP0Lf4uE1tv2r2E/TdbgyJPv1x29vA1IEx3tfxtB26cMIKnTIWsOK0wuvAIMtRrGhf/hGLHixecFco0PiFA0nuO4HO1StYPDlDx4pF1A/cgmHHnE+i7wA6mpewZMq1DB93MQCr332ZpU/cBG6B/rseRnLfCQAb3B6g0N7KwvsvZvj4SzGJelo/mM3SqZMwiXqGjf0pDUO2svJvEQIucJCmGYKl0g1YKp19lJiOciWU3gV20xq8wdH0QoDiPK0goaVphoBppBuQuE8rSKi5wIG5TNOztoPEgUa6wbkBFa6EkwFu0dkMwVDpBiCVzo5H0woSbtsBv7AdIg40vVBlxdX738R7UYuEWSuwQy7T9KHtILVMI93qOxMVrkRDH8CxHaLWaaRbRal0thF4BxhhO4vIZuoERucyTf+0HaRWaaRbXeehwpVoSQCX2w5RyzTSrZJUOjsUeA8YaDuLSAm+nMs0TbcdohZppFs956PClejK2A5QqzTSrYJUOrsNMAfobTuLSBmOzGWa/mw7RK3RSLc6HFS4En1XptLZWN3TKAgq3Qor3tX3JNs5RCrgi8AhtkPUGpVu5Z0JNNgOIVIhE20HqDWa062gVDqbAOYC29jOIlIhncBnc5mmebaD1AqNdCvraFS4UlsSwL/bDlFLVLqVdY7tACJVcGZxDRGpAJVuhaTS2R3RQQepTSOAb9kOUStUupVzNpu6H7hIdOmAWoXoQFoFFBd/nocWKZfaNjqXaXrddoio00i3MiagwpXad7btALVApVsZ420HEAnAsbpCrXwq3TKl0tn+wMG2c4gEYASwt+0QUafSLd8RaJ0FiY+xtgNEnUq3fHoRSpzo9V4mnb1QhuJlvx8Bw2xnEQnQtrlM01zbIaJKI93y7IsKV+JHo90yqHTLc7TtACIW6HVfBpVuefSOL3F0QCqdTdoOEVUq3RKl0tkdgJ1s5xCxoAH4N9shokqlWzqdmytxdpDtAFGl0i3dXrYDiFik13+JVLql04tO4mxXrbFbGpVuCVLpbC9gV9s5RCzqDYy2HSKKVLqlGQ30sh1CxDL9tlcClW5p9GIT0X5QEpVuacbYDiASAtoPSqDSLY3e4UV0MK0kKl2fdBBNZB0dTCuBSte/ndFBNJG19rAdIGpUuv5tYzuASIhsbTtA1Kh0/RtlO4BIiGh/8Eml699I2wFEQkT7g08qXf/0zi7SRfuDTypd//TOLtJF+4NPKl3/9M4u0mVEKp01tkNEiUrXP72zi3SpB7awHSJKVLo+pNLZOmC47RwiIaPf/nxQ6fqzBd47u4h00W9/Pqh0/RliO4BICGm/8EGl648W9xDpTvuFDypdfzS1INKd9gsfVLr+6B1dpDvtFz6odP3RO7pId9ovfNA/lg+v9f5uoZHWpbZziITJanp1wELbMSJDpevDALO6gI7UinxCI222I0SKphf86bAdQCSEtF/4oNL1Ry8uke60X/ig0vWn3XYAkRDSfuGDStefvO0AIiGk/cIHla4/HwMF2yFEQmaB7QBRotL1w8l3AItsxxAJmfm2A0SJStc/vauLdCng/QYom0ml65/e1UW6LMTJd9oOESUqXf800hXpov3BJ5WufxrpinTR/uCTStc/vbOLdNH+4JNK1z+9s4t00f7gk0rXv3m2A4iEiErXJ5Wuf28AOlor4nnVdoCoUen65eRbgH/ajiESAp3AK7ZDRI1KtzQzbQcQCYE3i4MQ8UGlWxqVroj2g5KodEszw3YAkRDQflAClW5pZqGDaSIa6ZZApVsKHUwT0UG0Eql0S6d3eYkzHUQrkUq3dCpdiTO9/kuk0i3dc7YDiFik13+JVLqlcvJ/R5cESzy5wGO2Q0SVSrc8j9oOIGLByzj5j2yHiCqVbnkesR1AxAK97sug0i3Pk8BK2yFEAqbSLYNKtxxOvg2YajuGSIByOPnXbIeIMpVu+TSvK3Gi13uZVLrlewzvNtQicaCphTKpdMvl5BcDL9iOIRKAPDDNdoioU+lWxgO2A4gE4FGcfLvtEFGn0q2M24BW2yFEquwG2wFqgUq3Epz8UuAe2zFEquhVnLwu/a0AlW7lXGc7gEgVTbIdoFaodCvFyU9HK+lLbVoB/Ml2iFqh0q0sjXalFv0RJ68rLytEpVtZdwNLbYcQqTANJipIpVtJTn41cKvtGCIV9DRO/k3bIWqJSrfyJuGtNypSCzTKrTCVbqU5+XfQAs9SG3LAQ7ZD1BqVbnVciNZjkOi7CCffYTtErVHpVoO39N0dtmOIlEGv4SpR6VbPRcAa2yFESnQBTl6/rVWBSrdanHwOuN52DJESPIeT13GJKlHpVtdlQLPtECI+/cx2gFqm0q0mJ78I+LXtGCI+PIqTf952iFqm0q2+q4FFtkOIbIYCcIHtELVOpVttTr4ZuNR2DJHNcDtOfrbtELVOpRuM69AKZBJui4Gf2g4RByrdIDj5TuBUoM1yEpEN+T5OfqHtEHGg0g2Kk38duMR2DJEePICT151PAqLSDdZVaJpBwmUxMNF2iDgxrqsFsQLlJHcBZgK9bUexJXVNMwN6GxIG6utgxln9WbraZcL9LeSWu6QGGe49rh+D+5pu2055p4MfTGmls+Byxp69SO/v/TNuaPu/vd/B2dlWetfDXcf2Y/shdSxv9R475YR+GNP9Z8TMtzXKDZZGukHTNAMAT53Sj1nf68+Ms/oDkHmujUM+V8+cc/tzyOfqyTzXffq7s+ByzuOr+b8T+vHGOf25a3Y7byzq3Oj2V7+whgfG9+WKr/Vh0sveVdmXTmvjgv17q3A1rWCFStcOTTN8yuS3Ojhl9wYATtm9gYff6r641fR5nWw/pI5tB9fRK2H49i4NTP5nx0a3b0jA6g5oaXdpSMC7SwvMay5wYKo+oGcWWppWsESla0PMz2YwBg6/vYW9blzJjTO90efHKwuMHOC9HEcOqGPhqu5rrcxrdtlmYNdLduuBhnnNhY1uf/7+vTnr0VaueWkN39+7Fxc+2cqlB8d2Zmd9OlvBkti/3Vvj5F/HSU4E/td2lKD97fRGRhWL8bDbW9h52Oa99/d0+GFTEwR7jEjw4hmNADzzrw5GDajDBSbc30JDneHqw3szvH/sxh6TNK1gT+xebaHi5P8AXGs7RtBGFUekWzbW8c2d65k+r5Ph/etYUBy1LmgusGVj95fm1gMNH6zoGgF/uMJd93dtanvXdbnsmTZ+fkBvLp7WxsUH9ebE3Rq49qXYrb75FPAftkPEmUrXvh8Bf7UdIiir1rg0t7nrvp76biejt0wwdsd6bnulHYDbXmnnmJ26/xL2pa0SzFlSYO6yAms6Xe5+vZ2xxcdtavvbXmmnaYd6Bvc1tLRDnfE+Wtqr+WxDZy4wTneDsEunjIWBkxwCvARsbztKtb23rMA372kBoKMAx49u4MIDerOkpcD4+1fzft7lM0nDfeP6MaSvYX5zgTMeaeXxE/oB8Picds6b0kan63L6Hr248ABvfnZD24N3EK3pzhamntiPhoTh2X91MPHxVnol4K5j+7Lj0ISdf4xgrQT20doK9ql0w8JJfgF4ARhoO4rUHBf4Fk7+YdtBRNML4eHk3wBOQDe0lMr7hQo3PFS6YeLdIuW/bMeQmnIvTl5Li4aISjdsnPyVwG22Y0hNmA6cZjuEfJJKN5y+C9xrO4RE2j+AI3DyLbaDyCfpQFpYOcl64H7gGNtRJHJmAwfj5BfbDiLdqXTDzEn2AiYDR9qOIpHxFnAgTv5j20GkZ5peCDMnvwb4JpC1HUUi4U28Ea4KN8RUumHn5FuBbwEP2Y4iofYq3gh3ge0gsnEq3SjwRrzjgbttR5FQmoE3wl1kO4hsmko3Krzr5U8ArrcdRULlSeBQnPxS20Fk8+hAWhR5S0L+Bi3NGXe/BX6kBWyiRaUbVU7yILxTyoZaTiLBWwOcg5O/2XYQ8U+lG2VO8nN4p5TtajuKBGYhcCxO/jnbQaQ0mtONMic/F9gXeNhyEgnGLOBLKtxoU+lGnZNfiXdK2WW2o0hV3Q/sh5N/33YQKY+mF2qJkzwOuAkYZDmJVE474ABX4uS1s9YAlW6tcZKjgBuBJttRpGyvAKfi5GfZDiKVo9KtVU7yFOAaNOqNonbgCuBynHy87uIWAyrdWqZRbxRpdFvjVLpx4CRPxruYYpDlJLJhGt3GhEo3LrxR7w3AUbajSDezgNM0uo0HlW7cOMkjgCuBL9qOInyAd2bCbTj5TstZJCAq3ThykgaYgHdu73aW08TRUryphN8Xl+6UGFHpxpmTbADOBC4ChltOEwcteGeU/BInn7ecRSxR6Qo4yUbgh8BPgIGW09SiDryLVi7ByX9kO4zYpdKVLk5yKPCfeKPfYZbT1ILVwF14V5O9YzuMhINKV7pzkr3x7lQxEfiK5TRRNAeYBNyKk19mO4yEi0pXNs5JfhGvfI8H+llOE2adwGPA74G/ap0E2RCVrmweJ5kETgXOBnayGyZUPgZuBm7AyX9gO4yEn0pX/HOS+wHHAGOJZwEvAB4FHgGm6goy8UOlK+Vxkjvile/RwH5Awm6gqnkVr2QfAWZo+kBKpdKVynGSQ/AW1xkLHAEMsBuoLGuAaawtWi0eLhWi0pXqcJIJYGdgL2BM8fMehPNgXDvwOjATmFH8/CpOvs1qKqlJKl0JzieLeG0Z74h3R2MTUIo8MBcVrFii0hX7nGQvYAQwEhjVw+cReCPk+uJHA97ccQFvlNpR/GjDO5tgPt7Bru6fnfzqoJ6WSE9UuiIiAdLdgEVEAqTSFREJkEpXRCRAKl0RkQCpdEVEAqTSFWuMMTljTLd1e40xY40xaRuZRKpNp4yJNcaYHDDGdd3FtrOIBEUjXQmEMabRGJM1xrxijJltjJlQ/Na5xpi/G2NeM8bsXHzsqcaY3xW/vtUYc70x5lljzNvGGN1CXiJNpStBORKY77ru7q7rjgamFP98seu6e+LdaeHHG9g2BRyIt5jO9caYPtUOK1ItKl0JymvAocaYq4wxX3Vdd+3dcB8sfp6JV649udd13YLrunOA9/DWbxCJpHrbASQeXNd92xizF/B14EpjzNTit9YuNNPJhl+Pnz7woAMRElka6UogjDGjgBbXdf8E/ArY08fm44wxdcaY7YBtgbeqkVEkCBrpSlB2Bf7bGLN2ZbCzgfs3c9u38BYUHw58z3Xd1upEFKk+nTImoWaMuRV4zHXdzS1okVDT9IKISIA00hURCZBGuiIiAVLpiogESKUrIhIgla6ISIBUuiIiAfp/PJwd78YowPYAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if AUGMENTATION:\n",
" _, count = np.unique(labels, return_counts=True)\n",
"\n",
" plt.pie(count,\n",
" explode=(0,0),\n",
" labels=class_names,\n",
" autopct=\"%1.2f%%\")\n",
" plt.axis('equal');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## One Hot Encoding Variables\n",
"\n",
"The labels numpy array is one hot encoded using to_categorical from keras. This removes any uncessary bias in the dataset, by keeping the class at equal footing, with respect to labels."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"labels = to_categorical(labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training, Validation and Testing\n",
"\n",
"Training, Validation and Testing\n",
"Instead of using train_test_split the images and labels arrays are randomly shuffled using the same seed value set at 42. This allows the images and their corresponding labels to remain linked even after shuffling.\n",
"\n",
"This method allows the user to make all 3 datasets. The training and validation dataset is used for training the model while the testing dataset is used for testing the model on unseen data. Unseen data is used for simulating real-world prediction, as the model has not seen this data before. It allows the developers to see how robust the model is."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"np.random.shuffle(images)\n",
"\n",
"np.random.seed(42)\n",
"np.random.shuffle(labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Spliting of data\n",
"\n",
"1. 70% - Training\n",
"2. 20% - Validation\n",
"3. 10% - Testing"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((4200, 48, 48, 3), (1200, 48, 48, 3), (600, 48, 48, 3))"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_count = len(images)\n",
"total_count\n",
"\n",
"train = int(0.7*total_count)\n",
"val = int(0.2*total_count)\n",
"test = int(0.1*total_count)\n",
"\n",
"train_images, train_labels = images[:train], labels[:train]\n",
"val_images, val_labels = images[train:(val+train)], labels[train:(val+train)]\n",
"test_images, test_labels = images[-test:], labels[-test:]\n",
"\n",
"train_images.shape, val_images.shape, test_images.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If AUGMENTATION is set True, then the number of images per class is balanced. If AUGMENTATION is set to False, then compute the class weights given below and accordingly change the fit function of the Keras API when training."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"if not AUGMENTATION:\n",
" count_labels = train_labels.sum(axis=0)\n",
"\n",
" classTotals = train_labels.sum(axis=0)\n",
" classWeight = {}\n",
"\n",
" for i in range(0,len(classTotals)):\n",
" classWeight[i] = classTotals.max()/classTotals[i]\n",
" pass\n",
" print(classWeight)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creation of model\n",
"\n",
"1. conv_block - This function contains the convulotional layer, batch normalization and activation layers. The number of filters, kernel_size, strides to be taken are defined by the developer. This allows a developer to make the model without having to repeat the same lines continuously many times. It also uses the OOPs concepts of Python which is recommended instead of coding like it is C.\n",
"\n",
"2. basic_model - This function creates the model using the aforementioned function, max pooling layers and dropouts. After the specified number of convolutional layers, a flatten layer is introduced, along with dense layers so that the image can be classified. The flatten layer converts the feature map produced by the convolutional layers into a single column for classification.\n",
"\n",
"For image detection in images using Faster R-CNN the feature map produced by the convolutional layers is used, i.e., the model drops the layers after the final convolutional block and passes the generated feature map to a Regional Proposal Network, which can either uses a vanilla CNN model containing fully connected layers or a Logistic Regression, Support Vector machines or Random forests. Advised to use vanilla CNN as it uses less CPU and memory, plus slightly faster and capable of giving multiple outputs if required."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Explanation of features\n",
"\n",
"1. Conv2D - This is a 2 dimensional convolutional layer, the number of filters decide what the convolutional layer learns. Greater the number of filters, greater the amount of information obtained.\n",
"2. MaxPooling2D - This reduces the spatial dimensions of the feature map produced by the convolutional layer without losing any range information. This allows a model to become slightly more robust\n",
"3. Dropout - This removes a user-defined percentage of links between neurons of consecutive layers. This allows the model to be robust. It can be used in both fully convolutional layers and fully connected layers.\n",
"4. BatchNormalization - This layer normalises the values present in the hidden part of the neural network. This is similar to MinMax/Standard scaling applied in machine learning algorithms\n",
"5. Padding- This pads the feature map/input image with zeros allowing border features to stay."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def conv_block(X,k,filters,stage,block,s=2):\n",
" \n",
" conv_base_name = 'conv_' + str(stage)+block+'_branch'\n",
" bn_base_name = 'bn_'+str(stage)+block+\"_branch\"\n",
" \n",
" F1 = filters\n",
" \n",
" X = Conv2D(filters=F1, kernel_size=(k,k), strides=(s,s),\n",
" padding='same',name=conv_base_name+'2a')(X)\n",
" X = BatchNormalization(name=bn_base_name+'2a')(X)\n",
" X = Activation('relu')(X)\n",
" \n",
" return X\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creation of Model\n",
"\n",
"### Block 1\n",
"\n",
"1. An input layer is initialised using the Input Keras layer, this defines the number of neurons present in the input layer\n",
"2. ZeroPadding is applied to the input image, so that boundary features are not lost.\n",
"\n",
"### Block 2\n",
"\n",
"First Convolutioanl Layer, it starts with 16 filters and kernel size with (3,3) and strides (2,2). Padding is maintaned same, so the image does not chaneg spatially, until the next block in which MaxPooling occurs\n",
"\n",
"### Block 3 - 4\n",
"\n",
"Similar structure in both with a convolutional layer followed by a MaxPooling and Dropout layers.\n",
"\n",
"### Output Block\n",
"\n",
"The feature map produced by the previous convolutional layers is converted into a single column using Flatten Layer and the classified using a Dense layer(output layer) with the number of classes present in the dataset, and sigmoid as activation function."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def basic_model(input_shape,classes):\n",
" \n",
" X_input = tf.keras.Input(input_shape)\n",
" \n",
" X = ZeroPadding2D((5,5))(X_input)\n",
" \n",
" X = Conv2D(16,(3,3),strides=(2,2),name='conv1',padding=\"same\")(X)\n",
" X = BatchNormalization(name='bn_conv1')(X)\n",
" \n",
" # stage 2\n",
" X = conv_block(X,3,32,2,block='A',s=1)\n",
" X = MaxPooling2D((2,2))(X)\n",
" X = Dropout(0.25)(X)\n",
"\n",
"# Stage 3\n",
" X = conv_block(X,5,32,3,block='A',s=2)\n",
" X = MaxPooling2D((2,2))(X)\n",
" X = Dropout(0.25)(X)\n",
" \n",
"# Stage 4\n",
" X = conv_block(X,3,64,4,block='A',s=1)\n",
" X = MaxPooling2D((2,2))(X)\n",
" X = Dropout(0.25)(X)\n",
" \n",
"# Output Layer\n",
" X = Flatten()(X)\n",
" X = Dense(64)(X)\n",
" X = Dropout(0.5)(X)\n",
" \n",
" X = Dense(128)(X)\n",
" X = Activation(\"relu\")(X)\n",
" \n",
" X = Dense(classes,activation=\"softmax\",name=\"fc\"+str(classes))(X)\n",
" \n",
" model = Model(inputs=X_input,outputs=X,name='Feature_Extraction_and_FC')\n",
" \n",
" return model\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'tf' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32mf:\\Machine Learning\\ML Project\\ShipsSatelliteImageClassification\\main.ipynb Cell 36\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> <a href='vscode-notebook-cell:/f%3A/Machine%20Learning/ML%20Project/ShipsSatelliteImageClassification/main.ipynb#X50sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m model \u001b[39m=\u001b[39m basic_model(input_shape\u001b[39m=\u001b[39;49m(\u001b[39m48\u001b[39;49m,\u001b[39m48\u001b[39;49m,\u001b[39m3\u001b[39;49m),classes\u001b[39m=\u001b[39;49m\u001b[39m2\u001b[39;49m)\n",
"\u001b[1;32mf:\\Machine Learning\\ML Project\\ShipsSatelliteImageClassification\\main.ipynb Cell 36\u001b[0m in \u001b[0;36mbasic_model\u001b[1;34m(input_shape, classes)\u001b[0m\n\u001b[0;32m <a href='vscode-notebook-cell:/f%3A/Machine%20Learning/ML%20Project/ShipsSatelliteImageClassification/main.ipynb#X50sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbasic_model\u001b[39m(input_shape,classes):\n\u001b[1;32m----> <a href='vscode-notebook-cell:/f%3A/Machine%20Learning/ML%20Project/ShipsSatelliteImageClassification/main.ipynb#X50sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m X_input \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39mkeras\u001b[39m.\u001b[39mInput(input_shape)\n\u001b[0;32m <a href='vscode-notebook-cell:/f%3A/Machine%20Learning/ML%20Project/ShipsSatelliteImageClassification/main.ipynb#X50sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m X \u001b[39m=\u001b[39m ZeroPadding2D((\u001b[39m5\u001b[39m,\u001b[39m5\u001b[39m))(X_input)\n\u001b[0;32m <a href='vscode-notebook-cell:/f%3A/Machine%20Learning/ML%20Project/ShipsSatelliteImageClassification/main.ipynb#X50sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m X \u001b[39m=\u001b[39m Conv2D(\u001b[39m16\u001b[39m,(\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m),strides\u001b[39m=\u001b[39m(\u001b[39m2\u001b[39m,\u001b[39m2\u001b[39m),name\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mconv1\u001b[39m\u001b[39m'\u001b[39m,padding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39msame\u001b[39m\u001b[39m\"\u001b[39m)(X)\n",
"\u001b[1;31mNameError\u001b[0m: name 'tf' is not defined"
]
}
],
"source": [
"model = basic_model(input_shape=(48,48,3),classes=2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "5819c1eaf6d552792a1bbc5e8998e6c2149ab26a1973a0d78107c0d9954e5ba0"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}