Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centroid sem #94

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 78 additions & 11 deletions exovetter/centroid/centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def compute_diff_image_centroids(
duration_days,
remove_transits,
max_oot_shift_pix=1.5,
starloc_pix = None,
plot=False
):
"""Compute difference image centroid shifts for every transit in a dataset.
Expand Down Expand Up @@ -42,6 +43,9 @@ def compute_diff_image_centroids(
(float) Duration of transit.
remove_transits
(list) List of 0 indexed transit integers to not calculate on.
starloc_pix
(2d array) catalog location of target star for plotting.
Default is None.
max_oot_shift_pix
(float) Passed to `fastpsffit.fastGaussianPrfFit()

Expand Down Expand Up @@ -84,6 +88,7 @@ def compute_diff_image_centroids(
cube,
cin,
max_oot_shift_pix=max_oot_shift_pix,
starloc_pix = starloc_pix,
plot=plot
)

Expand Down Expand Up @@ -140,6 +145,51 @@ def measure_centroid_shift(centroids, kept_transits, plot=False):
fig = covar.diagnostic_plot(dcol, drow, kept_transits, flags)
return offset_pix, signif, fig

def measure_centroid_shift_cat(centroids, kept_transits, starloc_pix, plot=False):
"""Measure the average offset of the DIC centroids from the catalog position.

Inputs
----------
centroids
(2d np array) Output of :func:`compute_diff_image_centroids`

kept_transits
(list) List of 0 indexed transit integers to calculate on.

starloc_pix
(2d np array) col,row expected location of target star

Returns
-----------
offset
(float) Size of offset in pixels (or whatever unit `centroids`
is in)
signif
(float) The statistical significance of the transit. Values
close to 1 mean the transit is likely on the target star.
Values less than ~1e-3 suggest the target is not the
source of the transit.
fig
A figure handle. Is **None** if plot is **False**
"""

# DIC - catalog
# dcol = centroids[:, 5] - centroids[:, 0]
# drow = centroids[:, 4] - centroids[:, 1]
dcol = centroids[:, 4] - starloc_pix[0]
drow = centroids[:, 5] - starloc_pix[1]

flags = centroids[:, -1].astype(bool)

offset_pix, signif = covar.compute_offset_and_signif(
dcol[~flags], drow[~flags])

fig = None
if plot:
fig = covar.diagnostic_plot(dcol, drow, kept_transits, flags)

return offset_pix, signif, fig


def getIngressEgressCadences(time, period_days, epoch_btjd, duration_days):
assert np.all(np.isfinite(time))
Expand All @@ -151,7 +201,7 @@ def getIngressEgressCadences(time, period_days, epoch_btjd, duration_days):
return transits


def measure_centroids(cube, cin, max_oot_shift_pix=0.5, plot=False):
def measure_centroids(cube, cin, max_oot_shift_pix=0.5, starloc_pix = None, plot=False):
"""Private function of :func:`compute_diff_image_centroids`

Computes OOT, ITR and diff images for a single transit event,
Expand Down Expand Up @@ -209,17 +259,31 @@ def measure_centroids(cube, cin, max_oot_shift_pix=0.5, plot=False):
if diffSoln.success:
clr = "green"

fig = plt.gcf()
axlist = fig.axes
#assert len(axlist) == 3, axlist

res = diffSoln.x
disp.plotCentroidLocation(res[0], res[1], marker="^", color=clr,
label="diff")
for ax in axlist:
if ax.get_label() == '<colorbar>':
continue

plt.sca(ax)
disp.plotCentroidLocation(res[0], res[1], marker="^", color=clr,
label="diff")

res = ootSoln.x
disp.plotCentroidLocation(res[0], res[1], marker="o", color=clr,
label="OOT")
res1 = ootSoln.x
disp.plotCentroidLocation(res1[0], res1[1], marker="o", color=clr,
label="OOT")

res = intransSoln.x
disp.plotCentroidLocation(res[0], res[1], marker="+", color=clr,
label="InT")
res2 = intransSoln.x
disp.plotCentroidLocation(res2[0], res2[1], marker="+", color=clr,
label="InT")

if starloc_pix is not None:
disp.plotCentroidLocation(starloc_pix[0], starloc_pix[1], marker="*",
color='red', label="Cat", ms=10)

plt.legend(fontsize=12, framealpha=0.7, facecolor='silver')

out = []
Expand All @@ -233,10 +297,10 @@ def measure_centroids(cube, cin, max_oot_shift_pix=0.5, plot=False):
flag = 2
out.append(flag)

return out, ax
return out, fig


def generateDiffImg(cube, transits, plot=False):
def generateDiffImg(cube, transits, starloc_pix = None, plot=False):
"""Generate a difference image.

Also generates an image for each the $n$ cadedences before
Expand All @@ -249,6 +313,8 @@ def generateDiffImg(cube, transits, plot=False):
(np 3 array) Datacube of postage stamps
transits
(2-tuples) Indices of the first and last cadence
starloc_pix
(np 2 element array) col, row position of star

Optional Inputs
-----------------
Expand Down Expand Up @@ -286,6 +352,7 @@ def generateDiffImg(cube, transits, plot=False):
fig = plt.figure()
fig.set_size_inches(16, 4)
disp.plotTransit(fig, oot, during, diff)

else:
fig = None

Expand Down
6 changes: 3 additions & 3 deletions exovetter/centroid/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

def plotTransit(fig, oot, during, diff, **kwargs):

fig.add_subplot(131)
ax1 = fig.add_subplot(131)
plotImage(oot, **kwargs)
plt.title("OOT")

fig.add_subplot(132)
ax2 = fig.add_subplot(132)
plotImage(during, **kwargs)
plt.title("In-transit")

fig.add_subplot(133)
ax3 = fig.add_subplot(133)
plotDiffImage(diff, **kwargs)
plt.title("Difference")

Expand Down
12 changes: 8 additions & 4 deletions exovetter/vetters.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __init__(self, lc_name="flux", diff_plots=False, centroid_plots=False):
self.diff_plots = diff_plots
self.centroid_plots = centroid_plots

def run(self, tce, lk_tpf, plot=False, remove_transits=None):
def run(self, tce, lk_tpf, starloc_pix = None, plot=False, remove_transits=None):
"""Runs cent.compute_diff_image_centroids and cent.measure_centroid_shift
to populate the vetter object.

Expand Down Expand Up @@ -665,12 +665,16 @@ def run(self, tce, lk_tpf, plot=False, remove_transits=None):

if remove_transits is None: # reformat to be a blank list
remove_transits = []

centroids, figs, kept_transits = cent.compute_diff_image_centroids(
time, cube, period_days, epoch, duration_days,
remove_transits, plot=self.diff_plots)
remove_transits, starloc_pix=starloc_pix, plot=self.diff_plots)

offset, signif, fig = cent.measure_centroid_shift(centroids, kept_transits, self.centroid_plots)
if (starloc_pix is not None) and (len(starloc_pix) == 2):
offset, signif, fig = cent.measure_centroid_shift_cat(centroids, kept_transits, starloc_pix, self.centroid_plots)
else:
offset, signif, fig = cent.measure_centroid_shift(centroids, kept_transits, self.centroid_plots)

figs.append(fig)

# TODO: If plot=True, figs is a list of figure handles.
Expand Down
Loading