Skip to content

Commit

Permalink
add tab completion
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 8, 2024
1 parent 96ad5e7 commit 137e3df
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 199 deletions.
10 changes: 5 additions & 5 deletions docs/gallery/concept/autogen/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def multiply(x, y):
return x * y


print("Input ports: ", multiply.task().inputs.keys())
print("Output ports: ", multiply.task().outputs.keys())
print("Input ports: ", multiply.task().get_input_names())
print("Output ports: ", multiply.task().get_output_names())

multiply.task().to_html()

Expand All @@ -49,8 +49,8 @@ def add_minus(x, y):
return {"sum": x + y, "difference": x - y}


print("Input ports: ", add_minus.task().inputs.keys())
print("Ouput ports: ", add_minus.task().outputs.keys())
print("Input ports: ", add_minus.task().get_input_names())
print("Ouput ports: ", add_minus.task().get_output_names())
add_minus.task().to_html()


Expand Down Expand Up @@ -140,7 +140,7 @@ def add(x, y):

def create_sockets(self):
# create a General port.
inp = self.inputs.new("workgraph.Any", "symbols")
inp = self.add_input("workgraph.Any", "symbols")
# add a string property to the port with default value "H".
inp.add_property("String", "default", default="H")

Expand Down
18 changes: 9 additions & 9 deletions docs/gallery/concept/autogen/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def multiply(x, y):
#

add1 = add.task()
print("Inputs:", add1.inputs.keys())
print("Outputs:", add1.outputs.keys())
print("Inputs:", add1.get_input_names())
print("Outputs:", add1.get_output_names())


######################################################################
Expand All @@ -62,8 +62,8 @@ def add_minus(x, y):
return {"sum": x + y, "difference": x - y}


print("Inputs:", add_minus.task().inputs.keys())
print("Outputs:", add_minus.task().outputs.keys())
print("Inputs:", add_minus.task().get_input_names())
print("Outputs:", add_minus.task().get_output_names())

######################################################################
# One can also add an ``identifier`` to indicates the data type. The data
Expand Down Expand Up @@ -194,11 +194,11 @@ class MyAdd(Task):
}

def create_sockets(self):
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.Any", "x")
inp = self.inputs.new("workgraph.Any", "y")
self.outputs.new("workgraph.Any", "sum")
self.inputs._clear()
self.outputs._clear()
inp = self.add_input("workgraph.Any", "x")
inp = self.add_input("workgraph.Any", "y")
self.add_output("workgraph.Any", "sum")


######################################################################
Expand Down
22 changes: 11 additions & 11 deletions src/aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class TaskCollection(NodeCollection):
def new(
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
Expand All @@ -33,21 +33,21 @@ def new(
outputs=kwargs.get("outputs", None),
parser_outputs=kwargs.pop("parser_outputs", None),
)
task = super().new(identifier, name, uuid, **kwargs)
task = super()._new(identifier, name, uuid, **kwargs)
return task
if isinstance(identifier, str) and identifier.upper() == "WHILE":
task = super().new("workgraph.while", name, uuid, **kwargs)
task = super()._new("workgraph.while", name, uuid, **kwargs)
return task
if isinstance(identifier, str) and identifier.upper() == "IF":
task = super().new("workgraph.if", name, uuid, **kwargs)
task = super()._new("workgraph.if", name, uuid, **kwargs)
return task
if isinstance(identifier, WorkGraph):
identifier = build_task_from_workgraph(identifier)
return super().new(identifier, name, uuid, **kwargs)
return super()._new(identifier, name, uuid, **kwargs)


class WorkGraphPropertyCollection(PropertyCollection):
def new(
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
Expand All @@ -59,11 +59,11 @@ def new(
if callable(identifier):
identifier = build_property_from_AiiDA(identifier)
# Call the original new method
return super().new(identifier, name, **kwargs)
return super()._new(identifier, name, **kwargs)


class WorkGraphInputSocketCollection(InputSocketCollection):
def new(
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
Expand All @@ -75,11 +75,11 @@ def new(
if callable(identifier):
identifier = build_socket_from_AiiDA(identifier)
# Call the original new method
return super().new(identifier, name, **kwargs)
return super()._new(identifier, name, **kwargs)


class WorkGraphOutputSocketCollection(OutputSocketCollection):
def new(
def _new(
self,
identifier: Union[Callable, str],
name: Optional[str] = None,
Expand All @@ -91,4 +91,4 @@ def new(
if callable(identifier):
identifier = build_socket_from_AiiDA(identifier)
# Call the original new method
return super().new(identifier, name, **kwargs)
return super()._new(identifier, name, **kwargs)
6 changes: 3 additions & 3 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def set_context(self, context: Dict[str, Any]) -> None:
key is the context key, value is the output key.
"""
# all values should belong to the outputs.keys()
remain_keys = set(context.values()).difference(self.outputs.keys())
remain_keys = set(context.values()).difference(self.get_output_names())
if remain_keys:
msg = f"Keys {remain_keys} are not in the outputs of this task."
raise ValueError(msg)
Expand All @@ -103,8 +103,8 @@ def process_nested_inputs(
full_key = f"{base_key}.{sub_key}" if base_key else sub_key

# Create a new input socket if it does not exist
if full_key not in self.inputs.keys() and dynamic:
self.inputs.new(
if full_key not in self.get_input_names() and dynamic:
self.add_input(
"workgraph.any",
name=full_key,
metadata={"required": True},
Expand Down
Loading

0 comments on commit 137e3df

Please sign in to comment.