Skip to content

Commit

Permalink
upgrades pytorch lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
cheind committed Dec 9, 2021
1 parent 8a4a2f0 commit 150ecf4
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 42 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@ python -m autoregressive.scripts.wavenet_mnist classify --config "models/mnist_q
## Train
To train / reproduce a model
```bash
python -m autoregressive.scripts.train --config "models/mnist_q2/config.yaml"
python -m autoregressive.scripts.train fit --config "models/mnist_q2/config.yaml"
```
Progress is logged to Tensorboard
```
tensorboard --logdir lightning_logs
```
To generate a training configuration file for a specific dataset use
```
python -m autoregressive.scripts.train fit --data autoregressive.datasets.FSeriesDataModule --print_config > fseries_config.yaml
```

## Test
To run the tests
Expand Down
25 changes: 14 additions & 11 deletions models/fseries_q127/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
seed_everything: null
trainer:
logger: true
checkpoint_callback: true
checkpoint_callback: null
enable_checkpointing: true
callbacks: null
default_root_dir: null
gradient_clip_val: 0.0
gradient_clip_algorithm: norm
gradient_clip_val: null
gradient_clip_algorithm: null
process_position: 0
num_nodes: 1
num_processes: 1
Expand All @@ -16,30 +17,32 @@ trainer:
ipus: null
log_gpu_memory: null
progress_bar_refresh_rate: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: 1
accumulate_grad_batches: null
max_epochs: 30
min_epochs: null
max_steps: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
limit_predict_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
flush_logs_every_n_steps: null
log_every_n_steps: 25
accelerator: null
strategy: null
sync_batchnorm: false
precision: 32
enable_model_summary: true
weights_summary: top
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: false
Expand All @@ -48,16 +51,16 @@ trainer:
reload_dataloaders_every_epoch: false
auto_lr_find: false
replace_sampler_ddp: true
terminate_on_nan: false
detect_anomaly: false
auto_scale_batch_size: false
prepare_data_per_node: true
prepare_data_per_node: null
plugins: null
amp_backend: native
amp_level: O2
distributed_backend: null
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
stochastic_weight_avg: false
terminate_on_nan: null
model:
wave_dilations:
- 1
Expand Down
29 changes: 16 additions & 13 deletions models/mnist_q2/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
seed_everything: null
trainer:
logger: true
checkpoint_callback: true
checkpoint_callback: null
enable_checkpointing: true
callbacks: null
default_root_dir: null
gradient_clip_val: 0.0
gradient_clip_algorithm: norm
gradient_clip_val: null
gradient_clip_algorithm: null
process_position: 0
num_nodes: 1
num_processes: 1
Expand All @@ -16,30 +17,32 @@ trainer:
ipus: null
log_gpu_memory: null
progress_bar_refresh_rate: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: 1
max_epochs: 3000
accumulate_grad_batches: null
max_epochs: 30
min_epochs: null
max_steps: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: 1.0
limit_val_batches: 0.2
limit_val_batches: 1.0
limit_test_batches: 1.0
limit_predict_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
flush_logs_every_n_steps: null
log_every_n_steps: 25
accelerator: null
strategy: null
sync_batchnorm: false
precision: 32
enable_model_summary: true
weights_summary: top
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: false
Expand All @@ -48,16 +51,16 @@ trainer:
reload_dataloaders_every_epoch: false
auto_lr_find: false
replace_sampler_ddp: true
terminate_on_nan: false
detect_anomaly: false
auto_scale_batch_size: false
prepare_data_per_node: true
prepare_data_per_node: null
plugins: null
amp_backend: native
amp_level: O2
distributed_backend: null
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
stochastic_weight_avg: false
terminate_on_nan: null
model:
wave_dilations:
- 1
Expand Down
29 changes: 16 additions & 13 deletions models/mnist_q256/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
seed_everything: null
trainer:
logger: true
checkpoint_callback: true
checkpoint_callback: null
enable_checkpointing: true
callbacks: null
default_root_dir: null
gradient_clip_val: 0.0
gradient_clip_algorithm: norm
gradient_clip_val: null
gradient_clip_algorithm: null
process_position: 0
num_nodes: 1
num_processes: 1
Expand All @@ -16,30 +17,32 @@ trainer:
ipus: null
log_gpu_memory: null
progress_bar_refresh_rate: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: 1
max_epochs: 3000
accumulate_grad_batches: null
max_epochs: 30
min_epochs: null
max_steps: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: 1.0
limit_val_batches: 0.2
limit_val_batches: 1.0
limit_test_batches: 1.0
limit_predict_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
flush_logs_every_n_steps: null
log_every_n_steps: 25
accelerator: null
strategy: null
sync_batchnorm: false
precision: 32
enable_model_summary: true
weights_summary: top
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: false
Expand All @@ -48,16 +51,16 @@ trainer:
reload_dataloaders_every_epoch: false
auto_lr_find: false
replace_sampler_ddp: true
terminate_on_nan: false
detect_anomaly: false
auto_scale_batch_size: false
prepare_data_per_node: true
prepare_data_per_node: null
plugins: null
amp_backend: native
amp_level: O2
distributed_backend: null
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
stochastic_weight_avg: false
terminate_on_nan: null
model:
wave_dilations:
- 1
Expand Down
6 changes: 3 additions & 3 deletions requirements/common.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch
torchvision
pytorch_lightning
torch>=1.9.0
torchvision>=0.10.0
pytorch_lightning>=1.5.5
jsonargparse[signatures]
numpy
matplotlib
Expand Down
3 changes: 2 additions & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
scipy
scipy
pytest

0 comments on commit 150ecf4

Please sign in to comment.