Imported libraries¶
In [ ]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
Load data and define feature columns and target variable¶
In [ ]:
# Read data
data = pd.read_csv("datasets/heartdisease.csv")
# Define your feature columns and target variable
X = data.drop('HeartDisease', axis=1)
y = data['HeartDisease']
Feature engineering: encode categorical data¶
In [ ]:
# Feature engineering of the categorical data
label_encoders = {}
for col in X.columns:
le = LabelEncoder()
X.loc[:, col] = le.fit_transform(X[col])
label_encoders[col] = le
In [ ]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1, shuffle=False)
# Fit logistic regression model
lr_model = LogisticRegression()
lr_model.fit(X_train, y_train)
# Make predictions
y_pred = lr_model.predict(X_test)
C:\Users\robkr\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\sklearn\linear_model\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. of ITERATIONS REACHED LIMIT. Increase the number of iterations (max_iter) or scale the data as shown in: https://scikit-learn.org/stable/modules/preprocessing.html Please also refer to the documentation for alternative solver options: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression n_iter_i = _check_optimize_result(
Results¶
In [ ]:
# Generate confusion matrix
class_labels = le.classes_
# Display confusion matrix using a heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel('Predicted Labels')
plt.ylabel('Actual Labels')
plt.title('Confusion Matrix - AAPL Dataset')
plt.show()