Skip to content

Commit

Permalink
Merge pull request #124 from xylar/fix-for-pyremap-1.1.0
Browse files Browse the repository at this point in the history
Updates for pyremap >= 1.1.0
  • Loading branch information
xylar authored Sep 27, 2023
2 parents 0cb07fc + 1c5dd25 commit 157b13f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 54 deletions.
2 changes: 1 addition & 1 deletion deploy/conda-dev-spec.template
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ otps={{ otps }}
progressbar2
pyamg>=4.2.2
pyproj
pyremap>=1.0.0,<2.0.0
pyremap>=1.1.0,<2.0.0
ruamel.yaml
requests
scipy>=1.8.0
Expand Down
70 changes: 17 additions & 53 deletions polaris/remap/mapping_file_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from pyremap import (
LatLon2DGridDescriptor,
LatLonGridDescriptor,
MpasCellMeshDescriptor,
MpasEdgeMeshDescriptor,
MpasMeshDescriptor,
MpasVertexMeshDescriptor,
PointCollectionDescriptor,
ProjectionGridDescriptor,
Remapper,
Expand Down Expand Up @@ -376,10 +377,10 @@ def _build_mapping_file_args(remapper, method, src_mesh_filename='src_mesh.nc',
_check_remapper(remapper, method)

src_descriptor = remapper.sourceDescriptor
src_loc = _write_mesh_and_get_location(src_descriptor, src_mesh_filename)
src_descriptor.to_scrip(src_mesh_filename)

dst_descriptor = remapper.destinationDescriptor
dst_loc = _write_mesh_and_get_location(dst_descriptor, dst_mesh_filename)
dst_descriptor.to_scrip(dst_mesh_filename)

args = ['ESMF_RegridWeightGen',
'--source', src_mesh_filename,
Expand All @@ -389,11 +390,6 @@ def _build_mapping_file_args(remapper, method, src_mesh_filename='src_mesh.nc',
'--netcdf4',
'--no_log']

if src_loc is not None:
args.extend(['--src_loc', src_loc])
if dst_loc is not None:
args.extend(['--dst_loc', dst_loc])

if src_descriptor.regional:
args.append('--src_regional')

Expand All @@ -416,38 +412,6 @@ def _check_remapper(remapper, method):
raise ValueError(f'method {method} not supported for destination '
'grid of type PointCollectionDescriptor.')

if isinstance(remapper.sourceDescriptor, MpasMeshDescriptor) and \
remapper.sourceDescriptor.vertices:
if 'conserve' in method:
raise ValueError('Can\'t remap from MPAS vertices with '
'conservative methods')

if isinstance(remapper.destinationDescriptor, MpasMeshDescriptor) and \
remapper.destinationDescriptor.vertices:
if 'conserve' in method:
raise ValueError('Can\'t remap to MPAS vertices with '
'conservative methods')


def _write_mesh_and_get_location(descriptor, mesh_filename):
if isinstance(descriptor,
(MpasMeshDescriptor, MpasEdgeMeshDescriptor)):
file_format = 'esmf'
descriptor.to_esmf(mesh_filename)
else:
file_format = 'scrip'
descriptor.to_scrip(mesh_filename)

if file_format == 'esmf':
if isinstance(descriptor, MpasMeshDescriptor) and descriptor.vertices:
location = 'corner'
else:
location = 'center'
else:
location = None

return location


def _get_descriptor(info):
""" Get a mesh descriptor from the mesh info """
Expand All @@ -466,17 +430,17 @@ def _get_descriptor(info):


def _get_mpas_descriptor(info):
""" Get an MpasMeshDescriptor from the given info """
""" Get an MPAS mesh descriptor from the given info """
mesh_type = info['mpas_mesh_type']
filename = info['filename']
mesh_name = info['name']

if mesh_type == 'cell':
descriptor = MpasMeshDescriptor(fileName=filename, meshName=mesh_name,
vertices=False)
descriptor = MpasCellMeshDescriptor(fileName=filename,
meshName=mesh_name)
elif mesh_type == 'vertex':
descriptor = MpasMeshDescriptor(fileName=filename, meshName=mesh_name,
vertices=True)
descriptor = MpasVertexMeshDescriptor(fileName=filename,
meshName=mesh_name)
elif mesh_type == 'edge':
descriptor = MpasEdgeMeshDescriptor(fileName=filename,
meshName=mesh_name)
Expand All @@ -498,8 +462,8 @@ def _get_lon_lat_descriptor(info):
lonMax=lon_max)
else:
filename = info['filename']
lon = info['lon_var']
lat = info['lat_var']
lon = info['lon']
lat = info['lat']
with xr.open_dataset(filename) as ds:
lon_lat_1d = len(ds[lon].dims) == 1 and len(ds[lat].dims) == 1
lon_lat_2d = len(ds[lon].dims) == 2 and len(ds[lat].dims) == 2
Expand Down Expand Up @@ -527,9 +491,9 @@ def _get_lon_lat_descriptor(info):
def _get_proj_descriptor(info):
""" Get a ProjectionGridDescriptor from the given info """
filename = info['filename']
grid_name = info['names']
x = info['x_var']
y = info['y_var']
grid_name = info['name']
x = info['x']
y = info['y']
if 'proj_attr' in info:
with xr.open_dataset(filename) as ds:
proj_str = ds.attrs[info['proj_attr']]
Expand All @@ -550,9 +514,9 @@ def _get_proj_descriptor(info):
def _get_points_descriptor(info):
""" Get a PointCollectionDescriptor from the given info """
filename = info['filename']
collection_name = info['names']
lon_var = info['lon_var']
lat_var = info['lat_var']
collection_name = info['name']
lon_var = info['lon']
lat_var = info['lat']
with xr.open_dataset(filename) as ds:
lon = ds[lon_var].value
lat = ds[lat_var].values
Expand Down

0 comments on commit 157b13f

Please sign in to comment.