Skip to content

Commit

Permalink
Add a module for computing geometry of ocean transects
Browse files Browse the repository at this point in the history
  • Loading branch information
xylar committed Jul 20, 2020
1 parent f571841 commit 0f4b1a6
Showing 1 changed file with 269 additions and 0 deletions.
269 changes: 269 additions & 0 deletions conda_package/mpas_tools/ocean/transects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import xarray
import numpy


def find_transect_levels_and_weights(dsTransect, layerThickness, bottomDepth,
maxLevelCell):
dsTransectCells = _get_transect_cells_and_nodes(
dsTransect, layerThickness, bottomDepth, maxLevelCell)

dsTransectTriangles = _transect_cells_to_triangles(dsTransectCells)

return dsTransectTriangles


def _get_transect_cells_and_nodes(dsTransect, layerThickness, bottomDepth,
maxLevelCell):

if 'Time' in layerThickness.dims:
raise ValueError('Please select a single time level in layerThickness.')

dsTransect = dsTransect.rename({'nBounds': 'nHorizBounds'})

zTop, zMid, zBot, ssh, zSeaFloor, interpCellIndices, interpCellWeights = \
_get_vertical_coordinate(dsTransect, layerThickness, bottomDepth,
maxLevelCell)

dNode = dsTransect.dNode
nVertLevels = layerThickness.sizes['nVertLevels']

levelIndices = xarray.DataArray(data=numpy.arange(2*nVertLevels)//2,
dims='nHalfLevels')

cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nHalfLevels')

dsTransectCells = _add_valid_cells_and_levels(
dsTransect, dsTransect.cellIndices.values, levelIndices.values,
cellMask.values)

# transect cells are made up of half-levels, and each half-level has a top
# and bottom interface, so we need 4 interfaces per MPAS level

interpCellIndices, interpLevelIndices, interpCellWeights = \
_get_interp_indices_and_weights(layerThickness, maxLevelCell,
interpCellIndices, interpCellWeights)

levelIndex, tempIndex = numpy.meshgrid(numpy.arange(nVertLevels),
numpy.arange(2), indexing='ij')
levelIndex = xarray.DataArray(data=levelIndex.ravel(), dims='nHalfLevels')
tempIndex = xarray.DataArray(data=tempIndex.ravel(), dims='nHalfLevels')
zTop = xarray.concat((zTop, zMid), dim='nTemp')
zTop = zTop.isel(nVertLevels=levelIndex, nTemp=tempIndex)
zBot = xarray.concat((zMid, zBot), dim='nTemp')
zBot = zBot.isel(nVertLevels=levelIndex, nTemp=tempIndex)

zInterface = xarray.concat((zTop, zBot), dim='nVertBounds')

segmentIndices = dsTransectCells.segmentIndices
halfLevelIndices = dsTransectCells.halfLevelIndices

dsTransectCells['interpCellIndices'] = interpCellIndices.isel(
nSegments=segmentIndices, nHalfLevels=halfLevelIndices)
dsTransectCells['interpLevelIndices'] = interpLevelIndices.isel(
nSegments=segmentIndices, nHalfLevels=halfLevelIndices)
dsTransectCells['interpCellWeights'] = interpCellWeights.isel(
nSegments=segmentIndices, nHalfLevels=halfLevelIndices)
dsTransectCells['zInterface'] = zInterface.isel(
nSegments=segmentIndices, nHalfLevels=halfLevelIndices)

dsTransectCells['ssh'] = ssh
dsTransectCells['zSeaFloor'] = zSeaFloor

dims = ['nSegments', 'nTransectCells', 'nHorizBounds', 'nVertBounds',
'nWeights']
for dim in dsTransectCells.dims:
if dim not in dims:
dims.insert(0, dim)
dsTransectCells = dsTransectCells.transpose(*dims)

return dsTransectCells


def _get_vertical_coordinate(dsTransect, layerThickness, bottomDepth,
maxLevelCell):
nVertLevels = layerThickness.sizes['nVertLevels']
levelIndices = xarray.DataArray(data=numpy.arange(nVertLevels),
dims='nVertLevels')
cellMask = (levelIndices <= maxLevelCell).transpose('nCells', 'nVertLevels')

ssh = -bottomDepth + layerThickness.sum(dim='nVertLevels')

interpCellIndices = dsTransect.interpCellIndices
interpCellWeights = dsTransect.interpCellWeights

interpMask = numpy.logical_and(interpCellIndices > 0,
cellMask.isel(nCells=interpCellIndices))

interpCellWeights = interpMask*interpCellWeights
weightSum = interpCellWeights.sum(dim='nWeights')

cellIndices = dsTransect.cellIndices

validCells = cellMask.isel(nCells=cellIndices)

_, validWeights = xarray.broadcast(interpCellWeights, validCells)
interpCellWeights = (interpCellWeights/weightSum).where(validWeights)

layerThicknessTransect = layerThickness.isel(nCells=interpCellIndices)
layerThicknessTransect = (layerThicknessTransect*interpCellWeights).sum(
dim='nWeights')

sshTransect = ssh.isel(nCells=interpCellIndices)
sshTransect = (sshTransect*dsTransect.interpCellWeights).sum(dim='nWeights')

zBot = sshTransect - layerThicknessTransect.cumsum(dim='nVertLevels')
zTop = zBot + layerThicknessTransect
zMid = 0.5*(zTop + zBot)

zSeaFloor = sshTransect - layerThicknessTransect.sum(dim='nVertLevels')

return zTop, zMid, zBot, sshTransect, zSeaFloor, interpCellIndices, \
interpCellWeights


def _add_valid_cells_and_levels(dsTransect, cellIndices, levelIndices,
cellMask):

dims = ('nTransectCells',)
CellIndices, LevelIndices = numpy.meshgrid(cellIndices, levelIndices,
indexing='ij')
mask = numpy.logical_and(CellIndices >= 0, cellMask[cellIndices, :])

SegmentIndices, HalfLevelIndices = \
numpy.meshgrid(numpy.arange(len(cellIndices)),
numpy.arange(len(levelIndices)), indexing='ij')

segmentIndices = xarray.DataArray(data=SegmentIndices[mask], dims=dims)
dropVars = []
for var in dsTransect.data_vars:
if 'nWeights' in dsTransect[var].dims:
dropVars.append(var)
if len(dropVars) > 0:
dsTransect = dsTransect.drop_vars(dropVars)

dsTransectCells = dsTransect
dsTransectCells['cellIndices'] = (dims, CellIndices[mask])
dsTransectCells['levelIndices'] = (dims, LevelIndices[mask])
dsTransectCells['segmentIndices'] = segmentIndices
dsTransectCells['halfLevelIndices'] = (dims, HalfLevelIndices[mask])

return dsTransectCells


def _get_interp_indices_and_weights(layerThickness, maxLevelCell,
interpCellIndices, interpCellWeights):
nVertLevels = layerThickness.sizes['nVertLevels']
nHalfLevels = 2*nVertLevels
nVertBounds = 2

interpMaxLevelCell = maxLevelCell.isel(nCells=interpCellIndices)

levelIndices = xarray.DataArray(
data=numpy.arange(nHalfLevels)//2, dims='nHalfLevels')
valid = levelIndices <= interpMaxLevelCell

topLevelIndices = -1*numpy.ones((nHalfLevels, nVertBounds), int)
topLevelIndices[1:, 0] = numpy.arange(nHalfLevels-1)//2
topLevelIndices[:, 1] = numpy.arange(nHalfLevels)//2
topLevelIndices = xarray.DataArray(
data=topLevelIndices, dims=('nHalfLevels', 'nVertBounds'))
interpCellIndices, topLevelIndices = \
xarray.broadcast(interpCellIndices, topLevelIndices)
topLevelIndices = topLevelIndices.where(valid, -1)

botLevelIndices = numpy.zeros((nHalfLevels, nVertBounds), int)
botLevelIndices[:, 0] = numpy.arange(nHalfLevels)//2
botLevelIndices[:, 1] = numpy.arange(1, nHalfLevels+1)//2
botLevelIndices = xarray.DataArray(
data=botLevelIndices, dims=('nHalfLevels', 'nVertBounds'))
_, botLevelIndices = xarray.broadcast(interpCellIndices, botLevelIndices)
botLevelIndices = botLevelIndices.where(valid, -1)
botLevelIndices = numpy.minimum(botLevelIndices, interpMaxLevelCell)

interpLevelIndices = xarray.concat((topLevelIndices, botLevelIndices),
dim='nWeights')

topHalfLevelThickness = 0.5*layerThickness.isel(
nCells=interpCellIndices, nVertLevels=topLevelIndices)
botHalfLevelThickness = 0.5*layerThickness.isel(
nCells=interpCellIndices, nVertLevels=botLevelIndices)

# vertical weights are proportional to the half-level thickness
interpCellWeights = xarray.concat(
(topHalfLevelThickness*interpCellWeights.isel(nVertLevels=topLevelIndices),
botHalfLevelThickness*interpCellWeights.isel(nVertLevels=botLevelIndices)),
dim='nWeights')

weightSum = interpCellWeights.sum(dim='nWeights')
_, outMask = xarray.broadcast(interpCellWeights, weightSum > 0.)
interpCellWeights = (interpCellWeights/weightSum).where(outMask)

interpCellIndices = xarray.concat((interpCellIndices, interpCellIndices),
dim='nWeights')

return interpCellIndices, interpLevelIndices, interpCellWeights


def _transect_cells_to_triangles(dsTransectCells):

nTransectCells = dsTransectCells.sizes['nTransectCells']
nTransectTriangles = 2*nTransectCells
triangleTransectCellIndices = numpy.arange(nTransectTriangles)//2
nodeTransectCellIndices = numpy.zeros((nTransectTriangles, 3), int)
nodeHorizBoundsIndices = numpy.zeros((nTransectTriangles, 3), int)
nodeVertBoundsIndices = numpy.zeros((nTransectTriangles, 3), int)

for index in range(3):
nodeTransectCellIndices[:, index] = triangleTransectCellIndices

# the upper triangle
nodeHorizBoundsIndices[0::2, 0] = 0
nodeVertBoundsIndices[0::2, 0] = 0
nodeHorizBoundsIndices[0::2, 1] = 1
nodeVertBoundsIndices[0::2, 1] = 0
nodeHorizBoundsIndices[0::2, 2] = 0
nodeVertBoundsIndices[0::2, 2] = 1

# the lower triangle
nodeHorizBoundsIndices[1::2, 0] = 0
nodeVertBoundsIndices[1::2, 0] = 1
nodeHorizBoundsIndices[1::2, 1] = 1
nodeVertBoundsIndices[1::2, 1] = 0
nodeHorizBoundsIndices[1::2, 2] = 1
nodeVertBoundsIndices[1::2, 2] = 1

triangleTransectCellIndices = xarray.DataArray(
data=triangleTransectCellIndices, dims='nTransectTriangles')
nodeTransectCellIndices = xarray.DataArray(
data=nodeTransectCellIndices,
dims=('nTransectTriangles', 'nTriangleNodes'))
nodeHorizBoundsIndices = xarray.DataArray(
data=nodeHorizBoundsIndices,
dims=('nTransectTriangles', 'nTriangleNodes'))
nodeVertBoundsIndices = xarray.DataArray(
data=nodeVertBoundsIndices,
dims=('nTransectTriangles', 'nTriangleNodes'))

dsTransectTriangles = xarray.Dataset()
dsTransectTriangles['nodeHorizBoundsIndices'] = \
nodeHorizBoundsIndices
for var_name in dsTransectCells.data_vars:
var = dsTransectCells[var_name]
if 'nTransectCells' in var.dims:
if 'nVertBounds' in var.dims:
assert 'nHorizBounds' in var.dims
dsTransectTriangles[var_name] = var.isel(
nTransectCells=nodeTransectCellIndices,
nHorizBounds=nodeHorizBoundsIndices,
nVertBounds=nodeVertBoundsIndices)
elif 'nHorizBounds' in var.dims:
dsTransectTriangles[var_name] = var.isel(
nTransectCells=nodeTransectCellIndices,
nHorizBounds=nodeHorizBoundsIndices)
else:
dsTransectTriangles[var_name] = var.isel(
nTransectCells=triangleTransectCellIndices)
else:
dsTransectTriangles[var_name] = var

return dsTransectTriangles

0 comments on commit 0f4b1a6

Please sign in to comment.