Skip to content

Commit

Permalink
add test copy_files
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 3, 2024
1 parent 6f93f5a commit 30216bd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 24 deletions.
34 changes: 18 additions & 16 deletions src/aiida_pythonjob/parsers/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,34 @@ def parse(self, **kwargs):
elif isinstance(exit_code, int):
exit_code = ExitCode(exit_code)
return exit_code

if len(top_level_output_list) > 1:
if len(top_level_output_list) == 1:
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results.pop(top_level_output_list[0]["name"]),
top_level_output_list[0],
)
# if there are any remaining results, raise an warning
if len(results) > 0:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
elif len(top_level_output_list) > 1:
for output in top_level_output_list:
if output["name"] not in results:
if output.get("required", True):
return self.exit_codes.ERROR_MISSING_OUTPUT
else:
output["value"] = self.serialize_output(results.pop(output["name"]), output)
# if there are any remaining results, raise an warning
if results:
if len(results) > 0:
self.logger.warning(
f"Found extra results that are not included in the output: {results.keys()}"
)
elif len(top_level_output_list) == 1:
# if output name in results, use it
if top_level_output_list[0]["name"] in results:
top_level_output_list[0]["value"] = self.serialize_output(
results[top_level_output_list[0]["name"]],
top_level_output_list[0],
)
# otherwise, we assume the results is the output
else:
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])

elif len(top_level_output_list) == 1:
# otherwise it returns a single value, we assume the results is the output
top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0])
Expand All @@ -80,9 +85,6 @@ def parse(self, **kwargs):
except ValueError as exception:
self.logger.error(exception)
return self.exit_codes.ERROR_INVALID_OUTPUT
except Exception as exception:
self.logger.error(exception)
return self.exit_codes.ERROR_INVALID_OUTPUT

def find_output(self, name):
"""Find the output with the given name."""
Expand Down
48 changes: 40 additions & 8 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@

import cloudpickle as pickle
from aiida import orm
from aiida.cmdline.utils.common import get_workchain_report
from aiida.common.links import LinkType
from aiida_pythonjob.parsers import PythonJobParser


def create_retrieved_folder(result: dict):
def create_retrieved_folder(result: dict, output_filename="results.pickle"):
# Create a retrieved ``FolderData`` node with results
with tempfile.TemporaryDirectory() as tmpdir:
dirpath = pathlib.Path(tmpdir)
with open((dirpath / "results.pickle"), "wb") as handle:
with open((dirpath / output_filename), "wb") as handle:
pickle.dump(result, handle)
folder_data = orm.FolderData(tree=dirpath.absolute())
return folder_data


def create_process_node(result: dict, function_data: dict):
def create_process_node(result: dict, function_data: dict, output_filename: str = "results.pickle"):
node = orm.CalcJobNode()
node.set_process_type("aiida.calculations:pythonjob.pythonjob")
function_data = orm.Dict(function_data)
retrieved = create_retrieved_folder(result)
retrieved = create_retrieved_folder(result, output_filename=output_filename)
node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data")
retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved")
function_data.store()
Expand All @@ -30,8 +31,8 @@ def create_process_node(result: dict, function_data: dict):
return node


def create_parser(result, function_data):
node = create_process_node(result, function_data)
def create_parser(result, function_data, output_filename="results.pickle"):
node = create_process_node(result, function_data, output_filename=output_filename)
parser = PythonJobParser(node=node)
return parser

Expand All @@ -55,11 +56,13 @@ def test_tuple_result_mismatch(fixture_localhost):

def test_dict_result(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 3
assert len(parser.outputs) == 2
report = get_workchain_report(parser.node, levelname="WARNING")
assert "Found extra results that are not included in the output: dict_keys(['c'])" in report


def test_dict_result_missing(fixture_localhost):
Expand All @@ -70,6 +73,27 @@ def test_dict_result_missing(fixture_localhost):
assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT


def test_dict_result_as_one_output(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "result"}]}
parser = create_parser(result, function_data)
exit_code = parser.parse()
assert exit_code is None
assert len(parser.outputs) == 1
assert parser.outputs["result"] == result


def test_dict_result_only_show_one_output(fixture_localhost):
result = {"a": 1, "b": 2}
function_data = {"outputs": [{"name": "a"}]}
parser = create_parser(result, function_data)
parser.parse()
assert len(parser.outputs) == 1
assert parser.outputs["a"] == 1
report = get_workchain_report(parser.node, levelname="WARNING")
assert "Found extra results that are not included in the output: dict_keys(['b'])" in report


def test_exit_code(fixture_localhost):
result = {"exit_code": {"status": 1, "message": "error"}}
function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]}
Expand All @@ -78,3 +102,11 @@ def test_exit_code(fixture_localhost):
assert exit_code is not None
assert exit_code.status == 1
assert exit_code.message == "error"


def test_no_output_file(fixture_localhost):
result = {"a": 1, "b": 2, "c": 3}
function_data = {"outputs": [{"name": "result"}]}
parser = create_parser(result, function_data, output_filename="not_results.pickle")
exit_code = parser.parse()
assert exit_code == parser.exit_codes.ERROR_READING_OUTPUT_FILE
24 changes: 24 additions & 0 deletions tests/test_pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,30 @@ def add(x, y):
assert "result.txt" in result["retrieved"].list_object_names()


def test_copy_files(fixture_localhost):
"""Test function with copy files."""

def add(x, y):
z = x + y
with open("result.txt", "w") as f:
f.write(str(z))

def multiply(x_folder_name, y):
with open(f"{x_folder_name}/result.txt", "r") as f:
x = int(f.read())
return x * y

inputs = prepare_pythonjob_inputs(add, function_inputs={"x": 1, "y": 2})
result, node = run_get_node(PythonJob, inputs=inputs)
inputs = prepare_pythonjob_inputs(
multiply,
function_inputs={"x_folder_name": "x_folder_name", "y": 2},
copy_files={"x_folder_name": result["remote_folder"]},
)
result, node = run_get_node(PythonJob, inputs=inputs)
assert result["result"].value == 6


def test_exit_code(fixture_localhost):
"""Test function with exit code."""
from numpy import array
Expand Down

0 comments on commit 30216bd

Please sign in to comment.