A train accident data was obtained from internet and the data contains many variables like people Survived data, sex, age, class travelled & other data. An attempt has been made to implement Logistic Linear classficiation and find survival probabilities.
Import Required Libraries
#data libraries
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pandas.plotting import scatter_matrix
%matplotlib inline
import seaborn as sns
sns.set(style="white",color_codes=True)
sns.set(font_scale=1.5)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import r2_score, mean_squared_error
from sklearn import metrics
from math import sqrt
%matplotlib inline
The Total Number of Columns & Rows of Data imported. Totally 619 x 10 data points available.
df_train = pd.read_csv("/trainaccidata.csv")
df_train.shape
(619, 10)
df_train.describe(include="all")
PassengerId | Survived | Class | Sex | Age | Cost | Stations | AA | BB | CC | |
---|---|---|---|---|---|---|---|---|---|---|
count | 619.000000 | 619.000000 | 619.000000 | 619 | 619.000000 | 619.000000 | 618 | 0.0 | 0.0 | 0.0 |
unique | NaN | NaN | NaN | 2 | NaN | NaN | 3 | NaN | NaN | NaN |
top | NaN | NaN | NaN | male | NaN | NaN | A | NaN | NaN | NaN |
freq | NaN | NaN | NaN | 388 | NaN | NaN | 443 | NaN | NaN | NaN |
mean | 310.000000 | 0.394184 | 2.310178 | NaN | 34.940226 | 2192.970921 | NaN | NaN | NaN | NaN |
std | 178.834188 | 0.489070 | 0.834574 | NaN | 17.663780 | 3169.612906 | NaN | NaN | NaN | NaN |
min | 1.000000 | 0.000000 | 1.000000 | NaN | 0.750000 | 0.000000 | NaN | NaN | NaN | NaN |
25% | 155.500000 | 0.000000 | 2.000000 | NaN | 22.000000 | 547.000000 | NaN | NaN | NaN | NaN |
50% | 310.000000 | 0.000000 | 3.000000 | NaN | 32.000000 | 1035.000000 | NaN | NaN | NaN | NaN |
75% | 464.500000 | 1.000000 | 3.000000 | NaN | 47.000000 | 2111.500000 | NaN | NaN | NaN | NaN |
max | 619.000000 | 1.000000 | 3.000000 | NaN | 71.000000 | 35351.000000 | NaN | NaN | NaN | NaN |
Now check for null in the data
df_train.isnull().sum()
PassengerId 0 Survived 0 Class 0 Sex 0 Age 0 Cost 0 Stations 1 AA 619 BB 619 CC 619 dtype: int64
Remove the unneccessary data
df_train=df_train.drop(['PassengerId'],axis=1)
df_train.head()
Survived | Class | Sex | Age | Cost | Stations | |
---|---|---|---|---|---|---|
0 | 1 | 3 | female | 0.75 | 1329 | C |
1 | 1 | 2 | male | 0.83 | 2001 | A |
2 | 1 | 1 | male | 0.92 | 10457 | A |
3 | 0 | 3 | male | 1.00 | 2738 | A |
4 | 1 | 3 | female | 1.00 | 768 | A |
Check Null again
df_train.isnull().sum()
PassengerId 0 Survived 0 Class 0 Sex 0 Age 0 Cost 0 Stations 1 dtype: int64
There is null in Station onbaord., it can be dropped or substituted. Am dropping it.
df_train.dropna(inplace=True)
df_train.isnull().sum()
PassengerId 0 Survived 0 Class 0 Sex 0 Age 0 Cost 0 Stations 0 dtype: int64
Check the datatypes in the csv
df_train.dtypes
PassengerId int64 Survived int64 Class int64 Sex object Age float64 Cost int64 Stations object dtype: object
Convert sex into categorical
df_train_dummied =pd.get_dummies(df_train,columns=["Sex"])
df_train_dummied =pd.get_dummies(df_train_dummied,columns=["Stations"])
df_train_dummied.dtypes
PassengerId int64 Survived int64 Class int64 Age float64 Cost int64 Sex_female uint8 Sex_male uint8 Stations_A uint8 Stations_B uint8 Stations_C uint8 dtype: object
Converting into Indicative variable
df_train_dummied.head()
PassengerId | Survived | Class | Age | Cost | Sex_female | Sex_male | Stations_A | Stations_B | Stations_C | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 470 | 1 | 3 | 0.75 | 1329 | 1 | 0 | 0 | 0 | 1 |
1 | 79 | 1 | 2 | 0.83 | 2001 | 0 | 1 | 1 | 0 | 0 |
2 | 306 | 1 | 1 | 0.92 | 10457 | 0 | 1 | 1 | 0 | 0 |
3 | 165 | 0 | 3 | 1.00 | 2738 | 0 | 1 | 1 | 0 | 0 |
4 | 173 | 1 | 3 | 1.00 | 768 | 1 | 0 | 1 | 0 | 0 |
The total number of male vs female in the coach
df_train.Sex.value_counts()
male 388 female 231 Name: Sex, dtype: int64
Station Boarded on Details
df_train.Stations.value_counts()
A 443 C 119 B 56 Name: Stations, dtype: int64
Let do some heat map
plt.figure(figsize=(5,5))
sns.heatmap(df_train_dummied.corr())
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b594bc278>
PassengerID doesnt have any role, better to remove it
df_train_dummied=df_train_dummied.drop(['PassengerId'],axis=1)
df_train_dummied.head()
Survived | Class | Age | Cost | Sex_female | Sex_male | Stations_A | Stations_B | Stations_C | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 3 | 0.75 | 1329 | 1 | 0 | 0 | 0 | 1 |
1 | 1 | 2 | 0.83 | 2001 | 0 | 1 | 1 | 0 | 0 |
2 | 1 | 1 | 0.92 | 10457 | 0 | 1 | 1 | 0 | 0 |
3 | 0 | 3 | 1.00 | 2738 | 0 | 1 | 1 | 0 | 0 |
4 | 1 | 3 | 1.00 | 768 | 1 | 0 | 1 | 0 | 0 |
Now again, performing heatmap
plt.figure(figsize=(5,5))
sns.heatmap(df_train_dummied.corr())
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b56179c18>
There is a strong correlation between Sex and Survived stats; Now lets built a classifier, using Independent x & dependent y variable.
used_features = ["Class","Age","Sex_female","Sex_male","Stations_A","Stations_B","Stations_C"]
x = df_train_dummied[used_features].values
y = df_train_dummied["Survived"]
Now, its time to split dataset into trained and test
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.3,random_state=1)
Checking the shape of test and train datasets
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
(432, 7) (186, 7) (432,) (186,)
Applying Logistic regression
LogReg = LogisticRegression()
On x train & ytrain
LogReg.fit(x_train,y_train)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=100, multi_class='auto', n_jobs=None, penalty='l2', random_state=None, solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)
Storing the Prediction in y_pred variables.
y_pred = LogReg.predict(x_test)
Computing Confusion Matrixs to Evaluate the accuracy of Classification used
metrics.confusion_matrix(y_test,y_pred)
array([[95, 12], [24, 55]])
Lets Check more accuracy
metrics.accuracy_score(y_test,y_pred)
0.8064516129032258
The computed matrixs has the accuracy of 81%. That is for 100 Passengers, around 81 passengers status are accurately known as survived or not survided.
len(x_test)
186
Lets Plot other parameter from Confusion matrixs
print(classification_report(y_test,y_pred))
precision recall f1-score support 0 0.80 0.89 0.84 107 1 0.82 0.70 0.75 79 accuracy 0.81 186 macro avg 0.81 0.79 0.80 186 weighted avg 0.81 0.81 0.80 186
Algorithm has decent precision & decent recall scores. The Coefficients & projected lines can be generated.
LogReg.coef_
array([[-0.93651253, -0.02470958, 1.2435027 , -1.24361931, -0.56096349, 0.40195066, 0.15889623]])
LogReg.intercept_
array([2.97946363])
Now the Intercept is sensible value and we can check the Probabilities.
df_train_dummied[used_features].columns
Index(['Class', 'Age', 'Sex_female', 'Sex_male', 'Stations_A', 'Stations_B', 'Stations_C'], dtype='object')
Now, the first column will be probability of entry has not survived label and second column entry that survived label.
LogReg.predict_proba(x_test)
array([[0.82321308, 0.17678692], [0.40373951, 0.59626049], [0.60558477, 0.39441523], [0.94245966, 0.05754034], [0.84985893, 0.15014107], [0.18728435, 0.81271565], [0.25028186, 0.74971814], [0.3978055 , 0.6021945 ], [0.70804697, 0.29195303], [0.92056678, 0.07943322], [0.3743838 , 0.6256162 ], [0.67864754, 0.32135246], [0.75821473, 0.24178527], [0.88391211, 0.11608789], [0.77154317, 0.22845683], [0.55752261, 0.44247739], [0.92892806, 0.07107194], [0.93518317, 0.06481683], [0.65653584, 0.34346416], [0.82321308, 0.17678692], [0.91874104, 0.08125896], [0.118094 , 0.881906 ], [0.86495191, 0.13504809], [0.10138955, 0.89861045], [0.81958824, 0.18041176], [0.96736015, 0.03263985], [0.9130323 , 0.0869677 ], [0.36505302, 0.63494698], [0.9241062 , 0.0758938 ], [0.84625216, 0.15374784], [0.48283594, 0.51716406], [0.60480059, 0.39519941], [0.19465748, 0.80534252], [0.80032396, 0.19967604], [0.60480059, 0.39519941], [0.17453915, 0.82546085], [0.50926883, 0.49073117], [0.21545788, 0.78454212], [0.23219685, 0.76780315], [0.47839638, 0.52160362], [0.05456151, 0.94543849], [0.4156908 , 0.5843092 ], [0.08865759, 0.91134241], [0.34191125, 0.65808875], [0.23219685, 0.76780315], [0.24007614, 0.75992386], [0.05350511, 0.94649489], [0.93972035, 0.06027965], [0.6628199 , 0.3371801 ], [0.90051653, 0.09948347], [0.25494693, 0.74505307], [0.13462671, 0.86537329], [0.26894863, 0.73105137], [0.96493672, 0.03506328], [0.93972035, 0.06027965], [0.44787585, 0.55212415], [0.92235496, 0.07764504], [0.95339818, 0.04660182], [0.75866386, 0.24133614], [0.75821473, 0.24178527], [0.36466091, 0.63533909], [0.9241062 , 0.0758938 ], [0.92056678, 0.07943322], [0.81128902, 0.18871098], [0.74010796, 0.25989204], [0.37600706, 0.62399294], [0.20319182, 0.79680818], [0.04852806, 0.95147194], [0.13462671, 0.86537329], [0.94508177, 0.05491823], [0.08642183, 0.91357817], [0.92056678, 0.07943322], [0.10366311, 0.89633689], [0.89600078, 0.10399922], [0.91309258, 0.08690742], [0.9657632 , 0.0342368 ], [0.9130323 , 0.0869677 ], [0.88391211, 0.11608789], [0.12559105, 0.87440895], [0.10335856, 0.89664144], [0.21895463, 0.78104537], [0.88650007, 0.11349993], [0.30989833, 0.69010167], [0.9241062 , 0.0758938 ], [0.68401238, 0.31598762], [0.67918193, 0.32081807], [0.88391211, 0.11608789], [0.91111165, 0.08888835], [0.91322676, 0.08677324], [0.84975293, 0.15024707], [0.92727933, 0.07272067], [0.41137752, 0.58862248], [0.12352188, 0.87647812], [0.32543235, 0.67456765], [0.31520742, 0.68479258], [0.13753152, 0.86246848], [0.89130479, 0.10869521], [0.07040303, 0.92959697], [0.94759103, 0.05240897], [0.46012544, 0.53987456], [0.95284617, 0.04715383], [0.5075382 , 0.4924618 ], [0.89130479, 0.10869521], [0.90292346, 0.09707654], [0.89130479, 0.10869521], [0.61147104, 0.38852896], [0.91309258, 0.08690742], [0.93830535, 0.06169465], [0.28916704, 0.71083296], [0.81128902, 0.18871098], [0.14968075, 0.85031925], [0.90485703, 0.09514297], [0.88888767, 0.11111233], [0.69063218, 0.30936782], [0.8763524 , 0.1236476 ], [0.71564929, 0.28435071], [0.91111165, 0.08888835], [0.21924413, 0.78075587], [0.84015292, 0.15984708], [0.79299019, 0.20700981], [0.6728628 , 0.3271372 ], [0.91497443, 0.08502557], [0.82321308, 0.17678692], [0.96234036, 0.03765964], [0.92056678, 0.07943322], [0.65727554, 0.34272446], [0.45380545, 0.54619455], [0.90485703, 0.09514297], [0.9241062 , 0.0758938 ], [0.9241062 , 0.0758938 ], [0.25494693, 0.74505307], [0.96409103, 0.03590897], [0.76759616, 0.23240384], [0.86495191, 0.13504809], [0.67323514, 0.32676486], [0.06420273, 0.93579727], [0.47223412, 0.52776588], [0.86495191, 0.13504809], [0.94759103, 0.05240897], [0.92582112, 0.07417888], [0.90270837, 0.09729163], [0.3935528 , 0.6064472 ], [0.77154317, 0.22845683], [0.86495191, 0.13504809], [0.43550251, 0.56449749], [0.09886755, 0.90113245], [0.92582112, 0.07417888], [0.14309623, 0.85690377], [0.30989833, 0.69010167], [0.56887211, 0.43112789], [0.94880468, 0.05119532], [0.86233075, 0.13766925], [0.85907445, 0.14092555], [0.91105013, 0.08894987], [0.76271567, 0.23728433], [0.9130323 , 0.0869677 ], [0.90696312, 0.09303688], [0.88416341, 0.11583659], [0.10834576, 0.89165424], [0.88642364, 0.11357636], [0.93830535, 0.06169465], [0.81253892, 0.18746108], [0.88391211, 0.11608789], [0.47839638, 0.52160362], [0.9657632 , 0.0342368 ], [0.36861451, 0.63138549], [0.83971168, 0.16028832], [0.74950028, 0.25049972], [0.47242338, 0.52757662], [0.20279521, 0.79720479], [0.46626928, 0.53373072], [0.81253892, 0.18746108], [0.30411983, 0.69588017], [0.72065067, 0.27934933], [0.39413803, 0.60586197], [0.51544237, 0.48455763], [0.91879772, 0.08120228], [0.84667869, 0.15332131], [0.6628199 , 0.3371801 ], [0.28916704, 0.71083296], [0.89600078, 0.10399922], [0.25028186, 0.74971814], [0.39946546, 0.60053454], [0.68915076, 0.31084924], [0.88391211, 0.11608789], [0.23219685, 0.76780315]])
plt.figure(figsize=(5,5))
sns.barplot(x="Sex_female",y="Survived",data=df_train_dummied)
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b55ee9ef0>
Above shows that more number of Females survived and bottomone shows more male did not survive.
sns.barplot(x="Sex_male",y="Survived",data=df_train_dummied)
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b55e4eef0>
On a cautionary Note, bottom I tried to plot Class vs Suriveval plot, but it is Lurking variable syndrome. Not to be accounted.
sns.barplot(x="Class",y="Survived",data=df_train_dummied)
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b55957e80>
sns.barplot(x="Class",y="Sex_male",data=df_train_dummied)
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b55445b00>
Even Cost vs Survival,
sns.barplot(x="Class",y="Cost",data=df_train_dummied)
<matplotlib.axes._subplots.AxesSubplot at 0x7f9b546a7be0>