Skip to content

Commit

Permalink
Merge pull request #109 from Carifio24/update-masks
Browse files Browse the repository at this point in the history
Update masks for scatter layers
  • Loading branch information
Carifio24 authored Dec 1, 2024
2 parents 43dbf66 + d537bc7 commit fb4b691
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 15 deletions.
5 changes: 1 addition & 4 deletions glue_plotly/common/base_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def bbox_mask(viewer_state, x, y, z):
(z >= viewer_state.z_min) & (z <= viewer_state.z_max)


def clipped_data(viewer_state, layer_state):
x = layer_state.layer[viewer_state.x_att]
y = layer_state.layer[viewer_state.y_att]
z = layer_state.layer[viewer_state.z_att]
def clipped_data(viewer_state, x, y, z):

# Plotly doesn't show anything outside the bounding box
mask = bbox_mask(viewer_state, x, y, z)
Expand Down
3 changes: 2 additions & 1 deletion glue_plotly/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def rgb_colors(layer_state, mask, cmap_att):
return rgba_strs


def color_info(layer_state, mask=None,
def color_info(layer_state,
mask=None,
mode_att="cmap_mode",
cmap_att="cmap_att"):
if getattr(layer_state, mode_att, "Fixed") == "Fixed":
Expand Down
11 changes: 10 additions & 1 deletion glue_plotly/common/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,16 @@ def traces_for_nonpixel_subset_layer(viewer_state, layer_state, full_view, trans
def traces_for_scatter_layer(viewer_state, layer_state, hover_data=None, add_data_label=True):
x = layer_state.layer[viewer_state.x_att].copy()
y = layer_state.layer[viewer_state.y_att].copy()
mask, (x, y) = sanitize(x, y)
arrs = [x, y]
if layer_state.cmap_mode == "Linear":
cvals = layer_state.layer[layer_state.cmap_att].copy()
arrs.append(cvals)
if layer_state.size_mode == "Linear":
svals = layer_state.layer[layer_state.size_att].copy()
arrs.append(svals)

mask, sanitized = sanitize(*arrs)
x, y = sanitized[:2]

marker = dict(color=color_info(layer_state),
opacity=layer_state.alpha,
Expand Down
11 changes: 10 additions & 1 deletion glue_plotly/common/scatter2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,16 @@ def trace_data_for_layer(viewer, layer_state, hover_data=None, add_data_label=Tr

x = layer_state.layer[viewer.state.x_att].copy()
y = layer_state.layer[viewer.state.y_att].copy()
mask, (x, y) = sanitize(x, y)
arrs = [x, y]
if layer_state.cmap_mode == "Linear":
cvals = layer_state.layer[layer_state.cmap_att].copy()
arrs.append(cvals)
if layer_state.size_mode == "Linear":
svals = layer_state.layer[layer_state.size_att].copy()
arrs.append(svals)

mask, sanitized = sanitize(*arrs)
x, y = sanitized[:2]

legend_group = uuid4().hex

Expand Down
41 changes: 33 additions & 8 deletions glue_plotly/common/scatter3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@
from plotly.graph_objs import Cone, Scatter3d
from uuid import uuid4

from glue_plotly.common import color_info
from glue_plotly.common.base_3d import clipped_data
from glue_plotly.common import color_info, sanitize
from glue_plotly.common.base_3d import bbox_mask

try:
from glue_vispy_viewers.scatter.layer_state import ScatterLayerState
except ImportError:
ScatterLayerState = type(None)

def size_info(layer_state, mask):

def size_info(layer_state, mask, size_att="size_attribute"):

# set all points to be the same size, with set scaling
if layer_state.size_mode == 'Fixed':
return layer_state.size_scaling * layer_state.size

# scale size of points by set size scaling
else:
s = ensure_numerical(layer_state.layer[layer_state.size_attribute][mask].ravel())
s = ensure_numerical(layer_state.layer[getattr(layer_state, size_att)][mask].ravel())
s = ((s - layer_state.size_vmin) /
(layer_state.size_vmax - layer_state.size_vmin))
# The following ensures that the sizes are in the
Expand Down Expand Up @@ -110,11 +115,31 @@ def symbol_for_geometry(geometry: str) -> str:

def traces_for_layer(viewer_state, layer_state, hover_data=None, add_data_label=True):

x, y, z, mask = clipped_data(viewer_state, layer_state)
x = layer_state.layer[viewer_state.x_att]
y = layer_state.layer[viewer_state.y_att]
z = layer_state.layer[viewer_state.z_att]

vispy_layer_state = isinstance(layer_state, ScatterLayerState)
cmap_mode_attr = "color_mode" if vispy_layer_state else "cmap_mode"
cmap_attr = "cmap_attribute" if vispy_layer_state else "cmap_att"
size_attr = "size_attribute" if vispy_layer_state else "size_att"
arrs = [x, y, z]
if getattr(layer_state, cmap_mode_attr) == "Linear":
cvals = layer_state.layer[getattr(layer_state, cmap_attr)].copy()
arrs.append(cvals)
if layer_state.size_mode == "Linear":
svals = layer_state.layer[getattr(layer_state, size_attr)].copy()
arrs.append(svals)

mask, _ = sanitize(*arrs)
bounds_mask = bbox_mask(viewer_state, x, y, z)
mask &= bounds_mask
x, y, z = x[mask], y[mask], z[mask]

marker = dict(color=color_info(layer_state, mask=mask,
mode_att="color_mode",
cmap_att="cmap_attribute"),
size=size_info(layer_state, mask),
mode_att=cmap_mode_attr,
cmap_att=cmap_attr),
size=size_info(layer_state, mask, size_att=size_attr),
opacity=layer_state.alpha,
line=dict(width=0))

Expand Down

0 comments on commit fb4b691

Please sign in to comment.