diff --git a/nglview/widget.py b/nglview/widget.py index d625278a..031024ba 100644 --- a/nglview/widget.py +++ b/nglview/widget.py @@ -52,7 +52,7 @@ _TRACKED_WIDGETS = {} -def _deprecated(msg): +def _deprecated(msg: str): def wrap_1(func): @@ -65,8 +65,8 @@ def wrap_2(*args, **kwargs): return wrap_1 -def write_html(fp, views, frame_range=None): - # type: (str, List[NGLWidget]) -> None +def write_html(fp: str, views: list, frame_range: tuple[int, int] | None = None): + # type: (str, List[NGLWidget]) """EXPERIMENTAL. Likely will be changed. Make html file to display a list of views. For further options, please @@ -291,7 +291,7 @@ def on_change_layout(change): self.layout.observe(on_change_layout, ['width', 'height']) - def _set_serialization(self, frame_range=None): + def _set_serialization(self, frame_range: tuple[int, int] | None = None): self._ngl_serialize = True resource = self._ngl_coordinate_resource if frame_range is not None: @@ -354,7 +354,7 @@ def camera(self, value): target='Stage', kwargs=dict(cameraType=self._camera_str)) - def _set_camera_orientation(self, arr): + def _set_camera_orientation(self, arr: list): self._remote_call('set_camera_orientation', target='Widget', args=[ @@ -365,7 +365,7 @@ def _request_stage_parameters(self): self._remote_call('requestUpdateStageParameters', target='Widget') @validate('gui_style') - def _validate_gui_style(self, proposal): + def _validate_gui_style(self, proposal: dict) -> str: val = proposal['value'] if val == 'ngl': if self._widget_theme is None: @@ -377,7 +377,7 @@ def _validate_gui_style(self, proposal): return val @observe("_gui_theme") - def _on_theme_changed(self, change): + def _on_theme_changed(self, change: dict): # EXPERIMENTAL from nglview.theme import theme if change.new == 'dark': @@ -386,7 +386,7 @@ def _on_theme_changed(self, change): self._widget_theme.light() @observe('background') - def _update_background_color(self, change): + def _update_background_color(self, change: dict): color = change['new'] self.stage.set_parameters(background_color=color) @@ -400,7 +400,7 @@ def _update_max_frame(self): for traj in self._trajlist if hasattr(traj, 'n_frames')) - 1 # index starts from 0 - def _wait_until_finished(self, timeout=0.0001): + def _wait_until_finished(self, timeout: float = 0.0001): # NGL need to send 'finished' signal to # backend self._event.clear() @@ -413,7 +413,7 @@ def _wait_until_finished(self, timeout=0.0001): # break while True break - def _run_on_another_thread(self, func, *args): + def _run_on_another_thread(self, func: callable, *args) -> threading.Thread: # use `event` to singal # func(*args) thread = threading.Thread( @@ -425,14 +425,14 @@ def _run_on_another_thread(self, func, *args): return thread @observe('loaded') - def on_loaded(self, change): + def on_loaded(self, change: dict): # trick for firefox on Linux time.sleep(0.1) if change['new']: self._fire_callbacks(self._ngl_displayed_callbacks_before_loaded) - def _fire_callbacks(self, callbacks): + def _fire_callbacks(self, callbacks: list): def _call(event): for callback in callbacks: @@ -449,7 +449,7 @@ def _ipython_display_(self, **kwargs): except AttributeError: display(super()._repr_mimebundle_(), raw=True) - def display(self, gui=False, style='ngl'): + def display(self, gui: bool = False, style: str = 'ngl') -> 'NGLWidget': """ Parameters @@ -472,7 +472,7 @@ def display(self, gui=False, style='ngl'): else: return self - def _set_size(self, w, h): + def _set_size(self, w: float | str, h: float | str): ''' Parameters @@ -488,7 +488,7 @@ def _set_size(self, w, h): ''' self._remote_call('setSize', target='Widget', args=[w, h]) - def _set_sync_repr(self, other_views): + def _set_sync_repr(self, other_views: list): model_ids = {v._model_id for v in other_views} self._synced_repr_model_ids = sorted( set(self._synced_repr_model_ids) | model_ids) @@ -496,7 +496,7 @@ def _set_sync_repr(self, other_views): target="Widget", args=[self._synced_repr_model_ids]) - def _set_unsync_repr(self, other_views): + def _set_unsync_repr(self, other_views: list): model_ids = {v._model_id for v in other_views} self._synced_repr_model_ids = list( set(self._synced_repr_model_ids) - model_ids) @@ -504,31 +504,31 @@ def _set_unsync_repr(self, other_views): target="Widget", args=[self._synced_repr_model_ids]) - def _set_sync_camera(self, other_views): + def _set_sync_camera(self, other_views: list): model_ids = {v._model_id for v in other_views} self._synced_model_ids = sorted(set(self._synced_model_ids) | model_ids) self._remote_call("setSyncCamera", target="Widget", args=[self._synced_model_ids]) - def _set_unsync_camera(self, other_views): + def _set_unsync_camera(self, other_views: list): model_ids = {v._model_id for v in other_views} self._synced_model_ids = list(set(self._synced_model_ids) - model_ids) self._remote_call("setSyncCamera", target="Widget", args=[self._synced_model_ids]) - def _set_spin(self, axis, angle): + def _set_spin(self, axis: list, angle: float): self._remote_call('setSpin', target='Stage', args=[axis, angle]) - def _set_selection(self, selection, component=0, repr_index=0): + def _set_selection(self, selection: str, component: int = 0, repr_index: int = 0): self._remote_call("setSelection", target='Representation', args=[selection], kwargs=dict(component_index=component, repr_index=repr_index)) - def color_by(self, color_scheme, component=0): + def color_by(self, color_scheme: str, component: int = 0): '''update color for all representations of given component Notes @@ -571,7 +571,7 @@ def representations(self, reps): for index in range(len(self._ngl_component_ids)): self.set_representations(reps) - def update_representation(self, component=0, repr_index=0, **parameters): + def update_representation(self, component: int = 0, repr_index: int = 0, **parameters): """ Parameters @@ -599,7 +599,7 @@ def _update_repr_dict(self): """ self._remote_call('request_repr_dict', target='Widget') - def set_representations(self, representations, component=0): + def set_representations(self, representations: list, component: int = 0): """ Parameters @@ -619,24 +619,24 @@ def set_representations(self, representations, component=0): ], kwargs=kwargs) - def _remove_representation(self, component=0, repr_index=0): + def _remove_representation(self, component: int = 0, repr_index: int = 0): self._remote_call('removeRepresentation', target='Widget', args=[component, repr_index]) - def _remove_representations_by_name(self, repr_name, component=0): + def _remove_representations_by_name(self, repr_name: str, component: int = 0): self._remote_call('removeRepresentationsByName', target='Widget', args=[repr_name, component]) - def _update_representations_by_name(self, repr_name, component=0, **kwargs): + def _update_representations_by_name(self, repr_name: str, component: int = 0, **kwargs): kwargs = _camelize_dict(kwargs) self._remote_call('updateRepresentationsByName', target='Widget', args=[repr_name, component], kwargs=kwargs) - def _display_repr(self, component=0, repr_index=0, name=None): + def _display_repr(self, component: int = 0, repr_index: int = 0, name: str | None = None) -> 'RepresentationControl': c = 'c' + str(component) r = str(repr_index) @@ -647,7 +647,7 @@ def _display_repr(self, component=0, repr_index=0, name=None): return RepresentationControl(self, component, repr_index, name=name) - def _set_coordinates(self, index, movie_making=False, render_params=None): + def _set_coordinates(self, index: int, movie_making: bool = False, render_params: dict | None = None): '''update coordinates for all trajectories at index-th frame''' render_params = render_params or {} if self._trajlist: @@ -670,7 +670,7 @@ def _set_coordinates(self, index, movie_making=False, render_params=None): else: print("no trajectory available") - def set_coordinates(self, arr_dict, movie_making=False, render_params=None): + def set_coordinates(self, arr_dict: dict, movie_making: bool = False, render_params: dict | None = None): """Used for update coordinates of a given trajectory >>> # arr: numpy array, ndim=2 >>> # update coordinates of 1st trajectory @@ -695,7 +695,7 @@ def set_coordinates(self, arr_dict, movie_making=False, render_params=None): self.send(msg, buffers=buffers) @observe('frame') - def _on_frame_changed(self, change): + def _on_frame_changed(self, change: dict): """set and send coordinates at current frame """ self._set_coordinates(change['new']) @@ -706,7 +706,7 @@ def clear(self, *args, **kwargs): self.clear_representations(*args, **kwargs) - def clear_representations(self, component=0): + def clear_representations(self, component: int = 0): '''clear all representations for given component Parameters @@ -719,7 +719,7 @@ def clear_representations(self, component=0): kwargs={'component_index': component}) @_update_url - def _add_shape(self, shapes, name='shape'): + def _add_shape(self, shapes: list, name: str = 'shape') -> 'ComponentViewer': """add shape objects TODO: update doc, caseless shape keyword @@ -764,7 +764,7 @@ def _add_shape(self, shapes, name='shape'): return ComponentViewer(self, cid) @_update_url - def add_representation(self, repr_type, selection='all', **kwargs): + def add_representation(self, repr_type: str, selection: str = 'all', **kwargs): '''Add structure representation (cartoon, licorice, ...) for given atom selection. Parameters @@ -833,7 +833,7 @@ def center_view(self, *args, **kwargs): """ self.center(*args, **kwargs) - def center(self, selection='*', duration=0, component=0, **kwargs): + def center(self, selection: str = '*', duration: int = 0, component: int = 0, **kwargs): """center view for given atom selection Examples @@ -847,7 +847,7 @@ def center(self, selection='*', duration=0, component=0, **kwargs): **kwargs) @observe('_image_data') - def _on_render_image(self, change): + def _on_render_image(self, change: dict): '''update image data to widget_image Notes @@ -857,11 +857,11 @@ def _on_render_image(self, change): self._widget_image._b64value = change['new'] def render_image(self, - frame=None, - factor=4, - antialias=True, - trim=False, - transparent=False): + frame: int | None = None, + factor: int = 4, + antialias: bool = True, + trim: bool = False, + transparent: bool = False) -> Image: """render and get image as ipywidgets.widget_image.Image Parameters @@ -904,11 +904,11 @@ def render_image(self, return iw def download_image(self, - filename='screenshot.png', - factor=4, - antialias=True, - trim=False, - transparent=False): + filename: str = 'screenshot.png', + factor: int = 4, + antialias: bool = True, + trim: bool = False, + transparent: bool = False): """render and download scene at current frame Parameters @@ -965,7 +965,7 @@ def _handle_image_data(self): _TRACKED_WIDGETS[self._ngl_msg.get('ID')].value = base64.b64decode( self._image_data) - def _handle_nglview_custom_msg(self, _, msg, buffers): + def _handle_nglview_custom_msg(self, _, msg: dict, buffers: list): self._ngl_msg = msg msg_type = self._ngl_msg.get('type') @@ -988,13 +988,13 @@ def _handle_nglview_custom_msg(self, _, msg, buffers): elif msg_type == 'image_data': self._handle_image_data() - def _request_repr_parameters(self, component=0, repr_index=0): + def _request_repr_parameters(self, component: int = 0, repr_index: int = 0): if self.n_components > 0: self._remote_call('requestReprParameters', target='Widget', args=[component, repr_index]) - def add_structure(self, structure, **kwargs): + def add_structure(self, structure: 'Structure', **kwargs) -> 'ComponentViewer': '''add structure to view Parameters @@ -1021,7 +1021,7 @@ def add_structure(self, structure, **kwargs): self._update_component_auto_completion() return self[-1] - def add_trajectory(self, trajectory, **kwargs): + def add_trajectory(self, trajectory: 'Trajectory', **kwargs) -> 'ComponentViewer': '''add new trajectory to `view` Parameters @@ -1058,7 +1058,7 @@ def add_trajectory(self, trajectory, **kwargs): self._update_component_auto_completion() return self[-1] - def add_pdbid(self, pdbid, **kwargs): + def add_pdbid(self, pdbid: str, **kwargs) -> 'ComponentViewer': '''add new Structure view by fetching pdb id from rcsb Examples @@ -1071,7 +1071,7 @@ def add_pdbid(self, pdbid, **kwargs): ''' return self.add_component(f'rcsb://{pdbid}.pdb', **kwargs) - def add_component(self, filename, **kwargs): + def add_component(self, filename: str | 'Trajectory' | 'Structure', **kwargs) -> 'ComponentViewer': '''add component from file/trajectory/struture Parameters @@ -1108,7 +1108,7 @@ def add_component(self, filename, **kwargs): self._update_component_auto_completion() return self[-1] - def _load_data(self, obj, **kwargs): + def _load_data(self, obj: 'Structure' | str, **kwargs): ''' Parameters @@ -1162,7 +1162,7 @@ def _load_data(self, obj, **kwargs): self._ngl_component_names.append(name) self._remote_call("loadFile", target='Stage', args=args, kwargs=kwargs2) - def remove_component(self, c): + def remove_component(self, c: int | 'ComponentViewer'): """remove component by its uuid. If isinstance(c, ComponentViewer), `c` won't be associated with `self` @@ -1201,15 +1201,15 @@ def remove_component(self, c): self._update_component_auto_completion() - def _dry_run(self, func, *args, **kwargs): + def _dry_run(self, func: callable, *args, **kwargs): return _dry_run(self, func, *args, **kwargs) def _get_remote_call_msg(self, - method_name, - target='Widget', - args=None, - kwargs=None, - **other_kwargs): + method_name: str, + target: str = 'Widget', + args: list | None = None, + kwargs: dict | None = None, + **other_kwargs) -> dict: """call NGL's methods from Python. Parameters @@ -1272,7 +1272,7 @@ def _get_remote_call_msg(self, msg.update(other_kwargs) return msg - def _trim_message(self, messages): + def _trim_message(self, messages: list) -> list: messages = messages[:] remove_comps = [(index, msg['args'][0]) @@ -1294,10 +1294,10 @@ def _trim_message(self, messages): return [msg for i, msg in enumerate(messages) if i not in messages_rm] def _remote_call(self, - method_name, - target='Widget', - args=None, - kwargs=None, + method_name: str, + target: str = 'Widget', + args: list | None = None, + kwargs: dict | None = None, **other_kwargs): msg = self._get_remote_call_msg(method_name, @@ -1325,7 +1325,7 @@ def callback(widget, msg=msg): archive.append(msg) self._ngl_msg_archive = self._trim_message(archive) - def _get_traj_by_id(self, itsid): + def _get_traj_by_id(self, itsid: str) -> 'Trajectory' | None: """return nglview.Trajectory or its derived class object """ for traj in self._trajlist: @@ -1333,7 +1333,7 @@ def _get_traj_by_id(self, itsid): return traj return None - def hide(self, indices): + def hide(self, indices: list): """set invisibility for given component/struture/trajectory (by their indices) """ traj_ids = {traj.id for traj in self._trajlist} @@ -1355,7 +1355,7 @@ def show(self, **kwargs): """ self.show_only(**kwargs) - def show_only(self, indices='all', **kwargs): + def show_only(self, indices: str | list = 'all', **kwargs): """set visibility for given components (by their indices) Parameters @@ -1411,10 +1411,10 @@ def _clear_component_auto_completion(self): name = 'component_' + str(index) delattr(self, name) - def _js(self, code, **kwargs): + def _js(self, code: str, **kwargs): self._execute_js_code(code, **kwargs) - def _execute_js_code(self, code, **kwargs): + def _execute_js_code(self, code: str, **kwargs): self._remote_call('executeCode', target='Widget', args=[code], **kwargs) def _update_component_auto_completion(self): @@ -1429,14 +1429,14 @@ def _update_component_auto_completion(self): traj_name = 'trajectory_' + str(trajids.index(cid)) setattr(self, traj_name, comp) - def __getitem__(self, index): + def __getitem__(self, index: int) -> 'ComponentViewer': """return ComponentViewer """ postive_index = py_utils.get_positive_index( index, len(self._ngl_component_ids)) return ComponentViewer(self, self._ngl_component_ids[postive_index]) - def __iter__(self): + def __iter__(self) -> 'ComponentViewer': """return ComponentViewer """ for i, _ in enumerate(self._ngl_component_ids): @@ -1455,7 +1455,7 @@ class Fullscreen(DOMWidget): _is_fullscreen = Bool().tag(sync=True) - def __init__(self, target, views): + def __init__(self, target: 'NGLWidget', views: list): super().__init__() self._target = target self._views = views @@ -1463,12 +1463,12 @@ def __init__(self, target, views): def fullscreen(self): self._js("this.fullscreen('%s')" % self._target.model_id) - def _js(self, code): + def _js(self, code: str): msg = {"executeCode": code} self.send(msg) @observe('_is_fullscreen') - def _fullscreen_changed(self, change): + def _fullscreen_changed(self, change: dict): if not change.new: self._target.layout.height = '300px' self.handle_resize()