From 04cf5edfe98f73adb7d2303dee3017af814e9480 Mon Sep 17 00:00:00 2001 From: Xylar Asay-Davis Date: Mon, 20 Jul 2020 17:12:42 +0200 Subject: [PATCH] Add a module for computing geometry of ocean transects --- conda_package/mpas_tools/ocean/transects.py | 571 ++++++++++++++++++++ 1 file changed, 571 insertions(+) create mode 100644 conda_package/mpas_tools/ocean/transects.py diff --git a/conda_package/mpas_tools/ocean/transects.py b/conda_package/mpas_tools/ocean/transects.py new file mode 100644 index 000000000..74056fa0f --- /dev/null +++ b/conda_package/mpas_tools/ocean/transects.py @@ -0,0 +1,571 @@ +import xarray +import numpy + + +def find_transect_levels_and_weights(dsTransect, layerThickness, bottomDepth, + maxLevelCell, zTransect=None): + """ + Construct a vertical coordinate for a transect produced by + ``mpas_tools.viz.transects.find_transect_cells_and_weights()``, then break + each resulting quadrilateral into 2 triangles that can later be visualized + with functions like ``tripcolor`` and ``tricontourf``. Also, compute + interpolation weights such that observations at points on the original + transect and with vertical cooridinate ``transectZ`` can be bilinearly + interpolated to the nodes of the transect. + Parameters + ---------- + dsTransect : xarray.Dataset + A dataset that defines nodes of the transect, the results of calling + ``find_transect_cells_and_weights()`` + + layerThickness : xarray.DataArray + layer thicknesses on the MPAS mesh + + bottomDepth : xarray.DataArray + the (positive down) depth of the seafloor on the MPAS mesh + + maxLevelCell : xarray.DataArray + the vertical zero-based index of the bathymetry on the MPAS mesh + + zTransect : xarray.DataArray, optional + the z coordinate of the transect (1D or 2D). If 2D, it must have the + same along-transect dimension as the lon and lat passed to + ``find_transect_cells_and_weights()`` + + Returns + ------- + dsTransectTriangles : xarray.Dataset + A dataset that contains nodes and triangles that make up a 2D transect. + For convenience in visualization, the quadrilaterals of each cell making + up the transect have been divided into an upper and lower triangle. The + nodes of the triangles are completely independent of one another, + allowing for potential jumps in fields values between nodes of different + triangles that are at the same location. This is convenient, for + example, when visualizing data with constant values within each MPAS + cell. + + There are ``nTransectTriangles = 2*nTransectCells`` triangles, each with + ``nTriangleNodes = 3`` nodes, where ``nTransectCells`` is the number of + valid transect cells (quadrilaterals) that are above the MPAS-Ocean + bathymetry. + + In addition to the variables and coordinates in ``dsTransect``, the + output dataset contains: + + - nodeHorizBoundsIndices: which of the ``nHorizBounds = 2`` + bounds of a horizontal transect segment a given node is on + - segmentIndices: the transect segment of a given triangle + - cellIndices: the MPAS-Ocean cell of a given triangle + - levelIndices: the MPAS-Ocean vertical level of a given triangle + + - zTransectNode: the vertical height of each triangle node + - ssh, zSeaFloor: the sea-surface height and sea-floor height at + each node of each transect segment + + - interpCellIndices, interpLevelIndices: the MPAS-Ocean cells and + levels from which the value at a given triangle node are + interpolated. This can involve up to `nWeights = 12` different + cells and levels. + - interpCellWeights: the weight to multiply each field value by + to perform interpolation to a triangle node. + + - transectInterpVertIndices, transectInterpVertWeights - if + ``zTransect`` is not ``None``, vertical indices and weights for + interpolating from the original transect grid to MPAS-Ocean + transect nodes. + + Interpolation of a DataArray from MPAS cells and levels to transect + triangle nodes can be performed with + ``interp_mpas_to_transect_triangle_nodes()``. Similarly, interpolation of a + DataArray (e.g. an observation) from the original transect grid to + transect triangle nodes can be performed with + ``interp_transect_grid_to_transect_triangle_nodes()`` + + To visualize constant values on MPAS cells, a field can be sampled + at indices ``nCells=cellIndices`` and ``nVertLevels=levelIndices``. + If a smoother visualization is desired, bilinear interpolation can be + performed by first sampling the field at indices + ``nCells=interpCellIndices`` and ``nVertLevels=interpLevelIndices`` and + then multiplying by ``interpCellWeights`` and summing over + ``nWeights``. + """ + + dsTransectCells = _get_transect_cells_and_nodes( + dsTransect, layerThickness, bottomDepth, maxLevelCell) + + dsTransectTriangles = _transect_cells_to_triangles(dsTransectCells) + + if zTransect is not None: + dsTransectTriangles = _add_vertical_interpolation_of_transect_points( + dsTransectTriangles, zTransect) + + return dsTransectTriangles + + +def interp_mpas_to_transect_triangles(dsTransectTriangles, da): + """ + Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray + to transect nodes with constant values in each MPAS cell + + Parameters + ---------- + dsTransectTriangles : xarray.Dataset + A dataset that defines triangles making up an MPAS-Ocean transect, the + results of calling ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + daNodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nTransectTriangles`` and ``nTriangleNodes`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + cellIndices = dsTransectTriangles.cellIndices + levelIndices = dsTransectTriangles.levelIndices + + daNodes = da.isel(nCells=cellIndices, nVertLevels=levelIndices) + + return daNodes + + +def interp_mpas_to_transect_triangle_nodes(dsTransectTriangles, da): + """ + Interpolate a 3D (``nCells`` by ``nVertLevels``) MPAS-Ocean DataArray + to transect nodes, linearly interpolating fields between the closest + neighboring cells + + Parameters + ---------- + dsTransectTriangles : xarray.Dataset + A dataset that defines triangles making up an MPAS-Ocean transect, the + results of calling ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An MPAS-Ocean 3D field with dimensions `nCells`` and ``nVertLevels`` + (possibly among others) + + Returns + ------- + daNodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nTransectTriangles`` and ``nTriangleNodes`` (in addition to whatever + dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) + """ + + interpCellIndices = dsTransectTriangles.interpCellIndices + interpLevelIndices = dsTransectTriangles.interpLevelIndices + interpCellWeights = dsTransectTriangles.interpCellWeights + + da = da.isel(nCells=interpCellIndices, nVertLevels=interpLevelIndices) + + daNodes = (da*interpCellWeights).sum(dim='nWeights') + + return daNodes + + +def interp_transect_grid_to_transect_triangle_nodes(dsTransectTriangles, da): + """ + Interpolate a DataArray on the original transect grid to triangle nodes on + the MPAS-Ocean transect. + + Parameters + ---------- + dsTransectTriangles : xarray.Dataset + A dataset that defines triangles making up an MPAS-Ocean transect, the + results of calling ``find_transect_levels_and_weights()`` + + da : xarray.DataArray + An field on the original triangle mesh + + Returns + ------- + daNodes : xarray.DataArray + The data array interpolated to transect nodes with dimensions + ``nTransectTriangles`` and ``nTriangleNodes`` + """ + + horizDim = dsTransectTriangles.dTransect.dims[0] + zTransect = dsTransectTriangles.zTransect + vertDim = None + for dim in zTransect.dims: + if dim != horizDim: + vertDim = dim + + horizIndices = dsTransectTriangles.transectIndicesOnHorizNode + horizWeights = dsTransectTriangles.transectWeightsOnHorizNode + + segmentIndices = dsTransectTriangles.segmentIndices + nodeHorizBoundsIndices = dsTransectTriangles.nodeHorizBoundsIndices + + horizIndices = horizIndices.isel(nSegments=segmentIndices, + nHorizBounds=nodeHorizBoundsIndices) + horizWeights = horizWeights.isel(nSegments=segmentIndices, + nHorizBounds=nodeHorizBoundsIndices) + + vertIndices = dsTransectTriangles.transectInterpVertIndices + vertWeights = dsTransectTriangles.transectInterpVertWeights + + kwargs00 = {horizDim: horizIndices, vertDim: vertIndices} + kwargs01 = {horizDim: horizIndices, vertDim: vertIndices+1} + kwargs10 = {horizDim: horizIndices+1, vertDim: vertIndices} + kwargs11 = {horizDim: horizIndices+1, vertDim: vertIndices+1} + + daNodes = (horizWeights * vertWeights * da.isel(**kwargs00) + + horizWeights * (1.0 - vertWeights) * da.isel(**kwargs01) + + (1.0 - horizWeights) * vertWeights * da.isel(**kwargs10) + + (1.0 - horizWeights) * (1.0 - vertWeights) * da.isel(**kwargs11)) + + + mask = numpy.logical_and(horizIndices != -1, vertIndices != -1) + + daNodes = daNodes.where(mask) + + return daNodes + + +def get_outline_segments(dsTransectTriangles): + """Get a set of line segments that outline the given transect""" + + dSeaFloor = dsTransectTriangles.dNode.values + zSeaFloor = dsTransectTriangles.zSeaFloor.values + ssh = dsTransectTriangles.ssh.values + + dJump = numpy.zeros((dSeaFloor.shape[0]-1, 2)) + zJump = numpy.zeros(dJump.shape) + dJump[:, 0] = dSeaFloor[0:-1, 1] + dJump[:, 1] = dSeaFloor[1:, 0] + zJump[:, 0] = zSeaFloor[0:-1, 1] + zJump[:, 1] = zSeaFloor[1:, 0] + mask = numpy.abs(dJump[:, 0] - dJump[:, 1]) < 1e-3 + dJump = dJump[mask, :] + zJump = zJump[mask, :] + + d = numpy.append(numpy.append(dSeaFloor, dSeaFloor, axis=0), + dJump, axis=0).T + z = numpy.append(numpy.append(ssh, zSeaFloor, axis=0), zJump, axis=0).T + + return d, z + + +def _get_transect_cells_and_nodes(dsTransect, layerThickness, bottomDepth, + maxLevelCell): + + if 'Time' in layerThickness.dims: + raise ValueError('Please select a single time level in layerThickness.') + + dsTransect = dsTransect.rename({'nBounds': 'nHorizBounds'}) + + zTop, zMid, zBot, ssh, zSeaFloor, interpCellIndices, interpCellWeights = \ + _get_vertical_coordinate(dsTransect, layerThickness, bottomDepth, + maxLevelCell) + + nVertLevels = layerThickness.sizes['nVertLevels'] + + levelIndices = xarray.DataArray(data=numpy.arange(2*nVertLevels)//2, + dims='nHalfLevels') + + cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nHalfLevels') + + dsTransectCells = _add_valid_cells_and_levels( + dsTransect, dsTransect.cellIndices.values, levelIndices.values, + cellMask.values) + + # transect cells are made up of half-levels, and each half-level has a top + # and bottom interface, so we need 4 interfaces per MPAS level + + interpCellIndices, interpLevelIndices, interpCellWeights = \ + _get_interp_indices_and_weights(layerThickness, maxLevelCell, + interpCellIndices, interpCellWeights) + + levelIndex, tempIndex = numpy.meshgrid(numpy.arange(nVertLevels), + numpy.arange(2), indexing='ij') + levelIndex = xarray.DataArray(data=levelIndex.ravel(), dims='nHalfLevels') + tempIndex = xarray.DataArray(data=tempIndex.ravel(), dims='nHalfLevels') + zTop = xarray.concat((zTop, zMid), dim='nTemp') + zTop = zTop.isel(nVertLevels=levelIndex, nTemp=tempIndex) + zBot = xarray.concat((zMid, zBot), dim='nTemp') + zBot = zBot.isel(nVertLevels=levelIndex, nTemp=tempIndex) + + zInterface = xarray.concat((zTop, zBot), dim='nVertBounds') + + segmentIndices = dsTransectCells.segmentIndices + halfLevelIndices = dsTransectCells.halfLevelIndices + + dsTransectCells['interpCellIndices'] = interpCellIndices.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['interpLevelIndices'] = interpLevelIndices.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['interpCellWeights'] = interpCellWeights.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['zTransectNode'] = zInterface.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + + dsTransectCells['ssh'] = ssh + dsTransectCells['zSeaFloor'] = zSeaFloor + + dims = ['nSegments', 'nTransectCells', 'nHorizBounds', 'nVertBounds', + 'nHorizWeights', 'nWeights'] + for dim in dsTransectCells.dims: + if dim not in dims: + dims.insert(0, dim) + dsTransectCells = dsTransectCells.transpose(*dims) + + return dsTransectCells + + +def _get_vertical_coordinate(dsTransect, layerThickness, bottomDepth, + maxLevelCell): + nVertLevels = layerThickness.sizes['nVertLevels'] + levelIndices = xarray.DataArray(data=numpy.arange(nVertLevels), + dims='nVertLevels') + cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nVertLevels') + + ssh = -bottomDepth + layerThickness.sum(dim='nVertLevels') + + interpCellIndices = dsTransect.interpHorizCellIndices + interpCellWeights = dsTransect.interpHorizCellWeights + + interpMask = numpy.logical_and(interpCellIndices > 0, + cellMask.isel(nCells=interpCellIndices)) + + interpCellWeights = interpMask*interpCellWeights + weightSum = interpCellWeights.sum(dim='nHorizWeights') + + cellIndices = dsTransect.cellIndices + + validCells = cellMask.isel(nCells=cellIndices) + + _, validWeights = xarray.broadcast(interpCellWeights, validCells) + interpCellWeights = (interpCellWeights/weightSum).where(validWeights) + + layerThicknessTransect = layerThickness.isel(nCells=interpCellIndices) + layerThicknessTransect = (layerThicknessTransect*interpCellWeights).sum( + dim='nHorizWeights') + + sshTransect = ssh.isel(nCells=interpCellIndices) + sshTransect = (sshTransect*dsTransect.interpHorizCellWeights).sum( + dim='nHorizWeights') + + zBot = sshTransect - layerThicknessTransect.cumsum(dim='nVertLevels') + zTop = zBot + layerThicknessTransect + zMid = 0.5*(zTop + zBot) + + zSeaFloor = sshTransect - layerThicknessTransect.sum(dim='nVertLevels') + + return zTop, zMid, zBot, sshTransect, zSeaFloor, interpCellIndices, \ + interpCellWeights + + +def _add_valid_cells_and_levels(dsTransect, cellIndices, levelIndices, + cellMask): + + dims = ('nTransectCells',) + CellIndices, LevelIndices = numpy.meshgrid(cellIndices, levelIndices, + indexing='ij') + mask = numpy.logical_and(CellIndices >= 0, cellMask[cellIndices, :]) + + SegmentIndices, HalfLevelIndices = \ + numpy.meshgrid(numpy.arange(len(cellIndices)), + numpy.arange(len(levelIndices)), indexing='ij') + + segmentIndices = xarray.DataArray(data=SegmentIndices[mask], dims=dims) + + dsTransectCells = dsTransect + dsTransectCells['cellIndices'] = (dims, CellIndices[mask]) + dsTransectCells['levelIndices'] = (dims, LevelIndices[mask]) + dsTransectCells['segmentIndices'] = segmentIndices + dsTransectCells['halfLevelIndices'] = (dims, HalfLevelIndices[mask]) + + return dsTransectCells + + +def _get_interp_indices_and_weights(layerThickness, maxLevelCell, + interpCellIndices, interpCellWeights): + interpCellIndices = interpCellIndices.rename({'nHorizWeights': 'nWeights'}) + interpCellWeights = interpCellWeights.rename({'nHorizWeights': 'nWeights'}) + nVertLevels = layerThickness.sizes['nVertLevels'] + nHalfLevels = 2*nVertLevels + nVertBounds = 2 + + interpMaxLevelCell = maxLevelCell.isel(nCells=interpCellIndices) + + levelIndices = xarray.DataArray( + data=numpy.arange(nHalfLevels)//2, dims='nHalfLevels') + valid = levelIndices <= interpMaxLevelCell + + topLevelIndices = -1*numpy.ones((nHalfLevels, nVertBounds), int) + topLevelIndices[1:, 0] = numpy.arange(nHalfLevels-1)//2 + topLevelIndices[:, 1] = numpy.arange(nHalfLevels)//2 + topLevelIndices = xarray.DataArray( + data=topLevelIndices, dims=('nHalfLevels', 'nVertBounds')) + interpCellIndices, topLevelIndices = \ + xarray.broadcast(interpCellIndices, topLevelIndices) + topLevelIndices = topLevelIndices.where(valid, -1) + + botLevelIndices = numpy.zeros((nHalfLevels, nVertBounds), int) + botLevelIndices[:, 0] = numpy.arange(nHalfLevels)//2 + botLevelIndices[:, 1] = numpy.arange(1, nHalfLevels+1)//2 + botLevelIndices = xarray.DataArray( + data=botLevelIndices, dims=('nHalfLevels', 'nVertBounds')) + _, botLevelIndices = xarray.broadcast(interpCellIndices, botLevelIndices) + botLevelIndices = botLevelIndices.where(valid, -1) + botLevelIndices = numpy.minimum(botLevelIndices, interpMaxLevelCell) + + interpLevelIndices = xarray.concat((topLevelIndices, botLevelIndices), + dim='nWeights') + + topHalfLevelThickness = 0.5*layerThickness.isel( + nCells=interpCellIndices, nVertLevels=topLevelIndices) + botHalfLevelThickness = 0.5*layerThickness.isel( + nCells=interpCellIndices, nVertLevels=botLevelIndices) + + # vertical weights are proportional to the half-level thickness + interpCellWeights = xarray.concat( + (topHalfLevelThickness*interpCellWeights.isel(nVertLevels=topLevelIndices), + botHalfLevelThickness*interpCellWeights.isel(nVertLevels=botLevelIndices)), + dim='nWeights') + + weightSum = interpCellWeights.sum(dim='nWeights') + _, outMask = xarray.broadcast(interpCellWeights, weightSum > 0.) + interpCellWeights = (interpCellWeights/weightSum).where(outMask) + + interpCellIndices = xarray.concat((interpCellIndices, interpCellIndices), + dim='nWeights') + + return interpCellIndices, interpLevelIndices, interpCellWeights + + +def _transect_cells_to_triangles(dsTransectCells): + + nTransectCells = dsTransectCells.sizes['nTransectCells'] + nTransectTriangles = 2*nTransectCells + triangleTransectCellIndices = numpy.arange(nTransectTriangles)//2 + nodeTransectCellIndices = numpy.zeros((nTransectTriangles, 3), int) + nodeHorizBoundsIndices = numpy.zeros((nTransectTriangles, 3), int) + nodeVertBoundsIndices = numpy.zeros((nTransectTriangles, 3), int) + + for index in range(3): + nodeTransectCellIndices[:, index] = triangleTransectCellIndices + + # the upper triangle + nodeHorizBoundsIndices[0::2, 0] = 0 + nodeVertBoundsIndices[0::2, 0] = 0 + nodeHorizBoundsIndices[0::2, 1] = 1 + nodeVertBoundsIndices[0::2, 1] = 0 + nodeHorizBoundsIndices[0::2, 2] = 0 + nodeVertBoundsIndices[0::2, 2] = 1 + + # the lower triangle + nodeHorizBoundsIndices[1::2, 0] = 0 + nodeVertBoundsIndices[1::2, 0] = 1 + nodeHorizBoundsIndices[1::2, 1] = 1 + nodeVertBoundsIndices[1::2, 1] = 0 + nodeHorizBoundsIndices[1::2, 2] = 1 + nodeVertBoundsIndices[1::2, 2] = 1 + + triangleTransectCellIndices = xarray.DataArray( + data=triangleTransectCellIndices, dims='nTransectTriangles') + nodeTransectCellIndices = xarray.DataArray( + data=nodeTransectCellIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + nodeHorizBoundsIndices = xarray.DataArray( + data=nodeHorizBoundsIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + nodeVertBoundsIndices = xarray.DataArray( + data=nodeVertBoundsIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + + dsTransectTriangles = xarray.Dataset() + dsTransectTriangles['nodeHorizBoundsIndices'] = \ + nodeHorizBoundsIndices + for var_name in dsTransectCells.data_vars: + var = dsTransectCells[var_name] + if 'nTransectCells' in var.dims: + if 'nVertBounds' in var.dims: + assert 'nHorizBounds' in var.dims + dsTransectTriangles[var_name] = var.isel( + nTransectCells=nodeTransectCellIndices, + nHorizBounds=nodeHorizBoundsIndices, + nVertBounds=nodeVertBoundsIndices) + elif 'nHorizBounds' in var.dims: + dsTransectTriangles[var_name] = var.isel( + nTransectCells=nodeTransectCellIndices, + nHorizBounds=nodeHorizBoundsIndices) + else: + dsTransectTriangles[var_name] = var.isel( + nTransectCells=triangleTransectCellIndices) + else: + dsTransectTriangles[var_name] = var + + dsTransectTriangles = dsTransectTriangles.drop_vars('halfLevelIndices') + + return dsTransectTriangles + + +def _add_vertical_interpolation_of_transect_points(dsTransectTriangles, + zTransect): + + dTransect = dsTransectTriangles.dTransect + # make sure zTransect is 2D + zTransect, _ = xarray.broadcast(zTransect, dTransect) + + assert len(zTransect.dims) == 2 + + horizDim = dTransect.dims[0] + vertDim = None + for dim in zTransect.dims: + if dim != horizDim: + vertDim = dim + + assert vertDim is not None + + nzTransect = zTransect.sizes[vertDim] + + horizIndices = dsTransectTriangles.transectIndicesOnHorizNode + horizWeights = dsTransectTriangles.transectWeightsOnHorizNode + kwargs0 = {horizDim: horizIndices} + kwargs1 = {horizDim: horizIndices+1} + zTransectAtHorizNodes = horizWeights*zTransect.isel(**kwargs0) + \ + (1.0 - horizWeights)*zTransect.isel(**kwargs1) + + zTriangleNode = dsTransectTriangles.zTransectNode + + segmentIndices = dsTransectTriangles.segmentIndices + nodeHorizBoundsIndices = dsTransectTriangles.nodeHorizBoundsIndices + + nTransectTriangles = dsTransectTriangles.sizes['nTransectTriangles'] + nTriangleNodes = dsTransectTriangles.sizes['nTriangleNodes'] + transectInterpVertIndices = -1*numpy.ones( + (nTransectTriangles, nTriangleNodes), int) + transectInterpVertWeights = numpy.zeros( + (nTransectTriangles, nTriangleNodes)) + + kwargs = {vertDim: 0, 'nSegments': segmentIndices, + 'nHorizBounds': nodeHorizBoundsIndices} + z0 = zTransectAtHorizNodes.isel(**kwargs) + for zIndex in range(nzTransect-1): + kwargs = {vertDim: zIndex+1, 'nSegments': segmentIndices, + 'nHorizBounds': nodeHorizBoundsIndices} + z1 = zTransectAtHorizNodes.isel(**kwargs) + mask = numpy.logical_and(zTriangleNode <= z0, zTriangleNode > z1) + mask = mask.values + weights = (z1 - zTriangleNode)/(z1 - z0) + + transectInterpVertIndices[mask] = zIndex + transectInterpVertWeights[mask] = weights.values[mask] + z0 = z1 + + dsTransectTriangles['transectInterpVertIndices'] = ( + ('nTransectTriangles', 'nTriangleNodes'), transectInterpVertIndices) + + dsTransectTriangles['transectInterpVertWeights'] = ( + ('nTransectTriangles', 'nTriangleNodes'), transectInterpVertWeights) + + dsTransectTriangles['zTransect'] = zTransect + + return dsTransectTriangles