Skip to content

Commit

Permalink
added test for mctrivers
Browse files Browse the repository at this point in the history
  • Loading branch information
Cinzia Mazzetti committed Dec 10, 2024
1 parent b5838b2 commit 9183937
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions tests/test_mctrivers.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,62 @@
import os
import shutil
import xarray as xr

from lisfloodutilities.compare.nc import NetCDFComparator

from lisfloodutilities.mctrivers.mctrivers import mct_mask


class TestMctMask:

pardir = 'tests/data/mctrivers'

def compare_masks(self, checked_case):
'check if the generated mct mask is identical to the expected ones'
gnr_mask = f'{pardir}/{checked_case}/mctmask.nc'
if os.path.exists(gnr_mask):
os.remove(gnr_mask)

# command = 'python3 src/lisfloodutilities/mctrivers/mctrivers.py'
# command += f' -i {pardir}/{checked_case}/changrad.nc -l {pardir}/{checked_case}/ldd.nc -u {pardir}/{checked_case}/upArea.nc'
# command += f' -m {pardir}/{checked_case}/mask.nc -S 0.001 -N 5 -U 500000000 -O {gnr_mask}'
# os.system(command)
mct_mask(f'{pardir}/{checked_case}/changrad.nc', f'{pardir}/{checked_case}/ldd.nc', f'{pardir}/{checked_case}/upArea.nc',
mask_file=f'{pardir}/{checked_case}/mask.nc', slp_threshold=0.001, nloops=5, minuparea=500000000,
outputfile=gnr_mask)

org_mask = f'{pardir}/{checked_case}/mctmask_original.nc'
original = xr.open_dataset(org_mask)
generated = xr.open_dataset(gnr_mask)

all_equal = original.equals(generated)
if all_equal:
print(f'Test for mct river mask generation for {checked_case} passed. Generated mask is deleted.')
os.remove(gnr_mask)
else:
fail_message = f'Test for mct river mask generation for {checked_case} failed.'
fail_message += f'Please check differences between the generated mask "{gnr_mask}" and the expected mask "{org_mask}".'
assert all_equal, fail_message


def test_mctmask(self):
'check mct masks for all available domains'
available_cases = os.listdir(pardir)
for i_case in available_cases:
self.compare_masks(i_case)
def mk_path_out(p):
path_out = os.path.join(os.path.dirname(__file__), p)
if os.path.exists(path_out):
shutil.rmtree(path_out)
os.mkdir(path_out)
return path_out

class TestMctMask():

case_dir = os.path.join(os.path.dirname(__file__), 'data', 'mctrivers')

def run(self, slp_threshold, nloops, minuparea, coords_names, type):

# setting
self.out_path_ref = os.path.join(self.case_dir, type, 'reference')
self.out_path_run = os.path.join(self.case_dir, type, 'out')
mk_path_out(self.out_path_run)

channels_slope_file = os.path.join(self.case_dir, type, 'changrad.nc')
ldd_file = os.path.join(self.case_dir, type, 'ldd.nc')
uparea_file = os.path.join(self.case_dir, type, 'upArea.nc')
mask_file = os.path.join(self.case_dir, type, 'mask.nc')
outputfile = os.path.join(self.out_path_run, 'chanmct.nc')

# generate the mct river mask
mct_mask(channels_slope_file, ldd_file, uparea_file, mask_file, slp_threshold, nloops, minuparea,outputfile)

# compare results with reference
nc_comparator = NetCDFComparator(mask_file, array_equal=True)
nc_comparator.compare_dirs(self.out_path_run, self.out_path_ref)

def teardown_method(self):
print('Cleaning directories')
out_path = os.path.join(self.case_dir, type, 'out')
if os.path.exists(out_path) and os.path.isdir(out_path):
shutil.rmtree(out_path, ignore_errors=True)


class TestMctrivers(TestMctMask):

# slp_threshold = 0.001
# nloops = 5
# minuparea = 0
# coords_names = 'None'

def test_mctrivers_etrs89(self):
self.run(0.001, 5, 0, 'None', 'LF_ETRS89_UseCase')

def test_mctrivers_latlon(self):
self.run( 0.001, 5, 0, 'x' 'y', 'LF_lat_lon_UseCase')

def cleaning(self,):
self.teardown_method()

0 comments on commit 9183937

Please sign in to comment.