Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise gridsearch #6

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion ProFSea-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ dependencies:
- matplotlib
- libwebp>=1.3.2
- cartopy
- iris
- iris
- tqdm
52 changes: 50 additions & 2 deletions profsea/slr_pkg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import cartopy.crs as ccrs

from profsea.config import settings
from profsea.slr_pkg import cmip, cubeplot, cubeutils, cubedata, process
from profsea.slr_pkg import cmip, cubeplot, cubeutils, cubedata, process, whichbox
from profsea.directories import makefolder, read_dir


Expand Down Expand Up @@ -244,6 +244,7 @@ def plot_ij(cube, model, location, idx, lat, lon, save_map=True, rad=5):
if targetlon > 180:
targetlon -= 360

fig = plt.figure()
ax = cubeplot.block(cube, land=False, region=region, cmin=-1, cmax=1,
plotcbar=True, nlevels=25, cent_lon=targetlon,
title='{} (1x1 grid) - SSH above geoid'.format(model))
Expand All @@ -265,7 +266,7 @@ def plot_ij(cube, model, location, idx, lat, lon, save_map=True, rad=5):
SRC_CRS, orig_lons, orig_lats).T

# Plot symbols showing the ocean point and the tide gauge
ax.plot(new_lons[0], new_lats[0], 'ok')
pred, = ax.plot(new_lons[0], new_lats[0], 'ok')
ax.plot(new_lons[1], new_lats[1], 'xr')

if save_map:
Expand All @@ -282,7 +283,54 @@ def plot_ij(cube, model, location, idx, lat, lon, save_map=True, rad=5):
plt.savefig(figfile)
plt.close()
else:
selected_lat = cube.coord('latitude').points[j]
selected_lon = cube.coord('longitude').points[i]

def onclick(event):
nonlocal selected_lat, selected_lon, pred
nonlocal i, j
selected_lon, selected_lat = event.xdata, event.ydata

# Transform onto original projection
MAP_CRS = ccrs.PlateCarree(central_longitude=targetlon)
SRC_CRS = ccrs.PlateCarree()

selected_lon, selected_lat = SRC_CRS.transform_point(
selected_lon, selected_lat, MAP_CRS)

(i, j), = whichbox.find_gridbox_indicies(cube,[(selected_lon, selected_lat)])

selected_lon = cube.coord('longitude').points[i]
selected_lat = cube.coord('latitude').points[j]

MAP_CRS = ccrs.PlateCarree(central_longitude=targetlon)
SRC_CRS = ccrs.PlateCarree()

plot_lon, plot_lat = MAP_CRS.transform_point(
selected_lon, selected_lat, SRC_CRS)

pred.remove()
pred, = ax.plot(plot_lon, plot_lat, 'ok')
fig.canvas.draw()

cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()


# Create the output file directory location
out_mapdir = read_dir()[1]
makefolder(out_mapdir)

# Abbreviate the site location name suitable to use as a filename
loc_abbrev = abbreviate_location_name(location)
figfile = os.path.join(out_mapdir,
f'{loc_abbrev}_{model}_ij_figure.png')

# Save the CMIP grid box selection map to file
fig.savefig(figfile)
plt.close()

return i, j, selected_lon, selected_lat


def read_ar5_component(datadir, rcp, var, value='mid'):
Expand Down
257 changes: 120 additions & 137 deletions profsea/step1_extract_cmip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,18 @@
import os
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm

from scipy.spatial import cKDTree

from profsea.config import settings
from profsea.directories import read_dir, makefolder
from profsea.slr_pkg import abbreviate_location_name, plot_ij # found in __init.py__
from profsea.slr_pkg import cmip, cubeutils, models, whichbox
from profsea.tide_gauge_locations import extract_site_info


def accept_reject_cmip(cube, model, site_loc, cmip_i, cmip_j, site_lat,
site_lon, unit_test=False):
"""
Accept or reject selected CMIP grid box based on a user input.
If CMIP grid box is rejected, search neighbouring grid boxes until a
suitable one is found.
:param cube: cube containing zos field from CMIP models
:param model: CMIP model name
:param site_loc: name of the site location
:param cmip_i: CMIP coord of site location's latitude
:param cmip_j: CMIP coord of site location's longitude
:param site_lat: latitude of the site location
:param site_lon: longitude of the site location
:param unit_test: flag to disable plotting for unit testing purposes
:return: Selected CMIP coords or None if user doesn't confirm grid box
selection
"""
# Plot a map of the site location and selected CMIP grid box for user
# validation. At this stage don't save the figure to file.
if not unit_test:
plot_ij(cube, model, site_loc, [cmip_i, cmip_j], site_lat, site_lon,
save_map=False)

# Ask user to accept cmip grid box or re-select from neighbouring cells
decision = input(f'Is selected CMIP grid box {cmip_i, cmip_j} '
f'appropriate for {model}? Y or N: ')
if decision == 'Y' or decision == 'y':
# Save map to file and return CMIP grid box coords
if not unit_test:
plot_ij(cube, model, site_loc, [cmip_i, cmip_j],
site_lat, site_lon)
return cmip_i, cmip_j
elif decision == 'N' or decision == 'n':
print('Selecting another CMIP grid box')
return None, None
else:
raise TypeError('Response needs to be Y or N')


def calc_radius_range(radius):
"""
Calculate the maximum distance to search for ocean grid point.
:param radius: Maximum range to search for ocean point
:return: x_radius_range, y_radius_range
"""
# if radius > 1:
# rm1 = radius - 1
# else:
# rm1 = radius

x_radius_range = list(range(-radius, radius + 1))
y_radius_range = list(range(-radius, radius + 1))

return x_radius_range, y_radius_range
warnings.filterwarnings("ignore")


def check_cube_mask(cube):
Expand All @@ -91,7 +41,6 @@ def check_cube_mask(cube):

if apply_mask:
new_mask = (cube.data == 0.0)
print(f'Scalar mask: Re-masking the cube to mask cells = 0')
cube = cube.copy(data=np.ma.array(cube.data, mask=new_mask))

return cube
Expand Down Expand Up @@ -163,44 +112,121 @@ def find_ocean_pt(zos_cube_in, model, site_loc, site_lat, site_lon):
:return: model grid box indices
"""
# Find grid box indices of location
(i, j), = whichbox.find_gridbox_indicies(zos_cube_in,
[(site_lon, site_lat)])
grid_lons = zos_cube_in.coord('longitude').points
grid_lats = zos_cube_in.coord('latitude').points

# Check to see if the cube has a scalar mask, and add mask where cmip
zos_cube = check_cube_mask(zos_cube_in)

# If the CMIP grid box of the exact site location is an ocean point
# Get the user to check if it's appropriate and if so return the indices
if not zos_cube.data.mask[j, i]:
print('Checking CMIP grid box at site location')
i_out, j_out = accept_reject_cmip(zos_cube, model, site_loc, i, j,
site_lat, site_lon)
if i_out is not None:
pt_lon = grid_lons[i_out]
pt_lat = grid_lats[j_out]
return i_out, j_out, pt_lon, pt_lat

# If no indices are returned then the CMIP grid box is not appropriate
# for use. Check the CMIP grid boxes surrounding the site location
# until an appropriate one is found.
else:
i_out, j_out, pt_lon, pt_lat = search_for_next_cmip(i, j,
zos_cube,
model,
site_loc,
site_lat,
site_lon)

# If the CMIP grid box of the exact site location is masked, start by
# checking the next set of cmip grid boxes

if settings["auto_site_selection"]:
search_distance = 2
else:
i_out, j_out, pt_lon, pt_lat = search_for_next_cmip(i, j, zos_cube,
model, site_loc,
site_lat, site_lon)

return i_out, j_out, pt_lon, pt_lat
search_distance = 7
best_i, best_j, best_lon, best_lat = find_best_gridcell(
zos_cube, site_lat, site_lon, max_distance=search_distance)

if settings["auto_site_selection"]:
plot_ij(zos_cube, model, site_loc, [best_i, best_j],
site_lat, site_lon, save_map=True)
return best_i, best_j, best_lon, best_lat

best_i, best_j, best_lon, best_lat = plot_ij(
zos_cube, model, site_loc, [best_i, best_j],
site_lat, site_lon, save_map=False)

return best_i, best_j, best_lon, best_lat


def find_best_gridcell(
cube, target_lat, target_lon,
max_distance=2, distance_weight=3,
difference_weight=0.003):
"""
Find the best grid cell in the CMIP model for the target latitude and
longitude. The best grid cell is the one that minimizes the weighted
score, which is computed as the distance to the target point minus a
grid cell difference parameter.
:param cube: iris.cube.Cube containing zos field from CMIP models
:param target_lat: Latitude of the target point
:param target_lon: Longitude of the target point
:param max_distance: Maximum distance to search for the best grid cell
:param distance_weight: Weight for the distance parameter
:param difference_weight: Weight for the difference parameter
:return: Best grid cell indices and coordinates
"""
lon_grid = cube.coord('longitude').points
lat_grid = cube.coord('latitude').points
data_grid = cube.data

# Get lat/lons onto grids, flatten and get mask
lat_mesh, lon_mesh = np.meshgrid(lat_grid, lon_grid, indexing='ij')
points = np.vstack([lat_mesh.ravel(), lon_mesh.ravel()]).T
masked_points = data_grid.mask.ravel()

tree = cKDTree(points[~masked_points])
dists, indices = tree.query([target_lat, target_lon], k=49) # 7x7 grid

best_i = None
best_j = None
best_lon = None
best_lat = None
min_weighted_score = float('inf')

def compute_weighted_score(lat_idx, lon_idx, dist):
value = data_grid[lat_idx, lon_idx]

# Pull out surrounding grid cell indices and values
surrounding_points = tree.query_ball_point(
[lat_mesh[lat_idx, lon_idx], lon_mesh[lat_idx, lon_idx]], 1)
surrounding_indices = np.where(~masked_points)[0][surrounding_points]
surrounding_lat_idx, surrounding_lon_idx = np.unravel_index(
surrounding_indices, data_grid.shape)
surrounding_values = data_grid[surrounding_lat_idx,
surrounding_lon_idx]

# Compute difference parameter
avg_surrounding_diffs = np.mean(np.abs(np.diff(surrounding_values)))
difference = np.mean(np.abs(surrounding_values - value))
diff_param = abs(avg_surrounding_diffs - difference)

# Compute weighted score
weighted_score = float((dist / distance_weight ) -
(difference_weight / diff_param))
return weighted_score

def check_and_update_best(lat_idx, lon_idx, dist):
nonlocal best_i, best_j, best_lat, best_lon
nonlocal compute_weighted_score, min_weighted_score

weighted_score = compute_weighted_score(lat_idx, lon_idx, dist)

if weighted_score < min_weighted_score:
min_weighted_score = weighted_score
best_lat = lat_mesh[lat_idx, lon_idx]
best_lon = lon_mesh[lat_idx, lon_idx]
best_i = nearest_lon_idx
best_j = nearest_lat_idx

candidate_points = sorted(zip(dists, indices), key=lambda x: x[0])
for dist, idx in candidate_points:
if dist > max_distance:
break

flat_idx = np.where(~masked_points)[0][idx]
nearest_lat_idx, nearest_lon_idx = np.unravel_index(
flat_idx, data_grid.shape)

check_and_update_best(nearest_lat_idx, nearest_lon_idx, dist)

if best_i is None or best_j is None:
if settings["auto_site_selection"]:
raise ValueError("Could not find a suitable grid cell due to the "
"model mask. This region might be complex - "
"please re-run with "
"\033[1mauto_site_selection: False \033[0m")
else:
raise ValueError("No valid points in the vicinity of the site. "
"Have you selected a lat/lon near the coast?")

return best_i, best_j, best_lon, best_lat


def ocean_point_wrapper(df, model_names, cubes):
Expand All @@ -222,60 +248,17 @@ def ocean_point_wrapper(df, model_names, cubes):
# Setup empty 2D list to store results for each model
# [name, i and j coords, lat and lon value]
result = []
for n, zos_cube in enumerate(cubes):
model = model_names[n]
i, j, pt_lon, pt_lat = find_ocean_pt(zos_cube, model, site_loc,
print(f'\nExtracting grid cells for {site_loc}')
for i in tqdm(range(len(cubes))):
model = model_names[i]
i, j, pt_lon, pt_lat = find_ocean_pt(cubes[i], model, site_loc,
lat, lon)
result.append([model, i, j, pt_lon, pt_lat])

# Write the data to a file
write_i_j(site_loc, result, lat, lon_orig)


def search_for_next_cmip(cmip_i, cmip_j, cube, model, site_loc, site_lat,
site_lon, unit_test=False):
"""
Iteratively check the CMIP grid boxes surrounding the site location
until a suitable option is found.
:param cmip_i: CMIP coord of site location's latitude
:param cmip_j: CMIP coord of site location's longitude
:param cube: cube containing zos field from CMIP models
:param model: CMIP model name
:param site_loc: name of the site location
:param site_lat: latitude of the site location
:param site_lon: longitude of the site location
:param unit_test: flag to disable plotting for unit testing purposes
:return: Selected CMIP coords
"""
grid_lons = cube.coord('longitude').points
grid_lats = cube.coord('latitude').points

# The radius limit of 7 is arbitrary but should be large enough.
for radius in range(1, 8): # grid boxes
print(f'Checking CMIP grid boxes {radius} box removed ' +
'from site location')
x_radius_range, y_radius_range = calc_radius_range(radius)
for ix in x_radius_range:
for iy in y_radius_range:
# Search the nearest grid cells. If the new mask is False,
# that grid cell is an ocean point
limit_lo = radius * radius
dd = ix * ix + iy * iy
if dd >= limit_lo:
# modulus for when grid cell is close to 0deg.
i_try = (cmip_i + ix) % len(grid_lons)
j_try = cmip_j + iy

if not cube.data.mask[j_try, i_try]:
i_out, j_out = accept_reject_cmip(
cube, model, site_loc, i_try, j_try, site_lat,
site_lon, unit_test)
if i_out is not None:
pt_lon = grid_lons[i_out]
pt_lat = grid_lats[j_out]
return i_out, j_out, pt_lon, pt_lat


def write_i_j(site_loc, result, site_lat, lon_orig):
"""
Convert the grid indices to a data frame and writes to file.
Expand Down
Loading