diff --git a/README.md b/README.md index 9db67aa..bdcf84e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,31 @@ -# SpikingNeuralProgramForagingInsect-PNAS -A spiking neural program for sensorimotor control during foraging in flying insects. +# A spiking neural program for sensorimotor control during foraging in flying insects + +This repository contains all accompanying code and allows to re-generate the results for the following paper: + +*H. Rapp, MP. Nawrot, A spiking neural program for sensorimotor control during foraging in flying insects.* + +If you use any parts of this code for your own work, please acknowledge us by citing the above paper. + + +If you have questions or encounter any problems while using this code base, feel free to file a Github issue here and we'll be in touch ! + +# project layout +This project uses a mixed code base of Python and MATLAB scripts. Python and BRIAN2 is used for the spiking neural network (SNN) models and simulations thereof. The simulation results (spike trains) are dumped as numpy pickled files (NPZ) and MATLAB (MAT) files. + +MATLAB is used for all learning and memory experiments to train the Multispike Tempotron readout neuron on the dumped spike trains from the model simulations and for most data analysis and figures. + +All script files are commented and/or self-explanatory. + +* `./` root folder contains all Python and BASH scripts to run the SNN simulations to re-generate the data used for the paper (Note: this requires several TB of disk space and large amount of RAM !) +* `olnet/models/` contains the BRIAN2 Mushroom body SNN model definitions +* `olnet/plotting/` contains matplotlib scripts to plot SNN network activity +* `matlab/` contains all the MATLAB code for fitting the readout neuron, data analysis and figures + + +# Using the model +The BRIAN2 model definition is located in the file `olnet/droso_mushroombody.py` for the model without APL neuron and `olnet/droso_mushroombody_apl.py` for the model with APL neuron. If you want to use our model for your own study, just import the model definitions from these files. +Both models use the same interface such that you can easily swap out the implementations by modifying your `import` statement. For an example on how to use it see the `run_model` function in `mkDataSet_DrosoCustomProtocol.py`. + + +# usage +Detailed description of usage of the individual script files will be published here shortly. \ No newline at end of file diff --git a/figures.mplstyle b/figures.mplstyle new file mode 100644 index 0000000..8e675a4 --- /dev/null +++ b/figures.mplstyle @@ -0,0 +1,9 @@ +axes.linewidth : 1 +xtick.labelsize : 8 +ytick.labelsize : 8 +axes.labelsize : 8 +lines.linewidth : 1 +lines.markersize : 2 +legend.frameon : False +legend.fontsize : 8 +axes.prop_cycle : cycler(color=['e41a1c', '377eb8', '4daf4a', '984ea3', 'c51b7d', '4d9221', '542788', '8c510a', 'b2182b', '2166ac', '01665e']) \ No newline at end of file diff --git a/make_labdcond.sh b/make_labdcond.sh new file mode 100755 index 0000000..0c04312 --- /dev/null +++ b/make_labdcond.sh @@ -0,0 +1,30 @@ +#!/bin/bash +PYTHON_BIN="python3.6" +N_CPU=12 +# default sparsity condition +$PYTHON_BIN mkDataSet_DrosoLabCondition.py --name LabCondConnectivityHighSparsityAPL_0-15-3sec \ +-N 30 --n_cpu $N_CPU -T 3 --max_pulse_duration 0.5 --min_pulse_duration 0.1 \ +--stim_noise_scale 0.004 --bg_noise_scale 0.0055 \ +--odor_ids 0 --odor_ids 15 \ +-o data/LabCondConnectivityHighSparsityAPL_0-15-3sec.mat + +$PYTHON_BIN mkDataSet_DrosoLabCondition.py --name LabCondConnectivityMediumSparsityAPL_0-15-3sec \ +-N 30 --n_cpu $N_CPU -T 3 --max_pulse_duration 0.5 --min_pulse_duration 0.1 \ +--stim_noise_scale 0.004 --bg_noise_scale 0.0055 \ +--odor_ids 0 --odor_ids 15 \ +--modelParams PNperKC=8.1 \ +-o data/LabCondConnectivityMediumSparsityAPL_0-15-3sec.mat + +$PYTHON_BIN mkDataSet_DrosoLabCondition.py --name LabCondWeightMediumSparsityAPL_0-15-3sec \ +-N 30 --n_cpu $N_CPU -T 3 --max_pulse_duration 0.5 --min_pulse_duration 0.1 \ +--stim_noise_scale 0.004 --bg_noise_scale 0.0055 \ +--odor_ids 0 --odor_ids 15 \ +--modelParams wPNKC=20 \ +-o data/LabCondWeightMediumSparsityAPL_0-15-3sec.mat + +$PYTHON_BIN mkDataSet_DrosoLabCondition.py --name LabCondConnectivityLowSparsityAPL_0-15-3sec \ +-N 30 --n_cpu $N_CPU -T 3 --max_pulse_duration 0.5 --min_pulse_duration 0.1 \ +--stim_noise_scale 0.004 --bg_noise_scale 0.0055 \ +--odor_ids 0 --odor_ids 15 \ +--modelParams PNperKC=12 \ +-o data/LabCondConnectivityLowSparsityAPL_0-15-3sec.mat diff --git a/make_paper_figures.py b/make_paper_figures.py new file mode 100644 index 0000000..11028da --- /dev/null +++ b/make_paper_figures.py @@ -0,0 +1,30 @@ +from olnet.plotting.figures import figure1 +import numpy as np +import matplotlib.pyplot as plt +fileType = "png" + +# plot LabConditioning single-trial +file = 'cache/LabCond_0-3-5-8-15-3sec/sim-odor-0-0-58.npz' +mstMATFile = 'matlab/model_cache/predictions/msp_classicalLabCond-0-15.odor-0.1-sp.1/LabCond_0-3-5-8-15-3sec.mat' +data = np.load(file)['data'][()] +figure_1 = figure1(data, t_min=1.0, t_max=1.3, orn_range=[620,680], pn_range=[0,35], cmap='seismic', + mstMatFile=mstMATFile, mstOdorIdx=0, mstTrialIdx=0, fig_size=(3.5, 6)) +figure_1.savefig("figures/system_response_labcond.{}".format(fileType), dpi=300) + +file = 'cache/PoisonPulse_0-3-5-8-15-10sec/sim-12-90.npz' +mstMATFile = 'matlab/model_cache/predictions/msp_classicalLabCond-0-15.odor-0.1-sp.1/PoisonPulse_0-3-5-8-15-10sec.mat' +data = np.load(file)['data'][()] +figure_2 = figure1(data, t_max=8, orn_range=[640,665], pn_range=[0,25], cmap='seismic',mstMatFile=mstMATFile, mstOdorIdx=0, mstTrialIdx=12) +figure_2.savefig("figures/system_response_poisson.{}".format(fileType), dpi=300) + +figure_2_alt = figure1(data, t_max=8, orn_range=-1, pn_range=[0,45], cmap='seismic',mstMatFile=mstMATFile, mstOdorIdx=0, mstTrialIdx=12) +figure_2_alt.savefig("figures/system_response_poisson_noORN.{}".format(fileType), dpi=300, fig_size=(4.25,8)) + +file = 'cache/GaussianCone_15-0-3-15_10sec/sim-13-27.npz' +mstMATFile = 'matlab/model_cache/predictions/msp_classicalLabCond-0-15.odor-15.1-sp.1/Gaussian_15-0-3-15_10sec.mat' +data = np.load(file)['data'][()] +figure_3 = figure1(data, t_max=10, orn_range=-1, pn_range=[0,35], cmap='seismic',mstMatFile=mstMATFile, mstOdorIdx=15, mstTrialIdx=13) +figure_3.savefig("figures/system_response_gaussian.{}".format(fileType), dpi=300) + +figure_3_alt = figure1(data, t_max=10, orn_range=-1, pn_range=[0,35], cmap='seismic') +figure_3_alt.savefig("figures/system_response_gaussian_noMST.{}".format(fileType), dpi=300) \ No newline at end of file diff --git a/make_sequences.sh b/make_sequences.sh new file mode 100755 index 0000000..6c7022d --- /dev/null +++ b/make_sequences.sh @@ -0,0 +1,148 @@ +#!/bin/bash +PYTHON_BIN="python3.6" +N_CPU=12 + +function defaultSparsity { +# 2% sparsity - default condition +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name 'GaussianConnectivityHighSparsityAPL_15-0-3-15_10sec' \ +-N 50 --odor_ids 0 --odor_ids 3 --odor_ids 15 --n_cpu $N_CPU -T 10 --stimulus_dt 5 --bg_noise_scale 0.0055 --pulse_rate 14 \ +--gaussian 1 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +-o data/GaussianConnectivityHighSparsityAPL_15-0-3-15_10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityHighSparsityAPL_0-3-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +-o data/PoisonPulseConnectivityHighSparsityAPL_0-3-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityHighSparsityAPL_0-3-5-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +-o data/PoisonPulseConnectivityHighSparsityAPL_0-3-5-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityHighSparsityAPL_0-3-5-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +-o data/PoisonPulseConnectivityHighSparsityAPL_0-3-5-8-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityHighSparsityAPL_0-3-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +-o data/PoisonPulseConnectivityHighSparsityAPL_0-3-8-15-10sec.mat +} + +function connectivityMediumSparsity { +# 5% sparsity +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name 'GaussianConnectivityMediumSparsityAPL_15-0-3-15_10sec' \ +-N 50 --odor_ids 0 --odor_ids 3 --odor_ids 15 --n_cpu $N_CPU -T 10 --stimulus_dt 5 --bg_noise_scale 0.0055 --pulse_rate 14 \ +--gaussian 1 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=8.1 \ +-o data/GaussianConnectivityMediumSparsityAPL_15-0-3-15_10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityMediumSparsityAPL_0-3-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=8.1 \ +-o data/PoisonPulseConnectivityMediumSparsityAPL_0-3-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityMediumSparsityAPL_0-3-5-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=8.1 \ +-o data/PoisonPulseConnectivityMediumSparsityAPL_0-3-5-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityMediumSparsityAPL_0-3-5-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=8.1 \ +-o data/PoisonPulseConnectivityMediumSparsityAPL_0-3-5-8-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityMediumSparsityAPL_0-3-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=8.1 \ +-o data/PoisonPulseConnectivityMediumSparsityAPL_0-3-8-15-10sec.mat +} + + +function connectivityLowSparsity { +# ~10% sparsity +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name 'GaussianConnectivityLowSparsityAPL_15-0-3-15_10sec' \ +-N 50 --odor_ids 0 --odor_ids 3 --odor_ids 15 --n_cpu $N_CPU -T 10 --stimulus_dt 5 --bg_noise_scale 0.0055 --pulse_rate 14 \ +--gaussian 1 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=12 \ +-o data/GaussianConnectivityLowSparsityAPL_15-0-3-15_10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityLowSparsityAPL_0-3-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=12 \ +-o data/PoisonPulseConnectivityLowSparsityAPL_0-3-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityLowSparsityAPL_0-3-5-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=12 \ +-o data/PoisonPulseConnectivityLowSparsityAPL_0-3-5-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityLowSparsityAPL_0-3-5-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=12 \ +-o data/PoisonPulseConnectivityLowSparsityAPL_0-3-5-8-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseConnectivityLowSparsityAPL_0-3-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams PNperKC=12 \ +-o data/PoisonPulseConnectivityLowSparsityAPL_0-3-8-15-10sec.mat +} + + +function weightMediumSparsity { +# 5% sparsity +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name 'GaussianWeightMediumSparsityAPL_15-0-3-15_10sec' \ +-N 50 --odor_ids 0 --odor_ids 3 --odor_ids 15 --n_cpu $N_CPU -T 10 --stimulus_dt 5 --bg_noise_scale 0.0055 --pulse_rate 14 \ +--gaussian 1 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams wPNKC=20 \ +-o data/GaussianWeightMediumSparsityAPL_15-0-3-15_10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseWeightMediumSparsityAPL_0-3-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams wPNKC=20 \ +-o data/PoisonPulseWeightMediumSparsityAPL_0-3-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseWeightMediumSparsityAPL_0-3-5-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams wPNKC=20 \ +-o data/PoisonPulseWeightMediumSparsityAPL_0-3-5-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseWeightMediumSparsityAPL_0-3-5-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 5 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams wPNKC=20 \ +-o data/PoisonPulseWeightMediumSparsityAPL_0-3-5-8-15-10sec.mat + +$PYTHON_BIN mkDataSet_DrosoArtificialStim.py --name PoisonPulseWeightMediumSparsityAPL_0-3-8-15-10sec \ +-N 50 --n_cpu $N_CPU --pulse_rate 8 -T 10 --stim_noise_scale 0.004 --bg_noise_scale 0.0055 --min_pulse_duration 0.001 --max_pulse_duration 0.2 \ +--odor_ids 0 --odor_ids 3 --odor_ids 8 --odor_ids 15 \ +--gaussian 0 --gauss_mean 5 --gauss_std 1.5 --gauss_primary_odor_id 15 --gauss_rate_other 5 \ +--modelParams wPNKC=20 \ +-o data/PoisonPulseWeightMediumSparsityAPL_0-3-8-15-10sec.mat +} \ No newline at end of file diff --git a/matlab/MSPTempotron.m b/matlab/MSPTempotron.m new file mode 100644 index 0000000..d7f785a --- /dev/null +++ b/matlab/MSPTempotron.m @@ -0,0 +1,91 @@ +% MSPTEMPOTRON(ts, t_i, w, V_thresh, V_rest, tau_m, tau_s) - multi-spike tempotron neuron model +% ts: time vector +% t_i: input pattern as cell array of spike times for each synapse +% w: synaptic efficiencies / weights +% V_thresh: spike threshold potential +% V_rest: resting potential +% tau_m: membrane time constant +% tau_s: synapse time constant + +function [v_t,t_out,t_out_idx,v_unreset,V_thresh, V_rest, V_0, tau_m, tau_s] = MSPTempotron(exp_fn, ts, t_i, w, V_thresh, V_rest, tau_m, tau_s) + + % MATLAB vararg parsing boilerplate + if nargin < 6 + tau_m = 0.020; + end + + if nargin < 7 + tau_s = 0.005; + end + + + eta = tau_m/tau_s; + V_0 = eta^(eta/(eta-1)) / (eta - 1); % normalizing constant for syn. currents + + v_t = zeros(1, length(ts)) + V_rest; % init V(t) with resting membrane potential + t_out= []; % output spike times + t_out_idx = []; % indices of ts vector where output spikes occour + t_sp_idx = 1; + + v_t(t_sp_idx:end) = (0 .* v_t(t_sp_idx:end)); % membrane potential V(t) + + % simulate neuron + for i=1:length(w) + v_sub = msp_tempotron_kernel(exp_fn, ts, t_i{i}, tau_m, tau_s, V_0); + v_t = v_t + (w(i).*v_sub); + end + + v_unreset = v_t; % save unresetted membrane potential + + % determine output spike times & perform soft-reset of V(t) + while (~isempty(t_sp_idx)) + % reached V_threshold, soft-reset & emit spike time + above = v_t > V_thresh; % 1 and 0 for event / non event + crossings = diff(above); + idx = find(crossings>0)+0; + t_sp_idx = idx(idx ~= t_sp_idx); + + if (~isempty(t_sp_idx)) + t_sp_idx = t_sp_idx(1); + t_out = unique([t_out ts(t_sp_idx)]); + t_out_idx = unique([t_out_idx t_sp_idx]); + v_reset = (V_thresh .* exp(-(ts(t_sp_idx:end)- ts(t_sp_idx))/tau_m));% + (v_t(t_sp_idx) - V_thresh); + v_t(t_sp_idx:end) = v_t(t_sp_idx:end) - v_reset; + end + end +end + + +% kernel to compute synaptic input current +function [v] = msp_tempotron_kernel(exp_fn, t, t_i, tau_m, tau_s, V_0) + exp_kernel = @(x, tau) exp(-(x)/tau); + step_fn = @(x) (sign(x) + 1) / 2; + v = zeros(1, length(t)); + for i=1:length(t_i) + tmp = step_fn(t-t_i(i)) .* (t-t_i(i)); + v = v + (exp_kernel(tmp, tau_m) - exp_kernel(tmp, tau_s)); + end + + v = v .* V_0; +end + +function [v] = msp_tempotron_kernel_opt(exp_fn, t, t_i, tau_m, tau_s, V_0) + exp_kernel = @(x, tau) exp(-(x)/tau); + step_fn = @(x) (sign(x) + 1) / 2; + v = zeros(1, length(t)); + if ~isempty(t_i) + tmp = zeros(length(t_i), length(t)); + for i=1:length(t_i) + %tmp = heaviside(t-t_i(i)) .* (t-t_i(i)); + tmp(i, :) = step_fn(t-t_i(i)) .* (t-t_i(i)); + end + + %tmp_idx = find(tmp > 0); + %tmp_2 = tmp; + %tmp_2(tmp_idx) = exp_kernel(tmp_2(tmp_idx), tau_s); + %tmp(tmp_idx) = exp_kernel(tmp(tmp_idx), tau_m); + v = v + sum(exp_kernel(tmp, tau_m) - exp_kernel(tmp, tau_s)); + %v = v + sum(tmp - tmp_2); + end + v = v .* V_0; +end diff --git a/matlab/fig_differential_conditioning.m b/matlab/fig_differential_conditioning.m new file mode 100644 index 0000000..f7dd3bc --- /dev/null +++ b/matlab/fig_differential_conditioning.m @@ -0,0 +1,247 @@ +close all; +clear all; +odorId = 15; + +if exist('sparsityLevel', 'var') == 0 + sparsityLevel = 'ConnectivityMediumSparsity'; +end + +if exist('dataSetName', 'var') == 0 + %dataSetName = 'msp_classicalLabCondConnectivityLowSparsityAPL-0-15.odor-15.1-sp'; + dataSetName = sprintf('msp_classicalLabCond%s-0-15.odor-%d.1-sp', sparsityLevel, odorId); +end + +filePattern = ['model_cache/', dataSetName, '.*.mat']; +f = dir(filePattern); +files = {f.name}; +files = natsort(files)'; +fprintf('models: %d\n', length(files)); +col_cs_minus = [0 158 227] / 255; +col_cs_plus = [243 146 0] / 255; +X_train = []; +X_test = []; +X_behavior_resp = {}; % col 1&2: cs+/cs- response on TrainSet, col 3&4: cs+/cs- on TestSet +N_models = length(files); + +batch_size = 4; % batch size used to compute accuarcy score over +eval_behavior_at_samples = 20; + +eval_beh_at_epoch = floor(40 / batch_size); + + +for j=1:N_models + data = load(sprintf('model_cache/%s', files{j})); + X_train = [X_train data.train_accuracy(:,1)]; + X_test = [X_test data.validation_accuracy(:,1)]; + + if (length(find(data.train_accuracy == 0)) > 3) + warning(sprintf('outlier (prob. perfectly converged model): %s', files{j})); + continue; + end + + % compute behavior response: + % correct CS+ response: 1 or more spikes + % correct CS- response: exactly 0 spikes + beh_resp = cell(length(data.predictions),6); + for i=1:length(data.predictions) + % compute TrainSet behavior response to CS++ and CS- + idx_cs = find(data.predictions{i,2} == 1); + idx_us = find(data.predictions{i,2} == 0); + idx_cs_behav = find(data.predictions{i,1} >= data.predictions{i,2}); + idx_us_behav = find(data.predictions{i,1} == data.predictions{i,2}); + res_vec = zeros(1, length(data.predictions{i,2})); + res_vec(intersect(idx_cs_behav, idx_cs)) = 1; + res_vec(intersect(idx_us_behav, idx_us)) = 1; + beh_resp{i,1} = res_vec; % correctness vector over all train samples + beh_resp{i,2} = data.predictions{i,1}; + beh_resp{i,3} = data.predictions{i,2}; + + % compute TestSet behavior response to CS++ and CS- + idx_cs = find(data.predictions{i,4} == 1); + idx_us = find(data.predictions{i,4} == 0); + idx_cs_behav = find(data.predictions{i,3} >= data.predictions{i,4}); + idx_us_behav = find(data.predictions{i,3} == data.predictions{i,4}); + res_vec = zeros(1, length(data.predictions{i,4})); + res_vec(intersect(idx_cs_behav, idx_cs)) = 1; + res_vec(intersect(idx_us_behav, idx_us)) = 1; + beh_resp{i,4} = res_vec; % correctness vector over all test samples + beh_resp{i,5} = data.predictions{i,3}; % predictions (TEST) + beh_resp{i,6} = data.predictions{i,4}; % ground truth (TEST) + end + X_behavior_resp{end+1} = beh_resp; +end + +N_models = length(X_behavior_resp); +xs = 1:size(X_train,1); +xs = xs * batch_size; + +fig = figure(); +% plut learning & test accuracy as function of training samples present +%subplot(2,1,1); +hold on; +sem = std(X_train, [], 2) / sqrt(size(X_train,1)); +mu_plus_sem = mean(X_train, 2) + sem; +mu_minus_sem = mean(X_train, 2) - sem; +plot(0, 50, 'sk', 'MarkerFaceColor', [0 0 0]); +h1 = plot(xs, mean(X_train, 2), 'k', 'LineWidth', 1.5); +h2 = plot(xs, mu_plus_sem, 'Color', [.3 .3 .3]); +h3 = plot(xs, mu_minus_sem, 'Color', [.3 .3 .3]); + +%fill([xs; fliplr(xs)]', [mu_minus_sem, fliplr(mu_plus_sem)], 'g'); % fill area defined by x & yy in blue + +%fill(xs, [mean(X_train, 2) - sem, mean(X_train, 2) + sem], 0, ... +% 'FaceColor', [.3 .3 .3]); +%plot(xs, mean(X_train, 2) - sem, 'Color', [.3 .3 .3]); + +%sem = std(X_test, [], 1); +%plot(xs, mean(X_test, 2), 'b'); +%plot([eval_behavior_at_samples eval_behavior_at_samples], [0 100], '-.k'); +%legend({'train', 'test', 'behavior response evaluation'},'Location','SouthEast'); +%legend({'train', 'test'},'Location','SouthEast'); +leg = legend([h1, h2],{'mean', 's.e.m.'}, 'Location', 'best'); +leg.ItemTokenSize = [10,5]; + +%title(sprintf('training (N=%d models)', size(X_train,2))); +%plot(xs, median(X_train, 2) + sem, 'k'); +%plot(xs, median(X_train, 2) - sem, 'k'); +xlabel('# trials'); +ylabel('accuracy [%]'); +ylim([0 100]); +xlim([0 100]); +xticks([0 20 50 100]); +xticklabels(floor([0 20 50 100] / 2)); +ax1 = gca(); + +% compute behav. score +B_test = zeros(2, length(X_behavior_resp{1, end}{eval_beh_at_epoch, 4})); +B_train = zeros(2, length(X_behavior_resp{1, end}{eval_beh_at_epoch, 2})); +B_norm_train = zeros(2, length(X_behavior_resp{1, end}{eval_beh_at_epoch, 2})); +B_norm_test = zeros(2, length(X_behavior_resp{1, end}{eval_beh_at_epoch, 4})); +fprintf("samples in epoch %d: %d (train) | %d (test)\n", eval_beh_at_epoch, size(B_train,2), size(B_test, 2)); +fprintf("eval behavior at epoch: %d (%d samples)\n", eval_beh_at_epoch, eval_beh_at_epoch * batch_size); +for i=1:length(X_behavior_resp) + % compute behavior learning curve on train set + v_beh = X_behavior_resp{1, i}{eval_beh_at_epoch, 1}; + v_true = X_behavior_resp{1, i}{eval_beh_at_epoch, 3}; + for j=1:length(v_true) + if (v_true(j) == 1 && v_beh(j) == 1) + B_train(1,j) = B_train(1,j) + 1; + end + + if (v_true(j) == 1) + B_norm_train(1,j) = B_norm_train(1,j) + 1; + end + + if (v_true(j) == 0 && v_beh(j) == 1) + B_train(2,j) = B_train(2,j) + 1; + end + + if (v_true(j) == 0) + B_norm_train(2,j) = B_norm_train(2,j) + 1; + end + end + % compute behavior learning curve on test set + v_beh = X_behavior_resp{1, i}{eval_beh_at_epoch, 4}; + v_true = X_behavior_resp{1, i}{eval_beh_at_epoch, 6}; + for j=1:length(v_true) + if (v_true(j) == 1 && v_beh(j) == 1) + B_test(1,j) = B_test(1,j) + 1; + end + + if (v_true(j) == 1) + B_norm_test(1,j) = B_norm_test(1,j) + 1; + end + + if (v_true(j) == 0 && v_beh(j) == 1) + B_test(2,j) = B_test(2,j) + 1; + end + + if (v_true(j) == 0) + B_norm_test(2,j) = B_norm_test(2,j) + 1; + end + end +end +%B_train = B_train / length(X_behavior_resp); +%B_test = B_test / length(X_behavior_resp); +B_train = B_train ./ B_norm_train; +B_test = B_test ./ B_norm_test; + + +fig2 = figure(); +%subplot(2,1,2); +ax = gca; +ax2 = ax; +hold on; +%title('behavior response (train)'); +plot(1, 0, '-s', 'MarkerFaceColor', col_cs_plus, 'MarkerEdgeColor', col_cs_plus); +plot(1, 100, '-s', 'MarkerFaceColor', col_cs_minus, 'MarkerEdgeColor', col_cs_minus); + +h1 = plot(2:size(B_train,2), B_train(1,2:end) * 100, '-o', ... + 'Color', col_cs_plus, 'MarkerFaceColor', col_cs_plus, ... + 'MarkerEdgeColor', [1 1 1]); +h2 = plot(2:size(B_train,2), B_train(2,2:end) * 100, '-d', ... + 'Color', col_cs_minus, 'MarkerFaceColor', col_cs_minus, ... + 'MarkerEdgeColor', [1 1 1]); + +legend([h1,h2], {'CS+', 'CS-'},'Location','SouthEast'); +xlabel('trial'); +ylabel('correct responders [%]'); +ylim([0 100]); +xlim([0, eval_behavior_at_samples]); +xticks([1 5 10 15 20]); +ax.XRuler.MinorTick = 'on'; +ax.XRuler.MinorTickValues = 1:1:20; +ax.XRuler.MinorTickValuesMode = 'manual'; + +%axp = get(gca,'position'); +%axp(1) = 1.1 * axp(1); +%set(gca, 'Position', axp); + +%yh = get(gca,'ylabel'); % handle to the label object +%p = get(yh,'position'); % get the current position property +%p(1) = 0.9*p(1) ; % double the distance, + % negative values put the label below the axis +%set(yh,'position',p); % set the new position + +%subplot(1,3,3); +%hold on; +%title('behavior response (test)'); +%plot(1:size(B_test,2), B_test(1,:) * 100, '-or'); +%plot(1:size(B_test,2), B_test(2,:) * 100, '-db'); +%legend({'CS+', 'CS-'},'Location','West'); +%xlabel('trial'); +%ylabel('% correct'); +%ylim([0 100]); + + +%single-col figure: 85 mm + +% model fitting learning curve +set(ax1,'Units','centimeters','Position',[1 1 3 9]); +fig.Units = 'centimeters'; +fig.Position(3) = 4.5; % 4.5 ; %10.5; +fig.Position(4) = 10.5; +set(fig.Children, ... + 'FontName', 'Arial', ... + 'FontSize', 8); +set(fig, 'DefaultFigureRenderer', 'painters'); +fig.PaperPositionMode = 'auto'; +set(fig, 'PaperUnits', 'centimeters', 'Units', 'centimeters'); +set(fig, 'PaperSize', fig.Position(3:4), 'Units', 'centimeters'); +mkdir('../figures/', dataSetName); +print(fig, '-dpdf', ['../figures/', dataSetName, '/fig_learning.pdf']); + +% behavioral learning curve +set(ax2,'Units','centimeters','Position',[1 1 3 5]); +fig2.Units = 'centimeters'; +fig2.Position(3) = 4.5; % 4.5 ; %10.5; +fig2.Position(4) = 6.5; +set(fig2.Children, ... + 'FontName', 'Arial', ... + 'FontSize', 8); +set(fig2, 'DefaultFigureRenderer', 'painters'); +fig2.PaperPositionMode = 'auto'; +set(fig2, 'PaperUnits', 'centimeters', 'Units', 'centimeters'); +set(fig2, 'PaperSize', fig.Position(3:4), 'Units', 'centimeters'); +mkdir('../figures/', dataSetName); +print(fig2, '-dpdf', ['../figures/', dataSetName, '/fig_differential_conditioning.pdf']); diff --git a/matlab/fig_evidence_trace.m b/matlab/fig_evidence_trace.m new file mode 100644 index 0000000..7157f91 --- /dev/null +++ b/matlab/fig_evidence_trace.m @@ -0,0 +1,156 @@ +plumeModelName = 'PoisonPulse_0-3-5-15-10sec'; +plumeModelName = 'GaussianLowSparsity_15-0-3-15_10sec'; +odorId = 15; +col_cs_minus = [0 158 227]; +col_cs_plus = [243 146 0]; + +predictions = load(sprintf('model_cache/predictions/msp_classicalLabCondLowSparsity-0-15.odor-%d.1-sp.5/%s.mat', odorId, plumeModelName)); +data = load(sprintf('../data/%s.mat', plumeModelName)); +rng(1365); +N_trials = 3; + +trialIdx = randsample(data.data.trial_ids, N_trials) + 1; +%trialIdx = data.data.trial_ids(1:N_trials)+1; +bgOdorIds = [3 5 15]; +colors = [[0 0 0]; col_cs_plus; [192 192 192]] / 255; +odorIdx = odorId + 1; +stim_times = data.data.stimulus_times(:,odorIdx); +T_trial = data.data.T_trial; +sp_times = predictions.sp_times'; +T_center = T_trial / 2; +tau_avg = 1.5; +tau = tau_avg; +idx_cues_model = 1; +idx_cues_true = 2; +idx_cues_bg = 3; +dt = 1/1000; +t = 1:ceil(T_trial / dt); +t = t .* dt; + +t_all = 1:ceil((T_trial * N_trials) / dt); +t_all = t_all .* dt; + +% bin spike times +X = zeros(3,N_trials,length(t)); +X_sp = {}; +for k=1:N_trials + % model prediction + sp = sp_times{trialIdx(k)}; + for i=1:length(sp) + idx = round(sp(i) / dt); + X(idx_cues_model,k,idx) = 1; + end + % true sensory cues + stim = stim_times{trialIdx(k)}; + for i=1:length(stim) + idx = round(stim(i) / dt); + X(idx_cues_true,k,idx) = 1; + end + % background / distractor cues + bg_stim = data.data.stimulus_times{trialIdx(k),bgOdorIds+1}; + for i=1:length(bg_stim) + idx = round(bg_stim(i) / dt); + X(idx_cues_bg,k,idx) = 1; + end +end + +f = figure(); +f.Renderer='Painters'; +%set(gca,'ydir','reverse') +%set(gca,'vis','on'); +%set(gca,'xtick',[],'ytick',[]); +%set(gca,'box','off'); +%set(gca,'ycolor',[.7 .7 .7],'xcolor',[.7 .7 .7]); +%set(gca,'xlim', [0 T_trial]); +%set(gca,'ylim', [0 2.5]); + +% plot trials +subplot(3,1,1); +hold on; +ax = gca; +ax.YAxis.TickLength = [0 0]; +%ax.YAxis.LineWidth = 0.0; + +offsets = [[0 0.5] ; [-0.5 0] ; [-0.5 0]]; +hndl = cell(size(X,1), 1); +for k=1:N_trials + plot([0 T_trial], [k k] + max(max(offsets)), 'Color', [0 0 0 0.2], 'LineWidth', 1); + for i=1:size(X,2) + for j=1:size(X,1) + sp_pos = t(squeeze(X(j,k,:)) == 1); + hndl{j,1} = plot([sp_pos; sp_pos], [(ones(size(sp_pos))*k) + offsets(j,1); (ones(size(sp_pos))*k) + offsets(j,2)], ... + 'Color', colors(j,:), 'linewidth', 2); + end + end +end +xlim([0 T_trial]); +xticks([]); +%xlabel('time [sec]'); +ylim([max(max(offsets)) N_trials+1]); +yticks([1 N_trials]); +ylabel('casting iteration'); +try +legend([hndl{1}(1) hndl{2}(1) hndl{3}(1)], {'model output', 'sensory cue', 'background cue'}, 'Location', 'northoutside', 'NumColumns', 3); +catch +legend([hndl{1}(1) hndl{2}(1) hndl{3}(1)], {'model output', 'sensory cue', 'background cue'}, 'Location', 'northoutside'); +end + + + +% plot smoothed data +subplot(3,1,2); +hold on; +ax.YAxis.TickLength = [0 0]; +for k=1:N_trials + %plot([0 T_trial], [k k] + 1, 'Color', [0 0 0 0.2], 'LineWidth', 1); + for i=1:size(X,2) + for j =1:size(X,1)-1 + %% convolute + kernel = gausswin(round(tau / dt)); + evidence = conv(squeeze(X(j,k,:)), kernel, 'same'); + lineStyle = '-'; + if mod(j,2) == 0 + lineStyle = '--'; % make true sensory rate dashed + end + + plot(t, ((evidence / max(evidence))*0.8) + k,'color',[colors(j,:) .6],'linewidth', 1,'LineStyle', lineStyle); + end + end +end +xlim([0 T_trial]); +xticks([]); +%xlabel('time [sec]'); +ylim([max(max(offsets)) N_trials+1]); +yticks([1 N_trials]); +ylabel('accum. evidence'); + +% plot averaged density +subplot(3,1,3); +hold on; +ax.YAxis.TickLength = [0 0]; + +%% convolute +kernel = gausswin(round(tau_avg / dt)); +%kernel = kernel./sum(kernel); +avg_evidence = sum(squeeze(X(1,:,:)), 1); +evidence = conv(avg_evidence, kernel, 'same'); +plot(t, ((evidence / max(evidence))*0.9),'color',[colors(1,:) .6],'linewidth', 1,'LineStyle', '-'); + +avg_evidence = sum(squeeze(X(2,:,:)), 1); +evidence = conv(avg_evidence, kernel, 'same'); +plot(t, ((evidence / max(evidence))*0.9),'color',[colors(2,:) .6],'linewidth', 1,'LineStyle', '--'); + +plot([T_center T_center], [0 1], 'LineWidth', 1, 'LineStyle', '--', 'Color', [1 0 1 0.8]); +legend({'model', 'true evidence', 'plume center'}); + +xlim([0 T_trial]); +xticks(0:2:T_trial); +yticks([ 0 1]); +xlabel('time [sec]'); +%ylim([max(max(offsets)) N_trials+1]); +ylabel('avg. evidence'); + +%lbwh = get(gca, 'position'); +%lbwh(end) = lbwh(end)*0.5; +%set(gca, 'position', lbwh); +print(sprintf('../figures/fig_evidence_trace_%s.pdf', plumeModelName),'-dpdf','-fillpage'); diff --git a/matlab/fig_mbon_transfer_learning.m b/matlab/fig_mbon_transfer_learning.m new file mode 100644 index 0000000..50606e9 --- /dev/null +++ b/matlab/fig_mbon_transfer_learning.m @@ -0,0 +1,86 @@ +close all; +clear all; +odorId = 15; + +if exist('sparsityLevel', 'var') == 0 + sparsityLevel = 'ConnectivityMediumSparsity'; +end + +if exist('dataSetName', 'var') == 0 + %dataSetName = 'msp_classicalLabCondConnectivityLowSparsityAPL-0-15.odor-15.1-sp'; + dataSetName = sprintf('msp_classicalLabCond%s-0-15.odor-%d.1-sp', sparsityLevel, odorId); +end + + +filePattern = ['model_cache/predictions/', dataSetName, '.*']; +f = dir(filePattern); +files = {f.name}; +files = natsort(files)'; +fprintf('models: %d\n', length(files)); +N_models = length(files); + +% eval on those data-sets +% label, dataSet +dataSets = { + {'CS+/CS-/1 bg. odor', ['PoisonPulse', sparsityLevel, '_0-3-15-10sec.mat'], '#1'}, + {'CS+/CS-/2 bg. odors (distinct)', ['PoisonPulse', sparsityLevel, '_0-3-5-15-10sec.mat'], '#2'}, + {'CS+/CS-/2 bg. odors (distinct & CS+ similar)', ['PoisonPulse', sparsityLevel, '_0-3-8-15-10sec.mat'], '#3'}, + {'CS+/CS-/3 bg. odors (2 distinct & CS+ similar)', ['PoisonPulse', sparsityLevel, '_0-3-5-8-15-10sec.mat'], '#4'} +}; + +results = cell(N_models, length(dataSets)); +x_accu = []; +x_accu_std = []; +err_low = []; +err_high = []; +labels = cell(1,length(dataSets)); +colors = colormap(lines(length(dataSets))); + +for i=1:N_models + for j=1:length(dataSets) + fileName = sprintf('model_cache/predictions/%s/%s', files{i}, dataSets{j}{2}); + data = load(fileName); + results{i,j} = data.accu; + labels{j} = dataSets{j}{3}; + end +end + + +c = categorical(labels); +c = reordercats(c, labels); +for j=1:length(dataSets) + x_accu = [x_accu mean([results{:,j}])]; + x_accu_std = [x_accu_std std([results{:,j}])]; + err_low = [err_low min([results{:,j}])]; + err_high = [err_high max([results{:,j}])]; + + x = 1:length(dataSets); + %b = bar(c(j),x_accu(j),'FaceColor',colors(j,:)); + b = bar(c(j), double(x_accu(j)), 'FaceColor', [.4 .4 .4]); + hold on; +end + +%title('transfer learning: seq. task'); +er = errorbar(x,x_accu,x_accu_std,[]); +er.Color = [0 0 0]; +er.LineStyle = 'none'; +ylim([0 100]); +ylabel('avg. accuracy'); +yticks([0 10 25 50 80 100]); +set(gca,'box','off'); + +set(gca,'Units','centimeters','Position',[1 1 2 3]); + +fig = gcf; +fig.Units = 'centimeters'; +fig.Position(3) = 3; %10.5; +fig.Position(4) = 4; +set(fig.Children, ... + 'FontName', 'Arial', ... + 'FontSize', 8); +set(fig, 'DefaultFigureRenderer', 'painters'); +fig.PaperPositionMode = 'auto'; +set(fig, 'PaperUnits', 'centimeters', 'Units', 'centimeters'); +set(fig, 'PaperSize', fig.Position(3:4), 'Units', 'centimeters'); +mkdir('../figures/', dataSetName); +print(gcf, '-dpdf', ['../figures/', dataSetName, '/fig_transfer_learning.pdf']); \ No newline at end of file diff --git a/matlab/fit_all_revision.m b/matlab/fit_all_revision.m new file mode 100644 index 0000000..87a708e --- /dev/null +++ b/matlab/fit_all_revision.m @@ -0,0 +1,39 @@ +target_odor_id = 15; +n_models = 50; +n_train_samples = -1; + +train_instances = { + 'ConnectivityMediumSparsityAPL', + 'ConnectivityLowSparsityAPL', + 'ConnectivityHighSparsityAPL' + %'WeightMediumSparsityAPL' +}; + +data_sets = { + 'Gaussian%s_15-0-3-15_10sec.mat', + 'PoisonPulse%s_0-3-15-10sec.mat', + 'PoisonPulse%s_0-3-5-15-10sec.mat', + 'PoisonPulse%s_0-3-5-8-15-10sec.mat', + 'PoisonPulse%s_0-3-8-15-10sec.mat' +}; + +for i=1:length(train_instances) + variant = train_instances{i}; + [w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond(['classicalLabCond', variant, '-0-15'], ... + ['../data/LabCond', variant, '_0-15-3sec.mat'], ... + 'n_samples', n_train_samples, 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, ... + 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', target_odor_id, 'odor_ids', [0,15]); +end + + +for i=1:length(train_instances) + variant = train_instances{i}; + for j=1:length(data_sets) + dataSet = data_sets{j}; + [w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond(['classicalLabCond', variant, '-0-15'], ... + ['../data/LabCond', variant, '_0-15-3sec.mat'], ... + 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, ... + 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', target_odor_id, 'odor_ids', [0,15], ... + 'predictDataSet', sprintf(['../data/', dataSet], variant), 'predictWeightIdx', -1); + end +end \ No newline at end of file diff --git a/matlab/fit_msp_tempotron.m b/matlab/fit_msp_tempotron.m new file mode 100644 index 0000000..b9ece5b --- /dev/null +++ b/matlab/fit_msp_tempotron.m @@ -0,0 +1,146 @@ +% FIT_MSP_TEMPOTRON(ts, trials, labels, w, V_thresh, V_rest, tau_m, tau_s, lr, n_iter, optimizer, fn_target) +% train multi-spike tempotron on given trials and labels +% ts: time vector +% trials: cell array of trials. Each entry is a cell array of input spike times +% labels: labels (cumulative number of output spikes) for each trial +% w: synaptic efficiencies / weights +% V_thresh: spiking threshold of neuron model (see MSPTempotron) +% V_rest: resting potential of neuron model (see MSPTempotron) +% tau_m: membrane time constant of neuron model (see MSPTempotron) +% tau_s: synapse time constant of neuron model (see MSPTempotron) +% lr: learning rate parameter +% n_iter: total number of iterations performed +% optimizer: one of 'sgd', 'adagrad', 'rmsprop', 'adam' +% fn_target: function handle to custom error function with signature fn(sample_idx, t_out, target_cum_reward) + +function [w, t_crit, dv_dw, errs, outputs, w_hist, anneal_lr, t_adam] = fit_msp_tempotron(ts, trials, labels, w, V_thresh, V_rest, tau_m, tau_s, lr, n_iter, optimizer, fn_target) + + if nargin < 12 + fn_target = []; + end + + if nargin < 11 + optimizer = 'rmsprop'; + end + + dataFormatType = iscell(trials{1}); + if dataFormatType == 0 + % this means, data is formated as cell array with spikes times as + % columns (per synapse) + N_syn = size(trials(1,:), 2); + else + N_syn = length(trials{1}); + end + + errs = []; + memo_exp = memoize(@exp); + memo_exp.CacheSize = N_syn*10; + outputs = zeros(1, size(trials, 1)); + d_momentum = zeros(1, N_syn); + t_crit = 0; + dv_dw = []; + w_hist = []; + grad_cache = zeros(1, N_syn); %adagrad / RMSprop gradient cache + eps = 10^-6; + momentum_mu = 0.99; % momentum hyper param + rms_decay_rate = 0.99; % RMSprop leak + lr_step = 100; + lr_step_size = 0.001; % annealing step size + lr_min = 0.0001; + anneal_lr = lr; + + % ADAM hyper params + beta1 = 0.9; + beta2 = 0.999; + m = grad_cache; + v = grad_cache; + t_adam = max(1, n_iter); + + shuffle_idx = randperm(size(trials, 1)); + profile_start = tic; + for i=1:size(trials,1) + % determine format of pattern + if dataFormatType == 0 + pattern = cell(trials(i,:)); + else + pattern = trials{i}; + end + + target = labels(i); + + if mod(i, 10) == 0 + tElapsed = toc(profile_start); + %disp(sprintf(' trial %d [%.3f sec]', i, tElapsed)); + profile_start = tic; + end + + [v_t, t_out, t_out_idx, v_unreset, ~, ~, V_0, tau_m, tau_s] = MSPTempotron(memo_exp, ts, pattern, w, V_thresh, V_rest, tau_m, tau_s); + outputs(i) = length(t_out); + % keep track on errors + if (~isempty(fn_target)) + %err = fn_target(shuffle_idx(i), t_out, labels(shuffle_idx(i))) - outputs(shuffle_idx(i)); + err = fn_target(i, t_out, target); + %disp(sprintf(' err=%d out=%d target=%d', err, outputs(shuffle_idx(i)), labels(shuffle_idx(i)))); + else + err = target - length(t_out); + end + + if (any(isnan(v_t))) + error('NaNs !!!'); + end + + if (mod(t_adam, lr_step) == 0) + anneal_lr = max(anneal_lr - lr_step_size, lr_min); + end + + if (err ~= 0) % perform training only on error trial + %disp(sprintf(' trial %d | %d -> %d | %d | %.2f | %.2f ', i, outputs(i), labels(i), errs(i), norm(w), mean(v_t))); + + t_adam = t_adam + 1; + errs = [errs err]; + + [pks, pks_idx, t_crit, d_w, dw_dir, dv_dw] = msp_grad(V_0, V_thresh, pattern, w, ts, v_t, v_unreset, t_out, t_out_idx, err, tau_m, tau_s); + + if strcmpi(optimizer, 'adagrad') == 1 + % ADAgrad optimizer + %disp('** adagrad'); + grad_cache = grad_cache + d_w.^2; + delta = (((dw_dir * lr) .* d_w) ./ (sqrt(grad_cache) + eps)); + elseif strcmpi(optimizer, 'rmsprop') == 1 + % RMSprop + %disp('** RMSprop'); + grad_cache = rms_decay_rate .* grad_cache + (1 - rms_decay_rate) .* d_w.^2; + delta = (((dw_dir * lr) .* d_w) ./ (sqrt(grad_cache) + eps)); + + elseif strcmpi(optimizer, 'adam') == 1 + % ADAM + %disp('** ADAM'); + m = beta1 .* m + (1-beta1) .* (dw_dir .* d_w); + mt = m ./ (1-beta1.^t_adam); + v = beta2 .* v + (1-beta2) .* (d_w.^2); + vt = v ./ (1-beta2.^t_adam); + delta = (((dw_dir * lr) .* mt) ./ (sqrt(vt) + eps)); + elseif strcmpi(optimizer, 'nesterov') == 1 + %disp('** Nesterov'); + % Nesterov Momentum + error("nesterov momentum not yet implemented - please use vanilla momentum for now."); + %d = (lr .* d_w) + (momentum_mu .* d); + %delta = (dw_dir .* d); + elseif strcmpi(optimizer, 'momentum') == 1 + %disp('** Momentum'); + % Momentum + d_momentum = ((dw_dir * lr) .* d_w) + (momentum_mu .* d_momentum); + delta = d_momentum; + + else + %default: vanilla SGD + %disp('** SGD'); + delta = ((dw_dir * lr) .* d_w); % regular gradient-based learning + end + + % update weights + w = w + delta; + w_hist = [w_hist; w]; + end + end +end \ No newline at end of file diff --git a/matlab/gen_all_models.m b/matlab/gen_all_models.m new file mode 100644 index 0000000..4cf3dff --- /dev/null +++ b/matlab/gen_all_models.m @@ -0,0 +1,60 @@ +n_models = 100; + +% classical (lab) cond. on single pulses: target odor 0 +[w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCond-0-15', '../data/LabCond_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15]); +% target odor: 15 +[w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCond-0-15', '../data/LabCond_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 15, 'odor_ids', [0,15]); + +% target odor: 0 sparsity 5% +[w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCondMediumSparsity-0-15', '../data/LabCondMediumSparsity_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15]); +% target odor: 0 sparsity 10% +[w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCondLowSparsity-0-15', '../data/LabCondLowSparsity_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15]); + + + +sequenceDataSets = { + '../data/PoisonPulse_0-3-15-10sec.mat', + '../data/PoisonPulse_0-3-5-15-10sec.mat', + '../data/PoisonPulse_0-3-5-8-15-10sec.mat', + '../data/PoisonPulse_0-3-8-15-10sec.mat', + '../data/PoisonPulseMediumSparsity_0-3-8-15-10sec.mat', + '../data/PoisonPulseLowSparsity_0-3-8-15-10sec.mat' +}; + +predictionDataSets = { + '../data/PoisonPulse_0-3-15-10sec.mat', + '../data/PoisonPulse_0-3-5-15-10sec.mat', + '../data/PoisonPulse_0-3-5-8-15-10sec.mat', + '../data/PoisonPulse_0-3-8-15-10sec.mat', + '../data/Gaussian_15-0-3-15_10sec.mat', + '../data/PoisonPulseMediumSparsity_0-3-8-15-10sec.mat', + '../data/PoisonPulseLowSparsity_0-3-8-15-10sec.mat', + '../data/GaussianMediumSparsity_15-0-3-15_10sec.mat', + '../data/GaussianLowSparsity_15-0-3-15_10sec.mat' +}; + +for k=i:length(predictionDataSets) + + if k <= 5 + %%%% predictions for regular sparsiy levels + % target odor: 0 + [w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCond-0-15', '../data/LabCond_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15], 'predictDataSet', predictionDataSets{k}); + + % target odor: 15 + [w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCond-0-15', '../data/LabCond_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 15, 'odor_ids', [0,15], 'predictDataSet', predictionDataSets{k}); + else + %%%% predictions for Low/Medium sparsity levels + % target odor: 0 + [w, train_loss, test_loss, w_init] = msp_fit_mbon_labcond('classicalLabCondMediumSparsity-0-15', '../data/LabCondMediumSparsity_0-3-5-8-15-3sec.mat', 'n_epochs', 1, 'optimizer', 'rmsprop', 'split', 0.25, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15], 'predictDataSet', predictionDataSets{k}); + end +end + +% fit models on sequential task +for m=i:length(sequenceDataSets) + dataFile = sequenceDataSets{m}; + modelName = dataFile(9:end-4); + [w, train_loss, test_loss, w_init] = msp_fit_mbon_task(modelName, dataFile, 'n_epochs', 15, 'optimizer', 'rmsprop', 'split', 0.2, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 0, 'odor_ids', [0,15]); + % target odor: 15 + [w, train_loss, test_loss, w_init] = msp_fit_mbon_task(modelName, dataFile, 'n_epochs', 15, 'optimizer', 'rmsprop', 'split', 0.2, 'learn_rate', 0.0005, 'n_models', n_models, 'spikes_per_reward', 1, 'target_odor_id', 15, 'odor_ids', [0,15]); + +end diff --git a/matlab/msp_fit_mbon_labcond.m b/matlab/msp_fit_mbon_labcond.m new file mode 100644 index 0000000..fffd31e --- /dev/null +++ b/matlab/msp_fit_mbon_labcond.m @@ -0,0 +1,216 @@ +function [w_out, train_losses, validation_losses, w_init] = msp_fit_mbon_task_labcond(modelName, dataSetFileName, varargin) + +args = inputParser; +defaultOptimizer='rmsprop'; +validOptimizers = {'rmsprop', 'momentum'}; +checkOptimizer = @(x) any(validatestring(x,validOptimizers)); +defaultCvMethod = 'traintest'; +validCvMethods = {'traintest', 'kfold'}; +checkCvMethod = @(x) any(validatestring(x,validCvMethods)); + +addRequired(args,'modelName',@ischar); +addRequired(args,'dataSetFileName',@ischar); + +addParameter(args, 'odor_ids', []); +addParameter(args, 'n_samples', -1, @isnumeric); +addParameter(args, 'batch_size', 4, @isnumeric); +addParameter(args, 'n_models', 1, @isnumeric); +addParameter(args, 'n_epochs', 1, @isnumeric); +addParameter(args, 'dt', 1/1000, @isnumeric); +addParameter(args, 'optimizer', defaultOptimizer, checkOptimizer); +addParameter(args, 'rng_seed', 42, @isnumeric); +addParameter(args, 'split', 0.3, @isnumeric); +%addParameter(args, 'cv_method', defaultCvMethod, checkCvMethod); +addParameter(args, 'learn_rate', 0.001, @isnumeric); +addParameter(args, 'spikes_per_reward', 1, @isnumeric); +addParameter(args, 'target_odor_id', 2, @isnumeric); +addParameter(args, 'early_stopping', -1, @isnumeric); +addParameter(args, 'predictDataSet', "", @ischar); +addParameter(args, 'predictWeightIdx', -1, @isnumeric); +args.KeepUnmatched = true; +parse(args,modelName, dataSetFileName, varargin{:}); + +n_samples = args.Results.n_samples; +n_epochs = args.Results.n_epochs; +odor_ids = args.Results.odor_ids; +optimizer = args.Results.optimizer; +seed = args.Results.rng_seed; +n_models = args.Results.n_models; +train_test_split = args.Results.split; +%cvMethod = args.Results.cv_method; +cvMethod = defaultCvMethod; +batchSize = args.Results.batch_size; +early_stopping_accuracy = args.Results.early_stopping; + +predictDataSet = ""; +predictWeightIdx = -1; +if ~isempty(args.Results.predictDataSet); + predictDataSet = args.Results.predictDataSet; + predictWeightIdx = args.Results.predictWeightIdx; +end + +f = load(dataSetFileName); +data = f.data; + +w_out = []; +train_losses = []; +validation_losses = []; +w_init = []; + +N_syn = size(data.trials, 2); +dt = args.Results.dt; +T = double(data.T_trial); +lr = args.Results.learn_rate; +ts = 0:dt:T; +reward_size = args.Results.spikes_per_reward; % no of spikes for individual pattern +target_idx = args.Results.target_odor_id + 1; % idx of rewarded odor source (data.rewards contains multiple counts for each odor) + +% neuron model +tau_m = 0.015; +tau_s = 0.005; +V_thresh = 1; +V_rest = 0; +rng(seed); + +if isfield(data, 'odor_ids') && ~isempty(odor_ids) + disp(['filtering for odor_ids: ' num2str(odor_ids)]); + idx = ismember(data.odor_ids, odor_ids); + data.trials = data.trials(idx, :); + data.targets = data.targets(idx, :); + data.odor_ids = data.odor_ids(idx); +end + +y = double(data.targets); +y(:, target_idx) = y(:, target_idx) .* reward_size; +y = y(:,target_idx); + +if n_samples > 0 + %c = cvpartition(y,'HoldOut', size(data.trials, 2) - n_samples) ; + rnd_idx = randperm(size(data.trials,1)); + + % shuffle + data.trials = data.trials(rnd_idx,:); + data.targets = data.targets(rnd_idx, :); + % re-partition + data.trials = data.trials(1:n_samples,:); + data.targets = data.targets(1:n_samples, :); + + if isfield(data, 'odor_ids') + data.odor_ids = data.odor_ids(rnd_idx); + data.odor_ids = data.odor_ids(1:n_samples); + end + + y = y(rnd_idx); + y = y(1:n_samples); + disp(sprintf('adjusted dataSet size: %d | rewards: %d', size(data.trials, 1), length(data.targets))); + +end + + + +if isfield(data, 'odor_ids') + % partition data - stratified HoldOut + c = cvpartition(y, 'HoldOut', train_test_split); +else + % partition data + c = cvpartition(size(data.trials,1), 'HoldOut', train_test_split); +end + + +n_batches = ceil(c.TrainSize / batchSize) - 1; +seeds = randi(98756, 1, n_models); +disp(sprintf('n_samples: %d | optimizer: %s | n_epochs: %d | cvMethod: %s | reward_size: %d | target_idx: %d | n_batches: %d', n_samples, optimizer, n_epochs, cvMethod, reward_size, target_idx, n_batches)); + +for k=1:n_models + n_iter = 0; + ctr = 1; + model_seed = seeds(k); + rng(model_seed); + w_inits = normrnd(0, 1 / N_syn, n_batches, N_syn); + w_outs = w_inits; + outFile = sprintf('model_cache/msp_%s.odor-%d.%d-sp.%d.mat', modelName, args.Results.target_odor_id, args.Results.spikes_per_reward, k); + + if exist(outFile, 'file') == 2 + disp(sprintf('model %d file exists - skipped: %s', k, outFile)); + + if exist(predictDataSet, 'file') == 2 + [rmse,accu,~,~,~] = rmse_mbon_task(outFile, predictDataSet, args.Results.target_odor_id, reward_size, predictWeightIdx); + end + continue; + end + + w_out = squeeze(w_outs(1, :))'; + w_init = squeeze(w_inits(1, :))'; + + train_losses = zeros(n_epochs * n_batches, 1); + train_accuracy = zeros(n_epochs * n_batches, 1); + validation_losses = zeros(n_epochs * n_batches, 1); + validation_accuracy = zeros(n_epochs * n_batches, 1); + predictions = cell(n_epochs * n_batches, 4); + + % split train data + train_data = data.trials(c.training(1), :); + y_train = y(c.training(1))'; + n_train = length(y_train); + % shuffle train set labels + shuffle_idx_train = randperm(length(y_train)); + train_data = train_data(shuffle_idx_train, :); + y_train = y_train(shuffle_idx_train); + + % split test/validation data + test_data = data.trials(c.test(1), :); + y_test = y(c.test(1))'; + n_test = length(y_test(1)); + + for i=1:n_epochs + for j=1:n_batches + b_start = 1; %(j-1)*batchSize + 1; + if (batchSize > 1) + b_end = min((j*batchSize), size(train_data,1)); + else + % edge-case: batchSize == 1 + b_end = min(j+1, size(train_data,1)); + end + [w_out, ~, ~, errs, preds, ~, ~, n_iter] = fit_msp_tempotron(ts, train_data(b_start:b_end,:), y_train(b_start:b_end), w_outs(j,:), V_thresh, V_rest, tau_m, tau_s, lr, n_iter, optimizer); + w_outs(j,:) = w_out; + loss = mean(abs(preds-y_train(b_start:b_end))); + n_correct = length(find(preds==y_train(b_start:b_end))); + train_accuracy(i*j) = (n_correct * 100)/length(y_train(b_start:b_end)); + train_losses(i*j) = loss; + predictions{i*j, 1} = preds; + predictions{i*j, 2} = y_train(b_start:b_end); + + [mean_val_loss, ~, val_preds, ~] = validate_msp_tempotron(ts, test_data, y_test, w_out, V_thresh, V_rest, tau_m, tau_s); + validation_losses(i*j) = mean_val_loss; + n_correct = length(find(val_preds==y_test)); + predictions{i*j, 3} = val_preds; + predictions{i*j, 4} = y_test; + + n_total = length(y_test); + validation_accuracy(i*j) = (n_correct * 100)/n_total; + ctr = ctr + 1; + + disp(sprintf('[%d@%s] epoch=%d|%d (%d samples) | lr=%.4f | train_loss: %.3f (train_acc: %.3f) | val_loss: %.3f (val_acc: %.3f | %d/%d)', k, optimizer, i, j, size(train_data(b_start:b_end,:), 1), lr, loss, train_accuracy(i*j), mean_val_loss, validation_accuracy(i*j), n_correct, n_total)); + + if early_stopping_accuracy > 0 && validation_accuracy(1,i*j) >= early_stopping_accuracy + disp(sprintf('[%d] early stopping @ %.3f | %s learning converged after %d epochs val_accuracy: %.3f', k, early_stopping_accuracy, optimizer, i, validation_accuracy(1,i*j))); + break; + end + + if (isempty(errs)) % all zeros => no errors, converged + disp(sprintf('[%d] %s learning converged after %d epochs (%d samples)', k, optimizer, i, i*j*batchSize)); + break; + end + end + end + + disp(sprintf('[%d] checkpoint model saving to: %s', k, outFile)); + save(outFile, 'lr', 'train_test_split', 'cvMethod', 'n_train', 'n_test', 'seed', 'train_losses', 'validation_losses', 'train_accuracy', 'validation_accuracy', 'predictions', 'w_outs', 'w_inits', 'tau_m', 'tau_s', 'V_rest', 'V_thresh', 'T', 'dt', 'ts', 'n_batches', 'batchSize', 'model_seed'); + + if exist(predictDataSet, 'file') == 2 + [rmse,accu,y,y_pred,~] = rmse_mbon_task(outFile, predictDataSet, args.Results.target_odor_id, reward_size, predictWeightIdx); + end +end + + +end \ No newline at end of file diff --git a/matlab/msp_fit_mbon_task.m b/matlab/msp_fit_mbon_task.m new file mode 100644 index 0000000..131977b --- /dev/null +++ b/matlab/msp_fit_mbon_task.m @@ -0,0 +1,192 @@ +function [w_out, train_losses, validation_losses, w_init] = msp_fit_mbon_task(modelName, dataSetFileName, varargin) + +args = inputParser; +defaultOptimizer='rmsprop'; +validOptimizers = {'rmsprop', 'momentum'}; +checkOptimizer = @(x) any(validatestring(x,validOptimizers)); +defaultCvMethod = 'traintest'; +validCvMethods = {'traintest', 'kfold'}; +checkCvMethod = @(x) any(validatestring(x,validCvMethods)); + +addRequired(args,'modelName',@ischar); +addRequired(args,'dataSetFileName',@ischar); +addParameter(args, 'odor_ids', []); +addParameter(args, 'n_samples', -1, @isnumeric); +addParameter(args, 'n_epochs', 10, @isnumeric); +addParameter(args, 'dt', 1/1000, @isnumeric); +addParameter(args, 'optimizer', defaultOptimizer, checkOptimizer); +addParameter(args, 'rng_seed', 42, @isnumeric); +addParameter(args, 'split', 0.3, @isnumeric); +addParameter(args, 'cv_method', defaultCvMethod, checkCvMethod); +addParameter(args, 'learn_rate', 0.001, @isnumeric); +addParameter(args, 'spikes_per_reward', 1, @isnumeric); +addParameter(args, 'target_odor_id', 2, @isnumeric); +addParameter(args, 'early_stopping', -1, @isnumeric); +addParameter(args, 'n_models', 15, @isnumeric); + +args.KeepUnmatched = true; +parse(args,modelName, dataSetFileName, varargin{:}); + +n_samples = args.Results.n_samples; +n_models = args.Results.n_models; +n_epochs = args.Results.n_epochs; +odor_ids = args.Results.odor_ids; +optimizer = args.Results.optimizer; +seed = args.Results.rng_seed; +train_test_split = args.Results.split; +cvMethod = args.Results.cv_method; +early_stopping_accuracy = args.Results.early_stopping; +seeds = randi(98756, 1, n_models); + +f = load(dataSetFileName); +data = f.data; + +N_syn = size(data.trials, 2); +dt = args.Results.dt; +T = double(data.T_trial); +lr = args.Results.learn_rate; +ts = 0:dt:T; +reward_size = args.Results.spikes_per_reward; % no of spikes for individual pattern +target_idx = args.Results.target_odor_id + 1; % idx of rewarded odor source (data.rewards contains multiple counts for each odor) + +% neuron model +tau_m = 0.015; +tau_s = 0.005; +V_thresh = 1; +V_rest = 0; +rng(seed); + +if isfield(data, 'odor_ids') && ~isempty(odor_ids) + disp(['filtering for odor_ids: ' num2str(odor_ids)]); + idx = ismember(data.odor_ids, odor_ids); + data.trials = data.trials(idx, :); + data.targets = data.targets(idx, :); + data.odor_ids = data.odor_ids(idx); +end + +y = double(data.targets); +y(:, target_idx) = y(:, target_idx) .* reward_size; +y = y(:,target_idx); + +if n_samples > 0 + %c = cvpartition(y,'HoldOut', size(data.trials, 2) - n_samples) ; + rnd_idx = randperm(size(data.trials,1)); + + % shuffle + data.trials = data.trials(rnd_idx,:); + data.targets = data.targets(rnd_idx, :); + % re-partition + data.trials = data.trials(1:n_samples,:); + data.targets = data.targets(1:n_samples, :); + + if isfield(data, 'odor_ids') + data.odor_ids = data.odor_ids(rnd_idx); + data.odor_ids = data.odor_ids(1:n_samples); + end + + y = y(rnd_idx); + y = y(1:n_samples); + disp(sprintf('adjusted dataSet size: %d | rewards: %d', size(data.trials, 1), length(data.targets))); + +end + +disp(sprintf('n_samples: %d | optimizer: %s | n_epochs: %d | cvMethod: %s | reward_size: %d | target_idx: %d', n_samples, optimizer, n_epochs, cvMethod, reward_size, target_idx)); + + +if isfield(data, 'odor_ids') + % partition data - stratified HoldOut + if strcmpi(cvMethod, 'kfold') == 1 + cval = cvpartition(y,'KFold', train_test_split); + else + cval = cvpartition(y, 'HoldOut', train_test_split); + end +else + % partition data + if strcmpi(cvMethod, 'kfold') == 1 + cval = cvpartition(y,'KFold', train_test_split); + else + cval = cvpartition(size(data.trials,1), 'HoldOut', train_test_split); + end +end + + +n_folds = 1; + +if strcmpi(cvMethod, 'kfold') == 1 + n_folds = cval.NumTestSets; + cval +end + + +for m=1:n_models + +model_seed = seeds(m); +rng(model_seed); +outFile = sprintf('model_cache/msp_%s.odor-%d.%d-sp.%d.mat', modelName, args.Results.target_odor_id, args.Results.spikes_per_reward, m); +c = repartition(cval); +%w_outs = zeros(n_folds, N_syn); +%w_inits = zeros(n_folds, N_syn); +w_inits = normrnd(0, 1 / N_syn, n_folds, N_syn); +w_outs = w_inits; + +train_losses = zeros(n_folds, n_epochs); +train_accuracy = zeros(n_folds, n_epochs); +validation_losses = zeros(n_folds, n_epochs); +validation_accuracy = zeros(n_folds, n_epochs); + +for k=1:n_folds + n_iter = 0; + w_out = w_outs(k, :); + w_init = w_inits(k, :); + + % split train data + train_data = data.trials(c.training(k), :); + y_train = y(c.training(k))'; + n_train = length(y_train); + % shuffle train set labels + shuffle_idx_train = randperm(length(y_train)); + train_data = train_data(shuffle_idx_train, :); + y_train = y_train(shuffle_idx_train); + + % split test/validation data + test_data = data.trials(c.test(k), :); + y_test = y(c.test(k))'; + n_test = length(y_test(k)); + + for i=1:n_epochs + [w_out, ~, ~, errs, preds, ~, ~, n_iter] = fit_msp_tempotron(ts, train_data, y_train, w_out, V_thresh, V_rest, tau_m, tau_s, lr, n_iter, optimizer); + w_outs(k,:) = w_out; + loss = mean(abs(preds-y_train)); + n_correct = length(find(preds==y_train)); + train_accuracy(k,i) = (n_correct * 100)/length(y_train); + train_losses(k,i) = loss; + + [mean_val_loss, ~, val_preds, ~] = validate_msp_tempotron(ts, test_data, y_test, w_out, V_thresh, V_rest, tau_m, tau_s); + validation_losses(k,i) = mean_val_loss; + n_correct = length(find(val_preds==y_test)); + n_total = length(y_test); + validation_accuracy(k,i) = (n_correct * 100)/n_total; + + if early_stopping_accuracy > 0 && validation_accuracy(1,i) >= early_stopping_accuracy + disp(sprintf('[%d] early stopping @ %.3f | %s learning converged after %d epochs val_accuracy: %.3f', k, early_stopping_accuracy, optimizer, i, validation_accuracy(1,i))); + break; + end + + if (isempty(errs)) % all zeros => no errors, converged + disp(sprintf('[%d] %s learning converged after %d epochs', k, optimizer, i)); + break; + end + + if (mod(i, 1) == 0) + disp(sprintf('[%d@%s] epoch=%d | lr=%.4f | train_loss: %.3f (train_acc: %.3f) | val_loss: %.3f (val_acc: %.3f | %d/%d)', k, optimizer, i, lr, loss, train_accuracy(1,i), mean_val_loss, validation_accuracy(1,i), n_correct, n_total)); + end + end + + disp(sprintf('[%d] checkpoint model saved to: %s', m, outFile)); + save(outFile, 'train_test_split', 'cvMethod', 'n_folds', 'n_train', 'n_test', 'seed', 'train_losses', 'validation_losses', 'train_accuracy', 'validation_accuracy', 'w_outs', 'w_inits', 'tau_m', 'tau_s', 'V_rest', 'V_thresh', 'T', 'dt', 'ts'); + +end + +end + +end \ No newline at end of file diff --git a/matlab/msp_grad.m b/matlab/msp_grad.m new file mode 100644 index 0000000..541cc44 --- /dev/null +++ b/matlab/msp_grad.m @@ -0,0 +1,161 @@ +% MSP_GRAD(V_0, V_thresh, t_i, w_i, ts, v_t, v_unreset, t_out, t_out_idx, N_output, tau_m, tau_s) - compute gradient theta^* for multi-spike tempotron learning rule +% V_0: normalizing constant of neuron model (see MSPTempotron) +% V_thresh: spiking threshold of neuron model +% t_i: current input pattern as cell array of input spike times for each synapse +% w_i: synaptic efficiencies / weights +% ts: time vector +% v_t: membrane potential of neuron for given input pattern t_i +% v_unreset: unresetted membrane potential +% t_out: ouput spike times for given input pattern t_i +% t_out_idx: indices within ts time vector where output spikes occoured +% N_output: number of desired output spikes (see below) +% tau_m: membrane time constant of neuron model (see MSPTempotron) +% tau_s: synapse time constant of neuron model (see MSPTempotron) +% +% N_output - the number of desired ADDITIONAL output spikes +% if N_output > 0 we want more spikes (search for subthreshold peaks) +% if N_output < 0 we want less spikes (determine smallest peaks in v_unreset +function [pks, pks_idx, t_crit, d_w, dw_dir, dv_dw] = msp_grad(V_0, V_thresh, t_i, w_i, ts, v_t, v_unreset, t_out, t_out_idx, N_output, tau_m, tau_s) + + t_crit = []; + dt = ts(2) - ts(1); + + if N_output > 0 + % determine theta_star which will produce N_output + % additional output spikes + dw_dir = 1; % direction of weight update + % want more spikes => increase weights, find N largest voltage peaks in subthreshold + [pks,pks_idx] = findpeaks(v_t); % this will find also find all output spike times + pks_idx = setdiff(pks_idx, t_out_idx-1); % remove output spike times from set + + pks = v_t(pks_idx); + [S,I] = sort(pks,'descend'); + %the N-th peak is the voltage which will produce N additional spikes + idx = min(N_output, length(S)); + if isempty(S) + v_crit = v_t(1); + t_crit = ts(1); + else + v_crit = S(idx); + v_crit_idx = pks_idx(I(idx)); + t_crit = ts(v_crit_idx); + end + elseif N_output < 0 && ~isempty(t_out) + % determine theta_star which will eliminate N_output + % output spikes + dw_dir = -1; % direction of weight update + % look for the peak above V_threshold which is + % closest to V_threshold in unresetted voltage + [pks,pks_idx] = findpeaks(v_unreset); + idx_tmp = find(pks > V_thresh); % we're only interested in peaks above threshold + + % edge case + if isempty(idx_tmp) + pks_idx = t_out_idx; + pks = v_t(pks_idx); + else + pks = pks(idx_tmp); + pks_idx = pks_idx(idx_tmp); + end + + [S,I] = sort(pks,'ascend'); + safe_pos = min(abs(N_output), length(S)); + v_crit = S(safe_pos); + v_crit_idx = pks_idx(I(safe_pos)); + t_crit = ts(v_crit_idx); + end + + + if (~isempty(t_crit)) + % compute dv_dw at t_crit, which needs to be normalized + % by gradients of all previous output spikes + N_syn = length(w_i); + % loop over set of time points which conribute to the gradient + % that is t_crit (t*) and all output spike times < t_crit + % this is the set t_x of eq 28 + t_x = [t_crit t_out(t_out < t_crit)]; + + % temporal deriv. of v(t) before each spike time + v_dot = zeros(1, length(t_x)); + % the weight derivatives at each spike time + dv_dw = zeros(N_syn, length(t_x)); + % eq 31 for for each output spike + dv_dt_hist = zeros(1, length(t_x)); + % eq 29 normalizing constant + c_tx = zeros(1, length(t_x)); + % eq 23,24 normalizing constants due to gradient dependency on + % previous gradients + b_k = zeros(N_syn, length(t_x) - 1); + a_k = zeros(N_syn, length(t_x) - 1); + % for numerical purpose + eps = 10^-12; + + for k=1:length(t_x) + t_max = t_x(k); % current time point of set t_x + t_out_hist = t_out(t_out < t_max); % output spike history up to t_max + v_tx = v_t(ts == t_max); % voltage at current timepoint + + % eq 32 - here numerical derivative is used instead + % add eps to prevent division by 0 later on + v_dot(k) = ((v_tx - v_t(max(1, find(ts == t_x(k)) - 1)))/dt) + eps; + + if (k == 1 && dw_dir < 0) + v_dot(k) = ((V_thresh - v_t(max(1, find(ts == t_x(k)) - 1)))/dt) + eps; + end + %v_dot(k) = ((v_tx - v_t(find(ts == t_x(k)) - 1))*dt); + + % eq 29 + c_tx(k) = 1 + sum(exp(-(t_max - t_out_hist) / tau_m)); + + % do computations for each synapse + for j=1:N_syn + t_in_hist = [t_i{j}]; + t_in_hist = t_in_hist(t_in_hist < t_max); + + % this is eq. to the simple tempotron learning rule + psp_err = sum(V_0 .* (exp(-(t_max - t_in_hist)/tau_m) - exp(-(t_max - t_in_hist)/tau_s))); + v_0_tx = -(psp_err .* w_i(j)); + + % eq 31 - summation over exp() missing in eq 31 ! + + dv_dt_hist(k) = (v_0_tx / (c_tx(k)^2)) * (sum(exp((-(t_max - t_out_hist))/tau_m)) / tau_m); + % eq 30 - in principle eq. to simple tempotron but + % normalized by some factor as we have multiple output + % spikes now + dv_dw(j,k) = (1/c_tx(k)) * psp_err; + + % k == 1 is t_crit but a_k and b_k only depend on outp. spikes + if (k > 1) + sum_to = k-2; + v_dot_factor = v_dot(2:(sum_to+1)); + dv_dt_hist_factor = dv_dt_hist(2:(sum_to+1)); + + % a_ks are independent of w_i + % => all rows will be identical so this could be moved + % outside the synapse loop + % + a_k(j, k-1) = 1 - sum( (a_k(j, 1:sum_to) ./ v_dot_factor) .* dv_dt_hist_factor ); + b_k(j, k-1) = -(dv_dw(j,k)) - sum( (b_k(j, 1:sum_to) ./ v_dot_factor) .* dv_dt_hist_factor ); + end + end + end + + % finally, construct scaling for graient at t_crit + % which recursively depends on all gradients of previous output + % spikes + v_dot_factor_ab = v_dot(2:end); + dv_dt_hist_factor_ab = dv_dt_hist(2:end); + A_star = (1 - sum((a_k ./ v_dot_factor_ab) .* dv_dt_hist_factor_ab, 2)); + B_star = ((-(dv_dw(:,1))) - sum((b_k ./ v_dot_factor_ab) .* dv_dt_hist_factor_ab, 2)); + + if (~isempty(A_star(A_star == 0))) + error('A_start is zero - numeric problem !'); + end + + d_w = -(B_star ./ (A_star))'; + + if (~isempty(d_w(d_w > 1000)) || any(isnan(d_w))) + error('diverging gradient ! |d_w|=%.2f', norm(d_w)); + end + end +end \ No newline at end of file diff --git a/matlab/natsort.m b/matlab/natsort.m new file mode 100644 index 0000000..beb8fa3 --- /dev/null +++ b/matlab/natsort.m @@ -0,0 +1,329 @@ +function [X,ndx,dbg] = natsort(X,xpr,varargin) %#ok<*SPERR> +% Alphanumeric / Natural-Order sort the strings in a cell array of strings (1xN char). +% +% (c) 2012 Stephen Cobeldick +% +% Alphanumeric sort of a cell array of strings: sorts by character order +% and also by the values of any numbers that are within the strings. The +% default is case-insensitive ascending with integer number substrings: +% optional inputs control the sort direction, case sensitivity, and number +% matching (see the section "Number Substrings" below). +% +%%% Example: +% X = {'x2', 'x10', 'x1'}; +% sort(X) +% ans = 'x1' 'x10' 'x2' +% natsort(X) +% ans = 'x1' 'x2' 'x10' +% +%%% Syntax: +% Y = natsort(X) +% Y = natsort(X,xpr) +% Y = natsort(X,xpr,) +% [Y,ndx] = natsort(X,...) +% [Y,ndx,dbg] = natsort(X,...) +% +% To sort filenames or filepaths use NATSORTFILES (File Exchange 47434). +% To sort the rows of a cell array of strings use NATSORTROWS (File Exchange 47433). +% +% See also NATSORTFILES NATSORTROWS SORT CELLSTR IREGEXP REGEXP SSCANF INTMAX +% +%% Number Substrings %% +% +% By default consecutive digit characters are interpreted as an integer. +% The optional regular expression pattern permits the numbers to also +% include a +/- sign, decimal digits, exponent E-notation, or any literal +% characters, quantifiers, or look-around requirements. For more information: +% http://www.mathworks.com/help/matlab/matlab_prog/regular-expressions.html +% +% The substrings are then parsed by SSCANF into numeric variables, using +% either the *default format '%f' or the user-supplied format specifier. +% +% This table shows some example regular expression patterns for some common +% notations and ways of writing numbers (see section "Examples" for more): +% +% Regular | Number Substring | Number Substring | SSCANF +% Expression: | Match Examples: | Match Description: | Format Specifier: +% ==============|==================|===============================|================== +% * \d+ | 0, 1, 234, 56789 | unsigned integer | %f %u %lu %i +% --------------|------------------|-------------------------------|------------------ +% (-|+)?\d+ | -1, 23, +45, 678 | integer with optional +/- sign| %f %d %ld %i +% --------------|------------------|-------------------------------|------------------ +% \d+\.?\d* | 012, 3.45, 678.9 | integer or decimal | %f +% --------------|------------------|-------------------------------|------------------ +% \d+|Inf|NaN | 123, 4, Inf, NaN | integer, infinite or NaN value| %f +% --------------|------------------|-------------------------------|------------------ +% \d+\.\d+e\d+ | 0.123e4, 5.67e08 | exponential notation | %f +% --------------|------------------|-------------------------------|------------------ +% 0[0-7]+ | 012, 03456, 0700 | octal prefix & notation | %o %i +% --------------|------------------|-------------------------------|------------------ +% 0X[0-9A-F]+ | 0X0, 0XFF, 0X7C4 | hexadecimal prefix & notation | %x %i +% --------------|------------------|-------------------------------|------------------ +% 0B[01]+ | 0B101, 0B0010111 | binary prefix & notation | %b (not SSCANF) +% --------------|------------------|-------------------------------|------------------ +% +% The SSCANF format specifier (including %b) can include literal characters +% and skipped fields. The octal, hexadecimal and binary prefixes are optional. +% For more information: http://www.mathworks.com/help/matlab/ref/sscanf.html +% +%% Debugging Output Array %% +% +% The third output is a cell array , to check if the numbers have +% been matched by the regular expression and converted to numeric +% by the SSCANF format. The rows of are linearly indexed from : +% +% [~,~,dbg] = natsort(X) +% dbg = +% 'x' [ 2] +% 'x' [10] +% 'x' [ 1] +% +%% Relative Sort Order %% +% +% The sort order of the number substrings relative to the characters +% can be controlled by providing one of the following string options: +% +% Option Token:| Relative Sort Order: | Example: +% =============|======================================|==================== +% 'beforechar' | numbers < char(0:end) | '1' < '#' < 'A' +% -------------|--------------------------------------|-------------------- +% 'afterchar' | char(0:end) < numbers | '#' < 'A' < '1' +% -------------|--------------------------------------|-------------------- +% 'asdigit' *| char(0:47) < numbers < char(48:end) | '#' < '1' < 'A' +% -------------|--------------------------------------|-------------------- +% +% Note that the digit characters have character values 48 to 57, inclusive. +% +%% Examples %% +% +%%% Multiple integer substrings (e.g. release version numbers): +% B = {'v10.6', 'v9.10', 'v9.5', 'v10.10', 'v9.10.20', 'v9.10.8'}; +% sort(B) +% ans = 'v10.10' 'v10.6' 'v9.10' 'v9.10.20' 'v9.10.8' 'v9.5' +% natsort(B) +% ans = 'v9.5' 'v9.10' 'v9.10.8' 'v9.10.20' 'v10.6' 'v10.10' +% +%%% Integer, decimal or Inf number substrings, possibly with +/- signs: +% C = {'test+Inf', 'test11.5', 'test-1.4', 'test', 'test-Inf', 'test+0.3'}; +% sort(C) +% ans = 'test' 'test+0.3' 'test+Inf' 'test-1.4' 'test-Inf' 'test11.5' +% natsort(C, '(-|+)?(Inf|\d+\.?\d*)') +% ans = 'test' 'test-Inf' 'test-1.4' 'test+0.3' 'test11.5' 'test+Inf' +% +%%% Integer or decimal number substrings, possibly with an exponent: +% D = {'0.56e007', '', '4.3E-2', '10000', '9.8'}; +% sort(D) +% ans = '' '0.56e007' '10000' '4.3E-2' '9.8' +% natsort(D, '\d+\.?\d*(E(+|-)?\d+)?') +% ans = '' '4.3E-2' '9.8' '10000' '0.56e007' +% +%%% Hexadecimal number substrings (possibly with '0X' prefix): +% E = {'a0X7C4z', 'a0X5z', 'a0X18z', 'aFz'}; +% sort(E) +% ans = 'a0X18z' 'a0X5z' 'a0X7C4z' 'aFz' +% natsort(E, '(?<=a)(0X)?[0-9A-F]+', '%x') +% ans = 'a0X5z' 'aFz' 'a0X18z' 'a0X7C4z' +% +%%% Binary number substrings (possibly with '0B' prefix): +% F = {'a11111000100z', 'a0B101z', 'a0B000000000011000z', 'a1111z'}; +% sort(F) +% ans = 'a0B000000000011000z' 'a0B101z' 'a11111000100z' 'a1111z' +% natsort(F, '(0B)?[01]+', '%b') +% ans = 'a0B101z' 'a1111z' 'a0B000000000011000z' 'a11111000100z' +% +%%% UINT64 number substrings (with full precision!): +% natsort({'a18446744073709551615z', 'a18446744073709551614z'}, [], '%lu') +% ans = 'a18446744073709551614z' 'a18446744073709551615z' +% +%%% Case sensitivity: +% G = {'a2', 'A20', 'A1', 'a10', 'A2', 'a1'}; +% natsort(G, [], 'ignorecase') % default +% ans = 'A1' 'a1' 'a2' 'A2' 'a10' 'A20' +% natsort(G, [], 'matchcase') +% ans = 'A1' 'A2' 'A20' 'a1' 'a2' 'a10' +% +%%% Sort direction: +% H = {'2', 'a', '3', 'B', '1'}; +% natsort(H, [], 'ascend') % default +% ans = '1' '2' '3' 'a' 'B' +% natsort(H, [], 'descend') +% ans = 'B' 'a' '3' '2' '1' +% +%%% Relative sort-order of number substrings compared to characters: +% V = num2cell(char(32+randperm(63))); +% cell2mat(natsort(V, [], 'asdigit')) % default +% ans = '!"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_' +% cell2mat(natsort(V, [], 'beforechar')) +% ans = '0123456789!"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_' +% cell2mat(natsort(V, [], 'afterchar')) +% ans = '!"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_0123456789' +% +%% Input and Output Arguments %% +% +%%% Inputs (*=default): +% X = CellArrayOfCharRowVectors, to be sorted into natural-order. +% xpr = CharRowVector, regular expression for number substrings, '\d+'*. +% tokens can be entered in any order, as many as required: +% - Sort direction: 'descend'/'ascend'*. +% - Case sensitive/insensitive matching: 'matchcase'/'ignorecase'*. +% - Relative sort of numbers: 'beforechar'/'afterchar'/'asdigit'*. +% - The SSCANF number conversion format, e.g.: '%x', '%i', '%f'*, etc. +% +%%% Outputs: +% Y = CellArrayOfCharRowVectors, sorted into natural-order. +% ndx = NumericArray, such that Y = X(ndx). The same size as . +% dbg = CellArray of the parsed characters and number values. Each row is +% one input char vector, linear-indexed from . To help debug . +% +% [X,ndx,dbg] = natsort(X,xpr*,) +%% Input Wrangling %% +% +assert(iscell(X),'First input must be a cell array.') +tmp = cellfun('isclass',X,'char') & cellfun('size',X,1)<2 & cellfun('ndims',X)<3; +assert(all(tmp(:)),'First input must be a cell array of char row vectors (1xN char).') +% +% Regular expression: +if nargin<2 || isnumeric(xpr)&&isempty(xpr) + xpr = '\d+'; +else + assert(ischar(xpr)&&isrow(xpr),'Second input must be a regular expression (char row vector).') +end +% +% Optional arguments: +tmp = cellfun('isclass',varargin,'char') & 1==cellfun('size',varargin,1) & 2==cellfun('ndims',varargin); +assert(all(tmp(:)),'All optional arguments must be char row vectors (1xN char).') +% Character case matching: +ChrM = strcmpi(varargin,'matchcase'); +ChrX = strcmpi(varargin,'ignorecase')|ChrM; +% Sort direction: +DrnD = strcmpi(varargin,'descend'); +DrnX = strcmpi(varargin,'ascend')|DrnD; +% Relative sort-order of numbers compared to characters: +RsoB = strcmpi(varargin,'beforechar'); +RsoA = strcmpi(varargin,'afterchar'); +RsoX = strcmpi(varargin,'asdigit')|RsoB|RsoA; +% SSCANF conversion format: +FmtX = ~(ChrX|DrnX|RsoX); +% +if nnz(FmtX)>1 + tmp = sprintf(', ''%s''',varargin{FmtX}); + error('Overspecified optional arguments:%s.',tmp(2:end)) +end +if nnz(DrnX)>1 + tmp = sprintf(', ''%s''',varargin{DrnX}); + error('Sort direction is overspecified:%s.',tmp(2:end)) +end +if nnz(RsoX)>1 + tmp = sprintf(', ''%s''',varargin{RsoX}); + error('Relative sort-order is overspecified:%s.',tmp(2:end)) +end +% +%% Split Strings %% +% +% Split strings into number and remaining substrings: +[MtS,MtE,MtC,SpC] = regexpi(X(:),xpr,'start','end','match','split',varargin{ChrX}); +% +% Determine lengths: +MtcD = cellfun(@minus,MtE,MtS,'UniformOutput',false); +LenZ = cellfun('length',X(:))-cellfun(@sum,MtcD); +LenY = max(LenZ); +LenX = numel(MtC); +% +dbg = cell(LenX,LenY); +NuI = false(LenX,LenY); +ChI = false(LenX,LenY); +ChA = char(double(ChI)); +% +ndx = 1:LenX; +for k = ndx(LenZ>0) + % Determine indices of numbers and characters: + ChI(k,1:LenZ(k)) = true; + if ~isempty(MtS{k}) + tmp = MtE{k} - cumsum(MtcD{k}); + dbg(k,tmp) = MtC{k}; + NuI(k,tmp) = true; + ChI(k,tmp) = false; + end + % Transfer characters into char array: + if any(ChI(k,:)) + tmp = SpC{k}; + ChA(k,ChI(k,:)) = [tmp{:}]; + end +end +% +%% Convert Number Substrings %% +% +if nnz(FmtX) % One format specifier + fmt = varargin{FmtX}; + err = ['The supplied format results in an empty output from sscanf: ''',fmt,'''']; + pct = '(? double + NuA(NuI) = sscanf(sprintf('%s\v',dbg{NuI}),'%f\v'); +end +% Note: NuA's class is determined by SSCANF or the custom binary parser. +NuA(~NuI) = 0; +NuA = reshape(NuA,LenX,LenY); +% +%% Debugging Array %% +% +if nargout>2 + dbg(:) = {''}; + for k = reshape(find(NuI),1,[]) + dbg{k} = NuA(k); + end + for k = reshape(find(ChI),1,[]) + dbg{k} = ChA(k); + end +end +% +%% Sort Columns %% +% +if ~any(ChrM) % ignorecase + ChA = upper(ChA); +end +% +ide = ndx.'; +% From the last column to the first... +for n = LenY:-1:1 + % ...sort the characters and number values: + [C,idc] = sort(ChA(ndx,n),1,varargin{DrnX}); + [~,idn] = sort(NuA(ndx,n),1,varargin{DrnX}); + % ...keep only relevant indices: + jdc = ChI(ndx(idc),n); % character + jdn = NuI(ndx(idn),n); % number + jde = ~ChI(ndx,n)&~NuI(ndx,n); % empty + % ...define the sort-order of numbers and characters: + jdo = any(RsoA)|(~any(RsoB)&C<'0'); + % ...then combine these indices in the requested direction: + if any(DrnD) % descending + idx = [idc(jdc&~jdo);idn(jdn);idc(jdc&jdo);ide(jde)]; + else % ascending + idx = [ide(jde);idc(jdc&jdo);idn(jdn);idc(jdc&~jdo)]; + end + ndx = ndx(idx); +end +% +ndx = reshape(ndx,size(X)); +X = X(ndx); +% +end +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%natsort \ No newline at end of file diff --git a/matlab/rmse_mbon_task.m b/matlab/rmse_mbon_task.m new file mode 100644 index 0000000..5175ef5 --- /dev/null +++ b/matlab/rmse_mbon_task.m @@ -0,0 +1,62 @@ +function [rmse, accu,y,pred_y,sp_times] = rmse_mbon_task(modelName, dataSetFileName, target_odor_id, reward_size, w_idx) + + load(modelName); + [filepath,name,ext] = fileparts(modelName); + + %reward_size = 1; % no of spikes for individual pattern + %target_idx = 2; % idx of rewarded odor source (data.rewards contains multiple counts for each odor) + + + load(dataSetFileName); + N_samples = size(data.trials, 1); + N_syn = size(w_outs, 2); + samples = data.trials; + rewards = double(data.targets); + rewards(:, target_odor_id+1) = rewards(:, target_odor_id+1) .* reward_size; + y = rewards(:,target_odor_id+1)'; + % data-set might have different duration than model + T = double(data.T_trial); + ts = 0:dt:T; + if (w_idx <= 0) + w_out = w_outs(end,:); + else + disp(sprintf('uusing weights of epoch %d', w_idx)); + w_out = w_outs(w_idx,:); + end + + [~,dataSetName,~] = fileparts(dataSetFileName); + disp(sprintf('computing RMSE over %d samples from file: %s', N_samples, dataSetFileName)); + + [mean_val_loss, ~, pred_y, sp_times] = validate_msp_tempotron(ts, samples, rewards, w_out, V_thresh, V_rest, tau_m, tau_s); + + accu = (length(find(y == pred_y)) * 100)/length(y); + rmse = sqrt(mean((y-pred_y).^2)); + disp(sprintf('RMSE=%.2f | accuracy=%.2f | N_samples=%d | N_syn=%d', rmse, accu, N_samples, N_syn)); + + % only save if datSet contains > 1 sample + if N_samples > 1 + targetPath = [filepath filesep 'predictions' filesep name]; + [status, msg, msgID] = mkdir(targetPath); + outFile = fullfile(targetPath, strcat(dataSetName,ext)); + save(outFile, 'pred_y', 'y', 'accu', 'rmse', 'sp_times'); + disp(sprintf('saved results to: %s', outFile)); + else + if ~isfield(data, 'predictions') + [data(:).predictions] = {}; + end + + exists = 0; + for i=1:length(data.predictions) + if 1 == strcmp(data.predictions{i}{1}, name) + data.predictions{i} = {name, pred_y,y,accu,rmse,sp_times}; + exists = 1; + end + end + + if (exists < 1) + data.predictions{end+1} = {name, pred_y,y,accu,rmse,sp_times}; + end + save(dataSetFileName, 'data'); + disp(sprintf('saved results to dataSet: %s', dataSetFileName)); + end +end \ No newline at end of file diff --git a/matlab/validate_msp_tempotron.m b/matlab/validate_msp_tempotron.m new file mode 100644 index 0000000..5c95cc7 --- /dev/null +++ b/matlab/validate_msp_tempotron.m @@ -0,0 +1,24 @@ +function [mean_loss, validation_errors, predictions, spiketimes] = validate_msp_tempotron(ts, trials, labels, w, V_thresh, V_rest, tau_m, tau_s) + + memo_exp = memoize(@exp); + memo_exp.CacheSize = size(trials, 2)*10; + validation_errors = zeros(1, size(trials, 1)); + predictions = zeros(1, size(trials, 1)); + dataFormatType = iscell(trials{1}); + %dataFormatType = size(trials{1},2) ~= size(trials{2},2) || size(trials{1},2) + size(trials{2},2) == 0; + spiketimes = cell(1, size(trials, 1)); + for j=1:size(trials, 1) + if dataFormatType == 0 + pattern = cell(trials(j,:)); + else + pattern = trials{j}; + end + + [v_t, t_sp, ~, ~, ~] = MSPTempotron(memo_exp, ts, pattern, w, V_thresh, V_rest, tau_m, tau_s); + validation_errors(1, j) = abs(length(t_sp) - labels(j)); + predictions(1,j) = length(t_sp); + spiketimes{1,j} = t_sp; + end + + mean_loss = mean(validation_errors); +end \ No newline at end of file diff --git a/mkDataSet_DrosoArtificialStim.py b/mkDataSet_DrosoArtificialStim.py new file mode 100644 index 0000000..88a5ddc --- /dev/null +++ b/mkDataSet_DrosoArtificialStim.py @@ -0,0 +1,438 @@ +from olnet import run_sim, save_sim, save_sim_hdf5 +from olnet.tuning import get_orn_tuning, get_receptor_tuning, create_stimulation_matrix, gen_shot_noise, combine_noise_with_protocol, gen_gauss_sequence +from brian2 import * +import numpy as np +import olnet.models.droso_mushroombody_apl as droso_mb +from olnet import AttrDict +import sys,argparse,time, os +import traceback + +def current_milli_time(): + return int(round(time.time() * 1000)) + +def flatten(a,b): + return a+b + +def divide(lst, mean_rate, split_size): + it = iter(lst) + from itertools import islice + size = len(lst) + for i in range(split_size - 1,0,-1): + s = max(2, int(np.random.poisson(mean_rate, size=1))) + yield list(islice(it,0,s)) + size -= s + yield list(it) + + +def gen_pulsed_gauss_stimulus(odor_ids, T, dt, mu, std, n_pulses, primary_odor_id, n_stim=1, allow_overlap=False): + """ + generate a pulsed stimulus that simulates a gaussian plume cone profile + Returns tuple of (stimulus,n_pulses,pulse_times) + :param odor_ids: + :param T: + :param dt: + :param mu: mean of gaussian profile (or T/2 if None) + :param std: std. dev. of gaussian + :param n_pulses: tuple(n_primary,n_bg) number of pulses to generate for primary_odor_id and all other background odors + :param primary_odor_id: the primary odor_id that should use the gaussian profile. all other odor_ids will be background distractors + :param n_stim: + :param allow_overlap: + :return: + """ + + if type(n_pulses) == tuple: + n_primary,n_other = n_pulses + else: + n_primary = n_pulses + n_other = n_pulses // 3 + + X = np.zeros((n_stim, int(T / dt))) + X[primary_odor_id, :] = gen_gauss_sequence(T, dt, std, mu, n_primary) + + for odor_id in odor_ids: + if odor_id != primary_odor_id: + bg_idx = np.random.choice(X.shape[1], n_other) + X[odor_id, bg_idx] = 1 + if allow_overlap is False: + overlap_idx = np.where(X[odor_id, :] == X[primary_odor_id, :])[0] + X[odor_id, overlap_idx] = 0 + + X_prime = np.c_[np.zeros((n_stim, 1)), X] + pulse_times = [(np.where(np.diff(X_prime[n, :]) == 1)[0] * dt).tolist() for n in range(n_stim)] + return X, [len(p) for p in pulse_times], pulse_times + +def gen_pulsed_stimulus(odor_ids, T, dt, pulse_rate, pulse_duration=(0.1, 0.5), n_stim=1, allow_overlap=False): + """ + generate a pulsed stimulus where no. of stimuli are poisson distr. with pulse_rate + and gaps between subsequent pulses are drawn uniformly. Returns tuple of (stimulus,n_pulses,pulse_times) + :param: odor_ids: + :param T: + :param dt: + :param pulse_rate: + :param pulse_duration: + :param n_stim: + :param allow_overlap: + :return: tuple + """ + from functools import reduce + n_bins = int(T / dt) * n_stim + pulses = np.random.uniform(pulse_duration[0], pulse_duration[1], size=np.random.poisson(pulse_rate)) + pulse_bins = (pulses / dt).astype(np.int) + X = np.zeros((n_stim, int(T / dt))) + + spaced_pulses = [[1] * p for p in pulse_bins.tolist()] + print("n_pulses: {} | spaced_pulses: {}".format(len(pulse_bins), len(spaced_pulses))) + pad_x = [0] * int((n_bins // n_stim) - pulse_bins.sum()) + pad_sizes = np.random.randint(10, int((len(pad_x) - len(pulse_bins)) // len(spaced_pulses)), + size=len(spaced_pulses)).tolist() + ptr = 0 + start = 0 + pulse_times = [[] for _ in range(n_stim)] + for k, s in enumerate(spaced_pulses): + stim_idx = np.random.choice(odor_ids, size=1, p=[1. / len(odor_ids) for _ in odor_ids])[0] + print("assigning pulse {}/{} to odor_id {}".format(k, len(spaced_pulses), stim_idx)) + stop = pad_sizes.pop(0) + pad = pad_x[start:start + stop] + #print("pulse #{} len: {} n_pad: {}".format(k, len(s), len(pad))) + ptr += len(pad) + X[stim_idx, ptr:ptr + len(s)] = s + #print("stim_idx: {} / {} / {}".format(stim_idx, n_stim, pulse_times)) + pulse_times[stim_idx].append(ptr * dt) + ptr += len(s) + start = stop + + X_prime = np.c_[np.zeros((n_stim, 1)), X] + return X, [len(np.where(np.diff(X_prime[n, :]) == 1)[0]) for n in range(n_stim)], pulse_times + +def gen_shared_params(model_params, params, neuron_models): + # autom. create shared model params for each neuron model type + for v in params: + for n in neuron_models: + k = '{}{}'.format(v, n) + if k not in model_params: + model_params.update({k: model_params[v]}) + return model_params + + +def gen_shotnoise_input(dt, warmup_time, pulse_stim, n_receptors, N_glo, odor_ids, ORNperGlo, receptors_per_odor, stimulus_rate, bg_rate, stim_scale=0.003, bg_scale=0.001): + #print("pulse_stim: {}".format(pulse_stim.shape)) + simtime = ((warmup_time/second) + (dt/second) * pulse_stim.shape[1]) * second + print("ORNs={} glumeruli={} ORNperGlu={} receptors_per_odor={}, n_receptors={}, odor_ids={}".format(n_receptors, N_glo, ORNperGlo, + receptors_per_odor, + n_receptors, odor_ids)) + pad_n = int(warmup_time / dt) + # TODO: support more than 2 odors + y = [] + for odor_idx in odor_ids: + y.append(np.array([0]*pad_n + pulse_stim[odor_idx, :].tolist())) # odor A pulse stim + #y2 = np.array([0]*pad_n + pulse_stim[1, :].tolist()) # odor B pulse stim + + ORN_noise = gen_shot_noise(stimulus_rate, simtime / second, tau=0.6, dt=dt/second, dim=n_receptors, scale=stim_scale) + + S = get_receptor_tuning(N_glo, N_glo, receptors_per_odor, peak_rate=stimulus_rate) / stimulus_rate + M = get_orn_tuning(S, n_orns=ORNperGlo) + M_prime = M[odor_ids, :, :] + + A = (gen_shot_noise(bg_rate, simtime / second, tau=0.5, dt=dt / second, dim=n_receptors, scale=bg_scale).values) + + # subsequently add/superimpose stimuli of all odors + #print("y={}".format(len(y))) + for i,odor_idx in enumerate(odor_ids): + y1 = combine_noise_with_protocol(TimedArray(y[i], dt=dt), ORN_noise) + A += (y1.values * np.tile(M_prime[i], (1, y1.values.shape[0])).T) + + + #A = (gen_shot_noise(bg_rate, simtime / second, tau=0.9, dt=dt/second, dim=n_receptors, scale=bg_scale).values) \ + # + ((y1.values * np.tile(M_prime[0], (1, y1.values.shape[0])).T) + (y2.values * np.tile(M_prime[1], (1, y2.values.shape[0])).T)) + + print("created stimulation matrix for {} odors: {}".format(len(odor_ids), A.shape)) + stimulus = TimedArray(A * uA, dt=dt) + print("created stimulus TimedArray: {} warmup: {}".format(stimulus.values.shape, warmup_time)) + return simtime, stimulus, M_prime + + +def run_model(model_params, N_glo, ORNperGlo, N_KC, simtime, stimulus, dt = 0.1 * ms, network_seed=42): + # use fixed random seed to build same network arch. + np.random.seed(network_seed) + seed(network_seed) + + model_params = gen_shared_params(model_params, ['C', 'gL', 'EL', 'Vt', 'Vr', 'tau_Ia'], ['ORN', 'PN', 'LN', 'KC', 'APL']) + model_params.update({'stimulus': stimulus}) + + + NG, c = droso_mb.network(model_params, + None, + droso_mb.model_ORN, + droso_mb.model_PN, + droso_mb.model_LN, + droso_mb.model_KC, + droso_mb.model_APL, + wORNinputORN=1 * model_params['w0'], + wORNPN=1.1282 * model_params['w0'], + wORNLN=1 * model_params['w0'], + wLNPN=2.5 * model_params['w0'], # enable lateral inhib. + wPNKC=double(model_params['wPNKC']) * model_params['w0'], + wKCAPL=double(model_params['wKCAPL']) * model_params['w0'], + wAPLKC=double(model_params['wAPLKC']) * model_params['w0'], + N_glu=N_glo, + ORNperGlu=ORNperGlo, + N_KC=N_KC, + PNperKC=double(model_params['PNperKC']), + V0min=model_params['EL'], + V0max=model_params['Vt'], + apl_delay=model_params['apl_delay']) + + var_mons = [ + ('ORN', ('v', 'g_i', 'g_e'), [360]), + ('PN', ('v', 'g_i', 'g_e'), [15]), + ('LN', ('v', 'g_i', 'g_e'), [15]), + ('APL', ('v', 'g_i', 'g_e'), [0]) + ] + + return run_sim(model_params, NG, c, simtime, sim_dt=dt, + spike_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + rate_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + state_monitors=var_mons) + + + + +def worker(args): + (id, name, seed, model_params, args, plot) = args + np.random.seed(seed) + + t_start = current_milli_time() + N_glo = 52 + ORNperGlo = (2080 // N_glo) # Droso: roughly 2000 ORNs total + N_KC = 2000 # droso: 2000 + n_receptors = N_glo * ORNperGlo # * model_params['orn_input_multiplier'] + receptors_per_odor = 15 + warmup_time = args.warmup_time * second # 2 * second + sim_dt = 0.1 * ms + stim_dt = args.stimulus_dt * ms # 1*ms # time-resolution for stimulus TimedArray + bg_rate = args.bg_rate + stimulus_rate = args.stimulus_rate + T = args.T + stim_noise_scale = args.stim_noise_scale # 0.003 + bg_noise_scale = args.bg_noise_scale #0.001 + + model_params.update({ + 'seed': seed, + 'T': T, + 'pulse_rate': args.pulse_rate, + 'min_pulse_duration': args.min_pulse_duration, + 'max_pulse_duration': args.max_pulse_duration, + 'stim_noise_scale': stim_noise_scale, + 'bg_noise_scale': bg_noise_scale, + 'stim_dt': stim_dt / second, + 'noise_bg_rate': bg_rate, + 'noise_stim_rate': stimulus_rate, + 'N_KC': N_KC, + 'ORNperGlo': ORNperGlo, + 'n_receptors': n_receptors, + 'receptors_per_odor': receptors_per_odor + }) + + print("worker[{}] started odor_ids: {} ...".format(id, args.odor_ids)) + pulse_stim, rewards = None, None + + # loop - to catch rare cases where stimulus could not be generated + while (pulse_stim is None): + try: + if args.gaussian <= 0: + pulse_stim, rewards, pulse_times = gen_pulsed_stimulus(args.odor_ids, T, stim_dt / second, args.pulse_rate, + pulse_duration=(args.min_pulse_duration, args.max_pulse_duration), + n_stim=N_glo) + else: + pulse_stim, rewards, pulse_times = gen_pulsed_gauss_stimulus(args.odor_ids, T, stim_dt / second, + args.gauss_mean, args.gauss_std, + (args.pulse_rate,args.gauss_rate_other), + int(args.gauss_primary_odor_id), + n_stim=N_glo) + except Exception as e: + traceback.print_exc() + pulse_stim, rewards, pulse_times = None, None, [] + + simtime, stimulus, M = gen_shotnoise_input(stim_dt, warmup_time, pulse_stim, n_receptors, N_glo, args.odor_ids, ORNperGlo, receptors_per_odor, + stimulus_rate, bg_rate, stim_scale=stim_noise_scale, bg_scale=bg_noise_scale) + + + spikemons, pop_mons, state_mons, var_mons = run_model(model_params, N_glo, ORNperGlo, N_KC, (T + args.warmup_time) * second, stimulus, sim_dt, args.network_seed) + t_stop = current_milli_time() + print("worker[{}] finished (took {} sec)".format(id, (t_stop-t_start)/1000)) + + model_params.update({'rewards': rewards}) + model_params.update({'stimulation_times': pulse_times}) + model_params.pop('stimulus', None) # TimedArray is not pickle-able - remove it + + if plot: + fileName = "sim-{}-{}".format(id, seed) + data = save_sim("cache/{}/{}.npz".format(name, fileName), + model_params, + spikemons, pop_mons, state_mons, simtime, warmup_time, sim_dt, + stimulus=np.flipud(stimulus.values.T), + tuning=M, + stimulus_times=pulse_times, + n_receptors=n_receptors + ) + + if plot: + from olnet.plotting.figures import figure1 + f = figure1(data) + f.savefig("figures/{}/{}.png".format(name, fileName), dpi=f.dpi) + print("worker[{}] saved figure: figures/{}/{}.png".format(id, name, fileName)) + + + # align spiketrains to warmup offset + sp_trains_aligned = {} + for k, v in spikemons.items(): + trial_sp = [] + for s in v.spike_trains().values(): + sp_times = (s / second) - args.warmup_time + trial_sp.append(list(sp_times)) + sp_trains_aligned[k] = trial_sp + + spikeData = AttrDict({ + k: AttrDict({'count': v.count[:], + 't': (v.t[:] / second), + 't_aligned': (v.t[:] / second) - args.warmup_time, + 'i': v.i[:], + 'spike_trains': v.spike_trains(), + 'spike_trains_aligned': sp_trains_aligned[k]}) for k, v in spikemons.items() + }) + + return (id, rewards, spikeData, pulse_times, (t_stop-t_start)) + + + +if __name__ == "__main__": + from concurrent.futures import ProcessPoolExecutor + import scipy.io as scpio + + argv = sys.argv[1:] + + parser = argparse.ArgumentParser(description='Generate data set of KC spike-times using drosoMB model and artificially generated stimulus sequences') + + parser.add_argument('-n', '--name', type=str, nargs='?', help='name of data-set') + parser.add_argument('-N', '--N', type=int, nargs='?', help = 'number of samples to generate', default=10) + parser.add_argument('--network_seed', type=int, nargs='?', help='RNG seed used to build network model', default=42) + parser.add_argument('--odor_ids', type=int, action='append', help='indices of odors to use', required=True) + parser.add_argument('--n_cpu', type=int, nargs='?', help = 'no of CPUs to use for parallel simulations', default=4) + parser.add_argument('--bg_rate', type=int, nargs='?', help = 'background shot noise poisson rate', default=300) + parser.add_argument('--stimulus_rate', type=int, nargs='?', help = 'stimulus shot noise poisson rate', default=300) + parser.add_argument('-T', type=float, nargs='?', help = 'stimulus duration (in seconds)', default=10) + parser.add_argument('--warmup_time', type=float, nargs='?', help = 'duration of warmup phase (in seconds)', default=2) + parser.add_argument('--stimulus_dt', type=float, nargs='?', help = 'dt of stimulus TimedArray (in ms)', default=0.5) + parser.add_argument('--pulse_rate', type=int, nargs='?', help = 'max. number of pulses within a sequence', default=8) + parser.add_argument('--max_pulse_duration', type=float, nargs='?', help = 'max. duration of a single pulse (in seconds)', default=0.5) + parser.add_argument('--min_pulse_duration', type=float, nargs='?', help = 'max. duration of a single pulse (in seconds)', default=0.1) + parser.add_argument('--stim_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for stimulus', default=0.004) + parser.add_argument('--bg_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for background activity', default=0.0055) + + parser.add_argument('--gaussian', type=int, nargs='?', help='whether to use gaussian profile', default=0) + parser.add_argument('--gauss_mean', type=float, nargs='?', help='mean of gaussian profile',required=True) + parser.add_argument('--gauss_std', type=float, nargs='?', help='mean of gaussian profile', default=1.5) + parser.add_argument('--gauss_primary_odor_id', type=int, nargs='?', help='primary odor_id to use for gaussian profile', required=True) + parser.add_argument('--gauss_rate_other', type=int, nargs='?', help='number of stimuli to draw for other odor_ids (uniform)',default=3) + + parser.add_argument('-o', '--outfile', nargs='?', type=str, help = 'output filaneme for MAT file') + parser.add_argument("--modelParams", action='append', type=lambda kv: kv.split("="), dest='customModelParams') + + args = parser.parse_args() + + + os.makedirs("cache/{}".format(args.name), exist_ok=True) + os.makedirs("figures/{}".format(args.name), exist_ok=True) + + print(args) + + model_params = { + # 'orn_input_multiplier': 1, # distribute total poisson rate over 10 indep. processes + # Neuron Parameters + 'C': 289.5 * pF, + 'gL': 28.95 * nS, + 'EL': -70 * mV, + 'Vt': -57 * mV, + 'Vr': -70 * mV, + 'tau_ref': 5 * ms, + # APL parameters + 'VtAPL': -50 * mV, + 'VrAPL': -55 * mV, + 'ELAPL': -55 * mV, + 'gLAPL': 0.5 * nS, + 'CAPL': 10 * pF, + 'apl_delay': 0.2 * ms, + # Synaptic Parameters + 'Ee': 0 * mV, + 'Ei': -75 * mV, + 'EIa': -90 * mV, # reversal potential + 'tau_syn_e': 2 * ms, + 'tau_syn_i': 10 * ms, + 'tau_Ia': 1000 * ms, # adaptation conduct. time constatnt + 'tau_IaKC': 50 * ms, # adaptation time constant for KCs + # Weights + 'w0': 1 * nS, + # Adaptation Parameters + 'bORN': 2 * nS, + 'bKC': 5 * nS, + 'bLN': 0 * nS, + 'bPN': 0 * nS, + 'D': 0.005, + 'PNperKC': 6, # this will achieve ~8% KC activity + 'wPNKC': 14, + 'wKCAPL': 3, + 'wAPLKC': 3 + } + + if args.customModelParams is not None: + model_params.update(args.customModelParams) + else: + args.customModelParams = {} + + print(model_params) + + trial_ids = [] + samples = [] + samples_alt = [] + odor_ids = [] + rewards = [] + stim_times = [] + durations = [] + warmup = args.warmup_time + + worker_args = [(id, args.name, seed, model_params, args, id in list(range(5))) for id,seed in enumerate(np.random.randint(142, size=args.N))] + with ProcessPoolExecutor(max_workers=args.n_cpu) as executor: + #result = executor.map(worker, worker_args) + for params, result in zip(worker_args, executor.map(worker, worker_args)): + task_id,reward,sp_data,pulse_times,duration = result + rewards.append(reward) + trial_ids.append(task_id) + trial_sp = [] + for sp in sp_data.KC.spike_trains_aligned: + sp_times = filter(lambda s: s >= 0.0, sp) # only spikes AFTER warmup + trial_sp.append(list(sp_times)) + + samples.append(trial_sp) + odor_ids.append(args.odor_ids) + samples_alt.append(dict({'t': sp_data.KC.t_aligned, 'i': sp_data.KC.i})) + durations.append(duration) + stim_times.append(pulse_times) + print("{} finished - avg. duration: {}".format(task_id, np.array(durations).mean())) + + + output = { + 'trial_ids': trial_ids, + 'targets': rewards, + 'odor_ids': odor_ids, + 'stimulus_times': stim_times, + 'trials': samples, + 'trials_tuples': samples_alt, + 'T_trial': args.T, + 'N_trials': len(rewards) + } + + scpio.savemat(args.outfile, {'data': output, 'args': args}) + print("saved to MATLAB file: {}".format(args.outfile)) + npzFile = args.outfile[:-4] + ".npz" + np.savez(npzFile, data=output, args=args) + print("saved to NPZ file: {}".format(npzFile)) \ No newline at end of file diff --git a/mkDataSet_DrosoCustomProtocol.py b/mkDataSet_DrosoCustomProtocol.py new file mode 100644 index 0000000..150207f --- /dev/null +++ b/mkDataSet_DrosoCustomProtocol.py @@ -0,0 +1,398 @@ +from olnet import run_sim, save_sim, save_sim_hdf5 +from olnet.tuning import get_orn_tuning, get_receptor_tuning, create_stimulation_matrix, gen_shot_noise, combine_noise_with_protocol +from brian2 import * +import numpy as np +import olnet.models.droso_mushroombody_apl as droso_mb +from olnet import AttrDict +import sys,argparse,time, os +import random +import ast + +def current_milli_time(): + return int(round(time.time() * 1000)) + + +def gen_pulsed_stimulus(T, dt, odor_idx, pulse_duration=(0.1, 0.5), n_stim=1): + """ + generate a single pulsed stimulus that is randomly positioned within [0,T]. + Pulse duration is randomly sampled from given bounds + Returns tuple of (stimulus,n_pulses,pulse_times) + :param T: + :param dt: + :param odor_idx: index of specific odor to generate stimulus for + :param pulse_duration: + :param n_stim: number of total stimulation types (e.g. odors) + :return: tuple + """ + # randomly sample pulse duration + pulses = np.random.uniform(pulse_duration[0], pulse_duration[1], size=100) + pulse_bins = (pulses / dt).astype(np.int).tolist() + random.shuffle(pulse_bins) + X = np.zeros((n_stim, int(T / dt))) + + # randomly position pulse - use poisson to have more variability in positioning + n_bins = int(T / dt) + spaced_pulses = [[1] * pulse_bins[0]] + start_bins = np.random.randint(0, n_bins-int(0.1/dt), size=200).astype(np.int).tolist() + #start_bins = (np.random.poisson(int(T * 100), size=10) / 100 / dt).astype(np.int) + #print("start_bins: {}".format(start_bins)) + start_bins = list(filter(lambda p: (p+pulse_bins[0]) < (n_bins-5), start_bins)) + random.shuffle(start_bins) + #print("start_bins: {} | {} sec".format(start_bins, np.array(start_bins)*dt)) + start_bin = start_bins[0] + print("pulse offset: {}sec | duration: {}sec".format(start_bin * dt, pulse_bins[0] * dt)) + + pulse_times = [[] for _ in range(n_stim)] + for k, s in enumerate(spaced_pulses): + X[odor_idx, start_bin:start_bin + len(s)] = s + pulse_times[odor_idx].append(start_bin * dt) + + X_prime = np.c_[np.zeros((n_stim, 1)), X] + return X, [len(np.where(np.diff(X_prime[n, :]) == 1)[0]) for n in range(n_stim)], pulse_times + +def gen_shared_params(model_params, params, neuron_models): + # autom. create shared model params for each neuron model type + for v in params: + for n in neuron_models: + k = '{}{}'.format(v, n) + if k not in model_params: + model_params.update({k: model_params[v]}) + return model_params + + +def gen_shotnoise_input(protocol_dt, dt, warmup_time, pulse_stim, n_receptors, N_glo, odor_ids, ORNperGlo, receptors_per_odor, stimulus_rate, bg_rate, stim_scale=0.003, bg_scale=0.001): + #print("pulse_stim: {}".format(pulse_stim.shape)) + simtime = ((warmup_time/second) + (protocol_dt/second) * pulse_stim.shape[1]) * second + print("ORNs={} glumeruli={} ORNperGlu={} receptors_per_odor={}, n_receptors={}, odor_ids={}, simtime={}".format(n_receptors, N_glo, ORNperGlo, + receptors_per_odor, + n_receptors, odor_ids,simtime)) + pad_n = int(warmup_time / protocol_dt) + y = [] + for odor_idx in odor_ids: + y.append(np.array([0]*pad_n + pulse_stim[odor_idx, :].tolist())) # odor A pulse stim + #y2 = np.array([0]*pad_n + pulse_stim[1, :].tolist()) # odor B pulse stim + + ORN_noise = gen_shot_noise(stimulus_rate, simtime / second, tau=0.6, dt=dt/second, dim=n_receptors, scale=stim_scale) + + S = get_receptor_tuning(N_glo, N_glo, receptors_per_odor, peak_rate=stimulus_rate) / stimulus_rate + M = get_orn_tuning(S, n_orns=ORNperGlo) + M_prime = M[odor_ids, :, :] + + A = (gen_shot_noise(bg_rate, simtime / second, tau=0.5, dt=dt / second, dim=n_receptors, scale=bg_scale).values) + + # subsequently add/superimpose stimuli of all odors + #print("y={}".format(len(y))) + for i,odor_idx in enumerate(odor_ids): + y1 = combine_noise_with_protocol(TimedArray(y[i], dt=protocol_dt), ORN_noise) + A += (y1.values * np.tile(M_prime[i], (1, y1.values.shape[0])).T) + + + print("created stimulation matrix for {} odors: {}".format(len(odor_ids), A.shape)) + stimulus = TimedArray(A * uA, dt=dt) + print("created stimulus TimedArray: {} warmup: {}".format(stimulus.values.shape, warmup_time)) + return simtime, stimulus, M_prime + + + +def run_model(model_params, N_glo, ORNperGlo, N_KC, simtime, stimulus, dt = 0.1 * ms, network_seed=42): + # use fixed random seed to build same network arch. + np.random.seed(network_seed) + seed(network_seed) + + model_params = gen_shared_params(model_params, ['C', 'gL', 'EL', 'Vt', 'Vr', 'tau_Ia'], ['ORN', 'PN', 'LN', 'KC', 'APL']) + model_params.update({'stimulus': stimulus}) + + + NG, c = droso_mb.network(model_params, + None, + droso_mb.model_ORN, + droso_mb.model_PN, + droso_mb.model_LN, + droso_mb.model_KC, + droso_mb.model_APL, + wORNinputORN=1 * model_params['w0'], + wORNPN=1.1282 * model_params['w0'], + wORNLN=1 * model_params['w0'], + wLNPN=2.5 * model_params['w0'], # enable lateral inhib. + wPNKC=double(model_params['wPNKC']) * model_params['w0'], + wKCAPL=double(model_params['wKCAPL']) * model_params['w0'], + wAPLKC=double(model_params['wAPLKC']) * model_params['w0'], + N_glu=N_glo, + ORNperGlu=ORNperGlo, + N_KC=N_KC, + PNperKC=double(model_params['PNperKC']), + V0min=model_params['EL'], + V0max=model_params['Vt'], + apl_delay=model_params['apl_delay']) + + var_mons = [ + ('ORN', ('v', 'g_i', 'g_e'), [360]), + ('PN', ('v', 'g_i', 'g_e'), [15]), + ('LN', ('v', 'g_i', 'g_e'), [15]), + ('APL', ('v', 'g_i', 'g_e'), [0]) + ] + + return run_sim(model_params, NG, c, simtime, sim_dt=dt, + spike_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + rate_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + state_monitors=var_mons) + + + + +def worker(args): + (id, stimulus_id, network_id, name, seed, odor_ids, stimulus_protocol, model_params, args, plot) = args + + N_odors = 1 + a = np.array(stimulus_protocol) + stim_protocols = [stimulus_protocol] + T = len(stimulus_protocol) * args.dt + + if len(a.shape) > 1: + N_odors = a.shape[0] + T = a.shape[1] * args.dt + stim_protocols = stimulus_protocol + else: + a = np.array(stim_protocols) + + assert N_odors == len(odor_ids), "number of stimulation protocols ({}) must be equal to number of odors ({})".format(N_odors, len(odor_ids)) + + np.random.seed(seed) + + + t_start = current_milli_time() + N_glo = 52 + ORNperGlo = (2080 // N_glo) # Droso: roughly 2000 ORNs total + N_KC = 2000 # droso: 2000 + n_receptors = N_glo * ORNperGlo # * model_params['orn_input_multiplier'] + receptors_per_odor = 15 + warmup_time = args.warmup_time * second # 2 * second + sim_dt = 0.1 * ms + stim_dt = args.stimulus_dt * ms # 1*ms # time-resolution for stimulus TimedArray + bg_rate = args.bg_rate + stimulus_rate = args.stimulus_rate + stim_noise_scale = args.stim_noise_scale # 0.003 + bg_noise_scale = args.bg_noise_scale #0.001 + + model_params.update({ + 'seed': seed, + 'T': T, + 'stimulus_protocol': stimulus_protocol, + 'stimulus_protocol_dt': args.dt, + 'stim_noise_scale': stim_noise_scale, + 'bg_noise_scale': bg_noise_scale, + 'stim_dt': stim_dt / second, + 'noise_bg_rate': bg_rate, + 'noise_stim_rate': stimulus_rate, + 'N_KC': N_KC, + 'ORNperGlo': ORNperGlo, + 'n_receptors': n_receptors, + 'receptors_per_odor': receptors_per_odor + }) + + print("worker[{} stimulus_id {} (n_odors: {}) network_id: {}] started ...".format(id, stimulus_id, N_odors, network_id)) + + X = np.zeros((N_glo, a.shape[1])) + for j,odor_idx in enumerate(odor_ids): + X[odor_idx, :] = stim_protocols[j] + + X_prime = np.c_[np.zeros((N_glo, 1)), X] + rewards = [len(np.where(np.diff(X_prime[n, :]) == 1)[0]) for n in range(N_glo)] + pulse_times = [(np.where(X[n,:] == 1)[0] * args.dt).tolist() for n in range(N_glo)] + + simtime, stimulus, M = gen_shotnoise_input(args.dt * second, stim_dt, warmup_time, X, n_receptors, N_glo, odor_ids, + ORNperGlo, receptors_per_odor, + stimulus_rate, bg_rate, stim_scale=stim_noise_scale, + bg_scale=bg_noise_scale) + + + spikemons, pop_mons, state_mons, var_mons = run_model(model_params, N_glo, ORNperGlo, N_KC, (T + args.warmup_time) * second, stimulus, sim_dt, network_id) + t_stop = current_milli_time() + print("worker[{}] finished (took {} sec)".format(id, (t_stop-t_start)/1000)) + + model_params.update({'rewards': rewards}) + model_params.update({'stimulation_times': pulse_times}) + model_params.pop('stimulus', None) # TimedArray is not pickle-able - remove it + + if plot: + fileName = "sim-{}-stim_id-{}-net_id-{}-seed-{}".format(id,stimulus_id, network_id, seed) + data = save_sim("cache/{}/{}.npz".format(name, fileName), + model_params, + spikemons, pop_mons, state_mons, simtime, warmup_time, sim_dt, + stimulus=np.flipud(stimulus.values.T), + tuning=M, + stimulus_times=pulse_times, + n_receptors=n_receptors, + stimulus_protocol=stimulus_protocol + ) + + if plot: + from olnet.plotting.figures import figure1 + f = figure1(data) + f.savefig("figures/{}/{}.png".format(name, fileName), dpi=f.dpi) + print("worker[{}] saved figure: figures/{}/{}.png".format(id, name, fileName)) + + # align spiketrains to warmup offset + sp_trains_aligned = {} + for k,v in spikemons.items(): + trial_sp = [] + for s in v.spike_trains().values(): + sp_times = (s / second) - args.warmup_time + trial_sp.append(list(sp_times)) + sp_trains_aligned[k] = trial_sp + + spikeData = AttrDict({ + k: AttrDict({'count': v.count[:], + 't': (v.t[:] / second), + 't_aligned': (v.t[:] / second) - args.warmup_time, + 'i': v.i[:], + 'spike_trains': v.spike_trains(), + 'spike_trains_aligned': sp_trains_aligned[k]}) for k, v in spikemons.items() + }) + + return (id, stimulus_id, network_id, seed, odor_ids, rewards, spikeData, pulse_times, T, (t_stop-t_start)) + + +def arg_as_list(s): + v = ast.literal_eval(s) + if type(v) is not list: + raise argparse.ArgumentTypeError("Argument \"%s\" is not a list" % (s)) + return v + +if __name__ == "__main__": + from concurrent.futures import ProcessPoolExecutor + import scipy.io as scpio + + argv = sys.argv[1:] + + parser = argparse.ArgumentParser(description='Generate data set of KC spike-times using drosoMB model and custom stimulation protocol') + + parser.add_argument('-n', '--name', type=str, nargs='?', help='name of data-set') + parser.add_argument('-N', '--N', type=int, nargs='?', help = 'number of samples to generate for each stimulation', default=5) + parser.add_argument('--network_seeds', type=list, nargs='?', help='RNG seed(s) used to build network model. If list > 0 multiple indep. networks will be simulated', default=[42]) + parser.add_argument('--protocols', type=arg_as_list, action='append', help='stimulation protocols to run.(single odor: [1,0,0], 2 odors: [[1,1,0,0],[0,1,1,0]])', required=True) + parser.add_argument('--odor_ids', type=arg_as_list, action='append', help='odor_ids used in each stimulus (single odor: [0] 2 odors: [5, 15]])', required=True) + parser.add_argument('--n_cpu', type=int, nargs='?', help = 'no of CPUs to use for parallel simulations', default=1) + parser.add_argument('--bg_rate', type=int, nargs='?', help = 'background shot noise poisson rate', default=300) + parser.add_argument('--stimulus_rate', type=int, nargs='?', help = 'stimulus shot noise poisson rate', default=300) + parser.add_argument('-dt', type=float, nargs='?', help = 'dt of stimulation protocol (in seconds)', default=0.5) + parser.add_argument('--warmup_time', type=float, nargs='?', help = 'duration of warmup phase (in seconds)', default=2) + parser.add_argument('--stimulus_dt', type=float, nargs='?', help = 'dt of stimulus TimedArray (in ms)', default=0.5) + parser.add_argument('--stim_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for stimulus', default=0.004) + parser.add_argument('--bg_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for background activity', default=0.0055) # use 0.0056 for more noise + parser.add_argument('-o', '--outfile', nargs='?', type=str, help = 'output filename for MAT file') + parser.add_argument("--modelParams", action='append', type=lambda kv: kv.split("="), dest='customModelParams') + + args = parser.parse_args() + + os.makedirs("cache/{}".format(args.name), exist_ok=True) + os.makedirs("figures/{}".format(args.name), exist_ok=True) + + if os.path.isfile(args.outfile): + print("skipped - cache file exists at: {}".format(args.outfile)) + exit(0) + + print(args) + + + model_params = { + # 'orn_input_multiplier': 1, # distribute total poisson rate over 10 indep. processes + # Neuron Parameters + 'C': 289.5 * pF, + 'gL': 28.95 * nS, + 'EL': -70 * mV, + 'Vt': -57 * mV, + 'Vr': -70 * mV, + 'tau_ref': 5 * ms, + # APL parameters + 'VtAPL': -50 * mV, + 'VrAPL': -55 * mV, + 'ELAPL': -55 * mV, + 'gLAPL': 0.5 * nS, + 'CAPL': 10 * pF, + 'apl_delay': 0.2 * ms, + # Synaptic Parameters + 'Ee': 0 * mV, + 'Ei': -75 * mV, + 'EIa': -90 * mV, # reversal potential + 'tau_syn_e': 2 * ms, + 'tau_syn_i': 10 * ms, + 'tau_Ia': 1000 * ms, # adaptation conduct. time constatnt + 'tau_IaKC': 50 * ms, # adaptation time constant for KCs + # Weights + 'w0': 1 * nS, + # Adaptation Parameters + 'bORN': 2 * nS, + 'bKC': 5 * nS, + 'bLN': 0 * nS, + 'bPN': 0 * nS, + 'D': 0.005, + 'PNperKC': 6, # this will achieve ~8% KC activity + 'wPNKC': 14, + 'wKCAPL': 3, + 'wAPLKC': 3 + } + + if args.customModelParams is not None: + model_params.update(args.customModelParams) + else: + args.customModelParams = {} + + print(model_params) + + samples = [] + trial_ids = [] + samples_alt = [] + rewards = [] + stimulus_ids = [] + odor_ids = [] + network_ids = [] + stim_times = [] + trial_durations = [] + durations = [] + warmup = args.warmup_time + + worker_args = [] + for stim_id,(stim_protocol,protocol_odor_ids) in enumerate(zip(args.protocols,args.odor_ids)): + for network_id in args.network_seeds: + worker_args.extend([(id, stim_id, network_id, args.name, seed, protocol_odor_ids, stim_protocol, model_params, args, id in list(range(5))) for id,seed in enumerate(np.random.randint(142, size=args.N))]) + + with ProcessPoolExecutor(max_workers=args.n_cpu) as executor: + #result = executor.map(worker, worker_args) + for params, result in zip(worker_args, executor.map(worker, worker_args)): + task_id,stim_id,net_id,_,protocol_odor_ids,reward,sp_data,pulse_times,T,duration = result + rewards.append(reward) + network_ids.append(net_id) + odor_ids.append(protocol_odor_ids) + stimulus_ids.append(stim_id) + trial_ids.append(task_id) + trial_durations.append(T) + trial_sp = [] + for sp in sp_data.KC.spike_trains_aligned: + sp_times = filter(lambda s: s >= 0.0, sp) # only spikes AFTER warmup + trial_sp.append(list(sp_times)) + + samples.append(trial_sp) + samples_alt.append(dict({'t': sp_data.KC.t_aligned, 'i': sp_data.KC.i})) + durations.append(duration) + stim_times.append(pulse_times) + print("{} finished - avg. duration: {}".format(task_id, np.array(durations).mean())) + + + output = { + 'trial_ids': trial_ids, + 'targets': rewards, + 'odor_ids': odor_ids, + 'stimulus_ids': stimulus_ids, + 'network_ids': network_ids, + 'stimulus_times': stim_times, + 'trials': samples, + 'trials_tuples': samples_alt, + 'T_trials': trial_durations + } + + scpio.savemat(args.outfile, {'data':output, 'args': args}) + print("saved to MATLAB file: {}".format(args.outfile)) + npzFile = args.outfile[:-4] + ".npz" + np.savez(npzFile, data=output, args=args) + print("saved to NPZ file: {}".format(npzFile)) \ No newline at end of file diff --git a/mkDataSet_DrosoLabCondition.py b/mkDataSet_DrosoLabCondition.py new file mode 100644 index 0000000..c21acba --- /dev/null +++ b/mkDataSet_DrosoLabCondition.py @@ -0,0 +1,364 @@ +from olnet import run_sim, save_sim, save_sim_hdf5 +from olnet.tuning import get_orn_tuning, get_receptor_tuning, create_stimulation_matrix, gen_shot_noise, combine_noise_with_protocol +from brian2 import * +import numpy as np +import olnet.models.droso_mushroombody_apl as droso_mb +from olnet import AttrDict +import sys,argparse,time, os +import traceback +import random + +N_samples = 800 + +def current_milli_time(): + return int(round(time.time() * 1000)) + + +def gen_pulsed_stimulus(T, dt, odor_idx, pulse_duration=(0.1, 0.5), n_stim=1): + """ + generate a single pulsed stimulus that is randomly positioned within [0,T]. + Pulse duration is randomly sampled from given bounds + Returns tuple of (stimulus,n_pulses,pulse_times) + :param T: + :param dt: + :param odor_idx: index of specific odor to generate stimulus for + :param pulse_duration: + :param n_stim: number of total stimulation types (e.g. odors) + :return: tuple + """ + # randomly sample pulse duration + pulses = np.random.uniform(pulse_duration[0], pulse_duration[1], size=100) + pulse_bins = (pulses / dt).astype(np.int).tolist() + random.shuffle(pulse_bins) + X = np.zeros((n_stim, int(T / dt))) + + # randomly position pulse - use poisson to have more variability in positioning + n_bins = int(T / dt) + spaced_pulses = [[1] * pulse_bins[0]] + start_bins = np.random.randint(0, n_bins-int(0.1/dt), size=200).astype(np.int).tolist() + #start_bins = (np.random.poisson(int(T * 100), size=10) / 100 / dt).astype(np.int) + #print("start_bins: {}".format(start_bins)) + start_bins = list(filter(lambda p: (p+pulse_bins[0]) < (n_bins-5), start_bins)) + random.shuffle(start_bins) + #print("start_bins: {} | {} sec".format(start_bins, np.array(start_bins)*dt)) + start_bin = start_bins[0] + print("pulse offset: {}sec | duration: {}sec".format(start_bin * dt, pulse_bins[0] * dt)) + + pulse_times = [[] for _ in range(n_stim)] + for k, s in enumerate(spaced_pulses): + X[odor_idx, start_bin:start_bin + len(s)] = s + pulse_times[odor_idx].append(start_bin * dt) + + X_prime = np.c_[np.zeros((n_stim, 1)), X] + return X, [len(np.where(np.diff(X_prime[n, :]) == 1)[0]) for n in range(n_stim)], pulse_times + +def gen_shared_params(model_params, params, neuron_models): + # autom. create shared model params for each neuron model type + for v in params: + for n in neuron_models: + k = '{}{}'.format(v, n) + if k not in model_params: + model_params.update({k: model_params[v]}) + return model_params + + +def gen_shotnoise_input(dt, warmup_time, pulse_stim, n_odors, odor_idx, n_receptors, N_glo, ORNperGlo, receptors_per_odor, stimulus_rate, bg_rate, stim_scale=0.003, bg_scale=0.001): + #print("pulse_stim: {}".format(pulse_stim.shape)) + simtime = ((warmup_time/second) + (dt/second) * pulse_stim.shape[1]) * second + print("ORNs={} glumeruli={} ORNperGlu={} receptors_per_odor={}, n_receptors={}".format(n_receptors, N_glo, ORNperGlo, + receptors_per_odor, n_receptors)) + pad_n = int(warmup_time / dt) + + y_1 = np.array([0]*pad_n + pulse_stim[odor_idx, :].tolist()) + + ORN_noise = gen_shot_noise(stimulus_rate, simtime / second, tau=0.6, dt=dt/second, dim=n_receptors, scale=stim_scale) + + S = get_receptor_tuning(N_glo, n_odors, receptors_per_odor, peak_rate=stimulus_rate) / stimulus_rate + M = get_orn_tuning(S, n_orns=ORNperGlo) + M_prime = M[[odor_idx], :, :] + + y1 = combine_noise_with_protocol(TimedArray(y_1, dt=dt), ORN_noise) + + A = (gen_shot_noise(bg_rate, simtime / second, tau=0.5, dt=dt/second, dim=n_receptors, scale=bg_scale).values) \ + + (y1.values * np.tile(M_prime[0], (1, y1.values.shape[0])).T) + + print("created stimulation matrix odor_idx={}: {}".format(odor_idx, A.shape)) + stimulus = TimedArray(A * uA, dt=dt) + print("created stimulus TimedArray: {} warmup: {}".format(stimulus.values.shape, warmup_time)) + return simtime, stimulus, M_prime + + +def run_model(model_params, N_glo, ORNperGlo, N_KC, simtime, stimulus, dt = 0.1 * ms, network_seed=42): + # use fixed random seed to build same network arch. + np.random.seed(network_seed) + seed(network_seed) + + model_params = gen_shared_params(model_params, ['C', 'gL', 'EL', 'Vt', 'Vr', 'tau_Ia'], ['ORN', 'PN', 'LN', 'KC', 'APL']) + model_params.update({'stimulus': stimulus}) + + + NG, c = droso_mb.network(model_params, + None, + droso_mb.model_ORN, + droso_mb.model_PN, + droso_mb.model_LN, + droso_mb.model_KC, + droso_mb.model_APL, + wORNinputORN=1 * model_params['w0'], + wORNPN=1.1282 * model_params['w0'], + wORNLN=1 * model_params['w0'], + wLNPN=2.5 * model_params['w0'], # enable lateral inhib. + wPNKC=double(model_params['wPNKC']) * model_params['w0'], + wKCAPL=double(model_params['wKCAPL']) * model_params['w0'], + wAPLKC=double(model_params['wAPLKC']) * model_params['w0'], + N_glu=N_glo, + ORNperGlu=ORNperGlo, + N_KC=N_KC, + PNperKC=double(model_params['PNperKC']), + V0min=model_params['EL'], + V0max=model_params['Vt'], + apl_delay=model_params['apl_delay']) + + var_mons = [ + ('ORN', ('v', 'g_i', 'g_e'), [360]), + ('PN', ('v', 'g_i', 'g_e'), [15]), + ('LN', ('v', 'g_i', 'g_e'), [15]), + ('APL', ('v', 'g_i', 'g_e'), [0]) + ] + + return run_sim(model_params, NG, c, simtime, sim_dt=dt, + spike_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + rate_monitors=['ORN', 'PN', 'LN', 'KC', 'APL'], + state_monitors=var_mons) + + + + +def worker(args): + (id, name, seed, odor_id, N_odors, model_params, args, plot) = args + np.random.seed(seed) + + t_start = current_milli_time() + N_glo = 52 + ORNperGlo = (2080 // N_glo) # Droso: roughly 2000 ORNs total + N_KC = 2000 # droso: 2000 + n_receptors = N_glo * ORNperGlo # * model_params['orn_input_multiplier'] + receptors_per_odor = 15 + warmup_time = args.warmup_time * second # 2 * second + sim_dt = 0.1 * ms + stim_dt = args.stimulus_dt * ms # 1*ms # time-resolution for stimulus TimedArray + bg_rate = args.bg_rate + stimulus_rate = args.stimulus_rate + T = args.T + stim_noise_scale = args.stim_noise_scale # 0.003 + bg_noise_scale = args.bg_noise_scale #0.001 + + model_params.update({ + 'seed': seed, + 'T': T, + 'odor_id': odor_id, + 'min_pulse_duration': args.min_pulse_duration, + 'max_pulse_duration': args.max_pulse_duration, + 'stim_noise_scale': stim_noise_scale, + 'bg_noise_scale': bg_noise_scale, + 'stim_dt': stim_dt / second, + 'noise_bg_rate': bg_rate, + 'noise_stim_rate': stimulus_rate, + 'N_KC': N_KC, + 'ORNperGlo': ORNperGlo, + 'n_receptors': n_receptors, + 'receptors_per_odor': receptors_per_odor + }) + + print("worker[{} odor_id {} / {}] started ...".format(id, odor_id, N_odors)) + pulse_stim, rewards = None, None + + # loop - to catch rare cases where stimulus could not be generated + while (pulse_stim is None): + try: + pulse_stim, rewards, pulse_times = gen_pulsed_stimulus(T, stim_dt / second, odor_id, + pulse_duration=(args.min_pulse_duration, args.max_pulse_duration), + n_stim=N_glo) + except Exception as e: + traceback.print_exc() + pulse_stim, rewards, pulse_times = None, None, [] + + simtime, stimulus, M = gen_shotnoise_input(stim_dt, warmup_time, pulse_stim, N_glo, odor_id, n_receptors, N_glo, ORNperGlo, receptors_per_odor, + stimulus_rate, bg_rate, stim_scale=stim_noise_scale, bg_scale=bg_noise_scale) + + + spikemons, pop_mons, state_mons, var_mons = run_model(model_params, N_glo, ORNperGlo, N_KC, (T + args.warmup_time) * second, stimulus, sim_dt, args.network_seed) + t_stop = current_milli_time() + print("worker[{}] finished (took {} sec)".format(id, (t_stop-t_start)/1000)) + + model_params.update({'rewards': rewards}) + model_params.update({'stimulation_times': pulse_times}) + model_params.pop('stimulus', None) # TimedArray is not pickle-able - remove it + + if plot: + fileName = "sim-odor-{}-{}-{}".format(odor_id, id, seed) + data = save_sim("cache/{}/{}.npz".format(name, fileName), + model_params, + spikemons, pop_mons, state_mons, simtime, warmup_time, sim_dt, + stimulus=np.flipud(stimulus.values.T), + tuning=M, + stimulus_times=pulse_times, + n_receptors=n_receptors, + odor_id=odor_id + ) + + if plot: + from olnet.plotting.figures import figure1 + f = figure1(data) + f.savefig("figures/{}/{}.png".format(name, fileName), dpi=f.dpi) + print("worker[{}] saved figure: figures/{}/{}.png".format(id, name, fileName)) + + # align spiketrains to warmup offset + sp_trains_aligned = {} + for k,v in spikemons.items(): + trial_sp = [] + for s in v.spike_trains().values(): + sp_times = (s / second) - args.warmup_time + trial_sp.append(list(sp_times)) + sp_trains_aligned[k] = trial_sp + + spikeData = AttrDict({ + k: AttrDict({'count': v.count[:], + 't': (v.t[:] / second), + 't_aligned': (v.t[:] / second) - args.warmup_time, + 'i': v.i[:], + 'spike_trains': v.spike_trains(), + 'spike_trains_aligned': sp_trains_aligned[k]}) for k, v in spikemons.items() + }) + + return (id, odor_id, rewards, spikeData, pulse_times, (t_stop-t_start)) + + + +if __name__ == "__main__": + from concurrent.futures import ProcessPoolExecutor + import scipy.io as scpio + + argv = sys.argv[1:] + + parser = argparse.ArgumentParser(description='Generate data set of KC spike-times using drosoMB model and single pulse stimulus of single odors') + + parser.add_argument('-n', '--name', type=str, nargs='?', help='name of data-set') + parser.add_argument('-N', '--N', type=int, nargs='?', help = 'number of samples to generate for each odor', default=10) + parser.add_argument('--network_seed', type=int, nargs='?', help='RNG seed used to build network model', default=42) + parser.add_argument('--odor_ids', type=int, action='append', help='indices of different odors to use', required=True) + parser.add_argument('--n_cpu', type=int, nargs='?', help = 'no of CPUs to use for parallel simulations', default=4) + parser.add_argument('--bg_rate', type=int, nargs='?', help = 'background shot noise poisson rate', default=300) + parser.add_argument('--stimulus_rate', type=int, nargs='?', help = 'stimulus shot noise poisson rate', default=300) + parser.add_argument('-T', type=float, nargs='?', help = 'stimulus duration (in seconds)', default=5) + parser.add_argument('--warmup_time', type=float, nargs='?', help = 'duration of warmup phase (in seconds)', default=2) + parser.add_argument('--stimulus_dt', type=float, nargs='?', help = 'dt of stimulus TimedArray (in ms)', default=0.5) + parser.add_argument('--max_pulse_duration', type=float, nargs='?', help = 'max. duration of a single pulse (in seconds)', default=1.0) + parser.add_argument('--min_pulse_duration', type=float, nargs='?', help = 'max. duration of a single pulse (in seconds)', default=0.1) + parser.add_argument('--stim_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for stimulus', default=0.004) + parser.add_argument('--bg_noise_scale', type=float, nargs='?', help = 'scale of shot-noise for background activity', default=0.0055) # use 0.0055 for less noise + parser.add_argument('-o', '--outfile', nargs='?', type=str, help = 'output filename for MAT file') + parser.add_argument("--modelParams", action='append', type=lambda kv: kv.split("="), dest='customModelParams') + + args = parser.parse_args() + + + os.makedirs("cache/{}".format(args.name), exist_ok=True) + os.makedirs("figures/{}".format(args.name), exist_ok=True) + + print(args) + + model_params = { + # 'orn_input_multiplier': 1, # distribute total poisson rate over 10 indep. processes + # Neuron Parameters + 'C': 289.5 * pF, + 'gL': 28.95 * nS, + 'EL': -70 * mV, + 'Vt': -57 * mV, + 'Vr': -70 * mV, + 'tau_ref': 5 * ms, + # APL parameters + 'VtAPL': -50 * mV, + 'VrAPL': -55 * mV, + 'ELAPL': -55 * mV, + 'gLAPL': 0.5 * nS, + 'CAPL': 10 * pF, + 'apl_delay': 0.2 * ms, + # Synaptic Parameters + 'Ee': 0 * mV, + 'Ei': -75 * mV, + 'EIa': -90 * mV, # reversal potential + 'tau_syn_e': 2 * ms, + 'tau_syn_i': 10 * ms, + 'tau_Ia': 1000 * ms, # adaptation conduct. time constatnt + 'tau_IaKC': 50 * ms, # adaptation time constant for KCs + # Weights + 'w0': 1 * nS, + # Adaptation Parameters + 'bORN': 2 * nS, + 'bKC': 5 * nS, + 'bLN': 0 * nS, + 'bPN': 0 * nS, + 'D': 0.005, + 'PNperKC': 6, # this will achieve ~8% KC activity + 'wPNKC': 14, + 'wKCAPL': 3, + 'wAPLKC': 3 + } + + if args.customModelParams is not None: + model_params.update(args.customModelParams) + else: + args.customModelParams = {} + + print(model_params) + + samples = [] + trial_ids = [] + samples_alt = [] + rewards = [] + odor_ids = [] + stim_times = [] + durations = [] + warmup = args.warmup_time + + worker_args = [] + for odor_id in args.odor_ids: + worker_args.extend([(id, args.name, seed, int(odor_id), len(odor_ids), model_params, args, id in list(range(5))) for id,seed in enumerate(np.random.randint(142, size=args.N))]) + + with ProcessPoolExecutor(max_workers=args.n_cpu) as executor: + #result = executor.map(worker, worker_args) + for params, result in zip(worker_args, executor.map(worker, worker_args)): + task_id,odor_id,reward,sp_data,pulse_times,duration = result + rewards.append(reward) + odor_ids.append(odor_id) + trial_ids.append(task_id) + trial_sp = [] + for sp in sp_data.KC.spike_trains_aligned: + sp_times = filter(lambda s: s >= 0.0, sp) # only spikes AFTER warmup + trial_sp.append(list(sp_times)) + + samples.append(trial_sp) + samples_alt.append(dict({'t': sp_data.KC.t_aligned, 'i': sp_data.KC.i})) + durations.append(duration) + stim_times.append(pulse_times) + print("{} finished - avg. duration: {}".format(task_id, np.array(durations).mean())) + + + output = { + 'trial_ids': trial_ids, + 'targets': rewards, + 'odor_ids': odor_ids, + 'stimulus_times': stim_times, + 'trials': samples, + 'trials_tuples': samples_alt, + 'T_trial': args.T, + 'N_trials': len(rewards) + } + + print(args) + scpio.savemat(args.outfile, {'data':output, 'args': args}) + print("saved to MATLAB file: {}".format(args.outfile)) + npzFile = args.outfile[:-4] + ".npz" + np.savez(npzFile, data=output, args=args) + print("saved to NPZ file: {}".format(npzFile)) diff --git a/olnet/__init__.py b/olnet/__init__.py new file mode 100644 index 0000000..ec64dae --- /dev/null +++ b/olnet/__init__.py @@ -0,0 +1,285 @@ +from brian2 import * +import inspect +#from collections import namedtuple +#__all__ = ["echo", "surround", "reverse"] + +class AttrDict(dict): + """ + dict subclass which allows access to keys as attributes: mydict.myattr + """ + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + +def get_args_from_dict(fn, params): + """ + extract function parameters by name from a dict + :param fn: + :param params: + :return: dict + """ + arg_keys = inspect.signature(fn).parameters.keys() + return dict((k, params[k]) for k in arg_keys if k in params) + +def run_sim(params, NG, + c, + simtime, + sim_dt=0.1 *ms, + rv_timestep=500, + report='text', + rate_monitors=None, + state_monitors=None, + spike_monitors=None, + recvars=None): + """ + run a BRIAN2 simulation given the network architecture as NeuronGroups (NG) and connections/synapses (c) + :param NG: dict of neuron groups with keys == neurons/layers + :param c: dict of connections / synapses + :param simtime: duration of simulation + :param sim_dt: simulation temporal resolution (timestep) + :param rv_timestep: + :param report: + :param rate_monitors: list of neuron group names + :param state_monitors: list of tuple: (neuron group, variables (tuple), indices (optional)) + :param spike_monitors: list of neuron group names + :param recvars: creates StateMonitor to record given variables from ALL neuron groups + :return: + """ + + defaultclock.dt = sim_dt + net = Network(NG.values(), c.values()) + + ### monitors + if spike_monitors is not None: + spmons = [SpikeMonitor(NG[mon], record=True) for mon in spike_monitors] + net.add(spmons) + + if rate_monitors is not None: + rate_mons = [PopulationRateMonitor(NG[mon], name='rate_{}'.format(mon)) for mon in rate_monitors] + net.add(rate_mons) + + if recvars is not None: + var_mons = [StateMonitor(NG[mon], variables=recvars, record=True, dt=rv_timestep) for mon in spike_monitors] + net.add(var_mons) + else: + var_mons = None + + if state_monitors is not None: + state_mons = [StateMonitor(NG[mon[0]], variables=mon[1], record=(True if len(mon) <= 2 else mon[2]), name='state_{}'.format(mon[0])) for mon in state_monitors] + net.add(state_mons) + + # RateKC = PopulationRateMonitor(NG['KC']) + # stateKC = StateMonitor(NG['KC'], 'v', record=True) + # net.add(stateKC) + # net.add(RateKC) + + ### run + net.run(simtime, report=report, namespace=params) + + if spike_monitors is not None: + out_spmons = dict((spike_monitors[i], sm) for i, sm in enumerate(spmons)) + else: + out_spmons = None + + # out_spmons.update(dict(('population_' + spike_monitors[i], sm) for i, sm in enumerate(rate_mons))) + # out_spmons = dict((spike_monitors[i], sm) for i, sm in enumerate(spmons)) + + # prepare rate monitors + if rate_monitors is not None: + out_pop_mons = dict((rate_monitors[i], sm) for i, sm in enumerate(rate_mons)) + else: + out_pop_mons = None + + # prepare state monitors + if state_monitors is not None: + out_statemons = dict((state_monitors[i][0], sm) for i, sm in enumerate(state_mons)) + else: + out_statemons = None + + # prepare recvar monitors (this is probably redundant to state_mons ?) + if var_mons is not None: + out_var_mons = dict( + (mon, dict((var, statemon.values) for var, statemon in m.iteritems())) for mon, m in zip(spike_monitors, var_mons)) + else: + out_var_mons = None + + return out_spmons, out_pop_mons, out_statemons, out_var_mons + + +def load_sim(filename): + """ + convenience function to load simulation results from numpy file. + :param filename: + :return: AttrDict + """ + return np.load(filename, allow_pickle=True)['data'][()] + +def save_sim(filename, params, spmons, popmons, statemons, simtime, warmup, dt, **kwargs): + """ + save all results (all monitors, model parameters, ...) from run_sim into a numpy file. + The monitors will be stored in a pickle-able object with the same attributes as the Brian2 monitors (t,i,spike_trains etc..). + All time values (Monitors, simtime ...) are being stored in seconds. + :param filename: + :param params: + :param spmons: + :param popmons: + :param statemons: + :param simtime: + :param warmup: + :param dt: + :param kwargs: custom data to be stored (e.g. stimulus input, tuning profiles ...). All items must be pickle-able. + :return: + """ + + #SpikeMon = namedtuple('SpikeMonitorLike', ['t', 'i', 'spike_trains'], verbose=False) + #PopMon = namedtuple('PopulationRateMonitorLike', ['t', 'rate', 'smooth_rate'], verbose=False) + + stateMonData = dict() + if statemons is not None: + for k, v in statemons.items(): + # TODO: also store the Quantity / unit ? + data = {var: v.variables[var].get_value().T for var in v.record_variables} + data.update({'t': v.t[:] / second}) + stateMonData.update({k: AttrDict(data)}) + + + data = { + 'spikes': AttrDict({k: AttrDict({'count': v.count[:], 't': v.t[:]/second, 'i': v.i[:], 'spike_trains': v.spike_trains()}) for k,v in spmons.items()}), + 'rates': AttrDict({k: AttrDict({'t': v.t[:]/second, 'rate': v.rate[:]/Hz, 'smooth_rate': v.smooth_rate(window='flat', width=50*ms)[:] / Hz}) for k,v in popmons.items()}) if popmons is not None else AttrDict({}), + 'variables': AttrDict(stateMonData), + 'simtime': simtime / second, + 'warmup': warmup / second, + 'dt': dt / second, + 'params': params + } + + data.update(kwargs) + + d = AttrDict(data) + np.savez(filename, data=d) + return d + +def save_sim_hdf5(filename, params, spmons, popmons, statemons, simtime, warmup, dt, **kwargs): + """ + save all results (all monitors, model parameters, ...) from run_sim into a HDF5 file. + The monitors will be stored in a pickle-able object with the same attributes as the Brian2 monitors (t,i,spike_trains etc..). + All time values (Monitors, simtime ...) are being stored in seconds. + :param filename: + :param params: + :param spmons: + :param popmons: + :param statemons: + :param simtime: + :param warmup: + :param dt: + :param kwargs: custom data to be stored (e.g. stimulus input, tuning profiles ...). All items must be pickle-able. + :return: + """ + import h5py + #SpikeMon = namedtuple('SpikeMonitorLike', ['t', 'i', 'spike_trains'], verbose=False) + #PopMon = namedtuple('PopulationRateMonitorLike', ['t', 'rate', 'smooth_rate'], verbose=False) + + def recursively_save_dict_contents_to_group(h5file, path, dic): + """ + .... + """ + for key, item in dic.items(): + if isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes)): + h5file[path + key] = item + elif isinstance(item, dict): + recursively_save_dict_contents_to_group(h5file, path + key + '/', item) + else: + raise ValueError('Cannot save %s type' % type(item)) + + + stateMonData = dict() + for k, v in statemons.items(): + # TODO: also store the Quantity / unit ? + data = {var: v.variables[var].get_value().T for var in v.record_variables} + data.update({'t': v.t[:] / second}) + stateMonData.update({k: AttrDict(data)}) + + + data = { + 'spikes': AttrDict({k: AttrDict({'t': v.t[:]/second, 'i': v.i[:], 'spike_trains': v.spike_trains()}) for k,v in spmons.items()}), + 'rates': AttrDict({k: AttrDict({'t': v.t[:]/second, 'rate': v.rate[:]/Hz, 'smooth_rate': v.smooth_rate(window='flat', width=50*ms)[:] / Hz}) for k,v in popmons.items()}), + 'variables': AttrDict(stateMonData), + 'simtime': simtime / second, + 'warmup': warmup / second, + 'dt': dt / second, + 'params': params + } + + data.update(kwargs) + + f = h5py.File(filename, "w") + recursively_save_dict_contents_to_group(f, '/', data) + f.close() + return f + #return np.savez(filename, data=AttrDict(data)) + +def export_sim_matlab(filename, matFile=None): + """ + export a simulation file saved with save_sim to MAT file + :param filename: sim file + :param matFile: opt. matlab file + :return: + """ + import scipy.io as scpio + + data = load_sim(filename) + + sp_trains_aligned = {} + for k, v in data.spikes.items(): + trial_sp = [] + for s in v.spike_trains.values(): + sp_times = (s / second) - data.warmup + trial_sp.append(list(sp_times)) + sp_trains_aligned[k] = trial_sp + + spikeData = AttrDict({ + k: AttrDict({'count': v.count[:], + 't': v.t[:], + 't_aligned': v.t[:] - data.warmup, + 'i': v.i[:], + 'spike_trains': v.spike_trains.values(), + 'spike_trains_aligned': sp_trains_aligned[k]}) for k, v in data.spikes.items() + }) + + trial_ids = [] + samples = [] + odor_ids = [] + rewards = [] + stim_times = [] + warmup = data.warmup + + trial_sp = [] + for sp in spikeData.KC.spike_trains_aligned: + sp_times = filter(lambda s: s >= 0.0, sp) # only spikes AFTER warmup + trial_sp.append(list(sp_times)) + + samples.append(trial_sp) + stim_times.append(data.stimulus_times) + rewards.append(data.params['rewards']) + try: + odor_ids.append(data.odor_id) + except AttributeError: + pass + + output = { + 'trial_ids': trial_ids, + 'targets': rewards, + 'odor_ids': odor_ids, + 'stimulus_times': stim_times, + 'trials': samples, + 'T_trial': data.params['T'], + 'N_trials': len(rewards) + } + + if matFile is None: + matFile = "{}.mat".format(filename[:-4]) + + + scpio.savemat(matFile, {'data': output}) + print("exported {} to: {}", filename, matFile) + return matFile \ No newline at end of file diff --git a/olnet/models/__init__.py b/olnet/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olnet/models/droso_mushroombody.py b/olnet/models/droso_mushroombody.py new file mode 100644 index 0000000..f565096 --- /dev/null +++ b/olnet/models/droso_mushroombody.py @@ -0,0 +1,204 @@ +from brian2 import * +import numpy as np +from olnet import get_args_from_dict + + +def model_ORN_input(CORN, gLORN, ELORN, tau_IaORN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtORN, VrORN, tau_ref, stimulus): + """ + same ase model_ORN but without adaptation + :return: + """ + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) + I0 + stimulus(t,i))/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + I0 : amp + ''' + + neuron_modelORN = dict() + neuron_modelORN['model'] = Equations(neuron_eqs, g_l=gLORN, E_l=ELORN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CORN, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaORN) + neuron_modelORN['threshold'] = 'v > VtORN' + neuron_modelORN['reset'] = '''v = VrORN''' # at reset, membrane v is reset and spike triggered adaptation conductance is increased + neuron_modelORN['refractory'] = tau_ref + + return neuron_modelORN + +def model_ORN(CORN, gLORN, ELORN, tau_IaORN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtORN, VrORN, tau_ref, bORN, stimulus): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0 + stimulus(t,i))/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelORN = dict() + neuron_modelORN['model'] = Equations(neuron_eqs, g_l=gLORN, E_l=ELORN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CORN, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaORN) + neuron_modelORN['threshold'] = 'v > VtORN' + neuron_modelORN['reset'] = '''v = VrORN; g_Ia-=bORN''' # at reset, membrane v is reset and spike triggered adaptation conductance is increased + neuron_modelORN['refractory'] = tau_ref + + return neuron_modelORN + + +def model_PN(CPN, gLPN, ELPN, tau_IaPN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtPN, VrPN, tau_ref, bPN): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelPN = dict() + neuron_modelPN['model'] = Equations(neuron_eqs, g_l=gLPN, E_l=ELPN, E_e=Ee, E_i=Ei,E_Ia = EIa, C_m=CPN, tau_e=tau_syn_e, tau_i=tau_syn_i,tau_Ia=tau_IaPN) + neuron_modelPN['threshold'] = 'v > VtPN' + neuron_modelPN['reset'] = '''v = VrPN; g_Ia-=bPN''' + neuron_modelPN['refractory'] = tau_ref + + return neuron_modelPN + +def model_LN(CLN, gLLN, ELLN, tau_IaLN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtLN, VrLN, tau_ref, bLN): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelLN = dict() + neuron_modelLN['model'] = Equations(neuron_eqs, g_l=gLLN, E_l=ELLN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CLN, tau_e=tau_syn_e, tau_i=tau_syn_i, tau_Ia=tau_IaLN) + neuron_modelLN['threshold'] = 'v > VtLN' + neuron_modelLN['reset'] = '''v = VrLN; g_Ia-=bLN''' + neuron_modelLN['refractory'] = tau_ref + + return neuron_modelLN + + +def model_KC(CKC, gLKC, ELKC, tau_IaKC, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtKC, VrKC, tau_ref, bKC): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelKC = dict() + neuron_modelKC['model'] = Equations(neuron_eqs, DeltaT=1 * mV, g_l=gLKC, E_l=ELKC, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CKC, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaKC) + neuron_modelKC['threshold'] = 'v > VtKC' + neuron_modelKC['reset'] = '''v = VrKC; g_Ia-=bKC''' + neuron_modelKC['refractory'] = tau_ref + + return neuron_modelKC + + + +def network(params, + input_ng, + neuron_modelORN, + neuron_modelPN, + neuron_modelLN, + neuron_modelKC, + wORNinputORN, + wORNPN, + wORNLN, + wLNPN, + wPNKC, + N_glu, + ORNperGlu, + N_KC, + PNperKC, + V0min, + V0max, + I0_PN = 0*nA, + I0_LN = 0*nA, + I0_KC = 0*nA, + inh_delay=0 * ms): + ''' + ## ToDo documentation ## + Connect ORNs to PNs such that ORNperGlu ORNs representing input to one Glu connects to 1 PN + repeat for every Glu, using connect_full. Connects ORNs to LNs in the same way. + ''' + + ######################### NEURONGROUPS ######################### + + NG = dict() + + # ORN Input + #n_receptors = ORNperGlu * N_glu + + if input_ng is not None: + validInputTypes = (PoissonGroup, Group, SpikeSource) + assert isinstance(input_ng, validInputTypes), "parameter 'input_ng' must be of type: {}".format(validInputTypes) + NG['ORNinput'] = input_ng + + neuron_params_orn = get_args_from_dict(neuron_modelORN, params) + neuron_params_pn = get_args_from_dict(neuron_modelPN, params) + neuron_params_ln = get_args_from_dict(neuron_modelLN, params) + neuron_params_kc = get_args_from_dict(neuron_modelKC, params) + + NG['ORN'] = NeuronGroup(N_glu*ORNperGlu, **neuron_modelORN(**neuron_params_orn), namespace=params, method='euler', name='ORNs') + NG['ORN'].I0 = I0_PN + NG['PN'] = NeuronGroup(N_glu, **neuron_modelPN(**neuron_params_pn), namespace=params, method='euler', name='PNs') + NG['PN'].I0=I0_PN + NG['LN'] = NeuronGroup(N_glu, **neuron_modelLN(**neuron_params_ln), namespace=params, method='euler', name='LNs') + NG['LN'].I0=I0_LN + NG['KC'] = NeuronGroup(N_KC, **neuron_modelKC(**neuron_params_kc), namespace=params, method='euler', name='KCs') + NG['KC'].I0=I0_KC + + ######################### CONNECTIONS ######################### + c = dict() + + if input_ng is not None: + ### input-ORN ### + c['ORNinputORN'] = Synapses(NG['ORNinput'], NG['ORN'], 'w : siemens', on_pre='g_e+=w', namespace=params) + for i in np.arange(len(NG['ORN'])): + #c['ORNinputORN'].connect(i=list(range(i * orn_input_multiplier, (i + 1) * orn_input_multiplier)), j=i) + c['ORNinputORN'].connect(i=i, j=i) + c['ORNinputORN'].w = wORNinputORN + + ### ORN-PN ### + c['ORNPN'] = Synapses(NG['ORN'], NG['PN'], 'w : siemens', on_pre='g_e += w', namespace=params) + for i in np.arange(N_glu): + c['ORNPN'].connect(i=list(range(i * ORNperGlu, (i + 1) * ORNperGlu)), j=i) + c['ORNPN'].w = wORNPN + + ### ORN-LN ### + c['ORNLN'] = Synapses(NG['ORN'], NG['LN'], 'w : siemens', on_pre='g_e += w', namespace=params) + for i in np.arange(N_glu): + c['ORNLN'].connect(i=list(range(i * ORNperGlu, (i + 1) * ORNperGlu)), j=i) + c['ORNLN'].w = wORNLN + + ### LN-PN ### + c['LNPN'] = Synapses(NG['LN'], NG['PN'], 'w : siemens', on_pre='g_i -= w', delay=inh_delay, namespace=params) + c['LNPN'].connect() # connect_all + c['LNPN'].w = wLNPN + + + ## PN-KC ## + c['KC'] = Synapses(NG['PN'], NG['KC'], 'w : siemens', on_pre='g_e += w', namespace=params) + c['KC'].connect(p=PNperKC / float(N_glu)) + c['KC'].w = wPNKC + # the total number of possible synapses is N_pre*N_post + # when the connection probability is 0.05 then N_syn = N_pre*N_post*0.05 (on average) + # every postsynaptic neuron will receive N_syn/N_post synaptic inputs _on average_ + # and every presynaptic input will send out N_syn/N_pre _on average_ + # number of inputs per KC is given by the biominal distribution + + ######################### INITIAL VALUES ######################### + #NG['PN'].v = np.random.rand(len(NG['PN']))*(V0max-V0min)+V0min + #NG['LN'].v = np.random.rand(len(NG['LN']))*(V0max-V0min)+V0min + #NG['KC'].v = np.random.rand(len(NG['KC']))*(V0max-V0min)+V0min + + NG['ORN'].v = np.random.uniform(V0min, V0max, size=len(NG['ORN'])) * volt + NG['PN'].v = np.random.uniform(V0min, V0max, size=len(NG['PN'])) * volt + NG['LN'].v = np.random.uniform(V0min, V0max, size=len(NG['LN'])) * volt + NG['KC'].v = np.random.uniform(V0min, V0max, size=len(NG['KC'])) * volt + + return NG, c \ No newline at end of file diff --git a/olnet/models/droso_mushroombody_apl.py b/olnet/models/droso_mushroombody_apl.py new file mode 100644 index 0000000..aa230f4 --- /dev/null +++ b/olnet/models/droso_mushroombody_apl.py @@ -0,0 +1,237 @@ +from brian2 import * +import numpy as np +from olnet import get_args_from_dict + + +def model_ORN_input(CORN, gLORN, ELORN, tau_IaORN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtORN, VrORN, tau_ref, stimulus): + """ + same ase model_ORN but without adaptation + :return: + """ + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) + I0 + stimulus(t,i))/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + I0 : amp + ''' + + neuron_modelORN = dict() + neuron_modelORN['model'] = Equations(neuron_eqs, g_l=gLORN, E_l=ELORN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CORN, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaORN) + neuron_modelORN['threshold'] = 'v > VtORN' + neuron_modelORN['reset'] = '''v = VrORN''' # at reset, membrane v is reset and spike triggered adaptation conductance is increased + neuron_modelORN['refractory'] = tau_ref + + return neuron_modelORN + +def model_ORN(CORN, gLORN, ELORN, tau_IaORN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtORN, VrORN, tau_ref, bORN, stimulus): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0 + stimulus(t,i))/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelORN = dict() + neuron_modelORN['model'] = Equations(neuron_eqs, g_l=gLORN, E_l=ELORN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CORN, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaORN) + neuron_modelORN['threshold'] = 'v > VtORN' + neuron_modelORN['reset'] = '''v = VrORN; g_Ia-=bORN''' # at reset, membrane v is reset and spike triggered adaptation conductance is increased + neuron_modelORN['refractory'] = tau_ref + + return neuron_modelORN + + +def model_PN(CPN, gLPN, ELPN, tau_IaPN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtPN, VrPN, tau_ref, bPN): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelPN = dict() + neuron_modelPN['model'] = Equations(neuron_eqs, g_l=gLPN, E_l=ELPN, E_e=Ee, E_i=Ei,E_Ia = EIa, C_m=CPN, tau_e=tau_syn_e, tau_i=tau_syn_i,tau_Ia=tau_IaPN) + neuron_modelPN['threshold'] = 'v > VtPN' + neuron_modelPN['reset'] = '''v = VrPN; g_Ia-=bPN''' + neuron_modelPN['refractory'] = tau_ref + + return neuron_modelPN + +def model_LN(CLN, gLLN, ELLN, tau_IaLN, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtLN, VrLN, tau_ref, bLN): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelLN = dict() + neuron_modelLN['model'] = Equations(neuron_eqs, g_l=gLLN, E_l=ELLN, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CLN, tau_e=tau_syn_e, tau_i=tau_syn_i, tau_Ia=tau_IaLN) + neuron_modelLN['threshold'] = 'v > VtLN' + neuron_modelLN['reset'] = '''v = VrLN; g_Ia-=bLN''' + neuron_modelLN['refractory'] = tau_ref + + return neuron_modelLN + + +def model_KC(CKC, gLKC, ELKC, tau_IaKC, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtKC, VrKC, tau_ref, bKC): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) - g_Ia*(E_Ia-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + dg_Ia/dt = -g_Ia/tau_Ia : siemens # conductance adaptation 'current' + I0 : amp + ''' + + neuron_modelKC = dict() + neuron_modelKC['model'] = Equations(neuron_eqs, DeltaT=1 * mV, g_l=gLKC, E_l=ELKC, E_e=Ee, E_i=Ei, E_Ia=EIa, C_m=CKC, tau_e=tau_syn_e,tau_i=tau_syn_i, tau_Ia=tau_IaKC) + neuron_modelKC['threshold'] = 'v > VtKC' + neuron_modelKC['reset'] = '''v = VrKC; g_Ia-=bKC''' + neuron_modelKC['refractory'] = tau_ref + + return neuron_modelKC + +def model_APL(CAPL, gLAPL, ELAPL, Ee, tau_syn_e, Ei, tau_syn_i, EIa, VtAPL, VrAPL, tau_ref): + + neuron_eqs = ''' + dv/dt = (g_l*(E_l-v) + g_e*(E_e-v) - g_i*(E_i-v) + I0)/C_m : volt (unless refractory) # Ia is the spike triggered adaptation + dg_e/dt = -g_e/tau_e : siemens # post-synaptic exc. conductance # synapses + dg_i/dt = -g_i/tau_i : siemens # post-synaptic inh. conductance + I0 : amp + ''' + + neuron_modelAPL = dict() + neuron_modelAPL['model'] = Equations(neuron_eqs, g_l=gLAPL, E_l=ELAPL, E_e=Ee, E_i=Ei,E_Ia = EIa, C_m=CAPL, tau_e=tau_syn_e, tau_i=tau_syn_i) + neuron_modelAPL['threshold'] = 'v > VtAPL' + neuron_modelAPL['reset'] = '''v = VrAPL''' + neuron_modelAPL['refractory'] = tau_ref + + return neuron_modelAPL + +def network(params, + input_ng, + neuron_modelORN, + neuron_modelPN, + neuron_modelLN, + neuron_modelKC, + neuron_modelAPL, + wORNinputORN, + wORNPN, + wORNLN, + wLNPN, + wPNKC, + wKCAPL, + wAPLKC, + N_glu, + ORNperGlu, + N_KC, + PNperKC, + V0min, + V0max, + I0_PN = 0*nA, + I0_LN = 0*nA, + I0_KC = 0*nA, + inh_delay=0 * ms, + apl_delay=0 * ms): + ''' + ## ToDo documentation ## + Connect ORNs to PNs such that ORNperGlu ORNs representing input to one Glu connects to 1 PN + repeat for every Glu, using connect_full. Connects ORNs to LNs in the same way. + ''' + + ######################### NEURONGROUPS ######################### + + NG = dict() + + # ORN Input + #n_receptors = ORNperGlu * N_glu + + if input_ng is not None: + validInputTypes = (PoissonGroup, Group, SpikeSource) + assert isinstance(input_ng, validInputTypes), "parameter 'input_ng' must be of type: {}".format(validInputTypes) + NG['ORNinput'] = input_ng + + neuron_params_orn = get_args_from_dict(neuron_modelORN, params) + neuron_params_pn = get_args_from_dict(neuron_modelPN, params) + neuron_params_ln = get_args_from_dict(neuron_modelLN, params) + neuron_params_kc = get_args_from_dict(neuron_modelKC, params) + neuron_params_apl = get_args_from_dict(neuron_modelAPL, params) + + NG['ORN'] = NeuronGroup(N_glu*ORNperGlu, **neuron_modelORN(**neuron_params_orn), namespace=params, method='euler', name='ORNs') + NG['ORN'].I0 = I0_PN + NG['PN'] = NeuronGroup(N_glu, **neuron_modelPN(**neuron_params_pn), namespace=params, method='euler', name='PNs') + NG['PN'].I0=I0_PN + NG['LN'] = NeuronGroup(N_glu, **neuron_modelLN(**neuron_params_ln), namespace=params, method='euler', name='LNs') + NG['LN'].I0=I0_LN + NG['KC'] = NeuronGroup(N_KC, **neuron_modelKC(**neuron_params_kc), namespace=params, method='euler', name='KCs') + NG['KC'].I0=I0_KC + NG['APL'] = NeuronGroup(1, **neuron_modelAPL(**neuron_params_apl), namespace=params, method='euler', name='APL') + NG['APL'].I0 = 0*nA + + ######################### CONNECTIONS ######################### + c = dict() + + if input_ng is not None: + ### input-ORN ### + c['ORNinputORN'] = Synapses(NG['ORNinput'], NG['ORN'], 'w : siemens', on_pre='g_e+=w', namespace=params) + for i in np.arange(len(NG['ORN'])): + #c['ORNinputORN'].connect(i=list(range(i * orn_input_multiplier, (i + 1) * orn_input_multiplier)), j=i) + c['ORNinputORN'].connect(i=i, j=i) + c['ORNinputORN'].w = wORNinputORN + + ### ORN-PN ### + c['ORNPN'] = Synapses(NG['ORN'], NG['PN'], 'w : siemens', on_pre='g_e += w', namespace=params) + for i in np.arange(N_glu): + c['ORNPN'].connect(i=list(range(i * ORNperGlu, (i + 1) * ORNperGlu)), j=i) + c['ORNPN'].w = wORNPN + + ### ORN-LN ### + c['ORNLN'] = Synapses(NG['ORN'], NG['LN'], 'w : siemens', on_pre='g_e += w', namespace=params) + for i in np.arange(N_glu): + c['ORNLN'].connect(i=list(range(i * ORNperGlu, (i + 1) * ORNperGlu)), j=i) + c['ORNLN'].w = wORNLN + + ### LN-PN ### + c['LNPN'] = Synapses(NG['LN'], NG['PN'], 'w : siemens', on_pre='g_i -= w', delay=inh_delay, namespace=params) + c['LNPN'].connect() # connect_all + c['LNPN'].w = wLNPN + + + ## PN-KC ## + c['KC'] = Synapses(NG['PN'], NG['KC'], 'w : siemens', on_pre='g_e += w', namespace=params) + c['KC'].connect(p=PNperKC / float(N_glu)) + c['KC'].w = wPNKC + # the total number of possible synapses is N_pre*N_post + # when the connection probability is 0.05 then N_syn = N_pre*N_post*0.05 (on average) + # every postsynaptic neuron will receive N_syn/N_post synaptic inputs _on average_ + # and every presynaptic input will send out N_syn/N_pre _on average_ + # number of inputs per KC is given by the biominal distribution + + ## KC-APL ## + c['KCAPL'] = Synapses(NG['KC'], NG['APL'], 'w : siemens', on_pre='g_e += w', delay=apl_delay, namespace=params) + c['KCAPL'].connect() # connect_all + c['KCAPL'].w = wKCAPL + + ## APL-KC ## + c['APLKC'] = Synapses(NG['APL'], NG['KC'], 'w : siemens', on_pre='g_i -= w', delay=apl_delay, namespace=params) + c['APLKC'].connect(p=1) + c['APLKC'].w = wAPLKC + + ######################### INITIAL VALUES ######################### + #NG['PN'].v = np.random.rand(len(NG['PN']))*(V0max-V0min)+V0min + #NG['LN'].v = np.random.rand(len(NG['LN']))*(V0max-V0min)+V0min + #NG['KC'].v = np.random.rand(len(NG['KC']))*(V0max-V0min)+V0min + + NG['ORN'].v = np.random.uniform(V0min, V0max, size=len(NG['ORN'])) * volt + NG['PN'].v = np.random.uniform(V0min, V0max, size=len(NG['PN'])) * volt + NG['LN'].v = np.random.uniform(V0min, V0max, size=len(NG['LN'])) * volt + NG['KC'].v = np.random.uniform(V0min, V0max, size=len(NG['KC'])) * volt + NG['APL'].v = np.random.uniform(V0min, V0max, size=len(NG['APL'])) * volt + + return NG, c \ No newline at end of file diff --git a/olnet/plotting/__init__.py b/olnet/plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olnet/plotting/figures.mplstyle b/olnet/plotting/figures.mplstyle new file mode 100644 index 0000000..c4ef9e1 --- /dev/null +++ b/olnet/plotting/figures.mplstyle @@ -0,0 +1,10 @@ +axes.linewidth : 1 +xtick.labelsize : 8 +ytick.labelsize : 8 +axes.labelsize : 8 +lines.linewidth : 1 +lines.markersize : 2 +legend.frameon : False +legend.fontsize : 8 +axes.prop_cycle : cycler(color=['e41a1c', '377eb8', '4daf4a', '984ea3', 'c51b7d', '4d9221', '542788', '8c510a', 'b2182b', '2166ac', '01665e']) +#font.sans-serif : Helvetica, Arial, sans-serif \ No newline at end of file diff --git a/olnet/plotting/figures.py b/olnet/plotting/figures.py new file mode 100644 index 0000000..5f8dfc8 --- /dev/null +++ b/olnet/plotting/figures.py @@ -0,0 +1,492 @@ +from brian2 import * +import matplotlib.pyplot as plt +from matplotlib import gridspec, cm +import matplotlib.colors as cor +from olnet import load_sim +import numpy as np +import scipy.io as scpio +from scipy import stats +from mpl_toolkits.axes_grid1 import make_axes_locatable +import string +#plt.style.use('ggplot') +plt.style.use('figures.mplstyle') +plt.rc('text', usetex=False) +#plt.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{wasysym}'+ +# r'\usepackage[dvipsnames]{xcolor} \usepackage{MnSymbol} \usepackage{txfonts}') + +def params_as_dict(params): + layers = ['ORN', 'PN', 'LN', 'KC'] + str_dict = dict({k: dict() for k in layers}) + str_dict['global'] = dict() + for k in params: + matched = list(filter(k.endswith, layers)) + + if len(matched): + l_ = matched[0] + str_dict[l_].update({k[:-len(l_)]: params[k]}) + else: + str_dict['global'].update({k: params[k]}) + + return str_dict + +def params_as_latex_table(d): + + header = "\\textbf{Parameter} & \\textbf{Value} & \\textbf{Unit} \\\\ \n" + + print_section = lambda name: "\multicolumn{3}{l}{\\cellcolor[HTML]{343434}{\\color[HTML]{FFFFFF} \\textbf{" + name + "}}} \\\\ \n" + header + print_row = lambda var,val,unit: "{} & ${:.5f}$ & {} \\\\ \n".format(var.replace('_', '\_'), val, unit if unit is not None else "") + tab_start = """\\begin{table}[]\n\\centering\n\\begin{tabular}{lll}""" + + tab_end = """\n\end{tabular}\n\end{table}""" + + str = "" + + for sec in d.keys(): + if sec == "global": + str += tab_start + str += print_section("Other") + for k in d[sec].keys(): + if isinstance(d[sec][k], Quantity): + v,u = d[sec][k].in_best_unit(5).split(" ") + str += print_row(k, float(v), u) + elif not isinstance(d[sec][k], (list, tuple, np.ndarray)): + str += print_row(k, d[sec][k], None) + else: + pass + str = str[:-4] + str += tab_end + else: + str += tab_start + str += print_section(sec) + for k in d[sec].keys(): + if isinstance(d[sec][k], Quantity): + v,u = d[sec][k].in_best_unit(5).split(" ") + str += print_row(k, float(v), u) + elif isinstance(d[sec][k], (list, tuple, np.ndarray)): + pass + else: + str += print_row(k, d[sec][k], None) + str = str[:-4] + str += tab_end + + + return str + +def legendAsLatex(axes, rotation=90) : + '''Generate a latex code to be used instead of the legend. + Uses the label, color, marker and linestyle provided to the pyplot.plot. + The marker and the linestyle must be defined using the one or two character + abreviations shown in the help of pyplot.plot. + Rotation of the markers must be multiple of 90. + ''' + latexLine = {'-':'\\textbf{\Large ---}', + '-.':'\\textbf{\Large --\:\!$\\boldsymbol{\cdot}$\:\!--}', + '--':'\\textbf{\Large --\,--}',':':'\\textbf{\Large -\:\!-}'} + latexSymbol = {'o':'medbullet', 'd':'diamond', 's':'filledmedsquare', + 'D':'Diamondblack', '*':'bigstar', '+':'boldsymbol{\plus}', + 'x':'boldsymbol{\\times}', 'p':'pentagon', 'h':'hexagon', + ',':'boldsymbol{\cdot}', '_':'boldsymbol{\minus}','<':'LHD', + '>':'RHD','v':'blacktriangledown', '^':'blacktriangle'} + rot90=['^','<','v','>'] + di = [0,-1,2,1][rotation%360//90] + latexSymbol.update({rot90[i]:latexSymbol[rot90[(i+di)%4]] for i in range(4)}) + return ', '.join(['\\textcolor[rgb]{'\ + + ','.join([str(x) for x in cor.to_rgb(handle.get_color())]) +'}{' + + '$\\'+latexSymbol.get(handle.get_marker(),';')+'$' + + latexLine.get(handle.get_linestyle(),'') + '} ' + label + for handle,label in zip(*axes.get_legend_handles_labels())]) + +def figure1_mst(file, modelIdx=3, odorIdx=0, modelType='msp_classicalLabCondLowSparsity-0-15.odor-0.1-sp', **args): + dataSetname = file[6:file[6:].index('/') + 6] + mstFile = "matlab/model_cache/predictions/{}.{}/{}.mat".format(modelType, modelIdx, dataSetname) + trialIdx = int(file.split('/')[-1].split('-')[1]) + print("trialIdx: {} | mstFile: {}".format(trialIdx, mstFile)) + data = load_sim(file) + return figure1(data, mstMatFile=mstFile, mstTrialIdx=trialIdx, mstOdorIdx=odorIdx, **args) + + +def figure1(data, show_rate=False, t_min=0, t_max=None, orn_range=None, pn_range=None, ln_range=None, kc_range=None, cmap=None, mstMatFile=None, mstOdorIdx=None, mstTrialIdx=None, fig_size=None): + + if type(data) == str: + data = load_sim(data) + + dt = data.dt + simtime = data.simtime + warmup_time= data.warmup + M = data.tuning + stimulus = data.stimulus + spikemons = data.spikes + pop_mons = data.rates + state_mons = data.variables + + t_max = data.simtime - data.warmup if t_max is None else t_max + t_min = 0 if t_min is None else t_min + col_map = 'Reds' if cmap is None else cmap + t_offset = 0 + tempotron_sp = None + + if mstMatFile is not None: + mat = scpio.loadmat(mstMatFile) + tempotron_sp = mat['sp_times'][0, mstTrialIdx] + print("MST spikes #{}: {}".format(len(tempotron_sp), tempotron_sp)) + print("stimulus times odor #{} #{}: {}".format(mstOdorIdx, len(data.stimulus_times[mstOdorIdx]), data.stimulus_times[mstOdorIdx])) + if len(data.stimulus_times[mstOdorIdx]) == 1: + t_offset = data.stimulus_times[mstOdorIdx][0] + print("stimulus onset: {}".format(t_offset)) + + #orn_range = [1, np.max(spikemons['ORN'].i) + 1] if orn_range is None else orn_range + #pn_range = [1, np.max(spikemons['PN'].i) + 1] if pn_range is None else pn_range + #ln_range = [1, np.max(spikemons['LN'].i) + 1] if ln_range is None else ln_range + #kc_range = [1, np.max(spikemons['KC'].i) + 1] if kc_range is None else kc_range + + print("orn_range: {}".format(orn_range)) + # trunacte stimulus array to t_min,t_max + stim_dt = data.simtime / data.stimulus.shape[1] + stim_from = int((data.warmup + t_min) / stim_dt) + stim_to = int((data.warmup + t_max) / stim_dt) + stimulus = data.stimulus[:, stim_from:stim_to] + + + smons = [ + ('ORN', ('v', 'g_i', 'g_e'), [360]), + ('PN', ('v', 'g_i', 'g_e'), [15]), + ('LN', ('v', 'g_i', 'g_e'), [15]) + ] + + axs = [] + connectivity_axs = [] + + fig = plt.figure(num=1, figsize=(4, 6) if fig_size is None else fig_size) + n_ticks_timeaxis = 10 + y_label_pad = -0.01 + + rate_win = 50 * ms + # KC_y_lim = 100 + + outer_grid = gridspec.GridSpec(2, 2, hspace=0.0, wspace=0.01, height_ratios=[.1, .9], width_ratios=[.9, .1]) + + network_cell = outer_grid[1:, :-1] + if tempotron_sp is None: + gs_spikes = gridspec.GridSpecFromSubplotSpec(7, 1, network_cell, hspace=0.0, + wspace=0.0, height_ratios=[.1, .2, .1, .3, .1, .1, .1]) + else: + gs_spikes = gridspec.GridSpecFromSubplotSpec(8, 1, network_cell, hspace=0.0, + wspace=0.0, height_ratios=[.1, .1, .1, .3, .1, .1, .1, .1]) + + # gs_connectivity = gridspec.GridSpecFromSubplotSpec(3, 3, outer_grid[4:, -1]) + + # stimulus + axs.append(plt.subplot(outer_grid[0:1, :-1])) + # ORN tuning + #divider = make_axes_locatable(axs[0]) + #ax_tuning = divider.append_axes("right", 1, pad=0.0) + #axs.append(ax_tuning) + + # ORN tuning + axs.append(plt.subplot(outer_grid[0:1, -1])) #1 + + # network spiking activity + if orn_range == -1: + axs.append(plt.subplot(gs_spikes[0:2, :])) # 2 no ORNs - only plot PNs ! + else: + axs.append(plt.subplot(gs_spikes[0, :])) # 2 ORNs + axs.append(plt.subplot(gs_spikes[1, :])) #sharex=axs[2] # 3 PNs + + axs.append(plt.subplot(gs_spikes[2, :])) #sharex=axs[3] # LNs + + if tempotron_sp is not None: + axs.append(plt.subplot(gs_spikes[3:-3, :])) # KCs + axs.append(plt.subplot(gs_spikes[-3, :])) # KC histogram + axs.append(plt.subplot(gs_spikes[-2, :])) # APL + axs.append(plt.subplot(gs_spikes[-1, :])) # MST output + else: + axs.append(plt.subplot(gs_spikes[3:-2, :])) # KCs + axs.append(plt.subplot(gs_spikes[-2, :])) # KC histogram + axs.append(plt.subplot(gs_spikes[-1, :])) # APL + + + # axis handles + if orn_range == -1: + ax_orn = None + ax_pn, ax_ln, ax_kc, ax_kc_hist = (2, 3, 4, 5) + else: + ax_orn, ax_pn, ax_ln, ax_kc, ax_kc_hist = (2, 3, 4, 5, 6) + + ax_apl = -1 + #ax_orn_hist, ax_pn_hist, ax_ln_hist, ax_kc_hist = (6, 7, 8, 9) + + # ORN tuning plot + chars = string.ascii_uppercase #['A', 'B'] + xs = np.arange(M.shape[1]) + for i, p in enumerate(range(M.shape[0])): + axs[1].plot(M[i, :, 0], xs, color='C{}'.format(8 - i), label='{}'.format(chars[i])) + #axs[1].yaxis.set_major_locator(MaxNLocator(4, integer=True)) + axs[1].set_ylabel('') + axs[1].set_yticks([]) + axs[1].set_xlabel('') + axs[1].set_xticks([]) + #axs[1].set_ylabel(legendAsLatex(axs[1])) + #axs[1].set_xlabel('sensivity [a.u.]') + #axs[1].xaxis.set_label_position('top') + #axs[1].xaxis.set_ticks_position('top') + leg = axs[1].legend(loc='upper left', frameon=False, markerscale=.2) + plt.setp(leg.get_texts(), fontsize='8') + axs[1].spines['top'].set_visible(False) + axs[1].spines['right'].set_visible(False) + axs[1].spines['left'].set_visible(False) + axs[1].spines['bottom'].set_visible(False) + + # stimulus plot + coords = (t_min, t_max, 1, M.shape[1]) + h = axs[0].imshow(stimulus, interpolation='hamming', aspect='auto', extent=coords, cmap=col_map) + #divider = make_axes_locatable(axs[0]) + #cax = divider.append_axes("top", size="5%", pad=0.05) + #cb = fig.colorbar(h, cax=cax, orientation="horizontal", format=plt.NullFormatter()) + #cb.set_label('intensity [a.u.]') + #cax.xaxis.set_ticks_position("top") + #cax.xaxis.set_label_position("top") + #axs[0].yaxis.set_major_locator(MaxNLocator(4, integer=True)) + #axs[0].xaxis.set_major_locator(MaxNLocator(n_ticks_timeaxis, integer=True)) + axs[0].set_xticks([]) + axs[0].set_yticks([]) + #axs[0].set_yticks([stimulus.shape[0]]) + axs[0].set_ylabel("ORNs\n") + axs[0].yaxis.set_label_coords(y_label_pad, .5) + axs[0].spines['bottom'].set_visible(False) + # axs[0].set_xlabel('time [sec]') + # axs[0].set_title('ORN stimulation input') + + + + # ORN rasterplot + if ax_orn is not None: + axs[ax_orn].plot(spikemons['ORN'].t - warmup_time - t_offset, spikemons['ORN'].i + 1, '|', linewidth=0.5, markersize=1, + color='C1') + axs[ax_orn].set_xlim(t_min - t_offset, t_max - t_offset) + # axs[ax_orn].yaxis.set_major_locator(MaxNLocator(2, integer=True)) + # axs[ax_orn].xaxis.set_major_locator(MaxNLocator(n_ticks_timeaxis, integer=True)) + axs[ax_orn].set_ylabel("ORNs\n({})".format(orn_range[-1] - orn_range[0] if orn_range is not None else len(spikemons['ORN'].count[:]))) + axs[ax_orn].set_xticks([]) + if orn_range: + axs[ax_orn].set_ylim(orn_range) + #axs[ax_orn].set_yticks([orn_range[-1] - orn_range[0]]) + axs[ax_orn].set_yticks([]) + axs[ax_orn].yaxis.set_label_coords(y_label_pad, .5) + axs[ax_orn].spines['bottom'].set_visible(False) + # pop rate + if (show_rate): + ax_rate = axs[ax_orn].twinx() + ax_rate.plot((pop_mons['ORN'].t - warmup_time), + pop_mons['ORN'].smooth_rate, color='k', alpha=0.8) + mean_rate = pop_mons['ORN'].smooth_rate[int(warmup_time / dt):].mean() + ax_rate.set_yticks([mean_rate]) + + + # PN rasterplot + axs[ax_pn].plot(spikemons['PN'].t - warmup_time - t_offset, spikemons['PN'].i + 1, '|', linewidth=0.5, markersize=1, + color='C0') + axs[ax_pn].set_xlim(t_min - t_offset, t_max - t_offset) + axs[ax_pn].set_xticks([]) + axs[ax_pn].set_ylabel("PNs\n({})".format(pn_range[-1] - pn_range[0] if pn_range is not None else len(spikemons['PN'].count[:]))) + if pn_range: + axs[ax_pn].set_ylim(pn_range) + #axs[ax_pn].set_yticks([pn_range[-1] - pn_range[0]]) + axs[ax_pn].set_yticks([]) + axs[ax_pn].yaxis.set_label_coords(y_label_pad, .5) + axs[ax_pn].spines['bottom'].set_visible(False) + # pop rate + if (show_rate): + ax_rate = axs[ax_pn].twinx() + ax_rate.plot((pop_mons['PN'].t - warmup_time), + pop_mons['PN'].smooth_rate, color='k', alpha=0.8) + mean_rate = pop_mons['PN'].smooth_rate[int(warmup_time/dt):].mean() + ax_rate.set_yticks([mean_rate]) + + # LN rasterplot + axs[ax_ln].plot(spikemons['LN'].t - warmup_time - t_offset, spikemons['LN'].i + 1, '|', linewidth=0.5, markersize=1, + color='C2') + axs[ax_ln].set_xlim(t_min - t_offset, t_max - t_offset) + axs[ax_ln].set_xticks([]) + axs[ax_ln].set_ylabel("LNs\n({})".format(ln_range[-1] - ln_range[0] if ln_range is not None else len(spikemons['LN'].count[:]))) + if ln_range: + axs[ax_ln].set_ylim(ln_range) + #axs[ax_ln].set_yticks([ln_range[-1] - ln_range[0]]) + axs[ax_ln].set_yticks([]) + axs[ax_ln].yaxis.set_label_coords(y_label_pad, .5) + axs[ax_ln].spines['bottom'].set_visible(False) + # pop. rate + if (show_rate): + ax_rate = axs[ax_ln].twinx() + ax_rate.plot((pop_mons['LN'].t - warmup_time), + pop_mons['LN'].smooth_rate, color='k', alpha=0.8) + mean_rate = pop_mons['LN'].smooth_rate[int(warmup_time / dt):].mean() + ax_rate.set_yticks([mean_rate]) + + # KC rasterplot + print((spikemons['KC'].i)) + print((spikemons['KC'].t - warmup_time - t_offset)) + axs[ax_kc].plot(spikemons['KC'].t - warmup_time - t_offset, spikemons['KC'].i + 1, '|', linewidth=0.5, markersize=1, + color='C3') + #axs[ax_kc].set_ylim(1, len(spikemons['KC'].count[:])) + axs[ax_kc].set_xlim(t_min - t_offset, t_max - t_offset) + axs[ax_kc].set_xticks([]) + axs[ax_kc].yaxis.set_major_locator(MaxNLocator(2, integer=True)) + #axs[ax_kc].xaxis.set_major_locator(MaxNLocator(n_ticks_timeaxis, integer=True)) + axs[ax_kc].set_ylabel("KCs\n({})".format(kc_range[-1] - kc_range[0] if kc_range is not None else len(spikemons['KC'].count[:]))) + if kc_range: + axs[ax_kc].set_ylim(kc_range) + else: + axs[ax_kc].set_ylim(1, len(spikemons['KC'].count[:])) + #axs[ax_kc].set_yticks([kc_range[-1] - kc_range[0]]) + axs[ax_kc].set_yticks([]) + axs[ax_kc].yaxis.set_label_coords(y_label_pad, .5) + + # pop. rate + if (show_rate): + ax_rate = axs[ax_kc].twinx() + ax_rate.set_ylabel('pop. rate [Hz]') + ax_rate.yaxis.set_label_coords(1.01, .5) + ax_rate.plot((pop_mons['KC'].t - warmup_time), + pop_mons['KC'].smooth_rate, color='k', alpha=0.8) + mean_rate = pop_mons['KC'].smooth_rate[int(warmup_time / dt):].mean() + ax_rate.set_yticks([mean_rate]) + + + ts = np.arange(0, simtime, dt) + n_kcs = len(spikemons['KC'].count[:]) + bins = np.arange(0, len(spikemons['KC'].count[:]) + 1) + print("#KCs={}".format(len(spikemons['KC'].count[:]))) + kc_active_percentage = [np.where(np.logical_and(spikemons['KC'].t >= t, spikemons['KC'].t <= t + dt)) for t in ts] + + for i,idx in enumerate(kc_active_percentage): + kc_active_percentage[i] = (len(np.unique(spikemons['KC'].i[idx])) * 100) / n_kcs + + #y = [1 - ((x.mean()**2) / ((x*x).mean()+ 1e-8)) for x in tmp] + #print(kc_active_percentage) + idx = np.where((ts - warmup_time) >= t_min)[0].tolist() + axs[ax_kc_hist].plot(ts - warmup_time - t_offset, kc_active_percentage, 'k', drawstyle='steps') + axs[ax_kc_hist].fill_between(ts - warmup_time - t_offset, kc_active_percentage, color='k', step="pre") + axs[ax_kc_hist].set_ylabel("% KCs\n") + axs[ax_kc_hist].yaxis.set_label_coords(y_label_pad, .5) + axs[ax_kc_hist].yaxis.set_tick_params({'pad': 0.05}) + #axs[ax_kc_hist].yaxis.set_major_locator(MaxNLocator(2, integer=True)) + axs[ax_kc_hist].set_xlim(t_min - t_offset, t_max - t_offset) + axs[ax_kc_hist].set_ylim(0, np.max(kc_active_percentage[idx[0]:]) + 0.5) + axs[ax_kc_hist].set_xticks([]) + axs[ax_kc_hist].set_yticks([0, np.ceil(np.max(kc_active_percentage[idx[0]:]))]) + #ax_rate.plot(ts - warmup_time, y, 'k', alpha=0.8) + #ax_rate.set_ylim(0.8,1.1) + #ax_rate.set_ylabel('sparseness') + # axs[ax_kc].set_xlabel('time [sec]') + #axs[ax_kc].set_yticks( + # np.arange(start=1, step=int(len(spikemons['KC'].count[:]) // 2) - 1, stop=len(spikemons['KC'].count[:]))) + + if tempotron_sp is not None: + ax_apl = -2 + #axs[-1].vlines(np.array(data.stimulus_times[mstOdorIdx]) - t_offset, 1.5, 2.5, linewidth=1.5, color=[.3,.3,.3]) + axs[-1].vlines(np.array(tempotron_sp) - t_offset, 0, 1, linewidth=1.5, color=[.3,.3,.3]) + axs[-1].set_yticks([]) + axs[-1].set_ylim(-.1, 1.5) + axs[ax_kc_hist].yaxis.set_label_coords(y_label_pad, .5) + axs[-1].yaxis.set_tick_params({'pad': 0.05}) + axs[-1].set_ylabel("MBON\n") + + # APL rasterplot + if 'APL' in spikemons.keys(): + #print(spikemons['APL'].t - warmup_time - t_offset) + axs[ax_apl].plot(spikemons['APL'].t - warmup_time - t_offset, spikemons['APL'].i + 1, '|', linewidth=.5, markersize=10, + color='k') + axs[ax_apl].set_xlim(t_min - t_offset, t_max - t_offset) + axs[ax_apl].set_xticks([]) + axs[ax_apl].yaxis.set_major_locator(MaxNLocator(2, integer=True)) + axs[ax_apl].set_ylabel("APL\n") + axs[ax_apl].set_ylim([0.9, 1.2]) + #axs[ax_apl].set_ylim(1) + axs[ax_apl].set_yticks([]) + axs[ax_apl].yaxis.set_tick_params({'pad': 0.05}) + axs[ax_apl].yaxis.set_label_coords(y_label_pad, .5) + + # time axis on last subplot + axs[-1].set_xlim(t_min - t_offset, t_max - t_offset) + axs[-1].xaxis.set_major_locator(MaxNLocator(5, integer=True)) + axs[-1].set_xlabel('time [sec]') + + fig.align_labels(axs[ax_kc]) + + #fig.tight_layout(w_pad=0.5, h_pad=0.5) + return fig + + +def tempotron_response(matFile, mstOdorIdx, ax=None, modelName=None, showGaussian=False, t_min=0, t_max=10): + + fig = None + + if ax is None: + print("creating new figure ...") + fig = plt.figure(figsize=(8,2)) + ax = plt.gca() + + + def get_tempotron_spikes(matFile, modelName=None): + mat = scpio.loadmat(matFile) + tempotron_sp = None + N_models = len(mat['data']['predictions'][0][0][0]) + if modelName is not None: + for k in range(N_models): + row = mat['data']['predictions'][0][0][0][k] + if modelName == row[0][0][0]: + tempotron_sp = row[0][5][0][0][0] + + if tempotron_sp is None: + raise Exception("model {} not found".format(modelName)) + + return tempotron_sp + else: + # use the first available prediction + return mat['data']['predictions'][0][0][0][0][0][5][0][0][0] + + spikes_y_offset = .8 + tick_size = .5 + + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['left'].set_visible(False) + + #dt = 1/10000 + #t = np.arange(0, kernel_duration, dt) + #x, _ = np.histogram(tempotron_sp, bins=np.arange(t_min, t_max + 2 * kernel_duration + dt, dt)) + #rate_est = np.convolve(x, norm_kernel, 'same')[len(t):len(t) + 1] + + for k,f in enumerate(matFile): + data = load_sim(f[:-4] + ".npz") + tempotron_sp = get_tempotron_spikes(f, modelName) + + ax.vlines(tempotron_sp + (k*t_max), spikes_y_offset, spikes_y_offset + tick_size, linewidth=1., color='r', label='output spike') + ax.vlines(np.array(data.stimulus_times[mstOdorIdx]) + (k*t_max), 0, tick_size, linewidth=1., color='b', label='filament crossing') + + if showGaussian: + xs = np.linspace(t_min, t_max, 10000) + ax2 = ax.twinx() + pdf_g = stats.norm.pdf(xs, showGaussian[0], showGaussian[1]) + ax2.plot(xs + (k*t_max), pdf_g + spikes_y_offset, color='k', alpha=.6, linewidth=2., label="Norm({},{})".format(showGaussian[0], showGaussian[1])) + ax2.fill_between(xs + (k*t_max), pdf_g + spikes_y_offset, facecolor='b', alpha=0.1) + ax2.set_yticks([]) + ax2.set_ylim(spikes_y_offset, spikes_y_offset + tick_size) + ax2.set_ylabel(None) + ax2.spines['top'].set_visible(False) + ax2.spines['left'].set_visible(False) + ax2.spines['right'].set_visible(False) + ax2.spines['bottom'].set_visible(False) + + ax2.legend(['Pr(filament)'], loc='upper right') + + ax.set_yticks([]) + ax.set_ylim(0, (spikes_y_offset + tick_size) * 1.2) + ax.set_xlim(t_min, len(matFile) * t_max) + ax.set_xlabel('time [sec]') + ax.legend(['filament crossing', 'output spike'], loc='upper left', ncol=2) + + if (fig is not None): + fig.tight_layout() + + return fig diff --git a/olnet/tuning.py b/olnet/tuning.py new file mode 100644 index 0000000..007e215 --- /dev/null +++ b/olnet/tuning.py @@ -0,0 +1,155 @@ +import numpy as np +from math import pi as PI +from brian2 import * + +def get_receptor_tuning(n_receptor_types, n_odors=1, receptors_per_odor=5, peak_rate=150): + """ + compute a circular tuning profile over n_receptor_types for some number of odors (n_odors). + Returns array of shape (n_receptor_types,n_odors) + :param n_receptor_types: is equal to number of glomeruli + :param n_odors: number of different odors + :param receptors_per_odor: number of receptors activated by each odor + :param peak_rate: peak rate of each 0.5 sin profile + :return: ndarray + """ + + assert n_odors <= n_receptor_types, "no. of odors must be <= no. of receptor types" + assert np.mod(receptors_per_odor, 2) != 2, "receptors per odor must be non-even number" + + y = np.zeros((n_receptor_types, n_odors), dtype=int) + receptors = np.arange(0, n_receptor_types) + + for idx_odor, odor_id in enumerate(range(0, n_odors)): + x = (np.mod((receptors - odor_id), n_receptor_types)) / (receptors_per_odor + 1.0) + idx = np.logical_and(x > 0, x < 1) + x[~idx] = 0 + r = peak_rate * np.sin(x * PI) # tuning by 0.5 sin cycle with peak amplitude orn_peak_rate + y[:, idx_odor] = r # all activation profiles for different stimuli + + return y + + +def get_orn_tuning(receptor_tuning, odor_ids=None, n_orns=50, dim=1): + """ + compute tuning of individual ORNs from receptor_type tuning. It's basically upsampling to number of ORN resolution + Returns a matrix of shape (n_odors,n_receptor_types*n_orns,dim) + :param receptor_tuning: matrix of receptor_type tuning + :param odor_ids: odors to compute ORN tuning for. If None, it will be computed for ALL odors available in receptor_tuning + :param n_orns: is equal to number of ORN neurons per receptor type/glomeruli + :param dim: size of each ORN tuning - defaults to 1 (e.g. time dimension) + :return: ndarray + """ + + if odor_ids is None: + odor_ids = list(range(receptor_tuning.shape[1])) + + if (np.isscalar(odor_ids)): + odor_ids = [odor_ids] + + n_receptors = receptor_tuning.shape[0] * n_orns + Y = np.zeros((len(odor_ids), n_receptors, dim)) + for odor_idx, odor_id in enumerate(odor_ids): + for rec_idx, rec_rate in enumerate(receptor_tuning[:, odor_idx]): + Y[odor_idx, rec_idx * n_orns:((rec_idx + 1) * n_orns), :] = rec_rate + + return Y + + +def create_stimulation_matrix(trigger_ts, receptor_profiles, bg_rate=0, warmup_padding=None, odor_idx=None): + """ + given a list of binary time-series, create a stimulation matrix using given receptor_profiles + Returns array of shape (n_receptors, time-steps) + :param trigger_ts: list of (multiple) binary time-series + :param receptor_profiles: receptor profiles for each stimulation type (e.g. each binary time-series) + :param bg_rate: background rate + :param warmup_padding: optionally prepend some 'warmup_padding' timesteps + :param odor_idx: only include triggers for specified odors + :return: ndarray + """ + assert (len(trigger_ts) == receptor_profiles.shape[0]), "number of binary stimulation protocols must be equal to receptor_profiles" + n_receptors = receptor_profiles.shape[1] + + A = np.ones((n_receptors, len(trigger_ts[0]))) * bg_rate + for i, odor in enumerate(trigger_ts): + if odor_idx is not None and i in odor_idx: + continue + + idx = np.where(odor == 1) + for j in idx[0]: + A[:, j] += receptor_profiles[i, :, 0] + + if warmup_padding is not None: + warmup_pad = np.ones((n_receptors, warmup_padding)) * bg_rate + return np.hstack((warmup_pad, A)) + + return A + + +def gen_shot_noise(rate, T, tau=8, dt=0.001, dim=1, scale=1.0): + """ + generate shot noise time-series by filtering Poisson noise. + Returns ndarray of shape (dim, T * 1/dt) + :param rate: Poisson rate + :param T: duration + :param tau: exp. filter time-constant + :param dt: sampling / binning + :param dim: number of time-series to generate + :param scale: scale noise to max value of scale + :return: TimedArray + """ + X = np.zeros((dim, int(T / dt))) + n_spikes = rate * T + kernel_duration = 1 + t = np.arange(1e-8, kernel_duration, dt) + + assert X.shape[1] > t.shape[0], "duration must be larger than filter window" + + exp = lambda x: np.exp(-tau/x) + kernel = exp(t) + norm_kernel = kernel / np.sum(kernel) + + for d in range(dim): + sp_times = np.random.uniform(0, T+2*kernel_duration, np.random.poisson(n_spikes)) + x, _ = np.histogram(sp_times, bins=np.arange(0, T + 2*kernel_duration + dt, dt)) + X[d, :] = np.convolve(x, norm_kernel, 'same')[len(t):len(t)+int(T / dt)] + + #X = np.linalg.norm(X, axis=1)[:, np.newaxis] + return TimedArray((X/X.max(1)[:,np.newaxis]).T * scale, dt=dt * second) + + +def combine_noise_with_protocol(X_protocol: TimedArray, X_noise: TimedArray): + """ + combine some (noise) signal with a stimulation protocol time-series. + The signals is filtered by the protocol time-series + :param X_protocol: stimulation protocol (binary/rect signal) + :param X_noise: + :return: TimedArray + """ + dt = X_noise.dt + T = X_protocol.values.shape[0] * X_protocol.dt + t = np.arange(0, T, dt) * second + print("T={} | dt={} | proto_dt: {} | X_noise: {} | X_protocol: {} | X_protocol(t): {}".format(T, dt, X_protocol.dt, X_noise.values.shape, X_protocol.values.shape, X_protocol(t).shape)) + mask = np.tile(X_protocol(t), (X_noise.values.shape[1], 1)).T + L = (X_noise.values * mask) + return TimedArray(L, dt=dt * second) + + +def gen_gauss_sequence(T, dt, std=1.5, mu=None, N=8): + """ + generate sequence of discrete events that are approx. gaussian distributed + :param T: duration of sequence in sec + :param dt: + :param std: std. dev. of gaussian + :param mu: mean of gaussian (or default to T/2) + :param N: no. of samples to draw + :return: list + """ + ts = np.zeros(int(T/dt)) + stim_pos = np.random.normal(T/2 if mu is None else mu, std, N + np.random.poisson(5)) # make no of samples a bit noisy + stim_pos_idx = stim_pos/dt + # add noise to duration of each stimulus + stim_pos_noise = [[n, n+np.random.choice(2, 1,p=[.3,.7])[0], n-np.random.choice(2, 1, p=[.2,.8])[0]] for n in np.random.choice(stim_pos_idx.astype(np.int), len(stim_pos_idx))] + flat_noise = [item for sublist in stim_pos_noise for item in sublist] + stim_idx = np.union1d(stim_pos_idx.astype(np.int), flat_noise).astype(np.int) + ts[stim_idx] = 1 + return ts