Skip to content

Commit 0a4a214

Browse files
author
Cara Giovanetti
committed
remove debug options for faster compile
1 parent 6e36a5e commit 0a4a214

2 files changed

Lines changed: 23 additions & 37 deletions

File tree

ABCMB/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def run_cosmology_abbr(self, params : dict):
194194
"""
195195

196196
# let the user know the code is compiling
197-
print("\n")
197+
print("")
198198
print(" __")
199199
print(" / \\")
200200
print(" / \\")
@@ -207,7 +207,7 @@ def run_cosmology_abbr(self, params : dict):
207207
print("______/"+'\033[1m' +"/ / \\ \\| _ \\\\ \\___ ||\\/||| - )"+"\033[0m"+"/\\ ")
208208
print(" "+'\033[1m' +"/_/ \\_\\___/ \\____||| |||_-_)"+"\033[0m"+" \\/\\ is compiling...")
209209
print(" \\/\\")
210-
print("\n")
210+
print("")
211211

212212
PT, BG = self.get_PTBG(params)
213213
output = ()

time_tests.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,46 @@
1-
from classy import Class
2-
31
import sys
42
sys.path.append('../')
5-
# sys.path.append('../JaxCMB')
6-
# print(sys.path)
7-
83
import os
9-
# os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
10-
# print(os.getcwd())
11-
12-
import sys
134

14-
# assert "jax" not in sys.modules, "jax already imported: you must restart your runtime"
15-
# os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"
165

176
import jax
187
print(jax.devices())
19-
# import jax
208
jax.config.update("jax_enable_x64", True)
21-
jax.config.update("jax_debug_nans", True)
229
from ABCMB.main import Model
2310
import ABCMB.spectrum as spectrum
2411
from scipy.interpolate import interp1d
2512
import jax.numpy as jnp
2613
import numpy as np
2714
import matplotlib.pyplot as plt
28-
import pytest
29-
import numpy as np
30-
np.seterr(all='raise')
15+
3116

3217
import time
3318

3419
h = 0.6762
3520

21+
model = Model(bbn_type='LINX')
22+
# model = Model()
23+
3624
for i in range(2):
3725
start=time.time()
3826
# ABCMB:
39-
params = {
40-
'h': jnp.asarray(h),
41-
'omega_cdm': jnp.asarray(0.1193),
42-
'omega_b': jnp.asarray(0.0225),
43-
'A_s': jnp.asarray(2.12424e-9),
44-
'n_s': jnp.asarray(0.9709),
45-
#'Neff': 3.044,
46-
#'Delta_Neff_init': jnp.asarray(0.),
47-
#'YHe': jnp.asarray(0.245),
48-
'TCMB0': jnp.asarray(2.34865418e-4),
49-
#'T_nu': jnp.asarray((4. / 11.)**(1. / 3.) * 2.34865418e-4),
50-
'N_ncdm': jnp.asarray(0.),
51-
'T_ncdm': jnp.asarray(0.71611),
52-
'm_ncdm': jnp.asarray(0.06)
53-
}
54-
55-
model = Model(bbn_type='LINX')
56-
#model = Model()
57-
27+
# params = {
28+
# 'h': jnp.asarray(h),
29+
# 'omega_cdm': jnp.asarray(0.1193),
30+
# 'omega_b': jnp.asarray(0.0225),
31+
# 'A_s': jnp.asarray(2.12424e-9),
32+
# 'n_s': jnp.asarray(0.9709),
33+
# #'Neff': 3.044,
34+
# #'Delta_Neff_init': jnp.asarray(0.),
35+
# #'YHe': jnp.asarray(0.245),
36+
# 'TCMB0': jnp.asarray(2.34865418e-4),
37+
# #'T_nu': jnp.asarray((4. / 11.)**(1. / 3.) * 2.34865418e-4),
38+
# # 'N_ncdm': jnp.asarray(0.),
39+
# # 'T_ncdm': jnp.asarray(0.71611),
40+
# # 'm_ncdm': jnp.asarray(0.06)
41+
# }
42+
params = {}
43+
5844
ell, ABC_Cl = model.run_cosmology(params)
5945

6046
print(ABC_Cl[0])

0 commit comments

Comments
 (0)