-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Advanced Initialization Methods #66
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
|
||
from .baseModel import BaseModel | ||
|
||
|
@@ -8,9 +9,13 @@ class Seq2seq(BaseModel): | |
and decoder. | ||
""" | ||
|
||
def __init__(self, encoder, decoder, decode_function=F.log_softmax): | ||
def __init__(self, encoder, decoder, decode_function=F.log_softmax, | ||
uniform_init=0, glorot_init=False): | ||
super(Seq2seq, self).__init__(encoder_module=encoder, | ||
decoder_module=decoder, decode_function=decode_function) | ||
decoder_module=decoder, | ||
decode_function=decode_function) | ||
# Initialize Weights | ||
self._init_weights(uniform_init, glorot_init) | ||
|
||
def flatten_parameters(self): | ||
""" | ||
|
@@ -32,3 +37,15 @@ def forward(self, inputs, input_lengths=None, targets={}, | |
function=self.decode_function, | ||
teacher_forcing_ratio=teacher_forcing_ratio) | ||
return result | ||
|
||
def _init_weights(self, uniform_init=0.0, glorot_init=False): | ||
# initialize weights using uniform distribution | ||
if uniform_init > 0.0: | ||
for p in self.parameters(): | ||
p.data.uniform_(-uniform_init, uniform_init) | ||
|
||
# xavier/glorot initialization if glorot_init | ||
if glorot_init: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know much about initialisations, so mostly a question: does the glorot initialisation depend on the probability of the uniform? If not they are both specified does the glorot_init overwrite the uniform_init paramater? In the tests it seems that the are not mutually exclusive. Perhaps we could add some docstring explaining how they behave/interact? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes if both are passed, then in this case Glorot overrides the uniform initialization for all parameters except biases. You usually tend to use either uniform or glorot/xavier at a time so I could add a docstring about this but it's not super likely that someone activates them both at the same time. Also code-wise this is exactly how openNMT does it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, makes sense. My confusion was mostly coming from my ignorance:). |
||
for p in self.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,29 @@ | ||
import unittest | ||
|
||
|
||
from machine.models.EncoderRNN import EncoderRNN | ||
from machine.models.DecoderRNN import DecoderRNN | ||
from machine.models.seq2seq import Seq2seq | ||
|
||
|
||
class TestSeq2seq(unittest.TestCase): | ||
pass | ||
def setUp(self): | ||
self.decoder = DecoderRNN(100, 50, 16, 0, 1, input_dropout_p=0) | ||
self.encoder = EncoderRNN(100, 10, 50, 16, n_layers=2, dropout_p=0.5) | ||
|
||
def test_standard_init(self): | ||
Seq2seq(self.encoder, self.decoder) | ||
Seq2seq(self.encoder, self.decoder, uniform_init=-1) | ||
|
||
def test_uniform_init(self): | ||
Seq2seq(self.encoder, self.decoder, uniform_init=1) | ||
|
||
def test_xavier_init(self): | ||
Seq2seq(self.encoder, self.decoder, glorot_init=True) | ||
|
||
def test_uniform_xavier_init(self): | ||
Seq2seq(self.encoder, self.decoder, uniform_init=1, glorot_init=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting all this in the init like this perhaps doesn't generalise very if you want to add also other kind of initialisations. Perhaps instead we could pass a function for initialisation? Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see what you mean, so we would move this initialization function outside of the model class and maybe into a util? And one would have the option of passing an initialization function to any class and have the weights of that class initialized accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think about it that way initially because I was following how openNMT has it but I think that makes more sense actually. So ignore this PR for now haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks!