From da7da26710a5c78ff368d39956cff6081314c6f5 Mon Sep 17 00:00:00 2001 From: blazma Date: Mon, 21 Nov 2022 18:01:57 +0100 Subject: [PATCH] Serialize PSP Attenuation test --- hippounit/tests/test_PSPAttenuationTest.py | 53 +++++++++++++++------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/hippounit/tests/test_PSPAttenuationTest.py b/hippounit/tests/test_PSPAttenuationTest.py index 46e6c5e..cb7a3dc 100644 --- a/hippounit/tests/test_PSPAttenuationTest.py +++ b/hippounit/tests/test_PSPAttenuationTest.py @@ -104,6 +104,8 @@ class PSPAttenuationTest(Test): random seed for random dendritic location selection trunk_origin : list first element : name of the section from which the trunk originates, second element : position on section (E.g. ['soma[5]', 1]). If not set by the user, the end of the default soma section is used. + serialized : boolean + if True, the simulation is not parallelized """ def __init__(self, config = {}, @@ -116,7 +118,8 @@ def __init__(self, config = {}, num_of_dend_locations = 15, random_seed = 1, save_all = True, - trunk_origin = None): + trunk_origin = None, + serialized=False): observation = self.format_data(observation) @@ -144,6 +147,7 @@ def __init__(self, config = {}, self.num_of_dend_locations = num_of_dend_locations self.random_seed = random_seed + self.serialized = serialized description = "Tests how much synaptic potential attenuates from the dendrite (different distances) to the soma." @@ -480,7 +484,14 @@ def generate_prediction(self, model, verbose=False): tau2 = self.config['tau_decay'] EPSC_amp = self.config['EPSC_amplitude'] - locations, locations_distances = model.get_random_locations_multiproc(self.num_of_dend_locations, self.random_seed, dist_range, self.trunk_origin) # number of random locations , seed + if self.serialized: + locations, locations_distances = model.get_random_locations(self.num_of_dend_locations, + self.random_seed, dist_range, + self.trunk_origin) + else: + locations, locations_distances = model.get_random_locations_multiproc(self.num_of_dend_locations, + self.random_seed, dist_range, + self.trunk_origin) # number of random locations , seed #print dend_locations, actual_distances print('Dendritic locations to be tested (with their actual distances):', locations_distances) @@ -493,25 +504,35 @@ def generate_prediction(self, model, verbose=False): #print locations_weights """ run model without an input""" - pool = multiprocessing.Pool(self.npool, maxtasksperchild=1) - run_stimulus_ = functools.partial(self.run_stimulus, model, tau1 = tau1, tau2 = tau2) - traces_no_input = pool.map(run_stimulus_, locations_weights, chunksize=1) - - pool.terminate() - pool.join() - del pool + if self.serialized: + traces_no_input = [] + for locations_weight in locations_weights: + trace = self.run_stimulus(model, locations_weight, tau1=tau1, tau2=tau2) + traces_no_input.append(trace) + else: + pool = multiprocessing.Pool(self.npool, maxtasksperchild=1) + run_stimulus_ = functools.partial(self.run_stimulus, model, tau1 = tau1, tau2 = tau2) + traces_no_input = pool.map(run_stimulus_, locations_weights, chunksize=1) + pool.terminate() + pool.join() + del pool traces_dict_no_input = dict(list(i.items())[0] for i in traces_no_input) # merge list of dicts into single dict locations_weights = self.calculate_weights(traces_dict_no_input, EPSC_amp) """run model with inputs""" - pool = multiprocessing.Pool(self.npool, maxtasksperchild=1) - run_stimulus_ = functools.partial(self.run_stimulus, model, tau1 = tau1, tau2 = tau2) - traces = pool.map(run_stimulus_, locations_weights, chunksize=1) - - pool.terminate() - pool.join() - del pool + if self.serialized: + traces = [] + for locations_weight in locations_weights: + trace = self.run_stimulus(model, locations_weight, tau1=tau1, tau2=tau2) + traces.append(trace) + else: + pool = multiprocessing.Pool(self.npool, maxtasksperchild=1) + run_stimulus_ = functools.partial(self.run_stimulus, model, tau1 = tau1, tau2 = tau2) + traces = pool.map(run_stimulus_, locations_weights, chunksize=1) + pool.terminate() + pool.join() + del pool traces_dict = dict(list(i.items())[0] for i in traces) # merge list of dicts into single dict