diff --git a/conda_package/mpas_tools/ocean/transects.py b/conda_package/mpas_tools/ocean/transects.py new file mode 100644 index 000000000..5bab28629 --- /dev/null +++ b/conda_package/mpas_tools/ocean/transects.py @@ -0,0 +1,269 @@ +import xarray +import numpy + + +def find_transect_levels_and_weights(dsTransect, layerThickness, bottomDepth, + maxLevelCell): + dsTransectCells = _get_transect_cells_and_nodes( + dsTransect, layerThickness, bottomDepth, maxLevelCell) + + dsTransectTriangles = _transect_cells_to_triangles(dsTransectCells) + + return dsTransectTriangles + + +def _get_transect_cells_and_nodes(dsTransect, layerThickness, bottomDepth, + maxLevelCell): + + if 'Time' in layerThickness.dims: + raise ValueError('Please select a single time level in layerThickness.') + + dsTransect = dsTransect.rename({'nBounds': 'nHorizBounds'}) + + zTop, zMid, zBot, ssh, zSeaFloor, interpCellIndices, interpCellWeights = \ + _get_vertical_coordinate(dsTransect, layerThickness, bottomDepth, + maxLevelCell) + + dNode = dsTransect.dNode + nVertLevels = layerThickness.sizes['nVertLevels'] + + levelIndices = xarray.DataArray(data=numpy.arange(2*nVertLevels)//2, + dims='nHalfLevels') + + cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nHalfLevels') + + dsTransectCells = _add_valid_cells_and_levels( + dsTransect, dsTransect.cellIndices.values, levelIndices.values, + cellMask.values) + + # transect cells are made up of half-levels, and each half-level has a top + # and bottom interface, so we need 4 interfaces per MPAS level + + interpCellIndices, interpLevelIndices, interpCellWeights = \ + _get_interp_indices_and_weights(layerThickness, maxLevelCell, + interpCellIndices, interpCellWeights) + + levelIndex, tempIndex = numpy.meshgrid(numpy.arange(nVertLevels), + numpy.arange(2), indexing='ij') + levelIndex = xarray.DataArray(data=levelIndex.ravel(), dims='nHalfLevels') + tempIndex = xarray.DataArray(data=tempIndex.ravel(), dims='nHalfLevels') + zTop = xarray.concat((zTop, zMid), dim='nTemp') + zTop = zTop.isel(nVertLevels=levelIndex, nTemp=tempIndex) + zBot = xarray.concat((zMid, zBot), dim='nTemp') + zBot = zBot.isel(nVertLevels=levelIndex, nTemp=tempIndex) + + zInterface = xarray.concat((zTop, zBot), dim='nVertBounds') + + segmentIndices = dsTransectCells.segmentIndices + halfLevelIndices = dsTransectCells.halfLevelIndices + + dsTransectCells['interpCellIndices'] = interpCellIndices.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['interpLevelIndices'] = interpLevelIndices.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['interpCellWeights'] = interpCellWeights.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + dsTransectCells['zInterface'] = zInterface.isel( + nSegments=segmentIndices, nHalfLevels=halfLevelIndices) + + dsTransectCells['ssh'] = ssh + dsTransectCells['zSeaFloor'] = zSeaFloor + + dims = ['nSegments', 'nTransectCells', 'nHorizBounds', 'nVertBounds', + 'nWeights'] + for dim in dsTransectCells.dims: + if dim not in dims: + dims.insert(0, dim) + dsTransectCells = dsTransectCells.transpose(*dims) + + return dsTransectCells + + +def _get_vertical_coordinate(dsTransect, layerThickness, bottomDepth, + maxLevelCell): + nVertLevels = layerThickness.sizes['nVertLevels'] + levelIndices = xarray.DataArray(data=numpy.arange(nVertLevels), + dims='nVertLevels') + cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nVertLevels') + + ssh = -bottomDepth + layerThickness.sum(dim='nVertLevels') + + interpCellIndices = dsTransect.interpCellIndices + interpCellWeights = dsTransect.interpCellWeights + + interpMask = numpy.logical_and(interpCellIndices > 0, + cellMask.isel(nCells=interpCellIndices)) + + interpCellWeights = interpMask*interpCellWeights + weightSum = interpCellWeights.sum(dim='nWeights') + + cellIndices = dsTransect.cellIndices + + validCells = cellMask.isel(nCells=cellIndices) + + _, validWeights = xarray.broadcast(interpCellWeights, validCells) + interpCellWeights = (interpCellWeights/weightSum).where(validWeights) + + layerThicknessTransect = layerThickness.isel(nCells=interpCellIndices) + layerThicknessTransect = (layerThicknessTransect*interpCellWeights).sum( + dim='nWeights') + + sshTransect = ssh.isel(nCells=interpCellIndices) + sshTransect = (sshTransect*dsTransect.interpCellWeights).sum(dim='nWeights') + + zBot = sshTransect - layerThicknessTransect.cumsum(dim='nVertLevels') + zTop = zBot + layerThicknessTransect + zMid = 0.5*(zTop + zBot) + + zSeaFloor = sshTransect - layerThicknessTransect.sum(dim='nVertLevels') + + return zTop, zMid, zBot, sshTransect, zSeaFloor, interpCellIndices, \ + interpCellWeights + + +def _add_valid_cells_and_levels(dsTransect, cellIndices, levelIndices, + cellMask): + + dims = ('nTransectCells',) + CellIndices, LevelIndices = numpy.meshgrid(cellIndices, levelIndices, + indexing='ij') + mask = numpy.logical_and(CellIndices >= 0, cellMask[cellIndices, :]) + + SegmentIndices, HalfLevelIndices = \ + numpy.meshgrid(numpy.arange(len(cellIndices)), + numpy.arange(len(levelIndices)), indexing='ij') + + segmentIndices = xarray.DataArray(data=SegmentIndices[mask], dims=dims) + dropVars = [] + for var in dsTransect.data_vars: + if 'nWeights' in dsTransect[var].dims: + dropVars.append(var) + if len(dropVars) > 0: + dsTransect = dsTransect.drop_vars(dropVars) + + dsTransectCells = dsTransect + dsTransectCells['cellIndices'] = (dims, CellIndices[mask]) + dsTransectCells['levelIndices'] = (dims, LevelIndices[mask]) + dsTransectCells['segmentIndices'] = segmentIndices + dsTransectCells['halfLevelIndices'] = (dims, HalfLevelIndices[mask]) + + return dsTransectCells + + +def _get_interp_indices_and_weights(layerThickness, maxLevelCell, + interpCellIndices, interpCellWeights): + nVertLevels = layerThickness.sizes['nVertLevels'] + nHalfLevels = 2*nVertLevels + nVertBounds = 2 + + interpMaxLevelCell = maxLevelCell.isel(nCells=interpCellIndices) + + levelIndices = xarray.DataArray( + data=numpy.arange(nHalfLevels)//2, dims='nHalfLevels') + valid = levelIndices <= interpMaxLevelCell + + topLevelIndices = -1*numpy.ones((nHalfLevels, nVertBounds), int) + topLevelIndices[1:, 0] = numpy.arange(nHalfLevels-1)//2 + topLevelIndices[:, 1] = numpy.arange(nHalfLevels)//2 + topLevelIndices = xarray.DataArray( + data=topLevelIndices, dims=('nHalfLevels', 'nVertBounds')) + interpCellIndices, topLevelIndices = \ + xarray.broadcast(interpCellIndices, topLevelIndices) + topLevelIndices = topLevelIndices.where(valid, -1) + + botLevelIndices = numpy.zeros((nHalfLevels, nVertBounds), int) + botLevelIndices[:, 0] = numpy.arange(nHalfLevels)//2 + botLevelIndices[:, 1] = numpy.arange(1, nHalfLevels+1)//2 + botLevelIndices = xarray.DataArray( + data=botLevelIndices, dims=('nHalfLevels', 'nVertBounds')) + _, botLevelIndices = xarray.broadcast(interpCellIndices, botLevelIndices) + botLevelIndices = botLevelIndices.where(valid, -1) + botLevelIndices = numpy.minimum(botLevelIndices, interpMaxLevelCell) + + interpLevelIndices = xarray.concat((topLevelIndices, botLevelIndices), + dim='nWeights') + + topHalfLevelThickness = 0.5*layerThickness.isel( + nCells=interpCellIndices, nVertLevels=topLevelIndices) + botHalfLevelThickness = 0.5*layerThickness.isel( + nCells=interpCellIndices, nVertLevels=botLevelIndices) + + # vertical weights are proportional to the half-level thickness + interpCellWeights = xarray.concat( + (topHalfLevelThickness*interpCellWeights.isel(nVertLevels=topLevelIndices), + botHalfLevelThickness*interpCellWeights.isel(nVertLevels=botLevelIndices)), + dim='nWeights') + + weightSum = interpCellWeights.sum(dim='nWeights') + _, outMask = xarray.broadcast(interpCellWeights, weightSum > 0.) + interpCellWeights = (interpCellWeights/weightSum).where(outMask) + + interpCellIndices = xarray.concat((interpCellIndices, interpCellIndices), + dim='nWeights') + + return interpCellIndices, interpLevelIndices, interpCellWeights + + +def _transect_cells_to_triangles(dsTransectCells): + + nTransectCells = dsTransectCells.sizes['nTransectCells'] + nTransectTriangles = 2*nTransectCells + triangleTransectCellIndices = numpy.arange(nTransectTriangles)//2 + nodeTransectCellIndices = numpy.zeros((nTransectTriangles, 3), int) + nodeHorizBoundsIndices = numpy.zeros((nTransectTriangles, 3), int) + nodeVertBoundsIndices = numpy.zeros((nTransectTriangles, 3), int) + + for index in range(3): + nodeTransectCellIndices[:, index] = triangleTransectCellIndices + + # the upper triangle + nodeHorizBoundsIndices[0::2, 0] = 0 + nodeVertBoundsIndices[0::2, 0] = 0 + nodeHorizBoundsIndices[0::2, 1] = 1 + nodeVertBoundsIndices[0::2, 1] = 0 + nodeHorizBoundsIndices[0::2, 2] = 0 + nodeVertBoundsIndices[0::2, 2] = 1 + + # the lower triangle + nodeHorizBoundsIndices[1::2, 0] = 0 + nodeVertBoundsIndices[1::2, 0] = 1 + nodeHorizBoundsIndices[1::2, 1] = 1 + nodeVertBoundsIndices[1::2, 1] = 0 + nodeHorizBoundsIndices[1::2, 2] = 1 + nodeVertBoundsIndices[1::2, 2] = 1 + + triangleTransectCellIndices = xarray.DataArray( + data=triangleTransectCellIndices, dims='nTransectTriangles') + nodeTransectCellIndices = xarray.DataArray( + data=nodeTransectCellIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + nodeHorizBoundsIndices = xarray.DataArray( + data=nodeHorizBoundsIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + nodeVertBoundsIndices = xarray.DataArray( + data=nodeVertBoundsIndices, + dims=('nTransectTriangles', 'nTriangleNodes')) + + dsTransectTriangles = xarray.Dataset() + dsTransectTriangles['nodeHorizBoundsIndices'] = \ + nodeHorizBoundsIndices + for var_name in dsTransectCells.data_vars: + var = dsTransectCells[var_name] + if 'nTransectCells' in var.dims: + if 'nVertBounds' in var.dims: + assert 'nHorizBounds' in var.dims + dsTransectTriangles[var_name] = var.isel( + nTransectCells=nodeTransectCellIndices, + nHorizBounds=nodeHorizBoundsIndices, + nVertBounds=nodeVertBoundsIndices) + elif 'nHorizBounds' in var.dims: + dsTransectTriangles[var_name] = var.isel( + nTransectCells=nodeTransectCellIndices, + nHorizBounds=nodeHorizBoundsIndices) + else: + dsTransectTriangles[var_name] = var.isel( + nTransectCells=triangleTransectCellIndices) + else: + dsTransectTriangles[var_name] = var + + return dsTransectTriangles