Skip to content

Commit

Permalink
Faster phase terms (#165)
Browse files Browse the repository at this point in the history
* Make estimates compatible with freq_interval - estimate per solution interval in time and freq.

* Begin implementing the even better phase-only approach. Mitigates slowdown of recent changes.

* Rename residual phase to amplocked residual.

* Apply changes to all relevant terms.

* Remove some defunct code.

* Enable nested parallelism in residual computation.
  • Loading branch information
JSKenyon authored Jun 15, 2022
1 parent 4e6c87e commit cd1e15f
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 183 deletions.
28 changes: 5 additions & 23 deletions quartical/gains/crosshand_phase/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numba import prange, generated_jit
from quartical.utils.numba import coerce_literal
from quartical.gains.general.generics import (solver_intermediaries,
compute_residual_phase,
compute_amplocked_residual,
per_array_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
Expand Down Expand Up @@ -88,9 +88,9 @@ def impl(base_args, term_args, meta_args, corr_mode):

for loop_idx in range(max_iter):

compute_residual_phase(base_args,
solver_imdry,
corr_mode)
compute_amplocked_residual(base_args,
solver_imdry,
corr_mode)

compute_jhj_jhr(base_args,
term_args,
Expand Down Expand Up @@ -245,8 +245,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, corr_mode):
lop_qp_arr = valloc(complex_dtype, leading_dims=(n_gdir,))
rop_qp_arr = valloc(complex_dtype, leading_dims=(n_gdir,))

norm_factors = valloc(complex_dtype)

tmp_kprod = np.zeros((4, 4), dtype=complex_dtype)
tmp_jhr = jhr[ti, fi]
tmp_jhj = jhj[ti, fi]
Expand Down Expand Up @@ -334,7 +332,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, corr_mode):
compute_jhwj_jhwr_elem(lop_pq_d,
rop_pq_d,
w,
norm_factors,
gains_p[active_term],
tmp_kprod,
r_pq,
Expand All @@ -347,7 +344,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, corr_mode):
compute_jhwj_jhwr_elem(lop_qp_d,
rop_qp_d,
w,
norm_factors,
gains_q[active_term],
tmp_kprod,
r_qp,
Expand Down Expand Up @@ -458,21 +454,14 @@ def impl(params, gain):
def compute_jhwj_jhwr_elem_factory(corr_mode):

v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
imul = factories.imul_factory(corr_mode)
a_kron_bt = factories.a_kron_bt_factory(corr_mode)
unpack = factories.unpack_factory(corr_mode)
unpackc = factories.unpackc_factory(corr_mode)
iabsdiv = factories.iabsdiv_factory(corr_mode)

if corr_mode.literal_value == 4:
def impl(lop, rop, w, normf, gain, tmp_kprod, res, jhr, jhj):

# Compute normalization factor.
v1_imul_v2(lop, rop, normf)
iabsdiv(normf)
def impl(lop, rop, w, gain, tmp_kprod, res, jhr, jhj):

# Accumulate an element of jhwr.
imul(res, normf) # Apply normalization factor to r.
v1_imul_v2(res, rop, res)
v1_imul_v2(lop, res, res)

Expand All @@ -497,13 +486,6 @@ def impl(lop, rop, w, normf, gain, tmp_kprod, res, jhr, jhj):
jhr[0] += upd_00

w_0, w_1, w_2, w_3 = unpack(w) # NOTE: XX, XY, YX, YY
n_0, n_1, n_2, n_3 = unpack(normf)

# Apply normalisation factors by scaling w.
w_0 = n_0 * w_0 * n_0
w_1 = n_1 * w_1 * n_1
w_2 = n_2 * w_2 * n_2
w_3 = n_3 * w_3 * n_3

jh_0, jh_1, jh_2, jh_3 = unpack(tmp_kprod[0])
j_0, j_1, j_2, j_3 = unpackc(tmp_kprod[0])
Expand Down
78 changes: 44 additions & 34 deletions quartical/gains/delay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def init_term(
a2 = kwargs["a2"]
chan_freq = kwargs["chan_freqs"]
t_map = kwargs["t_map_arr"][0, :, term_ind] # time -> solint
f_map = kwargs["f_map_arr"][1, :, term_ind] # freq -> solint
_, n_chan, n_ant, n_dir, n_corr = gain.shape
# TODO: Make controllable/automate. Check with Landman.
pad_factor = int(np.ceil(2 ** 15 / n_chan))

# We only need the baselines which include the ref_ant.
sel = np.where((a1 == ref_ant) | (a2 == ref_ant))
Expand All @@ -111,12 +110,14 @@ def init_term(
data[flags != 0] = 0

utint = np.unique(t_map)
ufint = np.unique(f_map)

for ut in utint:
sel = np.where((t_map == ut) & (a1 != a2))
ant_map_pq = np.where(a1[sel] == ref_ant, a2[sel], 0)
ant_map_qp = np.where(a2[sel] == ref_ant, a1[sel], 0)
ant_map = ant_map_pq + ant_map_qp

ref_data = np.zeros((n_ant, n_chan, n_corr), dtype=np.complex128)
counts = np.zeros((n_ant, n_chan), dtype=int)
np.add.at(
Expand All @@ -136,35 +137,44 @@ def init_term(
out=ref_data
)

fft_data = np.abs(
np.fft.fft(ref_data, n=n_chan*pad_factor, axis=1)
)
fft_data = np.fft.fftshift(fft_data, axes=1)

delta_freq = chan_freq[1] - chan_freq[0]
fft_freq = np.fft.fftfreq(n_chan*pad_factor, delta_freq)
fft_freq = np.fft.fftshift(fft_freq)

delay_est_ind_00 = np.argmax(fft_data[..., 0], axis=1)
delay_est_00 = fft_freq[delay_est_ind_00]

if n_corr > 1:
delay_est_ind_11 = np.argmax(fft_data[..., -1], axis=1)
delay_est_11 = fft_freq[delay_est_ind_11]

for t, p, q in zip(t_map[sel], a1[sel], a2[sel]):
if p == ref_ant:
param[t, 0, q, 0, 1] = -delay_est_00[q]
if n_corr > 1:
param[t, 0, q, 0, 3] = -delay_est_11[q]
else:
param[t, 0, p, 0, 1] = delay_est_00[p]
if n_corr > 1:
param[t, 0, p, 0, 3] = delay_est_11[p]

coeffs00 = param[..., 1]*kwargs["chan_freqs"][None, :, None, None]
gain[..., 0] = np.exp(2j*np.pi*coeffs00)

if n_corr > 1:
coeffs11 = param[..., 3]*kwargs["chan_freqs"][None, :, None, None]
gain[..., -1] = np.exp(2j*np.pi*coeffs11)
for uf in ufint:

fsel = np.where(f_map == uf)[0]
sel_n_chan = fsel.size
n = int(np.ceil(2 ** 15 / sel_n_chan)) * sel_n_chan

fft_data = np.abs(
np.fft.fft(ref_data[:, fsel], n=n, axis=1)
)
fft_data = np.fft.fftshift(fft_data, axes=1)

delta_freq = chan_freq[1] - chan_freq[0]
fft_freq = np.fft.fftfreq(n, delta_freq)
fft_freq = np.fft.fftshift(fft_freq)

delay_est_ind_00 = np.argmax(fft_data[..., 0], axis=1)
delay_est_00 = fft_freq[delay_est_ind_00]

if n_corr > 1:
delay_est_ind_11 = np.argmax(fft_data[..., -1], axis=1)
delay_est_11 = fft_freq[delay_est_ind_11]

for t, p, q in zip(t_map[sel], a1[sel], a2[sel]):
if p == ref_ant:
param[t, uf, q, 0, 1] = -delay_est_00[q]
if n_corr > 1:
param[t, uf, q, 0, 3] = -delay_est_11[q]
else:
param[t, uf, p, 0, 1] = delay_est_00[p]
if n_corr > 1:
param[t, uf, p, 0, 3] = delay_est_11[p]

for ut in utint:
for f in range(n_chan):
fm = f_map[f]
cf = 2j * np.pi * chan_freq[f]

gain[ut, f, :, :, 0] = np.exp(cf * param[ut, fm, :, :, 1])

if n_corr > 1:
gain[ut, f, :, :, -1] = np.exp(cf * param[ut, fm, :, :, 3])
52 changes: 13 additions & 39 deletions quartical/gains/delay/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numba import prange, generated_jit
from quartical.utils.numba import coerce_literal
from quartical.gains.general.generics import (solver_intermediaries,
compute_residual_phase,
compute_amplocked_residual,
per_array_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
Expand Down Expand Up @@ -97,9 +97,9 @@ def impl(base_args, term_args, meta_args, corr_mode):

for loop_idx in range(max_iter):

compute_residual_phase(base_args,
solver_imdry,
corr_mode)
compute_amplocked_residual(base_args,
solver_imdry,
corr_mode)

compute_jhj_jhr(base_args,
term_args,
Expand Down Expand Up @@ -261,8 +261,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, scaled_cf,
lop_qp_arr = valloc(complex_dtype, leading_dims=(n_gdir,))
rop_qp_arr = valloc(complex_dtype, leading_dims=(n_gdir,))

norm_factors = valloc(complex_dtype)

tmp_kprod = np.zeros((4, 4), dtype=complex_dtype)
tmp_jhr = jhr[ti, fi]
tmp_jhj = jhj[ti, fi]
Expand Down Expand Up @@ -353,7 +351,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, scaled_cf,
rop_pq_d,
w,
nu,
norm_factors,
gains_p[active_term],
tmp_kprod,
r_pq,
Expand All @@ -367,7 +364,6 @@ def impl(base_args, term_args, meta_args, solver_imdry, scaled_cf,
rop_qp_d,
w,
nu,
norm_factors,
gains_q[active_term],
tmp_kprod,
r_qp,
Expand Down Expand Up @@ -493,27 +489,19 @@ def impl(params, chanfreq, gain):
def compute_jhwj_jhwr_elem_factory(corr_mode):

v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
imul = factories.imul_factory(corr_mode)
a_kron_bt = factories.a_kron_bt_factory(corr_mode)
iunpack = factories.iunpack_factory(corr_mode)
unpack = factories.unpack_factory(corr_mode)
unpackc = factories.unpackc_factory(corr_mode)
iabsdiv = factories.iabsdiv_factory(corr_mode)

if corr_mode.literal_value == 4:
def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):
def impl(lop, rop, w, nu, gain, tmp_kprod, res, jhr, jhj):

# Effectively apply zero weight to off-diagonal terms.
# TODO: Can be tidied but requires moving other weighting code.
res[1] = 0
res[2] = 0

# Compute normalization factor.
v1_imul_v2(lop, rop, normf)
iabsdiv(normf)

# Accumulate an element of jhwr.
imul(res, normf) # Apply normalization factor to r.
v1_imul_v2(res, rop, res)
v1_imul_v2(lop, res, res)

Expand Down Expand Up @@ -544,14 +532,12 @@ def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):
jhr[3] += nu*upd_11

w_0, w_1, w_2, w_3 = unpack(w) # NOTE: XX, XY, YX, YY
n_0, n_1, n_2, n_3 = unpack(normf)

# Apply normalisation factors by scaling w. Neglect (set weight
# to zero) off diagonal terms.
w_0 = n_0 * w_0 * n_0
# Neglect (set weight to zero) off diagonal terms.
w_0 = w_0
w_1 = 0
w_2 = 0
w_3 = n_3 * w_3 * n_3
w_3 = w_3

jh_0, jh_1, jh_2, jh_3 = unpack(tmp_kprod[0])
j_0, j_1, j_2, j_3 = unpackc(tmp_kprod[0])
Expand Down Expand Up @@ -591,14 +577,9 @@ def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):
jhj[3, 3] += tmp_2*nusq

elif corr_mode.literal_value == 2:
def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):

# Compute normalization factor.
iunpack(normf, rop)
iabsdiv(normf)
def impl(lop, rop, w, nu, gain, tmp_kprod, res, jhr, jhj):

# Accumulate an element of jhwr.
imul(res, normf)
v1_imul_v2(res, rop, res)

r_0, r_1 = unpack(res)
Expand All @@ -619,31 +600,25 @@ def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):
jh_00, jh_11 = unpack(rop)
j_00, j_11 = unpackc(rop)
w_00, w_11 = unpack(w)
n_00, n_11 = unpack(normf)

nusq = nu*nu

tmp = (jh_00*n_00*w_00*n_00*j_00).real
tmp = (jh_00*w_00*j_00).real
jhj[0, 0] += tmp
jhj[0, 1] += tmp*nu
jhj[1, 0] += tmp*nu
jhj[1, 1] += tmp*nusq

tmp = (jh_11*n_11*w_11*n_11*j_11).real
tmp = (jh_11*w_11*j_11).real
jhj[2, 2] += tmp
jhj[2, 3] += tmp*nu
jhj[3, 2] += tmp*nu
jhj[3, 3] += tmp*nusq

elif corr_mode.literal_value == 1:
def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):

# Compute normalization factor.
iunpack(normf, rop)
iabsdiv(normf)
def impl(lop, rop, w, nu, gain, tmp_kprod, res, jhr, jhj):

# Accumulate an element of jhwr.
imul(res, normf)
v1_imul_v2(res, rop, res)

r_0 = unpack(res)
Expand All @@ -660,11 +635,10 @@ def impl(lop, rop, w, nu, normf, gain, tmp_kprod, res, jhr, jhj):
jh_00 = unpack(rop)
j_00 = unpackc(rop)
w_00 = unpack(w)
n_00 = unpack(normf)

nusq = nu*nu

tmp = (jh_00*n_00*w_00*n_00*j_00).real
tmp = (jh_00*w_00*j_00).real
jhj[0, 0] += tmp
jhj[0, 1] += tmp*nu
jhj[1, 0] += tmp*nu
Expand Down
Loading

0 comments on commit cd1e15f

Please sign in to comment.