-
Notifications
You must be signed in to change notification settings - Fork 0
/
logistic_regression.py
48 lines (32 loc) · 1.09 KB
/
logistic_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from sklearn.linear_model import LogisticRegression
import seaborn as sns
import matplotlib.pyplot as plt
from data_processing import process_data
from time import perf_counter
print("Preparing data...")
X_train, X_test, y_train, y_test, categories_mapping = process_data(
"star_classification.csv"
)
print("Training model...")
model = LogisticRegression(max_iter=1500)
training_start_time = perf_counter()
model.fit(X_train, y_train)
training_end_time = perf_counter()
print(f"Model took {training_end_time - training_start_time:.2f} seconds to train.")
y_pred = model.predict(X_test)
test_accuracy = model.score(X_test, y_test)
print(f"Test Accuracy: {test_accuracy * 100}%")
X_test["Class"] = y_test
def plot_data(x_ax, y_ax):
plt.scatter(X_test[x_ax], X_test[y_ax], c=y_test, edgecolor="k")
plt.scatter(X_test[x_ax], y_pred, c="red", marker="x")
plt.xlabel(x_ax)
plt.ylabel(y_ax)
plt.show()
plot_data("u", "Class")
plot_data("g", "Class")
plot_data("r", "Class")
plot_data("i", "Class")
plot_data("z", "Class")
plot_data("redshift", "Class")
sns.pairplot(X_test, hue="Class")