Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for the secondary y axis functionality #723

Merged
merged 13 commits into from
Nov 29, 2023
2 changes: 1 addition & 1 deletion act/plotting/contourdisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ContourDisplay(Display):
"""

def __init__(self, ds, subplot_shape=(1,), ds_name=None, **kwargs):
super().__init__(ds, subplot_shape, ds_name, **kwargs)
super().__init__(ds, subplot_shape, ds_name, secondary_y_allowed=False, **kwargs)

def create_contour(
self,
Expand Down
133 changes: 66 additions & 67 deletions act/plotting/distributiondisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DistributionDisplay(Display):
"""

def __init__(self, ds, subplot_shape=(1,), ds_name=None, **kwargs):
super().__init__(ds, subplot_shape, ds_name, **kwargs)
super().__init__(ds, subplot_shape, ds_name, secondary_y_allowed=True, **kwargs)

def set_xrng(self, xrng, subplot_index=(0,)):
"""
Expand All @@ -55,7 +55,7 @@ def set_xrng(self, xrng, subplot_index=(0,)):
elif not hasattr(self, 'xrng') and len(self.axes.shape) == 1:
self.xrng = np.zeros((self.axes.shape[0], 2), dtype='datetime64[D]')

self.axes[subplot_index].set_xlim(xrng)
self.axes[subplot_index][0].set_xlim(xrng)
self.xrng[subplot_index, :] = np.array(xrng)

def set_yrng(self, yrng, subplot_index=(0,)):
Expand All @@ -81,7 +81,7 @@ def set_yrng(self, yrng, subplot_index=(0,)):
if yrng[0] == yrng[1]:
yrng[1] = yrng[1] + 1

self.axes[subplot_index].set_ylim(yrng)
self.axes[subplot_index][0].set_ylim(yrng)
self.yrng[subplot_index, :] = yrng

def _get_data(self, dsname, fields):
Expand Down Expand Up @@ -163,13 +163,13 @@ def plot_stacked_bar_graph(
# We will defaut the y direction to have the same # of bins as x
sortby_bins = np.linspace(ydata.values.min(), ydata.values.max(), len(bins))

# Get the current plotting axis, add day/night background and plot data
# Get the current plotting axis
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if sortby_field is not None:
if 'units' in ydata.attrs:
Expand All @@ -189,26 +189,26 @@ def plot_stacked_bar_graph(
bins=[bins, sortby_bins],
**hist_kwargs)
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
self.axes[subplot_index].bar(
self.axes[subplot_index][0].bar(
x_inds,
my_hist[:, 0].flatten(),
label=(str(y_bins[0]) + ' to ' + str(y_bins[1])),
**kwargs,
)
for i in range(1, len(y_bins) - 1):
self.axes[subplot_index].bar(
self.axes[subplot_index][0].bar(
x_inds,
my_hist[:, i].flatten(),
bottom=my_hist[:, i - 1],
label=(str(y_bins[i]) + ' to ' + str(y_bins[i + 1])),
**kwargs,
)
self.axes[subplot_index].legend()
self.axes[subplot_index][0].legend()
else:
my_hist, bins = np.histogram(xdata.values.flatten(), bins=bins,
density=density, **hist_kwargs)
x_inds = (bins[:-1] + bins[1:]) / 2.0
self.axes[subplot_index].bar(x_inds, my_hist)
self.axes[subplot_index][0].bar(x_inds, my_hist)

# Set Title
if set_title is None:
Expand All @@ -220,9 +220,9 @@ def plot_stacked_bar_graph(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel('count')
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel('count')
self.axes[subplot_index][0].set_xlabel(xtitle)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
Expand Down Expand Up @@ -306,13 +306,13 @@ def plot_size_distribution(
+ 'length is equal to the field length!'
)

# Get the current plotting axis, add day/night background and plot data
# Get the current plotting axis
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Set Title
if set_title is None:
Expand All @@ -327,10 +327,10 @@ def plot_size_distribution(
if time is not None:
t = pd.Timestamp(time)
set_title += ''.join([' at ', ':'.join([str(t.hour), str(t.minute), str(t.second)])])
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].step(bins.values, xdata.values, **kwargs)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].step(bins.values, xdata.values, **kwargs)
self.axes[subplot_index][0].set_xlabel(xtitle)
self.axes[subplot_index][0].set_ylabel(ytitle)

return self.axes[subplot_index]

Expand Down Expand Up @@ -412,8 +412,9 @@ def plot_stairstep_graph(
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if sortby_field is not None:
if 'units' in ydata.attrs:
Expand All @@ -433,26 +434,26 @@ def plot_stairstep_graph(
**hist_kwargs
)
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
self.axes[subplot_index].step(
self.axes[subplot_index][0].step(
x_inds,
my_hist[:, 0].flatten(),
label=(str(y_bins[0]) + ' to ' + str(y_bins[1])),
**kwargs,
)
for i in range(1, len(y_bins) - 1):
self.axes[subplot_index].step(
self.axes[subplot_index][0].step(
x_inds,
my_hist[:, i].flatten(),
label=(str(y_bins[i]) + ' to ' + str(y_bins[i + 1])),
**kwargs,
)
self.axes[subplot_index].legend()
self.axes[subplot_index][0].legend()
else:
my_hist, bins = np.histogram(xdata.values.flatten(), bins=bins,
density=density, **hist_kwargs)

x_inds = (bins[:-1] + bins[1:]) / 2.0
self.axes[subplot_index].step(x_inds, my_hist, **kwargs)
self.axes[subplot_index][0].step(x_inds, my_hist, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -464,9 +465,9 @@ def plot_stairstep_graph(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel('count')
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel('count')
self.axes[subplot_index][0].set_xlabel(xtitle)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
Expand Down Expand Up @@ -568,10 +569,10 @@ def plot_heatmap(
# Get the current plotting axis, add day/night background and plot data
if self.fig is None:
self.fig = plt.figure()

if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

if 'units' in ydata.attrs:
ytitle = ''.join(['(', ydata.attrs['units'], ')'])
Expand All @@ -597,7 +598,7 @@ def plot_heatmap(
x_inds = (x_bins[:-1] + x_bins[1:]) / 2.0
y_inds = (y_bins[:-1] + y_bins[1:]) / 2.0
xi, yi = np.meshgrid(x_inds, y_inds, indexing='ij')
mesh = self.axes[subplot_index].pcolormesh(xi, yi, my_hist, shading=set_shading, **kwargs)
mesh = self.axes[subplot_index][0].pcolormesh(xi, yi, my_hist, shading=set_shading, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -608,13 +609,13 @@ def plot_heatmap(
dt_utils.numpy_to_arm_date(self._ds[dsname].time.values[0]),
]
)
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel(ytitle)
self.axes[subplot_index][0].set_xlabel(xtitle)
self.add_colorbar(mesh, title='count', subplot_index=subplot_index)

return_dict = {}
return_dict['plot_handle'] = self.axes[subplot_index]
return_dict['plot_handle'] = self.axes[subplot_index][0]
return_dict['x_bins'] = x_bins
return_dict['y_bins'] = y_bins
return_dict['histogram'] = my_hist
Expand All @@ -634,9 +635,9 @@ def set_ratio_line(self, subplot_index=(0, )):
if self.axes is None:
raise RuntimeError('set_ratio_line requires the plot to be displayed.')
# Define the xticks of the figure
xlims = self.axes[subplot_index].get_xticks()
xlims = self.axes[subplot_index][0].get_xticks()
ratio = np.linspace(xlims[0], xlims[-1])
self.axes[subplot_index].plot(ratio, ratio, 'k--')
self.axes[subplot_index][0].plot(ratio, ratio, 'k--')

def plot_scatter(self,
x_field,
Expand Down Expand Up @@ -713,15 +714,12 @@ def plot_scatter(self,

# Define the axes for the figure
if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Display the scatter plot, pass keyword args for unspecified attributes
scc = self.axes[subplot_index].scatter(xdata,
ydata,
c=mdata,
**kwargs
)
scc = self.axes[subplot_index][0].scatter(xdata, ydata, c=mdata, **kwargs)

# Set Title
if set_title is None:
Expand All @@ -748,9 +746,9 @@ def plot_scatter(self,
cbar.ax.set_ylabel(ztitle)

# Define the axe title, x-axis label, y-axis label
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index].set_ylabel(ytitle)
self.axes[subplot_index].set_xlabel(xtitle)
self.axes[subplot_index][0].set_title(set_title)
self.axes[subplot_index][0].set_ylabel(ytitle)
self.axes[subplot_index][0].set_xlabel(xtitle)

return self.axes[subplot_index]

Expand Down Expand Up @@ -818,8 +816,9 @@ def plot_violin(self,

# Define the axes for the figure
if self.axes is None:
self.axes = np.array([plt.axes()])
self.fig.add_axes(self.axes[0])
self.axes = np.array([[plt.axes(), plt.axes().twinx()]])
for a in self.axes[0]:
self.fig.add_axes(a)

# Define the axe label. If units are avaiable, plot.
if 'units' in ndata.attrs:
Expand All @@ -828,14 +827,14 @@ def plot_violin(self,
axtitle = field

# Display the scatter plot, pass keyword args for unspecified attributes
scc = self.axes[subplot_index].violinplot(ndata,
positions=positions,
vert=vert,
showmeans=showmeans,
showmedians=showmedians,
showextrema=showextrema,
**kwargs
)
scc = self.axes[subplot_index][0].violinplot(ndata,
positions=positions,
vert=vert,
showmeans=showmeans,
showmedians=showmedians,
showextrema=showextrema,
**kwargs
)
if showmeans is True:
scc['cmeans'].set_edgecolor('red')
scc['cmeans'].set_label('mean')
Expand All @@ -853,14 +852,14 @@ def plot_violin(self,
)

# Define the axe title, x-axis label, y-axis label
self.axes[subplot_index].set_title(set_title)
self.axes[subplot_index][0].set_title(set_title)
if vert is True:
self.axes[subplot_index].set_ylabel(axtitle)
self.axes[subplot_index][0].set_ylabel(axtitle)
if positions is None:
self.axes[subplot_index].set_xticks([])
self.axes[subplot_index][0].set_xticks([])
else:
self.axes[subplot_index].set_xlabel(axtitle)
self.axes[subplot_index][0].set_xlabel(axtitle)
if positions is None:
self.axes[subplot_index].set_yticks([])
self.axes[subplot_index][0].set_yticks([])

return self.axes[subplot_index]
return self.axes[subplot_index][0]
2 changes: 1 addition & 1 deletion act/plotting/geodisplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, ds, ds_name=None, **kwargs):
raise ImportError(
'Cartopy needs to be installed on your ' 'system to make geographic display plots.'
)
super().__init__(ds, ds_name, **kwargs)
super().__init__(ds, ds_name, secondary_y_allowed=False, **kwargs)
if self.fig is None:
self.fig = plt.figure(**kwargs)

Expand Down
Loading