From 70550f99640c2dd6b0f87e2e3013c5a5c5f9f350 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 14 Oct 2020 17:31:22 -0700 Subject: [PATCH] Phase Sequence: modifier functions that copy must use attr.evolve (#961) PiperOrigin-RevId: 337203373 --- openhtf/core/phase_collections.py | 15 ++++++--- test/core/phase_branches_test.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/openhtf/core/phase_collections.py b/openhtf/core/phase_collections.py index d58d1d78c..a1b1e79d7 100644 --- a/openhtf/core/phase_collections.py +++ b/openhtf/core/phase_collections.py @@ -155,28 +155,33 @@ def _asdict(self) -> Dict[Text, Any]: def with_args(self: SequenceClassT, **kwargs: Any) -> SequenceClassT: """Send these keyword-arguments when phases are called.""" - return type(self)( + return attr.evolve( + self, nodes=tuple(n.with_args(**kwargs) for n in self.nodes), name=util.format_string(self.name, kwargs)) def with_plugs(self: SequenceClassT, **subplugs: Type[base_plugs.BasePlug]) -> SequenceClassT: """Substitute plugs for placeholders for this phase, error on unknowns.""" - return type(self)( + return attr.evolve( + self, nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes), name=util.format_string(self.name, subplugs)) def load_code_info(self: SequenceClassT) -> SequenceClassT: """Load coded info for all contained phases.""" - return type(self)( - nodes=tuple(n.load_code_info() for n in self.nodes), name=self.name) + return attr.evolve( + self, + nodes=tuple(n.load_code_info() for n in self.nodes), + name=self.name) def apply_to_all_phases( self: SequenceClassT, func: Callable[[phase_descriptor.PhaseDescriptor], phase_descriptor.PhaseDescriptor] ) -> SequenceClassT: """Apply func to all contained phases.""" - return type(self)( + return attr.evolve( + self, nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes), name=self.name) diff --git a/test/core/phase_branches_test.py b/test/core/phase_branches_test.py index c660de16f..da461dce6 100644 --- a/test/core/phase_branches_test.py +++ b/test/core/phase_branches_test.py @@ -72,6 +72,58 @@ def test_as_dict(self): } self.assertEqual(expected, branch._asdict()) + def test_with_args(self): + branch = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase,), + name='name_{arg}') + expected = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase.with_args(arg=1),), + name='name_1') + + self.assertEqual(expected, branch.with_args(arg=1)) + + def test_with_plugs(self): + + class MyPlug(htf.BasePlug): + pass + + branch = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase,), + name='name_{my_plug.__name__}') + expected = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase.with_plugs(my_plug=MyPlug),), + name='name_MyPlug') + + self.assertEqual(expected, branch.with_plugs(my_plug=MyPlug)) + + def test_load_code_info(self): + branch = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase,)) + expected = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase.load_code_info(),)) + + self.assertEqual(expected, branch.load_code_info()) + + def test_apply_to_all_phases(self): + + def do_rename(phase): + return _rename(phase, 'blah_blah') + + branch = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase,)) + expected = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(do_rename(run_phase),)) + + self.assertEqual(expected, branch.apply_to_all_phases(do_rename)) + class BranchSequenceIntegrationTest(htf_test.TestCase):