Skip to content

Commit

Permalink
Save pickled function as AiiDA node
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 5, 2024
1 parent c08af7a commit ba8120d
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph==0.1.6",
"node-graph==0.1.7",
"node-graph-widget",
"aiida-core>=2.3",
"cloudpickle",
Expand Down
1 change: 0 additions & 1 deletion src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def build_task_from_workgraph(wg: any) -> Task:
"name": "WorkGraphEngine",
"wgdata": serialize(wg.to_dict(store_nodes=True)),
"type": tdata["metadata"]["task_type"],
"is_pickle": False,
}
tdata["metadata"]["group_outputs"] = group_outputs
tdata["executor"] = executor
Expand Down
7 changes: 4 additions & 3 deletions src/aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def sort_socket_data(socket_data: dict) -> dict:
def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
"""Prepare the inputs for PythonJob"""
from aiida_pythonjob import prepare_pythonjob_inputs
from aiida_workgraph.utils import get_executor

function_inputs = kwargs.pop("function_inputs", {})
for _, input in task["inputs"].items():
Expand Down Expand Up @@ -69,8 +70,8 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:

metadata = kwargs.pop("metadata", {})
metadata.update({"call_link_label": task["name"]})
# get the source code of the function
executor = task["executor"]
# get the function from executor
func, _ = get_executor(task["executor"])
function_outputs = []
for output in task["outputs"].values():
if output["metadata"].get("is_function_output", False):
Expand All @@ -85,7 +86,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
)

inputs = prepare_pythonjob_inputs(
function_data=executor,
function=func,
function_inputs=function_inputs,
function_outputs=function_outputs,
code=code,
Expand Down
18 changes: 1 addition & 17 deletions src/aiida_workgraph/orm/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def __init__(self, value=None, **kwargs):
self.set_attribute(value)

def __str__(self):
return (
f"PickledFunction<{self.base.attributes.get('function_name')}> pk={self.pk}"
)
return f"PickledFunction<{self.base.attributes.get('name')}> pk={self.pk}"

@property
def metadata(self):
Expand All @@ -34,21 +32,7 @@ def metadata(self):
"source_code_without_decorator"
),
"type": "function",
"is_pickle": True,
}

@classmethod
def build_callable(cls, func):
"""Return the executor for this node."""
import cloudpickle as pickle

executor = {
"executor": pickle.dumps(func),
"type": "function",
"is_pickle": True,
}
executor.update(cls.inspect_function(func))
return executor

def set_attribute(self, value):
"""Set the contents of this node by pickling the provided function.
Expand Down
3 changes: 3 additions & 0 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def to_dict(self, short: bool = False) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize

tdata = super().to_dict(short=short)
# clear unused keys
for key in ["ctrl_inputs", "ctrl_outputs"]:
tdata.pop(key, None)
tdata["context_mapping"] = self.context_mapping
tdata["wait"] = [task.name for task in self.waiting_on]
tdata["children"] = []
Expand Down
32 changes: 17 additions & 15 deletions src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,21 @@ def build_callable(obj: Callable) -> Dict[str, Any]:
to its module and name.
"""
import types
from aiida_workgraph.orm.function_data import PickledFunction

# Check if the callable is a function or class
if isinstance(obj, (types.FunctionType, type)):
# Check if callable is nested (contains dots in __qualname__ after the first segment)
if obj.__module__ == "__main__" or "." in obj.__qualname__.split(".", 1)[-1]:
# Local or nested callable, so pickle the callable
executor = PickledFunction.build_callable(obj)
executor = {"callable": obj, "use_module_path": False}
else:
# Global callable (function/class), store its module and name for reference
executor = {
"module": obj.__module__,
"name": obj.__name__,
"is_pickle": False,
"use_module_path": True,
}
elif isinstance(obj, PickledFunction) or isinstance(obj, dict):
elif isinstance(obj, dict):
executor = obj
else:
raise TypeError("Provided object is not a callable function or class.")
Expand Down Expand Up @@ -76,19 +75,12 @@ def get_executor(data: Dict[str, Any]) -> Union[Process, Any]:
"""Import executor from path and return the executor and type."""
import importlib
from aiida.plugins import CalculationFactory, WorkflowFactory, DataFactory
from aiida_workgraph.orm.function_data import PickledFunction

data = data or {}
is_pickle = data.get("is_pickle", False)
use_module_path = data.get("use_module_path", True)
type = data.get("type", "function")
if is_pickle:
import cloudpickle as pickle

try:
executor = pickle.loads(data["executor"])
except Exception as e:
print("Error in loading executor: ", e)
executor = None
else:
if use_module_path:
if type == "WorkflowFactory":
executor = WorkflowFactory(data["name"])
elif type == "CalculationFactory":
Expand All @@ -100,6 +92,10 @@ def get_executor(data: Dict[str, Any]) -> Union[Process, Any]:
else:
module = importlib.import_module("{}".format(data.get("module", "")))
executor = getattr(module, data["name"])
else:
if not isinstance(data["callable"], PickledFunction):
raise ValueError("The callable should be PickledFunction.")
executor = data["callable"].value

return executor, type

Expand Down Expand Up @@ -446,11 +442,17 @@ def serialize_properties(wgdata):
defined in a scope, e.g., local function in another function.
So, if a function is used as input, we needt to serialize the function.
"""
from aiida_workgraph.orm.function_data import PickledLocalFunction
from aiida_workgraph.orm.function_data import PickledLocalFunction, PickledFunction
from aiida_workgraph.tasks.pythonjob import PythonJob
import inspect

for _, task in wgdata["tasks"].items():
# find the pickled executor, create a pickleddata for it
# then use the pk of the pickleddata as the value of the executor
if task["executor"] and not task["executor"]["use_module_path"]:
executor = task["executor"]["callable"]
pickled_obj = PickledFunction(executor).store()
task["executor"]["callable"] = pickled_obj
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
PythonJob.serialize_pythonjob_data(task)
for _, input in task["inputs"].items():
Expand Down

0 comments on commit ba8120d

Please sign in to comment.