Skip to content

Commit ebc55ea

Browse files
author
Richard Michael
committed
added optimization and predictive unc to KNN
skopt k selection ; variance across predictions
1 parent e171033 commit ebc55ea

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

__init__.py

Whitespace-only changes.

algorithm_factories.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def UncertainRFFactory(representation, alphabet):
1717
return UncertainRandomForest()
1818

1919
def KNNFactory(representation, alphabet):
20-
return KNN()
20+
return KNN(optimize=True)
2121

2222

2323
optimize = True

algorithms/KNN.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,59 @@
11
import numpy as np
2+
from typing import Tuple
23
from sklearn.neighbors import KNeighborsRegressor
4+
from sklearn.model_selection import cross_val_score
35
from algorithms.abstract_algorithm import AbstractAlgorithm
6+
from skopt.space import Integer
7+
from skopt.utils import use_named_args
8+
from skopt import gp_minimize
49

510

611
class KNN(AbstractAlgorithm):
7-
def __init__(self):
12+
def __init__(self, optimize: bool=False, k_max: int=100, opt_budget: int=100, seed=42) -> None:
813
self.model = None
9-
self.optimize = False
14+
self.optimize = optimize
15+
self.seed = seed
16+
if self.optimize:
17+
self.k_max = k_max
18+
self.opt_budget = opt_budget
19+
self.opt_space = [
20+
Integer(1, self.k_max, name="n_neighbors"),
21+
]
1022

11-
def get_name(self):
23+
def get_name(self) -> str:
1224
return "KNN"
1325

14-
def train(self, X, Y):
26+
def train(self, X: np.ndarray, Y: np.ndarray) -> None:
1527
assert(Y.shape[1] == 1)
1628
self.model = KNeighborsRegressor(n_neighbors=int(np.ceil(0.3*len(X))), n_jobs=-1) # use all processors
1729
Y = Y.squeeze() if Y.shape[0] > 1 else Y
30+
if self.optimize:
31+
self.k_max = int(len(X)) # all data is maximal possible
32+
@use_named_args(self.opt_space)
33+
def _opt_objective(**params):
34+
self.model.set_params(**params)
35+
return -np.mean(cross_val_score(self.model, X, Y, cv=5, n_jobs=-1, scoring="neg_mean_absolute_error"))
36+
res_gp = gp_minimize(_opt_objective, self.opt_space, n_calls=self.opt_budget, random_state=self.seed)
37+
print(f"Score: {res_gp.fun}")
38+
print(f"Parameters: k={res_gp.x[0]}")
39+
self.model = KNeighborsRegressor(n_neighbors=res_gp.x[0], n_jobs=-1)
1840
self.model.fit(X, Y)
1941

20-
def predict(self, X):
42+
def predict(self, X) -> Tuple[np.array, np.array]:
43+
"""
44+
Returns:
45+
pred - model predictions
46+
unc - model variance as E[(f(x) - E[f(x)])**2]
47+
"""
2148
pred = self.model.predict(X).reshape(-1, 1)
22-
unc = np.zeros(pred.shape)
49+
unc = np.mean(np.square(pred-np.mean(pred)), axis=1).reshape(-1, 1)
50+
assert pred.shape == unc.shape
2351
return pred, unc
2452

2553
def predict_f(self, X: np.ndarray):
2654
return self.predict(X)
55+
56+
57+
58+
59+

run_experiments.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010

1111

1212
datasets = ["MTH3", "TIMB", "CALM", "1FQG", "UBQT", "BRCA", "TOXI"] # "MTH3", "TIMB", "CALM", "1FQG", "UBQT", "BRCA", "TOXI"
13-
representations = [EVE_DENSITY, EVE, TRANSFORMER, ONE_HOT, ESM] # VAE_AUX, VAE_RAND, TRANSFORMER, VAE, ONE_HOT, ESM, EVE, VAE_AUX EXTRA 1D rep: VAE_DENSITY
13+
# datasets = ["TOXI"] # "MTH3", "TIMB", "CALM", "1FQG", "UBQT", "BRCA", "TOXI"
14+
representations = [TRANSFORMER, ONE_HOT, ESM, EVE, EVE_DENSITY] # VAE_AUX, VAE_RAND, TRANSFORMER, VAE, ONE_HOT, ESM, EVE, VAE_AUX EXTRA 1D rep: VAE_DENSITY
1415
MOCK = False
1516
# Protocols: RandomSplitterFactory, BlockSplitterFactory, PositionalSplitterFactory, BioSplitterFactory, FractionalSplitterFactory
1617
protocol_factories = [RandomSplitterFactory, PositionalSplitterFactory]
18+
# protocol_factories = [PositionalSplitterFactory]
1719
# protocol_factories = [FractionalSplitterFactory]
18-
# protocol_factories = [WeightedTaskSplitterFactory]
1920
# protocol_factories = [BioSplitterFactory("TOXI", 1, 2), BioSplitterFactory("TOXI", 2, 2), BioSplitterFactory("TOXI", 2, 3), BioSplitterFactory("TOXI", 3, 3), BioSplitterFactory("TOXI", 3, 4)]
2021
# [BioSplitterFactory("TOXI", 1, 2), BioSplitterFactory("TOXI", 2, 2), BioSplitterFactory("TOXI", 2, 3), BioSplitterFactory("TOXI", 3, 3), BioSplitterFactory("TOXI", 3, 4)]:
2122

2223
# Methods: # KNNFactory, RandomForestFactory, UncertainRFFactory, GPSEFactory, GPLinearFactory, GPMaternFactory
2324
# method_factories = [get_key_for_factory(f) for f in [KNNFactory, RandomForestFactory]]
24-
method_factories = [get_key_for_factory(f) for f in [KNNFactory, RandomForestFactory, UncertainRFFactory, GPSEFactory, GPLinearFactory, GPMaternFactory]]
25+
method_factories = [get_key_for_factory(f) for f in [KNNFactory]]
2526

26-
# TODO: rerun with KNN and RF for sanity check after data-load refactor:
2727
experiment_iterator = product(datasets, representations, protocol_factories, method_factories)
2828
def run_experiments():
2929
for dataset, representation, protocol_factory, factory_key in experiment_iterator:
@@ -76,9 +76,9 @@ def run_augmentation_experiments():
7676

7777

7878
if __name__ == "__main__":
79-
# run_experiments()
79+
run_experiments() # TODO: toxi all
8080
# ABLATION STUDY: (dim-reduction, augmentation, threshold):
8181
# run_dim_reduction_experiments()
82-
run_augmentation_experiments()
82+
# run_augmentation_experiments()
8383
#run_threshold_experiments()
8484

0 commit comments

Comments
 (0)