Skip to content

Commit

Permalink
Auto-format
Browse files Browse the repository at this point in the history
  • Loading branch information
ines committed Aug 27, 2019
1 parent 541505e commit dd04a1a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
45 changes: 32 additions & 13 deletions spacy_pytorch_transformers/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,22 @@
from .util import cyclic_triangular_rate


def train_while_improving(nlp, train_data, evaluate, *,
learning_rate: float, batch_size: int,
weight_decay: float, classifier_lr: float, dropout: float,
lr_range: int, lr_period: int,
steps_per_batch: int, patience: int, eval_every: int):
def train_while_improving(
nlp,
train_data,
evaluate,
*,
learning_rate: float,
batch_size: int,
weight_decay: float,
classifier_lr: float,
dropout: float,
lr_range: int,
lr_period: int,
steps_per_batch: int,
patience: int,
eval_every: int
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
where info is a dict, and is_best_checkpoint is in [True, False, None] --
Expand Down Expand Up @@ -99,9 +110,7 @@ def train_while_improving(nlp, train_data, evaluate, *,
steps_per_epoch = nr_batch * steps_per_batch
optimizer = nlp.resume_training()
learn_rates = cyclic_triangular_rate(
learning_rate / lr_range,
learning_rate * lr_range,
steps_per_epoch
learning_rate / lr_range, learning_rate * lr_range, steps_per_epoch
)
optimizer.pytt_lr = next(learn_rates)
optimizer.pytt_weight_decay = HP.weight_decay
Expand All @@ -118,8 +127,13 @@ def train_while_improving(nlp, train_data, evaluate, *,
optimizer.pytt_lr = next(learn_rates)
docs, golds = zip(*batch)
losses = {}
nlp.update(docs, golds, drop=HP.dropout, losses=losses,
sgd=(optimizer if (step % steps_per_batch == 0) else None))
nlp.update(
docs,
golds,
drop=HP.dropout,
losses=losses,
sgd=(optimizer if (step % steps_per_batch == 0) else None),
)
if step != 0 and not (step % (eval_every * steps_per_batch)):
with nlp.use_params(optimizer.averages):
score, other_scores = evaluate()
Expand All @@ -128,9 +142,14 @@ def train_while_improving(nlp, train_data, evaluate, *,
else:
score, other_scores = (None, None)
is_best_checkpoint = None
info = {"epoch": epoch, "step": step, "score": score,
"other_scores": other_scores, "loss": losses,
"checkpoints": results}
info = {
"epoch": epoch,
"step": step,
"score": score,
"other_scores": other_scores,
"loss": losses,
"checkpoints": results,
}
yield batch, info, is_best_checkpoint
step += 1
epoch += 1
Expand Down
5 changes: 4 additions & 1 deletion spacy_pytorch_transformers/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def softmax_last_hidden(nr_class, *, exclusive_classes=True, **cfg):
"""
width = cfg["token_vector_width"]
return chain(
get_pytt_last_hidden, flatten_add_lengths, Pooling(mean_pool), Softmax(nr_class, width)
get_pytt_last_hidden,
flatten_add_lengths,
Pooling(mean_pool),
Softmax(nr_class, width),
)


Expand Down
2 changes: 1 addition & 1 deletion spacy_pytorch_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def get_segment_ids(name: str, *lengths) -> List[int]:
return get_gpt2_segment_ids(length1, length2)
elif "roberta" in name:
return get_roberta_segment_ids(length1, length2)

else:
raise ValueError(f"Unexpected model name: {name}")

Expand Down
2 changes: 0 additions & 2 deletions spacy_pytorch_transformers/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,3 @@ def from_bytes(self, data):
else:
map_location = torch.device("cuda")
self._model.load_state_dict(torch.load(filelike, map_location=map_location))


0 comments on commit dd04a1a

Please sign in to comment.