Skip to content

Commit

Permalink
rename socket's attribute's name
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 10, 2024
1 parent 137e3df commit c74bf4e
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 138 deletions.
6 changes: 3 additions & 3 deletions docs/gallery/howto/autogen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def aggregate(
aggregate_task = wg.add_task(aggregate, name="aggregate_task")

# we have to increase the link limit because by default workgraph only supports one link per input socket
aggregate_task.inputs["collected_values"].link_limit = 50
aggregate_task.inputs["collected_values"].socket_link_limit = 50

for i in range(2): # this can be chosen as wanted
generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i))
Expand Down Expand Up @@ -188,8 +188,8 @@ def aggregate(**collected_values):

# we have to increase the link limit because by default workgraph only supports
# one link per input socket.
aggregate_task.inputs["collected_ints"].link_limit = 50
aggregate_task.inputs["collected_floats"].link_limit = 50
aggregate_task.inputs["collected_ints"].socket_link_limit = 50
aggregate_task.inputs["collected_floats"].socket_link_limit = 50


for i in range(2): # this can be chosen as wanted
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"

[project.entry-points."aiida_workgraph.socket"]
"workgraph.any" = "aiida_workgraph.sockets.builtins:SocketAny"
"workgraph.namespace" = "aiida_workgraph.sockets.builtins:SocketNamespace"
"workgraph.namespace" = "aiida_workgraph.socket:TaskSocketNamespace"
"workgraph.int" = "aiida_workgraph.sockets.builtins:SocketInt"
"workgraph.float" = "aiida_workgraph.sockets.builtins:SocketFloat"
"workgraph.string" = "aiida_workgraph.sockets.builtins:SocketString"
Expand Down
34 changes: 0 additions & 34 deletions src/aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from node_graph.collection import (
NodeCollection,
PropertyCollection,
InputSocketCollection,
OutputSocketCollection,
)
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -60,35 +58,3 @@ def _new(
identifier = build_property_from_AiiDA(identifier)
# Call the original new method
return super()._new(identifier, name, **kwargs)


class WorkGraphInputSocketCollection(InputSocketCollection):
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
**kwargs: Any
) -> Any:
from aiida_workgraph.socket import build_socket_from_AiiDA

# build the socket on the fly if the identifier is a callable
if callable(identifier):
identifier = build_socket_from_AiiDA(identifier)
# Call the original new method
return super()._new(identifier, name, **kwargs)


class WorkGraphOutputSocketCollection(OutputSocketCollection):
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
**kwargs: Any
) -> Any:
from aiida_workgraph.socket import build_socket_from_AiiDA

# build the socket on the fly if the identifier is a callable
if callable(identifier):
identifier = build_socket_from_AiiDA(identifier)
# Call the original new method
return super()._new(identifier, name, **kwargs)
4 changes: 1 addition & 3 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def add_input_recursive(
"name": port_name,
"arg_type": "kwargs",
"metadata": {"required": required, "dynamic": port.dynamic},
"property": {"identifier": "workgraph.any", "default": None},
}
)
for value in port.values():
Expand Down Expand Up @@ -248,11 +247,10 @@ def build_task_from_AiiDA(
if name not in [input["name"] for input in inputs]:
inputs.append(
{
"identifier": "workgraph.any",
"identifier": "workgraph.namespace",
"name": name,
"arg_type": "var_kwargs",
"metadata": {"dynamic": True},
"property": {"identifier": "workgraph.any", "default": {}},
}
)

Expand Down
12 changes: 9 additions & 3 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None:
kwargs[key] = args[i]
# update the port namespace
kwargs = update_nested_dict_with_special_keys(kwargs)
print("kwargs: ", kwargs)
# kwargs["meta.label"] = name
# output must be a Data type or a mapping of {string: Data}
task["results"] = {}
Expand Down Expand Up @@ -592,9 +593,14 @@ def get_inputs(
for name, input in task["inputs"].items():
# print(f"input: {input['name']}")
if len(input["links"]) == 0:
inputs[name] = self.ctx_manager.update_context_variable(
input["property"]["value"]
)
if input["identifier"] == "workgraph.namespace":
inputs[name] = self.ctx_manager.update_context_variable(
input["value"]
)
else:
inputs[name] = self.ctx_manager.update_context_variable(
input["property"]["value"]
)
elif len(input["links"]) == 1:
link = input["links"][0]
if self.ctx._tasks[link["from_node"]]["results"] is None:
Expand Down
22 changes: 16 additions & 6 deletions src/aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Type

from aiida import orm
from node_graph.socket import NodeSocket
from node_graph.socket import NodeSocket, NodeSocketNamespace

from aiida_workgraph.property import TaskProperty

Expand All @@ -11,23 +11,33 @@ class TaskSocket(NodeSocket):

# use TaskProperty from aiida_workgraph.property
# to override the default NodeProperty from node_graph
node_property = TaskProperty
_socket_property_class = TaskProperty

@property
def node_value(self):
return self.get_node_value()

def get_node_value(self):
"""Obtain the actual Python `value` of the object attached to the Socket."""
if isinstance(self.value, orm.Data):
if hasattr(self.value, "value"):
return self.value.value
if isinstance(self.socket_value, orm.Data):
if hasattr(self.socket_value, "value"):
return self.socket_value.value
else:
raise ValueError(
"Data node does not have a value attribute. We do not know how to extract the raw Python value."
)
else:
return self.value
return self.socket_value


class TaskSocketNamespace(NodeSocketNamespace):
"""Represent a namespace of a Task in the AiiDA WorkGraph."""

_socket_identifier = "workgraph.namespace"
_socket_property_class = TaskProperty

def __init__(self, *args, **kwargs):
super().__init__(*args, entry_point="aiida_workgraph.socket", **kwargs)


def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]:
Expand Down
63 changes: 28 additions & 35 deletions src/aiida_workgraph/sockets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,103 +4,96 @@
class SocketAny(TaskSocket):
"""Any socket."""

identifier: str = "workgraph.any"
property_identifier: str = "workgraph.any"


class SocketNamespace(TaskSocket):
"""Namespace socket."""

identifier: str = "workgraph.namespace"
property_identifier: str = "workgraph.any"
_socket_identifier: str = "workgraph.any"
_socket_property_identifier: str = "workgraph.any"


class SocketFloat(TaskSocket):
"""Float socket."""

identifier: str = "workgraph.float"
property_identifier: str = "workgraph.float"
_socket_identifier: str = "workgraph.float"
_socket_property_identifier: str = "workgraph.float"


class SocketInt(TaskSocket):
"""Int socket."""

identifier: str = "workgraph.int"
property_identifier: str = "workgraph.int"
_socket_identifier: str = "workgraph.int"
_socket_property_identifier: str = "workgraph.int"


class SocketString(TaskSocket):
"""String socket."""

identifier: str = "workgraph.string"
property_identifier: str = "workgraph.string"
_socket_identifier: str = "workgraph.string"
_socket_property_identifier: str = "workgraph.string"


class SocketBool(TaskSocket):
"""Bool socket."""

identifier: str = "workgraph.bool"
property_identifier: str = "workgraph.bool"
_socket_identifier: str = "workgraph.bool"
_socket_property_identifier: str = "workgraph.bool"


class SocketAiiDAFloat(TaskSocket):
"""AiiDAFloat socket."""

identifier: str = "workgraph.aiida_float"
property_identifier: str = "workgraph.aiida_float"
_socket_identifier: str = "workgraph.aiida_float"
_socket_property_identifier: str = "workgraph.aiida_float"


class SocketAiiDAInt(TaskSocket):
"""AiiDAInt socket."""

identifier: str = "workgraph.aiida_int"
property_identifier: str = "workgraph.aiida_int"
_socket_identifier: str = "workgraph.aiida_int"
_socket_property_identifier: str = "workgraph.aiida_int"


class SocketAiiDAString(TaskSocket):
"""AiiDAString socket."""

identifier: str = "workgraph.aiida_string"
property_identifier: str = "workgraph.aiida_string"
_socket_identifier: str = "workgraph.aiida_string"
_socket_property_identifier: str = "workgraph.aiida_string"


class SocketAiiDABool(TaskSocket):
"""AiiDABool socket."""

identifier: str = "workgraph.aiida_bool"
property_identifier: str = "workgraph.aiida_bool"
_socket_identifier: str = "workgraph.aiida_bool"
_socket_property_identifier: str = "workgraph.aiida_bool"


class SocketAiiDAList(TaskSocket):
"""AiiDAList socket."""

identifier: str = "workgraph.aiida_list"
property_identifier: str = "workgraph.aiida_list"
_socket_identifier: str = "workgraph.aiida_list"
_socket_property_identifier: str = "workgraph.aiida_list"


class SocketAiiDADict(TaskSocket):
"""AiiDADict socket."""

identifier: str = "workgraph.aiida_dict"
property_identifier: str = "workgraph.aiida_dict"
_socket_identifier: str = "workgraph.aiida_dict"
_socket_property_identifier: str = "workgraph.aiida_dict"


class SocketAiiDAIntVector(TaskSocket):
"""Socket with a AiiDAIntVector property."""

identifier: str = "workgraph.aiida_int_vector"
property_identifier: str = "workgraph.aiida_int_vector"
_socket_identifier: str = "workgraph.aiida_int_vector"
_socket_property_identifier: str = "workgraph.aiida_int_vector"


class SocketAiiDAFloatVector(TaskSocket):
"""Socket with a FloatVector property."""

identifier: str = "workgraph.aiida_float_vector"
property_identifier: str = "workgraph.aiida_float_vector"
_socket_identifier: str = "workgraph.aiida_float_vector"
_socket_property_identifier: str = "workgraph.aiida_float_vector"


class SocketStructureData(TaskSocket):
"""Any socket."""

identifier: str = "workgraph.aiida_structuredata"
property_identifier: str = "workgraph.aiida_structuredata"
_socket_identifier: str = "workgraph.aiida_structuredata"
_socket_property_identifier: str = "workgraph.aiida_structuredata"
7 changes: 3 additions & 4 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from aiida_workgraph.properties import property_pool
from aiida_workgraph.sockets import socket_pool
from aiida_workgraph.socket import NodeSocketNamespace
from node_graph_widget import NodeGraphWidget
from aiida_workgraph.collection import (
WorkGraphPropertyCollection,
WorkGraphInputSocketCollection,
WorkGraphOutputSocketCollection,
)
import aiida
from typing import Any, Dict, Optional, Union, Callable, List, Set, Iterable
Expand Down Expand Up @@ -38,8 +37,8 @@ def __init__(
"""
super().__init__(
property_collection_class=WorkGraphPropertyCollection,
input_collection_class=WorkGraphInputSocketCollection,
output_collection_class=WorkGraphOutputSocketCollection,
input_collection_class=NodeSocketNamespace,
output_collection_class=NodeSocketNamespace,
**kwargs,
)
self.context_mapping = {} if context_mapping is None else context_mapping
Expand Down
6 changes: 3 additions & 3 deletions src/aiida_workgraph/tasks/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_sockets(self) -> None:
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
inp.link_limit = 100000
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")

Expand All @@ -50,7 +50,7 @@ def create_sockets(self) -> None:
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
inp.link_limit = 100000
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")

Expand Down Expand Up @@ -79,6 +79,6 @@ def create_sockets(self) -> None:
inp = self.add_input("workgraph.any", "timeout")
inp.add_property("workgraph.any", default=86400.0)
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
inp.link_limit = 100000
inp.socket_link_limit = 100000
self.add_output("workgraph.any", "result")
self.add_output("workgraph.any", "_wait")
8 changes: 4 additions & 4 deletions src/aiida_workgraph/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def link_creation_hook(self, link: Any) -> None:
"type": "add_link",
"data": {
"from_node": link.from_node.name,
"from_socket": link.from_socket.name,
"from_socket": link.from_socket.socket_name,
"to_node": link.to_node.name,
"to_socket": link.to_socket.name,
"to_socket": link.to_socket.socket_name,
},
}
)
Expand All @@ -65,9 +65,9 @@ def link_deletion_hook(self, link: Any) -> None:
"type": "delete_link",
"data": {
"from_node": link.from_node.name,
"from_socket": link.from_socket.name,
"from_socket": link.from_socket.socket_name,
"to_node": link.to_node.name,
"to_socket": link.to_socket.name,
"to_socket": link.to_socket.socket_name,
},
}
)
Loading

0 comments on commit c74bf4e

Please sign in to comment.