diff --git a/LWA/ConsumerThread.py b/LWA/ConsumerThread.py new file mode 100644 index 00000000..f4dec4c5 --- /dev/null +++ b/LWA/ConsumerThread.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import threading + +class ConsumerThread(threading.Thread): + class StopObject(object): pass + STOP = StopObject + def __init__(self, input_queue=None): + threading.Thread.__init__(self) + self.input_queue = input_queue + def request_stop(self): + # Note: If multiple consumers for one queue, calling this on each + # should still work to stop them all, just not in any order. + # Must be careful though to call all request_stop()'s before + # calling join()'s. + self.input_queue.put(ConsumerThread.STOP) + def run(self): + while True: + task = self.input_queue.get() + if task is ConsumerThread.STOP: + self.input_queue.task_done() + break + try: + self.process(task) + except Exception as e: + print("ERROR: Uncaught exception in %s: %s" % (self, e)) + self.input_queue.task_done() + self.shutdown() + def process(self, task): + # Implement this in the subclass + pass + def shutdown(self): + # Implement this in the subclass + pass + # Handy method for testing + def put(self, task, timeout=None): + self.input_queue.put(task, timeout) + +if __name__ == "__main__": + from Queue import Queue + import time + + class MyConsumer(ConsumerThread): + def process(self, task): + #time.sleep(5) + print(task) + + q = Queue() + m = MyConsumer(q) + m.daemon = True + m.start() + q.put('Hello') + time.sleep(0.2) + q.put('world!') + time.sleep(0.2) + m.request_stop() diff --git a/LWA/LWA_bifrost.py b/LWA/LWA_bifrost.py index 11653c6a..cf2e8564 100755 --- a/LWA/LWA_bifrost.py +++ b/LWA/LWA_bifrost.py @@ -11,16 +11,21 @@ import sys import json import time -import numpy import signal import logging -import threading import argparse +import threading +import numpy as np from collections import deque from scipy.fftpack import fft + +from astropy.io import fits +from astropy.time import Time, TimeDelta +from astropy.coordinates import SkyCoord, FK5 from astropy.constants import c as speed_of_light import datetime +import MCS2 as MCS import ctypes # Profiling Includes @@ -32,7 +37,7 @@ import bifrost.affinity from bifrost.address import Address as BF_Address from bifrost.udp_socket import UDPSocket as BF_UDPSocket -from bifrost.udp_capture import UDPCapture as BF_UDPCapture +from bifrost.packet_capture import PacketCaptureCallback, UDPCapture from bifrost.ring import Ring from bifrost.unpack import unpack as Unpack from bifrost.quantize import quantize as Quantize @@ -48,7 +53,12 @@ # LWA Software Library Includes from lsl.reader.ldp import TBNFile, TBFFile -from lsl.common.stations import lwasv +from lsl.common.stations import lwa1, lwasv + +#Optimized Bifrost blocks for EPIC +from bifrost.VGrid import VGrid +from bifrost.XGrid import XGrid +from bifrost.aCorr import aCorr # some py2/3 compatibility if sys.version_info.major < 3: @@ -58,6 +68,23 @@ TRIGGER_ACTIVE = threading.Event() +DATE_FORMAT = "%Y_%m_%dT%H_%M_%S" + + +def get_utc_start(): + got_utc_start = False + while not got_utc_start: + try: + with MCS.Communicator() as adp_control: + utc_start = adp_control.report('UTC_START') + # Check for valid timestamp + utc_start_dt = datetime.datetime.strptime(utc_start, DATE_FORMAT) + got_utc_start = True + except Exception as ex: + print(ex) + time.sleep(0.1) + return utc_start_dt + # Profiling def enable_thread_profiling(): @@ -120,28 +147,21 @@ def form_dft_matrix(lmn_vector, antenna_location, antenna_phases, nchan, npol, n """ # lm_matrix, shape = [...,2] , where the last dimension is an l/m pair. - lmn_vector[:, 2] = 1.0 - numpy.sqrt( + lmn_vector[:, 2] = 1.0 - np.sqrt( 1.0 - lmn_vector[:, 0] ** 2 - lmn_vector[:, 1] ** 2 ) - dft_matrix = numpy.zeros( - (nchan, npol, lmn_vector.shape[0], nstand), dtype=numpy.complex64 + dft_matrix = np.zeros( + (nchan, npol, lmn_vector.shape[0], nstand), dtype=np.complex64 ) # DFT phase factors - for i in numpy.arange(antenna_location.shape[3]): - ant_uvw = antenna_location[0, 0, :, i] - - # Both polarisations are at the same physical location, only phases differ. - dft_matrix[:, :, :, i] = numpy.exp( - 2j * numpy.pi * (numpy.dot(lmn_vector, ant_uvw)) - ) - + # Both polarisations are at the same physical location, only phases differ. + dft_matrix[:, :] = np.exp( + 2j * np.pi * (np.dot(lmn_vector, antenna_location[0, 0])) + ) # Can put the antenna phases in as well because maths - for i in numpy.arange(dft_matrix.shape[2]): - for p in numpy.arange(npol): - for c in numpy.arange(nchan): - dft_matrix[c, p, i, :] *= antenna_phases[c, :, p] + dft_matrix *= antenna_phases.transpose([0, 2, 1])[:, :, np.newaxis, :] / nstand - return dft_matrix / antenna_location.shape[3] + return dft_matrix # Frequency-Dependent Locations @@ -151,9 +171,9 @@ def Generate_DFT_Locations(lsl_locs, frequencies, ntime, nchan, npol): Parameters ---------- - lsl_locs : numpy.ndarray + lsl_locs : np.ndarray Array of stand locations. Has shape (3, nstand) - frequencies : numpy.ndarray + frequencies : np.ndarray Array of frequencies in the observation. ntime : int Number of times. @@ -170,10 +190,10 @@ def Generate_DFT_Locations(lsl_locs, frequencies, ntime, nchan, npol): lsl_locs = lsl_locs.T lsl_locs = lsl_locs.copy() chan_wavelengths = speed_of_light.value / frequencies - dft_locs = numpy.zeros(shape=(nchan, npol, 3, lsl_locs.shape[1])) - for j in numpy.arange(npol): - for i in numpy.arange(nchan): - dft_locs[i, j, :, :] = lsl_locs / chan_wavelengths[i] + + dft_locs = lsl_locs[np.newaxis, np.newaxis, :, :] / chan_wavelengths[:, np.newaxis, np.newaxis, np.newaxis] + + dft_locs = np.broadcast_to(dft_locs, (nchan, npol, 3, lsl_locs.shape[1])).copy() return dft_locs @@ -185,9 +205,9 @@ def GenerateLocations( Parameters ---------- - lsl_locs : numpy.ndarray + lsl_locs : np.ndarray Array of stand locations. Has shape (3, nstand) - frequencies : numpy.ndarray + frequencies : np.ndarray Array of frequencies in the observation. ntime : int Number of times. @@ -210,37 +230,20 @@ def GenerateLocations( The sampling length of the DFT or resolution in image space. """ - delta = (2 * grid_size * numpy.sin(numpy.pi * grid_resolution / 360)) ** -1 + delta = (2 * grid_size * np.sin(np.pi * grid_resolution / 360)) ** -1 chan_wavelengths = speed_of_light.value / frequencies sample_grid = chan_wavelengths * delta sll = sample_grid[0] / chan_wavelengths[0] lsl_locs = lsl_locs.T - lsl_locs = lsl_locs.copy() - lsl_locsf = numpy.zeros(shape=(3, npol, nchan, lsl_locs.shape[1])) - for l in numpy.arange(3): - for i in numpy.arange(nchan): - lsl_locsf[l, :, i, :] = lsl_locs[l, :] / sample_grid[i] - - # I'm sure there's a more numpy way of doing this. - for p in numpy.arange(npol): - lsl_locsf[l, p, i, :] -= numpy.min(lsl_locsf[l, p, i, :]) + lsl_locsf = lsl_locs[:, np.newaxis, np.newaxis, :] / sample_grid[np.newaxis, np.newaxis, :, np.newaxis] + lsl_locsf -= np.min(lsl_locsf, axis=3, keepdims=True) # Centre locations slightly - for l in numpy.arange(3): - for i in numpy.arange(nchan): - for p in numpy.arange(npol): - lsl_locsf[l, p, i, :] += ( - grid_size - numpy.max(lsl_locsf[l, p, i, :]) - ) / 2 - - # Tile them for ntime... - locx = numpy.tile(lsl_locsf[0, ...], (ntime, 1, 1, 1)) - locy = numpy.tile(lsl_locsf[1, ...], (ntime, 1, 1, 1)) - locz = numpy.tile(lsl_locsf[2, ...], (ntime, 1, 1, 1)) - # .. and then stick them all into one large array - locc = numpy.concatenate([[locx, ], [locy, ], [locz, ]]).transpose(0, 1, 3, 4, 2).copy() + lsl_locsf += (grid_size - np.max(lsl_locsf, axis=3, keepdims=True)) / 2. + # add ntime axis + locc = np.broadcast_to(lsl_locsf, (ntime, 3, npol, nchan, lsl_locs.shape[1])).transpose(1, 0, 3, 4, 2).copy() return delta, locc, sll @@ -330,7 +333,7 @@ def main(self): # Setup and load idata = data - odata = ospan.data_view(numpy.complex64).reshape(oshape) + odata = ospan.data_view(np.complex64).reshape(oshape) # Transpose and reshape to time by stand by pol idata = idata.transpose((1, 0)) idata = idata.reshape((ntime, nstand, npol)) @@ -356,6 +359,7 @@ def main(self): sys.exit() break print("TBNFillerOp - Done") + os.kill(os.getpid(), signal.SIGTERM) class FDomainOp(object): @@ -412,7 +416,7 @@ def main(self): with self.oring.begin_writing() as oring: for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) self.sequence_proclog.update(ihdr) print("FDomainOp: Config - %s" % ihdr) @@ -457,24 +461,24 @@ def main(self): prev_time = curr_time # Setup and load - idata = ispan.data_view(numpy.complex64).reshape(ishape) + idata = ispan.data_view(np.complex64).reshape(ishape) - odata = ospan.data_view(numpy.int8).reshape(oshape) + odata = ospan.data_view(np.int8).reshape(oshape) # FFT, shift, and phase fdata = fft(idata, axis=1) - fdata = numpy.fft.fftshift(fdata, axes=1) + fdata = np.fft.fftshift(fdata, axes=1) fdata = bifrost.ndarray(fdata, space="system") # Quantization try: - Quantize(fdata, qdata, scale=1. / numpy.sqrt(nchan)) + Quantize(fdata, qdata, scale=1. / np.sqrt(nchan)) except NameError: qdata = bifrost.ndarray(shape=fdata.shape, native=False, dtype="ci4") - Quantize(fdata, qdata, scale=1. / numpy.sqrt(nchan)) + Quantize(fdata, qdata, scale=1. / np.sqrt(nchan)) # Save - odata[...] = qdata.copy(space="cuda_host").view(numpy.int8).reshape(oshape) + odata[...] = qdata.copy(space="cuda_host").view(np.int8).reshape(oshape) if self.profile: spani += 1 @@ -526,7 +530,7 @@ def main(self): idf = TBFFile(self.filename) srate = idf.get_info("sample_rate") - chans = numpy.round(idf.get_info("freq1") / srate).astype(numpy.int32) + chans = np.round(idf.get_info("freq1") / srate).astype(np.int32) chan0 = int(chans[0]) nchan = len(chans) tInt, tStart, data = idf.read(0.1, time_in_samples=True) @@ -583,7 +587,7 @@ def main(self): # Setup and load idata = data - odata = ospan.data_view(numpy.int8).reshape(oshape) + odata = ospan.data_view(np.int8).reshape(oshape) # Transpose and reshape to time by channel by stand by pol idata = idata.transpose((2, 1, 0)) @@ -592,13 +596,13 @@ def main(self): # Quantization try: - Quantize(idata, qdata, scale=1. / numpy.sqrt(nchan)) + Quantize(idata, qdata, scale=1. / np.sqrt(nchan)) except NameError: qdata = bifrost.ndarray(shape=idata.shape, native=False, dtype="ci4") Quantize(idata, qdata, scale=1.0) # Save - odata[...] = qdata.copy(space="cuda_host").view(numpy.int8).reshape(oshape) + odata[...] = qdata.copy(space="cuda_host").view(np.int8).reshape(oshape) data = next_data @@ -618,6 +622,7 @@ def main(self): sys.exit() break print("TBFFillerOp - Done") + os.kill(os.getpid(), signal.SIGTERM) # For when we don't need to care about doing the F-Engine ourself. @@ -665,7 +670,7 @@ def seq_callback( "axes": "time,chan,stand,pol", } print("******** CFREQ:", hdr["cfreq"]) - hdr_str = json.dumps(hdr) + hdr_str = json.dumps(hdr).encode() # TODO: Can't pad with NULL because returned as C-string # hdr_str = json.dumps(hdr).ljust(4096, '\0') # hdr_str = json.dumps(hdr).ljust(4096, ' ') @@ -675,12 +680,14 @@ def seq_callback( return 0 def main(self): - seq_callback = bf.BFudpcapture_sequence_callback(self.seq_callback) - with BF_UDPCapture( - *self.args, sequence_callback=seq_callback, **self.kwargs - ) as capture: + seq_callback = PacketCaptureCallback() + + seq_callback.set_chips(self.seq_callback) + with UDPCapture(*self.args, + sequence_callback=seq_callback, + **self.kwargs) as capture: while not self.shutdown_event.is_set(): - capture.recv() + status = capture.recv() del capture @@ -725,7 +732,7 @@ def main(self): with self.oring.begin_writing() as oring: for iseq in self.iring.read(guarantee=self.guarantee): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) self.sequence_proclog.update(ihdr) @@ -740,8 +747,8 @@ def main(self): ishape = (self.ntime_gulp, nchan, nstand, npol) ogulp_size = self.ntime_gulp * self.nchan_out * nstand * self.npol_out * 1 # ci4 oshape = (self.ntime_gulp, self.nchan_out, nstand, self.npol_out) - self.iring.resize(igulp_size) - self.oring.resize(ogulp_size) # , obuf_size) + self.iring.resize(igulp_size, buffer_factor= 8) + self.oring.resize(ogulp_size, buffer_factor= 128) # , obuf_size) ohdr = ihdr.copy() ohdr["nchan"] = self.nchan_out @@ -764,8 +771,8 @@ def main(self): reserve_time = curr_time - prev_time prev_time = curr_time - idata = ispan.data_view(numpy.uint8).reshape(ishape) - odata = ospan.data_view(numpy.uint8).reshape(oshape) + idata = ispan.data_view(np.uint8).reshape(ishape) + odata = ospan.data_view(np.uint8).reshape(oshape) sdata = idata[:, :self.nchan_out, :, :] if self.npol_out != npol: @@ -796,7 +803,7 @@ def __init__( log, iring, oring, - antennas, + station, grid_size, grid_resolution, ntime_gulp=2500, @@ -815,12 +822,12 @@ def __init__( self.ntime_gulp = ntime_gulp self.accumulation_time = accumulation_time - self.antennas = antennas - locations = numpy.empty(shape=(0, 3)) - for ant in self.antennas: - locations = numpy.vstack((locations, [ant.stand[0], ant.stand[1], ant.stand[2]])) - locations = numpy.delete(locations, list(range(0, locations.shape[0], 2)), axis=0) - locations[255, :] = 0.0 + self.station = station + locations = np.array([(ant.stand.x, ant.stand.y, ant.stand.z) for ant in self.station.antennas[::2]]) + if self.station == lwasv: + locations[[i for i, a in enumerate(self.station.antennas[::2]) if a.stand.id == 256], :] = 0.0 + elif self.station == lwa1: + locations[[i for i, a in enumerate(self.station.antennas[::2]) if a.stand.id in (35, 257, 258, 259, 260)], :] = 0.0 self.locations = locations self.grid_size = grid_size @@ -869,7 +876,7 @@ def main(self): accum = 0 with self.oring.begin_writing() as oring: for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) self.sequence_proclog.update(ihdr) self.log.info("MOFFCorrelatorOp: Config - %s" % ihdr) chan0 = ihdr["chan0"] @@ -882,7 +889,7 @@ def main(self): igulp_size = self.ntime_gulp * nchan * nstand * npol * 1 # ci4 itshape = (self.ntime_gulp, nchan, nstand, npol) - freq = (chan0 + numpy.arange(nchan)) * CHAN_BW + freq = (chan0 + np.arange(nchan)) * CHAN_BW sampling_length, locs, sll = GenerateLocations( self.locations, freq, @@ -893,15 +900,15 @@ def main(self): grid_resolution=self.grid_resolution, ) try: - copy_array(self.locs, bifrost.ndarray(locs.astype(numpy.int32))) + copy_array(self.locs, bifrost.ndarray(locs.astype(np.int32))) except AttributeError: - self.locs = bifrost.ndarray(locs.astype(numpy.int32), space="cuda") + self.locs = bifrost.ndarray(locs.astype(np.int32), space="cuda") ohdr = ihdr.copy() ohdr["nbit"] = 64 ms_per_gulp = 1e3 * self.ntime_gulp / CHAN_BW - new_accumulation_time = numpy.ceil(self.accumulation_time / ms_per_gulp) * ms_per_gulp + new_accumulation_time = np.ceil(self.accumulation_time / ms_per_gulp) * ms_per_gulp if new_accumulation_time != self.accumulation_time: self.log.warning( "Adjusting accumulation time from %.3f ms to %.3f ms", @@ -918,9 +925,9 @@ def main(self): ohdr["sampling_length_y"] = sampling_length ohdr["accumulation_time"] = self.accumulation_time ohdr["FS"] = FS - ohdr["latitude"] = lwasv.lat * 180. / numpy.pi - ohdr["longitude"] = lwasv.lon * 180. / numpy.pi - ohdr["telescope"] = "LWA-SV" + ohdr["latitude"] = self.station.lat * 180. / np.pi + ohdr["longitude"] = self.station.lon * 180. / np.pi + ohdr["telescope"] = self.station.name.upper() ohdr["data_units"] = "UNCALIB" if ohdr["npol"] == 1: ohdr["pols"] = ["xx"] @@ -937,29 +944,35 @@ def main(self): # Setup the kernels to include phasing terms for zenith # Phases are Ntime x Nchan x Nstand x Npol x extent x extent freq.shape += (1, 1) - phases = numpy.zeros( + phases = np.zeros( (self.ntime_gulp, nchan, nstand, npol, self.ant_extent, self.ant_extent), - dtype=numpy.complex64 + dtype=np.complex64 ) for i in range(nstand): # X - a = self.antennas[2 * i + 0] + a = self.station.antennas[2 * i + 0] delay = a.cable.delay(freq) - a.stand.z / speed_of_light.value - phases[:, :, i, 0, :, :] = numpy.exp(2j * numpy.pi * freq * delay) - phases[:, :, i, 0, :, :] /= numpy.sqrt(a.cable.gain(freq)) + phases[:, :, i, 0, :, :] = np.exp(2j * np.pi * freq * delay) + phases[:, :, i, 0, :, :] /= np.sqrt(a.cable.gain(freq)) if npol == 2: # Y - a = self.antennas[2 * i + 1] + a = self.station.antennas[2 * i + 1] delay = a.cable.delay(freq) - a.stand.z / speed_of_light.value - phases[:, :, i, 1, :, :] = numpy.exp(2j * numpy.pi * freq * delay) - phases[:, :, i, 1, :, :] /= numpy.sqrt(a.cable.gain(freq)) + phases[:, :, i, 1, :, :] = np.exp(2j * np.pi * freq * delay) + phases[:, :, i, 1, :, :] /= np.sqrt(a.cable.gain(freq)) # Explicit bad and suspect antenna masking - this will # mask an entire stand if either pol is bad - if self.antennas[2 * i + 0].combined_status < 33 or self.antennas[2 * i + 1].combined_status < 33: + if ( + self.station.antennas[2 * i + 0].combined_status < 33 + or self.station.antennas[2 * i + 1].combined_status < 33 + ): phases[:, :, i, :, :, :] = 0.0 # Explicit outrigger masking - we probably want to do # away with this at some point - if a.stand.id == 256: + if ( + (self.station == lwasv and a.stand.id == 256) + or (self.station == lwa1 and a.stand.id in (35, 257, 258, 259, 260)) + ): phases[:, :, i, :, :, :] = 0.0 phases = phases.conj() phases = bifrost.ndarray(phases) @@ -970,8 +983,8 @@ def main(self): oshape = (1, nchan, npol ** 2, self.grid_size, self.grid_size) ogulp_size = nchan * npol ** 2 * self.grid_size * self.grid_size * 8 - self.iring.resize(igulp_size) - self.oring.resize(ogulp_size, buffer_factor=5) + self.iring.resize(igulp_size, buffer_factor=128) + self.oring.resize(ogulp_size, buffer_factor=256) prev_time = time.time() with oring.begin_sequence(time_tag=iseq.time_tag, header=ohdr_str) as oseq: iseq_spans = iseq.read(igulp_size) @@ -993,7 +1006,7 @@ def main(self): # Correlator # Setup and load - idata = ispan.data_view(numpy.uint8).reshape(itshape) + idata = ispan.data_view(np.uint8).reshape(itshape) # Fix the type udata = bifrost.ndarray( shape=itshape, @@ -1020,25 +1033,31 @@ def main(self): except NameError: gdata = bifrost.zeros( shape=(self.ntime_gulp, nchan, npol, self.grid_size, self.grid_size), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) # Grid the Antennas if self.benchmark is True: timeg1 = time.time() - try: - bf_romein.execute(udata, gdata) + bf_vgrid.execute(udata, gdata) except NameError: - bf_romein = Romein() - bf_romein.init(self.locs, gphases, self.grid_size, polmajor=False) - bf_romein.execute(udata, gdata) + bf_vgrid = VGrid() + bf_vgrid.init(self.locs, gphases, self.grid_size, polmajor=False) + bf_vgrid.execute(udata, gdata) + + #try: + # bf_romein.execute(udata, gdata) + #except NameError: + # bf_romein = Romein() + # bf_romein.init(self.locs, gphases, self.grid_size, polmajor=False) + # bf_romein.execute(udata, gdata) gdata = gdata.reshape(self.ntime_gulp * nchan * npol, self.grid_size, self.grid_size) # gdata = self.LinAlgObj.matmul(1.0, udata, bfantgridmap, 0.0, gdata) if self.benchmark is True: timeg2 = time.time() - print(" Romein time: %f" % (timeg2 - timeg1)) + print(" Grid time: %f" % (timeg2 - timeg1)) # Inverse transform @@ -1074,12 +1093,12 @@ def main(self): except NameError: crosspol = bifrost.zeros( shape=(self.ntime_gulp, nchan, npol ** 2, self.grid_size, self.grid_size), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) accumulated_image = bifrost.zeros( shape=(1, nchan, npol ** 2, self.grid_size, self.grid_size), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) self.newflag = False @@ -1098,48 +1117,64 @@ def main(self): except NameError: autocorrs = bifrost.ndarray( shape=(self.ntime_gulp, nchan, nstand, npol ** 2), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) autocorrs_av = bifrost.zeros( shape=(1, nchan, nstand, npol ** 2), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) autocorr_g = bifrost.zeros( shape=(1, nchan, npol ** 2, self.grid_size, self.grid_size), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda", ) autocorr_lo = bifrost.ndarray( - numpy.ones( + np.ones( shape=(3, 1, nchan, nstand, npol ** 2), - dtype=numpy.int32 + dtype=np.int32 ) * self.grid_size / 2, space="cuda", ) autocorr_il = bifrost.ndarray( - numpy.ones( + np.ones( shape=(1, nchan, nstand, npol ** 2, self.ant_extent, self.ant_extent), - dtype=numpy.complex64 + dtype=np.complex64 ), space="cuda", ) # Cross multiply to calculate autocorrs - bifrost.map( - "a(i,j,k,l) += (b(i,j,k,l/2) * b(i,j,k,l%2).conj())", - {"a": autocorrs, "b": udata, "t": self.ntime_gulp}, - axis_names=("i", "j", "k", "l"), - shape=(self.ntime_gulp, nchan, nstand, npol ** 2), + #bifrost.map( + # "a(i,j,k,l) += (b(i,j,k,l/2) * b(i,j,k,l%2).conj())", + # {"a": autocorrs, "b": udata, "t": self.ntime_gulp}, + # axis_names=("i", "j", "k", "l"), + # shape=(self.ntime_gulp, nchan, nstand, npol ** 2), + #) + try: + bf_auto.execute(udata, autocorrs) + except NameError: + bf_auto = aCorr() + bf_auto.init(self.locs, polmajor=False) + bf_auto.execute(udata, autocorrs) + autocorrs = autocorrs.reshape( + self.ntime_gulp, nchan, nstand, npol ** 2 ) - bifrost.map( - "a(i,j,p,k,l) += b(0,i,j,p/2,k,l)*b(0,i,j,p%2,k,l).conj()", - {"a": crosspol, "b": gdata}, - axis_names=("i", "j", "p", "k", "l"), - shape=(self.ntime_gulp, nchan, npol ** 2, self.grid_size, self.grid_size), - ) + + #bifrost.map( + # "a(i,j,p,k,l) += b(0,i,j,p/2,k,l)*b(0,i,j,p%2,k,l).conj()", + # {"a": crosspol, "b": gdata}, + # axis_names=("i", "j", "p", "k", "l"), + # shape=(self.ntime_gulp, nchan, npol ** 2, self.grid_size, self.grid_size), + #) + try: + bf_gmul.execute(gdata, crosspol) + except NameError: + bf_gmul = XGrid() + bf_gmul.init(self.grid_size, polmajor=False) + bf_gmul.execute(gdata, crosspol) crosspol = crosspol.reshape( self.ntime_gulp, nchan, npol ** 2, self.grid_size, self.grid_size ) @@ -1160,14 +1195,22 @@ def main(self): autocorr_g = autocorr_g.reshape( 1, nchan, npol ** 2, self.grid_size, self.grid_size ) + #try: + # bf_romein_autocorr.execute(autocorrs_av, autocorr_g) + #except NameError: + # bf_romein_autocorr = Romein() + # bf_romein_autocorr.init( + # autocorr_lo, autocorr_il, self.grid_size, polmajor=False + # ) + # bf_romein_autocorr.execute(autocorrs_av, autocorr_g) try: - bf_romein_autocorr.execute(autocorrs_av, autocorr_g) + bf_vgrid_autocorr.execute(autocorrs_av, autocorr_g) except NameError: - bf_romein_autocorr = Romein() - bf_romein_autocorr.init( + bf_vgrid_autocorr = VGrid() + bf_vgrid_autocorr.init( autocorr_lo, autocorr_il, self.grid_size, polmajor=False ) - bf_romein_autocorr.execute(autocorrs_av, autocorr_g) + bf_vgrid_autocorr.execute(autocorrs_av, autocorr_g) autocorr_g = autocorr_g.reshape(1 * nchan * npol ** 2, self.grid_size, self.grid_size) # autocorr_g = romein_float(autocorrs_av,autocorr_g,autocorr_il,autocorr_lx,autocorr_ly,autocorr_lz,self.ant_extent,self.grid_size,nstand,nchan*npol**2) # Inverse FFT @@ -1192,7 +1235,7 @@ def main(self): prev_time = curr_time with oseq.reserve(ogulp_size) as ospan: - odata = ospan.data_view(numpy.complex64).reshape(oshape) + odata = ospan.data_view(np.complex64).reshape(oshape) accumulated_image = accumulated_image.reshape(oshape) odata[...] = accumulated_image bifrost.device.stream_synchronize() @@ -1258,7 +1301,7 @@ def __init__( log, iring, oring, - antennas, + station, skymodes=64, ntime_gulp=2500, accumulation_time=10000, @@ -1276,12 +1319,12 @@ def __init__( self.accumulation_time = accumulation_time # Setup Antennas - self.antennas = antennas - locations = numpy.empty(shape=(0, 3)) - for ant in self.antennas: - locations = numpy.vstack((locations, [ant.stand[0], ant.stand[1], ant.stand[2]])) - locations = numpy.delete(locations, list(range(0, locations.shape[0], 2)), axis=0) - # locations[255,:] = 0.0 + self.station = station + locations = np.array([(ant.stand.x, ant.stand.y, ant.stand.z) for ant in self.station.antennas[::2]]) + #if self.station == lwasv: + # locations[[i for i, a in enumerate(self.station.antennas[::2]) if a.stand.id == 256], :] = 0.0 + #elif self.station == lwa1: + # locations[[i for i, a in enumerate(self.station.antennas[::2]) if a.stand.id in (35, 257, 258, 259, 260)], :] = 0.0 self.locations = locations # LinAlg @@ -1334,7 +1377,7 @@ def main(self): accum = 0 with self.oring.begin_writing() as oring: for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) self.sequence_proclog.update(ihdr) self.log.info("MOFFCorrelatorOp: Config - %s" % ihdr) chan0 = ihdr["chan0"] @@ -1348,21 +1391,21 @@ def main(self): itshape = (self.ntime_gulp, nchan, nstand, npol) # Sample locations at right u/v/w values - freq = (chan0 + numpy.arange(nchan)) * CHAN_BW + freq = (chan0 + np.arange(nchan)) * CHAN_BW locs = Generate_DFT_Locations( self.locations, freq, self.ntime_gulp, nchan, npol ) try: - copy_array(self.locs, bifrost.ndarray(locs.astype(numpy.int32))) + copy_array(self.locs, bifrost.ndarray(locs.astype(np.int32))) except AttributeError: - self.locs = bifrost.ndarray(locs.astype(numpy.int32), space="cuda") + self.locs = bifrost.ndarray(locs.astype(np.int32), space="cuda") ohdr = ihdr.copy() ohdr["nbit"] = 64 ms_per_gulp = 1e3 * self.ntime_gulp / CHAN_BW - new_accumulation_time = numpy.ceil(self.accumulation_time / ms_per_gulp) * ms_per_gulp + new_accumulation_time = np.ceil(self.accumulation_time / ms_per_gulp) * ms_per_gulp if new_accumulation_time != self.accumulation_time: self.log.warning("Adjusting accumulation time from %.3f ms to %.3f ms", self.accumulation_time, new_accumulation_time) @@ -1375,9 +1418,9 @@ def main(self): ohdr["axes"] = "time,chan,pol,gridy,gridx" ohdr["accumulation_time"] = self.accumulation_time ohdr["FS"] = FS - ohdr["latitude"] = lwasv.lat * 180. / numpy.pi - ohdr["longitude"] = lwasv.lon * 180. / numpy.pi - ohdr["telescope"] = "LWA-SV" + ohdr["latitude"] = self.station.lat * 180. / np.pi + ohdr["longitude"] = self.station.lon * 180. / np.pi + ohdr["telescope"] = self.station.name.upper() ohdr["data_units"] = "UNCALIB" if ohdr["npol"] == 1: ohdr["pols"] = ["xx"] @@ -1393,42 +1436,51 @@ def main(self): # Phases are Nchan x Nstand x Npol # freq.shape += (1,) - phases = numpy.zeros((nchan, nstand, npol), dtype=numpy.complex64) + phases = np.zeros((nchan, nstand, npol), dtype=np.complex64) for i in range(nstand): # X - a = self.antennas[2 * i + 0] + a = self.station.antennas[2 * i + 0] delay = a.cable.delay(freq) - a.stand.z / speed_of_light.value - phases[:, i, 0] = numpy.exp(2j * numpy.pi * freq * delay) - phases[:, i, 0] /= numpy.sqrt(a.cable.gain(freq)) + phases[:, i, 0] = np.exp(2j * np.pi * freq * delay) + phases[:, i, 0] /= np.sqrt(a.cable.gain(freq)) if npol == 2: # Y - a = self.antennas[2 * i + 1] + a = self.station.antennas[2 * i + 1] delay = a.cable.delay(freq) - a.stand.z / speed_of_light.value - phases[:, i, 1] = numpy.exp(2j * numpy.pi * freq * delay) - phases[:, i, 1] /= numpy.sqrt(a.cable.gain(freq)) - # Explicit outrigger masking - we probably want to do - # away with this at some point - # if a.stand.id == 256: - # phases[:,i] = 0.0 - # nj() + phases[:, i, 1] = np.exp(2j * np.pi * freq * delay) + phases[:, i, 1] /= np.sqrt(a.cable.gain(freq)) + # Explicit bad and suspect antenna masking - this will + # mask an entire stand if either pol is bad + if ( + self.station.antennas[2 * i + 0].combined_status < 33 + or self.station.antennas[2 * i + 1].combined_status < 33 + ): + phases[:, i, :] = 0.0 + # Explicit outrigger masking - we probably want to do + # away with this at some point + #if ( + # (self.station == lwasv and a.stand.id == 256) + # or (self.station == lwa1 and a.stand.id in (35, 257, 258, 259, 260)) + #): + # phases[:, i, :] = 0.0 phases = bifrost.ndarray(phases) # Setup DFT Transform Matrix - lm_matrix = numpy.zeros(shape=(self.skymodes1d, self.skymodes1d, 3)) + lm_matrix = np.zeros(shape=(self.skymodes1d, self.skymodes1d, 3)) lm_step = 2.0 / self.skymodes1d - i, j = numpy.meshgrid(numpy.arange(self.skymodes1d), numpy.arange(self.skymodes1d)) + i, j = np.meshgrid(np.arange(self.skymodes1d), np.arange(self.skymodes1d)) # this builds a 3 x 64 x 64 matrix, need to transpose axes to [2, 1, 0] to get correct # 64 x 64 x 3 shape - lm_matrix = numpy.asarray([i * lm_step - 1.0, j * lm_step - 1.0, numpy.zeros_like(j)]) - lm_matrix = numpy.fft.fftshift(lm_matrix, axes=(1,2)) + lm_matrix = np.asarray([i * lm_step - 1.0, j * lm_step - 1.0, np.zeros_like(j)]) + lm_matrix = np.fft.fftshift(lm_matrix, axes=(1,2)) lm_vector = lm_matrix.transpose([1, 2, 0]).reshape((self.skymodes, 3)) self.dftm = bifrost.ndarray( form_dft_matrix(lm_vector, locs, phases, nchan, npol, nstand) ) - # self.dftm = bifrost.ndarray(numpy.tile(self.dftm[numpy.newaxis,:],(nchan,1,1,1))) + # self.dftm = bifrost.ndarray(np.tile(self.dftm[np.newaxis,:],(nchan,1,1,1))) dftm_cu = self.dftm.copy(space="cuda") # sys.exit(1) @@ -1457,7 +1509,7 @@ def main(self): # Correlator # Setup and load - idata = ispan.data_view(numpy.uint8).reshape(itshape) + idata = ispan.data_view(np.uint8).reshape(itshape) # Fix the type tdata = bifrost.ndarray(shape=itshape, dtype="ci4", native=False, buffer=idata.ctypes.data) @@ -1474,7 +1526,7 @@ def main(self): udata = udata.reshape(*tdata.shape) Unpack(tdata, udata) except NameError: - udata = bifrost.ndarray(shape=tdata.shape, dtype=numpy.complex64, space="cuda") + udata = bifrost.ndarray(shape=tdata.shape, dtype=np.complex64, space="cuda") Unpack(tdata, udata) # Phase # bifrost.map('a(i,j,k,l) *= b(j,k,l)', @@ -1490,7 +1542,7 @@ def main(self): except NameError: gdata = bifrost.zeros( shape=(nchan * npol, self.skymodes, self.ntime_gulp), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda" ) memset_array(gdata, 0) @@ -1509,12 +1561,12 @@ def main(self): except NameError: gdatas = bifrost.zeros( shape=(nchan, npol ** 2, self.skymodes, self.ntime_gulp), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda" ) accumulated_image = bifrost.zeros( shape=(nchan, npol ** 2, self.skymodes, 1), - dtype=numpy.complex64, + dtype=np.complex64, space="cuda" ) self.newflag = False @@ -1538,7 +1590,7 @@ def main(self): prev_time = curr_time with oseq.reserve(ogulp_size) as ospan: - odata = ospan.data_view(numpy.complex64).reshape(oshape) + odata = ospan.data_view(np.complex64).reshape(oshape) accumulated_image = accumulated_image.reshape(oshape) # gdatass = gdatass.reshape(oshape) # odata[...] = gdatass @@ -1603,7 +1655,7 @@ def __init__( self.iring = iring self.ints_per_file = ints_per_analysis self.threshold = threshold - self.elevation_limit = elevation_limit * numpy.pi / 180.0 + self.elevation_limit = elevation_limit * np.pi / 180.0 self.core = core self.gpu = gpu @@ -1641,7 +1693,7 @@ def main(self): ) for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) fileid = 0 self.sequence_proclog.update(ihdr) @@ -1660,11 +1712,11 @@ def main(self): % (nchan, npol, grid_size, sampling_length) ) - x, y = numpy.arange(grid_size_x), numpy.arange(grid_size_y) - x, y = numpy.meshgrid(x, y) - rho = numpy.sqrt((x - grid_size_x / 2) ** 2 + (y - grid_size_y / 2) ** 2) - mask = numpy.where( - rho <= grid_size * sampling_length * numpy.cos(self.elevation_limit), + x, y = np.arange(grid_size_x), np.arange(grid_size_y) + x, y = np.meshgrid(x, y) + rho = np.sqrt((x - grid_size_x / 2) ** 2 + (y - grid_size_y / 2) ** 2) + mask = np.where( + rho <= grid_size * sampling_length * np.cos(self.elevation_limit), False, True ) @@ -1685,12 +1737,12 @@ def main(self): acquire_time = curr_time - prev_time prev_time = curr_time - idata = ispan.data_view(numpy.complex64).reshape(ishape) + idata = ispan.data_view(np.complex64).reshape(ishape) itemp = idata.copy(space="cuda_host") image.append(itemp) nints += 1 if nints >= self.ints_per_file: - image = numpy.fft.fftshift(image, axes=(3, 4)) + image = np.fft.fftshift(image, axes=(3, 4)) image = image[:, :, :, ::-1, :] # NOTE: This just uses the first polarization (XX) for now. # In the future we probably want to use Stokes I (if @@ -1706,8 +1758,8 @@ def main(self): # current image (image) with a moving average of the last # N images (image_background). This is roughly like what # is done at LWA1/LWA-SV to find events in the LASI images. - image_background = numpy.median(image_history, axis=0) - image_diff = numpy.ma.array(image - image_background, mask=mask) + image_background = np.median(image_history, axis=0) + image_diff = np.ma.array(image - image_background, mask=mask) peak, mid, rms = image_diff.max(), image_diff.mean(), image_diff.std() print("-->", peak, mid, rms, "@", (peak - mid) / rms) if (peak - mid) > self.threshold * rms: @@ -1771,6 +1823,7 @@ def main(self): MAX_HISTORY = 5 + if self.core != -1: bifrost.affinity.set_core(self.core) if self.gpu != -1: @@ -1787,7 +1840,7 @@ def main(self): image_history = deque([], MAX_HISTORY) for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) fileid = 0 self.sequence_proclog.update(ihdr) @@ -1814,6 +1867,51 @@ def main(self): dump_counter = 0 + # some constant header information + primary_hdu = fits.PrimaryHDU() + + primary_hdu.header["TELESCOP"] = ihdr["telescope"] + # grab the time from the 0th file for the primary header + # before dumping + primary_hdu.header["DATE-OBS"] = Time( + ihdr["time_tag"] / ihdr["FS"] + 1e-3 * ihdr["accumulation_time"] / 2.0, + format="unix", + precision=6, + ).isot + + primary_hdu.header["BUNIT"] = ihdr["data_units"] + primary_hdu.header["BSCALE"] = 1e0 + primary_hdu.header["BZERO"] = 0e0 + primary_hdu.header["EQUINOX"] = "J2000" + primary_hdu.header["EXTNAME"] = "PRIMARY" + primary_hdu.header["GRIDDIMX"] = ihdr["grid_size_x"] + primary_hdu.header["GRIDDIMY"] = ihdr["grid_size_y"] + primary_hdu.header["DGRIDX"] = ihdr["sampling_length_x"] + primary_hdu.header["DGRIDY"] = ihdr["sampling_length_y"] + primary_hdu.header["INTTIM"] = ihdr["accumulation_time"] * 1e-3 + primary_hdu.header["INTTIMU"] = "SECONDS" + primary_hdu.header["CFREQ"] = ihdr["cfreq"] + primary_hdu.header["CFREQU"] = "HZ" + + pol_dict = {"xx": -5, "yy": -6, "xy": -7, "yx": -8} + pol_nums = [pol_dict[p] for p in ihdr["pols"]] + pol_order = np.argsort(pol_nums)[::-1] + + dt = TimeDelta(1e-3 * ihdr["accumulation_time"], format="sec") + + dtheta_x = 2 * np.arcsin(0.5 / (ihdr["grid_size_x"] * ihdr["sampling_length_x"])) + dtheta_y = 2 * np.arcsin(0.5 / (ihdr["grid_size_y"] * ihdr["sampling_length_y"])) + + crit_pix_x = float(ihdr["grid_size_x"] / 2 + 1) + # Need to correct for shift in center pixel when we flipped dec dimension + # when writing npz, Only applies for even dimension size + crit_pix_y = float(ihdr["grid_size_y"] / 2 + 1) - (ihdr["grid_size_x"] + 1) % 2 + + delta_x = -dtheta_x * 180.0 / np.pi + delta_y = dtheta_y * 180.0 / np.pi + delta_f = ihdr["bw"] / ihdr["nchan"] + crit_pix_f = (ihdr["nchan"] - 1) * 0.5 + 1 # +1 for FITS numbering + if self.profile: spani = 0 @@ -1824,29 +1922,109 @@ def main(self): acquire_time = curr_time - prev_time prev_time = curr_time - idata = ispan.data_view(numpy.complex64).reshape(ishape) + idata = ispan.data_view(np.complex64).reshape(ishape) itemp = idata.copy(space="cuda_host") image.append(itemp) nints += 1 + if nints >= self.ints_per_file: - image = numpy.fft.fftshift(image, axes=(3, 4)) + image = np.fft.fftshift(image, axes=(3, 4)) image = image[:, :, :, ::-1, :] + + # Restructure data in preparation to stuff into fits + # Now (Ntimes, Npol, Nfreq, y, x) + image = image.transpose(0, 2, 1, 3, 4) + + # Reorder pol for fits convention + image = image[:, pol_order, :, :, :] + # Break up real/imaginary + image = image[ + :, np.newaxis, :, :, :, : + ] # Now (Ntimes, 2 (complex), Npol, Nfreq, y, x) + image = np.concatenate([image.real, image.imag], axis=1) + unix_time = ( ihdr["time_tag"] / FS + ihdr["accumulation_time"] * 1e-3 * fileid * self.ints_per_file ) - image_nums = numpy.arange(fileid * self.ints_per_file, (fileid + 1) * self.ints_per_file) - filename = os.path.join(self.out_dir, "EPIC_{0:3f}_{1:0.3f}MHz.npz".format(unix_time, cfreq / 1e6)) - image_history.append((filename, image, ihdr, image_nums)) + t0 = Time( + unix_time, + format="unix", + precision=6, + location=(ihdr["longitude"], ihdr["latitude"]) + ) + + time_array = t0 + np.arange(nints) * dt + + lsts = time_array.sidereal_time("apparent") + coords = SkyCoord( + lsts.deg, ihdr["latitude"], obstime=time_array, unit="deg" + ).transform_to(FK5(equinox="J2000")) + + hdul = [] + for im_num, d in enumerate(image): + hdu = fits.ImageHDU(data=d) + # Time + t = time_array[im_num] + lst = lsts[im_num] + hdu.header["DATETIME"] = t.isot + hdu.header["LST"] = lst.hour + # Coordinates - sky + + hdu.header["EQUINOX"] = "J2000" + + hdu.header["CTYPE1"] = "RA---SIN" + hdu.header["CRPIX1"] = crit_pix_x + hdu.header["CDELT1"] = delta_x + hdu.header["CRVAL1"] = coords[im_num].ra.deg + hdu.header["CUNIT1"] = "deg" + hdu.header["CTYPE2"] = "DEC--SIN" + + hdu.header["CRPIX2"] = crit_pix_y + + hdu.header["CDELT2"] = delta_y + hdu.header["CRVAL2"] = coords[im_num].dec.deg + hdu.header["CUNIT2"] = "deg" + # Coordinates - Freq + hdu.header["CTYPE3"] = "FREQ" + hdu.header["CRPIX3"] = crit_pix_f + hdu.header["CDELT3"] = delta_f + hdu.header["CRVAL3"] = ihdr["cfreq"] + hdu.header["CUNIT3"] = "Hz" + # Coordinates - Stokes parameters + hdu.header["CTYPE4"] = "STOKES" + hdu.header["CRPIX4"] = 1 + hdu.header["CDELT4"] = -1 + hdu.header["CRVAL4"] = pol_nums[pol_order[0]] + # Coordinates - Complex + hdu.header["CTYPE5"] = "COMPLEX" + hdu.header["CRVAL5"] = 1.0 + hdu.header["CRPIX5"] = 1.0 + hdu.header["CDELT5"] = 1.0 + + hdul.append(hdu) + + filename = os.path.join( + self.out_dir, + "EPIC_{0:3f}_{1:0.3f}MHz.fits".format(unix_time, cfreq / 1e6), + ) + + image_history.append((filename, hdul)) if TRIGGER_ACTIVE.is_set() or not self.triggering: if dump_counter == 0: dump_counter = 20 + MAX_HISTORY elif dump_counter == 1: TRIGGER_ACTIVE.clear() - cfilename, cimage, chdr, cimage_nums = image_history.popleft() - numpy.savez(cfilename, image=cimage, hdr=chdr, image_nums=cimage_nums) + + + + cfilename, hdus = image_history.popleft() + hdulist = fits.HDUList([primary_hdu, *hdus]) + hdulist.writeto(cfilename, overwrite=True) + + # np.savez(cfilename, image=cimage, hdr=chdr, image_nums=cimage_nums) print("SaveOp - Image Saved") dump_counter -= 1 @@ -1883,7 +2061,7 @@ def main(self): for iseq in self.iring.read(guarantee=True): - ihdr = json.loads(iseq.header.tostring()) + ihdr = json.loads(iseq.header.tobytes()) nchan = ihdr["nchan"] nstand = ihdr["nstand"] npol = ihdr["npol"] @@ -1899,21 +2077,17 @@ def main(self): if ispan.size < igulp_size: continue - idata = ispan.data_view(numpy.int8) + idata = ispan.data_view(np.int8) idata = idata.reshape(ishape) idata = bifrost.ndarray(shape=ishape, dtype="ci4", native=False, buffer=idata.ctypes.data) - print(numpy.shape(idata)) - numpy.savez(self.filename + "asdasd.npy", data=idata) + print(np.shape(idata)) + np.savez(self.filename + "asdasd.npy", data=idata) print("Wrote to disk") break print("Save F-Engine Spectra.. done") - -def main(): - - # Main Input: UDP Broadcast RX from F-Engine? - +def gen_args(return_parser=False): parser = argparse.ArgumentParser( description="EPIC Correlator", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -1929,7 +2103,7 @@ def main(): group1.add_argument( "--utcstart", type=str, - default="1970_1_1T0_0_0", + default=None, help="F-Engine UDP Stream Start Time", ) @@ -1938,6 +2112,9 @@ def main(): "--offline", action="store_true", help="Load TBN data from Disk" ) group2.add_argument("--tbnfile", type=str, help="TBN Data Path") + group2.add_argument( + "--lwa1", action="store_true", help="TBN data is from LWA1, not LWA-SV" + ) group2.add_argument("--tbffile", type=str, help="TBF Data Path") group3 = parser.add_argument_group("Processing Options") @@ -1954,6 +2131,13 @@ def main(): default=1000, help="How many milliseconds to accumulate an image over", ) + group3.add_argument( + "--duration", + type=int, + default=3600, + help="Duration of EPIC (seconds)", + ) + group4 = parser.add_argument_group("Correlation Options") group4.add_argument( @@ -2023,10 +2207,17 @@ def main(): print("Output directory does not exist. Defaulting to current directory.") args.out_dir = "." - if args.removeautocorrs: - raise NotImplementedError( - "Removing autocorrelations is not yet properly implemented." - ) + #if args.removeautocorrs: + # raise NotImplementedError( + # "Removing autocorrelations is not yet properly implemented." + # ) + + if return_parser: + return args, parser + else: + return args + +def main(args, parser): log = logging.getLogger(__name__) logFormat = logging.Formatter( @@ -2040,8 +2231,8 @@ def main(): log.setLevel(logging.DEBUG) # Setup the cores and GPUs to use - cores = [0, 2, 3, 4, 5, 6, 7] - gpus = [0, 0, 0, 0, 0, 0, 0] + cores = [3, 4, 5, 6, 7] + gpus = [0, 0, 0, 0, 0] # Setup the signal handling ops = [] @@ -2056,6 +2247,12 @@ def handle_signal_terminate(signum, frame): log.warning("Received signal %i %s", signum, SIGNAL_NAMES[signum]) try: ops[0].shutdown() + if SIGNAL_NAMES[signum] == "SIGINT": + print("****Observation is Interrupted****") + os._exit(0) + if SIGNAL_NAMES[signum] == "SIGALRM": + print("****Observation is Complete****") + os._exit(0) except IndexError: pass shutdown_event.set() @@ -2066,6 +2263,7 @@ def handle_signal_terminate(signum, frame): signal.SIGQUIT, signal.SIGTERM, signal.SIGTSTP, + signal.SIGALRM ]: signal.signal(sig, handle_signal_terminate) @@ -2075,10 +2273,11 @@ def handle_signal_terminate(signum, frame): fdomain_ring = Ring(name="fengine", space="cuda_host") gridandfft_ring = Ring(name="gridandfft", space="cuda") - # Setup Antennas - # TODO: Some sort of switch for other stations? + # Setup the station - lwasv_antennas = lwasv.antennas + lwa_station = lwasv + if args.lwa1: + lwa_station = lwa1 # Setup threads @@ -2132,8 +2331,10 @@ def handle_signal_terminate(signum, frame): "--offline set but no file provided via --tbnfile or --tbffile" ) else: - # It would be great is we could pull this from ADP MCS... - utc_start_dt = datetime.datetime.strptime(args.utcstart, "%Y_%m_%dT%H_%M_%S") + if args.utcstart is None: + utc_start_dt = get_utc_start() + else: + utc_start_dt = datetime.datetime.strptime(args.utcstart, DATE_FORMAT) # Note: Capture uses Bifrost address+socket objects, while output uses # plain Python address+socket objects. @@ -2175,7 +2376,7 @@ def handle_signal_terminate(signum, frame): log, fdomain_ring, gridandfft_ring, - lwasv_antennas, + lwa_station, skymodes=args.dft_skymodes_1D, ntime_gulp=args.nts, accumulation_time=args.accumulate, @@ -2191,7 +2392,7 @@ def handle_signal_terminate(signum, frame): log, fdomain_ring, gridandfft_ring, - lwasv_antennas, + lwa_station, args.imagesize, args.imageres, ntime_gulp=args.nts, @@ -2238,6 +2439,9 @@ def handle_signal_terminate(signum, frame): for thread in threads: thread.daemon = False thread.start() + + signal.alarm(args.duration) + while not shutdown_event.is_set(): # Keep threads alive -- if reader is still alive, prevent timeout signal from executing @@ -2260,4 +2464,5 @@ def handle_signal_terminate(signum, frame): if __name__ == "__main__": - main() + args, parser = gen_args(return_parser=True) + main(args, parser) diff --git a/LWA/MCS2.py b/LWA/MCS2.py new file mode 100644 index 00000000..b0878ba1 --- /dev/null +++ b/LWA/MCS2.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function, division, absolute_import +try: + range = xrange + def data_to_hex(data): + return data.encode('hex') +except NameError: + def data_to_hex(data): + try: + return data.hex() + except TypeError: + return data.encode().hex() + +try: + import queue +except ImportError: + import Queue as queue +import time +from datetime import datetime +import socket +from ConsumerThread import ConsumerThread +from SocketThread import UDPRecvThread +import string +import struct + +import socket +from threading import Thread, Event, Semaphore + +# Maximum number of bytes to receive from MCS +MCS_RCV_BYTES = 16*1024 + +# Note: Unless otherwise noted, slots are referenced to Unix time +def get_current_slot(): + # Returns current slot in Unix time + return int(time.time()) +def get_current_mpm(): + # Returns current milliseconds past midnight as an integer + dt = datetime.utcnow() + ms = int(dt.microsecond / 1000.) + return ((dt.hour*60 + dt.minute)*60 + dt.second)*1000 + ms +def slot2utc(slot=None): + if slot is None: + slot = get_current_slot() + return time.gmtime(slot) +# TODO: What is 'station time'? Is it UTC? +def slot2dayslot(slot=None): + utc = slot2utc(slot) + dayslot = (utc.tm_hour*60 + utc.tm_min)*60 + utc.tm_sec + return dayslot +def slot2mpm(slot=None): + return slot2dayslot(slot) * 1000 +def slot2mjd(slot=None): + tt = slot2utc(slot) + # Source: SkyField + janfeb = tt.tm_mon < 3 + jd_day = tt.tm_mday + jd_day += 1461 * (tt.tm_year + 4800 - janfeb) // 4 + jd_day += 367 * (tt.tm_mon - 2 + janfeb * 12) // 12 + jd_day -= 3 * ((tt.tm_year + 4900 - janfeb) // 100) // 4 + jd_day -= 32075 + mjd = tt.tm_sec + mjd = mjd*(1./60) + tt.tm_min + mjd = mjd*(1./60) + tt.tm_hour + mjd = mjd*(1./24) + (jd_day - 2400000.5) + mjd -= 0.5 + return mjd + +def mib_parse_label(data): + """Splits an MIB label into a list of arguments + E.g., "ANT71_TEMP_MAX" --> ['ANT', 71, 'TEMP', 'MAX'] + """ + args = [] + arg = '' + mode = None + for c in data: + if c in string.ascii_uppercase: + if mode == 'i': + args.append(int(arg)) + arg = '' + arg += c + mode = 's' + elif c in string.digits: + if mode == 's': + args.append(arg) + arg = '' + arg += c + mode = 'i' + elif c == '_': + if mode is not None: + args.append(int(arg) if mode == 'i' else arg) + arg = '' + mode = None + args.append(int(arg) if mode == 'i' else arg) + key = mib_args2key(args) + return key, args +def mib_args2key(args): + """Merges an MIB label arg list back into a label suitable for + use as a lookup key (all indexes are removed).""" + return '_'.join([arg for arg in args if not isinstance(arg, int)]) + +class Msg(object): + count = 0 + # Note: MsgSender will automatically set src + def __init__(self, illegal_argument=None, + src=None, dst=None, cmd=None, ref=None, data='', dst_ip=None, + pkt=None, src_ip=None): + assert(illegal_argument is None) # Ensure named args only + self.dst = dst + self.src = src + self.cmd = cmd + self.ref = ref + if self.ref is None: + self.ref = Msg.count % 10**9 + Msg.count += 1 + self.mjd = None + self.mpm = None + self.data = data + self.dst_ip = dst_ip + self.slot = None # For convenience, not part of encoded pkt + if pkt is not None: + self.decode(pkt) + self.src_ip = src_ip + def __str__(self): + if self.slot is None: + return ("" % + (self.ref, self.cmd, self.src, self.dst, + self.data, data_to_hex(self.data))) + else: + return (("") % + (self.ref, self.cmd, self.src, self.src_ip, + self.dst, self.data, data_to_hex(self.data), + self.slot)) + def decode(self, pkt): + hdr = pkt[:38] + try: + hdr = hdr.decode() + except Exception as e: + # Python2 catch/binary data catch + print('hdr error:', str(e), '@', hdr) + pass + + self.slot = get_current_slot() + self.dst = hdr[:3] + self.src = hdr[3:6] + self.cmd = hdr[6:9] + self.ref = int(hdr[9:18]) + datalen = int(hdr[18:22]) + self.mjd = int(hdr[22:28]) + self.mpm = int(hdr[28:37]) + space = hdr[37] + self.data = pkt[38:38+datalen] + # WAR for DATALEN parameter being wrong for BAM commands (FST too?) + broken_commands = ['BAM']#, 'FST'] + if self.cmd in broken_commands: + self.data = pkt[38:] + + def create_reply(self, accept, status, data=''): + msg = Msg(#src=self.dst, + dst=self.src, + cmd=self.cmd, + ref=self.ref, + dst_ip=self.src_ip) + #msg.mjd, msg.mpm = getTime() + response = 'A' if accept else 'R' + msg.data = response + str(status).rjust(7) + try: + msg.data = msg.data.encode() + except AttributeError: + # Python2 catch + pass + try: + data = data.encode() + except AttributeError: + # Python2 catch + pass + msg.data = msg.data+data + return msg + def is_valid(self): + return (self.dst is not None and len(self.dst) <= 3 and + self.src is not None and len(self.src) <= 3 and + self.cmd is not None and len(self.cmd) <= 3 and + self.ref is not None and (0 <= self.ref < 10**9) and + self.mjd is not None and (0 <= self.mjd < 10**6) and + self.mpm is not None and (0 <= self.mpm < 10**9) and + len(self.data) < 10**4) + def encode(self): + self.mjd = int(slot2mjd()) + self.mpm = get_current_mpm() + assert( self.is_valid() ) + pkt = (self.dst.ljust(3) + + self.src.ljust(3) + + self.cmd.ljust(3) + + str(self.ref ).rjust(9) + + str(len(self.data)).rjust(4) + + str(self.mjd ).rjust(6) + + str(self.mpm ).rjust(9) + + ' ') + try: + pkt = pkt.encode() + self.data = self.data.encode() + except (AttributeError, UnicdoeDecodeError): + # Python2 catch + pass + return pkt+self.data + +class MsgReceiver(UDPRecvThread): + def __init__(self, address, subsystem='ALL'): + UDPRecvThread.__init__(self, address) + self.subsystem = subsystem + self.msg_queue = queue.Queue() + self.name = 'MCS.MsgReceiver' + def process(self, pkt, src_ip): + if len(pkt): + msg = Msg(pkt=pkt, src_ip=src_ip) + if ( self.subsystem == 'ALL' or + msg.dst == 'ALL' or + self.subsystem == msg.dst ): + self.msg_queue.put(msg) + def shutdown(self): + self.msg_queue.put(ConsumerThread.STOP) + #print(self.name, "shutdown") + def get(self, timeout=None): + try: + return self.msg_queue.get(True, timeout) + except queue.Empty: + return None + +class MsgSender(ConsumerThread): + def __init__(self, dst_addr, subsystem, + max_attempts=5): + ConsumerThread.__init__(self) + self.subsystem = subsystem + self.max_attempts = max_attempts + self.socket = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM) + #self.socket.connect(address) + self.dst_ip = dst_addr[0] + self.dst_port = dst_addr[1] + self.name = 'MCS.MsgSender' + def process(self, msg): + msg.src = self.subsystem + try: + pkt = msg.encode() + except UnicodeDecodeError: + pkt = msg.data + dst_ip = msg.dst_ip if msg.dst_ip is not None else self.dst_ip + dst_addr = (dst_ip, self.dst_port) + #print("Sending msg to", dst_addr) + for attempt in range(self.max_attempts-1): + try: + #self.socket.send(pkt) + self.socket.sendto(pkt, dst_addr) + except socket.error: + time.sleep(0.001) + else: + return + #self.socket.send(pkt) + self.socket.sendto(pkt, dst_addr) + def shutdown(self): + #print(self.name, "shutdown") + pass + +# Simple interface for communicating with adp-control service +class Communicator(object): + def __init__(self, subsystem='MCS'): + # TODO: Load port numbers etc. from config + #sender = MsgSender(("localhost",1742), subsystem=subsystem) + self.sender = MsgSender(("adp",1742), subsystem=subsystem) + self.receiver = MsgReceiver(("0.0.0.0",1743)) + self.sender.input_queue = queue.Queue() + self.sender.daemon = True + self.receiver.daemon = True + self.sender.start() + self.receiver.start() + def __enter__(self): + return self + def __exit__(self, type, value, tb): + self.sender.request_stop() + self.receiver.request_stop() + self.sender.join() + self.receiver.join() + def _get_reply(self, timeout): + reply = self.receiver.get(timeout=timeout) + if reply is None: + raise RuntimeError("MCS request timed out") + # Parse the data section of the reply + response, status, data = reply.data[:1], reply.data[1:8], reply.data[8:] + try: + response = response.decode() + status = status.decode() + except AttributeError: + # Python2 catch + pass + if response != 'A': + raise ValueError("Message not accepted: response=%r, status=%r, data=%r" % (response, status, data)) + return status, data + def _send(self, msg, timeout): + self.sender.put(msg) + self.status, data = self._get_reply(timeout) + return data + def report(self, data, fmt=None, timeout=5.): + msg = Msg(dst='ADP', cmd='RPT', data=data) + data = self._send(msg, timeout) + if fmt is None or fmt == 's': + try: + data = data.decode() + except AttributeError: + # Python2 catch + pass + return data + else: + return struct.unpack('>'+fmt, data) + def command(self, cmd, data='', timeout=5.): + msg = Msg(dst='ADP', cmd=cmd, data=data) + self._send(msg, timeout) + +class SafeSocket(socket.socket): + def __init__(self, *args, **kwargs): + socket.socket.__init__(self, *args, **kwargs) + l_onoff = 1 + l_linger = 0 + self.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack('ii', l_onoff, l_linger)) + +class Synchronizer(object): + def __init__(self, group, addr=("adp",23820)): + self.group = group + self.addr = addr + self.socket = SafeSocket(socket.AF_INET, + socket.SOCK_STREAM) + self.socket.connect(addr) + self.socket.settimeout(10) # Prevent recv() from blocking indefinitely + msg = 'GROUP:'+str(group) + try: + msg = msg.encode() + except AttributeError: + # Python2 catch + pass + self.socket.send(msg) + def __call__(self, tag=None): + msg = 'TAG:'+str(tag) + try: + msg = msg.encode() + except AttributeError: + # Python2 catch + pass + self.socket.send(msg) + reply = self.socket.recv(4096) + try: + reply = reply.decode() + except AttributeError: + # Python2 catch + pass + expected_reply = 'GROUP:'+str(self.group) + ',TAG:'+str(tag) + if reply != expected_reply: + raise ValueError("Unexpected reply '%s', expected '%s'" % + (reply, expected_reply)) + +class SynchronizerGroup(object): + def __init__(self, group): + self.group = group + self.socks = [] + self.pending_lock = Semaphore() + self.shutdown_event = Event() + self.run_thread = Thread(target=self.run) + #self.run_thread.daemon = True + self.tStart = time.time() + self.run_thread.start() + #def __del__(self): + # self.shutdown() + def log(self, value): + print("[%.3f] %s" % (time.time()-self.tStart, value)) + def shutdown(self): + self.shutdown_event.set() + self.log("SynchronizerGroup "+self.group+": run joining") + self.run_thread.join() + self.log("SynchronizerGroup "+self.group+": run thread joined") + def add(self, sock, addr): + try: + self.pending_lock.acquire() + self.pending.append(sock) + except AttributeError: + self.pending = [sock,] + finally: + self.pending_lock.release() + def run(self): + while not self.shutdown_event.is_set(): + try: + self.pending_lock.acquire() + self.socks.extend(self.pending) + self.log("SynchronizerGroup "+self.group+": added "+str(len(self.pending))+" clients") + del self.pending + except AttributeError: + pass + finally: + self.pending_lock.release() + + if len(self.socks) == 0: + # Avoid spinning after all socks close + time.sleep(0.1) + continue + + # Find out where everyone is + tags = [] + i = 0 + while i < len(self.socks) and not self.shutdown_event.is_set(): + sock = self.socks[i] + try: + tag_msg = sock.recv(4096) + except socket.timeout as e: + self.log("WARNING: Synchronizer (1a): socket.timeout %s client %i: %s" % (self.group, i, e)) + self.socks[i].close() + del self.socks[i] + continue + except socket.error as e: + self.log("WARNING: Synchronizer (1b): socket.error %s client %i: %s" % (self.group, i, e)) + self.socks[i].close() + del self.socks[i] + continue + if tag_msg[:4] != 'TAG:': + e = tag_msg + self.log("WARNING: Synchronizer (1c): Unexpected message %s client %i: %s" % (self.group, i, e)) + self.socks[i].close() + del self.socks[i] + continue + tags.append( int(tag_msg[4:22], 10) ) + i += 1 + + # Elect tag0, the reference time tag + try: + tag0 = max(tags) + #print("ELECTED %i as tag0 for %s" % (tag0, self.group)) + except ValueError: + continue + + # Speed up the slow ones a little bit + slow = [i for i,tag in enumerate(tags) if tag < tag0] + if len(slow) > 0: + j = 0 + #slowFactors = {} + while slow and j < 5 and not self.shutdown_event.is_set(): + ## Deal with each slow client in turn + for i in slow: + ### Send - ignoring errors + sock, tag = self.socks[i], tags[i] + try: + sock.send('GROUP:'+self.group+',TAG:'+str(tag0)) + #try: + # slowFactors[i] += 1 + #except KeyError: + # slowFactors[i] = 1 + except socket.error as e: + self.log("WARNING: Synchronizer (2a): socket.error %s client %i: %s" % (self.group, i, e)) + + ### Receive - ignoring errors + try: + tag_msg = sock.recv(4096) + except socket.timeout as e: + self.log("WARNING: Synchronizer (2b): socket.timeout %s client %i: %s" % (self.group, i, e)) + except socket.error as e: + self.log("WARNING: Synchronizer (2c): socket.error %s client %i: %s" % (self.group, i, e)) + if not tag_msg.startswith('TAG:'): + e = tag_msg + self.log("WARNING: Synchronizer (2d): Unexpected message %s client %i: %s" % (self.group, i, e)) + continue + tags[i] = int(tag_msg[4:22], 10) + #print("Updated %s client %i tag to %i (tag0 is %i; delta is now %i" % (self.group, i, tags[i], tag0, tags[i]-tag0)) + + ## Evaluate the latest batch of timetags + slow = [i for i,tag in enumerate(tags) if tag < tag0] + + ## Update the iteration variable + j += 1 + + ### Report on what we've done + #for i,v in slowFactors.items(): + # print("WARNING: Synchronizer (2e): slipped %s client %i forward by %s" % (self.group, i, v)) + + # Send to everyone regardless to make sure the fast ones don't falter + i = 0 + while i < len(self.socks): + sock, tag = self.socks[i], tags[i] + if tag != tag0: + self.log("WARNING: Synchronizer (3a): Tag mismatch: "+str(tag)+" != "+str(tag0)+" from "+self.group+" client "+str(i)+" (delta is "+str((tag-tag0)/196e6*1000)+" ms)") + + try: + sock.send('GROUP:'+self.group+',TAG:'+str(tag0)) + except socket.error as e: + self.log("WARNING: Synchronizer (3b): socket.error: %s client %i: %s" % (self.group, i, e)) + self.socks[i].close() + del self.socks[i] + del tags[i] + continue + i += 1 + + # Done with the iteration + ##print("SYNCED "+str(len(self.socks))+" clients in "+self.group) + + self.log("SynchronizerGroup "+self.group+": shut down") + +class SynchronizerServer(object): + def __init__(self, addr=("0.0.0.0",23820)): + self.addr = addr + self.sock = SafeSocket(socket.AF_INET, + socket.SOCK_STREAM) + self.sock.settimeout(5) # Prevent accept() from blocking indefinitely + self.sock.bind(addr) + self.sock.listen(32) + self.groups = {}#defaultdict(SynchronizerGroup) + self.shutdown_event = Event() + def shutdown(self): + self.shutdown_event.set() + def run(self): + while not self.shutdown_event.is_set(): + try: + sock, addr = self.sock.accept() + sock.settimeout(10) + except socket.timeout: + continue + group_msg = sock.recv(4096) + if not group_msg.startswith('GROUP:'): + #raise ValueError("Unexpected message: "+group_msg) + print("WARNING: Synchronizer: Unexpected message: "+group_msg) + group = group_msg[len('GROUP:'):] + if group not in self.groups: + self.groups[group] = SynchronizerGroup(group) + self.groups[group].add(sock, addr) + for group in self.groups.values(): + group.shutdown() + # Note: This seems to be necessary to avoid 'address already in use' + self.sock.shutdown(socket.SHUT_RDWR) + self.sock.close() diff --git a/LWA/SocketThread.py b/LWA/SocketThread.py new file mode 100644 index 00000000..0c617e8c --- /dev/null +++ b/LWA/SocketThread.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import threading +import socket +try: + import queue +except NameError: + import Queue as queue + +class UDPRecvThread(threading.Thread): + #STOP = '__UDPRecvThread_STOP__' + def __init__(self, address, bufsize=16384): + threading.Thread.__init__(self) + self._addr = address + self._bufsize = bufsize + self._msg_queue = queue.Queue() # For default behaviour + self.socket = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM) + self.socket.bind(address) + self.stop_requested = threading.Event() + def request_stop(self): + """ + sendsock = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM) + sendsock.connect(self._addr) + sendsock.send(UDPRecvThread.STOP) + """ + self.stop_requested.set() + # WAR for "107: Transpose endpoint is not connected" in socket.shutdown + self.socket.connect(("0.0.0.0", 0)) + self.socket.shutdown(socket.SHUT_RD) + self.socket.close() + def run(self): + while True:#not self.stop_requested.is_set(): + #pkt = self.socket.recv(self._bufsize) + pkt, src_addr = self.socket.recvfrom(self._bufsize) + if self.stop_requested.is_set(): + break + #if pkt == UDPRecvThread.STOP: + # break + src_ip = src_addr[0] + self.process(pkt, src_ip) + self.shutdown() + def process(self, pkt, src_ip): + """Overide this in subclass""" + self._msg_queue.put((pkt,src_ip)) # Default behaviour + def shutdown(self): + """Overide this in subclass""" + pass + def get(self, timeout=None): + try: + return self._msg_queue.get(True, timeout) + except queue.Empty: + return None + +if __name__ == '__main__': + port = 8321 + rcv = UDPRecvThread(("localhost", port)) + #rcv.daemon = True + rcv.start() + print("Waiting for packet on port", port) + pkt,src_ip = rcv.get(timeout=5.) + if pkt is not None: + print("Received packet:", pkt,src_ip) + else: + print("Timed out waiting for packet") + rcv.request_stop() + rcv.join()