Skip to content

Commit

Permalink
#68 new alternate constructors for front-end delay experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Jan 5, 2019
1 parent cd396c5 commit 2d7b7a4
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions darc/designs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ class BayesianAdaptiveDesignGeneratorDARC(DesignGeneratorABC):
'''

def __init__(self, DA=[0.], DB=DEFAULT_DB, RA=list(), RB=[100.],
RA_over_RB=list(), PA=[1.], PB=[1.],
RA_over_RB=list(), inter_reward_interval=list(),
PA=[1.], PB=[1.],
max_trials=20,
NO_REPEATS=False):
super().__init__()

self._input_type_validation(RA, DA, PA, RB, DB, PB, RA_over_RB)
self._input_value_validation(PA, PB, DA, DB, RA_over_RB)
self._input_type_validation(RA, DA, PA, RB, DB, PB, RA_over_RB, inter_reward_interval)
self._input_value_validation(PA, PB, DA, DB, RA_over_RB, inter_reward_interval)

self._DA = DA
self._DB = DB
Expand All @@ -68,6 +69,7 @@ def __init__(self, DA=[0.], DB=DEFAULT_DB, RA=list(), RB=[100.],
self._PA = PA
self._PB = PB
self._RA_over_RB = RA_over_RB
self._inter_reward_interval = inter_reward_interval
self.max_trials = max_trials
self.NO_REPEATS = NO_REPEATS

Expand Down Expand Up @@ -101,8 +103,7 @@ def get_next_design(self, model):
f'get_next_design() took: {time.time()-start_time:1.3f} seconds')
return chosen_design_named_tuple


def _input_type_validation(self, RA, DA, PA, RB, DB, PB, RA_over_RB):
def _input_type_validation(self, RA, DA, PA, RB, DB, PB, RA_over_RB, inter_reward_interval):
# NOTE: possibly not very Pythonic
assert isinstance(RA, list), "RA should be a list"
assert isinstance(DA, list), "DA should be a list"
Expand All @@ -111,6 +112,8 @@ def _input_type_validation(self, RA, DA, PA, RB, DB, PB, RA_over_RB):
assert isinstance(DB, list), "DB should be a list"
assert isinstance(PB, list), "PB should be a list"
assert isinstance(RA_over_RB, list), "RA_over_RB should be a list"
assert isinstance(inter_reward_interval,
list), "inter_reward_interval should be a list"

# we expect EITHER values in RA OR values in RA_over_RB
# assert (not RA) ^ (not RA_over_RB), "Expecting EITHER RA OR RA_over_RB as an"
Expand All @@ -120,8 +123,7 @@ def _input_type_validation(self, RA, DA, PA, RB, DB, PB, RA_over_RB):
if not RA_over_RB:
assert not RA is False, "If not providing list for RA_over_RB, we expect a list for RA"


def _input_value_validation(self, PA, PB, DA, DB, RA_over_RB):
def _input_value_validation(self, PA, PB, DA, DB, RA_over_RB, inter_reward_interval):
'''Confirm values of provided design space specs are valid'''
if np.any((np.array(PA) < 0) | (np.array(PA) > 1)):
raise ValueError('Expect all values of PA to be between 0-1')
Expand All @@ -135,6 +137,10 @@ def _input_value_validation(self, PA, PB, DA, DB, RA_over_RB):
if np.any(np.array(DB) < 0):
raise ValueError('Expecting all values of DB to be >= 0')

if np.any(np.array(inter_reward_interval) < 0):
raise ValueError(
'Expecting all values of inter_reward_interval to be >= 0')

if np.any((np.array(RA_over_RB) < 0) | (np.array(RA_over_RB) > 1)):
raise ValueError('Expect all values of RA_over_RB to be between 0-1')

Expand Down Expand Up @@ -175,6 +181,8 @@ def _generate_all_possible_designs(self, assume_discounting=True):
logging.debug(f'provided DB = {self._DB}')
logging.debug(f'provided PB = {self._PB}')
logging.debug(f'provided RA_over_RB = {self._RA_over_RB}')
logging.debug(
f'provided inter_reward_interval = {self._inter_reward_interval}')

if not self._RA_over_RB:
'''assuming we are not doing magnitude effect, as this is
Expand Down Expand Up @@ -246,6 +254,16 @@ def delayed(cls, max_trials):
return cls(max_trials=max_trials,
RA=list(100*np.linspace(0.05, 0.95, 91)))

@classmethod
def delayed_frontend_delay(cls, max_trials):
'''Defaults for a front-end delay experiment. These typically use a
fixed reward ratio.
- inter_reward_interval = RA+RB'''
return cls(max_trials=max_trials,
RA=[50.], RB=[100.],
DA=[0., 7, 30, 30*3, 30*6, 365, 365*5],
inter_reward_interval=[1, 7, 14, 30, 30*3, 30*6, 365])

@classmethod
def risky(cls, max_trials):
prob_list = [0.1, 0.25, 0.5, 0.75, 0.8, 0.9]
Expand Down

0 comments on commit 2d7b7a4

Please sign in to comment.