Skip to content

Commit

Permalink
Fix doc/toml/gy6or6_predict.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed Aug 12, 2024
1 parent 21be547 commit 3b3f729
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions doc/toml/gy6or6_predict.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,16 @@ min_segment_dur = 0.01

# dataset.params = parameters used for datasets
# for a frame classification model, we use dataset classes with a specific `window_size`
[vak.predict.dataset.params]
window_size = 176
[vak.predict.dataset]
path = "/copy/path/from/train/config/here"
params = { window_size = 176 }

# Note we do not specify any options for the network, and just use the defaults
# We need to put this table here though, to indicate which model we are using.
[vak.predict.model.TweetyNet]
# We put this table though vak knows which model we are using
[vak.predict.model.TweetyNet.network]
# hidden_size: the number of elements in the hidden state in the recurrent layer of the network
# we trained with hidden size = 256 so we need to evaluate with the same hidden size;
# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint
hidden_size = 256

# this sub-table configures the `lightning.pytorch.Trainer`
[vak.predict.trainer]
Expand Down

0 comments on commit 3b3f729

Please sign in to comment.