Skip to content

Commit

Permalink
Adds requires attribute to Command (#822)
Browse files Browse the repository at this point in the history
The requires attribute is used by workload_commands to check if the command includes all necessary arguments.
  • Loading branch information
danjujan authored Sep 12, 2023
1 parent 3c019db commit 04bfd56
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 7 deletions.
14 changes: 14 additions & 0 deletions tests/experiment/test_workload_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def test_workload_commands_tags_selected(self) -> None:
)
self.assertEqual(len(commands), 1)

def test_workload_commands_requires(self) -> None:
revision = Revision(Xz, Variant(Xz.SOURCE[0], "c5c7ceb08a"))
project = Xz(revision=revision)
binary = Xz.binaries_for_revision(ShortCommitHash("c5c7ceb08a"))[0]

commands = wu.workload_commands(
project, binary, [wu.WorkloadCategory.EXAMPLE]
)
self.assertEqual(len(commands), 1)
commands = wu.workload_commands(
project, binary, [wu.WorkloadCategory.MEDIUM]
)
self.assertEqual(len(commands), 1)


class TestWorkloadFilenames(unittest.TestCase):

Expand Down
28 changes: 27 additions & 1 deletion varats-core/varats/experiment/workload_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Command,
)

from varats.experiment.experiment_util import get_extra_config_options
from varats.project.project_util import ProjectBinaryWrapper
from varats.project.varats_project import VProject
from varats.report.report import KeyedReportAggregate, ReportTy
Expand Down Expand Up @@ -92,8 +93,33 @@ def workload_commands(
)
]

# Filter commands that have required args set.
extra_options = set(get_extra_config_options(project))

def requires_any_filter(prj_cmd: ProjectCommand) -> bool:
if hasattr(
prj_cmd.command, "requires_any"
) and prj_cmd.command.requires_any:
args = set(prj_cmd.command._args).union(extra_options)
return bool(args.intersection(prj_cmd.command.requires_any))
return True

def requires_all_filter(prj_cmd: ProjectCommand) -> bool:
if hasattr(
prj_cmd.command, "requires_all"
) and prj_cmd.command.requires_all:
args = set(prj_cmd.command._args).union(extra_options)
return bool(prj_cmd.command.requires_all.issubset(args))
return True

available_cmds = filter(
requires_all_filter, filter(requires_any_filter, project_cmds)
)

return list(
filter(lambda prj_cmd: prj_cmd.path.name == binary.name, project_cmds)
filter(
lambda prj_cmd: prj_cmd.path.name == binary.name, available_cmds
)
)


Expand Down
33 changes: 33 additions & 0 deletions varats-core/varats/project/project_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import benchbuild as bb
import pygit2
from benchbuild.command import Command
from benchbuild.source import Git
from benchbuild.utils.cmd import git
from plumbum import local
Expand Down Expand Up @@ -382,3 +383,35 @@ def copy_renamed_git_to_dest(src_dir: Path, dest_dir: Path) -> None:
for name in dirs:
if name == ".gitted":
os.rename(os.path.join(root, name), os.path.join(root, ".git"))


class VCommand(Command): # type: ignore [misc]
"""
Wrapper around benchbuild's Command class.
Attributes:
requires_any: sufficient args that must be available for successful execution.
requires_all: all args that must be available for successful execution.
"""

_requires: tp.Set[str]

def __init__(
self,
*args: tp.Any,
requires_any: tp.Optional[tp.Set[str]] = None,
requires_all: tp.Optional[tp.Set[str]] = None,
**kwargs: tp.Union[str, tp.List[str]],
) -> None:

super().__init__(*args, **kwargs)
self._requires_any = requires_any if requires_any else set()
self._requires_all = requires_all if requires_all else set()

@property
def requires_any(self) -> tp.Set[str]:
return self._requires_any

@property
def requires_all(self) -> tp.Set[str]:
return self._requires_all
20 changes: 14 additions & 6 deletions varats/varats/projects/c_projects/xz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as tp

import benchbuild as bb
from benchbuild.command import Command, SourceRoot, WorkloadSet
from benchbuild.command import SourceRoot, WorkloadSet
from benchbuild.source import HTTPMultiple
from benchbuild.utils.cmd import autoreconf, make
from benchbuild.utils.revision_ranges import (
Expand All @@ -18,6 +18,7 @@
from varats.paper.paper_config import PaperConfigSpecificGit
from varats.project.project_domain import ProjectDomains
from varats.project.project_util import (
VCommand,
ProjectBinaryWrapper,
get_local_project_git_path,
BinaryType,
Expand Down Expand Up @@ -84,26 +85,33 @@ class Xz(VProject):

WORKLOADS = {
WorkloadSet(WorkloadCategory.EXAMPLE): [
Command(
VCommand(
SourceRoot("xz") / RSBinary("xz"),
"-k",
"geo-maps/countries-land-1km.geo.json",
# Use output_param to ensure input file
# gets appended after all arguments.
output_param=["{output}"],
output=SourceRoot("geo-maps/countries-land-250m.geo.json"),
label="countries-land-1km",
creates=["geo-maps/countries-land-1km.geo.json.xz"]
)
],
WorkloadSet(WorkloadCategory.MEDIUM): [
Command(
VCommand(
SourceRoot("xz") / RSBinary("xz"),
"-k",
"-9e",
"--compress",
"--threads=1",
"--format=xz",
"-vv",
"geo-maps/countries-land-250m.geo.json",
# Use output_param to ensure input file
# gets appended after all arguments.
output_param=["{output}"],
output=SourceRoot("geo-maps/countries-land-250m.geo.json"),
label="countries-land-250m",
creates=["geo-maps/countries-land-250m.geo.json.xz"]
creates=["geo-maps/countries-land-250m.geo.json.xz"],
requires_all={"--compress"},
)
],
}
Expand Down

0 comments on commit 04bfd56

Please sign in to comment.