Skip to content

Commit

Permalink
Major refactorings, plotting notebooks, license change
Browse files Browse the repository at this point in the history
  • Loading branch information
jsvetter committed Aug 20, 2024
1 parent 7103022 commit 8b0b55d
Show file tree
Hide file tree
Showing 41 changed files with 2,676 additions and 1,442 deletions.
682 changes: 21 additions & 661 deletions LICENSE

Large diffs are not rendered by default.

40 changes: 24 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# Generating realistic neurophysiological time series with denoising diffusion probabilistic models

This repository contains research code for the preprint [Generating realistic neurophysiological time series with denoising diffusion probabilistic models](https://www.biorxiv.org/content/10.1101/2023.08.23.554148v1).
Code for [Generating realistic neurophysiological time series with denoising diffusion probabilistic models](https://www.biorxiv.org/content/10.1101/2023.08.23.554148v1).

It contains all the scripts to run the experiments presented in the paper.
To run the experiments, first download the datasets (links in the paper) and change the `filepath` parameter in the dataset configuration files in `conf/dataset`. To save all outputs locally, set the `save_path` parameter in `conf/base` to the desired location. Experiment tracking and saving with [Weights & Biases](https://wandb.ai/site) is also supported, but disabled by default.
The repository contains all the scripts to run the experiments presented in the paper.
To run the experiments, first download the datasets (links in the paper) and change the `filepath` parameter in the dataset configuration files in `conf/dataset`. To save all outputs locally, set the `save_path` parameter in `conf/base` to the desired location.

Experiment tracking and saving with [Weights & Biases](https://wandb.ai/site) is also supported, but disabled by default.
Finally, use the shell scripts provided in `scripts` to run the experiments!

Additionally, an example jupyter notebook `example_ner_notebook.ipynb`, where a diffusion model is trained to generate the BCI Challenge @ NER 2015 data, is provided together with the data.

Coming soon: Notebooks for plotting the figures!


## Usage

Install dependencies:
## Installation
Install dependencies via pip:
```shell
git clone https://github.com/mackelab/neural_timeseries_diffusion.git
cd neural_timeseries_diffusion
pip install -e .
```

Run shell scripts:
## Further information

#### How can apply the DDPM to my datasets?
This requires writing a dataloader for your dataset. Example dataloaders (for all the datasets used in paper) can be found in
`datasets.py`.

#### What options do I have for the denoiser architectures?
Two denoiser different denoiser architectures based on [structrured convolutions](https://arxiv.org/abs/2210.09298) are provided, which only differ in the way conditional information is passed to the network.

In general, the architecture using adaptive layer norm (`AdaConv`) is preferable and a good first choice over the one where conitional information is just concatenated (`CatConv`) (see [this paper](https://arxiv.org/abs/2212.09748) for a comparison of both approaches for an image diffusion model).

#### White noise vs. OU process
The standard white noise is a good initial choice. However, using the OU process can improve the quality of the generated power spectra. This requires setting the length scale of the OU process as an additional hyperparameter in the `diffusion_kernel` config files.

## Running the experiments
After downloading the datasets and updating the file paths, run the experiment shell scripts:
```shell
cd scripts
./<name_of_script>.sh
```

Run the example jupyter notebook:
- Make sure that the `jupyter` package is installed
- Open and execute the notebook


This produces result files, which can then be passed to the plotting notebooks.
4 changes: 2 additions & 2 deletions conf/base/wandb_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ run_experiment: True
use_cuda_if_available: True

wandb_mode: disabled
wandb_project: time_series_diffusion
wandb_project: TODO
wandb_entity: TODO
home_path: "TODO" # wandb home path
save_path: null # Local saving
save_path: null # local saving
2 changes: 1 addition & 1 deletion conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- base: wandb_default
- dataset: ner
- network: lconv_ner
- network: ada_conv_ner
- diffusion_kernel: white_noise
- diffusion: diffusion_linear_200
- optimizer: base_optimizer
Expand Down
2 changes: 0 additions & 2 deletions conf/dataset/ner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ self: ner
filepath: "../data"
patient_id: "S02"
signal_length: 260
with_time_emb: True
cond_time_dim: 32
train_test_split: 1.0
split_seed: 0
4 changes: 0 additions & 4 deletions conf/diffusion/diffusion_cosine_500.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _global_
swr_prediction_experiment:
channel: 0 # can be integer or list of integers
channel: 0
batch_size: 1000
13 changes: 13 additions & 0 deletions conf/network/ada_conv_ajile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
self: ada_conv
signal_channel: 64 # only for P07
cond_dim: 128 # only for P07
hidden_channel: 8
in_kernel_size: 53
out_kernel_size: 53
slconv_kernel_size: 53
num_scales: 4
num_blocks: 3
num_off_diag: 8
padding_mode: circular
use_fft_conv: False
use_pos_emb: False
13 changes: 13 additions & 0 deletions conf/network/ada_conv_crcns.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
self: ada_conv
signal_channel: 3
cond_dim: 0
hidden_channel: 64
in_kernel_size: 64
out_kernel_size: 64
slconv_kernel_size: 64
num_scales: 5
num_blocks: 4
num_off_diag: 64
padding_mode: circular
use_fft_conv: False
use_pos_emb: False
13 changes: 13 additions & 0 deletions conf/network/ada_conv_ner.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
self: ada_conv
signal_channel: 56
cond_dim: 0
hidden_channel: 8
in_kernel_size: 65
out_kernel_size: 65
slconv_kernel_size: 65
num_scales: 1
num_blocks: 3
num_off_diag: 8
padding_mode: circular
use_fft_conv: False
use_pos_emb: True
13 changes: 13 additions & 0 deletions conf/network/ada_conv_tycho.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
self: ada_conv
signal_channel: 12
cond_dim: 1
hidden_channel: 32
in_kernel_size: 101
out_kernel_size: 101
slconv_kernel_size: 101
num_scales: 1
num_blocks: 3
num_off_diag: 32
padding_mode: circular
use_fft_conv: False
use_pos_emb: False
15 changes: 15 additions & 0 deletions conf/network/cat_conv_ajile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
self: cat_conv
signal_channel: 64 # only for P07
time_dim: 16
cond_channel: 0
hidden_channel: 32
in_kernel_size: 1
out_kernel_size: 1
slconv_kernel_size: 53
num_scales: 4
heads: 3
num_blocks: 3
num_off_diag: 4
padding_mode: circular
use_fft_conv: False
use_pos_emb: False
15 changes: 15 additions & 0 deletions conf/network/cat_conv_crcns.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
self: cat_conv
signal_channel: 3
time_dim: 16
cond_channel: 0
hidden_channel: 64
in_kernel_size: 32
out_kernel_size: 32
slconv_kernel_size: 64
num_scales: 5
heads: 3
num_blocks: 3
num_off_diag: 64
padding_mode: zeros
use_fft_conv: False
use_pos_emb: False
15 changes: 15 additions & 0 deletions conf/network/cat_conv_ner.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
self: cat_conv
signal_channel: 56
time_dim: 16
cond_channel: 0
hidden_channel: 16
in_kernel_size: 1
out_kernel_size: 1
slconv_kernel_size: 65
num_scales: 1
heads: 1
num_blocks: 3
num_off_diag: 8
padding_mode: circular
use_fft_conv: False
use_pos_emb: True
15 changes: 15 additions & 0 deletions conf/network/cat_conv_tycho.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
self: cat_conv
signal_channel: 12
time_dim: 16
cond_channel: 0
hidden_channel: 32
in_kernel_size: 1
out_kernel_size: 1
slconv_kernel_size: 101
num_scales: 1
heads: 3
num_blocks: 3
num_off_diag: 4
padding_mode: circular
use_fft_conv: False
use_pos_emb: False
19 changes: 0 additions & 19 deletions conf/network/lconv_ajile.yaml

This file was deleted.

19 changes: 0 additions & 19 deletions conf/network/lconv_crcns.yaml

This file was deleted.

19 changes: 0 additions & 19 deletions conf/network/lconv_ner.yaml

This file was deleted.

19 changes: 0 additions & 19 deletions conf/network/lconv_tycho.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions conf/optimizer/base_optimizer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ lr: 0.0004
weight_decay: 0.01
num_epochs: 100
train_batch_size: 32
scheduler_milestones: # Large milestone to avoid using scheduler.
- 1000000
scheduler_milestones:
- 1000000 # dummy value, no scheduler
scheduler_gamma: 1.0
55 changes: 55 additions & 0 deletions matplotlibrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# http://matplotlib.org/users/customizing.html

# Note: Units are in pt not in px
#
# How to convert px to pt in Inkscape
# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > This means if you need 10pt - use 12.5 in Inkscape (multiply with 1.25).
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964

text.usetex : False
mathtext.default : regular

font.serif : Arial, sans-serif
font.sans-serif : Arial, sans-serif
font.cursive : Arial, sans-serif
font.size : 10
figure.titlesize : medium
legend.fontsize : medium
axes.titlesize : medium
axes.labelsize : medium
xtick.labelsize : medium
ytick.labelsize : medium

image.interpolation : nearest
image.resample : False
image.composite_image : True

axes.spines.left : True
axes.spines.bottom : True
axes.spines.top : False
axes.spines.right : False
axes.facecolor : 'white'

axes.linewidth : 1.5
xtick.major.width : 1.5
xtick.minor.width : 1.5
ytick.major.width : 1.5
ytick.minor.width : 1.5
lines.linewidth : 1.5 # default
lines.markersize : 3.0

legend.frameon : False

# Formatting locally for now
# figure.autolayout: True
# figure.constrained_layout.use: True
# savefig.dpi : 300
# savefig.format : pdf
# savefig.bbox : tight
# savefig.pad_inches : 0.1
savefig.facecolor : 'white'

pdf.fonttype : 42

figure.max_open_warning : 0
Loading

0 comments on commit 8b0b55d

Please sign in to comment.