{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Brain Stroke prediction- DecisionTree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Context:\n",
"\n",
"A stroke is a medical condition in which poor blood flow to the brain causes cell death. There are two main types of stroke: ischemic, due to lack of blood flow, and hemorrhagic, due to bleeding. Both cause parts of the brain to stop functioning properly. Signs and symptoms of a stroke may include an inability to move or feel on one side of the body, problems understanding or speaking, dizziness, or loss of vision to one side. Signs and symptoms often appear soon after the stroke has occurred. If symptoms last less than one or two hours, the stroke is a transient ischemic attack (TIA), also called a mini-stroke. A hemorrhagic stroke may also be associated with a severe headache. The symptoms of a stroke can be permanent. Long-term complications may include pneumonia and loss of bladder control.\n",
"\n",
"The main risk factor for stroke is high blood pressure. Other risk factors include high blood cholesterol, tobacco smoking, obesity, diabetes mellitus, a previous TIA, end-stage kidney disease, and atrial fibrillation. An ischemic stroke is typically caused by blockage of a blood vessel, though there are also less common causes. A hemorrhagic stroke is caused by either bleeding directly into the brain or into the space between the brain's membranes. Bleeding may occur due to a ruptured brain aneurysm. Diagnosis is typically based on a physical exam and supported by medical imaging such as a CT scan or MRI scan. A CT scan can rule out bleeding, but may not necessarily rule out ischemia, which early on typically does not show up on a CT scan. Other tests such as an electrocardiogram (ECG) and blood tests are done to determine risk factors and rule out other possible causes. Low blood sugar may cause similar symptoms.\n",
"\n",
"Prevention includes decreasing risk factors, surgery to open up the arteries to the brain in those with problematic carotid narrowing, and warfarin in people with atrial fibrillation. Aspirin or statins may be recommended by physicians for prevention. A stroke or TIA often requires emergency care. An ischemic stroke, if detected within three to four and half hours, may be treatable with a medication that can break down the clot. Some hemorrhagic strokes benefit from surgery. Treatment to attempt recovery of lost function is called stroke rehabilitation, and ideally takes place in a stroke unit; however, these are not available in much of the world.\n",
"\n",
"### Attribute Information:\n",
"\n",
"1) gender: \"Male\", \"Female\" or \"Other\"\n",
"2) age: age of the patient\n",
"3) hypertension: 0 if the patient doesn't have hypertension, 1 if the patient has hypertension\n",
"4) heartdisease: 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease 5) evermarried: \"No\" or \"Yes\"\n",
"6) worktype: \"children\", \"Govtjov\", \"Neverworked\", \"Private\" or \"Self-employed\" 7) Residencetype: \"Rural\" or \"Urban\"\n",
"8) avgglucoselevel: average glucose level in blood\n",
"9) bmi: body mass index\n",
"10) smoking_status: \"formerly smoked\", \"never smoked\", \"smokes\" or \"Unknown\"*\n",
"11) stroke: 1 if the patient had a stroke or 0 if not\n",
"\n",
"*Note: \"Unknown\" in smoking_status means that the information is unavailable for this patient"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## IMPORTING LIBRARIES AND LOADING DATA"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np \n",
"import pandas as pd \n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import matthews_corrcoef\n",
"from sklearn.metrics import f1_score\n",
"\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.svm import SVC\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.neural_network import MLPClassifier\n",
"from sklearn.ensemble import StackingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"import joblib as joblib"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" gender | \n",
" age | \n",
" hypertension | \n",
" heart_disease | \n",
" ever_married | \n",
" work_type | \n",
" Residence_type | \n",
" avg_glucose_level | \n",
" bmi | \n",
" smoking_status | \n",
" stroke | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 30669 | \n",
" Male | \n",
" 3.0 | \n",
" 0 | \n",
" 0 | \n",
" No | \n",
" children | \n",
" Rural | \n",
" 95.12 | \n",
" 18.0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 30468 | \n",
" Male | \n",
" 58.0 | \n",
" 1 | \n",
" 0 | \n",
" Yes | \n",
" Private | \n",
" Urban | \n",
" 87.96 | \n",
" 39.2 | \n",
" never smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 16523 | \n",
" Female | \n",
" 8.0 | \n",
" 0 | \n",
" 0 | \n",
" No | \n",
" Private | \n",
" Urban | \n",
" 110.89 | \n",
" 17.6 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 56543 | \n",
" Female | \n",
" 70.0 | \n",
" 0 | \n",
" 0 | \n",
" Yes | \n",
" Private | \n",
" Rural | \n",
" 69.04 | \n",
" 35.9 | \n",
" formerly smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 46136 | \n",
" Male | \n",
" 14.0 | \n",
" 0 | \n",
" 0 | \n",
" No | \n",
" Never_worked | \n",
" Rural | \n",
" 161.28 | \n",
" 19.1 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 5 | \n",
" 32257 | \n",
" Female | \n",
" 47.0 | \n",
" 0 | \n",
" 0 | \n",
" Yes | \n",
" Private | \n",
" Urban | \n",
" 210.95 | \n",
" 50.1 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 6 | \n",
" 52800 | \n",
" Female | \n",
" 52.0 | \n",
" 0 | \n",
" 0 | \n",
" Yes | \n",
" Private | \n",
" Urban | \n",
" 77.59 | \n",
" 17.7 | \n",
" formerly smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 7 | \n",
" 41413 | \n",
" Female | \n",
" 75.0 | \n",
" 0 | \n",
" 1 | \n",
" Yes | \n",
" Self-employed | \n",
" Rural | \n",
" 243.53 | \n",
" 27.0 | \n",
" never smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 8 | \n",
" 15266 | \n",
" Female | \n",
" 32.0 | \n",
" 0 | \n",
" 0 | \n",
" Yes | \n",
" Private | \n",
" Rural | \n",
" 77.67 | \n",
" 32.3 | \n",
" smokes | \n",
" 0 | \n",
"
\n",
" \n",
" | 9 | \n",
" 28674 | \n",
" Female | \n",
" 74.0 | \n",
" 1 | \n",
" 0 | \n",
" Yes | \n",
" Self-employed | \n",
" Urban | \n",
" 205.84 | \n",
" 54.6 | \n",
" never smoked | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 30669 Male 3.0 0 0 No \n",
"1 30468 Male 58.0 1 0 Yes \n",
"2 16523 Female 8.0 0 0 No \n",
"3 56543 Female 70.0 0 0 Yes \n",
"4 46136 Male 14.0 0 0 No \n",
"5 32257 Female 47.0 0 0 Yes \n",
"6 52800 Female 52.0 0 0 Yes \n",
"7 41413 Female 75.0 0 1 Yes \n",
"8 15266 Female 32.0 0 0 Yes \n",
"9 28674 Female 74.0 1 0 Yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 children Rural 95.12 18.0 NaN \n",
"1 Private Urban 87.96 39.2 never smoked \n",
"2 Private Urban 110.89 17.6 NaN \n",
"3 Private Rural 69.04 35.9 formerly smoked \n",
"4 Never_worked Rural 161.28 19.1 NaN \n",
"5 Private Urban 210.95 50.1 NaN \n",
"6 Private Urban 77.59 17.7 formerly smoked \n",
"7 Self-employed Rural 243.53 27.0 never smoked \n",
"8 Private Rural 77.67 32.3 smokes \n",
"9 Self-employed Urban 205.84 54.6 never smoked \n",
"\n",
" stroke \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 \n",
"5 0 \n",
"6 0 \n",
"7 0 \n",
"8 0 \n",
"9 0 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('trainFile.csv')\n",
"df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DATA EXPLORATION"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 43400 entries, 0 to 43399\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 id 43400 non-null int64 \n",
" 1 gender 43400 non-null object \n",
" 2 age 43400 non-null float64\n",
" 3 hypertension 43400 non-null int64 \n",
" 4 heart_disease 43400 non-null int64 \n",
" 5 ever_married 43400 non-null object \n",
" 6 work_type 43400 non-null object \n",
" 7 Residence_type 43400 non-null object \n",
" 8 avg_glucose_level 43400 non-null float64\n",
" 9 bmi 41938 non-null float64\n",
" 10 smoking_status 30108 non-null object \n",
" 11 stroke 43400 non-null int64 \n",
"dtypes: float64(3), int64(4), object(5)\n",
"memory usage: 4.0+ MB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" age | \n",
" hypertension | \n",
" heart_disease | \n",
" avg_glucose_level | \n",
" bmi | \n",
" stroke | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 43400.000000 | \n",
" 43400.000000 | \n",
" 43400.000000 | \n",
" 43400.000000 | \n",
" 43400.000000 | \n",
" 41938.000000 | \n",
" 43400.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" 36326.142350 | \n",
" 42.217894 | \n",
" 0.093571 | \n",
" 0.047512 | \n",
" 104.482750 | \n",
" 28.605038 | \n",
" 0.018041 | \n",
"
\n",
" \n",
" | std | \n",
" 21072.134879 | \n",
" 22.519649 | \n",
" 0.291235 | \n",
" 0.212733 | \n",
" 43.111751 | \n",
" 7.770020 | \n",
" 0.133103 | \n",
"
\n",
" \n",
" | min | \n",
" 1.000000 | \n",
" 0.080000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 55.000000 | \n",
" 10.100000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 25% | \n",
" 18038.500000 | \n",
" 24.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 77.540000 | \n",
" 23.200000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 50% | \n",
" 36351.500000 | \n",
" 44.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 91.580000 | \n",
" 27.700000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 75% | \n",
" 54514.250000 | \n",
" 60.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 112.070000 | \n",
" 32.900000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | max | \n",
" 72943.000000 | \n",
" 82.000000 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
" 291.050000 | \n",
" 97.600000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id age hypertension heart_disease \\\n",
"count 43400.000000 43400.000000 43400.000000 43400.000000 \n",
"mean 36326.142350 42.217894 0.093571 0.047512 \n",
"std 21072.134879 22.519649 0.291235 0.212733 \n",
"min 1.000000 0.080000 0.000000 0.000000 \n",
"25% 18038.500000 24.000000 0.000000 0.000000 \n",
"50% 36351.500000 44.000000 0.000000 0.000000 \n",
"75% 54514.250000 60.000000 0.000000 0.000000 \n",
"max 72943.000000 82.000000 1.000000 1.000000 \n",
"\n",
" avg_glucose_level bmi stroke \n",
"count 43400.000000 41938.000000 43400.000000 \n",
"mean 104.482750 28.605038 0.018041 \n",
"std 43.111751 7.770020 0.133103 \n",
"min 55.000000 10.100000 0.000000 \n",
"25% 77.540000 23.200000 0.000000 \n",
"50% 91.580000 27.700000 0.000000 \n",
"75% 112.070000 32.900000 0.000000 \n",
"max 291.050000 97.600000 1.000000 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Male' 'Female' 'Other']\n",
"['children' 'Private' 'Never_worked' 'Self-employed' 'Govt_job']\n",
"['Rural' 'Urban']\n",
"[nan 'never smoked' 'formerly smoked' 'smokes']\n",
"['No' 'Yes']\n"
]
}
],
"source": [
"print(df['gender'].unique())\n",
"print(df['work_type'].unique())\n",
"print(df['Residence_type'].unique())\n",
"print(df['smoking_status'].unique())\n",
"print(df['ever_married'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DATA PREPROCESSING"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" gender | \n",
" age | \n",
" hypertension | \n",
" heart_disease | \n",
" ever_married | \n",
" work_type | \n",
" Residence_type | \n",
" avg_glucose_level | \n",
" bmi | \n",
" smoking_status | \n",
" stroke | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 30669 | \n",
" 0 | \n",
" 3.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" children | \n",
" Rural | \n",
" 95.12 | \n",
" 18.0 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 30468 | \n",
" 0 | \n",
" 58.0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" Private | \n",
" Urban | \n",
" 87.96 | \n",
" 39.2 | \n",
" never smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 16523 | \n",
" 1 | \n",
" 8.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" Private | \n",
" Urban | \n",
" 110.89 | \n",
" 17.6 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 56543 | \n",
" 1 | \n",
" 70.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" Private | \n",
" Rural | \n",
" 69.04 | \n",
" 35.9 | \n",
" formerly smoked | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 46136 | \n",
" 0 | \n",
" 14.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" Never_worked | \n",
" Rural | \n",
" 161.28 | \n",
" 19.1 | \n",
" NaN | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 30669 0 3.0 0 0 0 \n",
"1 30468 0 58.0 1 0 1 \n",
"2 16523 1 8.0 0 0 0 \n",
"3 56543 1 70.0 0 0 1 \n",
"4 46136 0 14.0 0 0 0 \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 children Rural 95.12 18.0 NaN \n",
"1 Private Urban 87.96 39.2 never smoked \n",
"2 Private Urban 110.89 17.6 NaN \n",
"3 Private Rural 69.04 35.9 formerly smoked \n",
"4 Never_worked Rural 161.28 19.1 NaN \n",
"\n",
" stroke \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['ever_married'] = [ 0 if i !='Yes' else 1 for i in df['ever_married'] ]\n",
"df['gender'] = [0 if i != 'Female' else 1 for i in df['gender']]\n",
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" id | \n",
" gender | \n",
" age | \n",
" hypertension | \n",
" heart_disease | \n",
" ever_married | \n",
" avg_glucose_level | \n",
" bmi | \n",
" stroke | \n",
" work_type_Govt_job | \n",
" work_type_Never_worked | \n",
" work_type_Private | \n",
" work_type_Self-employed | \n",
" work_type_children | \n",
" Residence_type_Rural | \n",
" Residence_type_Urban | \n",
" smoking_status_formerly smoked | \n",
" smoking_status_never smoked | \n",
" smoking_status_smokes | \n",
"
\n",
" \n",
" \n",
" \n",
" | 34642 | \n",
" 71716 | \n",
" 0 | \n",
" 60.0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 97.47 | \n",
" 29.5 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" | 18850 | \n",
" 66277 | \n",
" 0 | \n",
" 48.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 70.00 | \n",
" 36.4 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 17110 | \n",
" 56403 | \n",
" 0 | \n",
" 8.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 135.14 | \n",
" 22.8 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 6769 | \n",
" 49943 | \n",
" 1 | \n",
" 62.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 82.88 | \n",
" 41.3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 38044 | \n",
" 19764 | \n",
" 1 | \n",
" 36.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 107.91 | \n",
" 24.5 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"34642 71716 0 60.0 0 1 1 \n",
"18850 66277 0 48.0 0 0 1 \n",
"17110 56403 0 8.0 0 0 0 \n",
"6769 49943 1 62.0 0 0 1 \n",
"38044 19764 1 36.0 0 0 1 \n",
"\n",
" avg_glucose_level bmi stroke work_type_Govt_job \\\n",
"34642 97.47 29.5 0 0 \n",
"18850 70.00 36.4 0 0 \n",
"17110 135.14 22.8 0 0 \n",
"6769 82.88 41.3 0 0 \n",
"38044 107.91 24.5 0 1 \n",
"\n",
" work_type_Never_worked work_type_Private work_type_Self-employed \\\n",
"34642 0 1 0 \n",
"18850 0 1 0 \n",
"17110 0 0 0 \n",
"6769 0 1 0 \n",
"38044 0 0 0 \n",
"\n",
" work_type_children Residence_type_Rural Residence_type_Urban \\\n",
"34642 0 0 1 \n",
"18850 0 1 0 \n",
"17110 1 0 1 \n",
"6769 0 0 1 \n",
"38044 0 0 1 \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"34642 0 1 \n",
"18850 0 0 \n",
"17110 0 0 \n",
"6769 0 0 \n",
"38044 0 0 \n",
"\n",
" smoking_status_smokes \n",
"34642 0 \n",
"18850 0 \n",
"17110 0 \n",
"6769 0 \n",
"38044 0 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.get_dummies(df, columns = ['work_type', 'Residence_type','smoking_status'])\n",
"df.sample(5)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"avg_glucose_level 0\n",
"bmi 1462\n",
"stroke 0\n",
"work_type_Govt_job 0\n",
"work_type_Never_worked 0\n",
"work_type_Private 0\n",
"work_type_Self-employed 0\n",
"work_type_children 0\n",
"Residence_type_Rural 0\n",
"Residence_type_Urban 0\n",
"smoking_status_formerly smoked 0\n",
"smoking_status_never smoked 0\n",
"smoking_status_smokes 0\n",
"dtype: int64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.isnull().sum()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df = df.dropna(how = 'any', axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"X = df.drop(['stroke','id'], axis = 1)\n",
"y = df['stroke']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" gender | \n",
" age | \n",
" hypertension | \n",
" heart_disease | \n",
" ever_married | \n",
" avg_glucose_level | \n",
" bmi | \n",
" work_type_Govt_job | \n",
" work_type_Never_worked | \n",
" work_type_Private | \n",
" work_type_Self-employed | \n",
" work_type_children | \n",
" Residence_type_Rural | \n",
" Residence_type_Urban | \n",
" smoking_status_formerly smoked | \n",
" smoking_status_never smoked | \n",
" smoking_status_smokes | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 3.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 95.12 | \n",
" 18.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
" 58.0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 87.96 | \n",
" 39.2 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 8.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 110.89 | \n",
" 17.6 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1 | \n",
" 70.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 69.04 | \n",
" 35.9 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 14.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 161.28 | \n",
" 19.1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1 | \n",
" 47.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 210.95 | \n",
" 50.1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 6 | \n",
" 1 | \n",
" 52.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 77.59 | \n",
" 17.7 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 7 | \n",
" 1 | \n",
" 75.0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 243.53 | \n",
" 27.0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" | 8 | \n",
" 1 | \n",
" 32.0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 77.67 | \n",
" 32.3 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 9 | \n",
" 1 | \n",
" 74.0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 205.84 | \n",
" 54.6 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" gender age hypertension heart_disease ever_married avg_glucose_level \\\n",
"0 0 3.0 0 0 0 95.12 \n",
"1 0 58.0 1 0 1 87.96 \n",
"2 1 8.0 0 0 0 110.89 \n",
"3 1 70.0 0 0 1 69.04 \n",
"4 0 14.0 0 0 0 161.28 \n",
"5 1 47.0 0 0 1 210.95 \n",
"6 1 52.0 0 0 1 77.59 \n",
"7 1 75.0 0 1 1 243.53 \n",
"8 1 32.0 0 0 1 77.67 \n",
"9 1 74.0 1 0 1 205.84 \n",
"\n",
" bmi work_type_Govt_job work_type_Never_worked work_type_Private \\\n",
"0 18.0 0 0 0 \n",
"1 39.2 0 0 1 \n",
"2 17.6 0 0 1 \n",
"3 35.9 0 0 1 \n",
"4 19.1 0 1 0 \n",
"5 50.1 0 0 1 \n",
"6 17.7 0 0 1 \n",
"7 27.0 0 0 0 \n",
"8 32.3 0 0 1 \n",
"9 54.6 0 0 0 \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Rural \\\n",
"0 0 1 1 \n",
"1 0 0 0 \n",
"2 0 0 0 \n",
"3 0 0 1 \n",
"4 0 0 1 \n",
"5 0 0 0 \n",
"6 0 0 0 \n",
"7 1 0 1 \n",
"8 0 0 1 \n",
"9 1 0 0 \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"0 0 0 \n",
"1 1 0 \n",
"2 1 0 \n",
"3 0 1 \n",
"4 0 0 \n",
"5 1 0 \n",
"6 1 1 \n",
"7 0 0 \n",
"8 0 0 \n",
"9 1 0 \n",
"\n",
" smoking_status_never smoked smoking_status_smokes \n",
"0 0 0 \n",
"1 1 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"5 0 0 \n",
"6 0 0 \n",
"7 1 0 \n",
"8 0 1 \n",
"9 1 0 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Target and Feature values / Train Test Split"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((33550, 17), (8388, 17))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train, X_test, y_train , y_test = train_test_split(X,y, test_size = 0.2, random_state = 42)\n",
"X_train.shape, X_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MODEL BUILDING"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### KNeighborsClassifier\n",
"\n",
"Classifier implementing the k-nearest neighbors vote.\n",
"\n",
"class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, *, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.9855141579731743\n",
"- MCC: 0.2714217088023517\n",
"- F1 score: 0.9803223813632442\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9834287076776347\n",
"- MCC: -0.004865438031361363\n",
"- F1 score: 0.9767491714651221\n"
]
}
],
"source": [
"knn = KNeighborsClassifier(3) # Define classifier\n",
"knn.fit(X_train, y_train) # Train model\n",
"\n",
"# Make predictions\n",
"y_train_pred = knn.predict(X_train)\n",
"y_test_pred = knn.predict(X_test)\n",
"\n",
"# Training set performance\n",
"knn_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"knn_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"knn_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set performance\n",
"knn_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"knn_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"knn_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % knn_train_accuracy)\n",
"print('- MCC: %s' % knn_train_mcc)\n",
"print('- F1 score: %s' % knn_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % knn_test_accuracy)\n",
"print('- MCC: %s' % knn_test_mcc)\n",
"print('- F1 score: %s' % knn_test_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Support Vector Classification (SVC)\n",
"\n",
"C-Support Vector Classification.\n",
"SVC, NuSVC and LinearSVC are classes capable of performing binary and multi-class classification on a dataset.\n",
"\n",
"The implementation is based on libsvm. The fit time scales at least quadratically with the number of samples and may be impractical beyond tens of thousands of samples.\n",
"\n",
"class sklearn.svm.SVC(*, C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', break_ties=False, random_state=None)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.9845901639344262\n",
"- MCC: 0.0\n",
"- F1 score: 0.9769450726235194\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9849785407725322\n",
"- MCC: 0.0\n",
"- F1 score: 0.9775246491126318\n"
]
}
],
"source": [
"svm_rbf = SVC(gamma='auto', C=1)\n",
"svm_rbf.fit(X_train, y_train)\n",
"\n",
"# Make predictions\n",
"y_train_pred = svm_rbf.predict(X_train)\n",
"y_test_pred = svm_rbf.predict(X_test)\n",
"\n",
"# Training set performance\n",
"svm_rbf_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"svm_rbf_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"svm_rbf_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set performance\n",
"svm_rbf_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"svm_rbf_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"svm_rbf_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % svm_rbf_train_accuracy)\n",
"print('- MCC: %s' % svm_rbf_train_mcc)\n",
"print('- F1 score: %s' % svm_rbf_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % svm_rbf_test_accuracy)\n",
"print('- MCC: %s' % svm_rbf_test_mcc)\n",
"print('- F1 score: %s' % svm_rbf_test_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Decision Tree Classifier\n",
"\n",
"A decision tree classifier.\n",
"Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. A tree can be seen as a piecewise constant approximation.\n",
"\n",
"class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.9848286140089418\n",
"- MCC: 0.12344663420900479\n",
"- F1 score: 0.9775321005492623\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9845016690510253\n",
"- MCC: -0.002697410465010962\n",
"- F1 score: 0.9772861696142702\n"
]
}
],
"source": [
"dt = DecisionTreeClassifier(max_depth=5) # Define classifier\n",
"dt.fit(X_train, y_train) # Train model\n",
"\n",
"# Make predictions\n",
"y_train_pred = dt.predict(X_train)\n",
"y_test_pred = dt.predict(X_test)\n",
"\n",
"# Training set performance\n",
"dt_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"dt_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"dt_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set performance\n",
"dt_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"dt_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"dt_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % dt_train_accuracy)\n",
"print('- MCC: %s' % dt_train_mcc)\n",
"print('- F1 score: %s' % dt_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % dt_test_accuracy)\n",
"print('- MCC: %s' % dt_test_mcc)\n",
"print('- F1 score: %s' % dt_test_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random Forest Classifier\n",
"\n",
"A random forest classifier.\n",
"\n",
"A random forest is a meta estimator that fits a number of decision tree classifiers on various sub-samples of the dataset and uses averaging to improve the predictive accuracy and control over-fitting. The sub-sample size is controlled with the max_samples parameter if bootstrap=True (default), otherwise the whole dataset is used to build each tree.\n",
"\n",
"class sklearn.ensemble.RandomForestClassifier(n_estimators=100, *, criterion='gini', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features='sqrt', max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, class_weight=None, ccp_alpha=0.0, max_samples=None)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.996602086438152\n",
"- MCC: 0.8813717901811132\n",
"- F1 score: 0.9963944901128202\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9843824511206486\n",
"- MCC: -0.003015976451863439\n",
"- F1 score: 0.9772265318304354\n"
]
}
],
"source": [
"rf = RandomForestClassifier(n_estimators=10) # Define classifier\n",
"rf.fit(X_train, y_train) # Train model\n",
"\n",
"# Make predictions\n",
"y_train_pred = rf.predict(X_train)\n",
"y_test_pred = rf.predict(X_test)\n",
"\n",
"# Training set performance\n",
"rf_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"rf_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"rf_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set performance\n",
"rf_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"rf_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"rf_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % rf_train_accuracy)\n",
"print('- MCC: %s' % rf_train_mcc)\n",
"print('- F1 score: %s' % rf_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % rf_test_accuracy)\n",
"print('- MCC: %s' % rf_test_mcc)\n",
"print('- F1 score: %s' % rf_test_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-layer Perceptron classifier\n",
"\n",
"This model optimizes the log-loss function using LBFGS or stochastic gradient descent.\n",
"\n",
"class sklearn.neural_network.MLPClassifier(hidden_layer_sizes=(100,), activation='relu', *, solver='adam', alpha=0.0001, batch_size='auto', learning_rate='constant', learning_rate_init=0.001, power_t=0.5, max_iter=200, shuffle=True, random_state=None, tol=0.0001, verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08, n_iter_no_change=10, max_fun=15000)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.9845901639344262\n",
"- MCC: 0.0\n",
"- F1 score: 0.9769450726235194\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9849785407725322\n",
"- MCC: 0.0\n",
"- F1 score: 0.9775246491126318\n"
]
}
],
"source": [
"mlp = MLPClassifier(alpha=1, max_iter=1000)\n",
"mlp.fit(X_train, y_train)\n",
"\n",
"# Make predictions\n",
"y_train_pred = mlp.predict(X_train)\n",
"y_test_pred = mlp.predict(X_test)\n",
"\n",
"# Training set performance\n",
"mlp_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"mlp_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"mlp_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set performance\n",
"mlp_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"mlp_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"mlp_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % mlp_train_accuracy)\n",
"print('- MCC: %s' % mlp_train_mcc)\n",
"print('- F1 score: %s' % mlp_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % mlp_test_accuracy)\n",
"print('- MCC: %s' % mlp_test_mcc)\n",
"print('- F1 score: %s' % mlp_test_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Stacking Classifier\n",
"\n",
"Stack of estimators with a final classifier.\n",
"\n",
"Stacked generalization consists in stacking the output of individual estimator and use a classifier to compute the final prediction. Stacking allows to use the strength of each individual estimator by using their output as input of a final estimator.\n",
"\n",
"Note that estimators_ are fitted on the full X while final_estimator_ is trained using cross-validated predictions of the base estimators using cross_val_predict.\n",
"\n",
"class sklearn.ensemble.StackingClassifier(estimators, final_estimator=None, *, cv=None, stack_method='auto', n_jobs=None, passthrough=False, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model performance for Training set\n",
"- Accuracy: 0.9848286140089418\n",
"- MCC: 0.12344663420900479\n",
"- F1 score: 0.9775321005492623\n",
"----------------------------------\n",
"Model performance for Test set\n",
"- Accuracy: 0.9849785407725322\n",
"- MCC: 0.0\n",
"- F1 score: 0.9775246491126318\n"
]
}
],
"source": [
"estimator_list = [\n",
" ('knn',knn),\n",
" ('svm_rbf',svm_rbf),\n",
" ('dt',dt),\n",
" ('rf',rf),\n",
" ('mlp',mlp) ]\n",
"\n",
"# Build stack model\n",
"stack_model = StackingClassifier(\n",
" estimators=estimator_list, final_estimator=LogisticRegression()\n",
")\n",
"\n",
"# Train stacked model\n",
"stack_model.fit(X_train, y_train)\n",
"\n",
"# Make predictions\n",
"y_train_pred = stack_model.predict(X_train)\n",
"y_test_pred = stack_model.predict(X_test)\n",
"\n",
"# Training set model performance\n",
"stack_model_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy\n",
"stack_model_train_mcc = matthews_corrcoef(y_train, y_train_pred) # Calculate MCC\n",
"stack_model_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score\n",
"\n",
"# Test set model performance\n",
"stack_model_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy\n",
"stack_model_test_mcc = matthews_corrcoef(y_test, y_test_pred) # Calculate MCC\n",
"stack_model_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score\n",
"\n",
"print('Model performance for Training set')\n",
"print('- Accuracy: %s' % stack_model_train_accuracy)\n",
"print('- MCC: %s' % stack_model_train_mcc)\n",
"print('- F1 score: %s' % stack_model_train_f1)\n",
"print('----------------------------------')\n",
"print('Model performance for Test set')\n",
"print('- Accuracy: %s' % stack_model_test_accuracy)\n",
"print('- MCC: %s' % stack_model_test_mcc)\n",
"print('- F1 score: %s' % stack_model_test_f1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Run the Line 2 only then model changed!!!\n",
"# joblib.dump(stack_model, 'stroke-prediction-model.joblib')\n",
"\n",
"model = joblib.load('stroke-prediction-model.joblib') # Model File"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## PREDICTION TESTING"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"######################################################### PROGRAM STARTED #########################################################\n",
"\n",
"Array 0 = [ 1 39 1 1 1 43 39 0 1 1 1 0 0 0 0 1 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 1 = [ 1 37 1 1 0 201 4 1 0 1 1 0 1 1 0 0 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 2 = [ 0 71 0 1 1 6 17 1 0 1 1 1 0 1 1 1 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 3 = [ 1 4 1 0 0 129 30 1 0 0 0 1 0 1 1 0 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 4 = [ 1 70 1 1 1 234 26 0 0 0 0 1 0 0 0 0 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 5 = [ 0 7 1 0 0 220 17 0 1 1 1 1 1 0 0 0 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 6 = [ 1 52 1 0 1 139 12 0 1 0 1 0 1 0 0 0 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 7 = [ 0 55 1 0 0 45 13 0 0 0 0 1 0 1 0 1 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 8 = [ 1 61 1 1 0 38 12 1 1 0 1 1 1 1 1 1 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 9 = [ 1 7 0 1 0 75 25 1 1 1 0 0 1 1 0 0 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 10 = [ 1 50 1 0 0 65 2 0 0 0 0 0 1 0 0 1 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 11 = [ 0 6 1 1 1 243 4 0 1 0 0 0 0 0 0 1 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 12 = [ 0 50 1 0 0 191 31 0 0 0 0 0 0 1 1 0 1]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 13 = [ 1 1 1 1 0 205 36 1 0 0 0 0 0 1 0 0 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 14 = [ 1 40 1 1 0 229 1 0 1 0 1 0 1 1 0 1 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 15 = [ 0 75 1 0 1 159 34 0 0 0 1 0 1 0 0 0 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 16 = [ 0 22 1 1 0 6 40 0 0 0 1 1 0 1 1 0 0]\n",
"Model predicts NO STROKE = [0]\n",
"###################################################################################################################################\n",
"\n",
"Array 17 = [ 1 77 1 0 0 228 18 0 1 0 1 1 1 0 0 0 0]\n",
"Model predicts STROKE = [1]\n",
"######################################################### PROGRAM FINISHED #########################################################\n"
]
}
],
"source": [
"import random\n",
"\n",
"count=0\n",
"array=[]\n",
"predictionOutcome=[0]\n",
"\n",
"print(\"######################################################### PROGRAM STARTED #########################################################\")\n",
"\n",
"while predictionOutcome == [0]:\n",
" arr1=np.random.randint(2, size=1)\n",
" arr2=np.random.randint(83, size=1)\n",
" arr3=np.random.randint(2, size=3)\n",
" arr4=np.random.randint(272, size=1)\n",
" arr5=np.random.randint(49, size=1)\n",
" arr6=np.random.randint(2, size=10)\n",
" \n",
" array = np.concatenate((arr1, arr2, arr3, arr4, arr5, arr6), axis=0)\n",
" \n",
" print(\"\\nArray %d =\" %count, array)\n",
" \n",
" predictionOutcome = model.predict([array])\n",
" \n",
" if predictionOutcome == 0:\n",
" print(\"Model predicts NO STROKE = \", predictionOutcome)\n",
" print(\"###################################################################################################################################\")\n",
" else:\n",
" print(\"Model predicts STROKE = \", predictionOutcome)\n",
" \n",
" count+=1\n",
"else:\n",
" print(\"######################################################### PROGRAM FINISHED #########################################################\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('StrokePredictionModel')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6d6bab66b583e7661b89cead2220317a23c391a40fb8c52f2c1bcd3c04f3fbda"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}