From ca1d7dead749447d6b05b1e73d5459b9e8dd55ef Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 10:18:00 -0400 Subject: [PATCH 1/6] Add docstring to tests for client, extract_holog and holog. --- tests/unit/test_client.py | 22 +++++-- tests/unit/test_extract_holog.py | 16 ++++++ tests/unit/test_holog.py | 98 +++++++++++++++++++++----------- 3 files changed, 98 insertions(+), 38 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2c6f30e7..6670046d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -23,7 +23,10 @@ def teardown_method(self): pass def test_client_spawn(self): - """ Test client """ + """ + Run astrohack_local_client with N cores and with a memory_limit of M GB to create an instance of the + astrohack Dask client. + """ import distributed from astrohack.client import astrohack_local_client @@ -47,7 +50,10 @@ def test_client_spawn(self): client.shutdown() def test_client_dask_dir(self): - """ Test client """ + """ + Run astrohack_local_client with N cores and with a memory_limit of M GB to create an instance of the + astrohack Dask client. Check that temporary files are written to dask_local_dir. + """ import distributed from astrohack.client import astrohack_local_client @@ -56,7 +62,12 @@ def test_client_dask_dir(self): log_parms = {'log_level':'DEBUG'} - client = astrohack_local_client(cores=2, memory_limit='8GB', log_parms=log_parms, dask_local_dir='./dask_test_dir') + client = astrohack_local_client( + cores=2, + memory_limit='8GB', + log_parms=log_parms, + dask_local_dir='./dask_test_dir' + ) try: if os.path.exists('./dask_test_dir') is False: @@ -69,7 +80,10 @@ def test_client_dask_dir(self): client.shutdown() def test_client_logger(self): - """ Test client """ + """ + Run astrohack_local_client with N cores and with a memory_limit of M GB without any errors and the messages + will be logged in the terminal. + """ import os import re import distributed diff --git a/tests/unit/test_extract_holog.py b/tests/unit/test_extract_holog.py index 77572daf..9cc7fb53 100644 --- a/tests/unit/test_extract_holog.py +++ b/tests/unit/test_extract_holog.py @@ -32,6 +32,9 @@ def teardown_method(self): pass def test_extract_holog_obs_dict(self): + ''' + Specify a holography observations dictionary and check that the proper dictionary is created. + ''' # Generate pointing file extract_pointing( @@ -72,6 +75,9 @@ def test_extract_holog_obs_dict(self): assert holog_obs_test_dict == holog_obs_dict def test_extract_holog_ddi(self): + ''' + Specify a ddi value to be process and check that it is the only one processed. + ''' # Generate pointing file extract_pointing( @@ -107,6 +113,9 @@ def test_extract_holog_ddi(self): assert list(holog_mds.keys()) == ['ddi_1'] def test_extract_holog_overwrite(self): + ''' + Specify that the output file should be overwritten if it exists; check that it is overwritten. + ''' # Generate pointing file extract_pointing( @@ -151,6 +160,10 @@ def test_extract_holog_overwrite(self): assert initial_time != final_time def test_extract_holog_baseline_average_distance(self): + ''' + Run extract_holog using the baseline average distance as a filter; check that only the baselines with this + average distance are returned. + ''' # extract pointing data pnt_mds = extract_pointing( @@ -174,6 +187,9 @@ def test_extract_holog_baseline_average_distance(self): assert list(holog_mds['ddi_0']['map_0'].keys()) == ['ant_ea25'] def test_extract_holog_baseline_average_nearest(self): + ''' + Run extract_holog using the nearest baseline as a filter; check that only the nearest baselines are returned. + ''' # extract pointing data pnt_mds = extract_pointing( diff --git a/tests/unit/test_holog.py b/tests/unit/test_holog.py index 97362170..cda4a978 100644 --- a/tests/unit/test_holog.py +++ b/tests/unit/test_holog.py @@ -11,8 +11,10 @@ from astrohack.extract_holog import extract_holog from astrohack.extract_pointing import extract_pointing + def relative_difference(result, expected): - return 2*np.abs(result - expected)/(abs(result) + abs(expected)) + return 2 * np.abs(result - expected) / (abs(result) + abs(expected)) + class TestHolog(): @classmethod @@ -20,7 +22,7 @@ def setup_class(cls): """ setup any state specific to the execution of the given test class such as fetching test data """ astrohack.data.datasets.download(file="ea25_cal_small_after_fixed.split.ms", folder="data/") - + astrohack.data.datasets.download(file='extract_holog_verification.json') astrohack.data.datasets.download(file='holog_numerical_verification.json') @@ -51,7 +53,7 @@ def setup_class(cls): holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', overwrite=True, parallel=False ) @@ -73,8 +75,11 @@ def teardown_method(self): """ teardown any state that was previously setup for all methods of the given class """ pass - def test_holog_grid_cell_size(self): + """ + Calculate the correct grid and cell size when compared to known values in the test file; known values are + provided by a test json file. + """ tolerance = 2.e-5 @@ -83,28 +88,33 @@ def test_holog_grid_cell_size(self): with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as attr_file: image_attr = json.load(attr_file) - + for i, _ in enumerate(image_attr['cell_size']): assert relative_difference( - image_attr['cell_size'][i], + image_attr['cell_size'][i], reference_dict["vla"]['cell_size'][i] ) < tolerance assert relative_difference( - image_attr['grid_size'][i], + image_attr['grid_size'][i], reference_dict["vla"]['grid_size'][i] ) < tolerance - def test_holog_image_name(self): + """ + Test holog image name created correctly. + """ assert os.path.exists('data/ea25_cal_small_after_fixed.split.image.zarr') def test_holog_ant_id(self): + """ + Specify a single antenna to process; check that is the only antenna returned. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', ant_id=['ea25'], overwrite=True, parallel=False @@ -113,10 +123,13 @@ def test_holog_ant_id(self): assert list(image_mds.keys()) == ['ant_ea25'] def test_holog_ddi(self): + """ + Specify a single ddi to process; check that is the only ddi returned. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', overwrite=True, ddi=[0], parallel=False @@ -126,14 +139,15 @@ def test_holog_ddi(self): for ddi in image_mds[ant].keys(): assert ddi == "ddi_0" - - def test_holog_padding_factor(self): + """ + Specify a padding factor to use in the image creation; check that image size is created. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', - padding_factor=50, + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + padding_factor=50, overwrite=True, parallel=False ) @@ -143,27 +157,33 @@ def test_holog_padding_factor(self): assert image_mds[ant][ddi].APERTURE.shape == (1, 1, 4, 529, 529) def test_holog_chan_average(self): + """ + Check that channel average flag was set holog is run. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', chan_average=True, overwrite=True, parallel=False ) with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as json_attr: - json_file = json.load(json_attr) - + json_file = json.load(json_attr) + assert json_file['chan_average'] == True def test_holog_scan_average(self): + """ + Check that scan average flag was set holog is run. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', scan_average=False, overwrite=True, parallel=False - ) + ) with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as json_attr: json_file = json.load(json_attr) @@ -171,55 +191,67 @@ def test_holog_scan_average(self): assert json_file['scan_average'] == False def test_holog_grid_interpolation(self): + """ + Check that grid interpolation flag was set holog is run. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', grid_interpolation_mode='nearest', overwrite=True, parallel=False ) with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as json_attr: - json_file = json.load(json_attr) + json_file = json.load(json_attr) assert json_file['grid_interpolation_mode'] == 'nearest' def test_holog_chan_tolerance(self): + """ + Check that channel tolerance ir propagated correctly. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', chan_tolerance_factor=0.0049, overwrite=True, parallel=False ) with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as json_attr: - json_file = json.load(json_attr) + json_file = json.load(json_attr) assert json_file['chan_tolerance_factor'] == 0.0049 def test_holog_to_stokes(self): + """ + Check that to_stokes flag was set holog is run. + """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', to_stokes=True, overwrite=True, parallel=False ) with open('data/ea25_cal_small_after_fixed.split.image.zarr/.image_attr') as json_attr: - json_file = json.load(json_attr) + json_file = json.load(json_attr) assert json_file['to_stokes'] == True assert (image_mds['ant_ea25']['ddi_0'].pol.values == np.array(['I', 'Q', 'U', 'V'])).all() def test_holog_overwrite(self): + """ + Specify the output file should be overwritten; check that it WAS. + """ initial_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.image.zarr') - + image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', overwrite=True, parallel=False ) @@ -229,12 +261,15 @@ def test_holog_overwrite(self): assert initial_time != modified_time def test_holog_not_overwrite(self): + """ + Specify the output file should be NOT be overwritten; check that it WAS NOT. + """ initial_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.image.zarr') - + try: holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', overwrite=False, parallel=False ) @@ -246,8 +281,3 @@ def test_holog_not_overwrite(self): modified_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.image.zarr') assert initial_time == modified_time - - - - - \ No newline at end of file From 34c71ad1583d674c5a42010eb60278c2d211ab53 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 10:27:46 -0400 Subject: [PATCH 2/6] Add docstrings to panel tests. --- tests/unit/test_panel.py | 48 ++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_panel.py b/tests/unit/test_panel.py index fcaa58fc..9122b15e 100644 --- a/tests/unit/test_panel.py +++ b/tests/unit/test_panel.py @@ -12,8 +12,10 @@ from astrohack.extract_holog import extract_holog from astrohack.extract_pointing import extract_pointing + def relative_difference(result, expected): - return 2*np.abs(result - expected)/(abs(result) + abs(expected)) + return 2 * np.abs(result - expected) / (abs(result) + abs(expected)) + class TestPanel(): @classmethod @@ -21,10 +23,10 @@ def setup_class(cls): """ setup any state specific to the execution of the given test class such as fetching test data """ astrohack.data.datasets.download(file="ea25_cal_small_after_fixed.split.ms", folder="data/") - + astrohack.data.datasets.download(file='extract_holog_verification.json') astrohack.data.datasets.download(file='holog_numerical_verification.json') - + astrohack.data.datasets.download(file='panel_cutoff_mask') extract_pointing( @@ -54,7 +56,7 @@ def setup_class(cls): holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr', - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', overwrite=True, parallel=False ) @@ -86,12 +88,17 @@ def teardown_method(self): """ teardown any state that was previously setup for all methods of the given class """ pass - def test_panel_name(self): + """ + Check that the panel output name was created correctly. + """ assert os.path.exists('data/ea25_cal_small_after_fixed.split.panel.zarr') def test_panel_ant_id(self): + """ + Specify a single antenna to process; check that only that antenna was processed. + """ panel_mds = panel( image_name='data/ea25_cal_small_after_fixed.split.image.zarr', @@ -108,6 +115,9 @@ def test_panel_ant_id(self): assert list(panel_mds.keys()) == ['ant_ea25'] def test_panel_ddi(self): + """ + Specify a single ddi to process; check that only that ddi was processed. + """ panel_mds = panel( image_name='data/ea25_cal_small_after_fixed.split.image.zarr', @@ -125,8 +135,11 @@ def test_panel_ddi(self): assert ddi == "ddi_0" def test_panel_overwrite(self): + """ + Specify the output file should be overwritten; check that it WAS. + """ initial_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.image.zarr') - + panel_mds = panel( image_name='data/ea25_cal_small_after_fixed.split.image.zarr', panel_name='data/ea25_cal_small_after_fixed.split.panel.zarr', @@ -142,8 +155,11 @@ def test_panel_overwrite(self): assert initial_time != modified_time def test_panel_not_overwrite(self): + """ + Specify the output file should be NOT be overwritten; check that it WAS NOT. + """ initial_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.panel.zarr') - + try: panel_mds = panel( image_name='data/ea25_cal_small_after_fixed.split.image.zarr', @@ -164,7 +180,10 @@ def test_panel_not_overwrite(self): assert initial_time == modified_time def test_panel_mode(self): - panel_list=['3-4', '5-27', '5-37', '5-38'] + """ + Specify panel computation mode and check that the data rms responded as expected. + """ + panel_list = ['3-4', '5-27', '5-37', '5-38'] panel_mds = panel( image_name='data/ea25_cal_small_after_fixed.split.image.zarr', @@ -183,23 +202,20 @@ def test_panel_mode(self): mean_rms = panel_mds["ant_ea25"]["ddi_0"].sel(labels=panel_list).apply(np.std).PANEL_SCREWS.values - assert mean_rms < default_rms + assert mean_rms < default_rms def test_panel_cutoff(self): + """ + Set cutoff=0 and compare results to known truth value array. + """ with open("panel_cutoff_mask.npy", "rb") as array: reference_array = np.load(array) panel_mds = panel( - image_name='data/ea25_cal_small_after_fixed.split.image.zarr', + image_name='data/ea25_cal_small_after_fixed.split.image.zarr', cutoff=0.0, parallel=False, overwrite=True ) assert np.all(panel_mds["ant_ea25"]["ddi_0"].MASK.values == reference_array) - - - - - - \ No newline at end of file From 6e3392d9beff43d9f633f72a2059941aad47c583 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 10:53:16 -0400 Subject: [PATCH 3/6] fix current extract_locit tests and add docstings more tests. --- tests/unit/test_extract_locit.py | 108 +++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 33 deletions(-) diff --git a/tests/unit/test_extract_locit.py b/tests/unit/test_extract_locit.py index 3c503108..5218999c 100644 --- a/tests/unit/test_extract_locit.py +++ b/tests/unit/test_extract_locit.py @@ -1,9 +1,11 @@ +import os import pytest import shutil import astrohack from astrohack.extract_locit import extract_locit + class TestExtractLocit(): cal_table = './data/locit-input-pha.cal' locit_name = './data/locit-input-pha.locit.zarr' @@ -13,58 +15,98 @@ def setup_class(cls): """ setup any state specific to the execution of the given test class such as fetching test data """ astrohack.data.datasets.download(file="locit-input-pha.cal", folder="data") - + @classmethod def teardown_class(cls): """ teardown any state that was previously setup with a call to setup_class such as deleting test data """ shutil.rmtree("data") - + def teardown_method(self): shutil.rmtree(self.locit_name) - def test_extract_locit_simple(self): - + def test_extract_locit_creation(self): + """ + Create locit file with a cal-table and check it is created correctly. + """ + # Create locit_mds and check the dictionary structure locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name) - - expected_keys = ['obs_info', 'ant_info', 'ant_ea01', 'ant_ea02', 'ant_ea04', 'ant_ea05', 'ant_ea06', 'ant_ea07', 'ant_ea08', 'ant_ea09', 'ant_ea10', 'ant_ea11', 'ant_ea12', 'ant_ea13', 'ant_ea15', 'ant_ea16', 'ant_ea17', 'ant_ea18', 'ant_ea19', 'ant_ea20', 'ant_ea21', 'ant_ea22', 'ant_ea23', 'ant_ea24', 'ant_ea25', 'ant_ea26', 'ant_ea27', 'ant_ea28'] - + + expected_keys = ['obs_info', 'ant_info', 'ant_ea01', 'ant_ea02', 'ant_ea04', 'ant_ea05', 'ant_ea06', 'ant_ea07', + 'ant_ea08', 'ant_ea09', 'ant_ea10', 'ant_ea11', 'ant_ea12', 'ant_ea13', 'ant_ea15', 'ant_ea16', + 'ant_ea17', 'ant_ea18', 'ant_ea19', 'ant_ea20', 'ant_ea21', 'ant_ea22', 'ant_ea23', 'ant_ea24', + 'ant_ea25', 'ant_ea26', 'ant_ea27', 'ant_ea28'] + for key in locit_mds.keys(): assert key in expected_keys - + + assert os.path.exists(self.locit_name) + def test_extract_locit_antenna_select(self): - - locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name, ant_id='ea17') - + """ + Check that only specified antenna is processed. + """ + + locit_mds = extract_locit( + self.cal_table, + locit_name=self.locit_name, + ant_id='ea17' + ) + # There should only be 1 antenna in the dict named ea17 assert len(locit_mds.keys()) == 3 - assert 'ant_ea17' in locit_mds.keys() - + + # Check that only the specific antenna is in the keys. + assert list(locit_mds.keys()) == ['ant_ea17'] + def test_extract_locit_ddi(self): - + """ + Check that only specified ddi is processed. + """ + locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name, ddi=0) - - # Usually each antenna has a ddi_0 and ddi_1 with all selected. This should just give ddi_0 + + # Check that only the specific ddi is in the keys. assert len(locit_mds['ant_ea01'].keys()) == 1 - assert 'ddi_0' in locit_mds['ant_ea01'].keys() - + assert list(locit_mds['ant_ea01'].keys()) == ["ddi_0"] + def test_extract_locit_overwrite(self): - - locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name) - - # an exception shouldn't be raised with overwrite set to True - try: - locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name, overwrite=True) - except: - assert False, "Failed to overwrite" - + """ + Specify the output file should be overwritten; check that it WAS. + """ + + # To check this properly we need to not only know an exception was not thrown but that the file is ACTUALLY + # overwritten. We do this by checking the modification time. + initial_time = os.path.getctime(self.locit_name) + + extract_locit( + self.cal_table, + locit_name=self.locit_name, + overwrite=True + ) + + modified_time = os.path.getctime(self.locit_name) + + assert initial_time != modified_time + def test_extract_locit_no_overwrite(self): - - locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name, overwrite=False) - - # with overwrite set to False an exception should be raised + """ + Specify the output file should be NOT be overwritten; check that it WAS NOT. + """ + initial_time = os.path.getctime(self.locit_name) + try: - locit_mds = extract_locit(self.cal_table, locit_name=self.locit_name, overwrite=False) - except: + extract_locit( + self.cal_table, + locit_name=self.locit_name, + overwrite=False + ) + + except FileExistsError: pass + + finally: + modified_time = os.path.getctime(self.locit_name) + + assert initial_time == modified_time From d872e17913abcb603c323ced37ce59c7c5620454 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 11:02:33 -0400 Subject: [PATCH 4/6] Formatting changes. --- tests/unit/test_astrohack_import_template.py | 1 + tests/unit/test_client.py | 31 +++++++++--------- tests/unit/test_dio.py | 22 ++++++------- tests/unit/test_extract_holog.py | 34 ++++++++++---------- 4 files changed, 45 insertions(+), 43 deletions(-) diff --git a/tests/unit/test_astrohack_import_template.py b/tests/unit/test_astrohack_import_template.py index 8fa50a9d..c2d6b849 100644 --- a/tests/unit/test_astrohack_import_template.py +++ b/tests/unit/test_astrohack_import_template.py @@ -1,5 +1,6 @@ import pytest + class TestAstrohack(): @classmethod def setup_class(cls): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6670046d..30bb7511 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,6 +1,7 @@ import os import pytest + class TestAstrohack(): @classmethod def setup_class(cls): @@ -31,19 +32,19 @@ def test_client_spawn(self): from astrohack.client import astrohack_local_client - DEFAULT_DASK_ADDRESS="127.0.0.1:8786" + DEFAULT_DASK_ADDRESS = "127.0.0.1:8786" - log_parms = {'log_level':'DEBUG'} + log_parms = {'log_level': 'DEBUG'} client = astrohack_local_client(cores=2, memory_limit='8GB', log_parms=log_parms) - + if not distributed.client._get_global_client(): try: distributed.Client(DEFAULT_DASK_ADDRESS, timeout=2) except OSError: assert False - + finally: client.shutdown() @@ -58,9 +59,9 @@ def test_client_dask_dir(self): from astrohack.client import astrohack_local_client - DEFAULT_DASK_ADDRESS="127.0.0.1:8786" + DEFAULT_DASK_ADDRESS = "127.0.0.1:8786" - log_parms = {'log_level':'DEBUG'} + log_parms = {'log_level': 'DEBUG'} client = astrohack_local_client( cores=2, @@ -68,14 +69,14 @@ def test_client_dask_dir(self): log_parms=log_parms, dask_local_dir='./dask_test_dir' ) - + try: if os.path.exists('./dask_test_dir') is False: raise FileNotFoundError except FileNotFoundError: assert False - + finally: client.shutdown() @@ -90,27 +91,27 @@ def test_client_logger(self): from astrohack.client import astrohack_local_client - DEFAULT_DASK_ADDRESS="127.0.0.1:8786" + DEFAULT_DASK_ADDRESS = "127.0.0.1:8786" log_parms = { - 'log_level':'DEBUG', - 'log_to_file':True, + 'log_level': 'DEBUG', + 'log_to_file': True, 'log_file': 'astrohack_log_file' } client = astrohack_local_client(cores=2, memory_limit='8GB', log_parms=log_parms) - + files = os.listdir(".") try: for file in files: if re.match("^astrohack_log_file+[0-9].*log", file) is not None: return - + raise FileNotFoundError except FileNotFoundError: assert False - + finally: - client.shutdown() \ No newline at end of file + client.shutdown() diff --git a/tests/unit/test_dio.py b/tests/unit/test_dio.py index d8225557..6f5d6bcb 100644 --- a/tests/unit/test_dio.py +++ b/tests/unit/test_dio.py @@ -13,6 +13,7 @@ from astrohack.panel import panel from astrohack import holog + class TestAstrohackDio(): datafolder = 'dioData' holog_mds = dict() @@ -21,16 +22,15 @@ class TestAstrohackDio(): @classmethod def setup_class(cls): - astrohack.data.datasets.download(file="ea25_cal_small_after_fixed.split.ms", folder=cls.datafolder) - + extract_pointing( ms_name=cls.datafolder + '/ea25_cal_small_after_fixed.split.ms', point_name=cls.datafolder + '/ea25_cal_small_after_fixed.split.point.zarr', parallel=True, overwrite=True ) - + cls.holog_mds = extract_holog( ms_name=cls.datafolder + '/ea25_cal_small_after_fixed.split.ms', point_name=cls.datafolder + '/ea25_cal_small_after_fixed.split.point.zarr', @@ -58,32 +58,32 @@ def setup_class(cls): @classmethod def teardown_class(cls): shutil.rmtree(cls.datafolder) - + def test_open_holog(self): '''Open a holog file and return a holog data object''' holog_data = open_holog(self.datafolder + '/ea25_cal_small_after_fixed.split.holog.zarr') assert holog_data == self.holog_mds - + def test_open_image(self): '''Open an image file and return an image data object''' image_data = open_image(self.datafolder + '/ea25_cal_small_after_fixed.split.image.zarr') - + assert image_data == self.image_mds - + def test_open_panel(self): '''Open a panel file and return a panel data object''' panel_data = open_panel(self.datafolder + '/ea25_cal_small_after_fixed.split.panel.zarr') - + assert panel_data == self.panel_mds - + def test_open_pointing(self): '''Open a pointing file and return a pointing data object''' pointing_data = open_pointing(self.datafolder + '/ea25_cal_small_after_fixed.split.point.zarr') # check if keys match expected? # How to check xarray content... - + expected_keys = ['point_meta_ds', 'ant_ea25', 'ant_ea04', 'ant_ea06'] - + for key in pointing_data.keys(): assert key in expected_keys diff --git a/tests/unit/test_extract_holog.py b/tests/unit/test_extract_holog.py index 9cc7fb53..c344473e 100644 --- a/tests/unit/test_extract_holog.py +++ b/tests/unit/test_extract_holog.py @@ -9,6 +9,7 @@ from astrohack.extract_pointing import extract_pointing from astrohack.extract_holog import generate_holog_obs_dict + class TestExtractHolog(): @classmethod def setup_class(cls): @@ -24,7 +25,7 @@ def teardown_class(cls): def setup_method(self): """ setup any state specific to all methods of the given class """ - + pass def teardown_method(self): @@ -32,9 +33,9 @@ def teardown_method(self): pass def test_extract_holog_obs_dict(self): - ''' + """ Specify a holography observations dictionary and check that the proper dictionary is created. - ''' + """ # Generate pointing file extract_pointing( @@ -68,16 +69,16 @@ def test_extract_holog_obs_dict(self): # Get holog_obs_dict created by extract_holog with open(".holog_obs_dict.json") as holog_dict_file: holog_obs_test_dict = json.load(holog_dict_file) - + holog_obs_test_dict = json.loads(holog_obs_test_dict) # Check that the holog_obs_dict used in extract_holog matches the input holog_obs_dict assert holog_obs_test_dict == holog_obs_dict def test_extract_holog_ddi(self): - ''' + """ Specify a ddi value to be process and check that it is the only one processed. - ''' + """ # Generate pointing file extract_pointing( @@ -107,15 +108,14 @@ def test_extract_holog_ddi(self): parallel=False, overwrite=True ) - # Check that the holog_obs_dict used in extract_holog matches the input holog_obs_dict assert list(holog_mds.keys()) == ['ddi_1'] def test_extract_holog_overwrite(self): - ''' + """ Specify that the output file should be overwritten if it exists; check that it is overwritten. - ''' + """ # Generate pointing file extract_pointing( @@ -142,7 +142,7 @@ def test_extract_holog_overwrite(self): parallel=False, overwrite=True ) - + initial_time = os.path.getctime('data/ea25_cal_small_after_fixed.split.holog.zarr') # Extract holography data @@ -160,10 +160,10 @@ def test_extract_holog_overwrite(self): assert initial_time != final_time def test_extract_holog_baseline_average_distance(self): - ''' + """ Run extract_holog using the baseline average distance as a filter; check that only the baselines with this average distance are returned. - ''' + """ # extract pointing data pnt_mds = extract_pointing( @@ -182,14 +182,14 @@ def test_extract_holog_baseline_average_distance(self): parallel=False, overwrite=True ) - + # Check that the expected antenna is present. assert list(holog_mds['ddi_0']['map_0'].keys()) == ['ant_ea25'] def test_extract_holog_baseline_average_nearest(self): - ''' + """ Run extract_holog using the nearest baseline as a filter; check that only the nearest baselines are returned. - ''' + """ # extract pointing data pnt_mds = extract_pointing( @@ -208,6 +208,6 @@ def test_extract_holog_baseline_average_nearest(self): parallel=False, overwrite=True ) - + # Check that the expected antenna is present. - assert list(holog_mds['ddi_0']['map_0'].keys()).sort() == ['ant_ea25', 'ant_ea06'].sort() \ No newline at end of file + assert list(holog_mds['ddi_0']['map_0'].keys()).sort() == ['ant_ea25', 'ant_ea06'].sort() From 12f6d61f9ca5d5266d440c2d37d8e853500565c4 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 11:41:33 -0400 Subject: [PATCH 5/6] Fix download error. --- src/astrohack/data/_dropbox.py | 169 ++++++++++++++++--------------- tests/unit/test_extract_locit.py | 19 +++- tests/unit/test_panel.py | 6 +- 3 files changed, 103 insertions(+), 91 deletions(-) diff --git a/src/astrohack/data/_dropbox.py b/src/astrohack/data/_dropbox.py index 83eb695d..427dfa9d 100644 --- a/src/astrohack/data/_dropbox.py +++ b/src/astrohack/data/_dropbox.py @@ -7,90 +7,91 @@ from astrohack._utils._logger._astrohack_logger import _get_astrohack_logger FILE_ID = { - 'ea25_cal_small_before_fixed.split.ms': - { - 'file':'ea25_cal_small_before_fixed.split.ms.zip', - 'id':'m2qnd2w6g9fhdxyzi6h7f', - 'rlkey':'d4dgztykxpnnqrei7jhb1cu7m' - }, - 'ea25_cal_small_after_fixed.split.ms': - { - 'file':'ea25_cal_small_after_fixed.split.ms.zip', - 'id':'o3tl05e3qa440s4rk5owf', - 'rlkey':'hoxte3zzeqgkju2ywnif2t7ko' - }, - 'J1924-2914.ms.calibrated.split.SPW3': - { - 'file':'J1924-2914.ms.calibrated.split.SPW3.zip', - 'id':'kyrwc5y6u7lxbmqw7fveh', - 'rlkey':'r23qakcm24bid2x2cojsd96gs' - }, - 'extract_holog_verification.json': - { - 'file': 'extract_holog_verification.json', - 'id':'6pzucjd48a4n0eb74wys9', - 'rlkey':'azuynw358zxvse9i225sbl59s' - }, - 'holog_numerical_verification.json': - { - 'file':'holog_numerical_verification.json', - 'id':'x69700pznt7uktwprdqpk', - 'rlkey':'bxn9me7dgnxrtzvvay7xgicmi' - }, - 'locit-input-pha.cal': - { - 'file':'locit-input-pha.cal.zip', - 'id':'8fftz5my9h8ca2xdlupym', - 'rlkey':'fxfid92953ycorh5wrhfgh78b' - }, - 'panel_cutoff_mask': - { - 'file':'panel_cutoff_mask.npy', - 'id':'8ta02t72vwcv4ketv8rfw', - 'rlkey':'qsmos4hx2duz8upb83hghi6q8' - } - - + 'ea25_cal_small_before_fixed.split.ms': + { + 'file': 'ea25_cal_small_before_fixed.split.ms.zip', + 'id': 'm2qnd2w6g9fhdxyzi6h7f', + 'rlkey': 'd4dgztykxpnnqrei7jhb1cu7m' + }, + 'ea25_cal_small_after_fixed.split.ms': + { + 'file': 'ea25_cal_small_after_fixed.split.ms.zip', + 'id': 'o3tl05e3qa440s4rk5owf', + 'rlkey': 'hoxte3zzeqgkju2ywnif2t7ko' + }, + 'J1924-2914.ms.calibrated.split.SPW3': + { + 'file': 'J1924-2914.ms.calibrated.split.SPW3.zip', + 'id': 'kyrwc5y6u7lxbmqw7fveh', + 'rlkey': 'r23qakcm24bid2x2cojsd96gs' + }, + 'extract_holog_verification.json': + { + 'file': 'extract_holog_verification.json', + 'id': '6pzucjd48a4n0eb74wys9', + 'rlkey': 'azuynw358zxvse9i225sbl59s' + }, + 'holog_numerical_verification.json': + { + 'file': 'holog_numerical_verification.json', + 'id': 'x69700pznt7uktwprdqpk', + 'rlkey': 'bxn9me7dgnxrtzvvay7xgicmi' + }, + 'locit-input-pha.cal': + { + 'file': 'locit-input-pha.cal.zip', + 'id': '8fftz5my9h8ca2xdlupym', + 'rlkey': 'fxfid92953ycorh5wrhfgh78b' + }, + 'panel_cutoff_mask': + { + 'file': 'panel_cutoff_mask.npy', + 'id': '8ta02t72vwcv4ketv8rfw', + 'rlkey': 'qsmos4hx2duz8upb83hghi6q8' + } + } + def download(file, folder='.'): - logger = _get_astrohack_logger() - - if os.path.exists('/'.join((folder, file))): - logger.info("File exists.") - return - - if file not in FILE_ID.keys(): - logger.info("Requested file not found") - - return - - fullname=FILE_ID[file]['file'] - id=FILE_ID[file]['id'] - rlkey=FILE_ID[file]['rlkey'] - - url = 'https://www.dropbox.com/scl/fi/{id}/{file}?rlkey={rlkey}'.format(id=id, file=fullname, rlkey=rlkey) - - headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} - - r = requests.get(url, stream=True, headers=headers) - total = int(r.headers.get('content-length', 0)) - - fullname = '/'.join((folder, fullname)) - - with open(fullname, 'wb') as fd, tqdm( - desc=fullname, - total=total, - unit='iB', - unit_scale=True, - unit_divisor=1024) as bar: - for chunk in r.iter_content(chunk_size=1024): - if chunk: - size=fd.write(chunk) - bar.update(size) - - if zipfile.is_zipfile(fullname): - shutil.unpack_archive(filename=fullname, extract_dir=folder) - - # Let's clean up after ourselves - os.remove(fullname) \ No newline at end of file + logger = _get_astrohack_logger() + + if os.path.exists('/'.join((folder, file))): + logger.info("File exists.") + return + + if file not in FILE_ID.keys(): + logger.info("Requested file not found") + logger.info(FILE_ID.keys()) + + return + + fullname = FILE_ID[file]['file'] + id = FILE_ID[file]['id'] + rlkey = FILE_ID[file]['rlkey'] + + url = 'https://www.dropbox.com/scl/fi/{id}/{file}?rlkey={rlkey}'.format(id=id, file=fullname, rlkey=rlkey) + + headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} + + r = requests.get(url, stream=True, headers=headers) + total = int(r.headers.get('content-length', 0)) + + fullname = '/'.join((folder, fullname)) + + with open(fullname, 'wb') as fd, tqdm( + desc=fullname, + total=total, + unit='iB', + unit_scale=True, + unit_divisor=1024) as bar: + for chunk in r.iter_content(chunk_size=1024): + if chunk: + size = fd.write(chunk) + bar.update(size) + + if zipfile.is_zipfile(fullname): + shutil.unpack_archive(filename=fullname, extract_dir=folder) + + # Let's clean up after ourselves + os.remove(fullname) diff --git a/tests/unit/test_extract_locit.py b/tests/unit/test_extract_locit.py index 5218999c..38224988 100644 --- a/tests/unit/test_extract_locit.py +++ b/tests/unit/test_extract_locit.py @@ -12,8 +12,11 @@ class TestExtractLocit(): @classmethod def setup_class(cls): - """ setup any state specific to the execution of the given test class - such as fetching test data """ + """ + Setup any state specific to the execution of the given test class + such as fetching test data + """ + astrohack.data.datasets.download(file="locit-input-pha.cal", folder="data") @classmethod @@ -58,7 +61,7 @@ def test_extract_locit_antenna_select(self): assert len(locit_mds.keys()) == 3 # Check that only the specific antenna is in the keys. - assert list(locit_mds.keys()) == ['ant_ea17'] + assert list(locit_mds.keys()) == ['obs_info', 'ant_info', 'ant_ea17'] def test_extract_locit_ddi(self): """ @@ -76,6 +79,11 @@ def test_extract_locit_overwrite(self): Specify the output file should be overwritten; check that it WAS. """ + extract_locit( + self.cal_table, + locit_name=self.locit_name + ) + # To check this properly we need to not only know an exception was not thrown but that the file is ACTUALLY # overwritten. We do this by checking the modification time. initial_time = os.path.getctime(self.locit_name) @@ -94,6 +102,11 @@ def test_extract_locit_no_overwrite(self): """ Specify the output file should be NOT be overwritten; check that it WAS NOT. """ + extract_locit( + self.cal_table, + locit_name=self.locit_name + ) + initial_time = os.path.getctime(self.locit_name) try: diff --git a/tests/unit/test_panel.py b/tests/unit/test_panel.py index 9122b15e..6653e971 100644 --- a/tests/unit/test_panel.py +++ b/tests/unit/test_panel.py @@ -1,5 +1,3 @@ -import pytest - import os import json import shutil @@ -27,8 +25,6 @@ def setup_class(cls): astrohack.data.datasets.download(file='extract_holog_verification.json') astrohack.data.datasets.download(file='holog_numerical_verification.json') - astrohack.data.datasets.download(file='panel_cutoff_mask') - extract_pointing( ms_name="data/ea25_cal_small_after_fixed.split.ms", point_name="data/ea25_cal_small_after_fixed.split.point.zarr", @@ -208,6 +204,8 @@ def test_panel_cutoff(self): """ Set cutoff=0 and compare results to known truth value array. """ + astrohack.data.datasets.download(file='panel_cutoff_mask') + with open("panel_cutoff_mask.npy", "rb") as array: reference_array = np.load(array) From 7ffc79537ab491b1a13d16ab29b4664b28d9c213 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Wed, 4 Oct 2023 12:21:30 -0400 Subject: [PATCH 6/6] Update test_holog.py Fix typo mentioned by Sandra --- tests/unit/test_holog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_holog.py b/tests/unit/test_holog.py index cda4a978..72ef8764 100644 --- a/tests/unit/test_holog.py +++ b/tests/unit/test_holog.py @@ -209,7 +209,7 @@ def test_holog_grid_interpolation(self): def test_holog_chan_tolerance(self): """ - Check that channel tolerance ir propagated correctly. + Check that channel tolerance is propagated correctly. """ image_mds = holog( holog_name='data/ea25_cal_small_after_fixed.split.holog.zarr',