Skip to content

Commit 4d50a88

Browse files
committed
compare against mdl-rs
1 parent 8661a41 commit 4d50a88

25 files changed

+2567
-24
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ env
1919
clean.sh
2020
cancel_jobs.py
2121
.idea
22+
**tmp*

fmri/Linear_Regression_Synthetic.ipynb

+714
Large diffs are not rendered by default.

fmri/analyze_fmri.ipynb

+200-13
Large diffs are not rendered by default.

fmri/run.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import matplotlib.pyplot as plt
66
import numpy as np
77
from scipy import ndimage as ndi
8+
import sys
89
from skimage import data
910
import pickle as pkl
1011
from skimage.util import img_as_float
@@ -14,7 +15,9 @@
1415
from scipy.io import loadmat
1516
from copy import deepcopy
1617
from skimage.filters import gabor_kernel
17-
from sklearn.linear_model import RidgeCV
18+
from sklearn.linear_model import RidgeCV, ARDRegression
19+
sys.path.append('../lib/pymdlrs')
20+
from src.ulnml.least_square_regression import RidgeULNML
1821
import seaborn as sns
1922
from scipy.io import loadmat
2023
import numpy.linalg as npl
@@ -69,9 +72,9 @@ def get_roi_and_idx(run):
6972

7073
# fit linear models
7174
use_sigmas = False
72-
use_small = True
75+
use_small = False
7376
out_dir = '/scratch/users/vision/data/gallant/vim_2_crcns'
74-
save_dir = oj(out_dir, 'dec2_small_1')
77+
save_dir = oj(out_dir, 'dec13_baselines')
7578
suffix = '_feats' # _feats, '' for pixels
7679
norm = '_norm' # ''
7780
print('saving to', save_dir)
@@ -141,7 +144,20 @@ def get_roi_and_idx(run):
141144

142145
# reg values to try
143146
reg_params = np.logspace(3, 6, 20).round().astype(int)
144-
147+
148+
# fit ard + mdl-rs
149+
baselines = {}
150+
for model_type, model_name in zip([ARDRegression, RidgeULNML], ['ard', 'mdl-rs']):
151+
model = model_type()
152+
model.fit(X_train, y_train)
153+
preds_train = model.predict(X_train)
154+
preds = model.predict(X_test)
155+
baselines[f'{model_name}_mse_train'] = metrics.mean_squared_error(y_train, preds_train)
156+
baselines[f'{model_name}_r2_train'] = metrics.r2_score(y_train, preds_train)
157+
baselines[f'{model_name}_mse'] = metrics.mean_squared_error(y_test, preds)
158+
baselines[f'{model_name}_r2'] = metrics.r2_score(y_test, preds)
159+
baselines[f'{model_name}_corr'] = np.corrcoef(y_test, preds)[0, 1]
160+
145161
# fit ridge cv
146162
m = RidgeCV(alphas=reg_params, store_cv_values=True)
147163
m.fit(X_train, y_train)
@@ -153,6 +169,7 @@ def get_roi_and_idx(run):
153169
r2 = metrics.r2_score(y_test, preds)
154170
corr = np.corrcoef(y_test, preds)[0, 1]
155171
print('RidgeCV corr', corr)
172+
156173

157174
# fit mdl comp
158175
mdl_comp_opt = 1e10
@@ -199,23 +216,28 @@ def get_roi_and_idx(run):
199216
results = {
200217
'roi': roi,
201218
'model': m,
202-
'lambda_opt': lambda_opt,
203-
'theta_opt': theta_opt,
204-
'mdl_comp_opt': mdl_comp_opt,
205-
'mse_test_mdl': mse_test_mdl,
206-
'cv_values': m.cv_values_,
207219
'snr': snr,
208220
'lambda_best': m.alpha_,
209221
'n_train': n_train,
210222
'n_test': num_test,
211223
'd': d,
212224
'y_norm': y_norm,
225+
'idx': i,
226+
227+
# mdl stuff
228+
'lambda_opt': lambda_opt,
229+
'theta_opt': theta_opt,
230+
'mdl_comp_opt': mdl_comp_opt,
231+
'mse_test_mdl': mse_test_mdl,
232+
233+
# cv stuff
234+
'cv_values': m.cv_values_,
213235
'mse_train': mse_train,
214236
'r2_train': r2_train,
215237
'mse_test': mse,
216238
'r2_test': r2,
217239
'corr_test': corr,
218-
'idx': i,
219240
**r,
241+
**baselines,
220242
}
221243
pkl.dump(results, open(oj(save_dir, f'ridge_{i}.pkl'), 'wb'))

lib/pymdlrs/.gitignore

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation
65+
docs/_build/
66+
67+
# PyBuilder
68+
target/
69+
70+
# Jupyter Notebook
71+
.ipynb_checkpoints
72+
73+
# pyenv
74+
.python-version
75+
76+
# celery beat schedule file
77+
celerybeat-schedule
78+
79+
# SageMath parsed files
80+
*.sage.py
81+
82+
# dotenv
83+
.env
84+
85+
# virtualenv
86+
.venv
87+
venv/
88+
ENV/
89+
90+
# Spyder project settings
91+
.spyderproject
92+
.spyproject
93+
94+
# Rope project settings
95+
.ropeproject
96+
97+
# mkdocs documentation
98+
/site
99+
100+
# mypy
101+
.mypy_cache/
102+
103+
/memory*
104+
/tmp*
105+
/asset
106+
*.pdf
107+
108+
/.idea

lib/pymdlrs/pyglassobind/LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2017 koheimiya
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

lib/pymdlrs/pyglassobind/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
## Installation
2+
* Install Eigen
3+
* Install pybind11 via `pip install pybind11`
4+
* Specify the path to the Eigen's headers and compiler at the top of `setup.py`.
5+
* Execute `pip install .` in the target environment
6+
* Voila!

0 commit comments

Comments
 (0)