From 064cad66af0d9b6709005a652c56ca272ad64070 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Tue, 29 Oct 2024 16:00:35 +0200 Subject: [PATCH] add outputs and read degrid chan mapping from dds by default --- pfb/parser/degrid.yaml | 13 +++-- pfb/parser/fluxmop.yaml | 5 +- pfb/parser/grid.yaml | 5 +- pfb/parser/hci.yaml | 5 +- pfb/parser/init.yaml | 5 +- pfb/parser/klean.yaml | 5 +- pfb/parser/sara.yaml | 5 +- pfb/utils/misc.py | 27 +++++++-- pfb/utils/stokes2vis.py | 4 ++ pfb/workers/degrid.py | 123 +++++++++++++++++++++++++--------------- pfb/workers/grid.py | 10 +++- pfb/workers/init.py | 10 +++- tests/test_sara.py | 1 + 13 files changed, 153 insertions(+), 65 deletions(-) diff --git a/pfb/parser/degrid.yaml b/pfb/parser/degrid.yaml index 9b814077..b1578ff7 100644 --- a/pfb/parser/degrid.yaml +++ b/pfb/parser/degrid.yaml @@ -5,12 +5,17 @@ inputs: abbreviation: ms info: Path to measurement set. + suffix: + dtype: str + default: 'main' + info: + Can be used to specify a custom name for the image space data products mds: dtype: str - required: true abbreviation: mds info: - Path to the mds that needs to be degridded + Optional path to mds to use for degridding. + By default mds is inferred from output-filename. model-column: dtype: str default: MODEL_DATA @@ -38,10 +43,10 @@ inputs: channels-per-image: dtype: int abbreviation: cpi - default: -1 info: Number of channels per image. - Default (-1, 0, None) -> dataset per spw. + Default (None) -> read mapping from dds. + (-1, 0) -> one band per SPW. accumulate: dtype: bool default: false diff --git a/pfb/parser/fluxmop.yaml b/pfb/parser/fluxmop.yaml index 7031df46..965998bf 100644 --- a/pfb/parser/fluxmop.yaml +++ b/pfb/parser/fluxmop.yaml @@ -70,7 +70,10 @@ inputs: - (.)out.yml outputs: - {} + dds-out: + implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/parser/grid.yaml b/pfb/parser/grid.yaml index da54e6f9..39755c3f 100644 --- a/pfb/parser/grid.yaml +++ b/pfb/parser/grid.yaml @@ -122,7 +122,10 @@ inputs: - (.)out.yml outputs: - {} + dds-out: + implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/parser/hci.yaml b/pfb/parser/hci.yaml index 92181612..35bf1d04 100644 --- a/pfb/parser/hci.yaml +++ b/pfb/parser/hci.yaml @@ -176,7 +176,10 @@ inputs: - (.)cgopts.yml outputs: - {} + dir-out: + implicit: '{current.output-filename}_{current.product}.fds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/parser/init.yaml b/pfb/parser/init.yaml index 2ff1ed91..9acaa602 100644 --- a/pfb/parser/init.yaml +++ b/pfb/parser/init.yaml @@ -132,7 +132,10 @@ inputs: - (.)dist.yml outputs: - {} + xds-out: + implicit: '{current.output-filename}_{current.product}.xds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/parser/klean.yaml b/pfb/parser/klean.yaml index f3cdae2b..5683ae0d 100644 --- a/pfb/parser/klean.yaml +++ b/pfb/parser/klean.yaml @@ -133,7 +133,10 @@ inputs: - (.)out.yml outputs: - {} + dds-out: + implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/parser/sara.yaml b/pfb/parser/sara.yaml index 8bdc77ab..706f58af 100644 --- a/pfb/parser/sara.yaml +++ b/pfb/parser/sara.yaml @@ -120,7 +120,10 @@ inputs: Reduce the regularisation strength by this fraction at the outset. outputs: - {} + dds-out: + implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds' + dtype: Directory + must_exist: false policies: pass_missing_as_none: true diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index c26d5908..eef8bf80 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -399,7 +399,6 @@ def construct_mappings(ms_name, if not idx.any(): continue idx0 = np.argmax(idx) # returns index of first True element - # np.searchsorted here? try: # returns zero if not idx.any() assert idx[idx0] @@ -446,11 +445,27 @@ def construct_mappings(ms_name, row_mapping[ms][idt]['start_indices'] = ridx row_mapping[ms][idt]['counts'] = rcounts - nfreq_chunks = nchan_in // cpit - freq_chunks = (cpit,)*nfreq_chunks - rem = nchan_in - nfreq_chunks * cpit - if rem: - freq_chunks += (rem,) + freq_idx0 = freq_mapping[ms][idt]['start_indices'][0] + if freq_idx0 != 0: + freq_chunks = (freq_idx0,) + tuple(freq_mapping[ms][idt]['counts']) + else: + freq_chunks = tuple(freq_mapping[ms][idt]['counts']) + freq_idxf = np.sum(freq_chunks) + if freq_idxf != nchan_in: + freq_chunkf = nchan_in - freq_idxf + freq_chunks += (freq_chunkf,) + + try: + assert np.sum(freq_chunks) == nchan_in + except Exception as e: + raise RuntimeError("Something went wrong constructing the " + "frequency mapping. sum(fchunks != nchan)") + + # nfreq_chunks = nchan_in // cpit + # freq_chunks = (cpit,)*nfreq_chunks + # rem = nchan_in - nfreq_chunks * cpit + # if rem: + # freq_chunks += (rem,) ms_chunks[ms].append({'row': row_chunks, 'chan': freq_chunks}) diff --git a/pfb/utils/stokes2vis.py b/pfb/utils/stokes2vis.py index 06f2e80c..63f8f054 100644 --- a/pfb/utils/stokes2vis.py +++ b/pfb/utils/stokes2vis.py @@ -34,6 +34,8 @@ def single_stokes( utime=None, tbin_idx=None, tbin_counts=None, + chan_low=None, + chan_high=None, radec=None, antpos=None, poltype=None, @@ -315,6 +317,8 @@ def single_stokes( 'freq_out': freq_out, 'freq_min': freq_min, 'freq_max': freq_max, + 'chan_low': chan_low, + 'chan_high': chan_high, 'bandid': bandid, 'time_out': time_out, 'time_min': utime.min(), diff --git a/pfb/workers/degrid.py b/pfb/workers/degrid.py index f81cd581..431652a3 100644 --- a/pfb/workers/degrid.py +++ b/pfb/workers/degrid.py @@ -18,6 +18,9 @@ def degrid(**kw): ''' Predict model visibilities to measurement sets. + The default behaviour is to read the frequency mapping from the dds and + degrid one image per band. + If channels-per-image is provided, the model is evaluated from the mds. ''' opts = OmegaConf.create(kw) @@ -41,17 +44,28 @@ def degrid(**kw): msnames.append(*list(map(msstore.fs.unstrip_protocol, mslist))) except: raise ValueError(f"No MS at {ms}") - if len(opts.ms) > 1: - raise ValueError(f"There must be a single MS at {opts.ms}") opts.ms = msnames - modelstore = DaskMSStore(opts.mds.rstrip('/')) - try: - assert modelstore.exists() - except Exception as e: - raise ValueError(f"There must be a model at " - f"to {opts.mds}") - opts.mds = modelstore.url + basename = opts.output_filename + + if opts.mds is None: + mds_store = DaskMSStore(f'{basename}_{opts.suffix}_model.mds') + else: + mds_store = DaskMSStore(opts.mds) + try: + assert mds_store.exists() + except Exception as e: + raise ValueError(f"No mds at {opts.mds}") + opts.mds = mds_store.url + + dds_store = DaskMSStore(f'{basename}_{opts.suffix}.dds') + if opts.channels_per_image is None and not mds_store.exists(): + try: + assert dds_store.exists() + except Exception as e: + raise ValueError(f"There must be a dds at {dds_store.url}. " + "Specify mds and channels-per-image to degrid from mds.") + opts.dds = dds_store.url OmegaConf.set_struct(opts, True) if opts.product.upper() not in ["I"]: @@ -96,6 +110,7 @@ def _degrid(**kw): from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import xds_to_storage_table as xds_to_table + from daskms.fsspec_store import DaskMSStore import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import model as im2vis @@ -103,7 +118,7 @@ def _degrid(**kw): from pfb.utils.fits import load_fits, data_from_header, set_wcs from regions import Regions from astropy.io import fits - from pfb.utils.misc import compute_context + from pfb.utils.naming import xds_from_url import xarray as xr import sympy as sm from sympy.utilities.lambdify import lambdify @@ -112,29 +127,18 @@ def _degrid(**kw): resize_thread_pool(opts.nthreads) client = get_client() - mds = xr.open_zarr(opts.mds) - foo = client.scatter(mds, broadcast=True) - wait(foo) - # grid spec - cell_rad = mds.cell_rad_x - cell_deg = np.rad2deg(cell_rad) - nx = mds.npix_x - ny = mds.npix_y - x0 = mds.center_x - y0 = mds.center_y - radec = (mds.ra, mds.dec) - - # model func - params = sm.symbols(('t','f')) - params += sm.symbols(tuple(mds.params.values)) - symexpr = parse_expr(mds.parametrisation) - modelf = lambdify(params, symexpr) - texpr = parse_expr(mds.texpr) - tfunc = lambdify(params[0], texpr) - fexpr = parse_expr(mds.fexpr) - ffunc = lambdify(params[1], fexpr) + dds_store = DaskMSStore(opts.dds) + mds_store = DaskMSStore(opts.mds) + if opts.channels_per_image is None: + if dds_store.exists(): + dds, dds_list = xds_from_url(dds_store.url) + cpi = 0 + for ds in dds: + cpi = np.maximum(ds.chan.size, cpi) + else: + cpi = opts.channels_per_image if opts.freq_range is not None and len(opts.freq_range): fmin, fmax = opts.freq_range.strip(' ').split(':') @@ -160,7 +164,32 @@ def _degrid(**kw): construct_mappings(opts.ms, None, ipi=opts.integrations_per_image, - cpi=opts.channels_per_image) + cpi=cpi, + freq_min=freq_min, + freq_max=freq_max) + + mds = xr.open_zarr(opts.mds) + foo = client.scatter(mds, broadcast=True) + wait(foo) + + # grid spec + cell_rad = mds.cell_rad_x + cell_deg = np.rad2deg(cell_rad) + nx = mds.npix_x + ny = mds.npix_y + x0 = mds.center_x + y0 = mds.center_y + radec = (mds.ra, mds.dec) + + # model func + params = sm.symbols(('t','f')) + params += sm.symbols(tuple(mds.params.values)) + symexpr = parse_expr(mds.parametrisation) + modelf = lambdify(params, symexpr) + texpr = parse_expr(mds.texpr) + tfunc = lambdify(params[0], texpr) + fexpr = parse_expr(mds.fexpr) + ffunc = lambdify(params[1], fexpr) # load region file if given masks = [] @@ -226,7 +255,7 @@ def _degrid(**kw): for i, mask in enumerate(masks): out_data = [] columns = [] - for ds in xds: + for k, ds in enumerate(xds): if i == 0: column_name = opts.model_column else: @@ -240,26 +269,29 @@ def _degrid(**kw): # time <-> row mapping utime = da.from_array(utimes[ms][idt], - chunks=opts.integrations_per_image) + chunks=opts.integrations_per_image) tidx = da.from_array(time_mapping[ms][idt]['start_indices'], - chunks=1) + chunks=1) tcnts = da.from_array(time_mapping[ms][idt]['counts'], - chunks=1) + chunks=1) ridx = da.from_array(row_mapping[ms][idt]['start_indices'], - chunks=opts.integrations_per_image) + chunks=opts.integrations_per_image) rcnts = da.from_array(row_mapping[ms][idt]['counts'], - chunks=opts.integrations_per_image) + chunks=opts.integrations_per_image) - # freq <-> band mapping + # freq <-> band mapping (entire freq axis) freq = da.from_array(freqs[ms][idt], - chunks=opts.channels_per_image) - fidx = da.from_array(freq_mapping[ms][idt]['start_indices'], - chunks=1) - fcnts = da.from_array(freq_mapping[ms][idt]['counts'], - chunks=1) + chunks=ms_chunks[ms][k]['chan']) + fcnts = np.array(ms_chunks[ms][k]['chan']) + fidx = np.concatenate((np.array([0]), np.cumsum(fcnts)))[0:-1] + + fidx = da.from_array(fidx, + chunks=1) + fcnts = da.from_array(fcnts, + chunks=1) - # number of chunks need to math in mapping and coord + # number of chunks need to match in mapping and coord ntime_out = len(tidx.chunks[0]) assert len(utime.chunks[0]) == ntime_out nfreq_out = len(fidx.chunks[0]) @@ -267,6 +299,7 @@ def _degrid(**kw): # and they need to match the number of row chunks uvw = clone(ds.UVW.data) assert len(uvw.chunks[0]) == len(tidx.chunks[0]) + vis = comps2vis(uvw, utime, freq, diff --git a/pfb/workers/grid.py b/pfb/workers/grid.py index b434faa5..4e43cdf4 100644 --- a/pfb/workers/grid.py +++ b/pfb/workers/grid.py @@ -321,6 +321,8 @@ def _grid(**kw): xds_dct[tbid]['radec'] = (ds.ra, ds.dec) xds_dct[tbid]['time_out'] = times_out[0] xds_dct[tbid]['freq_out'] = freqs_out[b] + xds_dct[tbid]['chan_low'] = ds.chan_low + xds_dct[tbid]['chan_high'] = ds.chan_high else: ntime = ntime_in times_out = times_in @@ -335,6 +337,8 @@ def _grid(**kw): xds_dct[tbid]['radec'] = (ds.ra, ds.dec) xds_dct[tbid]['time_out'] = times_out[t] xds_dct[tbid]['freq_out'] = freqs_out[b] + xds_dct[tbid]['chan_low'] = ds.chan_low + xds_dct[tbid]['chan_high'] = ds.chan_high if opts.dirty: print(f"Image size = (ntime={ntime}, nband={nband}, " @@ -369,6 +373,8 @@ def _grid(**kw): dsl = ds_dct['dsl'] time_out = ds_dct['time_out'] freq_out = ds_dct['freq_out'] + chan_low = ds_dct['chan_low'] + chan_high = ds_dct['chan_high'] iter0 = 0 if from_cache: out_ds_name = f'{dds_store.url}/time{timeid}_band{bandid}.zarr' @@ -418,7 +424,9 @@ def _grid(**kw): 'super_resolution_factor': opts.super_resolution_factor, 'field_of_view': opts.field_of_view, 'product': opts.product, - 'niters': iter0 + 'niters': iter0, + 'chan_low': chan_low, + 'chan_high': chan_high, } # get the model diff --git a/pfb/workers/init.py b/pfb/workers/init.py index 8abdda44..07dfee75 100644 --- a/pfb/workers/init.py +++ b/pfb/workers/init.py @@ -276,7 +276,7 @@ def _init(**kw): utimes[ms][idt][It], ridx, rcnts, radecs[ms][idt], - fi, ti, ims, ms]) + fi, ti, ims, ms, flow, flow+fcounts]) futures = [] associated_workers = {} @@ -286,7 +286,7 @@ def _init(**kw): while idle_workers and len(datasets): # Seed each worker with a task. # pop so len(datasets) -> 0 (subds, jones, freqsi, chan_widthi, utimesi, ridx, rcnts, - radeci, fi, ti, ims, ms) = datasets.pop(0) + radeci, fi, ti, ims, ms, chan_low, chan_high) = datasets.pop(0) worker = idle_workers.pop() future = client.submit(single_stokes, @@ -301,6 +301,8 @@ def _init(**kw): utime=utimesi, tbin_idx=ridx, tbin_counts=rcnts, + chan_low=chan_low, + chan_high=chan_high, radec=radeci, antpos=antpos[ms], poltype=poltype[ms], @@ -345,7 +347,7 @@ def _init(**kw): # pop so len(datasets) -> 0 if len(datasets): (subds, jones, freqsi, chan_widthi, utimesi, ridx, rcnts, - radeci, fi, ti, ims, ms) = datasets.pop(0) + radeci, fi, ti, ims, ms, chan_low, chan_high) = datasets.pop(0) future = client.submit(single_stokes, dc1=dc1, @@ -359,6 +361,8 @@ def _init(**kw): utime=utimesi, tbin_idx=ridx, tbin_counts=rcnts, + chan_low=chan_low, + chan_high=chan_high, radec=radeci, antpos=antpos[ms], poltype=poltype[ms], diff --git a/tests/test_sara.py b/tests/test_sara.py index 405319f3..4f6feabb 100644 --- a/tests/test_sara.py +++ b/tests/test_sara.py @@ -235,6 +235,7 @@ def test_sara(ms_name): degrid_args[key.replace("-", "_")] = schema.degrid["inputs"][key]["default"] degrid_args["ms"] = [str(test_dir / 'test_ascii_1h60.0s.MS')] degrid_args["mds"] = f'{outname}_main_model.mds' + degrid_args["dds"] = f'{outname}_main.dds' degrid_args["channels_per_image"] = 1 degrid_args["nthreads"] = 8 degrid_args["do_wgridding"] = do_wgridding