Skip to content

Commit 618e886

Browse files
committed
Merge branch 'CPU_compile_easy'
2 parents 6a7b3d3 + e185196 commit 618e886

11 files changed

Lines changed: 794 additions & 252 deletions

File tree

ABCMB/cosmology.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,29 +88,31 @@ def __init__(self,params, species_list, RM):
8888
# self.params = params
8989
self.species_list = species_list
9090

91-
self.tau_tab = self._tabulate_conformal_time(params)
92-
self.tau0 = self.tau(0.)
91+
self.tau_tab = jax.device_put(self._tabulate_conformal_time(params),jax.devices('gpu')[0])
92+
self.tau0 = jax.device_put(self.tau(0.),jax.devices('gpu')[0])
9393

9494
### RECOMBINATION RELATED ###
9595

9696
# Run hyrex to tabulate recombination output
97+
# TODO: get this running on CPU. Will require refactorization.
9798
self.xe_tab, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = RM((self,params),z_reion = params["z_reion"],
9899
Delta_z_reion = params["Delta_z_reion"],
99100
z_reion_He = params["z_reion_He"],
100101
Delta_z_reion_He = params["Delta_z_reion_He"])
101-
self.kappa_func = self._tabulate_optical_depth(params)
102+
103+
self.kappa_func = jax.device_put(self._tabulate_optical_depth(params),jax.devices('gpu')[0])
102104

103105
# Find approximate maximum of visibility function.
104106
lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling should have happened at some time in this interval.
105107
vis_vals = vmap(self.visibility,in_axes=[0,None])(lna_vals, params)
106-
self.lna_rec = lna_vals[jnp.argmax(vis_vals)]
107-
self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)]
108-
self.rA_rec = self.tau0 - self.tau(self.lna_rec)
108+
self.lna_rec = jax.device_put(lna_vals[jnp.argmax(vis_vals)],jax.devices('gpu')[0])
109+
self.lna_visibility_stop = jax.device_put(lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)],jax.devices('gpu')[0])
110+
self.rA_rec = jax.device_put(self.tau0 - self.tau(self.lna_rec),jax.devices('gpu')[0])
109111

110112
# Find approximate early time when aH x tau_c = 0.008
111113
lna_vals = jnp.linspace(-15.0, -6.0, 5000)
112114
aH_tau_c_vals = vmap(self.aH,in_axes=[0,None])(lna_vals,params)*self.tau_c(lna_vals,params)
113-
self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)]
115+
self.lna_transfer_start = jax.device_put(lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)],jax.devices('gpu')[0])
114116

115117

116118
def rho_tot(self, lna, params):

ABCMB/linx/abundances.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class AbundanceModel(eqx.Module):
4141
species_mass : list
4242
Mass of each species.
4343
"""
44+
4445
nuclear_net : nucl.NuclearRates
4546
weak_rates : wr.WeakRates
46-
species_dict : dict
47+
species_dict : dict #= eqx.field(static=True)
4748
species_Z : list
4849
species_N : list
4950
species_A : list

ABCMB/linx/nuclear.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@ class NuclearRates(eqx.Module):
3434
bkwrd_rate : dict of callable
3535
Dictionary of backward rate parameter for each reaction.
3636
"""
37-
37+
3838
max_i_species : int
39-
interp_type : str
39+
interp_type : str #= eqx.field(static=True)
4040
reactions : list
41-
reactions_names: list
41+
reactions_names: list #= eqx.field(static=True)
4242
in_states : dict
4343
out_states : dict
4444
frwrd_symmetry_fac : dict
4545
bkwrd_symmetry_fac : dict
46-
frwrd_rate_param : dict
47-
bkwrd_rate_param : dict
48-
frwrd_reaction_by_particle : dict
49-
bkwrd_reaction_by_particle : dict
46+
frwrd_rate_param : dict #= eqx.field(static=True)
47+
bkwrd_rate_param : dict #= eqx.field(static=True)
48+
frwrd_reaction_by_particle : dict #= eqx.field(static=True)
49+
bkwrd_reaction_by_particle : dict #= eqx.field(static=True)
5050

5151
def __init__(
5252
self, reactions=None, nuclear_net=None, interp_type='linear',
@@ -81,17 +81,17 @@ def __init__(
8181

8282
if reactions is not None:
8383

84-
self.reactions = reactions
84+
self.reactions = tuple(reactions)
8585

8686
elif nuclear_net == 'np_only':
8787
# No nuclear reactions. n<->p rates are always included.
8888

89-
self.reactions = []
89+
self.reactions = ()
9090
self.max_i_species = 2
9191

9292
else:
9393

94-
self.reactions = self.populate(nuclear_net)
94+
self.reactions = tuple(self.populate(nuclear_net))
9595

9696
if nuclear_net[:3] == 'key':
9797

@@ -111,7 +111,7 @@ def __init__(
111111
self.bkwrd_symmetry_fac = {}
112112
self.frwrd_rate_param = {}
113113
self.bkwrd_rate_param = {}
114-
self.reactions_names = []
114+
self.reactions_names = ()
115115

116116
self.frwrd_reaction_by_particle = {
117117
i:[] for i in range(self.max_i_species)
@@ -135,7 +135,7 @@ def __init__(
135135
if i in self.out_states[rxn.name]:
136136
self.bkwrd_reaction_by_particle[i].append(rxn.name)
137137

138-
self.reactions_names.append(rxn.name)
138+
self.reactions_names = self.reactions_names + (rxn.name,)
139139

140140

141141
@eqx.filter_jit

ABCMB/linx/reactions.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,18 @@ class Reaction(eqx.Module):
5151
5252
"""
5353

54-
name : str
55-
in_states : tuple
56-
out_states : tuple
57-
frwrd_symmetry_fac : float
58-
bkwrd_symmetry_fac : float
54+
name : str #= eqx.field(static=True)
55+
in_states : tuple #= eqx.field(static=True)
56+
out_states : tuple #= eqx.field(static=True)
57+
frwrd_symmetry_fac : float#= eqx.field(static=True)
58+
bkwrd_symmetry_fac : float #= eqx.field(static=True)
5959
alpha : float
6060
beta : float
6161
gamma : float
6262
T9_vec : list
6363
mu_median_vec : list
6464
expsigma_vec : list
65-
interp_type : str
65+
interp_type : str #= eqx.field(static=True)
6666
frwrd_rate_param_func : callable
6767

6868
def __init__(
@@ -140,17 +140,17 @@ def __init__(
140140
file_dir+'/data/nuclear_rates/'+spline_data,
141141
unpack=True
142142
)
143-
try:
144-
gpus = jax.devices('gpu')
145-
self.T9_vec = jax.device_put(self.T9_vec, device=gpus[0])
146-
self.mu_median_vec = jax.device_put(
147-
self.mu_media_vec, device=gpus[0]
148-
)
149-
self.expsigma_vec = jax.device_put(
150-
self.expsigma_vec, device=gpus[0]
151-
)
152-
except:
153-
pass
143+
# try:
144+
# gpus = jax.devices('gpu')
145+
# self.T9_vec = jax.device_put(self.T9_vec, device=gpus[0])
146+
# self.mu_median_vec = jax.device_put(
147+
# self.mu_median_vec, device=gpus[0]
148+
# )
149+
# self.expsigma_vec = jax.device_put(
150+
# self.expsigma_vec, device=gpus[0]
151+
# )
152+
# except:
153+
# pass
154154

155155
elif frwrd_rate_param_func is not None:
156156

ABCMB/linx/thermo.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -646,35 +646,6 @@ def p_massive_MB(T, mu, m, g):
646646
f_nue_ann_tab = np.loadtxt(file_dir+"/data/background/"+"nue_ann.txt")
647647
f_numu_ann_tab = np.loadtxt(file_dir+"/data/background/"+"numu_ann.txt")
648648

649-
try:
650-
gpus = devices('gpu')
651-
P_QED_tab = device_put(
652-
P_QED_tab, device=gpus[0]
653-
)
654-
dPdT_QED_tab = device_put(
655-
dPdT_QED_tab, device=gpus[0]
656-
)
657-
d2PdT2_QED_tab = device_put(
658-
d2PdT2_QED_tab , device=gpus[0]
659-
)
660-
661-
f_nue_scat_tab = device_put(
662-
f_nue_scat_tab, device=gpus[0]
663-
)
664-
f_numu_scat_tab = device_put(
665-
f_numu_scat_tab, device=gpus[0]
666-
)
667-
668-
f_nue_ann_tab = device_put(
669-
f_nue_ann_tab, device=gpus[0]
670-
)
671-
f_numu_ann_tab = device_put(
672-
f_numu_ann_tab, device=gpus[0]
673-
)
674-
except:
675-
pass
676-
677-
678649
######################
679650
# Standard EM Sector #
680651
######################

ABCMB/linx/weak_rates.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,33 +93,33 @@ def __init__(self,
9393
unpack = True
9494
)
9595

96-
try:
97-
gpus = jax.devices('gpu')
98-
self.T_nTOp_thermal_interval = jax.device_put(
99-
self.T_nTOp_thermal_interval, device=gpus[0]
100-
)
101-
self.L_nTOpCCRTh_res = jax.device_put(
102-
self.L_nTOpCCRTh_res, device=gpus[0]
103-
)
104-
except:
105-
pass
96+
# try:
97+
# gpus = jax.devices('gpu')
98+
# self.T_nTOp_thermal_interval = jax.device_put(
99+
# self.T_nTOp_thermal_interval, device=gpus[0]
100+
# )
101+
# self.L_nTOpCCRTh_res = jax.device_put(
102+
# self.L_nTOpCCRTh_res, device=gpus[0]
103+
# )
104+
# except:
105+
# pass
106106

107107
self.T_pTOn_thermal_interval, self.L_pTOnCCRTh_res = np.loadtxt(
108108
file_dir+"/data/weak_thermal_corrections/"
109109
+"pTOn_thermal_corrections_SBBN.txt",
110110
unpack = True
111111
)
112112

113-
try:
114-
gpus = jax.devices('gpu')
115-
self.T_pTOn_thermal_interval = jax.device_put(
116-
self.T_pTOn_thermal_interval, device=gpus[0]
117-
)
118-
self.L_pTOnCCRTh_res = jax.device_put(
119-
self.L_pTOnCCRTh_res, device=gpus[0]
120-
)
121-
except:
122-
pass
113+
# try:
114+
# gpus = jax.devices('gpu')
115+
# self.T_pTOn_thermal_interval = jax.device_put(
116+
# self.T_pTOn_thermal_interval, device=gpus[0]
117+
# )
118+
# self.L_pTOnCCRTh_res = jax.device_put(
119+
# self.L_pTOnCCRTh_res, device=gpus[0]
120+
# )
121+
# except:
122+
# pass
123123

124124
else:
125125

0 commit comments

Comments
 (0)