|
1 | | -from classy import Class |
2 | | - |
3 | 1 | import sys |
4 | 2 | sys.path.append('../') |
5 | | -# sys.path.append('../JaxCMB') |
6 | | -# print(sys.path) |
7 | | - |
8 | 3 | import os |
9 | | -# os.environ.setdefault("JAX_PLATFORM_NAME", "cpu") |
10 | | -# print(os.getcwd()) |
11 | | - |
12 | | -import sys |
13 | 4 |
|
14 | | -# assert "jax" not in sys.modules, "jax already imported: you must restart your runtime" |
15 | | -# os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8" |
16 | 5 |
|
17 | 6 | import jax |
18 | 7 | print(jax.devices()) |
19 | | -# import jax |
20 | 8 | jax.config.update("jax_enable_x64", True) |
21 | | -jax.config.update("jax_debug_nans", True) |
22 | 9 | from ABCMB.main import Model |
23 | 10 | import ABCMB.spectrum as spectrum |
24 | 11 | from scipy.interpolate import interp1d |
25 | 12 | import jax.numpy as jnp |
26 | 13 | import numpy as np |
27 | 14 | import matplotlib.pyplot as plt |
28 | | -import pytest |
29 | | -import numpy as np |
30 | | -np.seterr(all='raise') |
| 15 | + |
31 | 16 |
|
32 | 17 | import time |
33 | 18 |
|
34 | 19 | h = 0.6762 |
35 | 20 |
|
| 21 | +model = Model(bbn_type='LINX') |
| 22 | +# model = Model() |
| 23 | + |
36 | 24 | for i in range(2): |
37 | 25 | start=time.time() |
38 | 26 | # ABCMB: |
39 | | - params = { |
40 | | - 'h': jnp.asarray(h), |
41 | | - 'omega_cdm': jnp.asarray(0.1193), |
42 | | - 'omega_b': jnp.asarray(0.0225), |
43 | | - 'A_s': jnp.asarray(2.12424e-9), |
44 | | - 'n_s': jnp.asarray(0.9709), |
45 | | - #'Neff': 3.044, |
46 | | - #'Delta_Neff_init': jnp.asarray(0.), |
47 | | - #'YHe': jnp.asarray(0.245), |
48 | | - 'TCMB0': jnp.asarray(2.34865418e-4), |
49 | | - #'T_nu': jnp.asarray((4. / 11.)**(1. / 3.) * 2.34865418e-4), |
50 | | - 'N_ncdm': jnp.asarray(0.), |
51 | | - 'T_ncdm': jnp.asarray(0.71611), |
52 | | - 'm_ncdm': jnp.asarray(0.06) |
53 | | - } |
54 | | - |
55 | | - model = Model(bbn_type='LINX') |
56 | | - #model = Model() |
57 | | - |
| 27 | + # params = { |
| 28 | + # 'h': jnp.asarray(h), |
| 29 | + # 'omega_cdm': jnp.asarray(0.1193), |
| 30 | + # 'omega_b': jnp.asarray(0.0225), |
| 31 | + # 'A_s': jnp.asarray(2.12424e-9), |
| 32 | + # 'n_s': jnp.asarray(0.9709), |
| 33 | + # #'Neff': 3.044, |
| 34 | + # #'Delta_Neff_init': jnp.asarray(0.), |
| 35 | + # #'YHe': jnp.asarray(0.245), |
| 36 | + # 'TCMB0': jnp.asarray(2.34865418e-4), |
| 37 | + # #'T_nu': jnp.asarray((4. / 11.)**(1. / 3.) * 2.34865418e-4), |
| 38 | + # # 'N_ncdm': jnp.asarray(0.), |
| 39 | + # # 'T_ncdm': jnp.asarray(0.71611), |
| 40 | + # # 'm_ncdm': jnp.asarray(0.06) |
| 41 | + # } |
| 42 | + params = {} |
| 43 | + |
58 | 44 | ell, ABC_Cl = model.run_cosmology(params) |
59 | 45 |
|
60 | 46 | print(ABC_Cl[0]) |
|
0 commit comments