diff --git a/tests/experiment/test_workload_util.py b/tests/experiment/test_workload_util.py index 729dedc24..b8de4db27 100644 --- a/tests/experiment/test_workload_util.py +++ b/tests/experiment/test_workload_util.py @@ -52,6 +52,20 @@ def test_workload_commands_tags_selected(self) -> None: ) self.assertEqual(len(commands), 1) + def test_workload_commands_requires(self) -> None: + revision = Revision(Xz, Variant(Xz.SOURCE[0], "c5c7ceb08a")) + project = Xz(revision=revision) + binary = Xz.binaries_for_revision(ShortCommitHash("c5c7ceb08a"))[0] + + commands = wu.workload_commands( + project, binary, [wu.WorkloadCategory.EXAMPLE] + ) + self.assertEqual(len(commands), 1) + commands = wu.workload_commands( + project, binary, [wu.WorkloadCategory.MEDIUM] + ) + self.assertEqual(len(commands), 1) + class TestWorkloadFilenames(unittest.TestCase): diff --git a/varats-core/varats/experiment/workload_util.py b/varats-core/varats/experiment/workload_util.py index 8cf66daff..4566e2ee7 100644 --- a/varats-core/varats/experiment/workload_util.py +++ b/varats-core/varats/experiment/workload_util.py @@ -19,6 +19,7 @@ Command, ) +from varats.experiment.experiment_util import get_extra_config_options from varats.project.project_util import ProjectBinaryWrapper from varats.project.varats_project import VProject from varats.report.report import KeyedReportAggregate, ReportTy @@ -92,8 +93,33 @@ def workload_commands( ) ] + # Filter commands that have required args set. + extra_options = set(get_extra_config_options(project)) + + def requires_any_filter(prj_cmd: ProjectCommand) -> bool: + if hasattr( + prj_cmd.command, "requires_any" + ) and prj_cmd.command.requires_any: + args = set(prj_cmd.command._args).union(extra_options) + return bool(args.intersection(prj_cmd.command.requires_any)) + return True + + def requires_all_filter(prj_cmd: ProjectCommand) -> bool: + if hasattr( + prj_cmd.command, "requires_all" + ) and prj_cmd.command.requires_all: + args = set(prj_cmd.command._args).union(extra_options) + return bool(prj_cmd.command.requires_all.issubset(args)) + return True + + available_cmds = filter( + requires_all_filter, filter(requires_any_filter, project_cmds) + ) + return list( - filter(lambda prj_cmd: prj_cmd.path.name == binary.name, project_cmds) + filter( + lambda prj_cmd: prj_cmd.path.name == binary.name, available_cmds + ) ) diff --git a/varats-core/varats/project/project_util.py b/varats-core/varats/project/project_util.py index a4c27d74d..3028c438b 100644 --- a/varats-core/varats/project/project_util.py +++ b/varats-core/varats/project/project_util.py @@ -7,6 +7,7 @@ import benchbuild as bb import pygit2 +from benchbuild.command import Command from benchbuild.source import Git from benchbuild.utils.cmd import git from plumbum import local @@ -382,3 +383,35 @@ def copy_renamed_git_to_dest(src_dir: Path, dest_dir: Path) -> None: for name in dirs: if name == ".gitted": os.rename(os.path.join(root, name), os.path.join(root, ".git")) + + +class VCommand(Command): # type: ignore [misc] + """ + Wrapper around benchbuild's Command class. + + Attributes: + requires_any: sufficient args that must be available for successful execution. + requires_all: all args that must be available for successful execution. + """ + + _requires: tp.Set[str] + + def __init__( + self, + *args: tp.Any, + requires_any: tp.Optional[tp.Set[str]] = None, + requires_all: tp.Optional[tp.Set[str]] = None, + **kwargs: tp.Union[str, tp.List[str]], + ) -> None: + + super().__init__(*args, **kwargs) + self._requires_any = requires_any if requires_any else set() + self._requires_all = requires_all if requires_all else set() + + @property + def requires_any(self) -> tp.Set[str]: + return self._requires_any + + @property + def requires_all(self) -> tp.Set[str]: + return self._requires_all diff --git a/varats/varats/projects/c_projects/xz.py b/varats/varats/projects/c_projects/xz.py index 1fac7c349..3299a0e08 100644 --- a/varats/varats/projects/c_projects/xz.py +++ b/varats/varats/projects/c_projects/xz.py @@ -2,7 +2,7 @@ import typing as tp import benchbuild as bb -from benchbuild.command import Command, SourceRoot, WorkloadSet +from benchbuild.command import SourceRoot, WorkloadSet from benchbuild.source import HTTPMultiple from benchbuild.utils.cmd import autoreconf, make from benchbuild.utils.revision_ranges import ( @@ -18,6 +18,7 @@ from varats.paper.paper_config import PaperConfigSpecificGit from varats.project.project_domain import ProjectDomains from varats.project.project_util import ( + VCommand, ProjectBinaryWrapper, get_local_project_git_path, BinaryType, @@ -84,16 +85,19 @@ class Xz(VProject): WORKLOADS = { WorkloadSet(WorkloadCategory.EXAMPLE): [ - Command( + VCommand( SourceRoot("xz") / RSBinary("xz"), "-k", - "geo-maps/countries-land-1km.geo.json", + # Use output_param to ensure input file + # gets appended after all arguments. + output_param=["{output}"], + output=SourceRoot("geo-maps/countries-land-250m.geo.json"), label="countries-land-1km", creates=["geo-maps/countries-land-1km.geo.json.xz"] ) ], WorkloadSet(WorkloadCategory.MEDIUM): [ - Command( + VCommand( SourceRoot("xz") / RSBinary("xz"), "-k", "-9e", @@ -101,9 +105,13 @@ class Xz(VProject): "--threads=1", "--format=xz", "-vv", - "geo-maps/countries-land-250m.geo.json", + # Use output_param to ensure input file + # gets appended after all arguments. + output_param=["{output}"], + output=SourceRoot("geo-maps/countries-land-250m.geo.json"), label="countries-land-250m", - creates=["geo-maps/countries-land-250m.geo.json.xz"] + creates=["geo-maps/countries-land-250m.geo.json.xz"], + requires_all={"--compress"}, ) ], }