Skip to content

Commit d93f74d

Browse files
committed
init
1 parent 0c58821 commit d93f74d

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

basic-ml.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import streamlit as st
2+
from sklearn import datasets
3+
import numpy as np
4+
5+
from sklearn.model_selection import train_test_split
6+
from sklearn.neighbors import KNeighborsClassifier
7+
from sklearn.svm import SVC
8+
from sklearn.ensemble import RandomForestClassifier
9+
from sklearn.metrics import accuracy_score
10+
from sklearn.decomposition import PCA
11+
from matplotlib import pyplot as plt
12+
13+
# frontend ui section
14+
st.title('Basic ML Visualizer')
15+
16+
st.write("""
17+
## Exploring different classifiers
18+
""")
19+
20+
d_name = st.sidebar.selectbox("Select dataset", ("iris", "breast cancer", "wine"))
21+
c_name = st.sidebar.selectbox("Select classifer", ("KNN", "SVM", "RF"))
22+
23+
# classifier building section
24+
def get_dataset(d_name):
25+
if d_name == "iris":
26+
data = datasets.load_iris()
27+
elif d_name == "breast cancer":
28+
data = datasets.load_breast_cancer()
29+
else:
30+
data = datasets.load_wine()
31+
X = data.data
32+
y = data.target
33+
return X,y
34+
35+
X,y = get_dataset(d_name)
36+
st.write("dataset shape: ", X.shape)
37+
st.write("no. of classes: ", len(np.unique(y)))
38+
39+
def add_parameters_ui(c_name):
40+
params = dict()
41+
if c_name == "KNN":
42+
K = st.sidebar.slider("K", 1, 15)
43+
params["K"] = K
44+
elif c_name == "SVM":
45+
C = st.sidebar.slider("C", 0.01, 10.0)
46+
params["C"] = C
47+
else:
48+
max_depth = st.sidebar.slider("max_depth", 2, 15)
49+
n_estimators = st.sidebar.slider("n_estimators", 1, 100)
50+
params["max_depth"] = max_depth
51+
params["n_estimators"] = n_estimators
52+
return params
53+
54+
params = add_parameters_ui(c_name)
55+
56+
def get_classifier(c_name, params):
57+
if c_name == "KNN":
58+
clf = KNeighborsClassifier(n_neighbors=params["K"])
59+
elif c_name == "SVM":
60+
clf = SVC(C=params["C"])
61+
else:
62+
clf = RandomForestClassifier(n_estimators=params["n_estimators"], max_depth=params["max_depth"], random_state=42)
63+
return clf
64+
65+
clf = get_classifier(c_name, params)
66+
67+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
68+
69+
clf.fit(X_train, y_train)
70+
y_pred = clf.predict(X_test)
71+
72+
acc = accuracy_score(y_test, y_pred)
73+
st.write(f"Classifier = {c_name}")
74+
st.write(f"Accuracy = {acc}")
75+
76+
# adding 2 PCA components
77+
pca = PCA(2)
78+
X_projected = pca.fit_transform(X)
79+
80+
x1 = X_projected[:, 0]
81+
x2 = X_projected[:, 1]
82+
83+
# plotting section
84+
fig = plt.figure()
85+
plt.scatter(x1, x2, c=y, alpha=0.7, cmap="viridis")
86+
plt.xlabel("Principle component 1")
87+
plt.ylabel("Principle component 2")
88+
plt.colorbar()
89+
90+
st.pyplot(fig)

0 commit comments

Comments
 (0)