-
Notifications
You must be signed in to change notification settings - Fork 18
/
p01-feature-splits.py
95 lines (73 loc) · 2.53 KB
/
p01-feature-splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Decision Trees: Feature Splits
#%%
# Python typing introduced in 3.5: https://docs.python.org/3/library/typing.html
from typing import List
# As of Python 3.7, this exists! https://www.python.org/dev/peps/pep-0557/
from dataclasses import dataclass
# My python file (very limited for now, but we will build up shared functions)
from shared import TODO
#%%
# Let's define a really simple class with two fields:
@dataclass
class DataPoint:
temperature: float
frozen: bool
def secret_answer(self) -> bool:
return self.temperature <= 32
def clone(self) -> "DataPoint":
return DataPoint(self.temperature, self.frozen)
# Fahrenheit, sorry.
data = [
# vermont temperatures; frozen=True
DataPoint(0, True),
DataPoint(-2, True),
DataPoint(10, True),
DataPoint(11, True),
DataPoint(6, True),
DataPoint(28, True),
DataPoint(31, True),
# warm temperatures; frozen=False
DataPoint(33, False),
DataPoint(45, False),
DataPoint(76, False),
DataPoint(60, False),
DataPoint(34, False),
DataPoint(98.6, False),
]
def is_water_frozen(temperature: float) -> bool:
"""
This is how we **should** implement it.
"""
return temperature <= 32
# Make sure the data I invented is actually correct...
for d in data:
assert d.frozen == is_water_frozen(d.temperature)
def find_candidate_splits(data: List[DataPoint]) -> List[float]:
midpoints = []
TODO("find the midpoints!")
return midpoints
def gini_impurity(points: List[DataPoint]) -> float:
"""
The standard version of gini impurity sums over the classes:
"""
p_ice = sum(1 for x in points if x.frozen) / len(points)
p_water = 1.0 - p_ice
return p_ice * (1 - p_ice) + p_water * (1 - p_water)
# for binary gini-impurity (just two classes) we can simplify, because 1 - p_ice == p_water, etc.
# p_ice * p_water + p_water * p_ice
# 2 * p_ice * p_water
# not really a huge difference.
def impurity_of_split(points: List[DataPoint], split: float) -> float:
smaller = []
bigger = []
TODO("split the points based on the candidate split value")
return gini_impurity(smaller) + gini_impurity(bigger)
if __name__ == "__main__":
print("Initial Impurity: ", gini_impurity(data))
print("Impurity of first-six (all True): ", gini_impurity(data[:6]))
print("")
for split in find_candidate_splits(data):
score = impurity_of_split(data, split)
print("splitting at {} gives us impurity {}".format(split, score))
if score == 0.0:
break