Skip to content

Commit

Permalink
refactor two widgets code (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
hainm authored Dec 5, 2024
1 parent 249467b commit f0a1152
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 291 deletions.
118 changes: 7 additions & 111 deletions nglview/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from contextlib import contextmanager

import numpy as np
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import (Image, Box, DOMWidget, HBox, IntSlider, Play, jslink)
from ipywidgets import embed
Expand All @@ -29,6 +28,7 @@
seq_to_string)
from .viewer_control import ViewerControl
from ._frontend import __frontend_version__
from .widget_base import WidgetBase

logger = getLogger(__name__)

Expand Down Expand Up @@ -126,16 +126,12 @@ def _unset_serialization(views):
f.write(html_code)


class NGLWidget(DOMWidget):
class NGLWidget(WidgetBase):
_view_name = Unicode("NGLView").tag(sync=True)
_view_module = Unicode("nglview-js-widgets").tag(sync=True)
_view_module_version = Unicode(__frontend_version__).tag(sync=True)
_model_name = Unicode("NGLModel").tag(sync=True)
_model_module = Unicode("nglview-js-widgets").tag(sync=True)
_model_module_version = Unicode(__frontend_version__).tag(sync=True)
_ngl_version = Unicode().tag(sync=True)

# View and model attributes
# View and model attributes
_image_data = Unicode().tag(sync=False)
_view_width = Unicode().tag(sync=True) # px
_view_height = Unicode().tag(sync=True) # px
Expand Down Expand Up @@ -187,7 +183,7 @@ def __init__(self,
representations=None,
parameters=None,
**kwargs):
super().__init__(**kwargs)
super().__init__(structure=structure, representations=representations, parameters=parameters, **kwargs)
self._initialize_attributes(kwargs)
self._initialize_threads()
self._initialize_components(structure, representations, parameters, kwargs)
Expand All @@ -204,7 +200,7 @@ def _initialize_attributes(self, kwargs):
self._image_array = []
# do not use _displayed_callbacks since there is another Widget._display_callbacks
self._event = threading.Event()
self._ngl_displayed_callbacks_before_loaded = []
self._callbacks_before_loaded = []
widget_utils._add_repr_method_shortcut(self, self)
self.shape = Shape(view=self)
self.stage = Stage(view=self)
Expand Down Expand Up @@ -383,47 +379,8 @@ def _update_background_color(self, change):
self.stage.set_parameters(background_color=color)

def handle_resize(self):
# self._remote_call("handleResize", target='Stage')
self._remote_call("handleResize")

def _update_max_frame(self):
self.max_frame = max(
int(traj.n_frames)
for traj in self._trajlist
if hasattr(traj, 'n_frames')) - 1 # index starts from 0

def _wait_until_finished(self, timeout=0.0001):
# NGL need to send 'finished' signal to
# backend
self._event.clear()
while True:
# idle to make room for waiting for
# "finished" event sent from JS
time.sleep(timeout)
if self._event.is_set():
# if event is set from another thread
# break while True
break

def _run_on_another_thread(self, func, *args):
# use `event` to singal
# func(*args)
thread = threading.Thread(
target=func,
args=args,
)
thread.daemon = True
thread.start()
return thread

@observe('loaded')
def on_loaded(self, change):
# trick for firefox on Linux
time.sleep(0.1)

if change['new']:
self._fire_callbacks(self._ngl_displayed_callbacks_before_loaded)

def _fire_callbacks(self, callbacks):

def _call(event):
Expand All @@ -432,14 +389,7 @@ def _call(event):
if callback._method_name == 'loadFile':
self._wait_until_finished()

self._run_on_another_thread(_call, self._event)

def _ipython_display_(self, **kwargs):
try:
# ipywidgets < 8
super()._ipython_display_(**kwargs)
except AttributeError:
display(super()._repr_mimebundle_(), raw=True)
self._thread_run(_call, self._event)

def display(self, gui=False, style='ngl'):
"""
Expand Down Expand Up @@ -639,60 +589,6 @@ def _display_repr(self, component=0, repr_index=0, name=None):

return RepresentationControl(self, component, repr_index, name=name)

def _set_coordinates(self, index, movie_making=False, render_params=None):
'''update coordinates for all trajectories at index-th frame'''
render_params = render_params or {}
if self._trajlist:
coordinates_dict = {}
for trajectory in self._trajlist:
traj_index = self._ngl_component_ids.index(trajectory.id)

try:
if trajectory.shown:
coordinates_dict[
traj_index] = trajectory.get_coordinates(index)
else:
coordinates_dict[traj_index] = np.empty((0), dtype='f4')
except (IndexError, ValueError):
coordinates_dict[traj_index] = np.empty((0), dtype='f4')

self.set_coordinates(coordinates_dict,
render_params=render_params,
movie_making=movie_making)
else:
print("no trajectory available")

def set_coordinates(self, arr_dict, movie_making=False, render_params=None):
# type: (Dict[int, np.ndarray]) -> None
"""Used for update coordinates of a given trajectory
>>> # arr: numpy array, ndim=2
>>> # update coordinates of 1st trajectory
>>> view.set_coordinates({0: arr})# doctest: +SKIP
"""
render_params = render_params or {}
self._coordinates_dict = arr_dict

buffers = []
coordinates_meta = dict()
for index, arr in self._coordinates_dict.items():
buffers.append(arr.astype('f4').tobytes())
coordinates_meta[index] = index
msg = {
'type': 'binary_single',
'data': coordinates_meta,
}
if movie_making:
msg['movie_making'] = movie_making
msg['render_params'] = render_params

self.send(msg, buffers=buffers)

@observe('frame')
def _on_frame_changed(self, change):
"""set and send coordinates at current frame
"""
self._set_coordinates(change['new'])

def clear(self, *args, **kwargs):
'''shortcut of `clear_representations`
'''
Expand Down Expand Up @@ -1319,7 +1215,7 @@ def callback(widget, msg=msg):
else:
# send later
# all callbacks will be called right after widget is loaded
self._ngl_displayed_callbacks_before_loaded.append(callback)
self._callbacks_before_loaded.append(callback)

if callback._method_name not in _EXCLUDED_CALLBACK_AFTER_FIRING and \
(not other_kwargs.get("fire_once", False)):
Expand Down
158 changes: 158 additions & 0 deletions nglview/widget_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import threading
import time
import ipywidgets as widgets
from traitlets import Bool, Integer, observe, Unicode
from .remote_thread import RemoteCallThread
from IPython.display import display
import numpy as np

from ._frontend import __frontend_version__

class WidgetBase(widgets.DOMWidget):
_view_module = Unicode('nglview-js-widgets').tag(sync=True)
_model_module = Unicode('nglview-js-widgets').tag(sync=True)
_view_module_version = Unicode(__frontend_version__).tag(sync=True)
_model_module_version = Unicode(__frontend_version__).tag(sync=True)

frame = Integer().tag(sync=True)
loaded = Bool(False).tag(sync=False)
_component_ids = []
_trajlist = []
_callbacks_before_loaded = []
_event = threading.Event()

def __init__(self, **kwargs):
# Extract recognized arguments
recognized_kwargs = {k: v for k, v in kwargs.items() if k in self.trait_names()}
super().__init__(**recognized_kwargs)
self._initialize_threads()

def _initialize_threads(self):
self._remote_call_thread = RemoteCallThread(self, registered_funcs=[])
self._remote_call_thread.daemon = True
self._remote_call_thread.start()
self._handle_msg_thread = threading.Thread(target=self.on_msg, args=(self._handle_nglview_custom_message,))
self._handle_msg_thread.daemon = True
self._handle_msg_thread.start()

def _handle_nglview_custom_message(self, widget, msg, buffers):
raise NotImplementedError()

def render_image(self):
image = widgets.Image()
self._js(f"this.exportImage('{image.model_id}')")
return image

def handle_resize(self):
self._js("this.plugin.handleResize()")

@observe('loaded')
def on_loaded(self, change):
# trick for firefox on Linux
time.sleep(0.1)
if change['new']:
self._fire_callbacks(self._callbacks_before_loaded)

def _thread_run(self, func, *args):
thread = threading.Thread(target=func, args=args)
thread.daemon = True
thread.start()
return thread

def _fire_callbacks(self, callbacks):
def _call(event):
for callback in callbacks:
callback(self)
self._thread_run(_call, self._event)

def _update_max_frame(self):
self.max_frame = max(
int(traj.n_frames) for traj in self._trajlist
if hasattr(traj, 'n_frames')) - 1 # index starts from 0

def _wait_until_finished(self, timeout=0.0001):
self._event.clear()
while True:
# idle to make room for waiting for
# "finished" event sent from JS
time.sleep(timeout)
if self._event.is_set():
# if event is set from another thread
# break while True
break

def _js(self, code, **kwargs):
self._remote_call('executeCode', target='Widget', args=[code], **kwargs)

def _remote_call(self, method_name, target='Widget', args=None, kwargs=None, **other_kwargs):
msg = self._get_remote_call_msg(method_name, target=target, args=args, kwargs=kwargs, **other_kwargs)
def callback(widget, msg=msg):
widget.send(msg)
callback._method_name = method_name
callback._msg = msg
if self.loaded:
self._remote_call_thread.q.append(callback)
else:
self._callbacks_before_loaded.append(callback)

def _get_remote_call_msg(self, method_name, target='Widget', args=None, kwargs=None, **other_kwargs):
msg = {'target': target, 'type': 'call_method', 'methodName': method_name, 'args': args, 'kwargs': kwargs}
msg.update(other_kwargs)
return msg

def _set_coordinates(self, index, movie_making=False, render_params=None):
'''update coordinates for all trajectories at index-th frame'''
render_params = render_params or {}
if self._trajlist:
coordinates_dict = {}
for trajectory in self._trajlist:
traj_index = self._ngl_component_ids.index(trajectory.id)

try:
if trajectory.shown:
coordinates_dict[traj_index] = trajectory.get_coordinates(index)
else:
coordinates_dict[traj_index] = np.empty((0), dtype='f4')
except (IndexError, ValueError):
coordinates_dict[traj_index] = np.empty((0), dtype='f4')

self.set_coordinates(coordinates_dict,
render_params=render_params,
movie_making=movie_making)
else:
print("no trajectory available")

def set_coordinates(self, arr_dict, movie_making=False, render_params=None):
"""Used for update coordinates of a given trajectory
>>> # arr: numpy array, ndim=2
>>> # update coordinates of 1st trajectory
>>> view.set_coordinates({0: arr})# doctest: +SKIP
"""
render_params = render_params or {}
self._coordinates_dict = arr_dict

buffers = []
coordinates_meta = dict()
for index, arr in self._coordinates_dict.items():
buffers.append(arr.astype('f4').tobytes())
coordinates_meta[index] = index
msg = {
'type': 'binary_single',
'data': coordinates_meta,
}
if movie_making:
msg['movie_making'] = movie_making
msg['render_params'] = render_params

self.send(msg, buffers=buffers)

@observe('frame')
def _on_frame_changed(self, change):
self._set_coordinates(change['new'])

def _ipython_display_(self, **kwargs):
try:
# ipywidgets < 8
super()._ipython_display_(**kwargs)
except AttributeError:
display(super()._repr_mimebundle_(), raw=True)
Loading

0 comments on commit f0a1152

Please sign in to comment.