diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index b25e6796d..28b41de52 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -174,7 +174,6 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 patches, patch_mask = _compute_cell_patches(ds_mesh, patch_mask) elif 'nEdges' in field.dims: patch_mask = _edge_mask_from_cell_mask(ds_mesh, cell_mask) - patch_mask = _remove_boundary_edges_from_mask(ds_mesh, patch_mask) patches, patch_mask = _compute_edge_patches(ds_mesh, patch_mask) else: raise ValueError('Cannot plot a field without dim nCells or ' @@ -250,50 +249,26 @@ def _edge_mask_from_cell_mask(ds, cell_mask): return edge_mask -def _remove_boundary_edges_from_mask(ds, mask): - area_cell = ds.areaCell.values - mean_area_cell = np.mean(area_cell) - cells_on_edge = ds.cellsOnEdge.values - 1 - vertices_on_edge = ds.verticesOnEdge.values - 1 - x_cell = ds.xCell.values - y_cell = ds.yCell.values - boundary_vertex = ds.boundaryVertex.values - x_vertex = ds.xVertex.values - y_vertex = ds.yVertex.values - for edge_index in range(ds.sizes['nEdges']): - if not mask[edge_index]: - continue - cell_indices = cells_on_edge[edge_index] - vertex_indices = vertices_on_edge[edge_index, :] - if any(boundary_vertex[vertex_indices]): - mask[edge_index] = 0 - continue - vertices = np.zeros((4, 2)) - vertices[0, 0] = x_vertex[vertex_indices[0]] - vertices[0, 1] = y_vertex[vertex_indices[0]] - vertices[1, 0] = x_cell[cell_indices[0]] - vertices[1, 1] = y_cell[cell_indices[0]] - vertices[2, 0] = x_vertex[vertex_indices[1]] - vertices[2, 1] = y_vertex[vertex_indices[1]] - vertices[3, 0] = x_cell[cell_indices[1]] - vertices[3, 1] = y_cell[cell_indices[1]] - - # Remove edges that span the periodic boundaries - dx = max(vertices[:, 0]) - min(vertices[:, 0]) - dy = max(vertices[:, 1]) - min(vertices[:, 1]) - if dx * dy / 10 > mean_area_cell: - mask[edge_index] = 0 - - return mask - - def _compute_cell_patches(ds, mask): patches = [] num_vertices_on_cell = ds.nEdgesOnCell.values vertices_on_cell = ds.verticesOnCell.values - 1 + x_cell = ds.xCell.values + y_cell = ds.yCell.values x_vertex = ds.xVertex.values y_vertex = ds.yVertex.values - area_cell = ds.areaCell.values + + is_periodic = ds.attrs['is_periodic'].strip() == 'YES' + is_x_periodic = False + is_y_periodic = False + if is_periodic: + x_period = ds.attrs['x_period'] + if x_period > 0.: + is_x_periodic = True + y_period = ds.attrs['y_period'] + if y_period > 0.: + is_y_periodic = True + for cell_index in range(ds.sizes['nCells']): if not mask[cell_index]: continue @@ -303,14 +278,24 @@ def _compute_cell_patches(ds, mask): vertices[:, 0] = 1e-3 * x_vertex[vertex_indices] vertices[:, 1] = 1e-3 * y_vertex[vertex_indices] - # Remove cells that span the periodic boundaries - dx = max(x_vertex[vertex_indices]) - min(x_vertex[vertex_indices]) - dy = max(y_vertex[vertex_indices]) - min(y_vertex[vertex_indices]) - if dx * dy / 10 > area_cell[cell_index]: - mask[cell_index] = False - else: - polygon = Polygon(vertices, closed=True) - patches.append(polygon) + if is_x_periodic: + # Fix cells that span the periodic boundaries + for count, vertex_index in enumerate(vertex_indices): + vertices = _fix_vertices(vertices, + loc_center=x_cell[cell_index] * 1e-3, + index=count, + period=x_period * 1e-3, + period_index=0) + if is_y_periodic: + # Fix cells that span the periodic boundaries + for count, vertex_index in enumerate(vertex_indices): + vertices = _fix_vertices(vertices, + loc_center=y_cell[cell_index] * 1e-3, + index=count, + period=y_period * 1e-3, + period_index=1) + polygon = Polygon(vertices, closed=True) + patches.append(polygon) return patches, mask @@ -321,13 +306,32 @@ def _compute_edge_patches(ds, mask): vertices_on_edge = ds.verticesOnEdge.values - 1 x_cell = ds.xCell.values y_cell = ds.yCell.values + x_edge = ds.xEdge.values + y_edge = ds.yEdge.values x_vertex = ds.xVertex.values y_vertex = ds.yVertex.values + boundary_vertex = ds.boundaryVertex.values + + is_periodic = ds.attrs['is_periodic'].strip() == 'YES' + is_x_periodic = False + is_y_periodic = False + if is_periodic: + x_period = ds.attrs['x_period'] + if x_period > 0.: + is_x_periodic = True + y_period = ds.attrs['y_period'] + if y_period > 0.: + is_y_periodic = True + for edge_index in range(ds.sizes['nEdges']): if not mask[edge_index]: continue cell_indices = cells_on_edge[edge_index] vertex_indices = vertices_on_edge[edge_index, :] + # Remove edges on boundaries because they are always invalid + if any(boundary_vertex[vertex_indices]): + mask[edge_index] = 0 + continue vertices = np.zeros((4, 2)) vertices[0, 0] = 1e-3 * x_vertex[vertex_indices[0]] vertices[0, 1] = 1e-3 * y_vertex[vertex_indices[0]] @@ -337,8 +341,46 @@ def _compute_edge_patches(ds, mask): vertices[2, 1] = 1e-3 * y_vertex[vertex_indices[1]] vertices[3, 0] = 1e-3 * x_cell[cell_indices[1]] vertices[3, 1] = 1e-3 * y_cell[cell_indices[1]] - + if is_x_periodic: + # Fix cells that span the periodic boundaries + for count, vertex_index in enumerate(vertex_indices): + new_index = np.where(count == 0, 0, 2) + vertices = _fix_kite(vertices, + loc_center=x_edge[edge_index] * 1e-3, + index=new_index, + period=x_period * 1e-3, + period_index=0) + if is_y_periodic: + # Fix cells that span the periodic boundaries + for count, vertex_index in enumerate(vertex_indices): + new_index = np.where(count == 0, 0, 2) + vertices = _fix_kite(vertices, + loc_center=y_edge[edge_index] * 1e-3, + index=new_index, + period=y_period * 1e-3, + period_index=1) polygon = Polygon(vertices, closed=True) patches.append(polygon) return patches, mask + + +def _fix_vertices(vertices, loc_center, index, period, period_index): + if vertices[index, period_index] - loc_center > 0.5 * period: + vertices[index, period_index] += -period + elif vertices[index, period_index] - loc_center < -0.5 * period: + vertices[index, period_index] += period + return vertices + + +def _fix_kite(vertices, loc_center, index, period, period_index): + if vertices[index, period_index] - loc_center > 0.5 * period: + vertices[index, period_index] += -period + elif vertices[index, period_index] - loc_center < -0.5 * period: + vertices[index, period_index] += period + # We need to check the cell node of the kite as well + if vertices[index + 1, period_index] - loc_center > 0.5 * period: + vertices[index + 1, period_index] += -period + elif vertices[index + 1, period_index] - loc_center < -0.5 * period: + vertices[index + 1, period_index] += period + return vertices