Skip to content

Commit

Permalink
Made last layer activation explicit block
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed May 20, 2024
1 parent f5598d5 commit dd3b353
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions mlcolvar/cvs/committor/committor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward
from mlcolvar.core.loss import CommittorLoss
from mlcolvar.core.nn.utils import Custom_Sigmoid

__all__ = ["Committor"]

Expand Down Expand Up @@ -31,7 +32,7 @@ class Committor(BaseCV, lightning.LightningModule):
Utils to initialize the masses tensor for the training
"""

BLOCKS = ["nn"]
BLOCKS = ["nn", "sigmoid"]

def __init__(
self,
Expand Down Expand Up @@ -79,23 +80,16 @@ def __init__(
# ======= OPTIONS =======
# parse and sanitize
options = self.parse_options(options)

# add the relevant nn options, set tanh for hidden layers and sharp sigmoid for output layer
activ_list = ["tanh" for i in range( len(layers) - 2 )]
activ_list.append("custom_sigmoid")

# update options dict for activations if not already set
if not "activation" in options["nn"]:
options["nn"]["activation"] = activ_list

# ======= CHECKS =======
# should be empty in this case


# ======= BLOCKS =======
# initialize NN turning on last layer activation
# initialize NN turning
o = "nn"
self.nn = FeedForward(layers, last_layer_activation=True, **options[o])
self.nn = FeedForward(layers, **options[o])

# separately add sigmoid activation on last layer, this way it can be deactived
o = "sigmoid"
if (options[o] is not False) and (options[o] is not None):
self.sigmoid = Custom_Sigmoid(**options[o])


def training_step(self, train_batch, batch_idx):
Expand Down Expand Up @@ -153,5 +147,7 @@ def test_committor():
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
trainer.fit(model, datamodule)

model(X).sum().backward()

if __name__ == "__main__":
test_committor()

0 comments on commit dd3b353

Please sign in to comment.