From dd04a1a56ff59d04d080a7a760e58574e7d40e9c Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 27 Aug 2019 11:18:24 +0200 Subject: [PATCH] Auto-format --- spacy_pytorch_transformers/_train.py | 45 ++++++++++++++------ spacy_pytorch_transformers/model_registry.py | 5 ++- spacy_pytorch_transformers/util.py | 2 +- spacy_pytorch_transformers/wrapper.py | 2 - 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/spacy_pytorch_transformers/_train.py b/spacy_pytorch_transformers/_train.py index 2ae2abf7..d62d975b 100644 --- a/spacy_pytorch_transformers/_train.py +++ b/spacy_pytorch_transformers/_train.py @@ -4,11 +4,22 @@ from .util import cyclic_triangular_rate -def train_while_improving(nlp, train_data, evaluate, *, - learning_rate: float, batch_size: int, - weight_decay: float, classifier_lr: float, dropout: float, - lr_range: int, lr_period: int, - steps_per_batch: int, patience: int, eval_every: int): +def train_while_improving( + nlp, + train_data, + evaluate, + *, + learning_rate: float, + batch_size: int, + weight_decay: float, + classifier_lr: float, + dropout: float, + lr_range: int, + lr_period: int, + steps_per_batch: int, + patience: int, + eval_every: int +): """Train until an evaluation stops improving. Works as a generator, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, where info is a dict, and is_best_checkpoint is in [True, False, None] -- @@ -99,9 +110,7 @@ def train_while_improving(nlp, train_data, evaluate, *, steps_per_epoch = nr_batch * steps_per_batch optimizer = nlp.resume_training() learn_rates = cyclic_triangular_rate( - learning_rate / lr_range, - learning_rate * lr_range, - steps_per_epoch + learning_rate / lr_range, learning_rate * lr_range, steps_per_epoch ) optimizer.pytt_lr = next(learn_rates) optimizer.pytt_weight_decay = HP.weight_decay @@ -118,8 +127,13 @@ def train_while_improving(nlp, train_data, evaluate, *, optimizer.pytt_lr = next(learn_rates) docs, golds = zip(*batch) losses = {} - nlp.update(docs, golds, drop=HP.dropout, losses=losses, - sgd=(optimizer if (step % steps_per_batch == 0) else None)) + nlp.update( + docs, + golds, + drop=HP.dropout, + losses=losses, + sgd=(optimizer if (step % steps_per_batch == 0) else None), + ) if step != 0 and not (step % (eval_every * steps_per_batch)): with nlp.use_params(optimizer.averages): score, other_scores = evaluate() @@ -128,9 +142,14 @@ def train_while_improving(nlp, train_data, evaluate, *, else: score, other_scores = (None, None) is_best_checkpoint = None - info = {"epoch": epoch, "step": step, "score": score, - "other_scores": other_scores, "loss": losses, - "checkpoints": results} + info = { + "epoch": epoch, + "step": step, + "score": score, + "other_scores": other_scores, + "loss": losses, + "checkpoints": results, + } yield batch, info, is_best_checkpoint step += 1 epoch += 1 diff --git a/spacy_pytorch_transformers/model_registry.py b/spacy_pytorch_transformers/model_registry.py index f3b75350..bc4d84e8 100644 --- a/spacy_pytorch_transformers/model_registry.py +++ b/spacy_pytorch_transformers/model_registry.py @@ -93,7 +93,10 @@ def softmax_last_hidden(nr_class, *, exclusive_classes=True, **cfg): """ width = cfg["token_vector_width"] return chain( - get_pytt_last_hidden, flatten_add_lengths, Pooling(mean_pool), Softmax(nr_class, width) + get_pytt_last_hidden, + flatten_add_lengths, + Pooling(mean_pool), + Softmax(nr_class, width), ) diff --git a/spacy_pytorch_transformers/util.py b/spacy_pytorch_transformers/util.py index 33737477..ae9caaa9 100644 --- a/spacy_pytorch_transformers/util.py +++ b/spacy_pytorch_transformers/util.py @@ -238,7 +238,7 @@ def get_segment_ids(name: str, *lengths) -> List[int]: return get_gpt2_segment_ids(length1, length2) elif "roberta" in name: return get_roberta_segment_ids(length1, length2) - + else: raise ValueError(f"Unexpected model name: {name}") diff --git a/spacy_pytorch_transformers/wrapper.py b/spacy_pytorch_transformers/wrapper.py index e58d5c09..04ed22ae 100644 --- a/spacy_pytorch_transformers/wrapper.py +++ b/spacy_pytorch_transformers/wrapper.py @@ -253,5 +253,3 @@ def from_bytes(self, data): else: map_location = torch.device("cuda") self._model.load_state_dict(torch.load(filelike, map_location=map_location)) - -