Skip to content

Commit 3a602d6

Browse files
committed
merge
2 parents 410bacf + fe580f4 commit 3a602d6

7 files changed

Lines changed: 276 additions & 361 deletions

File tree

abcmb/main.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax import jit, config, lax
1+
from jax import jit, config, lax, tree_util
22
import jax.numpy as jnp
33
from jaxtyping import Array
44
import numpy as np
@@ -59,7 +59,7 @@ class Model(eqx.Module):
5959
6060
Methods:
6161
--------
62-
run_cosmology : Compute CMB angular power spectra
62+
__call__ : Compute CMB angular power spectra
6363
get_PTBG : Get perturbation table and background cosmology
6464
get_BG : Get background cosmology
6565
add_derived_parameters : Compute derived parameters
@@ -148,7 +148,7 @@ def __init__(self,
148148

149149
# need this outside of the jit context
150150
# since we want LINX to run on CPU
151-
def run_cosmology(self, params : dict = {}):
151+
def __call__(self, params : dict = {}):
152152
"""
153153
Compute CMB angular power spectra for given parameters.
154154
@@ -168,8 +168,7 @@ def run_cosmology(self, params : dict = {}):
168168

169169

170170
full_params = self.add_derived_parameters(params)
171-
output, aux = self.run_cosmology_abbr(full_params)
172-
return output, aux
171+
return self.run_cosmology_abbr(full_params)
173172

174173
### JITTED OR JITTABLE FUNCTIONS ###
175174

@@ -205,30 +204,24 @@ def run_cosmology_abbr(self, params : dict):
205204
print('\\_____/ ')
206205
print("")
207206

207+
# Compute background and linear perturbations
208208
PT, BG = self.get_PTBG(params)
209-
output = ()
210-
aux = ()
211-
212-
if self.specs["output_Cl"]:
213-
Cls = self.SS.get_Cl(PT, BG, params)
214-
ells = self.SS.ells
215-
output += Cls
216-
aux += (ells,)
217-
218-
if self.specs["output_Pk"]:
219-
Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params)
220-
output += (Pk,)
221-
aux += (self.SS.k_axis_Pk_output,)
222-
223-
aux += (params,)
224-
225-
if self.specs["output_perturbations"]:
226-
aux += (PT,)
227209

228-
if self.specs["output_background"]:
229-
aux += (BG,)
210+
# Compute CMB power spectra
211+
Cls = self.SS.get_Cl(PT, BG, params)
212+
l = self.SS.ells
213+
214+
# Compute linear matter power spectrum
215+
Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params)
216+
k = self.SS.k_axis_Pk_output
217+
218+
# Package
219+
output = Output(
220+
Cls[0], Cls[1], Cls[2], Pk,
221+
l, k, BG, PT, params
222+
)
230223

231-
return output, aux
224+
return output
232225

233226
@eqx.filter_jit
234227
def get_PTBG(self, params : dict):
@@ -528,4 +521,24 @@ def add_derived_parameters(self, param_in : dict) -> dict:
528521
if key not in expected_keys:
529522
params[key] = jnp.array(value)
530523

531-
return params
524+
return params
525+
526+
class Output(eqx.Module):
527+
"""
528+
Object containing final and intermediate results from one cosmological simulation.
529+
Contains the power spectra (CMB & P(k)) as well as auxillary fields including
530+
the multipoles l for the Cls, wavenumbers k for P(k), background BG, perturbations PT, and
531+
a full list of parameters (input + derived) in the params dictionary.
532+
"""
533+
534+
# Power spectra
535+
ClTT : jnp.array
536+
ClTE : jnp.array
537+
ClEE : jnp.array
538+
Pk : jnp.array
539+
540+
l : jnp.array
541+
k : jnp.array
542+
BG : background.Background
543+
PT : perturbations.PerturbationTable
544+
params : dict

abcmb/model_specs.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,15 @@ def load_specs(input_specs):
1818
specs["input_tau_reion"] = input_specs.get("input_tau_reion", True)
1919

2020
### OUTPUT RELATED specs PARAMS ###
21-
specs["output_Cl"] = input_specs.get("output_Cl", True)
2221
specs["l_min"] = input_specs.get("l_min", 2)
2322
specs["l_max"] = input_specs.get("l_max", 2500)
2423
specs["lensing"] = input_specs.get("lensing", False)
25-
26-
specs["output_Pk"] = input_specs.get("output_Pk", True)
27-
specs["output_k_max"] = input_specs.get("output_k_max", 0.5)
28-
29-
specs["output_background"] = input_specs.get("output_background", False)
30-
specs["output_perturbations"] = input_specs.get("output_perturbations", False)
24+
specs["k_max"] = input_specs.get("k_max", 0.5)
3125

3226
### BBN ###
3327
specs["bbn_type"] = input_specs.get("bbn_type", "")
3428
specs["linx_reaction_net"] = input_specs.get("linx_reaction_net", "key_PRIMAT_2023")
3529

36-
### TODO: HYREX RELATED specs PARAMS ###
37-
3830
### Boltzmann Hierarchy Cutoffs ###
3931
specs["l_max_g"] = input_specs.get("l_max_g", 12)
4032
specs["l_max_pol_g"] = input_specs.get("l_max_pol_g", 10)
@@ -52,7 +44,6 @@ def load_specs(input_specs):
5244
specs["tau0_fid"] = input_specs.get("tau0_fid",1.418668e+04)
5345
specs["rs_rec_fid"] = input_specs.get("rs_rec_fid", 1.446279e+02)
5446

55-
5647
### Transfer integration k-grid resolution ###
5748
specs["k_transfer_linstep"] = input_specs.get("k_transfer_linstep", 4.5e-1)
5849
specs["k_transfer_logstep"] = input_specs.get("k_transfer_logstep", 170.)
@@ -167,9 +158,9 @@ def get_k_axis_perturbations(specs):
167158
i += 1
168159
ks[i] = k
169160

170-
# If the user wants P(k) and specified a k_max above the current, we should add these as well.
171-
if specs["output_Pk"] and k < specs["output_k_max"]:
172-
k_max = specs["output_k_max"]
161+
# If the user specified a k_max above the current, we should add these as well.
162+
if k < specs["k_max"]:
163+
k_max = specs["k_max"]
173164

174165
while k < k_max:
175166
step = 0.005
@@ -179,7 +170,7 @@ def get_k_axis_perturbations(specs):
179170
ks[i] = k
180171

181172
ks = ks[np.where(ks>0)]
182-
k_axis_Pk_output = ks[np.where(ks<=specs["output_k_max"])]
173+
k_axis_Pk_output = ks[np.where(ks<=specs["k_max"])]
183174

184175
return jnp.array(ks), jnp.array(k_axis_Pk_output)
185176

example_notebooks/ABCMB_Fluids.ipynb

Lines changed: 88 additions & 83 deletions
Large diffs are not rendered by default.

example_notebooks/ABCMB_basics.ipynb

Lines changed: 99 additions & 181 deletions
Large diffs are not rendered by default.

example_notebooks/ABCMB_with_LINX.ipynb

Lines changed: 36 additions & 48 deletions
Large diffs are not rendered by default.

pytests/accuracy_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def test_accuracy_checker(h = 0.6762):
110110

111111
# ABCMB
112112

113-
data, label = model.run_cosmology(params)
114-
ells = label[0]
113+
output = model(params)
114+
ells = output.l
115115

116-
ABC_tt = data[0]
117-
ABC_te = data[1]
118-
ABC_ee = data[2]
116+
ABC_tt = output.ClTT
117+
ABC_te = output.ClTE
118+
ABC_ee = output.ClEE
119119

120120
# Compare Cltt
121121
err_tt = abs(cltt-ABC_tt)/cltt
@@ -126,8 +126,8 @@ def test_accuracy_checker(h = 0.6762):
126126
print(err_ee.max())
127127

128128
# Compare P(k)
129-
ABC_Pk = data[3]
130-
ABC_k = label[1]
129+
ABC_Pk = output.Pk
130+
ABC_k = output.k
131131
CLA_Pk = np.vectorize(CLASS_Model.pk)(ABC_k, 0.)
132132
err_Pk = abs(CLA_Pk-ABC_Pk)/CLA_Pk
133133
print(err_Pk.max())

time_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
# }
4646
params = {}
4747

48-
out, aux = model.run_cosmology(params)
48+
out, aux = model(params)
4949

5050
print(out[0])
5151
print(time.time()-start)

0 commit comments

Comments
 (0)