Skip to content

Commit

Permalink
Hide all builtin attribute of namespace socket
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent 9dbbe2f commit ffa2b69
Show file tree
Hide file tree
Showing 28 changed files with 118 additions and 114 deletions.
6 changes: 3 additions & 3 deletions docs/gallery/howto/autogen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def aggregate(
aggregate_task = wg.add_task(aggregate, name="aggregate_task")

# we have to increase the link limit because by default workgraph only supports one link per input socket
aggregate_task.inputs["collected_values"].socket_link_limit = 50
aggregate_task.inputs["collected_values"]._link_limit = 50

for i in range(2): # this can be chosen as wanted
generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i))
Expand Down Expand Up @@ -188,8 +188,8 @@ def aggregate(**collected_values):

# we have to increase the link limit because by default workgraph only supports
# one link per input socket.
aggregate_task.inputs["collected_ints"].socket_link_limit = 50
aggregate_task.inputs["collected_floats"].socket_link_limit = 50
aggregate_task.inputs["collected_ints"]._link_limit = 50
aggregate_task.inputs["collected_floats"]._link_limit = 50


for i in range(2): # this can be chosen as wanted
Expand Down
16 changes: 8 additions & 8 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,12 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.inputs:
if socket.socket_name in builtin_input_names:
if socket._name in builtin_input_names:
continue
inputs.append(
{
"identifier": socket._socket_identifier,
"name": f"{task.name}.{socket.socket_name}",
"identifier": socket._identifier,
"name": f"{task.name}.{socket._name}",
}
)
# outputs
Expand All @@ -397,18 +397,18 @@ def build_task_from_workgraph(wg: any) -> Task:
}
)
for socket in task.outputs:
if socket.socket_name in builtin_output_names:
if socket._name in builtin_output_names:
continue
outputs.append(
{
"identifier": socket._socket_identifier,
"name": f"{task.name}.{socket.socket_name}",
"identifier": socket._identifier,
"name": f"{task.name}.{socket._name}",
}
)
group_outputs.append(
{
"name": f"{task.name}.{socket.socket_name}",
"from": f"{task.name}.{socket.socket_name}",
"name": f"{task.name}.{socket._name}",
"from": f"{task.name}.{socket._name}",
}
)
# add built-in sockets
Expand Down
6 changes: 3 additions & 3 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def get_task(self, name: str):
task = Task.from_dict(self.ctx._tasks[name])
# update task results
for output in task.outputs:
output.socket_value = get_nested_dict(
output.value = get_nested_dict(
self.ctx._tasks[name]["results"],
output.socket_name,
default=output.socket_value,
output._name,
default=output.value,
)
return task

Expand Down
10 changes: 5 additions & 5 deletions src/aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ def node_value(self):

def get_node_value(self):
"""Obtain the actual Python `value` of the object attached to the Socket."""
if isinstance(self.socket_value, orm.Data):
if hasattr(self.socket_value, "value"):
return self.socket_value.value
if isinstance(self.value, orm.Data):
if hasattr(self.value, "value"):
return self.value.value
else:
raise ValueError(
"Data node does not have a value attribute. We do not know how to extract the raw Python value."
)
else:
return self.socket_value
return self.value


class TaskSocketNamespace(NodeSocketNamespace):
"""Represent a namespace of a Task in the AiiDA WorkGraph."""

_socket_identifier = "workgraph.namespace"
_identifier = "workgraph.namespace"
_socket_property_class = TaskProperty
_type_mapping: dict = type_mapping

Expand Down
28 changes: 14 additions & 14 deletions src/aiida_workgraph/sockets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,96 +4,96 @@
class SocketAny(TaskSocket):
"""Any socket."""

_socket_identifier: str = "workgraph.any"
_identifier: str = "workgraph.any"
_socket_property_identifier: str = "workgraph.any"


class SocketFloat(TaskSocket):
"""Float socket."""

_socket_identifier: str = "workgraph.float"
_identifier: str = "workgraph.float"
_socket_property_identifier: str = "workgraph.float"


class SocketInt(TaskSocket):
"""Int socket."""

_socket_identifier: str = "workgraph.int"
_identifier: str = "workgraph.int"
_socket_property_identifier: str = "workgraph.int"


class SocketString(TaskSocket):
"""String socket."""

_socket_identifier: str = "workgraph.string"
_identifier: str = "workgraph.string"
_socket_property_identifier: str = "workgraph.string"


class SocketBool(TaskSocket):
"""Bool socket."""

_socket_identifier: str = "workgraph.bool"
_identifier: str = "workgraph.bool"
_socket_property_identifier: str = "workgraph.bool"


class SocketAiiDAFloat(TaskSocket):
"""AiiDAFloat socket."""

_socket_identifier: str = "workgraph.aiida_float"
_identifier: str = "workgraph.aiida_float"
_socket_property_identifier: str = "workgraph.aiida_float"


class SocketAiiDAInt(TaskSocket):
"""AiiDAInt socket."""

_socket_identifier: str = "workgraph.aiida_int"
_identifier: str = "workgraph.aiida_int"
_socket_property_identifier: str = "workgraph.aiida_int"


class SocketAiiDAString(TaskSocket):
"""AiiDAString socket."""

_socket_identifier: str = "workgraph.aiida_string"
_identifier: str = "workgraph.aiida_string"
_socket_property_identifier: str = "workgraph.aiida_string"


class SocketAiiDABool(TaskSocket):
"""AiiDABool socket."""

_socket_identifier: str = "workgraph.aiida_bool"
_identifier: str = "workgraph.aiida_bool"
_socket_property_identifier: str = "workgraph.aiida_bool"


class SocketAiiDAList(TaskSocket):
"""AiiDAList socket."""

_socket_identifier: str = "workgraph.aiida_list"
_identifier: str = "workgraph.aiida_list"
_socket_property_identifier: str = "workgraph.aiida_list"


class SocketAiiDADict(TaskSocket):
"""AiiDADict socket."""

_socket_identifier: str = "workgraph.aiida_dict"
_identifier: str = "workgraph.aiida_dict"
_socket_property_identifier: str = "workgraph.aiida_dict"


class SocketAiiDAIntVector(TaskSocket):
"""Socket with a AiiDAIntVector property."""

_socket_identifier: str = "workgraph.aiida_int_vector"
_identifier: str = "workgraph.aiida_int_vector"
_socket_property_identifier: str = "workgraph.aiida_int_vector"


class SocketAiiDAFloatVector(TaskSocket):
"""Socket with a FloatVector property."""

_socket_identifier: str = "workgraph.aiida_float_vector"
_identifier: str = "workgraph.aiida_float_vector"
_socket_property_identifier: str = "workgraph.aiida_float_vector"


class SocketStructureData(TaskSocket):
"""Any socket."""

_socket_identifier: str = "workgraph.aiida_structuredata"
_identifier: str = "workgraph.aiida_structuredata"
_socket_property_identifier: str = "workgraph.aiida_structuredata"
6 changes: 3 additions & 3 deletions src/aiida_workgraph/tasks/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def create_sockets(self) -> None:
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
inp._link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")

Expand Down Expand Up @@ -54,7 +54,7 @@ def create_sockets(self) -> None:
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
inp._link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")

Expand Down Expand Up @@ -85,6 +85,6 @@ def create_sockets(self) -> None:
self.add_input(
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
)
inp.socket_link_limit = 100000
inp._link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
14 changes: 8 additions & 6 deletions src/aiida_workgraph/tasks/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@ class PythonJob(Task):

def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob":
"""Overwrite the update_from_dict method to handle the PythonJob data."""
self.deserialize_pythonjob_data(data)
self.deserialize_pythonjob_data(data["inputs"])
super().update_from_dict(data)

@classmethod
def serialize_pythonjob_data(cls, tdata: Dict[str, Any]):
def serialize_pythonjob_data(cls, input_data: Dict[str, Any]):
"""Serialize the properties for PythonJob."""

for input in tdata["inputs"].values():
for input in input_data.values():
if input["metadata"].get("is_function_input", False):
if input.get("property", {}).get("value") is not None:
if ["identifier"] == "workgraph.namespace":
cls.serialize_socket_data(input["sockets"])
elif input.get("property", {}).get("value") is not None:
input["property"]["value"] = cls.serialize_socket_data(input)

@classmethod
def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
def deserialize_pythonjob_data(cls, input_data: Dict[str, Any]) -> None:
"""
Process the task data dictionary for a PythonJob.
It load the orignal Python data from the AiiDA Data node for the
Expand All @@ -37,7 +39,7 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
Dict[str, Any]: The processed data dictionary.
"""

for input in tdata["inputs"].values():
for input in input_data.values():
if input["metadata"].get("is_function_input", False):
input["property"]["value"] = cls.deserialize_socket_data(input)

Expand Down
2 changes: 1 addition & 1 deletion src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def serialize_workgraph_inputs(wgdata):
if not data["handler"]["use_module_path"]:
pickle_callable(data["handler"])
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
PythonJob.serialize_pythonjob_data(task)
PythonJob.serialize_pythonjob_data(task["inputs"])
for _, input in task["inputs"].items():
if input.get("property"):
prop = input["property"]
Expand Down
8 changes: 4 additions & 4 deletions src/aiida_workgraph/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def link_creation_hook(self, link: Any) -> None:
"type": "add_link",
"data": {
"from_node": link.from_node.name,
"from_socket": link.from_socket.socket_name,
"from_socket": link.from_socket._name,
"to_node": link.to_node.name,
"to_socket": link.to_socket.socket_name,
"to_socket": link.to_socket._name,
},
}
)
Expand All @@ -65,9 +65,9 @@ def link_deletion_hook(self, link: Any) -> None:
"type": "delete_link",
"data": {
"from_node": link.from_node.name,
"from_socket": link.from_socket.socket_name,
"from_socket": link.from_socket._name,
"to_node": link.to_node.name,
"to_socket": link.to_socket.socket_name,
"to_socket": link.to_socket._name,
},
}
)
8 changes: 5 additions & 3 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,14 @@ def update(self) -> None:
if node.is_finished_ok:
# update the output sockets
for socket in self.tasks[name].outputs:
socket.socket_value = get_nested_dict(
node.outputs, socket.socket_name, default=None
if socket._identifier == "workgraph.namespace":
continue
socket.value = get_nested_dict(
node.outputs, socket._name, default=None
)
# read results from the process outputs
elif isinstance(node, aiida.orm.Data):
self.tasks[name].outputs[0].socket_value = node
self.tasks[name].outputs[0].value = node
execution_count = getattr(self.process.outputs, "execution_count", None)
self.execution_count = execution_count if execution_count else 0
if self._widget is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_pause_play_task(wg_calcjob):
# Seems the daemon is not responding to the play signal
wg.play_tasks(["add2"])
wg.wait(interval=5)
assert wg.tasks["add2"].outputs["sum"].socket_value == 9
assert wg.tasks["add2"].outputs["sum"].value == 9


def test_pause_play_error_handler(wg_calcjob, finished_process_node):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_awaitable_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def awaitable_func(x, y):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: awaitable_func1" in report
assert add1.outputs["result"].socket_value == 4
assert add1.outputs["result"].value == 4


def test_monitor_decorator():
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_time_monitor(decorated_add):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: monitor1" in report
assert add1.outputs["result"].socket_value == 3
assert add1.outputs["result"].value == 3


def test_file_monitor(decorated_add, tmp_path):
Expand All @@ -84,7 +84,7 @@ async def create_test_file(filepath="/tmp/test_file_monitor.txt", t=2):
wg.run()
report = get_workchain_report(wg.process, "REPORT")
assert "Waiting for child processes: monitor1" in report
assert add1.outputs["result"].socket_value == 3
assert add1.outputs["result"].value == 3


@pytest.mark.usefixtures("started_daemon_client")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_calcfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_run(wg_calcfunction: WorkGraph) -> None:
print("state: ", wg.state)
# print("results: ", results[])
assert wg.tasks["sumdiff2"].node.outputs.sum == 9
assert wg.tasks["sumdiff2"].outputs["sum"].socket_value == 9
assert wg.tasks["sumdiff2"].outputs["sum"].value == 9


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -27,4 +27,4 @@ def add(**kwargs):
wg = WorkGraph("test_dynamic_inputs")
wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2))
wg.run()
assert wg.tasks["add1"].outputs["result"].socket_value == 3
assert wg.tasks["add1"].outputs["result"].value == 3
2 changes: 1 addition & 1 deletion tests/test_calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None:
wg = wg_calcjob
wg.name = "test_submit_calcjob"
wg.submit(wait=True)
assert wg.tasks["add2"].outputs["sum"].socket_value == 9
assert wg.tasks["add2"].outputs["sum"].value == 9
4 changes: 2 additions & 2 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_workgraph_ctx(decorated_add: Callable) -> None:
get_ctx1.waiting_on.add(add1)
add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1)
wg.run()
assert add2.outputs["result"].socket_value == 6
assert add2.outputs["result"].value == 6


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -40,4 +40,4 @@ def test_task_set_ctx(decorated_add: Callable) -> None:
add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}")
wg.add_link(add1.outputs[0], add2.inputs["x"])
wg.submit(wait=True)
assert add2.outputs["result"].socket_value == 10
assert add2.outputs["result"].value == 10
Loading

0 comments on commit ffa2b69

Please sign in to comment.