Skip to content

Commit

Permalink
Modify lazy_dyndep loading to trigger inside workspace. (pytorch#41687)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#41687

Specifically, this makes a new library (lazy), which can be used from both core
and workspace.

This allows workspace.Createnet to trigger lazy loading of dyndep dependencies.

Test Plan: Added a unit test specifically for workspace.CreateNet

Reviewed By: dzhulgakov

Differential Revision: D22441877

fbshipit-source-id: 3a9d1af9962585d08ea2566c9c85bec7377d39f2
  • Loading branch information
c00w authored and facebook-github-bot committed Jul 22, 2020
1 parent af5d0bf commit dfa914a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 19 deletions.
23 changes: 6 additions & 17 deletions caffe2/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils, workspace
from caffe2.python.lazy import TriggerLazyImport
from caffe2.python.control_ops_grad import \
gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output

Expand Down Expand Up @@ -49,18 +50,6 @@ def _InitDataType():

_InitDataType()

_import_lazy_calls = []

def RegisterLazyImport(lazy):
global _import_lazy_calls
_import_lazy_calls += [lazy]


def _import_lazy():
global _import_lazy_calls
for lazy in _import_lazy_calls:
lazy()


def _GetRegisteredOperators():
return set(workspace.RegisteredOperators())
Expand All @@ -71,7 +60,7 @@ def _GetRegisteredOperators():

def RefreshRegisteredOperators(trigger_lazy=True):
if trigger_lazy:
_import_lazy()
TriggerLazyImport()
global _REGISTERED_OPERATORS
_REGISTERED_OPERATORS = _GetRegisteredOperators()

Expand All @@ -80,7 +69,7 @@ def RefreshRegisteredOperators(trigger_lazy=True):


def GlobalInit(args):
_import_lazy()
TriggerLazyImport()
_GLOBAL_INIT_ARGS.extend(args[1:])
C.global_init(args)

Expand All @@ -94,7 +83,7 @@ def IsOperator(op_type):


def IsOperatorWithEngine(op_type, engine):
_import_lazy()
TriggerLazyImport()
return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS


Expand Down Expand Up @@ -294,7 +283,7 @@ def __getattr__(self, op_type):
op_type, *args, **kwargs)

def __dir__(self):
_import_lazy()
TriggerLazyImport()
additional_methods = [
op
for op in _REGISTERED_OPERATORS
Expand Down Expand Up @@ -2228,7 +2217,7 @@ def __getattr__(self, op_type):
op_type, *args, **kwargs)

def __dir__(self):
_import_lazy()
TriggerLazyImport()
additional_methods = [
op
for op in _REGISTERED_OPERATORS
Expand Down
14 changes: 14 additions & 0 deletions caffe2/python/lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## @package workspace
# Module caffe2.python.lazy

_import_lazy_calls = []

def RegisterLazyImport(lazy):
global _import_lazy_calls
_import_lazy_calls += [lazy]


def TriggerLazyImport():
global _import_lazy_calls
for lazy in _import_lazy_calls:
lazy()
4 changes: 2 additions & 2 deletions caffe2/python/lazy_dyndep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import unicode_literals

import os
from caffe2.python import core, dyndep
from caffe2.python import dyndep, lazy


def RegisterOpsLibrary(name):
Expand Down Expand Up @@ -81,4 +81,4 @@ def _import_lazy():
finally:
_LAZY_IMPORTED_DYNDEPS.remove(name)

core.RegisterLazyImport(_import_lazy)
lazy.RegisterLazyImport(_import_lazy)
14 changes: 14 additions & 0 deletions caffe2/python/lazy_dyndep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ def handlernoop(e):
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
core.RefreshRegisteredOperators()

def test_workspacecreatenet(self):
from caffe2.python import workspace, lazy_dyndep
import tempfile

with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
called = False

def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
workspace.CreateNet("fake")


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions caffe2/python/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils
from caffe2.python.lazy import TriggerLazyImport

import caffe2.python._import_c_extension as C

Expand Down Expand Up @@ -172,6 +173,7 @@ def ResetWorkspace(root_folder=None):


def CreateNet(net, overwrite=False, input_blobs=None):
TriggerLazyImport()
if input_blobs is None:
input_blobs = []
for input_blob in input_blobs:
Expand Down

0 comments on commit dfa914a

Please sign in to comment.