-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInteractive Visualization of Decision Trees
81 lines (64 loc) · 2.02 KB
/
Interactive Visualization of Decision Trees
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
## https://towardsdatascience.com/interactive-visualization-of-decision-trees-with-jupyter-widgets-ca15dd312084
## Interactive Visualization of Decision Trees with Jupyter Widgets
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn import tree
from sklearn.datasets import load_wine
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
# load dataset
data = load_wine()
# feature matrix
X = data.data
# target vector
y = data.target
# class labels
labels = data.feature_names
# print dataset description
print(data.DESCR)
estimator = DecisionTreeClassifier()
estimator.fit(X, y)
graph = Source(tree.export_graphviz(estimator, out_file=None
, feature_names=labels, class_names=['0', '1', '2']
, filled = True))
display(SVG(graph.pipe(format='svg')))
pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension
conda install -c conda-forge ipywidgets
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn import tree
from sklearn.datasets import load_wine
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
from ipywidgets import interactive
# load dataset
data = load_wine()
# feature matrix
X = data.data
# target vector
y = data.target
# class labels
labels = data.feature_names
def plot_tree(crit, split, depth, min_split, min_leaf=0.2):
estimator = DecisionTreeClassifier(random_state = 0
, criterion = crit
, splitter = split
, max_depth = depth
, min_samples_split=min_split
, min_samples_leaf=min_leaf)
estimator.fit(X, y)
graph = Source(tree.export_graphviz(estimator
, out_file=None
, feature_names=labels
, class_names=['0', '1', '2']
, filled = True))
display(SVG(graph.pipe(format='svg')))
return estimator
inter=interactive(plot_tree
, crit = ["gini", "entropy"]
, split = ["best", "random"]
, depth=[1,2,3,4]
, min_split=(0.1,1)
, min_leaf=(0.1,0.5))
display(inter)