From 3b3f72985ed6b79a5a75f3c0967d9e5e8c170550 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 11 Aug 2024 21:31:51 -0400 Subject: [PATCH] Fix doc/toml/gy6or6_predict.toml --- doc/toml/gy6or6_predict.toml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index 5061488be..b82cf048c 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -61,12 +61,16 @@ min_segment_dur = 0.01 # dataset.params = parameters used for datasets # for a frame classification model, we use dataset classes with a specific `window_size` -[vak.predict.dataset.params] -window_size = 176 +[vak.predict.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } -# Note we do not specify any options for the network, and just use the defaults -# We need to put this table here though, to indicate which model we are using. -[vak.predict.model.TweetyNet] +# We put this table though vak knows which model we are using +[vak.predict.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 # this sub-table configures the `lightning.pytorch.Trainer` [vak.predict.trainer]