{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"_cell_guid":"b5d748ec-ba93-440b-9020-647f44fe6e0f","_uuid":"aa2ef3305175ad6f2f289d18aa9d1e2e570bccd4"},"source":["# **Predictors of mental health illness**\n","\n","`Last update: 16/12/2022`\n","\n","The proccess is the following:\n","1. [Library and data loading](#Library_and_data_loading)\n","2. [Data cleaning](#Data_cleaning)\n","3. [Encoding data](#Encoding_data)\n","4. [Covariance Matrix. Variability comparison between categories of variables](#Covariance_Matrix)\n","5. [Some charts to see data relationship](#Some_charts_to_see_data_relationship)\n","6. [Scaling and fitting](#Scaling_and_fitting)\n","7. [Tuning](#Tuning)\n","8. [Evaluating models](#Evaluating_models) \n"," 1. [Logistic Eegression](#Logistic_regressio)\n"," 2. [KNeighbors Classifier](#KNeighborsClassifier)\n"," 3. [Decision Tree Classifier](#Decision_Tree_classifier)\n"," 4. [Random Forests](#Random_Forests)\n"," 5. [Bagging](#Bagging)\n"," 6. [Boosting](#Boosting)\n"," 7. [Stacking](#Stacking)\n","9. [Predicting with Neural Network](#Predicting_with_Neural_Network)\n","10. [Success method plot](#Success_method_plot)\n","11. [Creating predictions on test set](#Creating_predictions_on_test_set)\n","12. [Submission](#Submission)\n","13. [Conclusions](#Conclusions)"]},{"cell_type":"markdown","metadata":{"_cell_guid":"dc7cd7b3-32e0-4284-b669-87d4fb6dbaf8","_uuid":"dfeb8a6c8cc31996e69537a9a25102c42ccf3e6d"},"source":["\n","## **1. Library and data loading** ##"]},{"cell_type":"code","execution_count":90,"metadata":{"_cell_guid":"507667c6-d01e-44e4-b8d7-a2d61d3eb02a","_uuid":"d59a9a91a5da35354233aaf9fc1f0dd66686349b","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["(1259, 27)\n"," Age\n","count 1.259000e+03\n","mean 7.942815e+07\n","std 2.818299e+09\n","min -1.726000e+03\n","25% 2.700000e+01\n","50% 3.100000e+01\n","75% 3.600000e+01\n","max 1.000000e+11\n","\n","RangeIndex: 1259 entries, 0 to 1258\n","Data columns (total 27 columns):\n"," # Column Non-Null Count Dtype \n","--- ------ -------------- ----- \n"," 0 Timestamp 1259 non-null object\n"," 1 Age 1259 non-null int64 \n"," 2 Gender 1259 non-null object\n"," 3 Country 1259 non-null object\n"," 4 state 744 non-null object\n"," 5 self_employed 1241 non-null object\n"," 6 family_history 1259 non-null object\n"," 7 treatment 1259 non-null object\n"," 8 work_interfere 995 non-null object\n"," 9 no_employees 1259 non-null object\n"," 10 remote_work 1259 non-null object\n"," 11 tech_company 1259 non-null object\n"," 12 benefits 1259 non-null object\n"," 13 care_options 1259 non-null object\n"," 14 wellness_program 1259 non-null object\n"," 15 seek_help 1259 non-null object\n"," 16 anonymity 1259 non-null object\n"," 17 leave 1259 non-null object\n"," 18 mental_health_consequence 1259 non-null object\n"," 19 phys_health_consequence 1259 non-null object\n"," 20 coworkers 1259 non-null object\n"," 21 supervisor 1259 non-null object\n"," 22 mental_health_interview 1259 non-null object\n"," 23 phys_health_interview 1259 non-null object\n"," 24 mental_vs_physical 1259 non-null object\n"," 25 obs_consequence 1259 non-null object\n"," 26 comments 164 non-null object\n","dtypes: int64(1), object(26)\n","memory usage: 265.7+ KB\n","None\n"]}],"source":["import numpy as np # linear algebra\n","import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","\n","from scipy import stats\n","from scipy.stats import randint\n","\n","# prep\n","from sklearn.model_selection import train_test_split\n","from sklearn import preprocessing\n","from sklearn.datasets import make_classification\n","from sklearn.preprocessing import binarize, LabelEncoder, MinMaxScaler\n","\n","# models\n","from sklearn.linear_model import LogisticRegression\n","from sklearn.tree import DecisionTreeClassifier\n","from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n","\n","# Validation libraries\n","from sklearn import metrics\n","from sklearn.metrics import accuracy_score, mean_squared_error, precision_recall_curve\n","from sklearn.model_selection import cross_val_score\n","\n","#Neural Network\n","from sklearn.neural_network import MLPClassifier\n","from sklearn.model_selection import RandomizedSearchCV\n","\n","#Bagging\n","from sklearn.ensemble import BaggingClassifier, AdaBoostClassifier\n","from sklearn.neighbors import KNeighborsClassifier\n","\n","#Naive bayes\n","from sklearn.naive_bayes import GaussianNB \n","\n","#Stacking\n","from mlxtend.classifier import StackingClassifier\n","\n","\n","# Any results you write to the current directory are saved as output.\n","\n","#reading in CSV's from a file path\n","train_df = pd.read_csv('raw_datasets\\OSMI Mental Health in Tech Survey 2014.csv')\n","\n","\n","#Pandas: whats the data row count?\n","print(train_df.shape)\n"," \n","#Pandas: whats the distribution of the data?\n","print(train_df.describe())\n"," \n","#Pandas: What types of data do i have?\n","print(train_df.info())\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"9a8c2272-ba6c-477a-9710-b82d57a1804f","_uuid":"e4eef2cb6628af4e719fdbd434ccbacc1846e487"},"source":["\n","## **2. Data cleaning** ##"]},{"cell_type":"code","execution_count":91,"metadata":{"_cell_guid":"a8e060e1-18fa-4214-a6fb-f924f74af108","_kg_hide-input":true,"_kg_hide-output":true,"_uuid":"6088e9b05d062fbed2c2272924e9d8aa0e23e5b7","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":[" Total Percent\n","comments 1095 0.869738\n","state 515 0.409055\n","work_interfere 264 0.209690\n","self_employed 18 0.014297\n","seek_help 0 0.000000\n","obs_consequence 0 0.000000\n","mental_vs_physical 0 0.000000\n","phys_health_interview 0 0.000000\n","mental_health_interview 0 0.000000\n","supervisor 0 0.000000\n","coworkers 0 0.000000\n","phys_health_consequence 0 0.000000\n","mental_health_consequence 0 0.000000\n","leave 0 0.000000\n","anonymity 0 0.000000\n","Timestamp 0 0.000000\n","wellness_program 0 0.000000\n","Age 0 0.000000\n","benefits 0 0.000000\n","tech_company 0 0.000000\n","remote_work 0 0.000000\n","no_employees 0 0.000000\n","treatment 0 0.000000\n","family_history 0 0.000000\n","Country 0 0.000000\n","Gender 0 0.000000\n","care_options 0 0.000000\n"]}],"source":["#missing data\n","total = train_df.isnull().sum().sort_values(ascending=False)\n","percent = (train_df.isnull().sum()/train_df.isnull().count()).sort_values(ascending=False)\n","missing_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'])\n","missing_data.head(20)\n","print(missing_data)\n"]},{"cell_type":"code","execution_count":92,"metadata":{"_cell_guid":"d21825da-92d6-48e2-9ab9-c63fbbbbd2b7","_uuid":"50bdd9973655b16c61711d519645df6afb2ca214","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","
\n"," \n","
\n","
\n","
Age
\n","
Gender
\n","
Country
\n","
self_employed
\n","
family_history
\n","
treatment
\n","
work_interfere
\n","
no_employees
\n","
remote_work
\n","
tech_company
\n","
...
\n","
anonymity
\n","
leave
\n","
mental_health_consequence
\n","
phys_health_consequence
\n","
coworkers
\n","
supervisor
\n","
mental_health_interview
\n","
phys_health_interview
\n","
mental_vs_physical
\n","
obs_consequence
\n","
\n"," \n"," \n","
\n","
0
\n","
37
\n","
Female
\n","
United States
\n","
NaN
\n","
No
\n","
Yes
\n","
Often
\n","
6-25
\n","
No
\n","
Yes
\n","
...
\n","
Yes
\n","
Somewhat easy
\n","
No
\n","
No
\n","
Some of them
\n","
Yes
\n","
No
\n","
Maybe
\n","
Yes
\n","
No
\n","
\n","
\n","
1
\n","
44
\n","
M
\n","
United States
\n","
NaN
\n","
No
\n","
No
\n","
Rarely
\n","
More than 1000
\n","
No
\n","
No
\n","
...
\n","
Don't know
\n","
Don't know
\n","
Maybe
\n","
No
\n","
No
\n","
No
\n","
No
\n","
No
\n","
Don't know
\n","
No
\n","
\n","
\n","
2
\n","
32
\n","
Male
\n","
Canada
\n","
NaN
\n","
No
\n","
No
\n","
Rarely
\n","
6-25
\n","
No
\n","
Yes
\n","
...
\n","
Don't know
\n","
Somewhat difficult
\n","
No
\n","
No
\n","
Yes
\n","
Yes
\n","
Yes
\n","
Yes
\n","
No
\n","
No
\n","
\n","
\n","
3
\n","
31
\n","
Male
\n","
United Kingdom
\n","
NaN
\n","
Yes
\n","
Yes
\n","
Often
\n","
26-100
\n","
No
\n","
Yes
\n","
...
\n","
No
\n","
Somewhat difficult
\n","
Yes
\n","
Yes
\n","
Some of them
\n","
No
\n","
Maybe
\n","
Maybe
\n","
No
\n","
Yes
\n","
\n","
\n","
4
\n","
31
\n","
Male
\n","
United States
\n","
NaN
\n","
No
\n","
No
\n","
Never
\n","
100-500
\n","
Yes
\n","
Yes
\n","
...
\n","
Don't know
\n","
Don't know
\n","
No
\n","
No
\n","
Some of them
\n","
Yes
\n","
Yes
\n","
Yes
\n","
Don't know
\n","
No
\n","
\n"," \n","
\n","
5 rows × 24 columns
\n","
"],"text/plain":[" Age Gender Country self_employed family_history treatment \\\n","0 37 Female United States NaN No Yes \n","1 44 M United States NaN No No \n","2 32 Male Canada NaN No No \n","3 31 Male United Kingdom NaN Yes Yes \n","4 31 Male United States NaN No No \n","\n"," work_interfere no_employees remote_work tech_company ... anonymity \\\n","0 Often 6-25 No Yes ... Yes \n","1 Rarely More than 1000 No No ... Don't know \n","2 Rarely 6-25 No Yes ... Don't know \n","3 Often 26-100 No Yes ... No \n","4 Never 100-500 Yes Yes ... Don't know \n","\n"," leave mental_health_consequence phys_health_consequence \\\n","0 Somewhat easy No No \n","1 Don't know Maybe No \n","2 Somewhat difficult No No \n","3 Somewhat difficult Yes Yes \n","4 Don't know No No \n","\n"," coworkers supervisor mental_health_interview phys_health_interview \\\n","0 Some of them Yes No Maybe \n","1 No No No No \n","2 Yes Yes Yes Yes \n","3 Some of them No Maybe Maybe \n","4 Some of them Yes Yes Yes \n","\n"," mental_vs_physical obs_consequence \n","0 Yes No \n","1 Don't know No \n","2 No No \n","3 No Yes \n","4 Don't know No \n","\n","[5 rows x 24 columns]"]},"execution_count":92,"metadata":{},"output_type":"execute_result"}],"source":["#dealing with missing data\n","#Let’s get rid of the variables \"Timestamp\",“comments”, “state” just to make our lives easier.\n","train_df = train_df.drop(['comments'], axis= 1)\n","train_df = train_df.drop(['state'], axis= 1)\n","train_df = train_df.drop(['Timestamp'], axis= 1)\n","\n","train_df.isnull().sum().max() #just checking that there's no missing data missing...\n","train_df.head(5)"]},{"cell_type":"markdown","metadata":{"_cell_guid":"871f195e-d25d-426a-84f5-997b017b892c","_uuid":"7a54d86c30dab2e270e7374455178f8abcc15a7c"},"source":["**Cleaning NaN**"]},{"cell_type":"code","execution_count":93,"metadata":{"_cell_guid":"2bf70e26-afd9-4766-95ae-68d8ff9b21db","_uuid":"08b863f7d666b7e68cef056ba2a0ae033cbdeb9d","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","
\n"," \n","
\n","
\n","
Age
\n","
Gender
\n","
Country
\n","
self_employed
\n","
family_history
\n","
treatment
\n","
work_interfere
\n","
no_employees
\n","
remote_work
\n","
tech_company
\n","
...
\n","
anonymity
\n","
leave
\n","
mental_health_consequence
\n","
phys_health_consequence
\n","
coworkers
\n","
supervisor
\n","
mental_health_interview
\n","
phys_health_interview
\n","
mental_vs_physical
\n","
obs_consequence
\n","
\n"," \n"," \n","
\n","
0
\n","
37
\n","
Female
\n","
United States
\n","
NaN
\n","
No
\n","
Yes
\n","
Often
\n","
6-25
\n","
No
\n","
Yes
\n","
...
\n","
Yes
\n","
Somewhat easy
\n","
No
\n","
No
\n","
Some of them
\n","
Yes
\n","
No
\n","
Maybe
\n","
Yes
\n","
No
\n","
\n","
\n","
1
\n","
44
\n","
M
\n","
United States
\n","
NaN
\n","
No
\n","
No
\n","
Rarely
\n","
More than 1000
\n","
No
\n","
No
\n","
...
\n","
Don't know
\n","
Don't know
\n","
Maybe
\n","
No
\n","
No
\n","
No
\n","
No
\n","
No
\n","
Don't know
\n","
No
\n","
\n","
\n","
2
\n","
32
\n","
Male
\n","
Canada
\n","
NaN
\n","
No
\n","
No
\n","
Rarely
\n","
6-25
\n","
No
\n","
Yes
\n","
...
\n","
Don't know
\n","
Somewhat difficult
\n","
No
\n","
No
\n","
Yes
\n","
Yes
\n","
Yes
\n","
Yes
\n","
No
\n","
No
\n","
\n","
\n","
3
\n","
31
\n","
Male
\n","
United Kingdom
\n","
NaN
\n","
Yes
\n","
Yes
\n","
Often
\n","
26-100
\n","
No
\n","
Yes
\n","
...
\n","
No
\n","
Somewhat difficult
\n","
Yes
\n","
Yes
\n","
Some of them
\n","
No
\n","
Maybe
\n","
Maybe
\n","
No
\n","
Yes
\n","
\n","
\n","
4
\n","
31
\n","
Male
\n","
United States
\n","
NaN
\n","
No
\n","
No
\n","
Never
\n","
100-500
\n","
Yes
\n","
Yes
\n","
...
\n","
Don't know
\n","
Don't know
\n","
No
\n","
No
\n","
Some of them
\n","
Yes
\n","
Yes
\n","
Yes
\n","
Don't know
\n","
No
\n","
\n"," \n","
\n","
5 rows × 24 columns
\n","
"],"text/plain":[" Age Gender Country self_employed family_history treatment \\\n","0 37 Female United States NaN No Yes \n","1 44 M United States NaN No No \n","2 32 Male Canada NaN No No \n","3 31 Male United Kingdom NaN Yes Yes \n","4 31 Male United States NaN No No \n","\n"," work_interfere no_employees remote_work tech_company ... anonymity \\\n","0 Often 6-25 No Yes ... Yes \n","1 Rarely More than 1000 No No ... Don't know \n","2 Rarely 6-25 No Yes ... Don't know \n","3 Often 26-100 No Yes ... No \n","4 Never 100-500 Yes Yes ... Don't know \n","\n"," leave mental_health_consequence phys_health_consequence \\\n","0 Somewhat easy No No \n","1 Don't know Maybe No \n","2 Somewhat difficult No No \n","3 Somewhat difficult Yes Yes \n","4 Don't know No No \n","\n"," coworkers supervisor mental_health_interview phys_health_interview \\\n","0 Some of them Yes No Maybe \n","1 No No No No \n","2 Yes Yes Yes Yes \n","3 Some of them No Maybe Maybe \n","4 Some of them Yes Yes Yes \n","\n"," mental_vs_physical obs_consequence \n","0 Yes No \n","1 Don't know No \n","2 No No \n","3 No Yes \n","4 Don't know No \n","\n","[5 rows x 24 columns]"]},"execution_count":93,"metadata":{},"output_type":"execute_result"}],"source":["# Assign default values for each data type\n","defaultInt = 0\n","defaultString = 'NaN'\n","defaultFloat = 0.0\n","\n","# Create lists by data tpe\n","intFeatures = ['Age']\n","stringFeatures = ['Gender', 'Country', 'self_employed', 'family_history', 'treatment', 'work_interfere',\n"," 'no_employees', 'remote_work', 'tech_company', 'anonymity', 'leave', 'mental_health_consequence',\n"," 'phys_health_consequence', 'coworkers', 'supervisor', 'mental_health_interview', 'phys_health_interview',\n"," 'mental_vs_physical', 'obs_consequence', 'benefits', 'care_options', 'wellness_program',\n"," 'seek_help']\n","floatFeatures = []\n","\n","# Clean the NaN's\n","for feature in train_df:\n"," if feature in intFeatures:\n"," train_df[feature] = train_df[feature].fillna(defaultInt)\n"," elif feature in stringFeatures:\n"," train_df[feature] = train_df[feature].fillna(defaultString)\n"," elif feature in floatFeatures:\n"," train_df[feature] = train_df[feature].fillna(defaultFloat)\n"," else:\n"," print('Error: Feature %s not recognized.' % feature)\n","train_df.head(5) "]},{"cell_type":"code","execution_count":94,"metadata":{"_cell_guid":"4ec04bf5-c88f-4839-a79b-2d3ad6962599","_uuid":"4507b7c43c9663f90a9cf915dd73850e4e0f17ab","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["['female' 'male' 'trans']\n"]}],"source":["#clean 'Gender'\n","#Slower case all columm's elements\n","gender = train_df['Gender'].str.lower()\n","#print(gender)\n","\n","#Select unique elements\n","gender = train_df['Gender'].unique()\n","\n","#Made gender groups\n","male_str = [\"male\", \"m\", \"male-ish\", \"maile\", \"mal\", \"male (cis)\", \"make\", \"male \", \"man\",\"msle\", \"mail\", \"malr\",\"cis man\", \"Cis Male\", \"cis male\"]\n","trans_str = [\"trans-female\", \"something kinda male?\", \"queer/she/they\", \"non-binary\",\"nah\", \"all\", \"enby\", \"fluid\", \"genderqueer\", \"androgyne\", \"agender\", \"male leaning androgynous\", \"guy (-ish) ^_^\", \"trans woman\", \"neuter\", \"female (trans)\", \"queer\", \"ostensibly male, unsure what that really means\"] \n","female_str = [\"cis female\", \"f\", \"female\", \"woman\", \"femake\", \"female \",\"cis-female/femme\", \"female (cis)\", \"femail\"]\n","\n","for (row, col) in train_df.iterrows():\n","\n"," if str.lower(col.Gender) in male_str:\n"," train_df['Gender'].replace(to_replace=col.Gender, value='male', inplace=True)\n","\n"," if str.lower(col.Gender) in female_str:\n"," train_df['Gender'].replace(to_replace=col.Gender, value='female', inplace=True)\n","\n"," if str.lower(col.Gender) in trans_str:\n"," train_df['Gender'].replace(to_replace=col.Gender, value='trans', inplace=True)\n","\n","#Get rid of bullshit\n","stk_list = ['A little about you', 'p']\n","train_df = train_df[~train_df['Gender'].isin(stk_list)]\n","\n","print(train_df['Gender'].unique())"]},{"cell_type":"code","execution_count":95,"metadata":{"_cell_guid":"0e211fa8-a69e-42c1-adda-375d82916db4","_uuid":"818c6a88caf07379d5012f15919d6eb46ae6d98c","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["#complete missing age with mean\n","train_df['Age'].fillna(train_df['Age'].median(), inplace = True)\n","\n","# Fill with media() values < 18 and > 120\n","s = pd.Series(train_df['Age'])\n","s[s<18] = train_df['Age'].median()\n","train_df['Age'] = s\n","s = pd.Series(train_df['Age'])\n","s[s>120] = train_df['Age'].median()\n","train_df['Age'] = s\n","\n","#Ranges of Age\n","train_df['age_range'] = pd.cut(train_df['Age'], [0,20,30,65,100], labels=[\"0-20\", \"21-30\", \"31-65\", \"66-100\"], include_lowest=True)\n","\n"]},{"cell_type":"code","execution_count":96,"metadata":{"_cell_guid":"ffa02f53-30cd-4f7e-bcd9-f74b3c0826c4","_uuid":"55b6621c02ca8d8603d815da74b15ad621b46629","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["['No' 'Yes']\n"]}],"source":["#There are only 0.014% of self employed so let's change NaN to NOT self_employed\n","#Replace \"NaN\" string from defaultString\n","train_df['self_employed'] = train_df['self_employed'].replace([defaultString], 'No')\n","print(train_df['self_employed'].unique())"]},{"cell_type":"code","execution_count":97,"metadata":{"_cell_guid":"d910581d-4b92-4919-ab5a-85b018158745","_uuid":"b934d97a22a46c36e107d2e0fdeb9f0ee74dc532","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["['Often' 'Rarely' 'Never' 'Sometimes' \"Don't know\"]\n"]}],"source":["#There are only 0.20% of self work_interfere so let's change NaN to \"Don't know\n","#Replace \"NaN\" string from defaultString\n","\n","train_df['work_interfere'] = train_df['work_interfere'].replace([defaultString], 'Don\\'t know' )\n","print(train_df['work_interfere'].unique())"]},{"cell_type":"markdown","metadata":{"_cell_guid":"b77019bf-b900-4c2b-9676-792860d89825","_uuid":"01bdb3e74278bd61fdbdda9b1f9a3085c873999b"},"source":["\n","## **3. Encoding data**"]},{"cell_type":"code","execution_count":98,"metadata":{"_cell_guid":"05a6455e-187d-4db2-bb71-4fbf2ba6d967","_uuid":"af23578fce1520cbe5fd970b13c7ddd71f442018","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["label_Age [18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 60, 61, 62, 65, 72]\n","label_Gender ['female', 'male', 'trans']\n","label_Country ['Australia', 'Austria', 'Belgium', 'Bosnia and Herzegovina', 'Brazil', 'Bulgaria', 'Canada', 'China', 'Colombia', 'Costa Rica', 'Croatia', 'Czech Republic', 'Denmark', 'Finland', 'France', 'Georgia', 'Germany', 'Greece', 'Hungary', 'India', 'Ireland', 'Israel', 'Italy', 'Japan', 'Latvia', 'Mexico', 'Moldova', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Philippines', 'Poland', 'Portugal', 'Romania', 'Russia', 'Singapore', 'Slovenia', 'South Africa', 'Spain', 'Sweden', 'Switzerland', 'Thailand', 'United Kingdom', 'United States', 'Uruguay', 'Zimbabwe']\n","label_self_employed ['No', 'Yes']\n","label_family_history ['No', 'Yes']\n","label_treatment ['No', 'Yes']\n","label_work_interfere [\"Don't know\", 'Never', 'Often', 'Rarely', 'Sometimes']\n","label_no_employees ['1-5', '100-500', '26-100', '500-1000', '6-25', 'More than 1000']\n","label_remote_work ['No', 'Yes']\n","label_tech_company ['No', 'Yes']\n","label_benefits [\"Don't know\", 'No', 'Yes']\n","label_care_options ['No', 'Not sure', 'Yes']\n","label_wellness_program [\"Don't know\", 'No', 'Yes']\n","label_seek_help [\"Don't know\", 'No', 'Yes']\n","label_anonymity [\"Don't know\", 'No', 'Yes']\n","label_leave [\"Don't know\", 'Somewhat difficult', 'Somewhat easy', 'Very difficult', 'Very easy']\n","label_mental_health_consequence ['Maybe', 'No', 'Yes']\n","label_phys_health_consequence ['Maybe', 'No', 'Yes']\n","label_coworkers ['No', 'Some of them', 'Yes']\n","label_supervisor ['No', 'Some of them', 'Yes']\n","label_mental_health_interview ['Maybe', 'No', 'Yes']\n","label_phys_health_interview ['Maybe', 'No', 'Yes']\n","label_mental_vs_physical [\"Don't know\", 'No', 'Yes']\n","label_obs_consequence ['No', 'Yes']\n","label_age_range ['0-20', '21-30', '31-65', '66-100']\n"]},{"data":{"text/html":["
"]},"metadata":{},"output_type":"display_data"}],"source":["#correlation matrix\n","corrmat = train_df.corr()\n","f, ax = plt.subplots(figsize=(12, 9))\n","sns.heatmap(corrmat, vmax=.8, square=True);\n","plt.show()\n","\n","#treatment correlation matrix\n","k = 10 #number of variables for heatmap\n","cols = corrmat.nlargest(k, 'treatment')['treatment'].index\n","cm = np.corrcoef(train_df[cols].values.T)\n","sns.set(font_scale=1.25)\n","hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={'size': 10}, yticklabels=cols.values, xticklabels=cols.values)\n","plt.show()\n","\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"f71d0c3d-8554-4478-b1d5-7aa461fcb7bb","_uuid":"22adfaf19848e1103804bf8737cb5656fcac5388"},"source":["\n","## **5. Some charts to see data relationship** "]},{"cell_type":"markdown","metadata":{"_cell_guid":"2af9fb8b-79ac-4444-95b0-210c704afef9","_uuid":"39488f9878f2ab937c4ce888eb82eba133fb175c"},"source":["Distribiution and density by Age"]},{"cell_type":"code","execution_count":101,"metadata":{"_cell_guid":"89387b66-3a36-4b3f-98a4-254253452056","_uuid":"e7288897925d3f340ba4c2e73d2ef61c6ca28c4f","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["C:\\Users\\puran\\AppData\\Local\\Temp\\ipykernel_20412\\1394260443.py:3: UserWarning: \n","\n","`distplot` is a deprecated function and will be removed in seaborn v0.14.0.\n","\n","Please adapt your code to use either `displot` (a figure-level function with\n","similar flexibility) or `histplot` (an axes-level function for histograms).\n","\n","For a guide to updating your code to use the new functions, please see\n","https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751\n","\n"," sns.distplot(train_df[\"Age\"], bins=24)\n"]},{"data":{"text/plain":["Text(0.5, 0, 'Age')"]},"execution_count":101,"metadata":{},"output_type":"execute_result"},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Distribiution and density by Age\n","plt.figure(figsize=(12,8))\n","sns.distplot(train_df[\"Age\"], bins=24)\n","plt.title(\"Distribuition and density by Age\")\n","plt.xlabel(\"Age\")\n","\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"a8880fbd-1545-48dd-a850-1ccc6214bcdc","_uuid":"9c24e2674ba7fe8cfecc0b911e6773512b7099c5"},"source":["Separate by treatment"]},{"cell_type":"code","execution_count":102,"metadata":{"_cell_guid":"b464a345-00b2-47a2-9d3e-c7074cfaf790","_uuid":"2cc9d4abe3ca092a5bf9b7a9cd9f0b4d443c0471","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\seaborn\\axisgrid.py:848: UserWarning: \n","\n","`distplot` is a deprecated function and will be removed in seaborn v0.14.0.\n","\n","Please adapt your code to use either `displot` (a figure-level function with\n","similar flexibility) or `histplot` (an axes-level function for histograms).\n","\n","For a guide to updating your code to use the new functions, please see\n","https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751\n","\n"," func(*plot_args, **plot_kwargs)\n","y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\seaborn\\axisgrid.py:848: UserWarning: \n","\n","`distplot` is a deprecated function and will be removed in seaborn v0.14.0.\n","\n","Please adapt your code to use either `displot` (a figure-level function with\n","similar flexibility) or `histplot` (an axes-level function for histograms).\n","\n","For a guide to updating your code to use the new functions, please see\n","https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751\n","\n"," func(*plot_args, **plot_kwargs)\n"]},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Separate by treatment or not\n","\n","g = sns.FacetGrid(train_df, col='treatment')\n","g = g.map(sns.distplot, \"Age\")"]},{"cell_type":"markdown","metadata":{"_cell_guid":"20b446a7-fe87-478d-9328-de7ea7afea5b","_uuid":"1bfd6affb0501138a98d4dbc9c70ad491b35962c"},"source":["How many people has been treated?"]},{"cell_type":"code","execution_count":103,"metadata":{"_cell_guid":"81f1eb2a-c3ac-433d-a674-24aa6a33d287","_uuid":"ac35d322a601a3ff9fdde64fa5e33b28dbdd10e5","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/plain":["Text(0.5, 1.0, 'Total Distribuition by treated or not')"]},"execution_count":103,"metadata":{},"output_type":"execute_result"},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Let see how many people has been treated\n","plt.figure(figsize=(12,8))\n","labels = labelDict['label_Gender']\n","g = sns.countplot(x=\"treatment\", data=train_df)\n","g.set_xticklabels(labels)\n","\n","plt.title('Total Distribuition by treated or not')"]},{"cell_type":"markdown","metadata":{"_cell_guid":"d7a9a45c-0df3-48a4-9ab2-e5aaae9abe7f","_uuid":"a781bf88a14f5bb937ede09d81365f16bdcbbe3f"},"source":["Draw a nested barplot to show probabilities for class and sex"]},{"cell_type":"code","execution_count":104,"metadata":{"_cell_guid":"210c02cf-9296-4373-8725-bd13b819a9e4","_uuid":"2b931335e19a15bf897fa8a6eabcd26cb263503a","collapsed":true,"jupyter":{"outputs_hidden":true},"scrolled":true,"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["o = labelDict['label_age_range']\n","\n","g = sns.barplot(x=\"age_range\", y=\"treatment\", hue=\"Gender\", data=train_df)\n","g.set_xticklabels(o)\n","\n","plt.title('Probability of mental health condition')\n","plt.ylabel('Probability x 100')\n","plt.xlabel('Age')\n","\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"6517af95-3b02-4b50-b246-ea56f61f6113","_uuid":"4ed2e62093d643cf105829ff2ae9a9917b110996"},"source":["Barplot to show probabilities for family history"]},{"cell_type":"code","execution_count":105,"metadata":{"_cell_guid":"b93d6150-b23a-4b71-8cf8-a10b811152a8","_uuid":"605bf86a497f0b7989b57a3458088131f564e2a8","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["o = labelDict['label_family_history']\n","g = sns.barplot(x=\"family_history\", y=\"treatment\", hue=\"Gender\", data=train_df)\n","g.set_xticklabels(o)\n","plt.title('Probability of mental health condition')\n","plt.ylabel('Probability x 100')\n","plt.xlabel('Family History')\n","\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"beecc2fe-4cd8-489c-bc0e-bb5ad633931c","_uuid":"7c7baa6c000ab81edb071071d45bac309f796d80"},"source":["Barplot to show probabilities for care options"]},{"cell_type":"code","execution_count":106,"metadata":{"_cell_guid":"c77da62e-f71f-49fb-9c13-fe2fea3d474e","_uuid":"bc61dc5b0c4c0a8204453524438006a84f6168d1","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["o = labelDict['label_care_options']\n","g = sns.barplot(x=\"care_options\", y=\"treatment\", hue=\"Gender\", data=train_df)\n","g.set_xticklabels(o)\n","plt.title('Probability of mental health condition')\n","plt.ylabel('Probability x 100')\n","plt.xlabel('Care options')\n","\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"56de0fc1-8ee1-41b8-868e-133db7635c64","_uuid":"5f9e29712de3b44df01755c481b898a38e508f07"},"source":["Barplot to show probabilities for benefits"]},{"cell_type":"code","execution_count":107,"metadata":{"_cell_guid":"4fab65ea-f3f5-4831-9f3f-7560fb908d3e","_uuid":"25c2da49fd8c83fc2d42355c0e147f439c00a7c0","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["o = labelDict['label_benefits']\n","g = sns.barplot(x=\"care_options\", y=\"treatment\", hue=\"Gender\", data=train_df)\n","g.set_xticklabels(o)\n","plt.title('Probability of mental health condition')\n","plt.ylabel('Probability x 100')\n","plt.xlabel('Benefits')\n","\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"bebe13ce-94d5-487c-88b8-7a634f361223","_uuid":"9cfc643a94d39e2c06ed870c306deee6a9219b60"},"source":["Barplot to show probabilities for work interfere"]},{"cell_type":"code","execution_count":108,"metadata":{"_cell_guid":"1606646f-0db7-41f9-b4bc-f8f2c982087c","_uuid":"a7f3daeded334645d4cf5bd202e3116811159481","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["o = labelDict['label_work_interfere']\n","g = sns.barplot(x=\"work_interfere\", y=\"treatment\", hue=\"Gender\", data=train_df)\n","g.set_xticklabels(o)\n","plt.title('Probability of mental health condition')\n","plt.ylabel('Probability x 100')\n","plt.xlabel('Work interfere')\n","\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"4ea1fe2d-e6bd-434f-807c-9936f17be784","_uuid":"ebb8c1dc1fdd5dcfa5f9d1b05db3d23323b21adc"},"source":["\n","## **6. Scaling and fitting** ##\n","\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"a95cf3ba-36d4-4df5-9797-2ccf1f018f5a","_uuid":"5bdd6122fff56b5957a59210160269e7a32af869"},"source":["Features Scaling\n","We're going to scale age, because is extremely different from the othere ones."]},{"cell_type":"code","execution_count":109,"metadata":{"_cell_guid":"6ae3cc24-d8cf-4ab2-915d-091007ff2457","_uuid":"d8dcf5e62e990fb6747f5695cbce9919f5cdec4b","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/html":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Build a forest and compute the feature importances\n","forest = ExtraTreesClassifier(n_estimators=250,\n"," random_state=0)\n","\n","forest.fit(X, y)\n","importances = forest.feature_importances_\n","std = np.std([tree.feature_importances_ for tree in forest.estimators_],\n"," axis=0)\n","indices = np.argsort(importances)[::-1]\n","\n","labels = []\n","for f in range(X.shape[1]):\n"," labels.append(feature_cols[f]) \n"," \n","# Plot the feature importances of the forest\n","plt.figure(figsize=(12,8))\n","plt.title(\"Feature importances\")\n","plt.bar(range(X.shape[1]), importances[indices],\n"," color=\"r\", yerr=std[indices], align=\"center\")\n","plt.xticks(range(X.shape[1]), labels, rotation='vertical')\n","plt.xlim([-1, X.shape[1]])\n","plt.show()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"b07c9bf0-3323-4cb3-9589-fe64b1fd854a","_uuid":"a5a4b72ee4b193749ff0f59acb689c5900d98ec1"},"source":["\n","## **7. Tuning** \n","### **Evaluating a Classification Model.** \n","This function will evalue: \n","* **Classification accuracy: **percentage of correct predictions\n","* **Null accuracy:** accuracy that could be achieved by always predicting the most frequent class\n","* **Percentage of ones** \n","* **Percentage of zero**s \n","* **Confusion matrix: **Table that describes the performance of a classification model\n"," True Positives (TP): we correctly predicted that they do have diabetes\n"," True Negatives (TN): we correctly predicted that they don't have diabetes\n"," False Positives (FP): we incorrectly predicted that they do have diabetes (a \"Type I error\")\n"," Falsely predict positive\n"," False Negatives (FN): we incorrectly predicted that they don't have diabetes (a \"Type II error\")\n"," Falsely predict negative\n","\n","* **False Positive Rate** \n","* **Precision of Positive value** \n","* **AUC:** is the percentage of the ROC plot that is underneath the curve\n"," .90-1 = excellent (A)\n"," .80-.90 = good (B)\n"," .70-.80 = fair (C)\n"," .60-.70 = poor (D)\n"," .50-.60 = fail (F)\n","And some others values for tuning processes.\n","More information: [http://www.ritchieng.com/machine-learning-evaluate-classification-model/]: \n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"ae969199-93b6-4566-a677-69560ab1760f","_uuid":"8f672b391aa354c77032b1f0c174a59aaa75cea9"},"source":[]},{"cell_type":"code","execution_count":112,"metadata":{"_cell_guid":"0c78481e-0e93-4369-932f-d2b21f1dea0c","_uuid":"dec481c2407ad73187e7b678fbdb9812f6a636a8","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def evalClassModel(model, y_test, y_pred_class, plot=False):\n"," #Classification accuracy: percentage of correct predictions\n"," # calculate accuracy\n"," print('Accuracy:', metrics.accuracy_score(y_test, y_pred_class))\n"," \n"," #Null accuracy: accuracy that could be achieved by always predicting the most frequent class\n"," # examine the class distribution of the testing set (using a Pandas Series method)\n"," print('Null accuracy:\\n', y_test.value_counts())\n"," \n"," # calculate the percentage of ones\n"," print('Percentage of ones:', y_test.mean())\n"," \n"," # calculate the percentage of zeros\n"," print('Percentage of zeros:',1 - y_test.mean())\n"," \n"," #Comparing the true and predicted response values\n"," print('True:', y_test.values[0:25])\n"," print('Pred:', y_pred_class[0:25])\n"," \n"," #Conclusion:\n"," #Classification accuracy is the easiest classification metric to understand\n"," #But, it does not tell you the underlying distribution of response values\n"," #And, it does not tell you what \"types\" of errors your classifier is making\n"," \n"," #Confusion matrix\n"," # save confusion matrix and slice into four pieces\n"," confusion = metrics.confusion_matrix(y_test, y_pred_class)\n"," #[row, column]\n"," TP = confusion[1, 1]\n"," TN = confusion[0, 0]\n"," FP = confusion[0, 1]\n"," FN = confusion[1, 0]\n"," \n"," # visualize Confusion Matrix\n"," sns.heatmap(confusion,annot=True,fmt=\"d\") \n"," plt.title('Confusion Matrix')\n"," plt.xlabel('Predicted')\n"," plt.ylabel('Actual')\n"," plt.show()\n"," \n"," #Metrics computed from a confusion matrix\n"," #Classification Accuracy: Overall, how often is the classifier correct?\n"," accuracy = metrics.accuracy_score(y_test, y_pred_class)\n"," print('Classification Accuracy:', accuracy)\n"," \n"," #Classification Error: Overall, how often is the classifier incorrect?\n"," print('Classification Error:', 1 - metrics.accuracy_score(y_test, y_pred_class))\n"," \n"," #False Positive Rate: When the actual value is negative, how often is the prediction incorrect?\n"," false_positive_rate = FP / float(TN + FP)\n"," print('False Positive Rate:', false_positive_rate)\n"," \n"," #Precision: When a positive value is predicted, how often is the prediction correct?\n"," print('Precision:', metrics.precision_score(y_test, y_pred_class))\n"," \n"," \n"," # IMPORTANT: first argument is true values, second argument is predicted probabilities\n"," print('AUC Score:', metrics.roc_auc_score(y_test, y_pred_class))\n"," \n"," # calculate cross-validated AUC\n"," print('Cross-validated AUC:', cross_val_score(model, X, y, cv=10, scoring='roc_auc').mean())\n"," \n"," ##########################################\n"," #Adjusting the classification threshold\n"," ##########################################\n"," # print the first 10 predicted responses\n"," # 1D array (vector) of binary values (0, 1)\n"," print('First 10 predicted responses:\\n', model.predict(X_test)[0:10])\n","\n"," # print the first 10 predicted probabilities of class membership\n"," print('First 10 predicted probabilities of class members:\\n', model.predict_proba(X_test)[0:10])\n","\n"," # print the first 10 predicted probabilities for class 1\n"," model.predict_proba(X_test)[0:10, 1]\n"," \n"," # store the predicted probabilities for class 1\n"," y_pred_prob = model.predict_proba(X_test)[:, 1]\n"," \n"," if plot == True:\n"," # histogram of predicted probabilities\n"," # adjust the font size \n"," plt.rcParams['font.size'] = 12\n"," # 8 bins\n"," plt.hist(y_pred_prob, bins=8)\n"," \n"," # x-axis limit from 0 to 1\n"," plt.xlim(0,1)\n"," plt.title('Histogram of predicted probabilities')\n"," plt.xlabel('Predicted probability of treatment')\n"," plt.ylabel('Frequency')\n"," \n"," \n"," # predict treatment if the predicted probability is greater than 0.3\n"," # it will return 1 for all values above 0.3 and 0 otherwise\n"," # results are 2D so we slice out the first column\n"," y_pred_prob = y_pred_prob.reshape(-1,1) \n"," y_pred_class = binarize(y_pred_prob)[0]\n"," \n"," # print the first 10 predicted probabilities\n"," print('First 10 predicted probabilities:\\n', y_pred_prob[0:10])\n"," \n"," ##########################################\n"," #ROC Curves and Area Under the Curve (AUC)\n"," ##########################################\n"," \n"," #Question: Wouldn't it be nice if we could see how sensitivity and specificity are affected by various thresholds, without actually changing the threshold?\n"," #Answer: Plot the ROC curve!\n"," \n"," \n"," #AUC is the percentage of the ROC plot that is underneath the curve\n"," #Higher value = better classifier\n"," roc_auc = metrics.roc_auc_score(y_test, y_pred_prob)\n"," \n"," \n","\n"," # IMPORTANT: first argument is true values, second argument is predicted probabilities\n"," # we pass y_test and y_pred_prob\n"," # we do not use y_pred_class, because it will give incorrect results without generating an error\n"," # roc_curve returns 3 objects fpr, tpr, thresholds\n"," # fpr: false positive rate\n"," # tpr: true positive rate\n"," fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_prob)\n"," if plot == True:\n"," plt.figure()\n"," \n"," plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)\n"," plt.plot([0, 1], [0, 1], color='navy', linestyle='--')\n"," plt.xlim([0.0, 1.0])\n"," plt.ylim([0.0, 1.0])\n"," plt.rcParams['font.size'] = 12\n"," plt.title('ROC curve for treatment classifier')\n"," plt.xlabel('False Positive Rate (1 - Specificity)')\n"," plt.ylabel('True Positive Rate (Sensitivity)')\n"," plt.legend(loc=\"lower right\")\n"," plt.show()\n"," \n"," # define a function that accepts a threshold and prints sensitivity and specificity\n"," def evaluate_threshold(threshold):\n"," #Sensitivity: When the actual value is positive, how often is the prediction correct?\n"," #Specificity: When the actual value is negative, how often is the prediction correct?print('Sensitivity for ' + str(threshold) + ' :', tpr[thresholds > threshold][-1])\n"," print('Specificity for ' + str(threshold) + ' :', 1 - fpr[thresholds > threshold][-1])\n","\n"," # One way of setting threshold\n"," predict_mine = np.where(y_pred_prob > 0.50, 1, 0)\n"," confusion = metrics.confusion_matrix(y_test, predict_mine)\n"," print(confusion)\n"," \n"," \n"," \n"," return accuracy"]},{"cell_type":"markdown","metadata":{"_cell_guid":"3e4552da-c5cc-45af-81ca-05a090922bb0","_uuid":"4e5d57cfad5eee9dd3a34ed9da5fdccce8dc7c3a"},"source":["### **Tuning with cross validation score**"]},{"cell_type":"code","execution_count":113,"metadata":{"_cell_guid":"ff07e090-5af7-4a06-b553-c86a7a75cd64","_uuid":"8ac94690bafb3277fec2d35bd196c317a33c9555","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["##########################################\n","# Tuning with cross validation score\n","##########################################\n","def tuningCV(knn):\n"," \n"," # search for an optimal value of K for KNN\n"," k_range = list(range(1, 31))\n"," k_scores = []\n"," for k in k_range:\n"," knn = KNeighborsClassifier(n_neighbors=k)\n"," scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')\n"," k_scores.append(scores.mean())\n"," print(k_scores)\n"," # plot the value of K for KNN (x-axis) versus the cross-validated accuracy (y-axis)\n"," plt.plot(k_range, k_scores)\n"," plt.xlabel('Value of K for KNN')\n"," plt.ylabel('Cross-Validated Accuracy')\n"," plt.show()\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"55a2a701-b435-4143-bc7b-c56b4a24491f","_uuid":"1cae2be4f65fc14d55eb1331110f69f016df2dbc"},"source":["### **Tuning with GridSearchCV** ###"]},{"cell_type":"code","execution_count":114,"metadata":{"_cell_guid":"736e766c-b1c8-4d7e-b1b0-1bdab402ee4f","_uuid":"3f64ad65f27cd4a92e5593ed6359245f34ba0e8e","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def tuningGridSerach(knn):\n"," #More efficient parameter tuning using GridSearchCV\n"," # define the parameter values that should be searched\n"," k_range = list(range(1, 31))\n"," print(k_range)\n"," \n"," # create a parameter grid: map the parameter names to the values that should be searched\n"," param_grid = dict(n_neighbors=k_range)\n"," print(param_grid)\n"," \n"," # instantiate the grid\n"," grid = GridSearchCV(knn, param_grid, cv=10, scoring='accuracy')\n","\n"," # fit the grid with data\n"," grid.fit(X, y)\n"," \n"," # view the complete results (list of named tuples)\n"," grid.cv_results_\n"," \n"," # examine the first tuple\n"," print(grid.cv_results_[0].parameters)\n"," print(grid.cv_results_[0].cv_validation_scores)\n"," print(grid.cv_results_[0].mean_validation_score)\n"," \n"," # create a list of the mean scores only\n"," grid_mean_scores = [result.mean_validation_score for result in grid.cv_results_]\n"," print(grid_mean_scores)\n"," \n"," # plot the results\n"," plt.plot(k_range, grid_mean_scores)\n"," plt.xlabel('Value of K for KNN')\n"," plt.ylabel('Cross-Validated Accuracy')\n"," plt.show()\n"," \n"," # examine the best model\n"," print('GridSearch best score', grid.best_score_)\n"," print('GridSearch best params', grid.best_params_)\n"," print('GridSearch best estimator', grid.best_estimator_)\n"]},{"cell_type":"markdown","metadata":{"_cell_guid":"9ff8088d-53b8-4724-ab5a-3c23e8728b61","_uuid":"2ab0cb978d8869389d97fb4d29f4cab1c3a33ba7"},"source":["### **Tuning with RandomizedSearchCV** ###"]},{"cell_type":"code","execution_count":115,"metadata":{"_cell_guid":"446bfe3e-b42f-4821-a20a-2cb07fd07c50","_uuid":"b8d3212c33baad0e847e7a36272feb2ef31c4d5e","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def tuningRandomizedSearchCV(model, param_dist):\n"," #Searching multiple parameters simultaneously\n"," # n_iter controls the number of searches\n"," rand = RandomizedSearchCV(model, param_dist, cv=10, scoring='accuracy', n_iter=10, random_state=5)\n"," rand.fit(X, y)\n"," rand.cv_results_\n"," \n"," # examine the best model\n"," print('Rand. Best Score: ', rand.best_score_)\n"," print('Rand. Best Params: ', rand.best_params_)\n"," \n"," # run RandomizedSearchCV 20 times (with n_iter=10) and record the best score\n"," best_scores = []\n"," for _ in range(20):\n"," rand = RandomizedSearchCV(model, param_dist, cv=10, scoring='accuracy', n_iter=10)\n"," rand.fit(X, y)\n"," best_scores.append(round(rand.best_score_, 3))\n"," print(best_scores)"]},{"cell_type":"markdown","metadata":{"_cell_guid":"70d61d3d-4b78-43cb-bccf-d7588dc61762","_uuid":"695a98ccaa5870275bc2fe91ef9527dbb6833368"},"source":["### **Tuning with searching multiple parameters simultaneously** ###"]},{"cell_type":"code","execution_count":116,"metadata":{"_cell_guid":"114d4734-9647-4f09-b1bd-20c022421011","_uuid":"0f4aad3e3f129aadfda0828345ce6418d46c5cf2","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def tuningMultParam(knn):\n"," \n"," #Searching multiple parameters simultaneously\n"," # define the parameter values that should be searched\n"," k_range = list(range(1, 31))\n"," weight_options = ['uniform', 'distance']\n"," \n"," # create a parameter grid: map the parameter names to the values that should be searched\n"," param_grid = dict(n_neighbors=k_range, weights=weight_options)\n"," print(param_grid) \n"," \n"," # instantiate and fit the grid\n"," grid = GridSearchCV(knn, param_grid, cv=10, scoring='accuracy')\n"," grid.fit(X, y) \n"," \n"," # view the complete results\n"," print(grid.cv_results_)\n"," \n"," # examine the best model\n"," print('Multiparam. Best Score: ', grid.best_score_)\n"," print('Multiparam. Best Params: ', grid.best_params_)"]},{"cell_type":"markdown","metadata":{"_uuid":"2e05b3654a69f79cce46c2f5f007933de7dd91dc"},"source":["\n","## **8. Evaluating models**
\n","\n","### Logistic Regression"]},{"cell_type":"code","execution_count":117,"metadata":{"_cell_guid":"8613beea-11e3-426b-91f6-39df3626eaf9","_uuid":"89f5e2c8ec51637568ac22982649205ca1c340e2","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def logisticRegression():\n"," # train a logistic regression model on the training set\n"," logreg = LogisticRegression()\n"," logreg.fit(X_train, y_train)\n"," \n"," # make class predictions for the testing set\n"," y_pred_class = logreg.predict(X_test)\n"," \n"," print('########### Logistic Regression ###############')\n"," \n"," accuracy_score = evalClassModel(logreg, y_test, y_pred_class, True)\n"," \n"," #Data for final graph\n"," methodDict['Log. Regres.'] = accuracy_score * 100"]},{"cell_type":"markdown","metadata":{"_cell_guid":"de99d1dd-eb98-478c-9ed0-3cb518eac4b2","_uuid":"1f71d045fa89ece846fe9e44559f057293e8bdf7"},"source":["\n"]},{"cell_type":"code","execution_count":118,"metadata":{"_cell_guid":"2c090f3c-ecc9-433d-84a9-0d8f56cfd75d","_uuid":"dd4c840a2b9a7ed7ad737730c4f780d11d603699","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["########### Logistic Regression ###############\n","Accuracy: 0.7962962962962963\n","Null accuracy:\n"," 0 191\n","1 187\n","Name: treatment, dtype: int64\n","Percentage of ones: 0.4947089947089947\n","Percentage of zeros: 0.5052910052910053\n","True: [0 0 0 0 0 0 0 0 1 1 0 1 1 0 1 1 0 1 0 0 0 1 1 0 0]\n","Pred: [1 0 0 0 1 1 0 1 0 1 0 1 1 0 1 1 1 1 0 0 0 0 1 0 0]\n"]},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["[[151 40]\n"," [ 28 159]]\n"]}],"source":["stacking()"]},{"cell_type":"markdown","metadata":{"_uuid":"87e8181f247c4dd00dd8b2caf65ebe5b76367719"},"source":["\n","## **9. Predicting with Neural Network**\n"]},{"cell_type":"markdown","metadata":{"_uuid":"ee05c058a7e73956696bead61f457ee04f39a004"},"source":["### Create input functions\n","You must create input functions to supply data for training, evaluating, and prediction."]},{"cell_type":"code","execution_count":132,"metadata":{"_uuid":"3c532804f31ae661e5cf820406e4eee58679dcc6","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","import argparse\n","\n","\n","batch_size = 100\n","train_steps = 1000\n","\n","X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=0)\n","\n","def train_input_fn(features, labels, batch_size):\n"," \"\"\"An input function for training\"\"\"\n"," # Convert the inputs to a Dataset.\n"," dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))\n","\n"," # Shuffle, repeat, and batch the examples.\n"," return dataset.shuffle(1000).repeat().batch(batch_size)\n","\n","def eval_input_fn(features, labels, batch_size):\n"," \"\"\"An input function for evaluation or prediction\"\"\"\n"," features=dict(features)\n"," if labels is None:\n"," # No labels, use only features.\n"," inputs = features\n"," else:\n"," inputs = (features, labels)\n","\n"," # Convert the inputs to a Dataset.\n"," dataset = tf.data.Dataset.from_tensor_slices(inputs)\n","\n"," # Batch the examples\n"," assert batch_size is not None, \"batch_size must not be None\"\n"," dataset = dataset.batch(batch_size)\n","\n"," # Return the dataset.\n"," return dataset\n","\n"]},{"cell_type":"markdown","metadata":{"_uuid":"7540da5ccf242c31fae61193422a50c4132fb7cd"},"source":["### Define the feature columns\n","A feature column is an object describing how the model should use raw input data from the features dictionary. When you build an Estimator model, you pass it a list of feature columns that describes each of the features you want the model to use."]},{"cell_type":"code","execution_count":133,"metadata":{"_uuid":"4b51ea10c5fc254910a46910db0fb76b27150b24","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["WARNING:tensorflow:From C:\\Users\\puran\\AppData\\Local\\Temp\\ipykernel_20412\\3225071575.py:2: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n"]}],"source":["# Define Tensorflow feature columns\n","age = tf.feature_column.numeric_column(\"Age\")\n","gender = tf.feature_column.numeric_column(\"Gender\")\n","family_history = tf.feature_column.numeric_column(\"family_history\")\n","benefits = tf.feature_column.numeric_column(\"benefits\")\n","care_options = tf.feature_column.numeric_column(\"care_options\")\n","anonymity = tf.feature_column.numeric_column(\"anonymity\")\n","leave = tf.feature_column.numeric_column(\"leave\")\n","work_interfere = tf.feature_column.numeric_column(\"work_interfere\")\n","feature_columns = [age, gender, family_history, benefits, care_options, anonymity, leave, work_interfere]\n"]},{"cell_type":"markdown","metadata":{"_uuid":"0ec5458a421476ca5b77b77d67e5d7dbd406090f"},"source":["### Instantiate an Estimator\n","Our problem is a classic classification problem. We want to predict whether a patient has to be treated or not. We'll use tf.estimator.DNNClassifier for deep models that perform multi-class classification."]},{"cell_type":"code","execution_count":147,"metadata":{"_uuid":"605d76f9ecf1e3a87332633a0b84c68b04024494","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["INFO:tensorflow:Using default config.\n","WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\n","INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\puran\\\\AppData\\\\Local\\\\Temp\\\\tmpzxdwhcif', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n","graph_options {\n"," rewrite_options {\n"," meta_optimizer_iterations: ONE\n"," }\n","}\n",", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"]}],"source":["# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.\n","model = tf.estimator.DNNClassifier(feature_columns=feature_columns,\n"," hidden_units=[10, 10],\n"," optimizer=tf.keras.optimizers.legacy.Adagrad(learning_rate=0.1, decay=0.001))"]},{"cell_type":"markdown","metadata":{"_uuid":"294ab3369fd6ddee67b62b402ebe707001010466"},"source":["### Train, Evaluate, and Predict\n","Now that we have an Estimator object, we can call methods to do the following:\n","\n","* Train the model.\n","* Evaluate the trained model.\n","* Use the trained model to make predictions.\n"]},{"cell_type":"markdown","metadata":{"_uuid":"4fcb1a081143098bce426cc496167f024c432350"},"source":["#### Train the model\n","The steps argument tells the method to stop training after a number of training steps."]},{"cell_type":"code","execution_count":148,"metadata":{"_uuid":"805f28d819c59eee02123b11eff66d812d392188","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["INFO:tensorflow:Calling model_fn.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\keras\\optimizers\\legacy\\adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Call initializer instance with the dtype argument instead of passing it to the constructor\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:Done calling model_fn.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\estimator.py:1417: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\estimator.py:1454: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:Create CheckpointSaverHook.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:Graph was finalized.\n","INFO:tensorflow:Running local_init_op.\n","INFO:tensorflow:Done running local_init_op.\n","INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n","INFO:tensorflow:Saving checkpoints for 0 into C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\\model.ckpt.\n","INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:loss = 0.7750711, step = 0\n","INFO:tensorflow:global_step/sec: 127.227\n","INFO:tensorflow:loss = 0.40610826, step = 100 (0.789 sec)\n","INFO:tensorflow:global_step/sec: 328.947\n","INFO:tensorflow:loss = 0.34490243, step = 200 (0.303 sec)\n","INFO:tensorflow:global_step/sec: 357.148\n","INFO:tensorflow:loss = 0.39927828, step = 300 (0.278 sec)\n","INFO:tensorflow:global_step/sec: 354.609\n","INFO:tensorflow:loss = 0.33172143, step = 400 (0.290 sec)\n","INFO:tensorflow:global_step/sec: 346.023\n","INFO:tensorflow:loss = 0.3397673, step = 500 (0.281 sec)\n","INFO:tensorflow:global_step/sec: 456.62\n","INFO:tensorflow:loss = 0.42048493, step = 600 (0.222 sec)\n","INFO:tensorflow:global_step/sec: 411.523\n","INFO:tensorflow:loss = 0.32625598, step = 700 (0.242 sec)\n","INFO:tensorflow:global_step/sec: 440.531\n","INFO:tensorflow:loss = 0.37452662, step = 800 (0.226 sec)\n","INFO:tensorflow:global_step/sec: 425.526\n","INFO:tensorflow:loss = 0.4125634, step = 900 (0.235 sec)\n","INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1000...\n","INFO:tensorflow:Saving checkpoints for 1000 into C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\\model.ckpt.\n","INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1000...\n","INFO:tensorflow:Loss for final step: 0.34743762.\n"]},{"data":{"text/plain":[""]},"execution_count":148,"metadata":{},"output_type":"execute_result"}],"source":["model.train(input_fn=lambda:train_input_fn(X_train, y_train, batch_size), steps=train_steps)"]},{"cell_type":"markdown","metadata":{"_uuid":"a4e82a5b76e6d0a89d21187413ed1574788fbe5a"},"source":["### Evaluate the trained model\n","Now that the model has been trained, we can get some statistics on its performance. The following code block evaluates the accuracy of the trained model on the test data."]},{"cell_type":"code","execution_count":149,"metadata":{"_uuid":"2b99faa167137ef3e90c7ebb9cd093837e8d4ef5","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["INFO:tensorflow:Calling model_fn.\n","INFO:tensorflow:Done calling model_fn.\n","INFO:tensorflow:Starting evaluation at 2022-12-16T18:37:47\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow\\python\\training\\evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:Graph was finalized.\n","INFO:tensorflow:Restoring parameters from C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\\model.ckpt-1000\n","INFO:tensorflow:Running local_init_op.\n","INFO:tensorflow:Done running local_init_op.\n","INFO:tensorflow:Inference Time : 6.38700s\n","INFO:tensorflow:Finished evaluation at 2022-12-16-18:37:54\n","INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.8068783, accuracy_baseline = 0.505291, auc = 0.8851108, auc_precision_recall = 0.850176, average_loss = 0.4337394, global_step = 1000, label/mean = 0.49470899, loss = 0.43456596, precision = 0.75, prediction/mean = 0.50710773, recall = 0.9144385\n","INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1000: C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\\model.ckpt-1000\n","\n","Test set accuracy: 0.81\n","\n"]}],"source":["# Evaluate the model.\n","eval_result = model.evaluate(\n"," input_fn=lambda:eval_input_fn(X_test, y_test, batch_size))\n","\n","print('\\nTest set accuracy: {accuracy:0.2f}\\n'.format(**eval_result))\n","\n","#Data for final graph\n","accuracy = eval_result['accuracy'] * 100\n","methodDict['NN DNNClasif.'] = accuracy"]},{"cell_type":"markdown","metadata":{"_uuid":"de375081922a83052d997ce836c31884ce44efd7"},"source":["### Making predictions (inferring) from the trained model\n","We now have a trained model that produces good evaluation results. We can now use the trained model to predict whether a patient needs treatment or not."]},{"cell_type":"code","execution_count":150,"metadata":{"_uuid":"a72f2f39e610a12fea7f542fa4cccec9ca17544a","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["INFO:tensorflow:Calling model_fn.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\head\\base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\head\\binary_class_head.py:561: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","WARNING:tensorflow:From y:\\Anaconda\\envs\\StrokePredictionModel\\lib\\site-packages\\tensorflow_estimator\\python\\estimator\\head\\binary_class_head.py:563: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use tf.keras instead.\n","INFO:tensorflow:Done calling model_fn.\n","INFO:tensorflow:Graph was finalized.\n","INFO:tensorflow:Restoring parameters from C:\\Users\\puran\\AppData\\Local\\Temp\\tmpzxdwhcif\\model.ckpt-1000\n","INFO:tensorflow:Running local_init_op.\n","INFO:tensorflow:Done running local_init_op.\n"]}],"source":["predictions = list(model.predict(input_fn=lambda:eval_input_fn(X_train, y_train, batch_size=batch_size)))"]},{"cell_type":"code","execution_count":151,"metadata":{"_uuid":"e078018202e2a75c3cace605d815be5435798682","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","
\n"," \n","
\n","
\n","
index
\n","
prediction
\n","
expected
\n","
\n"," \n"," \n","
\n","
0
\n","
929
\n","
0
\n","
0
\n","
\n","
\n","
1
\n","
901
\n","
1
\n","
1
\n","
\n","
\n","
2
\n","
579
\n","
1
\n","
1
\n","
\n","
\n","
3
\n","
367
\n","
1
\n","
1
\n","
\n","
\n","
4
\n","
615
\n","
1
\n","
1
\n","
\n"," \n","
\n","
"],"text/plain":[" index prediction expected\n","0 929 0 0\n","1 901 1 1\n","2 579 1 1\n","3 367 1 1\n","4 615 1 1"]},"execution_count":151,"metadata":{},"output_type":"execute_result"}],"source":["# Generate predictions from the model\n","template = ('\\nIndex: \"{}\", Prediction is \"{}\" ({:.1f}%), expected \"{}\"')\n","\n","# Dictionary for predictions\n","col1 = []\n","col2 = []\n","col3 = []\n","\n","\n","for idx, input, p in zip(X_train.index, y_train, predictions):\n"," v = p[\"class_ids\"][0] \n"," class_id = p['class_ids'][0]\n"," probability = p['probabilities'][class_id] # Probability\n"," \n"," # Adding to dataframe\n"," col1.append(idx) # Index\n"," col2.append(v) # Prediction\n"," col3.append(input) # Expecter\n"," \n"," \n"," #print(template.format(idx, v, 100 * probability, input))\n","\n","\n","results = pd.DataFrame({'index':col1, 'prediction':col2, 'expected':col3})\n","results.head()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"83d68dc1-1929-4069-8b9f-2b1b0768840d","_uuid":"289d8b5896d20f5859307e07c509c92151b4e942"},"source":["\n","## **10. Success method plot**"]},{"cell_type":"code","execution_count":152,"metadata":{"_cell_guid":"ff9279ed-3a53-47ef-8d69-b9c013c73ba0","_uuid":"67df920b185fcf2a8941369bb91c8975d0c8ce10","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["def plotSuccess():\n"," s = pd.Series(methodDict)\n"," s = s.sort_values(ascending=False)\n"," plt.figure(figsize=(12,8))\n"," #Colors\n"," ax = s.plot(kind='bar') \n"," for p in ax.patches:\n"," ax.annotate(str(round(p.get_height(),2)), (p.get_x() * 1.005, p.get_height() * 1.005))\n"," plt.ylim([70.0, 90.0])\n"," plt.xlabel('Method')\n"," plt.ylabel('Percentage')\n"," plt.title('Success of methods')\n"," \n"," plt.show()"]},{"cell_type":"code","execution_count":153,"metadata":{"_cell_guid":"3672320b-e060-45e4-b481-029ad98feb58","_uuid":"26d24346015d93adf6c4fa9d384f480b7b391584","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["plotSuccess()"]},{"cell_type":"markdown","metadata":{"_cell_guid":"05e05c13-6e1f-4900-b147-844cf49cdb41","_uuid":"97b824045c3668ad8eab81bdfe35d07ab0d18c76"},"source":["\n","## **11. Creating predictions on test set**"]},{"cell_type":"code","execution_count":160,"metadata":{"_cell_guid":"cd416a0e-a234-45c1-a502-e46d7b381dcd","_uuid":"ddf3291a57a8ed1ebd10c81779abc44ff9f5ce72","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[{"data":{"text/html":["
\n","\n","
\n"," \n","
\n","
\n","
Index
\n","
Treatment
\n","
\n"," \n"," \n","
\n","
0
\n","
5
\n","
1
\n","
\n","
\n","
1
\n","
494
\n","
0
\n","
\n","
\n","
2
\n","
52
\n","
0
\n","
\n","
\n","
3
\n","
984
\n","
0
\n","
\n","
\n","
4
\n","
186
\n","
0
\n","
\n"," \n","
\n","
"],"text/plain":[" Index Treatment\n","0 5 1\n","1 494 0\n","2 52 0\n","3 984 0\n","4 186 0"]},"execution_count":160,"metadata":{},"output_type":"execute_result"}],"source":["# Generate predictions with the best method\n","clf = AdaBoostClassifier()\n","clf.fit(X, y)\n","dfTestPredictions = clf.predict(X_test)\n","\n","# Write predictions to csv file\n","# We don't have any significative field so we save the index\n","results = pd.DataFrame({'Index': X_test.index, 'Treatment': dfTestPredictions})\n","# Save to file\n","# This file will be visible after publishing in the output section\n","results.to_csv('preprocessed_datasets/2014/results.csv', index=False)\n","results.head()"]},{"cell_type":"markdown","metadata":{"_uuid":"35d306ec4cb719e14f1b2600dff7a3ffa89f1b47"},"source":["\n","## ** 12. Submision**"]},{"cell_type":"markdown","metadata":{"_uuid":"8fdd6b0a10dd1efb08cfac45026ad83d3fd770eb"},"source":["### Prepare Submission File\n","We make submissions in CSV files. Your submissions usually have two columns: an ID column and a prediction column. The ID field comes from the test data (keeping whatever name the ID field had in that data, which for our data is the index). The prediction column will use the name of the target field.\n","\n","We will create a DataFrame with this data, and then use the dataframe's to_csv method to write our submission file. Explicitly include the argument index=False to prevent pandas from adding another column in our csv file."]},{"cell_type":"code","execution_count":161,"metadata":{"_uuid":"224fe370c6c43359faa76d067a3ef39ed3b13402","collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"outputs":[],"source":["# Write predictions to csv file\n","# We don't have any significative field so we save the index\n","results = pd.DataFrame({'Index': X_test.index, 'Treatment': dfTestPredictions})\n","# Save to file\n","# This file will be visible after publishing in the output section\n","results.to_csv('preprocessed_datasets/2014/submission.csv', index=False)\n"]},{"cell_type":"markdown","metadata":{"_uuid":"e56961d7c4f00c35537023aad9c6cbf09cd8286a"},"source":["### Make Submission\n","Hit the blue Publish button at the top of your notebook screen. It will take some time for your kernel to run. When it has finished your navigation bar at the top of the screen will have a tab for Output. This only shows up if you have written an output file (like we did in the Prepare Submission File step)."]},{"cell_type":"markdown","metadata":{"_cell_guid":"2c05dd56-2528-40b1-9cd0-368300adc2c3","_uuid":"d5972622da305ae627019fc0476a769a22a9f3fc"},"source":["\n","## **13. Conclusions**\n","\n","As a beginner I don't know whether the results are the best. I think over 80% of success in the majority of methods is a good rate, given the point is to know whether a patient needs treatment or not.\n","\n","There's only left to have a way to persist the model for future use without having to retrain. This job will be done in another kernel.\n","\n","\n","Thanks for reading and if you'd like my job or want to give some advice, feel free.\n","\n"]}],"metadata":{"kernelspec":{"display_name":"StrokePredictionModel","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.13 (default, Mar 28 2022, 06:59:08) [MSC v.1916 64 bit (AMD64)]"},"vscode":{"interpreter":{"hash":"6d6bab66b583e7661b89cead2220317a23c391a40fb8c52f2c1bcd3c04f3fbda"}}},"nbformat":4,"nbformat_minor":4}