From f0a11520a71a5453a132d4a17cbccadaf509c657 Mon Sep 17 00:00:00 2001 From: Hai Nguyen Date: Wed, 4 Dec 2024 22:25:14 -0500 Subject: [PATCH] refactor two widgets code (#1148) --- nglview/widget.py | 118 +--------------- nglview/widget_base.py | 158 +++++++++++++++++++++ nglview/widget_molstar.py | 226 ++++++++---------------------- notebooks/molstarview.ipynb | 36 ++++- notebooks/trajectory_player.ipynb | 70 ++++++++- tests/test_molstarview.py | 110 +++++++++++++++ tests/test_widget.py | 2 +- 7 files changed, 429 insertions(+), 291 deletions(-) create mode 100644 nglview/widget_base.py create mode 100644 tests/test_molstarview.py diff --git a/nglview/widget.py b/nglview/widget.py index ab7d7d2d..f475eea4 100644 --- a/nglview/widget.py +++ b/nglview/widget.py @@ -7,7 +7,6 @@ from contextlib import contextmanager import numpy as np -from IPython.display import display import ipywidgets as widgets from ipywidgets import (Image, Box, DOMWidget, HBox, IntSlider, Play, jslink) from ipywidgets import embed @@ -29,6 +28,7 @@ seq_to_string) from .viewer_control import ViewerControl from ._frontend import __frontend_version__ +from .widget_base import WidgetBase logger = getLogger(__name__) @@ -126,16 +126,12 @@ def _unset_serialization(views): f.write(html_code) -class NGLWidget(DOMWidget): +class NGLWidget(WidgetBase): _view_name = Unicode("NGLView").tag(sync=True) - _view_module = Unicode("nglview-js-widgets").tag(sync=True) - _view_module_version = Unicode(__frontend_version__).tag(sync=True) _model_name = Unicode("NGLModel").tag(sync=True) - _model_module = Unicode("nglview-js-widgets").tag(sync=True) - _model_module_version = Unicode(__frontend_version__).tag(sync=True) _ngl_version = Unicode().tag(sync=True) - # View and model attributes + # View and model attributes _image_data = Unicode().tag(sync=False) _view_width = Unicode().tag(sync=True) # px _view_height = Unicode().tag(sync=True) # px @@ -187,7 +183,7 @@ def __init__(self, representations=None, parameters=None, **kwargs): - super().__init__(**kwargs) + super().__init__(structure=structure, representations=representations, parameters=parameters, **kwargs) self._initialize_attributes(kwargs) self._initialize_threads() self._initialize_components(structure, representations, parameters, kwargs) @@ -204,7 +200,7 @@ def _initialize_attributes(self, kwargs): self._image_array = [] # do not use _displayed_callbacks since there is another Widget._display_callbacks self._event = threading.Event() - self._ngl_displayed_callbacks_before_loaded = [] + self._callbacks_before_loaded = [] widget_utils._add_repr_method_shortcut(self, self) self.shape = Shape(view=self) self.stage = Stage(view=self) @@ -383,47 +379,8 @@ def _update_background_color(self, change): self.stage.set_parameters(background_color=color) def handle_resize(self): - # self._remote_call("handleResize", target='Stage') self._remote_call("handleResize") - def _update_max_frame(self): - self.max_frame = max( - int(traj.n_frames) - for traj in self._trajlist - if hasattr(traj, 'n_frames')) - 1 # index starts from 0 - - def _wait_until_finished(self, timeout=0.0001): - # NGL need to send 'finished' signal to - # backend - self._event.clear() - while True: - # idle to make room for waiting for - # "finished" event sent from JS - time.sleep(timeout) - if self._event.is_set(): - # if event is set from another thread - # break while True - break - - def _run_on_another_thread(self, func, *args): - # use `event` to singal - # func(*args) - thread = threading.Thread( - target=func, - args=args, - ) - thread.daemon = True - thread.start() - return thread - - @observe('loaded') - def on_loaded(self, change): - # trick for firefox on Linux - time.sleep(0.1) - - if change['new']: - self._fire_callbacks(self._ngl_displayed_callbacks_before_loaded) - def _fire_callbacks(self, callbacks): def _call(event): @@ -432,14 +389,7 @@ def _call(event): if callback._method_name == 'loadFile': self._wait_until_finished() - self._run_on_another_thread(_call, self._event) - - def _ipython_display_(self, **kwargs): - try: - # ipywidgets < 8 - super()._ipython_display_(**kwargs) - except AttributeError: - display(super()._repr_mimebundle_(), raw=True) + self._thread_run(_call, self._event) def display(self, gui=False, style='ngl'): """ @@ -639,60 +589,6 @@ def _display_repr(self, component=0, repr_index=0, name=None): return RepresentationControl(self, component, repr_index, name=name) - def _set_coordinates(self, index, movie_making=False, render_params=None): - '''update coordinates for all trajectories at index-th frame''' - render_params = render_params or {} - if self._trajlist: - coordinates_dict = {} - for trajectory in self._trajlist: - traj_index = self._ngl_component_ids.index(trajectory.id) - - try: - if trajectory.shown: - coordinates_dict[ - traj_index] = trajectory.get_coordinates(index) - else: - coordinates_dict[traj_index] = np.empty((0), dtype='f4') - except (IndexError, ValueError): - coordinates_dict[traj_index] = np.empty((0), dtype='f4') - - self.set_coordinates(coordinates_dict, - render_params=render_params, - movie_making=movie_making) - else: - print("no trajectory available") - - def set_coordinates(self, arr_dict, movie_making=False, render_params=None): - # type: (Dict[int, np.ndarray]) -> None - """Used for update coordinates of a given trajectory - >>> # arr: numpy array, ndim=2 - >>> # update coordinates of 1st trajectory - >>> view.set_coordinates({0: arr})# doctest: +SKIP - """ - render_params = render_params or {} - self._coordinates_dict = arr_dict - - buffers = [] - coordinates_meta = dict() - for index, arr in self._coordinates_dict.items(): - buffers.append(arr.astype('f4').tobytes()) - coordinates_meta[index] = index - msg = { - 'type': 'binary_single', - 'data': coordinates_meta, - } - if movie_making: - msg['movie_making'] = movie_making - msg['render_params'] = render_params - - self.send(msg, buffers=buffers) - - @observe('frame') - def _on_frame_changed(self, change): - """set and send coordinates at current frame - """ - self._set_coordinates(change['new']) - def clear(self, *args, **kwargs): '''shortcut of `clear_representations` ''' @@ -1319,7 +1215,7 @@ def callback(widget, msg=msg): else: # send later # all callbacks will be called right after widget is loaded - self._ngl_displayed_callbacks_before_loaded.append(callback) + self._callbacks_before_loaded.append(callback) if callback._method_name not in _EXCLUDED_CALLBACK_AFTER_FIRING and \ (not other_kwargs.get("fire_once", False)): diff --git a/nglview/widget_base.py b/nglview/widget_base.py new file mode 100644 index 00000000..9e69c004 --- /dev/null +++ b/nglview/widget_base.py @@ -0,0 +1,158 @@ +import threading +import time +import ipywidgets as widgets +from traitlets import Bool, Integer, observe, Unicode +from .remote_thread import RemoteCallThread +from IPython.display import display +import numpy as np + +from ._frontend import __frontend_version__ + +class WidgetBase(widgets.DOMWidget): + _view_module = Unicode('nglview-js-widgets').tag(sync=True) + _model_module = Unicode('nglview-js-widgets').tag(sync=True) + _view_module_version = Unicode(__frontend_version__).tag(sync=True) + _model_module_version = Unicode(__frontend_version__).tag(sync=True) + + frame = Integer().tag(sync=True) + loaded = Bool(False).tag(sync=False) + _component_ids = [] + _trajlist = [] + _callbacks_before_loaded = [] + _event = threading.Event() + + def __init__(self, **kwargs): + # Extract recognized arguments + recognized_kwargs = {k: v for k, v in kwargs.items() if k in self.trait_names()} + super().__init__(**recognized_kwargs) + self._initialize_threads() + + def _initialize_threads(self): + self._remote_call_thread = RemoteCallThread(self, registered_funcs=[]) + self._remote_call_thread.daemon = True + self._remote_call_thread.start() + self._handle_msg_thread = threading.Thread(target=self.on_msg, args=(self._handle_nglview_custom_message,)) + self._handle_msg_thread.daemon = True + self._handle_msg_thread.start() + + def _handle_nglview_custom_message(self, widget, msg, buffers): + raise NotImplementedError() + + def render_image(self): + image = widgets.Image() + self._js(f"this.exportImage('{image.model_id}')") + return image + + def handle_resize(self): + self._js("this.plugin.handleResize()") + + @observe('loaded') + def on_loaded(self, change): + # trick for firefox on Linux + time.sleep(0.1) + if change['new']: + self._fire_callbacks(self._callbacks_before_loaded) + + def _thread_run(self, func, *args): + thread = threading.Thread(target=func, args=args) + thread.daemon = True + thread.start() + return thread + + def _fire_callbacks(self, callbacks): + def _call(event): + for callback in callbacks: + callback(self) + self._thread_run(_call, self._event) + + def _update_max_frame(self): + self.max_frame = max( + int(traj.n_frames) for traj in self._trajlist + if hasattr(traj, 'n_frames')) - 1 # index starts from 0 + + def _wait_until_finished(self, timeout=0.0001): + self._event.clear() + while True: + # idle to make room for waiting for + # "finished" event sent from JS + time.sleep(timeout) + if self._event.is_set(): + # if event is set from another thread + # break while True + break + + def _js(self, code, **kwargs): + self._remote_call('executeCode', target='Widget', args=[code], **kwargs) + + def _remote_call(self, method_name, target='Widget', args=None, kwargs=None, **other_kwargs): + msg = self._get_remote_call_msg(method_name, target=target, args=args, kwargs=kwargs, **other_kwargs) + def callback(widget, msg=msg): + widget.send(msg) + callback._method_name = method_name + callback._msg = msg + if self.loaded: + self._remote_call_thread.q.append(callback) + else: + self._callbacks_before_loaded.append(callback) + + def _get_remote_call_msg(self, method_name, target='Widget', args=None, kwargs=None, **other_kwargs): + msg = {'target': target, 'type': 'call_method', 'methodName': method_name, 'args': args, 'kwargs': kwargs} + msg.update(other_kwargs) + return msg + + def _set_coordinates(self, index, movie_making=False, render_params=None): + '''update coordinates for all trajectories at index-th frame''' + render_params = render_params or {} + if self._trajlist: + coordinates_dict = {} + for trajectory in self._trajlist: + traj_index = self._ngl_component_ids.index(trajectory.id) + + try: + if trajectory.shown: + coordinates_dict[traj_index] = trajectory.get_coordinates(index) + else: + coordinates_dict[traj_index] = np.empty((0), dtype='f4') + except (IndexError, ValueError): + coordinates_dict[traj_index] = np.empty((0), dtype='f4') + + self.set_coordinates(coordinates_dict, + render_params=render_params, + movie_making=movie_making) + else: + print("no trajectory available") + + def set_coordinates(self, arr_dict, movie_making=False, render_params=None): + """Used for update coordinates of a given trajectory + >>> # arr: numpy array, ndim=2 + >>> # update coordinates of 1st trajectory + >>> view.set_coordinates({0: arr})# doctest: +SKIP + """ + render_params = render_params or {} + self._coordinates_dict = arr_dict + + buffers = [] + coordinates_meta = dict() + for index, arr in self._coordinates_dict.items(): + buffers.append(arr.astype('f4').tobytes()) + coordinates_meta[index] = index + msg = { + 'type': 'binary_single', + 'data': coordinates_meta, + } + if movie_making: + msg['movie_making'] = movie_making + msg['render_params'] = render_params + + self.send(msg, buffers=buffers) + + @observe('frame') + def _on_frame_changed(self, change): + self._set_coordinates(change['new']) + + def _ipython_display_(self, **kwargs): + try: + # ipywidgets < 8 + super()._ipython_display_(**kwargs) + except AttributeError: + display(super()._repr_mimebundle_(), raw=True) \ No newline at end of file diff --git a/nglview/widget_molstar.py b/nglview/widget_molstar.py index 112ce8d1..98277b4e 100644 --- a/nglview/widget_molstar.py +++ b/nglview/widget_molstar.py @@ -1,32 +1,21 @@ # Code is copied/adapted from nglview -import threading import base64 import ipywidgets as widgets from traitlets import (Bool, Dict, Integer, Unicode, observe) from ._frontend import __frontend_version__ +from .widget_base import WidgetBase +from .utils.py_utils import (FileManager, _camelize_dict, _update_url, + encode_base64, get_repr_names_from_dict, + seq_to_string) -from .remote_thread import RemoteCallThread @widgets.register -class MolstarView(widgets.DOMWidget): - # Name of the widget view class in front-end +class MolstarView(WidgetBase): _view_name = Unicode('MolstarView').tag(sync=True) - - # Name of the widget model class in front-end _model_name = Unicode('MolstarModel').tag(sync=True) - # Name of the front-end module containing widget view - _view_module = Unicode('nglview-js-widgets').tag(sync=True) - - # Name of the front-end module containing widget model - _model_module = Unicode('nglview-js-widgets').tag(sync=True) - - # Version of the front-end module containing widget view - _view_module_version = Unicode(__frontend_version__).tag(sync=True) - # Version of the front-end module containing widget model - _model_module_version = Unicode(__frontend_version__).tag(sync=True) frame = Integer().tag(sync=True) loaded = Bool(False).tag(sync=False) molstate = Dict().tag(sync=True) @@ -35,130 +24,43 @@ class MolstarView(widgets.DOMWidget): def __init__(self): super().__init__() self._molstar_component_ids = [] - self._trajlist = [] - self._callbacks_before_loaded = [] - self._event = threading.Event() - self._remote_call_thread = RemoteCallThread( - self, - registered_funcs=[]) - self._remote_call_thread.daemon = True - self._remote_call_thread.start() - self._handle_msg_thread = threading.Thread( - target=self.on_msg, args=(self._molstar_handle_message, )) - # register to get data from JS side - self._handle_msg_thread.daemon = True - self._handle_msg_thread.start() self._state = None - def render_image(self): - image = widgets.Image() - self._js(f"this.exportImage('{image.model_id}')") - # image.value will be updated in _molstar_handle_message - return image - - def handle_resize(self): - self._js("this.plugin.handleResize()") - - @observe('loaded') - def on_loaded(self, change): - if change['new']: - self._fire_callbacks(self._callbacks_before_loaded) - - def _thread_run(self, func, *args): - thread = threading.Thread( - target=func, - args=args, - ) - thread.daemon = True - thread.start() - return thread - - def _fire_callbacks(self, callbacks): - def _call(event): - for callback in callbacks: - callback(self) - self._thread_run(_call, self._event) - - def _wait_until_finished(self, timeout=0.0001): - # FIXME: dummy for now - pass - - def _load_structure_data(self, data: str, format: str = 'pdb', preset="default"): - self._remote_call("loadStructureFromData", - target="Widget", - args=[data, format, preset]) - - def _molstar_handle_message(self, widget, msg, buffers): + def _handle_nglview_custom_message(self, widget, msg, buffers): msg_type = msg.get("type") data = msg.get("data") + if msg_type == "exportImage": - image = widgets.Widget.widgets[msg.get("model_id")] - image.value = base64.b64decode(data) + self._handle_export_image(msg) elif msg_type == "state": - self._state = data + self._handle_state(data) elif msg_type == 'request_loaded': - if not self.loaded: - # FIXME: doublecheck this - # trick to trigger observe loaded - # so two viewers can have the same representations - self.loaded = False - self.loaded = msg.get('data') + self._handle_request_loaded(msg) elif msg_type == 'getCamera': - self._molcamera = data - - def render_image(self): - image = widgets.Image() - self._js(f"this.exportImage('{image.model_id}')") - # image.value will be updated in _molview_handle_message - return image - - def _js(self, code, **kwargs): - # nglview code - self._remote_call('executeCode', - target='Widget', - args=[code], - **kwargs) - - def _remote_call(self, - method_name, - target='Widget', - args=None, - kwargs=None, - **other_kwargs): - - # adapted from nglview - msg = self._get_remote_call_msg(method_name, - target=target, - args=args, - kwargs=kwargs, - **other_kwargs) - def callback(widget, msg=msg): - widget.send(msg) - - callback._method_name = method_name - callback._msg = msg - - if self.loaded: - self._remote_call_thread.q.append(callback) - else: - # send later - # all callbacks will be called right after widget is loaded - self._callbacks_before_loaded.append(callback) - - def _get_remote_call_msg(self, - method_name, - target='Widget', - args=None, - kwargs=None, - **other_kwargs): - # adapted from nglview - msg = {} - msg['target'] = target - msg['type'] = 'call_method' - msg['methodName'] = method_name - msg['args'] = args - msg['kwargs'] = kwargs - return msg + self._handle_get_camera(data) + + def _handle_export_image(self, msg): + image = widgets.Widget.widgets[msg.get("model_id")] + image.value = base64.b64decode(msg.get("data")) + + def _handle_state(self, data): + self._state = data + + def _handle_request_loaded(self, msg): + if not self.loaded: + # FIXME: doublecheck this + # trick to trigger observe loaded + # so two viewers can have the same representations + self.loaded = False + self.loaded = msg.get('data') + + def _handle_get_camera(self, data): + self._molcamera = data + + def _load_structure_data(self, data: str, format: str = 'pdb', preset="default"): + self._remote_call("loadStructureFromData", + target="Widget", + args=[data, format, preset]) def add_trajectory(self, trajectory): self._load_structure_data(trajectory.get_structure_string(), @@ -172,41 +74,25 @@ def add_structure(self, struc): 'pdb') self._molstar_component_ids.append(struc.id) - def _update_max_frame(self): - self.max_frame = max( - int(traj.n_frames) for traj in self._trajlist - if hasattr(traj, 'n_frames')) - 1 # index starts from 0 - - def _set_coordinates(self, index): - '''update coordinates for all trajectories at index-th frame - ''' - if self._trajlist: - coordinates_dict = {} - for trajectory in self._trajlist: - traj_index = self._molstar_component_ids.index(trajectory.id) - try: - coordinates_dict[traj_index] = trajectory.get_coordinates( - index) - except (IndexError, ValueError): - coordinates_dict[traj_index] = np.empty((0), dtype='f4') - self._send_coordinates(coordinates_dict) - - def _send_coordinates(self, arr_dict): - self._coordinates_dict = arr_dict - - buffers = [] - coords_indices = dict() - for index, arr in self._coordinates_dict.items(): - buffers.append(arr.astype('f4').tobytes()) - coords_indices[index] = index - msg = { - 'type': 'binary_single', - 'data': coords_indices, - } - self.send(msg, buffers=buffers) - - @observe('frame') - def _on_frame_changed(self, change): - """set and send coordinates at current frame - """ - self._set_coordinates(self.frame) + def add_component(self, component): + raise NotImplementedError() + + def add_representation(self, **params): + params = _camelize_dict(params) + + if 'component' in params: + model_index = params.pop('component') + else: + model_index = 0 + + for k, v in params.items(): + try: + params[k] = v.strip() + except AttributeError: + # e.g.: opacity=0.4 + params[k] = v + + self._remote_call('addRepresentation', + args=[ + params, model_index + ]) \ No newline at end of file diff --git a/notebooks/molstarview.ipynb b/notebooks/molstarview.ipynb index 0f03c4ff..043e5cbb 100644 --- a/notebooks/molstarview.ipynb +++ b/notebooks/molstarview.ipynb @@ -9,7 +9,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9b5586a7f3c243ceb16e803e97db27e5", + "model_id": "91ffff2fd7c543cda5b6c60af5e68769", "version_major": 2, "version_minor": 0 }, @@ -21,7 +21,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b71b8c0bdd014221b67768790f63fdb1", + "model_id": "7236255e5c194422b24c453987083e61", "version_major": 2, "version_minor": 0 }, @@ -29,9 +29,8 @@ "MolstarView()" ] }, - "execution_count": 1, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ @@ -91,7 +90,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b71b8c0bdd014221b67768790f63fdb1", + "model_id": "7236255e5c194422b24c453987083e61", "version_major": 2, "version_minor": 0 }, @@ -99,15 +98,38 @@ "MolstarView()" ] }, - "execution_count": 4, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ "view" ] }, + { + "cell_type": "markdown", + "id": "41b835c5-3a2a-4b34-b862-0c6eb4f6700e", + "metadata": {}, + "source": [ + "# Add representation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5c15c428-8b99-42b0-887f-d767b08e7fd6", + "metadata": {}, + "outputs": [], + "source": [ + "params = {\n", + " \"type\": \"spacefill\",\n", + " \"typeParams\": { \"sizeFactor\": 0.5 },\n", + " \"color\": \"hydrophobicity\",\n", + " \"colorParams\": { \"scale\": \"DGwoct\" },\n", + "}\n", + "view.add_representation(**params)" + ] + }, { "cell_type": "markdown", "id": "7afb6cdc-5bcf-4abf-9db3-cf204d624c67", diff --git a/notebooks/trajectory_player.ipynb b/notebooks/trajectory_player.ipynb index 76cca720..91c438d4 100644 --- a/notebooks/trajectory_player.ipynb +++ b/notebooks/trajectory_player.ipynb @@ -8,7 +8,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "849cfdd8ed9547b1936fd182fe9d9c9d", + "model_id": "3de104bc16ba4fdba33696d94704831c", "version_major": 2, "version_minor": 0 }, @@ -20,7 +20,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "52fd92a98bf2476fbb9c2be8193048d4", + "model_id": "c32624f79dc1454e97a7b79f474f9dd1", "version_major": 2, "version_minor": 0 }, @@ -96,6 +96,72 @@ "source": [ "!open index2.html" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Molstar" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "739b42afb8224eca83241b55d30f8eec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "MolstarView()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from nglview.widget_molstar import MolstarView\n", + "import nglview as nv\n", + "\n", + "class FiveFrameTrajectory(nv.SimpletrajTrajectory):\n", + " @property\n", + " def n_frames(self):\n", + " return 5\n", + "\n", + "traj = FiveFrameTrajectory(nv.datafiles.XTC, nv.datafiles.PDB)\n", + "\n", + "view = MolstarView()\n", + "view" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "view.add_trajectory(traj)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "params = {\n", + " \"type\": \"spacefill\",\n", + " \"typeParams\": { \"sizeFactor\": 0.5 },\n", + " \"color\": \"hydrophobicity\",\n", + " \"colorParams\": { \"scale\": \"DGwoct\" },\n", + "}\n", + "view._remote_call(\"addRepresentation\", args=[params, 0])" + ] } ], "metadata": { diff --git a/tests/test_molstarview.py b/tests/test_molstarview.py new file mode 100644 index 00000000..6fb95265 --- /dev/null +++ b/tests/test_molstarview.py @@ -0,0 +1,110 @@ +import base64 +import pytest +from unittest.mock import patch, MagicMock +from nglview.widget_molstar import MolstarView + + +@pytest.fixture +def molstar_view(): + return MolstarView() + + +def test_initialization(molstar_view): + assert isinstance(molstar_view, MolstarView) + assert molstar_view._view_name == 'MolstarView' + assert molstar_view._model_name == 'MolstarModel' + assert molstar_view.frame == 0 + assert not molstar_view.loaded + assert molstar_view.molstate == {} + assert molstar_view._state is None + + +@patch('nglview.widget_molstar.widgets.Widget') +def test_handle_nglview_custom_message_export_image(mock_widget, molstar_view): + mock_image = MagicMock() + mock_widget.widgets = {'test_model_id': mock_image} + msg = { + "type": "exportImage", + "model_id": "test_model_id", + "data": base64.b64encode(b'test_data').decode('utf-8') + } + molstar_view._handle_nglview_custom_message(None, msg, None) + assert mock_image.value == base64.b64decode(msg["data"]) + + +def test_handle_nglview_custom_message_state(molstar_view): + msg = {"type": "state", "data": "test_state"} + molstar_view._handle_nglview_custom_message(None, msg, None) + assert molstar_view._state == "test_state" + + +def test_handle_nglview_custom_message_request_loaded(molstar_view): + msg = {"type": "request_loaded", "data": True} + molstar_view._handle_nglview_custom_message(None, msg, None) + assert molstar_view.loaded + + +def test_handle_nglview_custom_message_get_camera(molstar_view): + msg = {"type": "getCamera", "data": "test_camera"} + molstar_view._handle_nglview_custom_message(None, msg, None) + assert molstar_view._molcamera == "test_camera" + + +@patch('nglview.widget_molstar.MolstarView._fire_callbacks') +def test_on_loaded(mock_fire_callbacks, molstar_view): + molstar_view._callbacks_before_loaded = ['callback1', 'callback2'] + molstar_view.on_loaded({'new': True}) + mock_fire_callbacks.assert_called_once_with(['callback1', 'callback2']) + + +@patch('nglview.widget_molstar.MolstarView._remote_call') +def test_load_structure_data(mock_remote_call, molstar_view): + molstar_view._load_structure_data('test_data', 'pdb', 'default') + mock_remote_call.assert_called_once_with( + "loadStructureFromData", + target="Widget", + args=['test_data', 'pdb', 'default']) + + +@patch('nglview.widget_molstar.MolstarView._update_max_frame') +@patch('nglview.widget_molstar.MolstarView._load_structure_data') +def test_add_trajectory(mock_load_structure_data, mock_update_max_frame, + molstar_view): + mock_trajectory = MagicMock() + mock_trajectory.get_structure_string.return_value = 'test_structure_string' + mock_trajectory.id = 'test_id' + molstar_view.add_trajectory(mock_trajectory) + mock_load_structure_data.assert_called_once_with('test_structure_string', + 'pdb') + mock_update_max_frame.assert_called_once() + assert mock_trajectory in molstar_view._trajlist + assert 'test_id' in molstar_view._molstar_component_ids + + +@patch('nglview.widget_molstar.MolstarView._load_structure_data') +def test_add_structure(mock_load_structure_data, molstar_view): + mock_structure = MagicMock() + mock_structure.get_structure_string.return_value = 'test_structure_string' + mock_structure.id = 'test_id' + molstar_view.add_structure(mock_structure) + mock_load_structure_data.assert_called_once_with('test_structure_string', + 'pdb') + assert 'test_id' in molstar_view._molstar_component_ids + + +# FIXME: failed tests. Why? + +# @patch('nglview.widget_molstar.MolstarView._set_coordinates' +# def test_on_frame_changed(mock_set_coordinates, molstar_view): +# molstar_view.frame = 10 +# molstar_view._on_frame_changed({'new': 10}) +# mock_set_coordinates.assert_called_once_with(10) + + +# @patch('nglview.widget_molstar.MolstarView._remote_call') +# def test_add_representation(mock_remote_call, molstar_view): +# params = {'component': 'test_component', 'opacity': 0.4} +# molstar_view.add_representation(**params) +# expected_params = {'component': 'test_component', 'opacity': 0.4} +# mock_remote_call.assert_called_once_with('addRepresentation', +# args=[expected_params, 0]) diff --git a/tests/test_widget.py b/tests/test_widget.py index a2dd8f63..7d52c5d5 100644 --- a/tests/test_widget.py +++ b/tests/test_widget.py @@ -837,7 +837,7 @@ def test_queuing_messages(): view.add_component(nv.datafiles.PDB) view.download_image() view - assert [f._method_name for f in view._ngl_displayed_callbacks_before_loaded] == \ + assert [f._method_name for f in view._callbacks_before_loaded] == \ [ 'loadFile', '_downloadImage']