Skip to content

Commit

Permalink
Make TraceCtx objects pickleable with dill (#942)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaz-cli authored Aug 8, 2024
1 parent 7433b86 commit 088798e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ igraph >=0.10.4
optree >=0.11.0
opt_einsum >= 3.3.0
mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined`
dill >=0.3.8 # Support for 3.12
6 changes: 6 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def make_prim(
method_name: None | str = None,
_bind_postprocess: None | Callable = None,
_print_as_impl: bool = False,
python_name: str | None = None,
):
sym = Symbol(
name=name,
Expand All @@ -308,6 +309,7 @@ def make_prim(
python_impl=python_impl,
_bind_postprocess=_bind_postprocess,
_print_as_impl=_print_as_impl,
_python_name=python_name,
)

if method_name is not None:
Expand Down Expand Up @@ -486,6 +488,7 @@ def _check_tensor_shape_and_metadata_meta(
"check_tensor_metadata",
meta=_check_tensor_shape_and_metadata_meta,
tags=(OpTags.DONT_DCE,),
python_name="check_tensor_shape_and_metadata",
)


Expand Down Expand Up @@ -1562,6 +1565,7 @@ def python_print_printer(
python_printer=python_print_printer,
python_impl=print,
tags=(OpTags.DONT_DCE,),
python_name="python_print",
)


Expand Down Expand Up @@ -1630,6 +1634,7 @@ def _del_impl(x: Any, /) -> None:
meta=_del_meta,
python_printer=del_printer,
python_impl=_del_impl,
python_name="python_del",
)


Expand Down Expand Up @@ -1667,6 +1672,7 @@ def _return_impl(*args) -> Any:
python_printer=return_printer,
python_impl=_return_impl,
tags=(OpTags.DONT_DCE,),
python_name="python_return",
)

#
Expand Down
54 changes: 54 additions & 0 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class Symbol:
executor: None | Any = None
python_impl: None | Callable = None
_print_as_impl: bool = False # If not None, w
_python_name: str | None = None

# An optional postprocessing function to modify the bound symbol resulting from bind()
_bind_postprocess: None | Callable = None
Expand Down Expand Up @@ -199,6 +200,59 @@ def module(self) -> None | ModuleType:
result = inspect.getmodule(fn_)
return result

@classmethod
def lookup_from_module(cls, name: str, executor: Any, module: ModuleType) -> Symbol: # For unpickling
if module not in sys.modules:
raise RuntimeError(f"Cannot find module {module} for symbol {name}.")

if executor is None:
not_found = object()
sym = getattr(sys.modules[module], name, not_found)
if sym is not_found:
raise RuntimeError(f"Could not find symbol {name} in module {module}.")
assert isinstance(sym, Symbol), (name, module, type(sym), sym)
return sym
else:
# Try to find the executor in all_executors
import thunder.extend

executors = thunder.extend.get_all_executors()

for ex in executors:
implmap = ex.implmap.values()

for key, info in implmap:
assert isinstance(key.id, str)
if key.id == name:
if (
impl.symbol is not None
and module is not None
and impl.module is not None
and module != impl.module
):
continue
return lookup_from_module(name, ex, module)

raise ValueError(f"Could not find an executor for symbol {name} from module {module.__qualname__}.")

def __reduce__(self): # For pickling
if self.module is None:
raise ValueError("Cannot serialize a symbol without a module.")

if hasattr(self, "_python_name") and not self._python_name is None:
name = self._python_name
else:
name = self.name

return (
Symbol.lookup_from_module,
(
name,
None if self.executor is None else self.executor.name,
None if self.module is None else self.module.__name__,
),
)

def __repr__(self) -> str:
return f"[Symbol name={self.name}]"

Expand Down
14 changes: 14 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,6 +2782,20 @@ def foo2(x):
assert x.grad is not None


def test_serialize_trace():
import dill as pickle

def fn(a, b):
return a + b

tm = thunder.jit(fn)
a, b = torch.randn(2, 5, device=("cuda" if torch.cuda.is_available() else "cpu"))
tm(a, b)
trace = thunder.last_traces(tm)[0]

assert str(pickle.loads(pickle.dumps(trace))) == str(trace)


@pytest.mark.parametrize("requires_grad", (True, False))
def test_dataclass_output(requires_grad):
# Test both `requires_grad={True, False}` as both have
Expand Down

0 comments on commit 088798e

Please sign in to comment.