diff --git a/integration_test.sh b/integration_test.sh index d1ef3ee3..13ae3809 100755 --- a/integration_test.sh +++ b/integration_test.sh @@ -110,6 +110,14 @@ echo "\n\nTest multiple layers with pre-rnn attention" python3 train_model.py --train $TRAIN_PATH --dev $DEV_PATH --output_dir $EXPT_DIR --print_every 50 --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --epoch $EPOCH --save_every $CP_EVERY --n_layers 3 --attention 'post-rnn' --attention_method 'dot' ERR=$((ERR+$?)); EX=$((EX+1)) +echo "\n\nTest Xavier/Glorot Initialization" +python3 train_model.py --train $TRAIN_PATH --dev $DEV_PATH --output_dir $EXPT_DIR --print_every 50 --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --epoch 1 --save_every $CP_EVERY --n_layers 2 --glorot_init +ERR=$((ERR+$?)); EX=$((EX+1)) + +echo "\n\nTest uniform Initialization" +python3 train_model.py --train $TRAIN_PATH --dev $DEV_PATH --output_dir $EXPT_DIR --print_every 50 --embedding_size $EMB_SIZE --hidden_size $H_SIZE --rnn_cell $CELL --epoch 1 --save_every $CP_EVERY --n_layers 2 --uniform_init 0.1 +ERR=$((ERR+$?)); EX=$((EX+1)) + echo "\n\n\n$EX tests executed, $ERR tests failed\n\n" rm -r $EXPT_DIR diff --git a/machine/models/seq2seq.py b/machine/models/seq2seq.py index d90e12fe..539f3d8c 100644 --- a/machine/models/seq2seq.py +++ b/machine/models/seq2seq.py @@ -1,4 +1,5 @@ import torch.nn.functional as F +import torch.nn as nn from .baseModel import BaseModel @@ -8,9 +9,13 @@ class Seq2seq(BaseModel): and decoder. """ - def __init__(self, encoder, decoder, decode_function=F.log_softmax): + def __init__(self, encoder, decoder, decode_function=F.log_softmax, + uniform_init=0, glorot_init=False): super(Seq2seq, self).__init__(encoder_module=encoder, - decoder_module=decoder, decode_function=decode_function) + decoder_module=decoder, + decode_function=decode_function) + # Initialize Weights + self._init_weights(uniform_init, glorot_init) def flatten_parameters(self): """ @@ -32,3 +37,15 @@ def forward(self, inputs, input_lengths=None, targets={}, function=self.decode_function, teacher_forcing_ratio=teacher_forcing_ratio) return result + + def _init_weights(self, uniform_init=0.0, glorot_init=False): + # initialize weights using uniform distribution + if uniform_init > 0.0: + for p in self.parameters(): + p.data.uniform_(-uniform_init, uniform_init) + + # xavier/glorot initialization if glorot_init + if glorot_init: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) diff --git a/test/test_seq2seq.py b/test/test_seq2seq.py index b996d8c9..72cd8f35 100644 --- a/test/test_seq2seq.py +++ b/test/test_seq2seq.py @@ -1,5 +1,29 @@ import unittest +from machine.models.EncoderRNN import EncoderRNN +from machine.models.DecoderRNN import DecoderRNN +from machine.models.seq2seq import Seq2seq + + class TestSeq2seq(unittest.TestCase): - pass + def setUp(self): + self.decoder = DecoderRNN(100, 50, 16, 0, 1, input_dropout_p=0) + self.encoder = EncoderRNN(100, 10, 50, 16, n_layers=2, dropout_p=0.5) + + def test_standard_init(self): + Seq2seq(self.encoder, self.decoder) + Seq2seq(self.encoder, self.decoder, uniform_init=-1) + + def test_uniform_init(self): + Seq2seq(self.encoder, self.decoder, uniform_init=1) + + def test_xavier_init(self): + Seq2seq(self.encoder, self.decoder, glorot_init=True) + + def test_uniform_xavier_init(self): + Seq2seq(self.encoder, self.decoder, uniform_init=1, glorot_init=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/train_model.py b/train_model.py index cfa92ffa..08d0b9d7 100644 --- a/train_model.py +++ b/train_model.py @@ -85,6 +85,11 @@ def init_argparser(): choices=['adam', 'adadelta', 'adagrad', 'adamax', 'rmsprop', 'sgd']) parser.add_argument('--max_len', type=int, help='Maximum sequence length', default=50) + parser.add_argument('--uniform_init', type=float, + help='Initializes weights of model from uniform distribution in range (-uniform_init, uniform_init). \ + If <= 0, standard pytorch init is used. (default, 0.0)', default=0.0) + parser.add_argument('--glorot_init', action='store_true', + help='Initializes weights of model using glorot/xavier distribution') parser.add_argument( '--rnn_cell', help="Chose type of rnn cell", default='lstm') parser.add_argument('--bidirectional', action='store_true', @@ -252,7 +257,9 @@ def initialize_model(opt, src, tgt, train): bidirectional=opt.bidirectional, rnn_cell=opt.rnn_cell, eos_id=tgt.eos_id, sos_id=tgt.sos_id) - seq2seq = Seq2seq(encoder, decoder) + seq2seq = Seq2seq(encoder, decoder, + glorot_init=opt.glorot_init, + uniform_init=opt.uniform_init) seq2seq.to(device) return seq2seq, input_vocab, output_vocab