Skip to content

Commit

Permalink
add outputs and read degrid chan mapping from dds by default
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Oct 29, 2024
1 parent 699d5a5 commit 064cad6
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 65 deletions.
13 changes: 9 additions & 4 deletions pfb/parser/degrid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ inputs:
abbreviation: ms
info:
Path to measurement set.
suffix:
dtype: str
default: 'main'
info:
Can be used to specify a custom name for the image space data products
mds:
dtype: str
required: true
abbreviation: mds
info:
Path to the mds that needs to be degridded
Optional path to mds to use for degridding.
By default mds is inferred from output-filename.
model-column:
dtype: str
default: MODEL_DATA
Expand Down Expand Up @@ -38,10 +43,10 @@ inputs:
channels-per-image:
dtype: int
abbreviation: cpi
default: -1
info:
Number of channels per image.
Default (-1, 0, None) -> dataset per spw.
Default (None) -> read mapping from dds.
(-1, 0) -> one band per SPW.
accumulate:
dtype: bool
default: false
Expand Down
5 changes: 4 additions & 1 deletion pfb/parser/fluxmop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ inputs:
- (.)out.yml

outputs:
{}
dds-out:
implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
5 changes: 4 additions & 1 deletion pfb/parser/grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ inputs:
- (.)out.yml

outputs:
{}
dds-out:
implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
5 changes: 4 additions & 1 deletion pfb/parser/hci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ inputs:
- (.)cgopts.yml

outputs:
{}
dir-out:
implicit: '{current.output-filename}_{current.product}.fds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
5 changes: 4 additions & 1 deletion pfb/parser/init.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ inputs:
- (.)dist.yml

outputs:
{}
xds-out:
implicit: '{current.output-filename}_{current.product}.xds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
5 changes: 4 additions & 1 deletion pfb/parser/klean.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ inputs:
- (.)out.yml

outputs:
{}
dds-out:
implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
5 changes: 4 additions & 1 deletion pfb/parser/sara.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ inputs:
Reduce the regularisation strength by this fraction at the outset.

outputs:
{}
dds-out:
implicit: '{current.output-filename}_{current.product}_{current.suffix}.dds'
dtype: Directory
must_exist: false

policies:
pass_missing_as_none: true
27 changes: 21 additions & 6 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def construct_mappings(ms_name,
if not idx.any():
continue
idx0 = np.argmax(idx) # returns index of first True element
# np.searchsorted here?
try:
# returns zero if not idx.any()
assert idx[idx0]
Expand Down Expand Up @@ -446,11 +445,27 @@ def construct_mappings(ms_name,
row_mapping[ms][idt]['start_indices'] = ridx
row_mapping[ms][idt]['counts'] = rcounts

nfreq_chunks = nchan_in // cpit
freq_chunks = (cpit,)*nfreq_chunks
rem = nchan_in - nfreq_chunks * cpit
if rem:
freq_chunks += (rem,)
freq_idx0 = freq_mapping[ms][idt]['start_indices'][0]
if freq_idx0 != 0:
freq_chunks = (freq_idx0,) + tuple(freq_mapping[ms][idt]['counts'])
else:
freq_chunks = tuple(freq_mapping[ms][idt]['counts'])
freq_idxf = np.sum(freq_chunks)
if freq_idxf != nchan_in:
freq_chunkf = nchan_in - freq_idxf
freq_chunks += (freq_chunkf,)

try:
assert np.sum(freq_chunks) == nchan_in
except Exception as e:
raise RuntimeError("Something went wrong constructing the "
"frequency mapping. sum(fchunks != nchan)")

# nfreq_chunks = nchan_in // cpit
# freq_chunks = (cpit,)*nfreq_chunks
# rem = nchan_in - nfreq_chunks * cpit
# if rem:
# freq_chunks += (rem,)

ms_chunks[ms].append({'row': row_chunks,
'chan': freq_chunks})
Expand Down
4 changes: 4 additions & 0 deletions pfb/utils/stokes2vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def single_stokes(
utime=None,
tbin_idx=None,
tbin_counts=None,
chan_low=None,
chan_high=None,
radec=None,
antpos=None,
poltype=None,
Expand Down Expand Up @@ -315,6 +317,8 @@ def single_stokes(
'freq_out': freq_out,
'freq_min': freq_min,
'freq_max': freq_max,
'chan_low': chan_low,
'chan_high': chan_high,
'bandid': bandid,
'time_out': time_out,
'time_min': utime.min(),
Expand Down
123 changes: 78 additions & 45 deletions pfb/workers/degrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
def degrid(**kw):
'''
Predict model visibilities to measurement sets.
The default behaviour is to read the frequency mapping from the dds and
degrid one image per band.
If channels-per-image is provided, the model is evaluated from the mds.
'''
opts = OmegaConf.create(kw)

Expand All @@ -41,17 +44,28 @@ def degrid(**kw):
msnames.append(*list(map(msstore.fs.unstrip_protocol, mslist)))
except:
raise ValueError(f"No MS at {ms}")
if len(opts.ms) > 1:
raise ValueError(f"There must be a single MS at {opts.ms}")
opts.ms = msnames
modelstore = DaskMSStore(opts.mds.rstrip('/'))
try:
assert modelstore.exists()
except Exception as e:
raise ValueError(f"There must be a model at "
f"to {opts.mds}")
opts.mds = modelstore.url

basename = opts.output_filename

if opts.mds is None:
mds_store = DaskMSStore(f'{basename}_{opts.suffix}_model.mds')
else:
mds_store = DaskMSStore(opts.mds)
try:
assert mds_store.exists()
except Exception as e:
raise ValueError(f"No mds at {opts.mds}")
opts.mds = mds_store.url

dds_store = DaskMSStore(f'{basename}_{opts.suffix}.dds')
if opts.channels_per_image is None and not mds_store.exists():
try:
assert dds_store.exists()
except Exception as e:
raise ValueError(f"There must be a dds at {dds_store.url}. "
"Specify mds and channels-per-image to degrid from mds.")
opts.dds = dds_store.url
OmegaConf.set_struct(opts, True)

if opts.product.upper() not in ["I"]:
Expand Down Expand Up @@ -96,14 +110,15 @@ def _degrid(**kw):
from daskms import xds_from_storage_ms as xds_from_ms
from daskms import xds_from_storage_table as xds_from_table
from daskms import xds_to_storage_table as xds_to_table
from daskms.fsspec_store import DaskMSStore
import dask.array as da
from africanus.constants import c as lightspeed
from africanus.gridding.wgridder.dask import model as im2vis
from pfb.operators.gridder import comps2vis, _comps2vis_impl
from pfb.utils.fits import load_fits, data_from_header, set_wcs
from regions import Regions
from astropy.io import fits
from pfb.utils.misc import compute_context
from pfb.utils.naming import xds_from_url
import xarray as xr
import sympy as sm
from sympy.utilities.lambdify import lambdify
Expand All @@ -112,29 +127,18 @@ def _degrid(**kw):
resize_thread_pool(opts.nthreads)

client = get_client()
mds = xr.open_zarr(opts.mds)
foo = client.scatter(mds, broadcast=True)
wait(foo)

# grid spec
cell_rad = mds.cell_rad_x
cell_deg = np.rad2deg(cell_rad)
nx = mds.npix_x
ny = mds.npix_y
x0 = mds.center_x
y0 = mds.center_y
radec = (mds.ra, mds.dec)

# model func
params = sm.symbols(('t','f'))
params += sm.symbols(tuple(mds.params.values))
symexpr = parse_expr(mds.parametrisation)
modelf = lambdify(params, symexpr)
texpr = parse_expr(mds.texpr)
tfunc = lambdify(params[0], texpr)
fexpr = parse_expr(mds.fexpr)
ffunc = lambdify(params[1], fexpr)
dds_store = DaskMSStore(opts.dds)
mds_store = DaskMSStore(opts.mds)

if opts.channels_per_image is None:
if dds_store.exists():
dds, dds_list = xds_from_url(dds_store.url)
cpi = 0
for ds in dds:
cpi = np.maximum(ds.chan.size, cpi)
else:
cpi = opts.channels_per_image

if opts.freq_range is not None and len(opts.freq_range):
fmin, fmax = opts.freq_range.strip(' ').split(':')
Expand All @@ -160,7 +164,32 @@ def _degrid(**kw):
construct_mappings(opts.ms,
None,
ipi=opts.integrations_per_image,
cpi=opts.channels_per_image)
cpi=cpi,
freq_min=freq_min,
freq_max=freq_max)

mds = xr.open_zarr(opts.mds)
foo = client.scatter(mds, broadcast=True)
wait(foo)

# grid spec
cell_rad = mds.cell_rad_x
cell_deg = np.rad2deg(cell_rad)
nx = mds.npix_x
ny = mds.npix_y
x0 = mds.center_x
y0 = mds.center_y
radec = (mds.ra, mds.dec)

# model func
params = sm.symbols(('t','f'))
params += sm.symbols(tuple(mds.params.values))
symexpr = parse_expr(mds.parametrisation)
modelf = lambdify(params, symexpr)
texpr = parse_expr(mds.texpr)
tfunc = lambdify(params[0], texpr)
fexpr = parse_expr(mds.fexpr)
ffunc = lambdify(params[1], fexpr)

# load region file if given
masks = []
Expand Down Expand Up @@ -226,7 +255,7 @@ def _degrid(**kw):
for i, mask in enumerate(masks):
out_data = []
columns = []
for ds in xds:
for k, ds in enumerate(xds):
if i == 0:
column_name = opts.model_column
else:
Expand All @@ -240,33 +269,37 @@ def _degrid(**kw):

# time <-> row mapping
utime = da.from_array(utimes[ms][idt],
chunks=opts.integrations_per_image)
chunks=opts.integrations_per_image)
tidx = da.from_array(time_mapping[ms][idt]['start_indices'],
chunks=1)
chunks=1)
tcnts = da.from_array(time_mapping[ms][idt]['counts'],
chunks=1)
chunks=1)

ridx = da.from_array(row_mapping[ms][idt]['start_indices'],
chunks=opts.integrations_per_image)
chunks=opts.integrations_per_image)
rcnts = da.from_array(row_mapping[ms][idt]['counts'],
chunks=opts.integrations_per_image)
chunks=opts.integrations_per_image)

# freq <-> band mapping
# freq <-> band mapping (entire freq axis)
freq = da.from_array(freqs[ms][idt],
chunks=opts.channels_per_image)
fidx = da.from_array(freq_mapping[ms][idt]['start_indices'],
chunks=1)
fcnts = da.from_array(freq_mapping[ms][idt]['counts'],
chunks=1)
chunks=ms_chunks[ms][k]['chan'])
fcnts = np.array(ms_chunks[ms][k]['chan'])
fidx = np.concatenate((np.array([0]), np.cumsum(fcnts)))[0:-1]

fidx = da.from_array(fidx,
chunks=1)
fcnts = da.from_array(fcnts,
chunks=1)

# number of chunks need to math in mapping and coord
# number of chunks need to match in mapping and coord
ntime_out = len(tidx.chunks[0])
assert len(utime.chunks[0]) == ntime_out
nfreq_out = len(fidx.chunks[0])
assert len(freq.chunks[0]) == nfreq_out
# and they need to match the number of row chunks
uvw = clone(ds.UVW.data)
assert len(uvw.chunks[0]) == len(tidx.chunks[0])

vis = comps2vis(uvw,
utime,
freq,
Expand Down
10 changes: 9 additions & 1 deletion pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def _grid(**kw):
xds_dct[tbid]['radec'] = (ds.ra, ds.dec)
xds_dct[tbid]['time_out'] = times_out[0]
xds_dct[tbid]['freq_out'] = freqs_out[b]
xds_dct[tbid]['chan_low'] = ds.chan_low
xds_dct[tbid]['chan_high'] = ds.chan_high
else:
ntime = ntime_in
times_out = times_in
Expand All @@ -335,6 +337,8 @@ def _grid(**kw):
xds_dct[tbid]['radec'] = (ds.ra, ds.dec)
xds_dct[tbid]['time_out'] = times_out[t]
xds_dct[tbid]['freq_out'] = freqs_out[b]
xds_dct[tbid]['chan_low'] = ds.chan_low
xds_dct[tbid]['chan_high'] = ds.chan_high

if opts.dirty:
print(f"Image size = (ntime={ntime}, nband={nband}, "
Expand Down Expand Up @@ -369,6 +373,8 @@ def _grid(**kw):
dsl = ds_dct['dsl']
time_out = ds_dct['time_out']
freq_out = ds_dct['freq_out']
chan_low = ds_dct['chan_low']
chan_high = ds_dct['chan_high']
iter0 = 0
if from_cache:
out_ds_name = f'{dds_store.url}/time{timeid}_band{bandid}.zarr'
Expand Down Expand Up @@ -418,7 +424,9 @@ def _grid(**kw):
'super_resolution_factor': opts.super_resolution_factor,
'field_of_view': opts.field_of_view,
'product': opts.product,
'niters': iter0
'niters': iter0,
'chan_low': chan_low,
'chan_high': chan_high,
}

# get the model
Expand Down
Loading

0 comments on commit 064cad6

Please sign in to comment.