@@ -88,29 +88,31 @@ def __init__(self,params, species_list, RM):
8888 # self.params = params
8989 self .species_list = species_list
9090
91- self .tau_tab = self ._tabulate_conformal_time (params )
92- self .tau0 = self .tau (0. )
91+ self .tau_tab = jax . device_put ( self ._tabulate_conformal_time (params ), jax . devices ( 'gpu' )[ 0 ] )
92+ self .tau0 = jax . device_put ( self .tau (0. ), jax . devices ( 'gpu' )[ 0 ] )
9393
9494 ### RECOMBINATION RELATED ###
9595
9696 # Run hyrex to tabulate recombination output
97+ # TODO: get this running on CPU. Will require refactorization.
9798 self .xe_tab , self .lna_xe_tab , self .Tm_tab , self .lna_Tm_tab = RM ((self ,params ),z_reion = params ["z_reion" ],
9899 Delta_z_reion = params ["Delta_z_reion" ],
99100 z_reion_He = params ["z_reion_He" ],
100101 Delta_z_reion_He = params ["Delta_z_reion_He" ])
101- self .kappa_func = self ._tabulate_optical_depth (params )
102+
103+ self .kappa_func = jax .device_put (self ._tabulate_optical_depth (params ),jax .devices ('gpu' )[0 ])
102104
103105 # Find approximate maximum of visibility function.
104106 lna_vals = jnp .linspace (- 8.0 , - 4.0 , 1500 ) # Decoupling should have happened at some time in this interval.
105107 vis_vals = vmap (self .visibility ,in_axes = [0 ,None ])(lna_vals , params )
106- self .lna_rec = lna_vals [jnp .argmax (vis_vals )]
107- self .lna_visibility_stop = lna_vals [jnp .argmin ((vis_vals - 1.e-3 )** 2 )]
108- self .rA_rec = self .tau0 - self .tau (self .lna_rec )
108+ self .lna_rec = jax . device_put ( lna_vals [jnp .argmax (vis_vals )], jax . devices ( 'gpu' )[ 0 ])
109+ self .lna_visibility_stop = jax . device_put ( lna_vals [jnp .argmin ((vis_vals - 1.e-3 )** 2 )], jax . devices ( 'gpu' )[ 0 ])
110+ self .rA_rec = jax . device_put ( self .tau0 - self .tau (self .lna_rec ), jax . devices ( 'gpu' )[ 0 ] )
109111
110112 # Find approximate early time when aH x tau_c = 0.008
111113 lna_vals = jnp .linspace (- 15.0 , - 6.0 , 5000 )
112114 aH_tau_c_vals = vmap (self .aH ,in_axes = [0 ,None ])(lna_vals ,params )* self .tau_c (lna_vals ,params )
113- self .lna_transfer_start = lna_vals [jnp .argmin ((aH_tau_c_vals - 0.008 )** 2 )]
115+ self .lna_transfer_start = jax . device_put ( lna_vals [jnp .argmin ((aH_tau_c_vals - 0.008 )** 2 )], jax . devices ( 'gpu' )[ 0 ])
114116
115117
116118 def rho_tot (self , lna , params ):
0 commit comments