Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
ines committed Aug 2, 2019
1 parent e7983d1 commit 753be22
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion spacy_pytorch_transformers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def resume_training(self, sgd=None, component_cfg=None, **kwargs):
for name, component in self.pipeline:
if name == tok2vec_name:
continue
elif getattr(component, "model", None) != True:
elif getattr(component, "model", None) is not True:
continue
elif not hasattr(component, "begin_training"):
continue
Expand Down
7 changes: 3 additions & 4 deletions spacy_pytorch_transformers/model_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from thinc.api import layerize, chain, flatten_add_lengths, with_getitem
from thinc.t2v import Pooling, mean_pool, max_pool
from thinc.v2v import Softmax, Maxout
from thinc.t2v import Pooling, mean_pool
from thinc.v2v import Softmax
from thinc.neural.util import get_array_module


Expand Down Expand Up @@ -39,8 +39,7 @@ def fine_tune_class_vector(nr_class, *, exclusive_classes=True, **cfg):
return chain(
get_pytt_class_tokens,
flatten_add_lengths,
with_getitem(0,
Softmax(nr_class, cfg["token_vector_width"])),
with_getitem(0, Softmax(nr_class, cfg["token_vector_width"])),
Pooling(mean_pool),
)

Expand Down
4 changes: 3 additions & 1 deletion spacy_pytorch_transformers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ def Model(cls, nr_class=1, exclusive_classes=False, **cfg):
**cfg: Optional config parameters.
RETURNS (thinc.neural.Model): The model.
"""
make_model = get_model_function(cfg.get("architecture", "fine_tune_class_vector"))
make_model = get_model_function(
cfg.get("architecture", "fine_tune_class_vector")
)
return make_model(nr_class, exclusive_classes=exclusive_classes, **cfg)
2 changes: 1 addition & 1 deletion spacy_pytorch_transformers/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def predict(self, ids: Array):
y_var = self._model(ids, **model_kwargs)
self._model.training = is_training
return Activations.from_pytt(y_var, is_grad=False)

def begin_update(
self, ids: Array, drop: Dropout = 0.0
) -> Tuple[Activations, Callable[..., None]]:
Expand Down

0 comments on commit 753be22

Please sign in to comment.