diff --git a/train.py b/train.py index 7a626977..40750e0c 100644 --- a/train.py +++ b/train.py @@ -41,14 +41,14 @@ def train(opt): histories = {} if opt.start_from is not None: # open old infos and check if models are compatible - with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: + with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme - if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): + if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb'): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f)