diff --git a/stytra/stimulation/__init__.py b/stytra/stimulation/__init__.py index 4749c334..2a7318d9 100644 --- a/stytra/stimulation/__init__.py +++ b/stytra/stimulation/__init__.py @@ -1,14 +1,16 @@ +from dataclasses import dataclass import datetime from copy import deepcopy +import warnings from PyQt5.QtCore import pyqtSignal, QTimer, QObject -from stytra.stimulation.stimuli import Pause, DynamicStimulus +from stytra.stimulation.stimuli import Pause, DynamicStimulus, EnvironmentState from stytra.collectors.accumulators import DynamicLog, FramerateAccumulator from stytra.utilities import FramerateRecorder from lightparam.param_qt import ParametrizedQt, Param import logging - + class ProtocolRunner(QObject): """Class for managing and running stimulation Protocols. @@ -91,7 +93,9 @@ def __init__(self, experiment=None, target_dt=0, log_print=True): self.current_stimulus = None # current stimulus object self.past_stimuli_elapsed = None # time elapsed in previous stimuli self.dynamic_log = None # dynamic log for stimuli - + + self.environment_state = EnvironmentState(calibrator = self.experiment.calibrator,) + self.update_protocol() self.protocol.sig_param_changed.connect(self.update_protocol) @@ -109,11 +113,31 @@ def update_protocol(self): self.stimuli = self.protocol._get_stimulus_list() self.current_stimulus = self.stimuli[0] - + + #populate environment_state class + if hasattr(self.experiment, 'estimator'): + self.environment_state.estimator = self.experiment.estimator + if hasattr(self.experiment, 'arduino_board'): + self.environment_state.arduino_board = self.experiment.arduino_board + if hasattr(self.experiment, 'asset_dir'): + self.environment_state.asset_dir = self.experiment.asset_dir + if hasattr(self.experiment, 'logger'): + self.environment_state.logger = self.experiment.logger + if hasattr(self.experiment, 'trigger'): + self.environment_state.trigger = self.experiment.trigger + # pass experiment to stimuli for calibrator and asset folders: for stimulus in self.stimuli: - stimulus.initialise_external(self.experiment) - + try: + stimulus.initialise_external(self.experiment, self.environment_state,) + except TypeError as e: + print("Error: {}".format(e)) + stimulus.initialise_external(self.experiment) + msg = "Warning: self._experiment is deprecated use self._environment_state instead, self._experiment will be unavailable from version 1.0!" + warnings.warn(msg, FutureWarning) + warnings.warn(msg, DeprecationWarning) + + if self.dynamic_log is None: self.dynamic_log = DynamicLog(self.stimuli, experiment=self.experiment) else: diff --git a/stytra/stimulation/stimuli/arduino.py b/stytra/stimulation/stimuli/arduino.py index 7700383c..161054ca 100644 --- a/stytra/stimulation/stimuli/arduino.py +++ b/stytra/stimulation/stimuli/arduino.py @@ -22,7 +22,7 @@ def __init__(self, pin_values_dict, *args, **kwargs): def start(self): super().start() - self._experiment.arduino_board.write_multiple(self.pin_values) + self._environment_state.arduino_board.write_multiple(self.pin_values) class ContinuousWriteArduinoPin(InterpolatedStimulus): @@ -45,10 +45,10 @@ def __init__(self, pin, *args, **kwargs): def update(self): super().update() - self._experiment.arduino_board.write(self.pin, self.pin_value) + self._environment_state.arduino_board.write(self.pin, self.pin_value) def stop(self): super().update() - self._experiment.arduino_board.write(self.pin, 0) + self._environment_state.arduino_board.write(self.pin, 0) diff --git a/stytra/stimulation/stimuli/closed_loop.py b/stytra/stimulation/stimuli/closed_loop.py index 73b5e25e..9dff11b4 100644 --- a/stytra/stimulation/stimuli/closed_loop.py +++ b/stytra/stimulation/stimuli/closed_loop.py @@ -69,7 +69,7 @@ def get_fish_vel(self): """ Function that update estimated fish velocty. Change to add lag or shunting. """ - self.fish_vel = self._experiment.estimator.get_velocity() + self.fish_vel = self._environment_state.estimator.get_velocity() def bout_started(self): """ Function called on bout start. @@ -87,7 +87,7 @@ def bout_ended(self): def update(self): if self.max_interbout_time is not None: if self._elapsed - self.prev_bout_t > self.max_interbout_time: - self._experiment.logger.info( + self._environment_state.logger.info( "Experiment aborted! {} seconds without bouts".format( self._elapsed - self.prev_bout_t ) @@ -177,7 +177,7 @@ def __init__( def bout_started(self): super().bout_started() - self.est_gain = self._experiment.estimator.base_gain + self.est_gain = self._environment_state.estimator.base_gain def bout_occurring(self): self.bout_vig.append(self.fish_vel / self.est_gain) @@ -196,7 +196,7 @@ def bout_ended(self): self.median_calib = self.median_vig * self.est_gain self.est_gain = self.target_avg_fish_vel / self.median_vig - self._experiment.estimator.base_gain = self.est_gain + self._environment_state.estimator.base_gain = self.est_gain self.bout_vel = [] @@ -208,14 +208,14 @@ def stop(self): ): self.abort_experiment() - self._experiment.logger.info( + self._environment_state.logger.info( "Experiment aborted! N bouts: {}; gain: {}".format( len(self.bouts_vig_list), self.est_gain ) ) if len(self.bouts_vig_list) > self.calibrate_after: - self._experiment.logger.info( + self._environment_state.logger.info( "Calibrated! Calculated gain {} with {} bouts".format( self.est_gain, len(self.bouts_vig_list) ) @@ -246,7 +246,7 @@ def __init__(self, newgain=1): self.newgain = newgain def start(self): - self._experiment.estimator.base_gain = self.newgain + self._environment_state.estimator.base_gain = self.newgain class GainLagClosedLoop1D(Basic_CL_1D): @@ -277,8 +277,7 @@ def get_fish_vel(self): shunting. """ super(GainLagClosedLoop1D, self).get_fish_vel() - self.lag_vel = self._experiment.estimator.get_velocity(self.lag) - + self.lag_vel = self._environment_state.estimator.get_velocity(self.lag) def calculate_final_vel(self): subtract_to_base = self.gain * self.lag_vel @@ -329,7 +328,7 @@ def bout_started(self): # print("set: {} gain and {} lag".format(self.gain, self.lag)) # refresh lag if it was changed: - self.lag_vel = self._experiment.estimator.get_velocity(self.lag) + self.lag_vel = self._environment_state.estimator.get_velocity(self.lag) class PerpendicularMotion(BackgroundStimulus, InterpolatedStimulus): @@ -338,7 +337,7 @@ class PerpendicularMotion(BackgroundStimulus, InterpolatedStimulus): """ def update(self): - y, x, theta = self._experiment.estimator.get_position() + y, x, theta = self._environment_state.estimator.get_position() if np.isfinite(theta): self.theta = theta super().update() @@ -352,7 +351,7 @@ def __init__(self, *args, **kwargs): def update(self): if self.is_tracking: - y, x, theta = self._experiment.estimator.get_position() + y, x, theta = self._environment_state.estimator.get_position() if np.isfinite(theta): self.x = x self.y = y @@ -362,7 +361,7 @@ def update(self): class FishRelativeStimulus(BackgroundStimulus): def get_transform(self, w, h, x, y): - y_fish, x_fish, theta_fish = self._experiment.estimator.get_position() + y_fish, x_fish, theta_fish = self._environment_state.estimator.get_position() if np.isnan(y_fish): return super().get_transform(w, h, x, y) rot_fish = (theta_fish - np.pi / 2) * 180 / np.pi diff --git a/stytra/stimulation/stimuli/conditional.py b/stytra/stimulation/stimuli/conditional.py index 1dd6f69f..831c6c13 100644 --- a/stytra/stimulation/stimuli/conditional.py +++ b/stytra/stimulation/stimuli/conditional.py @@ -36,9 +36,9 @@ def get_dynamic_state(self): state.update(self.active.get_dynamic_state()) return state - def initialise_external(self, experiment): - super().initialise_external(experiment) - self.active.initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) + self.active.initialise_external(experiment, environment_state) def get_state(self): state = super().get_state() @@ -50,7 +50,7 @@ def start(self): self.active.start() def check_condition(self): - y, x, theta = self._experiment.estimator.get_position() + y, x, theta = self._environment_state.estimator.get_position() return not np.isnan(y) def update(self): @@ -157,10 +157,10 @@ def get_dynamic_state(self): state.update(self._stim_on.get_dynamic_state()) return state - def initialise_external(self, experiment): - super().initialise_external(experiment) - self._stim_on.initialise_external(experiment) - self._stim_off.initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) + self._stim_on.initialise_external(experiment, environment_state) + self._stim_off.initialise_external(experiment, environment_state) def get_state(self): state = super().get_state() @@ -270,8 +270,8 @@ def __init__(self, stimulus, *args, centering_stimulus=None, margin=45, **kwargs self.yc = 240 def check_condition_on(self): - y, x, theta = self._experiment.estimator.get_position() - scale = self._experiment.calibrator.mm_px ** 2 + y, x, theta = self._environment_state.estimator.get_position() + scale = self._environment_state.calibrator.mm_px ** 2 return ( x > 0 and ((x - self.xc) ** 2 + (y - self.yc) ** 2) <= self.margin / scale ) @@ -323,15 +323,15 @@ def __init__( self.yc = 240 def check_condition_on(self): - y, x, theta = self._experiment.estimator.get_position() - scale = self._experiment.calibrator.mm_px ** 2 + y, x, theta = self._environment_state.estimator.get_position() + scale = self._environment_state.calibrator.mm_px ** 2 return (not np.isnan(x)) and ( (x - self.xc) ** 2 + (y - self.yc) ** 2 <= self.margin_in / scale ) def check_condition_off(self): - y, x, theta = self._experiment.estimator.get_position() - scale = self._experiment.calibrator.mm_px ** 2 + y, x, theta = self._environment_state.estimator.get_position() + scale = self._environment_state.calibrator.mm_px ** 2 return np.isnan(x) or ( (x - self.xc) ** 2 + (y - self.yc) ** 2 > self.margin_out / scale ) diff --git a/stytra/stimulation/stimuli/external.py b/stytra/stimulation/stimuli/external.py index 42319cb5..c926c468 100644 --- a/stytra/stimulation/stimuli/external.py +++ b/stytra/stimulation/stimuli/external.py @@ -42,7 +42,7 @@ def __init__( pulse_dur_str = str(pulse_dur_ms).zfill(3) self.mex = str("shock" + amp_dac + pulse_dur_str) - def initialise_external(self, experiment): + def initialise_external(self, experiment, environment_state): """ Parameters diff --git a/stytra/stimulation/stimuli/generic_stimuli.py b/stytra/stimulation/stimuli/generic_stimuli.py index 4b28d218..9857bcea 100644 --- a/stytra/stimulation/stimuli/generic_stimuli.py +++ b/stytra/stimulation/stimuli/generic_stimuli.py @@ -1,6 +1,29 @@ import numpy as np import datetime - +import warnings +from dataclasses import dataclass + +@dataclass +class EnvironmentState: + def __init__(self, calibrator = None, + estimator = None, + arduino_board = None, + asset_dir = None, + logger = None, + trigger = None, + height:int = 600, + width:int = 800): + """ + Holds Environment variables to pass from the protocol runner to the stimulus + """ + self.calibrator = calibrator + self.estimator = estimator + self.arduino_board = arduino_board + self.trigger = trigger + self.asset_dir = asset_dir + self.logger = logger + self.height = height + self.width = width class Stimulus: """ Abstract class for a Stimulus. @@ -66,6 +89,7 @@ def __init__(self, duration=0.0): self._elapsed = 0.0 # time from the beginning of the stimulus self.name = "undefined" self._experiment = None + self._environment_state = None self.real_time_start = None self.real_time_stop = None @@ -111,7 +135,7 @@ def stop(self): """ pass - def initialise_external(self, experiment): + def initialise_external(self, experiment, environment_state: EnvironmentState = None): """ Make a reference to the Experiment class inside the Stimulus. This is required to access from inside the Stimulus class to the Calibrator, the Pyboard, the asset directories with movies or the motor @@ -130,7 +154,19 @@ def initialise_external(self, experiment): None """ + + if isinstance(environment_state, EnvironmentState): + self._environment_state = environment_state + else: + self._environment_state = experiment + msg = "Warning: self._experiment is deprecated use self._environment_state instead, self._experiment will be unavailable from version 1.0!" + warnings.warn(msg, FutureWarning) + warnings.warn(msg, DeprecationWarning) + + self._experiment = experiment + + class DynamicStimulus(Stimulus): @@ -251,7 +287,7 @@ def start(self): def update(self): # If trigger is set, make it end: - if self._experiment.trigger.start_event.is_set(): + if self._environment_state.trigger.start_event.is_set(): self.duration = self._elapsed @@ -289,10 +325,10 @@ def update(self): s.update() s._elapsed = self._elapsed - def initialise_external(self, experiment): - super().initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) for s in self._stim_list: - s.initialise_external(experiment) + s.initialise_external(experiment, environment_state) @property def dynamic_parameter_names(self): diff --git a/stytra/stimulation/stimuli/kinematograms.py b/stytra/stimulation/stimuli/kinematograms.py index 9f69355b..fe19abfc 100644 --- a/stytra/stimulation/stimuli/kinematograms.py +++ b/stytra/stimulation/stimuli/kinematograms.py @@ -75,8 +75,8 @@ def get_dimensions(self): ------- number of dots to display and the displacement amount in pixel coordinates """ - if self._experiment.calibrator is not None: - mm_px = self._experiment.calibrator.mm_px + if self._environment_state.calibrator is not None: + mm_px = self._environment_state.calibrator.mm_px else: mm_px = 1 diff --git a/stytra/stimulation/stimuli/visual.py b/stytra/stimulation/stimuli/visual.py index 17bb8429..311ee29d 100644 --- a/stytra/stimulation/stimuli/visual.py +++ b/stytra/stimulation/stimuli/visual.py @@ -206,7 +206,7 @@ def __init__(self, *args, video_path, framerate=None, duration=None, **kwargs): def initialise_external(self, *args, **kwargs): super().initialise_external(*args, **kwargs) - self._video_seq = pims.Video(self._experiment.asset_dir + "/" + self.video_path) + self._video_seq = pims.Video(self._environment_state.asset_dir + "/" + self.video_path) self._current_frame = self._video_seq.get_frame(self.i_frame) try: @@ -313,8 +313,8 @@ def get_tile_ranges(self, imw, imh, w, h, tr: QTransform): return range(x_start, x_end + 1), range(y_start, y_end + 1) def paint(self, p, w, h): - if self._experiment.calibrator is not None: - mm_px = self._experiment.calibrator.mm_px + if self._environment_state.calibrator is not None: + mm_px = self._environment_state.calibrator.mm_px else: mm_px = 1 @@ -396,14 +396,14 @@ def __init__(self, *args, background, background_name=None, **kwargs): self.background_name = "array {}x{}".format(*self._background.shape) self._qbackground = None - def initialise_external(self, experiment): - super().initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) # Get background image from folder: if isinstance(self._background, str): self._qbackground = qimage2ndarray.array2qimage( existing_file_background( - self._experiment.asset_dir + "/" + self._background + self._environment_state.asset_dir + "/" + self._background ) ) elif isinstance(self._background, Path): @@ -468,7 +468,7 @@ def __init__( def create_pattern(self): l = max( 2, - int(self.grating_period / (max(self._experiment.calibrator.mm_px, 0.0001))), + int(self.grating_period / (max(self._environment_state.calibrator.mm_px, 0.0001))), ) if self.wave_shape == "square": self._pattern = np.ones((l, 3), np.uint8) * self.color_1 @@ -483,8 +483,8 @@ def create_pattern(self): + (1 - w[:, None]) * np.array(self.color_2)[None, :] ).astype(np.uint8) - def initialise_external(self, experiment): - super().initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) self.create_pattern() # Get background image from folder: self._qbackground = qimage2ndarray.array2qimage(self._pattern[None, :, :]) @@ -532,7 +532,7 @@ def get_unit_dims(self, w, h): #TODO what does this thing define? """ return ( - int(self.grating_period / (max(self._experiment.calibrator.mm_px, 0.0001))), + int(self.grating_period / (max(self._environment_state.calibrator.mm_px, 0.0001))), self.barheight, ) @@ -547,7 +547,7 @@ def draw_block(self, p, point, w, h): point.y(), int( self.grating_period - / (2 * max(self._experiment.calibrator.mm_px, 0.0001)) + / (2 * max(self._environment_state.calibrator.mm_px, 0.0001)) ), self.barheight, ) @@ -654,7 +654,7 @@ def update(self): def paint(self, p, w, h): x, y = ( - (np.arange(d) - d / 2) * self._experiment.calibrator.mm_px for d in (w, h) + (np.arange(d) - d / 2) * self._environment_state.calibrator.mm_px for d in (w, h) ) self.image = np.round( np.sin( @@ -751,8 +751,8 @@ def create_pattern(self, side_len=500): self._pattern = W * self.color_1 + (1 - W) * self.color_2 self._qbackground = qimage2ndarray.array2qimage(self._pattern) - def initialise_external(self, experiment): - super().initialise_external(experiment) + def initialise_external(self, experiment, environment_state): + super().initialise_external(experiment, environment_state) self.create_pattern() def draw_block(self, p, point, w, h): @@ -933,8 +933,8 @@ def __init__( def paint(self, p, w, h): super().paint(p, w, h) - if self._experiment.calibrator is not None: - mm_px = self._experiment.calibrator.mm_px + if self._environment_state.calibrator is not None: + mm_px = self._environment_state.calibrator.mm_px else: mm_px = 1