From 94bebbcefa530675ad39ad90e5e303598d06ab0b Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Wed, 14 Oct 2020 14:58:09 -0700 Subject: [PATCH] Branches (#960) * Drop py2.7 tests from travis config and setup.py. * diagnoses_lib: Serialize diagnoses result values, not names Diagnosis results should values should be serialized out rather than the enum constant names. PiperOrigin-RevId: 318469885 * Attachments: Handle case where NamedTemporaryFile object stops working. We have noticed times where we get bad file descriptor exceptions trying to reload data from the temporary files. For now, just return empty byte strings while we investigate more. Also, adding a callback to close the Attachments' temporary files. PiperOrigin-RevId: 319879865 * test_record: Fix lint issues. PiperOrigin-RevId: 320025482 * output callbacks: Fix CloseAttachments Attachments are stored in a dictionary under the phase record, not a list. PiperOrigin-RevId: 321406529 * OpenHTF: Fix lint issues and run autoformatter. PiperOrigin-RevId: 323661267 * Initial pytype-based type annotations. PiperOrigin-RevId: 324720299 * conf: Use inspect.getfullargspec. PiperOrigin-RevId: 324754078 * Internal change PiperOrigin-RevId: 325271171 * Only have the attachment's temporary file open while reading/writing. If there are a lot of attachments, the program can exceed Linux's max allowed open files per process. PiperOrigin-RevId: 325490340 * Add core annotations and replace mutablerecords and namedtuple with attr. PhaseDescriptor will be in the next commit. PiperOrigin-RevId: 327075946 * Rearrange and add comments for the Measurement fields. PiperOrigin-RevId: 327140593 * Internal change PiperOrigin-RevId: 327735936 * Convert PhaseOptions and PhaseDescriptor to attr. PiperOrigin-RevId: 328778682 * callbacks: Add type annotations Add type annotations to the callbacks library and break apart some complex types. Also breaking out the JSON conversion logic to an independent function for easier use by other modules. PiperOrigin-RevId: 328972148 * Internal change PiperOrigin-RevId: 329550447 * Internal change PiperOrigin-RevId: 331590164 * Internal change PiperOrigin-RevId: 331890539 * Remove TestPhase alias. The TestPhase alias for PhaseOptions has long been deprecated. Removing it. PiperOrigin-RevId: 332291802 * Add more type annotations. PiperOrigin-RevId: 332519023 * Remove plugs-phase_descriptor circular dependency Remove the plugs to phase_descriptor circular dependency by moving the pieces phase_descriptor depends on to a new core/base_plugs.py file. PiperOrigin-RevId: 332527874 * PhaseDescriptor: with_plugs and with_args now ignore unknowns Change with_plugs and with_args to use their with_known_plugs and with_known_args implementations instead. PiperOrigin-RevId: 332546916 * Add type checking to unit tests to verify things are working. PiperOrigin-RevId: 332551203 * test_descriptor: Remove Test Teardown. Test teardown was deprecated in favor of PhaseGroup teardowns. Fully removing them. PiperOrigin-RevId: 333112589 * Add PhaseNode and implement PhaseSequence. Phase nodes are now the basic building block of the OpenHTF execution engine. Phase sequence is a phase node that constains a sequence of phase nodes. Phase groups now use phase sequences to contain its setup, main, and teardown phases. PiperOrigin-RevId: 334461205 * Implement Phase Branches Phase branches run phases conditionally based on triggered diagnosis results. PiperOrigin-RevId: 335460872 * PhaseExecutor: Raise on invalid phase result. PiperOrigin-RevId: 335469722 * Implement Subtests. Subtests are a collection of phases that can indepenently fail and skip the rest of the phases while still working with PhaseGroup teardowns. PiperOrigin-RevId: 335472411 * util/test: Allow customizing the test_start_function. The OpenHTF TestCase can now customize the test start function by setting the `test_start_function` attribute. This change will now force unit tests to call super().setUp() in all cases. PiperOrigin-RevId: 335509542 * Implement Phase Checkpoints Phase checkpoints are nodes that check if a diagnosis result has been triggered or a simple set of phases has failed. In those cases, they will either be resolved as FAIL_SUBTEST or STOP. PiperOrigin-RevId: 335513243 * Refactor TestApi and expose diagnoses store. The TestApi object should just be a proxy to the functions on the PhaseState and TestState instances. Add the diagnoses store to the API for simpler access during phases. PiperOrigin-RevId: 336213159 * Subtest skip phases when they fail. Subtests should skip the later phases with special handling for other nodes as documented in event_sequence.md. PiperOrigin-RevId: 336386553 * Fix typo in TestApi.diagnoses_store PiperOrigin-RevId: 336408351 * Internal change PiperOrigin-RevId: 336691169 * Add DimensionPivot. DimensionPivot is a validator that runs a subvalidator on each value independently for a dimensioned measurement. If any value fails, the measurement is a failure. PiperOrigin-RevId: 336730422 * Drop PY2 support from OpenHTF. PiperOrigin-RevId: 337139629 --- .travis.yml | 2 - CHANGELOG | 22 + bin/units_from_xls.py | 96 +-- docs/event_sequence.md | 94 ++- examples/all_the_things.py | 82 +- examples/checkpoints.py | 23 +- examples/example_plugs.py | 19 +- examples/frontend_example.py | 13 +- examples/hello_world.py | 7 +- examples/ignore_early_canceled_tests.py | 7 +- examples/measurements.py | 30 +- examples/phase_groups.py | 50 +- examples/repeat.py | 33 +- examples/stop_on_first_failure.py | 18 +- examples/with_plugs.py | 45 +- openhtf/__init__.py | 19 +- openhtf/core/base_plugs.py | 192 +++++ openhtf/core/diagnoses_lib.py | 164 ++-- openhtf/core/measurements.py | 413 ++++++---- openhtf/core/monitors.py | 61 +- openhtf/core/phase_branches.py | 272 ++++++ openhtf/core/phase_collections.py | 236 ++++++ openhtf/core/phase_descriptor.py | 311 +++---- openhtf/core/phase_executor.py | 225 +++-- openhtf/core/phase_group.py | 337 +++----- openhtf/core/phase_nodes.py | 67 ++ openhtf/core/test_descriptor.py | 404 +++++---- openhtf/core/test_executor.py | 405 ++++++--- openhtf/core/test_record.py | 386 ++++++--- openhtf/core/test_state.py | 394 +++++---- openhtf/output/callbacks/__init__.py | 98 ++- openhtf/output/callbacks/console_summary.py | 48 +- openhtf/output/callbacks/json_factory.py | 102 ++- openhtf/output/callbacks/mfg_inspector.py | 43 +- openhtf/output/proto/mfg_event_converter.py | 31 +- openhtf/output/proto/test_runs_converter.py | 18 +- openhtf/output/servers/dashboard_server.py | 73 +- openhtf/output/servers/pub_sub.py | 21 +- openhtf/output/servers/station_server.py | 77 +- openhtf/output/servers/web_gui_server.py | 8 +- .../app/plugs/user-input-plug.component.html | 4 +- openhtf/plugs/__init__.py | 361 +++----- openhtf/plugs/cambrionix/__init__.py | 50 +- openhtf/plugs/device_wrapping.py | 37 +- openhtf/plugs/generic/serial_collection.py | 43 +- openhtf/plugs/usb/__init__.py | 35 +- openhtf/plugs/usb/adb_device.py | 47 +- openhtf/plugs/usb/adb_message.py | 61 +- openhtf/plugs/usb/adb_protocol.py | 188 +++-- openhtf/plugs/usb/fastboot_device.py | 33 +- openhtf/plugs/usb/fastboot_protocol.py | 104 ++- openhtf/plugs/usb/filesync_service.py | 94 ++- openhtf/plugs/usb/local_usb.py | 54 +- openhtf/plugs/usb/shell_service.py | 41 +- openhtf/plugs/usb/usb_exceptions.py | 30 +- openhtf/plugs/usb/usb_handle.py | 18 +- openhtf/plugs/usb/usb_handle_stub.py | 16 +- openhtf/plugs/user_input.py | 114 +-- openhtf/util/__init__.py | 94 ++- openhtf/util/argv.py | 73 +- openhtf/util/atomic_write.py | 16 +- openhtf/util/checkpoints.py | 34 +- openhtf/util/conf.py | 182 ++-- openhtf/util/console_output.py | 151 ++-- openhtf/util/data.py | 105 ++- openhtf/util/exceptions.py | 74 -- openhtf/util/functions.py | 44 +- openhtf/util/logs.py | 62 +- openhtf/util/multicast.py | 55 +- openhtf/util/test.py | 363 ++++++-- openhtf/util/threads.py | 48 +- openhtf/util/timeouts.py | 104 ++- openhtf/util/units.py | 11 +- openhtf/util/validators.py | 94 ++- openhtf/util/xmlrpcutil.py | 39 +- pylint_plugins/mutablerecords_plugin.py | 49 -- setup.py | 55 +- test/capture_source_test.py | 8 +- test/core/diagnoses_test.py | 85 +- test/core/exe_test.py | 778 +++++++++++++----- test/core/measurements_test.py | 92 ++- test/core/monitors_test.py | 36 +- test/core/phase_branches_test.py | 733 +++++++++++++++++ test/core/phase_collections_test.py | 763 +++++++++++++++++ test/core/phase_group_test.py | 252 +++--- test/core/test_descriptor_test.py | 11 +- test/core/test_record_test.py | 6 +- test/output/callbacks/callbacks_test.py | 20 +- .../callbacks/mfg_event_converter_test.py | 90 +- test/output/callbacks/mfg_inspector_test.py | 35 +- test/phase_descriptor_test.py | 182 ++-- test/plugs/plugs_test.py | 105 +-- test/plugs/user_input_test.py | 4 +- test/test_state_test.py | 34 +- test/util/conf_test.py | 32 +- test/util/data_test.py | 17 +- test/util/functions_test.py | 32 +- test/util/logs_test.py | 3 +- test/util/test_test.py | 23 +- test/util/util_test.py | 30 +- test/util/validators_test.py | 136 +-- tox.ini | 2 +- 102 files changed, 7697 insertions(+), 3741 deletions(-) create mode 100644 CHANGELOG create mode 100644 openhtf/core/base_plugs.py create mode 100644 openhtf/core/phase_branches.py create mode 100644 openhtf/core/phase_collections.py create mode 100644 openhtf/core/phase_nodes.py delete mode 100644 openhtf/util/exceptions.py delete mode 100644 pylint_plugins/mutablerecords_plugin.py create mode 100644 test/core/phase_branches_test.py create mode 100644 test/core/phase_collections_test.py diff --git a/.travis.yml b/.travis.yml index d83a2e628..8df6732a6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,6 @@ language: python matrix: include: - - python: 2.7 - env: TOXENV=py27 - python: 3.6 env: TOXENV=py36 addons: diff --git a/CHANGELOG b/CHANGELOG new file mode 100644 index 000000000..f9f5bff5b --- /dev/null +++ b/CHANGELOG @@ -0,0 +1,22 @@ +Changes for 2.0. + +* Dropped Python 2.x support. +* Added type annotations. +* Replaced mutablerecords with attrs. +* PhaseOptions: + * The openhtf.TestPhase alias for PhaseOptions has been deprecated for a + long time. Removing it. +* PhaseDescriptor: + * with_known_plugs and with_known_args are being rolled into with_plugs and + with_args, respectively. They will no longer raise exceptions when + the names are not found. + * If the options name field is a callable, the name property of + PhaseDescriptors will only return the name of the function rather than + the callable. This ensures that the name property is always Text. +* Test: + * The test teardown has been removed in favor of using a PhaseGroup. +* Unit testing: + * Unit tests using openhtf.util.test.TestCase can customize the test start + function when yielding openhtf.Test instances by setting the + `test_start_function` attribute. This can be set to None to remove the + function. diff --git a/bin/units_from_xls.py b/bin/units_from_xls.py index 1129c854d..36bc70188 100644 --- a/bin/units_from_xls.py +++ b/bin/units_from_xls.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Read in a .xls file and generate a units module for OpenHTF. UNECE, the United Nations Economic Commission for Europe, publishes a set of @@ -39,23 +37,19 @@ spaces into underscores, and converting to uppercase. """ - import argparse import os -import shutil import re +import shutil import sys -from tempfile import mkstemp +import tempfile import six import xlrd - # Column names for the columns we care about. This list must be populated in # the expected order: [, , ]. -COLUMN_NAMES = ['Name', - 'Common\nCode', - 'Symbol'] +COLUMN_NAMES = ['Name', 'Common\nCode', 'Symbol'] PRE = '''# coding: utf-8 # THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT. @@ -100,7 +94,14 @@ import collections -UnitDescriptor = collections.namedtuple('UnitDescriptor', 'name code suffix') +class UnitDescriptor( + collections.namedtuple('UnitDescriptor', [ + 'name', + 'code', + 'suffix', + ])): + pass + ALL_UNITS = [] @@ -120,8 +121,10 @@ # pylint: enable=line-too-long + class UnitLookup(object): """Facilitates user-friendly access to units.""" + def __init__(self, lookup): self._lookup = lookup @@ -144,41 +147,42 @@ def __call__(self, name_or_suffix): ''' SHEET_NAME = 'Annex II & Annex III' -UNIT_KEY_REPLACEMENTS = {' ': '_', - ',' : '_', - '.': '_', - '-': '_', - '/': '_PER_', - '%': 'PERCENT', - '[': '', - ']': '', - '(': '', - ')': '', - "'": '', - '8': 'EIGHT', - '15': 'FIFTEEN', - '30': 'THIRTY', - '\\': '_', - six.unichr(160): '_', - six.unichr(176): 'DEG_', - six.unichr(186): 'DEG_', - six.unichr(8211): '_', - } +UNIT_KEY_REPLACEMENTS = { + ' ': '_', + ',': '_', + '.': '_', + '-': '_', + '/': '_PER_', + '%': 'PERCENT', + '[': '', + ']': '', + '(': '', + ')': '', + "'": '', + '8': 'EIGHT', + '15': 'FIFTEEN', + '30': 'THIRTY', + '\\': '_', + six.unichr(160): '_', # NO-BREAK SPACE + six.unichr(176): 'DEG_', # DEGREE SIGN + six.unichr(186): 'DEG_', # MASCULINE ORDINAL INDICATOR + six.unichr(8211): '_', # EN DASH +} def main(): """Main entry point for UNECE code .xls parsing.""" parser = argparse.ArgumentParser( description='Reads in a .xls file and generates a units module for ' - 'OpenHTF.', + 'OpenHTF.', prog='python units_from_xls.py') - parser.add_argument('xlsfile', type=str, - help='the .xls file to parse') + parser.add_argument('xlsfile', type=str, help='the .xls file to parse') parser.add_argument( '--outfile', type=str, - default=os.path.join(os.path.dirname(__file__), os.path.pardir, - 'openhtf','util', 'units.py'), + default=os.path.join( + os.path.dirname(__file__), os.path.pardir, 'openhtf', 'util', + 'units.py'), help='where to put the generated .py file.') args = parser.parse_args() @@ -188,14 +192,12 @@ def main(): sys.exit() unit_defs = unit_defs_from_sheet( - xlrd.open_workbook(args.xlsfile).sheet_by_name(SHEET_NAME), - COLUMN_NAMES) + xlrd.open_workbook(args.xlsfile).sheet_by_name(SHEET_NAME), COLUMN_NAMES) - _, tmp_path = mkstemp() + _, tmp_path = tempfile.mkstemp() with open(tmp_path, 'w') as new_file: new_file.write(PRE) - new_file.writelines( - [line.encode('utf8', 'replace') for line in unit_defs]) + new_file.writelines([line.encode('utf8', 'replace') for line in unit_defs]) new_file.write(POST) new_file.flush() @@ -209,14 +211,16 @@ def unit_defs_from_sheet(sheet, column_names): Args: sheet: An xldr.sheet object representing a UNECE code worksheet. column_names: A list/tuple with the expected column names corresponding to - the unit name, code and suffix in that order. - Yields: Lines of Python source code that define OpenHTF Unit objects. + the unit name, code and suffix in that order. + + Yields: + Lines of Python source code that define OpenHTF Unit objects. """ seen = set() try: col_indices = {} rows = sheet.get_rows() - + # Find the indices for the columns we care about. for idx, cell in enumerate(six.next(rows)): if cell.value in column_names: @@ -234,9 +238,9 @@ def unit_defs_from_sheet(sheet, column_names): # Split on ' or ' to support the units like '% or pct' for suffix in suffix.split(' or '): - yield "%s = UnitDescriptor('%s', '%s', '''%s''')\n" % ( - key, name, code, suffix) - yield "ALL_UNITS.append(%s)\n" % key + yield "%s = UnitDescriptor('%s', '%s', '''%s''')\n" % (key, name, code, + suffix) + yield 'ALL_UNITS.append(%s)\n' % key except xlrd.XLRDError: sys.stdout.write('Unable to process the .xls file.') diff --git a/docs/event_sequence.md b/docs/event_sequence.md index f18cbae1f..3d95b07cf 100644 --- a/docs/event_sequence.md +++ b/docs/event_sequence.md @@ -6,13 +6,58 @@ further: 1. `test_start`'s plugs are instantiated 1. `test_start` is run in a new thread 1. All plugs for the test are instantiated -1. Each phase is run -1. The teardown phase is run +1. Each phase node is run + 1. If the node is a subtest, each node is run until a FAIL_SUBTEST is + returned by a phase. + 1. If the node is a branch, each node is run if the condition is met + 1. If the node is a sequence, each node is run + 1. If the node is a group, each of the groups sequences is run + 1. If the node is a phase descriptor, that phase is run 1. All plugs' `tearDown` function is called 1. All plugs are deleted 1. Test outcome is calculated as PASS or FAIL 1. Output callbacks are called +## Phase node execution + +``` +[PhaseNode] + | + +--[PhaseDescriptor] + | + +--[Checkpoint] + | + \--[PhaseCollection] + | + +--[PhaseSequence] + | | + | +--[PhaseBranch] + | | + | \--[Subtest] + | + \--[PhaseGroup] +``` + +`PhaseNode`s are the basic building block for OpenHTF's phase execution. They +are a base class that defines a few basic operations that can get recursively +applied. The `PhaseDescriptor` is the primary executable unit that wraps the +phase functions. `PhaseCollection` is a base class for a node that contains +multiple nodes. The primary one of these is the `PhaseSequence`, which is a +tuple of phase nodes; each of those nodes is executed in order with nested +execution if those nodes are other collections. `PhaseBranch`s are phase +sequences that are only run when the Diagnosis Result-based conditions are met. +`Checkpoint` nodes check conditions, like phase failure or a triggered +diagnosis; if that condition is met, they act as a failed phase. `PhaseGroup`s +are phase collections that have three sequences as described below. + +### Recursive nesting + +Phase collections allow for nested nodes where each nested level is handled with +recursion. + +OpenHTF does not check for or handle the situation where a node is nested inside +itself. The current collection types are frozen to prevent this from happening. + ## Test error short-circuiting A phase raising an exception won't kill the test, but will initiate a @@ -23,27 +68,52 @@ If a `test_start` phase is terminal, then the executor will skip to Plug Teardown, where only the plugs initialized for `test_start` have their `teardown` functions called. +In all cases with terminal phases, the Test outcome is ERROR for output +callbacks. + +### PhaseGroups + `PhaseGroup` collections behave like contexts. They are entered if their `setup` phases are all non-terminal; if this happens, the `teardown` phases are guarenteed to run. `PhaseGroup` collections can contain additional `PhaseGroup` instances. If a nested group has a terminal phase, the outer groups will trigger the same shortcut logic. -For terminal phases in a `PhaseGroup`, -* If the phase was a `PhaseGroup.setup` phase, then we skip the rest of the - `PhaseGroup`. -* If the phase was a `PhaseGroup.main` phase, then we skip to the - `PhaseGroup.teardown` phases of that `PhaseGroup`. -* If the phase was a `PhaseGroup.teardown` phase, the rest of the `teardown` - phases are run, but outer groups will trigger the shortcut logic. - -In all cases with terminal phases, the Test outcome is ERROR for output -callbacks. +For terminal phases (or phases that return `FAIL_SUBTEST`) in a `PhaseGroup`, +* If the phase was in the `setup` sequence, then we do not run the rest of + the `PhaseGroup`. +* If the phase was in the `main` sequence, then we do not run the rest of the + `main` sequence and proceed to the `teardown` sequence of that `PhaseGroup`. +* If the phase was in the `teardown` sequence, the rest of the `teardown` + sequence ndoes are run, but outer groups will trigger the shortcut logic. + This also applies to all nested phase nodes. NOTE: If a phase calls `os.abort()` or an equivalent to the C++ `die()` function, then the process dies and you cannot recover the results from this, so try to avoid such behavior in any Python or C++ libraries you use. +### Subtests + +`Subtest`s are Phase Sequences that allow phases to exit early, but continue on +with other phases. A phase can indicate this by returning +`htf.PhaseResult.FAIL_SUBTEST` or with a checkpoint with that result as its +action. The details of subtests are included in the output test record. + +The rest of the phases in a subtest after the failing node will be processed as: + +* Phase descriptors are all skipped. +* Branches are not run at all, as though their condition was evaluated as false. +* Groups entered after the failing node are entirely skipped, including their + `teardown` sequences. +* Groups with the failing node in its `main` sequence will skip the rest of the + `main` sequence, but will run the teardown phases. +* Groups with the failing node in its `setup` sequence will skip the rest of the + setup phases and will record skips for the `main` and `teardown` sequences. +* Groups with the failing node in its `teardown` sequence will still run the + rest of the `teardown` sequence. +* Sequences are recursively processed by these same rules. + +Phase group teardowns are run properly when nested in a subtest. ## Test abortion short-circuiting diff --git a/examples/all_the_things.py b/examples/all_the_things.py index 13c7fe6e7..0a1595382 100644 --- a/examples/all_the_things.py +++ b/examples/all_the_things.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test logic. Run with (your virtualenv must be activated first): python all_the_things.py """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import os.path import time @@ -46,15 +41,15 @@ def example_monitor(example, frontend_aware): @htf.measures( - htf.Measurement( - 'widget_type').matches_regex(r'.*Widget$').doc( - '''This measurement tracks the type of widgets.'''), - htf.Measurement( - 'widget_color').doc('Color of the widget'), + htf.Measurement('widget_type').matches_regex(r'.*Widget$').doc( + """This measurement tracks the type of widgets."""), + htf.Measurement('widget_color').doc('Color of the widget'), htf.Measurement('widget_size').in_range(1, 4).doc('Size of widget')) -@htf.measures('specified_as_args', docstring='Helpful docstring', - units=units.HERTZ, - validators=[util.validators.matches_regex('Measurement')]) +@htf.measures( + 'specified_as_args', + docstring='Helpful docstring', + units=units.HERTZ, + validators=[util.validators.matches_regex('Measurement')]) @htf.plug(example=example_plugs.ExamplePlug) @htf.plug(prompts=user_input.UserInput) def hello_world(test, example, prompts): @@ -72,10 +67,9 @@ def hello_world(test, example, prompts): # Timeout if this phase takes longer than 10 seconds. -@htf.TestPhase(timeout_s=10) +@htf.PhaseOptions(timeout_s=10) @htf.measures( - *(htf.Measurement( - 'level_%s' % i) for i in ['none', 'some', 'all'])) + *(htf.Measurement('level_%s' % i) for i in ['none', 'some', 'all'])) @htf.monitors('monitor_measurement', example_monitor) def set_measurements(test): """Test phase that sets a measurement.""" @@ -95,25 +89,32 @@ def set_measurements(test): units.HERTZ, units.SECOND, htf.Dimension(description='my_angle', unit=units.RADIAN))) def dimensions(test): + """Phase with dimensioned measurements.""" for dim in range(5): test.measurements.dimensions[dim] = 1 << dim - for x, y, z in zip(list(range(1, 5)), list(range(21, 25)), list(range(101, 105))): + for x, y, z in zip( + list(range(1, 5)), list(range(21, 25)), list(range(101, 105))): test.measurements.lots_of_dims[x, y, z] = x + y + z @htf.measures( - htf.Measurement('replaced_min_only').in_range('{min}', 5, type=int), - htf.Measurement('replaced_max_only').in_range(0, '{max}', type=int), - htf.Measurement('replaced_min_max').in_range('{min}', '{max}', type=int), + htf.Measurement('replaced_min_only').in_range('{minimum}', 5, type=int), + htf.Measurement('replaced_max_only').in_range(0, '{maximum}', type=int), + htf.Measurement('replaced_min_max').in_range( + '{minimum}', '{maximum}', type=int), ) -def measures_with_args(test, min, max): +def measures_with_args(test, minimum, maximum): + """Phase with measurement with arguments.""" + del minimum # Unused. + del maximum # Unused. test.measurements.replaced_min_only = 1 test.measurements.replaced_max_only = 1 test.measurements.replaced_min_max = 1 def attachments(test): - test.attach('test_attachment', 'This is test attachment data.'.encode('utf-8')) + test.attach('test_attachment', + 'This is test attachment data.'.encode('utf-8')) test.attach_from_file( os.path.join(os.path.dirname(__file__), 'example_attachment.txt')) @@ -121,12 +122,12 @@ def attachments(test): assert test_attachment.data == b'This is test attachment data.' -@htf.TestPhase(run_if=lambda: False) -def skip_phase(test): +@htf.PhaseOptions(run_if=lambda: False) +def skip_phase(): """Don't run this phase.""" -def analysis(test): +def analysis(test): # pylint: disable=missing-function-docstring level_all = test.get_measurement('level_all') assert level_all.value == 9 test_attachment = test.get_attachment('test_attachment') @@ -136,7 +137,7 @@ def analysis(test): (1, 21, 101, 123), (2, 22, 102, 126), (3, 23, 103, 129), - (4, 24, 104, 132) + (4, 24, 104, 132), ] test.logger.info('Pandas datafram of lots_of_dims \n:%s', lots_of_dims.value.to_dataframe()) @@ -146,21 +147,29 @@ def teardown(test): test.logger.info('Running teardown') -if __name__ == '__main__': +def main(): test = htf.Test( htf.PhaseGroup.with_teardown(teardown)( hello_world, - set_measurements, dimensions, attachments, skip_phase, - measures_with_args.with_args(min=1, max=4), analysis, + set_measurements, + dimensions, + attachments, + skip_phase, + measures_with_args.with_args(minimum=1, maximum=4), + analysis, ), # Some metadata fields, these in particular are used by mfg-inspector, # but you can include any metadata fields. - test_name='MyTest', test_description='OpenHTF Example Test', + test_name='MyTest', + test_description='OpenHTF Example Test', test_version='1.0.0') - test.add_output_callbacks(callbacks.OutputToFile( - './{dut_id}.{metadata[test_name]}.{start_time_millis}.pickle')) - test.add_output_callbacks(json_factory.OutputToJSON( - './{dut_id}.{metadata[test_name]}.{start_time_millis}.json', indent=4)) + test.add_output_callbacks( + callbacks.OutputToFile( + './{dut_id}.{metadata[test_name]}.{start_time_millis}.pickle')) + test.add_output_callbacks( + json_factory.OutputToJSON( + './{dut_id}.{metadata[test_name]}.{start_time_millis}.json', + indent=4)) test.add_output_callbacks(console_summary.ConsoleSummary()) # Example of how to output to testrun protobuf format and save to disk then @@ -174,5 +183,8 @@ def teardown(test): # inspector.save_to_disk('./{dut_id}.{start_time_millis}.pb'), # inspector.upload()) - test.execute(test_start=user_input.prompt_for_test_start()) + + +if __name__ == '__main__': + main() diff --git a/examples/checkpoints.py b/examples/checkpoints.py index 572ac4935..3a6938a07 100644 --- a/examples/checkpoints.py +++ b/examples/checkpoints.py @@ -11,21 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test demonstrating use of checkpoints and measurements.""" import time import openhtf as htf +from examples import measurements as measurements_example from openhtf.output.callbacks import console_summary from openhtf.output.callbacks import json_factory from openhtf.util import checkpoints -from examples import measurements as measurements_example - -@htf.measures(htf.Measurement('fixed_time').in_range(0, 10).doc( - 'This is going to fail validation.').with_units(htf.units.SECOND)) +@htf.measures( + htf.Measurement('fixed_time').in_range( + 0, 10).doc('This is going to fail validation.').with_units( + htf.units.SECOND)) def failing_phase(test): # The 'outcome' of this measurement in the test_record result will be a FAIL # because its value fails the validator specified (0 <= 5 <= 10). @@ -36,12 +36,13 @@ def long_running_phase(test): # A long running phase could be something like a hardware burn-in. This # phase should not run if previous phases have failed, so we make sure # checkpoint phase is run right before this phase. - for i in range(10): + for _ in range(10): test.logger.info('Still running....') time.sleep(10) test.logger.info('Done with long_running_phase') -if __name__ == '__main__': + +def main(): # We instantiate our OpenHTF test with the phases we want to run as args. test = htf.Test( measurements_example.hello_phase, @@ -54,9 +55,7 @@ def long_running_phase(test): # In order to view the result of the test, we have to output it somewhere, # outputting to console is an easy way to do this. - test.add_output_callbacks( - console_summary.ConsoleSummary() - ) + test.add_output_callbacks(console_summary.ConsoleSummary()) # The complete summary is viable in json, including the measurements # included in measurements_example.lots_of_measurements. @@ -66,3 +65,7 @@ def long_running_phase(test): # Unlike hello_world.py, where we prompt for a DUT ID, here we'll just # use an arbitrary one. test.execute(test_start=lambda: 'MyDutId') + + +if __name__ == '__main__': + main() diff --git a/examples/example_plugs.py b/examples/example_plugs.py index 8fbe39eb5..afc0e728f 100644 --- a/examples/example_plugs.py +++ b/examples/example_plugs.py @@ -11,18 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example plugs for OpenHTF.""" -import openhtf.plugs as plugs +from openhtf.core import base_plugs from openhtf.util import conf - -conf.declare('example_plug_increment_size', default_value=1, - description='increment constant for example plug.') +conf.declare( + 'example_plug_increment_size', + default_value=1, + description='increment constant for example plug.') -class ExamplePlug(plugs.BasePlug): # pylint: disable=no-init +class ExamplePlug(base_plugs.BasePlug): # pylint: disable=no-init """Example of a simple plug. This plug simply keeps a value and increments it each time increment() is @@ -79,7 +79,7 @@ def increment(self): return self.value - self.increment_size -class ExampleFrontendAwarePlug(plugs.FrontendAwareBasePlug): +class ExampleFrontendAwarePlug(base_plugs.FrontendAwareBasePlug): """Example of a simple frontend-aware plug. A frontend-aware plug is a plug that agrees to call self.notify_update() @@ -88,9 +88,10 @@ class ExampleFrontendAwarePlug(plugs.FrontendAwareBasePlug): plug's state in real time. See also: - - openhtf.plugs.FrontendAwareBasePlug - - openhtf.plugs.user_input.UserInput + - base_plugs.FrontendAwareBasePlug + - base_plugs.user_input.UserInput """ + def __init__(self): super(ExampleFrontendAwarePlug, self).__init__() self.value = 0 diff --git a/examples/frontend_example.py b/examples/frontend_example.py index d7d4d666d..fdd7ac457 100644 --- a/examples/frontend_example.py +++ b/examples/frontend_example.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Simple OpenHTF test which launches the web GUI client.""" import openhtf as htf -from openhtf.util import conf - from openhtf.output.servers import station_server from openhtf.output.web_gui import web_launcher from openhtf.plugs import user_input +from openhtf.util import conf @htf.measures(htf.Measurement('hello_world_measurement')) @@ -29,11 +26,15 @@ def hello_world(test): test.measurements.hello_world_measurement = 'Hello Again!' -if __name__ == '__main__': +def main(): conf.load(station_server_port='4444') with station_server.StationServer() as server: web_launcher.launch('http://localhost:4444') - for i in range(5): + for _ in range(5): test = htf.Test(hello_world) test.add_output_callbacks(server.publish_final_state) test.execute(test_start=user_input.prompt_for_test_start()) + + +if __name__ == '__main__': + main() diff --git a/examples/hello_world.py b/examples/hello_world.py index bbf5c679c..cff69afba 100644 --- a/examples/hello_world.py +++ b/examples/hello_world.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test logic. Run with (your virtualenv must be activated first): @@ -64,7 +63,7 @@ def hello_world(test): test.measurements.hello_world_measurement = 'Hello Again!' -if __name__ == '__main__': +def main(): # We instantiate our OpenHTF test with the phases we want to run as args. # Multiple phases would be passed as additional args, and additional # keyword arguments may be passed as well. See other examples for more @@ -87,3 +86,7 @@ def hello_world(test): # be set later (OpenHTF will raise an exception when the test completes if # a DUT ID has not been set). test.execute(test_start=user_input.prompt_for_test_start()) + + +if __name__ == '__main__': + main() diff --git a/examples/ignore_early_canceled_tests.py b/examples/ignore_early_canceled_tests.py index 0fc967fb3..a465ea478 100644 --- a/examples/ignore_early_canceled_tests.py +++ b/examples/ignore_early_canceled_tests.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example of excluding certain test records from the output callbacks. In this case, we exclude tests which were aborted before a DUT ID was set, since @@ -25,8 +24,8 @@ """ import openhtf as htf -from openhtf.output.callbacks import json_factory from openhtf.core import test_record +from openhtf.output.callbacks import json_factory from openhtf.plugs import user_input from openhtf.util import console_output @@ -36,8 +35,8 @@ class CustomOutputToJSON(json_factory.OutputToJSON): def __call__(self, record): - if (record.outcome == test_record.Outcome.ABORTED - and record.dut_id == DEFAULT_DUT_ID): + if (record.outcome == test_record.Outcome.ABORTED and + record.dut_id == DEFAULT_DUT_ID): console_output.cli_print( 'Test was aborted at test start. Skipping output to JSON.') else: diff --git a/examples/measurements.py b/examples/measurements.py index b352f435b..23b044041 100644 --- a/examples/measurements.py +++ b/examples/measurements.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test demonstrating use of measurements. Run with (your virtualenv must be activated first): @@ -94,8 +93,10 @@ def lots_of_measurements(test): # measurement against some criteria, or specify additional information # describing the measurement. Validators can get quite complex, for more # details, see the validators.py example. -@htf.measures(htf.Measurement('validated_measurement').in_range(0, 10).doc( - 'This measurement is validated.').with_units(htf.units.SECOND)) +@htf.measures( + htf.Measurement('validated_measurement').in_range( + 0, + 10).doc('This measurement is validated.').with_units(htf.units.SECOND)) def measure_seconds(test): # The 'outcome' of this measurement in the test_record result will be a PASS # because its value passes the validator specified (0 <= 5 <= 10). @@ -107,10 +108,14 @@ def measure_seconds(test): # specify exactly one measurement with that decorator (ie. the first argument # must be a string containing the measurement name). If you want to specify # multiple measurements this way, you can stack multiple decorators. -@htf.measures('inline_kwargs', docstring='This measurement is declared inline!', - units=htf.units.HERTZ, validators=[validators.in_range(0, 10)]) +@htf.measures( + 'inline_kwargs', + docstring='This measurement is declared inline!', + units=htf.units.HERTZ, + validators=[validators.in_range(0, 10)]) @htf.measures('another_inline', docstring='Because why not?') def inline_phase(test): + """Phase that declares a measurements validators as a keyword argument.""" # This measurement will have an outcome of FAIL, because the set value of 15 # will not pass the 0 <= x <= 10 validator. test.measurements.inline_kwargs = 15 @@ -122,17 +127,18 @@ def inline_phase(test): # A multidim measurement including how to convert to a pandas dataframe and # a numpy array. -@htf.measures(htf.Measurement('power_time_series') - .with_dimensions('ms', 'V', 'A')) +@htf.measures( + htf.Measurement('power_time_series').with_dimensions('ms', 'V', 'A')) @htf.measures(htf.Measurement('average_voltage').with_units('V')) @htf.measures(htf.Measurement('average_current').with_units('A')) @htf.measures(htf.Measurement('resistance').with_units('ohm').in_range(9, 11)) def multdim_measurements(test): + """Phase with a multidimensional measurement.""" # Create some fake current and voltage over time data for t in range(10): resistance = 10 - voltage = 10 + 10.0*t - current = voltage/resistance + .01*random.random() + voltage = 10 + 10.0 * t + current = voltage / resistance + .01 * random.random() dimensions = (t, voltage, current) test.measurements['power_time_series'][dimensions] = 0 @@ -156,7 +162,7 @@ def multdim_measurements(test): test.measurements['average_current']) -if __name__ == '__main__': +def main(): # We instantiate our OpenHTF test with the phases we want to run as args. test = htf.Test(hello_phase, again_phase, lots_of_measurements, measure_seconds, inline_phase, multdim_measurements) @@ -172,3 +178,7 @@ def multdim_measurements(test): # Unlike hello_world.py, where we prompt for a DUT ID, here we'll just # use an arbitrary one. test.execute(test_start=lambda: 'MyDutId') + + +if __name__ == '__main__': + main() diff --git a/examples/phase_groups.py b/examples/phase_groups.py index e2d0eb373..0a9c95d6e 100644 --- a/examples/phase_groups.py +++ b/examples/phase_groups.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF Phase Groups. PhaseGroups are used to control phase shortcutting due to terminal errors to @@ -56,11 +55,12 @@ def run_basic_group(): In this example, there are no terminal phases; all phases are run. """ - test = htf.Test(htf.PhaseGroup( - setup=[setup_phase], - main=[main_phase], - teardown=[teardown_phase], - )) + test = htf.Test( + htf.PhaseGroup( + setup=[setup_phase], + main=[main_phase], + teardown=[teardown_phase], + )) test.execute() @@ -71,11 +71,12 @@ def run_setup_error_group(): skipped. The PhaseGroup is not entered, so the teardown phases are also skipped. """ - test = htf.Test(htf.PhaseGroup( - setup=[error_setup_phase], - main=[main_phase], - teardown=[teardown_phase], - )) + test = htf.Test( + htf.PhaseGroup( + setup=[error_setup_phase], + main=[main_phase], + teardown=[teardown_phase], + )) test.execute() @@ -86,11 +87,12 @@ def run_main_error_group(): because the setup phases ran without error, so the teardown phases are run. The other main phase is skipped. """ - test = htf.Test(htf.PhaseGroup( - setup=[setup_phase], - main=[error_main_phase, main_phase], - teardown=[teardown_phase], - )) + test = htf.Test( + htf.PhaseGroup( + setup=[setup_phase], + main=[error_main_phase, main_phase], + teardown=[teardown_phase], + )) test.execute() @@ -111,9 +113,7 @@ def run_nested_groups(): htf.PhaseGroup.with_teardown(inner_teardown_phase)( inner_main_phase), ], - teardown=[teardown_phase] - ) - ) + teardown=[teardown_phase])) test.execute() @@ -132,8 +132,7 @@ def run_nested_error_groups(): main_phase, ], teardown=[teardown_phase], - ) - ) + )) test.execute() @@ -152,15 +151,18 @@ def run_nested_error_skip_unentered_groups(): main_phase, ], teardown=[teardown_phase], - ) - ) + )) test.execute() -if __name__ == '__main__': +def main(): run_basic_group() run_setup_error_group() run_main_error_group() run_nested_groups() run_nested_error_groups() run_nested_error_skip_unentered_groups() + + +if __name__ == '__main__': + main() diff --git a/examples/repeat.py b/examples/repeat.py index 89f7c216b..5dfd038d9 100644 --- a/examples/repeat.py +++ b/examples/repeat.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test logic. Run with (your virtualenv must be activated first): @@ -27,12 +26,12 @@ phase can be limited specifying a PhaseOptions.repeat_limit. """ -from __future__ import print_function import openhtf -import openhtf.plugs as plugs +from openhtf import plugs +from openhtf.core import base_plugs -class FailTwicePlug(plugs.BasePlug): +class FailTwicePlug(base_plugs.BasePlug): """Plug that fails twice raising an exception.""" def __init__(self): @@ -48,16 +47,16 @@ def run(self): return True -class FailAlwaysPlug(plugs.BasePlug): +class FailAlwaysPlug(base_plugs.BasePlug): """Plug that always returns False indicating failure.""" def __init__(self): self.count = 0 def run(self): - """Increments counter and returns False indicating failure""" + """Increments counter and returns False indicating failure.""" self.count += 1 - print("FailAlwaysPlug: Run number %s" % (self.count)) + print('FailAlwaysPlug: Run number %s' % (self.count)) return False @@ -66,15 +65,14 @@ def run(self): # returning PhaseResult.REPEAT to trigger a repeat. The phase will be run a # total of three times: two fails followed by a success @plugs.plug(test_plug=FailTwicePlug) -def phase_repeat(test, test_plug): +def phase_repeat(test_plug): try: test_plug.run() - - except: - print("Error in phase_repeat, will retry") + except: # pylint: disable=bare-except + print('Error in phase_repeat, will retry') return openhtf.PhaseResult.REPEAT - print("Completed phase_repeat") + print('Completed phase_repeat') # This phase demonstrates repeating a phase based upon a result returned from a @@ -83,13 +81,18 @@ def phase_repeat(test, test_plug): # limit the number of retries. @openhtf.PhaseOptions(repeat_limit=5) @plugs.plug(test_plug=FailAlwaysPlug) -def phase_repeat_with_limit(test, test_plug): +def phase_repeat_with_limit(test_plug): result = test_plug.run() if not result: - print("Invalid result in phase_repeat_with_limit, will retry") + print('Invalid result in phase_repeat_with_limit, will retry') return openhtf.PhaseResult.REPEAT -if __name__ == '__main__': + +def main(): test = openhtf.Test(phase_repeat, phase_repeat_with_limit) test.execute(test_start=lambda: 'RepeatDutID') + + +if __name__ == '__main__': + main() diff --git a/examples/stop_on_first_failure.py b/examples/stop_on_first_failure.py index 97f830d3e..2628b2961 100644 --- a/examples/stop_on_first_failure.py +++ b/examples/stop_on_first_failure.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF stop_on_first_failure test executor option. This is feature is very useful when you do not want the test to continue @@ -26,21 +25,16 @@ conf.load(stop_on_first_failure=True) """ - import openhtf as htf - +from openhtf.output.callbacks import console_summary from openhtf.plugs import user_input from openhtf.util import conf # pylint: disable=unused-import from openhtf.util import validators -from openhtf.output.callbacks import console_summary @htf.measures('number_sum', validators=[validators.in_range(0, 5)]) def add_numbers_fails(test): - """Add numbers fails phase - - This phase will return a failed measurement number_sum. - """ + """Add numbers, but measurement number_sum fails.""" test.logger.info('Add numbers 2 and 4') number_sum = 2 + 4 test.measurements.number_sum = number_sum @@ -56,7 +50,7 @@ def hello_world(test): test.measurements.hello_world_measurement = 'Hello World!' -if __name__ == '__main__': +def main(): test = htf.Test(add_numbers_fails, hello_world) test.add_output_callbacks(console_summary.ConsoleSummary()) # Option 1: test.configure @@ -66,4 +60,8 @@ def hello_world(test): # to get same result # conf.load(stop_on_first_failure=True) - test.execute(test_start=user_input.prompt_for_test_start()) \ No newline at end of file + test.execute(test_start=user_input.prompt_for_test_start()) + + +if __name__ == '__main__': + main() diff --git a/examples/with_plugs.py b/examples/with_plugs.py index 4168d39e1..6c45c4a23 100644 --- a/examples/with_plugs.py +++ b/examples/with_plugs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Example OpenHTF test logic. Run with (your virtualenv must be activated first): @@ -25,15 +24,14 @@ end up with the 4 phases you want. """ -from __future__ import print_function import subprocess import time import openhtf as htf -from openhtf import plugs +from openhtf.core import base_plugs -class PingPlug(plugs.BasePlug): +class PingPlug(base_plugs.BasePlug): """This plug simply does a ping against the host attribute.""" host = None @@ -45,15 +43,15 @@ def __init__(self): def _get_command(self, count): # Returns the commandline for pinging the host. return [ - 'ping', - '-c', - str(count), - self.host, + 'ping', + '-c', + str(count), + self.host, ] def run(self, count): command = self._get_command(count) - print("running: %s" % ' '.join(command)) + print('running: %s' % ' '.join(command)) return subprocess.call(command) @@ -75,17 +73,22 @@ class PingDnsB(PingPlug): # passed into the phase so each phase has a unique name. @htf.PhaseOptions(name='Ping-{pinger.host}-{count}') @htf.plug(pinger=PingPlug.placeholder) -@htf.measures( - 'total_time_{pinger.host}_{count}', - htf.Measurement('retcode').equals('{expected_retcode}', type=int) -) +@htf.measures('total_time_{pinger.host}_{count}', + htf.Measurement('retcode').equals('{expected_retcode}', type=int)) def test_ping(test, pinger, count, expected_retcode): """This tests that we can ping a host. The plug, pinger, is expected to be replaced at test creation time, so the placeholder property was used instead of the class directly. + + Args: + test: The test API. + pinger: pinger plug. + count: number of times to ping; filled in using with_args + expected_retcode: expected return code from pinging; filled in using + with_args. """ - del expected_retcode # Not used in the phase, only used by a measurement + del expected_retcode # Not used in the phase, only used by a measurement. start = time.time() retcode = pinger.run(count) elapsed = time.time() - start @@ -93,17 +96,17 @@ def test_ping(test, pinger, count, expected_retcode): test.measurements.retcode = retcode -if __name__ == '__main__': +def main(): # We instantiate our OpenHTF test with the phases we want to run as args. # We're going to use these these plugs to create all our phases using only 1 # written phase. ping_plugs = [ - PingGoogle, - PingDnsA, - PingDnsB, + PingGoogle, + PingDnsA, + PingDnsB, ] - + phases = [ test_ping.with_plugs(pinger=plug).with_args(count=2, expected_retcode=0) for plug in ping_plugs @@ -114,3 +117,7 @@ def test_ping(test, pinger, count, expected_retcode): # Unlike hello_world.py, where we prompt for a DUT ID, here we'll just # use an arbitrary one. test.execute(test_start=lambda: 'MyDutId') + + +if __name__ == '__main__': + main() diff --git a/openhtf/__init__.py b/openhtf/__init__.py index 0d84bd5b6..b148e72ff 100644 --- a/openhtf/__init__.py +++ b/openhtf/__init__.py @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """The main OpenHTF entry point.""" -import pkg_resources import signal from openhtf import plugs from openhtf.core import phase_executor from openhtf.core import test_record +from openhtf.core.base_plugs import BasePlug from openhtf.core.diagnoses_lib import diagnose +from openhtf.core.diagnoses_lib import DiagnosesStore from openhtf.core.diagnoses_lib import Diagnosis from openhtf.core.diagnoses_lib import DiagnosisComponent from openhtf.core.diagnoses_lib import DiagPriority @@ -33,13 +32,22 @@ from openhtf.core.measurements import Measurement from openhtf.core.measurements import measures from openhtf.core.monitors import monitors +from openhtf.core.phase_branches import BranchSequence +from openhtf.core.phase_branches import DiagnosisCheckpoint +from openhtf.core.phase_branches import DiagnosisCondition +from openhtf.core.phase_branches import PhaseFailureCheckpoint +from openhtf.core.phase_collections import PhaseSequence +from openhtf.core.phase_collections import Subtest from openhtf.core.phase_descriptor import PhaseDescriptor from openhtf.core.phase_descriptor import PhaseOptions from openhtf.core.phase_descriptor import PhaseResult from openhtf.core.phase_group import PhaseGroup +from openhtf.core.phase_nodes import PhaseNode from openhtf.core.test_descriptor import Test from openhtf.core.test_descriptor import TestApi from openhtf.core.test_descriptor import TestDescriptor +from openhtf.core.test_record import PhaseRecord +from openhtf.core.test_record import TestRecord from openhtf.plugs import plug from openhtf.util import conf from openhtf.util import console_output @@ -47,9 +55,7 @@ from openhtf.util import functions from openhtf.util import logs from openhtf.util import units - -# TODO: TestPhase is used for legacy reasons and should be deprecated. -TestPhase = PhaseOptions # pylint: disable=invalid-name +import pkg_resources def get_version(): @@ -62,6 +68,7 @@ def get_version(): except pkg_resources.DistributionNotFound: return 'Unknown - Perhaps openhtf was not installed via setup.py or pip.' + __version__ = get_version() # Register signal handler to stop all tests on SIGINT. diff --git a/openhtf/core/base_plugs.py b/openhtf/core/base_plugs.py new file mode 100644 index 000000000..d1d47cd15 --- /dev/null +++ b/openhtf/core/base_plugs.py @@ -0,0 +1,192 @@ +# Copyright 2020 Google Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The plugs module provides boilerplate for accessing hardware. + +Most tests require interaction with external hardware. This module provides +framework support for such interfaces, allowing for automatic setup and +teardown of the objects. + +A plug may be made "frontend-aware", allowing it, in conjunction with the +Station API, to update any frontends each time the plug's state changes. See +FrontendAwareBasePlug for more info. + +Example implementation of a plug: + + from openhtf import plugs + + class ExamplePlug(base_plugs.BasePlug): + '''A Plug that does nothing.''' + + def __init__(self): + print 'Instantiating %s!' % type(self).__name__ + + def DoSomething(self): + print '%s doing something!' % type(self).__name__ + + def tearDown(self): + # This method is optional. If implemented, it will be called at the end + # of the test. + print 'Tearing down %s!' % type(self).__name__ + +Example usage of the above plug: + + from openhtf import plugs + from my_custom_plugs_package import example + + @plugs.plug(example=example.ExamplePlug) + def TestPhase(test, example): + print 'Test phase started!' + example.DoSomething() + print 'Test phase done!' + +Putting all this together, when the test is run (with just that phase), you +would see the output (with other framework logs before and after): + + Instantiating ExamplePlug! + Test phase started! + ExamplePlug doing something! + Test phase done! + Tearing down ExamplePlug! + +Plugs will often need to use configuration values. The recommended way +of doing this is with the conf.inject_positional_args decorator: + + from openhtf import plugs + from openhtf.util import conf + + conf.declare('my_config_key', default_value='my_config_value') + + class ExamplePlug(base_plugs.BasePlug): + '''A plug that requires some configuration.''' + + @conf.inject_positional_args + def __init__(self, my_config_key) + self._my_config = my_config_key + +Note that Plug constructors shouldn't take any other arguments; the +framework won't pass any, so you'll get a TypeError. Any values that are only +known at run time must be either passed into other methods or set via explicit +setter methods. See openhtf/conf.py for details, but with the above +example, you would also need a configuration .yaml file with something like: + + my_config_key: my_config_value + +This will result in the ExamplePlug being constructed with +self._my_config having a value of 'my_config_value'. +""" + +import logging +from typing import Any, Dict, Set, Text, Type, Union + +import attr + +from openhtf import util + +_LOG = logging.getLogger(__name__) + + +class InvalidPlugError(Exception): + """Raised when a plug declaration or requested name is invalid.""" + + +class BasePlug(object): + """All plug types must subclass this type. + + Attributes: + logger: This attribute will be set by the PlugManager (and as such it + doesn't appear here), and is the same logger as passed into test phases + via TestApi. + """ + # Override this to True in subclasses to support remote Plug access. + enable_remote = False # type: bool + # Allow explicitly disabling remote access to specific attributes. + disable_remote_attrs = set() # type: Set[Text] + # Override this to True in subclasses to support using with_plugs with this + # plug without needing to use placeholder. This will only affect the classes + # that explicitly define this; subclasses do not share the declaration. + auto_placeholder = False # type: bool + # Default logger to be used only in __init__ of subclasses. + # This is overwritten both on the class and the instance so don't store + # a copy of it anywhere. + logger = _LOG # type: logging.Logger + + @util.classproperty + def placeholder(cls) -> 'PlugPlaceholder': # pylint: disable=no-self-argument + """Returns a PlugPlaceholder for the calling class.""" + return PlugPlaceholder(cls) + + def _asdict(self) -> Dict[Text, Any]: + """Returns a dictionary representation of this plug's state. + + This is called repeatedly during phase execution on any plugs that are in + use by that phase. The result is reported via the Station API by the + PlugManager (if the Station API is enabled, which is the default). + + Note that this method is called in a tight loop, it is recommended that you + decorate it with functions.call_at_most_every() to limit the frequency at + which updates happen (pass a number of seconds to it to limit samples to + once per that number of seconds). + + You can also implement an `as_base_types` function that can return a dict + where the values must be base types at all levels. This can help prevent + recursive copying, which is time intensive. + + """ + return {} + + def tearDown(self) -> None: + """This method is called automatically at the end of each Test execution.""" + pass + + @classmethod + def uses_base_tear_down(cls) -> bool: + """Checks whether the tearDown method is the BasePlug implementation.""" + this_tear_down = getattr(cls, 'tearDown') + base_tear_down = getattr(BasePlug, 'tearDown') + return this_tear_down.__code__ is base_tear_down.__code__ + + +class FrontendAwareBasePlug(BasePlug, util.SubscribableStateMixin): + """A plug that notifies of any state updates. + + Plugs inheriting from this class may be used in conjunction with the Station + API to update any frontends each time the plug's state changes. The plug + should call notify_update() when and only when the state returned by _asdict() + changes. + + Since the Station API runs in a separate thread, the _asdict() method of + frontend-aware plugs should be written with thread safety in mind. + """ + enable_remote = True # type: bool + + +@attr.s(slots=True, frozen=True) +class PlugPlaceholder(object): + """Placeholder for a specific plug to be provided before test execution. + + Use the with_plugs() method to provide the plug before test execution. The + with_plugs() method checks to make sure the substitute plug is a subclass of + the PlugPlaceholder's base_class and BasePlug. + """ + + base_class = attr.ib(type=Type[object]) + + +@attr.s(slots=True) +class PhasePlug(object): + """Information about the use of a plug in a phase.""" + + name = attr.ib(type=Text) + cls = attr.ib(type=Union[Type[BasePlug], PlugPlaceholder]) + update_kwargs = attr.ib(type=bool, default=True) diff --git a/openhtf/core/diagnoses_lib.py b/openhtf/core/diagnoses_lib.py index 080481cc3..8b29dc227 100644 --- a/openhtf/core/diagnoses_lib.py +++ b/openhtf/core/diagnoses_lib.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python2, python3 +# Lint as: python3 """Diagnoses: Measurement and meta interpreters. Diagnoses are higher level signals that result from processing multiple @@ -125,14 +125,19 @@ def main(): import abc import collections import logging +from typing import Any, Callable, DefaultDict, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Text, Type, TYPE_CHECKING, Union import attr import enum # pylint: disable=g-bad-import-order from openhtf.core import phase_descriptor +from openhtf.core import test_record from openhtf.util import data import six from six.moves import collections_abc +if TYPE_CHECKING: + from openhtf.core import test_state # pylint: disable=g-import-not-at-top + class DiagnoserError(Exception): """Diagnoser was constructed incorrectly..""" @@ -151,19 +156,20 @@ class DiagnosesStore(object): """Storage and lookup of diagnoses.""" _diagnoses_by_results = attr.ib( - default=attr.Factory(dict)) # type: Dict['DiagResultEnum', 'Diagnosis'] - _diagnoses = attr.ib(default=attr.Factory(list)) # type: List['Diagnosis'] + type=Dict['DiagResultEnum', 'Diagnosis'], default=attr.Factory(dict)) + _diagnoses = attr.ib(type=List['Diagnosis'], default=attr.Factory(list)) - def _add_diagnosis(self, diagnosis): + def _add_diagnosis(self, diagnosis: 'Diagnosis') -> None: """Add a diagnosis to the store.""" self._diagnoses_by_results[diagnosis.result] = diagnosis self._diagnoses.append(diagnosis) - def has_diagnosis_result(self, diagnosis_result): + def has_diagnosis_result(self, diagnosis_result: 'DiagResultEnum') -> bool: """Returns if the diagnosis_result has been added.""" return diagnosis_result in self._diagnoses_by_results - def get_diagnosis(self, diagnosis_result): + def get_diagnosis( + self, diagnosis_result: 'DiagResultEnum') -> Optional['Diagnosis']: """Returns the latest diagnosis with the passed in result.""" return self._diagnoses_by_results.get(diagnosis_result) @@ -180,13 +186,14 @@ class DiagnosesManager(object): store = attr.ib( type=DiagnosesStore, default=attr.Factory(DiagnosesStore), init=False) - def _add_diagnosis(self, diagnosis): + def _add_diagnosis(self, diagnosis: 'Diagnosis') -> None: """Adds a diagnosis to the internal store.""" if self.store.has_diagnosis_result(diagnosis.result): self._logger.warning('Duplicate diagnosis result: %s', diagnosis) self.store._add_diagnosis(diagnosis) # pylint: disable=protected-access - def _verify_and_fix_diagnosis(self, diag, diagnoser): + def _verify_and_fix_diagnosis(self, diag: 'Diagnosis', + diagnoser: '_BaseDiagnoser') -> 'Diagnosis': if not isinstance(diag.result, diagnoser.result_type): raise InvalidDiagnosisError( 'Diagnoser {} returned different result then its result_type.'.format( @@ -195,7 +202,10 @@ def _verify_and_fix_diagnosis(self, diag, diagnoser): return attr.evolve(diag, is_failure=True) return diag - def _convert_result(self, diagnosis_or_diagnoses, diagnoser): + def _convert_result(self, + diagnosis_or_diagnoses: Union['Diagnosis', + Sequence['Diagnosis']], + diagnoser: '_BaseDiagnoser') -> Iterable['Diagnosis']: """Convert parameter into a list if a single Diagnosis.""" if not diagnosis_or_diagnoses: return @@ -215,14 +225,16 @@ def _convert_result(self, diagnosis_or_diagnoses, diagnoser): type(diag).__name__)) yield self._verify_and_fix_diagnosis(diag, diagnoser) - def execute_phase_diagnoser(self, diagnoser, phase_state, test_record): + def execute_phase_diagnoser(self, diagnoser: 'BasePhaseDiagnoser', + phase_state: 'test_state.PhaseState', + test_rec: test_record.TestRecord) -> None: """Execute a phase diagnoser. Args: diagnoser: BasePhaseDiagnoser, the diagnoser to run for the given phase. phase_state: test_state.PhaseState, the current running phase state context. - test_record: test_record.TestRecord, the current running test's record. + test_rec: test_record.TestRecord, the current running test's record. """ diagnosis_or_diagnoses = diagnoser.run(phase_state.phase_record) for diag in self._convert_result(diagnosis_or_diagnoses, diagnoser): @@ -230,33 +242,36 @@ def execute_phase_diagnoser(self, diagnoser, phase_state, test_record): # Internal diagnosers are not saved to the test record because they are # not serialized. if not diag.is_internal: - test_record.add_diagnosis(diag) + test_rec.add_diagnosis(diag) self._add_diagnosis(diag) - def execute_test_diagnoser(self, diagnoser, test_record): + def execute_test_diagnoser(self, diagnoser: 'BaseTestDiagnoser', + test_rec: test_record.TestRecord) -> None: """Execute a test diagnoser. Args: diagnoser: TestDiagnoser, the diagnoser to run for the test. - test_record: test_record.TestRecord, the current running test's record. + test_rec: test_record.TestRecord, the current running test's record. Raises: InvalidDiagnosisError: when the diagnoser returns an Internal diagnosis. """ - diagnosis_or_diagnoses = diagnoser.run(test_record, self.store) + diagnosis_or_diagnoses = diagnoser.run(test_rec, self.store) for diag in self._convert_result(diagnosis_or_diagnoses, diagnoser): if diag.is_internal: raise InvalidDiagnosisError( 'Test-level diagnosis {} cannot be Internal'.format(diag)) - test_record.add_diagnosis(diag) + test_rec.add_diagnosis(diag) self._add_diagnosis(diag) -def check_for_duplicate_results(phase_iterator, test_diagnosers): +def check_for_duplicate_results( + phase_iterator: Iterator[phase_descriptor.PhaseDescriptor], + test_diagnosers: Sequence['BaseTestDiagnoser']) -> None: """Check for any results with the same enum value in different ResultTypes. Args: - phase_iterator: iterator over the phases to check; can be a PhaseGroup. + phase_iterator: iterator over the phases to check. test_diagnosers: list of test level diagnosers. Raises: @@ -270,7 +285,7 @@ def check_for_duplicate_results(phase_iterator, test_diagnosers): all_result_enums.add(test_diag.result_type) values_to_enums = collections.defaultdict( - list) # type: DefaultDict[str, Type['DiagResultEnum'] + list) # type: DefaultDict[str, Type['DiagResultEnum']] for enum_cls in all_result_enums: for entry in enum_cls: values_to_enums[entry.value].append(enum_cls) @@ -282,27 +297,30 @@ def check_for_duplicate_results(phase_iterator, test_diagnosers): result_value, enum_classes)) if not duplicates: return - duplicates.sort() raise DuplicateResultError('Duplicate DiagResultEnum values: {}'.format( '\n'.join(duplicates))) -def _check_diagnoser(diag, diagnoser_cls): +def _check_diagnoser(diagnoser: '_BaseDiagnoser', + diagnoser_cls: Type['_BaseDiagnoser']) -> None: """Check that a diagnoser is properly created.""" - if not isinstance(diag, diagnoser_cls): + if not isinstance(diagnoser, diagnoser_cls): raise DiagnoserError('Diagnoser "{}" is not a {}.'.format( - diag.__class__.__name__, diagnoser_cls.__name__)) - if not diag.result_type: + diagnoser.__class__.__name__, diagnoser_cls.__name__)) + if not diagnoser.result_type: raise DiagnoserError( - 'Diagnoser "{}" does not have a result_type set.'.format(diag.name)) - if not issubclass(diag.result_type, DiagResultEnum): + 'Diagnoser "{}" does not have a result_type set.'.format( + diagnoser.name)) + if not issubclass(diagnoser.result_type, DiagResultEnum): raise DiagnoserError( 'Diagnoser "{}" result_type "{}" does not inherit from ' - 'DiagResultEnum.'.format(diag.name, diag.result_type.__name__)) - diag._check_definition() # pylint: disable=protected-access + 'DiagResultEnum.'.format(diagnoser.name, + diagnoser.result_type.__name__)) + diagnoser._check_definition() # pylint: disable=protected-access -def check_diagnosers(diagnosers, diagnoser_cls): +def check_diagnosers(diagnosers: Sequence['_BaseDiagnoser'], + diagnoser_cls: Type['_BaseDiagnoser']) -> None: """Check if all the diagnosers are properly created. Args: @@ -310,8 +328,11 @@ def check_diagnosers(diagnosers, diagnoser_cls): diagnoser_cls: _BaseDiagnoser subclass that all the diagnosers are supposed to be derived from. """ - for diag in diagnosers: - _check_diagnoser(diag, diagnoser_cls) + for diagnoser in diagnosers: + _check_diagnoser(diagnoser, diagnoser_cls) + + +DiagnoserReturnT = Union[None, 'Diagnosis', List['Diagnosis']] @attr.s(slots=True) @@ -324,15 +345,15 @@ class _BaseDiagnoser(object): # The DiagResultEnum-derived enum for the possible results this diagnoser # instance can return. - result_type = attr.ib() # type: Type['DiagResultEnum'] + result_type = attr.ib(type=Type['DiagResultEnum']) # The descriptive name for this diagnoser instance. - name = attr.ib(type=str, default=None) # pylint: disable=g-ambiguous-str-annotation + name = attr.ib(type=Optional[Text], default=None) # If set, diagnoses from this diagnoser will always be marked as failures. always_fail = attr.ib(type=bool, default=False) - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: ret = { 'name': self.name, 'possible_results': self.possible_results, @@ -342,10 +363,10 @@ def as_base_types(self): return ret @property - def possible_results(self): - return [r.value for r in self.result_type] + def possible_results(self) -> List[Text]: + return [r.value for r in self.result_type] # pytype: disable=missing-parameter - def _check_definition(self): + def _check_definition(self) -> None: """Internal function to verify that the diagnoser is completely defined.""" pass @@ -356,7 +377,8 @@ class BasePhaseDiagnoser(six.with_metaclass(abc.ABCMeta, _BaseDiagnoser)): __slots__ = () @abc.abstractmethod - def run(self, phase_record): + def run(self, + phase_record: phase_descriptor.PhaseDescriptor) -> DiagnoserReturnT: """Must be implemented to return list of Diagnoses instances. Args: @@ -373,11 +395,13 @@ class PhaseDiagnoser(BasePhaseDiagnoser): """Diagnoser definition for a Phase using a function.""" # The function to run. Set with run_func in the initializer. - # type: Optional[Callable[[test_record.PhaseRecord], - # Union[None, 'Diagnosis', List['Diagnosis']]]] - _run_func = attr.ib(default=None) + _run_func = attr.ib( + type=Optional[Callable[[test_record.PhaseRecord], DiagnoserReturnT]], + default=None) - def __call__(self, func): + def __call__( + self, func: Callable[[test_record.PhaseRecord], DiagnoserReturnT] + ) -> 'PhaseDiagnoser': """Returns PhaseDiagnoser for the provided function.""" if self._run_func: raise DiagnoserError( @@ -387,11 +411,11 @@ def __call__(self, func): changes['name'] = func.__name__ return attr.evolve(self, **changes) - def run(self, phase_record): + def run(self, phase_record: test_record.PhaseRecord) -> DiagnoserReturnT: """Runs the phase diagnoser and returns the diagnoses.""" return self._run_func(phase_record) - def _check_definition(self): + def _check_definition(self) -> None: if not self._run_func: raise DiagnoserError( 'PhaseDiagnoser run function not defined for {}'.format(self.name)) @@ -403,7 +427,8 @@ class BaseTestDiagnoser(six.with_metaclass(abc.ABCMeta, _BaseDiagnoser)): __slots__ = () @abc.abstractmethod - def run(self, test_rec, diagnoses_store): + def run(self, test_rec: test_record.TestRecord, + diagnoses_store: DiagnosesStore) -> DiagnoserReturnT: """Must be implemented to return list of Diagnoses instances. Args: @@ -422,11 +447,15 @@ class TestDiagnoser(BaseTestDiagnoser): """Diagnoser definition for a Test using a function.""" # The function to run. Set with run_func in the initializer. - # type: Optional[Callable[[test_record.TestRecord, DiagnosesStore], - # Union[None, 'Diagnosis', List['Diagnosis']]]] - _run_func = attr.ib(default=None) - - def __call__(self, func): + _run_func = attr.ib( + type=Optional[Callable[[test_record.TestRecord, DiagnosesStore], + DiagnoserReturnT]], + default=None) + + def __call__( + self, func: Callable[[test_record.TestRecord, DiagnosesStore], + DiagnoserReturnT] + ) -> 'TestDiagnoser': """Returns TestDiagnoser for the provided function.""" if self._run_func: raise DiagnoserError( @@ -436,11 +465,12 @@ def __call__(self, func): changes['name'] = func.__name__ return attr.evolve(self, **changes) - def run(self, test_record, diagnoses_store): + def run(self, test_rec: test_record.TestRecord, + diagnoses_store: DiagnosesStore) -> DiagnoserReturnT: """Runs the test diagnoser and returns the diagnoses.""" - return self._run_func(test_record, diagnoses_store) + return self._run_func(test_rec, diagnoses_store) - def _check_definition(self): + def _check_definition(self) -> None: if not self._run_func: raise DiagnoserError( 'TestDiagnoser run function not defined for {}'.format(self.name)) @@ -453,7 +483,9 @@ class DiagResultEnum(str, enum.Enum): Users should subclass this enum to add their specific diagnoses. Separate subclasses should be used for unrelated diagnosis results. """ - pass + + def as_base_types(self) -> Text: + return self.value @enum.unique @@ -473,13 +505,13 @@ class DiagnosisComponent(object): """Component definition for a diagnosis.""" # Name of the component. - name = attr.ib(type=str) # pylint: disable=g-ambiguous-str-annotation + name = attr.ib(type=Text) # Unique identifier for the component, like a barcode or serial number. - identifier = attr.ib(type=str) # pylint: disable=g-ambiguous-str-annotation + identifier = attr.ib(type=Text) -def _diagnosis_serialize_filter(attribute, value): - return attribute.name not in ('is_failure', 'is_internal') or value +def _diagnosis_serialize_filter(attribute: attr.Attribute, value: Any) -> bool: + return attribute.name not in ('is_failure', 'is_internal') or value # pytype: disable=attribute-error @attr.s(slots=True, frozen=True) @@ -492,10 +524,10 @@ class Diagnosis(object): # Human readable description that gives more information about the failure and # possible what to do with it. - description = attr.ib(type=str, default='') # pylint: disable=g-ambiguous-str-annotation + description = attr.ib(type=Text, default='') # The component that is associated with this diagnosis. - component = attr.ib(type=DiagnosisComponent, default=None) + component = attr.ib(type=Optional[DiagnosisComponent], default=None) # The level of importance for the diagnosis. priority = attr.ib(type=DiagPriority, default=DiagPriority.NORMAL) @@ -508,21 +540,25 @@ class Diagnosis(object): # Diagnosers. is_internal = attr.ib(type=bool, default=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self.is_internal and self.is_failure: raise InvalidDiagnosisError('Internal diagnoses cannot be failures.') - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: return data.convert_to_base_types( attr.asdict(self, filter=_diagnosis_serialize_filter)) -def diagnose(*diagnosers): +def diagnose( + *diagnosers: BasePhaseDiagnoser +) -> Callable[[phase_descriptor.PhaseT], phase_descriptor.PhaseDescriptor]: """Decorator to add diagnosers to a PhaseDescriptor.""" check_diagnosers(diagnosers, BasePhaseDiagnoser) diags = list(diagnosers) - def decorate(wrapped_phase): + def decorate( + wrapped_phase: phase_descriptor.PhaseT + ) -> phase_descriptor.PhaseDescriptor: """Phase decorator to be returned.""" phase = phase_descriptor.PhaseDescriptor.wrap_or_copy(wrapped_phase) phase.diagnosers.extend(diags) diff --git a/openhtf/core/measurements.py b/openhtf/core/measurements.py index 8b99d6385..efc878db9 100644 --- a/openhtf/core/measurements.py +++ b/openhtf/core/measurements.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Measurements for OpenHTF. Measurements in OpenHTF are used to represent values collected during a Test. @@ -60,25 +58,26 @@ def WidgetTestPhase(test): """ - import collections +import enum import functools import logging +from typing import Any, Callable, Dict, Iterator, List, Optional, Text, Tuple, Union import attr -import enum -import mutablerecords from openhtf import util from openhtf.core import diagnoses_lib from openhtf.core import phase_descriptor from openhtf.util import data -from openhtf.util import units +from openhtf.util import units as util_units from openhtf.util import validators import six try: - import pandas # pylint: disable=g-import-not-at-top + # pylint: disable=g-import-not-at-top + import pandas # pytype: disable=import-error + # pylint: enable=g-import-not-at-top except ImportError: pandas = None @@ -123,29 +122,29 @@ class _ConditionalValidator(object): result = attr.ib(type=diagnoses_lib.DiagResultEnum) # The validator to use when the result is present. - validator = attr.ib() # type: Callable[[Any], bool] + validator = attr.ib(type=Callable[[Any], bool]) - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: return dict( result=data.convert_to_base_types(self.result), validator=str(self.validator), ) - def with_args(self, **kwargs): + def with_args(self, **kwargs: Any) -> '_ConditionalValidator': if hasattr(self.validator, 'with_args'): - return _ConditionalValidator( - self.result, self.validator.with_args(**kwargs)) + return _ConditionalValidator(self.result, + self.validator.with_args(**kwargs)) return self -def _coordinates_len(coordinates): +def _coordinates_len(coordinates: Any) -> int: """Returns count of measurement coordinates. Treat single string as a single dimension. Args: - coordinates: any type, measurement coordinates - for multidimensional measurements. + coordinates: any type, measurement coordinates for multidimensional + measurements. """ if isinstance(coordinates, six.string_types): return 1 @@ -154,17 +153,19 @@ def _coordinates_len(coordinates): return 1 -class Measurement( # pylint: disable=no-init - mutablerecords.Record( - 'Measurement', ['name'], - {'units': None, 'dimensions': None, 'docstring': None, - '_notification_cb': None, - 'validators': list, - 'conditional_validators': list, - 'transform_fn': None, - 'outcome': Outcome.UNSET, - 'measured_value': None, - '_cached': None})): +UnitInputT = Union[Text, util_units.UnitDescriptor] +DimensionInputT = Union['Dimension', Text, util_units.UnitDescriptor] + + +class _MeasuredValueSentinel(enum.Enum): + UNINITIALIZED = 0 + + +_MEASURED_VALUE_UNINITIALIZED = _MeasuredValueSentinel.UNINITIALIZED + + +@attr.s(slots=True) +class Measurement(object): """Record encapsulating descriptive data for a measurement. This record includes an _asdict() method so it can be easily output. Output @@ -176,45 +177,88 @@ class Measurement( # pylint: disable=no-init docstring: Optional string describing this measurement. units: UOM code of the units for the measurement being taken. dimensions: Tuple of UOM codes for units of dimensions. + transform_fn: A function to apply to measurements as they are ingested. validators: List of callable validator objects to perform pass/fail checks. conditional_validators: List of _ConditionalValidator instances that are called when certain Diagnosis Results are present at the beginning of the associated phase. - transform_fn: A function to apply to measurements as they are ingested. - outcome: One of the Outcome() enumeration values, starting at UNSET. measured_value: An instance of MeasuredValue or DimensionedMeasuredValue containing the value(s) of this Measurement that have been set, if any. + notification_cb: An optional function to be called when the measurement is + set. + outcome: One of the Outcome() enumeration values, starting at UNSET. _cached: A cached dict representation of this measurement created initially during as_base_types and updated in place to save allocation time. """ - def __init__(self, name, **kwargs): - super(Measurement, self).__init__(name, **kwargs) - if 'measured_value' not in kwargs: + # Informational fields set during definition. + name = attr.ib(type=Text) + docstring = attr.ib(type=Optional[Text], default=None) + units = attr.ib(type=Optional[util_units.UnitDescriptor], default=None) + + # Fields set during definition that affect how the measurement gets set or + # validated, ordered by when they are used. + _dimensions = attr.ib(type=Optional[Tuple['Dimension', ...]], default=None) + _transform_fn = attr.ib(type=Optional[Callable[[Any], Any]], default=None) + validators = attr.ib(type=List[Callable[[Any], bool]], factory=list) + conditional_validators = attr.ib( + type=List[_ConditionalValidator], factory=list) + + # Fields set during runtime. + # measured_value needs to be initialized in the post init function if and only + # if it wasn't set during initialization. + _measured_value = attr.ib( + type=Union['MeasuredValue', 'DimensionedMeasuredValue'], default=None) + _notification_cb = attr.ib(type=Optional[Callable[[], None]], default=None) + outcome = attr.ib(type=Outcome, default=Outcome.UNSET) + + # Runtime cache to speed up conversions. + _cached = attr.ib(type=Optional[Dict[Text, Any]], default=None) + + def __attrs_post_init__(self) -> None: + if self._measured_value is None: self._initialize_value() - def _initialize_value(self): - if self.measured_value and self.measured_value.is_value_set: + def _initialize_value(self) -> None: + """Initialize the measurement value.""" + if self._measured_value and self._measured_value.is_value_set: raise ValueError('Cannot update a Measurement once a value is set.') if self.dimensions: - self.measured_value = DimensionedMeasuredValue( + self._measured_value = DimensionedMeasuredValue( name=self.name, num_dimensions=len(self.dimensions), transform_fn=self.transform_fn) else: - self.measured_value = MeasuredValue( - name=self.name, - transform_fn=self.transform_fn) + self._measured_value = MeasuredValue( + name=self.name, transform_fn=self.transform_fn) - def __setattr__(self, name, value): - super(Measurement, self).__setattr__(name, value) - # When dimensions or transform_fn change, we may need to update our - # measured_value type. - if name in ['dimensions', 'transform_fn']: - self._initialize_value() + @property + def dimensions(self) -> Optional[Tuple['Dimension', ...]]: + return self._dimensions + + @dimensions.setter + def dimensions(self, value: Optional[Tuple['Dimension', ...]]) -> None: + self._dimensions = value + self._initialize_value() + + @property + def transform_fn(self) -> Optional[Callable[[Any], Any]]: + return self._transform_fn + + @transform_fn.setter + def transform_fn(self, value: Optional[Callable[[Any], Any]]) -> None: + self._transform_fn = value + self._initialize_value() - def __setstate__(self, state): + # TODO(arsharma): Create a common base class for the measured value types. + # Otherwise, pytype will require casting the type whenever one tries to use + # unique functions in those classes. + @property + def measured_value(self) -> Any: + return self._measured_value + + def __setstate__(self, state: Dict[Text, Any]) -> None: """Set this record's state during unpickling. This override is necessary to ensure that the the _initialize_value check @@ -224,21 +268,23 @@ def __setstate__(self, state): state: internal state. """ # TODO(arsharma) Add unit tests for unpickling operations. - dimensions = state.pop('dimensions') - transform_fn = state.pop('transform_fn', None) + dimensions = state.pop('_dimensions') + transform_fn = state.pop('_transform_fn', None) - super(Measurement, self).__setstate__(state) - object.__setattr__(self, 'dimensions', dimensions) - object.__setattr__(self, 'transform_fn', transform_fn) + for name, value in state.items(): + setattr(self, name, value) + setattr(self, '_dimensions', dimensions) + setattr(self, '_transform_fn', transform_fn) - def set_notification_callback(self, notification_cb): + def set_notification_callback( + self, notification_cb: Optional[Callable[[], None]]) -> 'Measurement': """Set the notifier we'll call when measurements are set.""" self._notification_cb = notification_cb if not notification_cb and self.dimensions: - self.measured_value.notify_value_set = None + self._measured_value.notify_value_set = None return self - def notify_value_set(self): + def notify_value_set(self) -> None: if self.dimensions: self.outcome = Outcome.PARTIALLY_SET else: @@ -246,46 +292,47 @@ def notify_value_set(self): if self._notification_cb: self._notification_cb() - def doc(self, docstring): + def doc(self, docstring: Text) -> 'Measurement': """Set this Measurement's docstring, returns self for chaining.""" self.docstring = docstring return self - def _maybe_make_unit_desc(self, unit_desc): + def _maybe_make_unit_desc(self, + unit_desc: UnitInputT) -> util_units.UnitDescriptor: """Return the UnitDescriptor or convert a string to one.""" if isinstance(unit_desc, str) or unit_desc is None: - unit_desc = units.Unit(unit_desc) - if not isinstance(unit_desc, units.UnitDescriptor): - raise TypeError('Invalid units for measurement %s: %s' % (self.name, - unit_desc)) + unit_desc = util_units.Unit(unit_desc) + if not isinstance(unit_desc, util_units.UnitDescriptor): + raise TypeError('Invalid units for measurement %s: %s' % + (self.name, unit_desc)) return unit_desc - def _maybe_make_dimension(self, dimension): + def _maybe_make_dimension(self, dimension: DimensionInputT) -> 'Dimension': """Return a `measurements.Dimension` instance.""" # For backwards compatibility the argument can be either a Dimension, a - # string or a `units.UnitDescriptor`. + # string or a `util_units.UnitDescriptor`. if isinstance(dimension, Dimension): return dimension - if isinstance(dimension, units.UnitDescriptor): + if isinstance(dimension, util_units.UnitDescriptor): return Dimension.from_unit_descriptor(dimension) if isinstance(dimension, str): return Dimension.from_string(dimension) raise TypeError('Cannot convert {} to a dimension'.format(dimension)) - def with_units(self, unit_desc): + def with_units(self, unit_desc: UnitInputT) -> 'Measurement': """Declare the units for this Measurement, returns self for chaining.""" self.units = self._maybe_make_unit_desc(unit_desc) return self - def with_dimensions(self, *dimensions): + def with_dimensions(self, *dimensions: DimensionInputT) -> 'Measurement': """Declare dimensions for this Measurement, returns self for chaining.""" self.dimensions = tuple( self._maybe_make_dimension(dim) for dim in dimensions) self._cached = None return self - def with_validator(self, validator): + def with_validator(self, validator: Callable[[Any], bool]) -> 'Measurement': """Add a validator callback to this Measurement, chainable.""" if not callable(validator): raise ValueError('Validator must be callable', validator) @@ -293,7 +340,10 @@ def with_validator(self, validator): self._cached = None return self - def validate_on(self, result_to_validator_mapping): + def validate_on( + self, result_to_validator_mapping: Dict[diagnoses_lib.DiagResultEnum, + Callable[[Any], bool]] + ) -> 'Measurement': """Adds conditional validators. Note that results are added by the current phase after measurements are @@ -315,14 +365,14 @@ def validate_on(self, result_to_validator_mapping): self._cached = None return self - def with_precision(self, precision): + def with_precision(self, precision: int) -> 'Measurement': """Set a precision value to round results to.""" if not isinstance(precision, int): - raise TypeError('Precision must be specified as an int, not %s' % type( - precision)) + raise TypeError('Precision must be specified as an int, not %s' % + type(precision)) return self.with_transform(functools.partial(round, ndigits=precision)) - def with_transform(self, transform_fn): + def with_transform(self, transform_fn: Callable[[Any], Any]) -> 'Measurement': """Set the transform function.""" if not callable(transform_fn): raise TypeError('Transform function must be callable.') @@ -331,40 +381,43 @@ def with_transform(self, transform_fn): self.transform_fn = transform_fn return self - def with_args(self, **kwargs): + def with_args(self, **kwargs: Any) -> 'Measurement': """String substitution for names and docstrings.""" new_validators = [ v.with_args(**kwargs) if hasattr(v, 'with_args') else v for v in self.validators ] new_conditional_validators = [ - cv.with_args(**kwargs) for cv in self.conditional_validators] - return mutablerecords.CopyRecord( - self, name=util.format_string(self.name, kwargs), + cv.with_args(**kwargs) for cv in self.conditional_validators + ] + return data.attr_copy( + self, + name=util.format_string(self.name, kwargs), docstring=util.format_string(self.docstring, kwargs), validators=new_validators, conditional_validators=new_conditional_validators, - _cached=None, + cached=None, ) - def __getattr__(self, name): # pylint: disable=invalid-name + def __getattr__(self, name: Text) -> Any: """Support our default set of validators as direct attributes.""" # Don't provide a back door to validators.py private stuff accidentally. if name.startswith('_') or not validators.has_validator(name): - raise AttributeError("'%s' object has no attribute '%s'" % ( - type(self).__name__, name)) + raise AttributeError("'%s' object has no attribute '%s'" % + (type(self).__name__, name)) # Create a wrapper to invoke the attribute from within validators. - def _with_validator(*args, **kwargs): # pylint: disable=invalid-name + def _with_validator(*args, **kwargs): return self.with_validator( validators.create_validator(name, *args, **kwargs)) + return _with_validator - def validate(self): + def validate(self) -> 'Measurement': """Validate this measurement and update its 'outcome' field.""" # PASS if all our validators return True, otherwise FAIL. try: - if all(v(self.measured_value.value) for v in self.validators): + if all(v(self._measured_value.value) for v in self.validators): self.outcome = Outcome.PASS else: self.outcome = Outcome.FAIL @@ -378,7 +431,7 @@ def validate(self): if self._cached: self._cached['outcome'] = self.outcome.name - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: """Convert this measurement to a dict of basic types.""" if not self._cached: # Create the single cache file the first time this is called. @@ -398,13 +451,13 @@ def as_base_types(self): self._cached['units'] = data.convert_to_base_types(self.units) if self.docstring: self._cached['docstring'] = self.docstring - if self.measured_value.is_value_set: - self._cached['measured_value'] = self.measured_value.basetype_value() + if self._measured_value.is_value_set: + self._cached['measured_value'] = self._measured_value.basetype_value() return self._cached - def to_dataframe(self, columns=None): + def to_dataframe(self, columns: Any = None) -> Any: """Convert a multi-dim to a pandas dataframe.""" - if not isinstance(self.measured_value, DimensionedMeasuredValue): + if not isinstance(self._measured_value, DimensionedMeasuredValue): raise TypeError( 'Only a dimensioned measurement can be converted to a DataFrame') @@ -412,15 +465,13 @@ def to_dataframe(self, columns=None): columns = [d.name for d in self.dimensions] columns += [self.units.name if self.units else 'value'] - dataframe = self.measured_value.to_dataframe(columns) + dataframe = self._measured_value.to_dataframe(columns) return dataframe -class MeasuredValue( - mutablerecords.Record('MeasuredValue', ['name'], - {'transform_fn': None, 'stored_value': None, - 'is_value_set': False, '_cached_value': None})): +@attr.s(slots=True) +class MeasuredValue(object): """Class encapsulating actual values measured. Note that this is really just a value wrapper with some sanity checks. See @@ -439,27 +490,33 @@ class MeasuredValue( is set. """ - def __str__(self): + name = attr.ib(type=Text) + transform_fn = attr.ib(type=Optional[Callable[[Any], Any]], default=None) + stored_value = attr.ib(type=Optional[Any], default=None) + is_value_set = attr.ib(type=bool, default=False) + _cached_value = attr.ib(type=Optional[Any], default=None) + + def __str__(self) -> Text: return str(self.value) if self.is_value_set else 'UNSET' - def __eq__(self, other): + def __eq__(self, other: 'MeasuredValue') -> bool: return (type(self) == type(other) and self.name == other.name and # pylint: disable=unidiomatic-typecheck - self.is_value_set == other.is_value_set and - self.stored_value == other.stored_value) + self.is_value_set == other.is_value_set + and self.stored_value == other.stored_value) - def __ne__(self, other): + def __ne__(self, other: 'MeasuredValue') -> bool: return not self.__eq__(other) @property - def value(self): + def value(self) -> Any: if not self.is_value_set: raise MeasurementNotSetError('Measurement not yet set', self.name) return self.stored_value - def basetype_value(self): + def basetype_value(self) -> Any: return self._cached_value - def set(self, value): + def set(self, value: Any) -> None: """Set the value for this measurement, with some sanity checks.""" # Apply transform function if it is set. @@ -481,6 +538,7 @@ def set(self, value): self.is_value_set = True +@attr.s(slots=True) class Dimension(object): """Dimension for multi-dim Measurements. @@ -488,11 +546,12 @@ class Dimension(object): as a drop-in replacement for UnitDescriptor for backwards compatibility. """ - __slots__ = ['_description', '_unit', '_cached_dict'] + _description = attr.ib(type=Text, default='') + _unit = attr.ib( + type=util_units.UnitDescriptor, default=util_units.NO_DIMENSION) + _cached_dict = attr.ib(type=Dict[Text, Any], default=None) - def __init__(self, description='', unit=units.NO_DIMENSION): - self._description = description - self._unit = unit + def __attrs_post_init__(self) -> None: self._cached_dict = data.convert_to_base_types({ 'code': self.code, 'description': self.description, @@ -500,62 +559,59 @@ def __init__(self, description='', unit=units.NO_DIMENSION): 'suffix': self.suffix, }) - def __eq__(self, other): + def __eq__(self, other: 'Dimension') -> bool: return self.description == other.description and self.unit == other.unit - def __ne__(self, other): + def __ne__(self, other: 'Dimension') -> bool: return not self == other - def __repr__(self): + def __repr__(self) -> Text: return '<%s: %s>' % (type(self).__name__, self._asdict()) @classmethod - def from_unit_descriptor(cls, unit_desc): + def from_unit_descriptor(cls, + unit_desc: util_units.UnitDescriptor) -> 'Dimension': return cls(unit=unit_desc) @classmethod - def from_string(cls, string): + def from_string(cls, string: Text) -> 'Dimension': """Convert a string into a Dimension.""" # Note: There is some ambiguity as to whether the string passed is intended # to become a unit looked up by name or suffix, or a Dimension descriptor. - if string in units.UNITS_BY_ALL: - return cls(description=string, unit=units.Unit(string)) + if string in util_units.UNITS_BY_ALL: + return cls(description=string, unit=util_units.Unit(string)) else: return cls(description=string) @property - def description(self): + def description(self) -> Text: return self._description @property - def unit(self): + def unit(self) -> util_units.UnitDescriptor: return self._unit @property - def code(self): - """Provides backwards compatibility to `units.UnitDescriptor` api.""" + def code(self) -> Text: + """Provides backwards compatibility to `util_units.UnitDescriptor` api.""" return self._unit.code @property - def suffix(self): - """Provides backwards compatibility to `units.UnitDescriptor` api.""" + def suffix(self) -> Optional[Text]: + """Provides backwards compatibility to `util_units.UnitDescriptor` api.""" return self._unit.suffix @property - def name(self): - """Provides backwards compatibility to `units.UnitDescriptor` api.""" + def name(self) -> Text: + """Provides backwards compatibility to `util_units.UnitDescriptor` api.""" return self._description or self._unit.name - def _asdict(self): + def _asdict(self) -> Dict[Text, Any]: return self._cached_dict -class DimensionedMeasuredValue(mutablerecords.Record( - 'DimensionedMeasuredValue', ['name', 'num_dimensions'], - {'transform_fn': None, - 'notify_value_set': None, - 'value_dict': collections.OrderedDict, - '_cached_basetype_values': list})): +@attr.s(slots=True) +class DimensionedMeasuredValue(object): """Class encapsulating actual values measured. See the MeasuredValue class docstring for more info. This class provides a @@ -570,27 +626,36 @@ class DimensionedMeasuredValue(mutablerecords.Record( basetype_value. """ - def __str__(self): + name = attr.ib(type=Text) + num_dimensions = attr.ib(type=int) + + transform_fn = attr.ib(type=Optional[Callable[[Any], Any]], default=None) + notify_value_set = attr.ib(type=Optional[Callable[[], None]], default=None) + value_dict = attr.ib(type=Dict[Any, Any], factory=collections.OrderedDict) + _cached_basetype_values = attr.ib(type=List[Any], factory=list) + + def __str__(self) -> Text: return str(self.value) if self.is_value_set else 'UNSET' - def with_notify(self, notify_value_set): + def with_notify( + self, notify_value_set: Callable[[], None]) -> 'DimensionedMeasuredValue': self.notify_value_set = notify_value_set return self @property - def is_value_set(self): + def is_value_set(self) -> bool: return bool(self.value_dict) - def __iter__(self): # pylint: disable=invalid-name + def __iter__(self) -> Iterator[Any]: """Iterate over items, allows easy conversion to a dict.""" return iter(six.iteritems(self.value_dict)) - def __setitem__(self, coordinates, value): # pylint: disable=invalid-name + def __setitem__(self, coordinates: Any, value: Any) -> None: coordinates_len = _coordinates_len(coordinates) if coordinates_len != self.num_dimensions: raise InvalidDimensionsError( - 'Expected %s-dimensional coordinates, got %s' % (self.num_dimensions, - coordinates_len)) + 'Expected %s-dimensional coordinates, got %s' % + (self.num_dimensions, coordinates_len)) # Wrap single dimensions in a tuple so we can assume value_dict keys are # always tuples later. @@ -604,8 +669,8 @@ def __setitem__(self, coordinates, value): # pylint: disable=invalid-name self.name, coordinates, self.value_dict[coordinates], value) self._cached_basetype_values = None elif self._cached_basetype_values is not None: - self._cached_basetype_values.append(data.convert_to_base_types( - coordinates + (value,))) + self._cached_basetype_values.append( + data.convert_to_base_types(coordinates + (value,))) except TypeError as e: raise InvalidDimensionsError( 'Mutable objects cannot be used as measurement dimensions: ' + str(e)) @@ -619,7 +684,7 @@ def __setitem__(self, coordinates, value): # pylint: disable=invalid-name if self.notify_value_set: self.notify_value_set() - def __getitem__(self, coordinates): # pylint: disable=invalid-name + def __getitem__(self, coordinates: Any) -> Any: # Wrap single dimensions in a tuple so we can assume value_dict keys are # always tuples later. if self.num_dimensions == 1: @@ -627,7 +692,7 @@ def __getitem__(self, coordinates): # pylint: disable=invalid-name return self.value_dict[coordinates] @property - def value(self): + def value(self) -> List[Any]: """The values stored in this record. Raises: @@ -640,17 +705,19 @@ def value(self): """ if not self.is_value_set: raise MeasurementNotSetError('Measurement not yet set', self.name) - return [dimensions + (value,) for dimensions, value in - six.iteritems(self.value_dict)] + return [ + dimensions + (value,) + for dimensions, value in six.iteritems(self.value_dict) + ] - def basetype_value(self): + def basetype_value(self) -> List[Any]: if self._cached_basetype_values is None: self._cached_basetype_values = list( data.convert_to_base_types(coordinates + (value,)) for coordinates, value in six.iteritems(self.value_dict)) return self._cached_basetype_values - def to_dataframe(self, columns=None): + def to_dataframe(self, columns: Any = None) -> Any: """Converts to a `pandas.DataFrame`.""" if not self.is_value_set: raise ValueError('Value must be set before converting to a DataFrame.') @@ -659,7 +726,8 @@ def to_dataframe(self, columns=None): return pandas.DataFrame.from_records(self.value, columns=columns) -class Collection(mutablerecords.Record('Collection', ['_measurements'])): +@attr.s(slots=True) +class Collection(object): """Encapsulates a collection of measurements. This collection can have measurement values retrieved and set via getters and @@ -703,54 +771,68 @@ class Collection(mutablerecords.Record('Collection', ['_measurements'])): # [(5, 10), (6, 11)] """ - def _assert_valid_key(self, name): + _measurements = attr.ib(type=Dict[Text, Measurement]) + + def _assert_valid_key(self, name: Text) -> None: """Raises if name is not a valid measurement.""" if name not in self._measurements: raise NotAMeasurementError('Not a measurement', name, self._measurements) - def __iter__(self): # pylint: disable=invalid-name + def __iter__(self) -> Iterator[Tuple[Text, Any]]: """Extract each MeasurementValue's value.""" return ((key, meas.measured_value.value) for key, meas in six.iteritems(self._measurements)) - def __setattr__(self, name, value): # pylint: disable=invalid-name + def _custom_setattr(self, name: Text, value: Any) -> None: + if name == '_measurements': + object.__setattr__(self, name, value) + return self[name] = value - def __getattr__(self, name): # pylint: disable=invalid-name + def __getattr__(self, name: Text) -> Any: return self[name] - def __setitem__(self, name, value): # pylint: disable=invalid-name + def __setitem__(self, name: Text, value: Any) -> None: self._assert_valid_key(name) - if self._measurements[name].dimensions: + m = self._measurements[name] + if m.dimensions: raise InvalidDimensionsError( 'Cannot set dimensioned measurement without indices') - self._measurements[name].measured_value.set(value) - self._measurements[name].notify_value_set() + m.measured_value.set(value) + m.notify_value_set() - def __getitem__(self, name): # pylint: disable=invalid-name + def __getitem__(self, name: Text) -> Any: self._assert_valid_key(name) - if self._measurements[name].dimensions: - return self._measurements[name].measured_value.with_notify( - self._measurements[name].notify_value_set) + m = self._measurements[name] + if m.dimensions: + return m.measured_value.with_notify(m.notify_value_set) # Return the MeasuredValue's value, MeasuredValue will raise if not set. - return self._measurements[name].measured_value.value + return m.measured_value.value + +# Work around for attrs bug in 20.1.0; after the next release, this can be +# removed and `Collection._custom_setattr` can be renamed to `__setattr__`. +# https://github.com/python-attrs/attrs/issues/680 +Collection.__setattr__ = Collection._custom_setattr # pylint: disable=protected-access +del Collection._custom_setattr -def measures(*measurements, **kwargs): +def measures( + *measurements: Union[Text, Measurement], **kwargs: Any +) -> Callable[[phase_descriptor.PhaseT], phase_descriptor.PhaseDescriptor]: """Decorator-maker used to declare measurements for phases. See the measurements module docstring for examples of usage. Args: *measurements: Measurement objects to declare, or a string name from which - to create a Measurement. + to create a Measurement. **kwargs: Keyword arguments to pass to Measurement constructor if we're - constructing one. Note that if kwargs are provided, the length - of measurements must be 1, and that value must be a string containing - the measurement name. For valid kwargs, see the definition of the - Measurement class. + constructing one. Note that if kwargs are provided, the length of + measurements must be 1, and that value must be a string containing the + measurement name. For valid kwargs, see the definition of the Measurement + class. Raises: InvalidMeasurementTypeError: When the measurement is not defined correctly. @@ -758,7 +840,8 @@ def measures(*measurements, **kwargs): Returns: A decorator that declares the measurement(s) for the decorated phase. """ - def _maybe_make(meas): + + def _maybe_make(meas: Union[Text, Measurement]) -> Measurement: """Turn strings into Measurement objects if necessary.""" if isinstance(meas, Measurement): return meas @@ -779,14 +862,18 @@ def _maybe_make(meas): measurements = [_maybe_make(meas) for meas in measurements] # 'measurements' is guaranteed to be a list of Measurement objects here. - def decorate(wrapped_phase): + def decorate( + wrapped_phase: phase_descriptor.PhaseT + ) -> phase_descriptor.PhaseDescriptor: """Phase decorator to be returned.""" phase = phase_descriptor.PhaseDescriptor.wrap_or_copy(wrapped_phase) - duplicate_names = (set(m.name for m in measurements) & - set(m.name for m in phase.measurements)) + duplicate_names = ( + set(m.name for m in measurements) + & set(m.name for m in phase.measurements)) if duplicate_names: raise DuplicateNameError('Measurement names duplicated', duplicate_names) phase.measurements.extend(measurements) return phase + return decorate diff --git a/openhtf/core/monitors.py b/openhtf/core/monitors.py index 22519a831..2e301b1cf 100644 --- a/openhtf/core/monitors.py +++ b/openhtf/core/monitors.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Monitors provide a mechanism for periodically collecting data and -automatically persisting values in a measurement. +"""Monitors provide a mechanism for periodically collecting a measurement. Monitors are implemented similar to phase functions - they are decorated with plugs.plug() to pass plugs in. The return value of a monitor @@ -48,37 +46,52 @@ def MyPhase(test): import functools import inspect import time +from typing import Any, Callable, Dict, Optional, Text import openhtf from openhtf import plugs from openhtf.core import measurements +from openhtf.core import phase_descriptor +from openhtf.core import test_state as core_test_state from openhtf.util import threads from openhtf.util import units as uom +import six class _MonitorThread(threads.KillableThread): + """Background thread that runs a monitor.""" daemon = True - def __init__(self, measurement_name, monitor_desc, extra_kwargs, test_state, - interval_ms): - super(_MonitorThread, self).__init__( - name='%s_MonitorThread' % measurement_name) + def __init__(self, measurement_name: Text, + monitor_desc: phase_descriptor.PhaseDescriptor, + extra_kwargs: Dict[Any, Any], + test_state: core_test_state.TestState, interval_ms: int): + super(_MonitorThread, + self).__init__(name='%s_MonitorThread' % measurement_name) self.measurement_name = measurement_name self.monitor_desc = monitor_desc self.test_state = test_state self.interval_ms = interval_ms self.extra_kwargs = extra_kwargs - def get_value(self): - arg_info = inspect.getargspec(self.monitor_desc.func) - if arg_info.keywords: + def get_value(self) -> Any: + if six.PY3: + argspec = inspect.getfullargspec(self.monitor_desc.func) + argspec_args = argspec.args + argspec_keywords = argspec.varkw + else: + argspec = inspect.getargspec(self.monitor_desc.func) # pylint: disable=deprecated-method + argspec_args = argspec.args + argspec_keywords = argspec.keywords + if argspec_keywords: # Monitor phase takes **kwargs, so just pass everything in. kwargs = self.extra_kwargs else: # Only pass in args that the monitor phase takes. - kwargs = {arg: val for arg, val in self.extra_kwargs - if arg in arg_info.args} + kwargs = { + arg: val for arg, val in self.extra_kwargs if arg in argspec_args + } return self.monitor_desc.with_args(**kwargs)(self.test_state) def _thread_proc(self): @@ -120,17 +133,25 @@ def _take_sample(): mean_sample_ms = ((9 * mean_sample_ms) + cur_sample_ms) / 10.0 -def monitors(measurement_name, monitor_func, units=None, poll_interval_ms=1000): +def monitors( + measurement_name: Text, + monitor_func: phase_descriptor.PhaseT, + units: Optional[uom.UnitDescriptor] = None, + poll_interval_ms: int = 1000 +) -> Callable[[phase_descriptor.PhaseT], phase_descriptor.PhaseDescriptor]: + """Returns a decorator that wraps a phase with a monitor.""" monitor_desc = openhtf.PhaseDescriptor.wrap_or_copy(monitor_func) - def wrapper(phase_func): + + def wrapper( + phase_func: phase_descriptor.PhaseT) -> phase_descriptor.PhaseDescriptor: phase_desc = openhtf.PhaseDescriptor.wrap_or_copy(phase_func) # Re-key this dict so we don't have to worry about collisions with # plug.plug() decorators on the phase function. Since we aren't # updating kwargs here, we don't have to worry about collisions with # kwarg names. - monitor_plugs = {('_' * idx) + measurement_name + '_monitor': plug.cls for - idx, plug in enumerate(monitor_desc.plugs, start=1)} + monitor_plugs = {('_' * idx) + measurement_name + '_monitor': plug.cls + for idx, plug in enumerate(monitor_desc.plugs, start=1)} @openhtf.PhaseOptions(requires_state=True) @plugs.plug(update_kwargs=False, **monitor_plugs) @@ -140,14 +161,16 @@ def wrapper(phase_func): @functools.wraps(phase_desc.func) def monitored_phase_func(test_state, *args, **kwargs): # Start monitor thread, it will run monitor_desc periodically. - monitor_thread = _MonitorThread( - measurement_name, monitor_desc, phase_desc.extra_kwargs, test_state, - poll_interval_ms) + monitor_thread = _MonitorThread(measurement_name, monitor_desc, + phase_desc.extra_kwargs, test_state, + poll_interval_ms) monitor_thread.start() try: return phase_desc(test_state, *args, **kwargs) finally: monitor_thread.kill() monitor_thread.join() + return monitored_phase_func + return wrapper diff --git a/openhtf/core/phase_branches.py b/openhtf/core/phase_branches.py new file mode 100644 index 000000000..f98333c9f --- /dev/null +++ b/openhtf/core/phase_branches.py @@ -0,0 +1,272 @@ +# Copyright 2020 Google Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Implements phase node branches. + +A BranchSequence is a phase node sequence that runs conditiionally based on the +diagnosis results of the test run. +""" + +import abc +from typing import Any, Callable, Dict, Iterator, Text, Tuple, TYPE_CHECKING, Union + +import attr +import enum # pylint: disable=g-bad-import-order +from openhtf import util +from openhtf.core import diagnoses_lib +from openhtf.core import phase_collections +from openhtf.core import phase_descriptor +from openhtf.core import phase_nodes +from openhtf.core import test_record +from openhtf.util import data +import six + +if TYPE_CHECKING: + from openhtf.core import test_state # pylint: disable=g-import-not-at-top + + +class NoPhasesFoundError(Exception): + """No phases were found in the test record.""" + + +class ConditionOn(enum.Enum): + ALL = 'ALL' + ANY = 'ANY' + NOT_ANY = 'NOT_ANY' + NOT_ALL = 'NOT_ALL' + + +class PreviousPhases(enum.Enum): + + # Check the immediately previous phase. + LAST = 'LAST' + + # Check all phases. + ALL = 'ALL' + + +def _not_any(iterable: Iterator[bool]) -> bool: + return not any(iterable) + + +def _not_all(iterable: Iterator[bool]) -> bool: + return not all(iterable) + + +_CONDITION_LOOKUP = { + ConditionOn.ALL: all, + ConditionOn.ANY: any, + ConditionOn.NOT_ANY: _not_any, + ConditionOn.NOT_ALL: _not_all, +} + + +@attr.s(slots=True, frozen=True) +class DiagnosisCondition(object): + """Encapsulated object for evaulating DiagResultEnum conditions.""" + + # Indicates the diagnosis is tested. + condition = attr.ib(type=ConditionOn) + + # The diagnosis results to test on. + diagnosis_results = attr.ib(type=Tuple[diagnoses_lib.DiagResultEnum, ...]) + + @classmethod + def on_all(cls, *diags: diagnoses_lib.DiagResultEnum) -> 'DiagnosisCondition': + return cls(condition=ConditionOn.ALL, diagnosis_results=tuple(diags)) + + @classmethod + def on_any(cls, *diags: diagnoses_lib.DiagResultEnum) -> 'DiagnosisCondition': + return cls(condition=ConditionOn.ANY, diagnosis_results=tuple(diags)) + + @classmethod + def on_not_all(cls, + *diags: diagnoses_lib.DiagResultEnum) -> 'DiagnosisCondition': + return cls(condition=ConditionOn.NOT_ALL, diagnosis_results=tuple(diags)) + + @classmethod + def on_not_any(cls, + *diags: diagnoses_lib.DiagResultEnum) -> 'DiagnosisCondition': + return cls(condition=ConditionOn.NOT_ANY, diagnosis_results=tuple(diags)) + + def check(self, diag_store: diagnoses_lib.DiagnosesStore) -> bool: + condition_func = _CONDITION_LOOKUP[self.condition] + return condition_func( + diag_store.has_diagnosis_result(d) for d in self.diagnosis_results) + + def _asdict(self) -> Dict[Text, Any]: + """Returns a base type dictionary for serialization.""" + return { + 'condition': self.condition, + 'diagnosis_results': list(self.diagnosis_results), + } + + @property + def message(self) -> Text: + return '{}{}'.format(self.condition, self.diagnosis_results) + + +@attr.s(slots=True, frozen=True, init=False) +class BranchSequence(phase_collections.PhaseSequence): + """A node that collects phase sequence that conditionally run. + + This object is immutable. + """ + + diag_condition = attr.ib(type=DiagnosisCondition, default=None) + + def __init__(self, diag_condition: DiagnosisCondition, + *args: phase_collections.SequenceInitializerT, **kwargs: Any): + super(BranchSequence, self).__init__(*args, **kwargs) + object.__setattr__(self, 'diag_condition', diag_condition) + + def _asdict(self) -> Dict[Text, Any]: + """Returns a base type dictionary for serialization.""" + ret = super(BranchSequence, self)._asdict() # type: Dict[Text, Any] + ret.update(diag_condition=self.diag_condition._asdict()) + return ret + + def should_run(self, diag_store: diagnoses_lib.DiagnosesStore) -> bool: + return self.diag_condition.check(diag_store) + + +@attr.s(slots=True, frozen=True) +class Checkpoint(six.with_metaclass(abc.ABCMeta, phase_nodes.PhaseNode)): + """Nodes that check for phase failures or if diagnoses were triggered. + + When the condition for a checkpoint is triggered, a STOP or FAIL_SUBTEST + result is handled by the TestExecutor. + """ + + name = attr.ib(type=Text) + action = attr.ib( + type=phase_descriptor.PhaseResult, + validator=attr.validators.in_([ + phase_descriptor.PhaseResult.STOP, + phase_descriptor.PhaseResult.FAIL_SUBTEST + ]), + default=phase_descriptor.PhaseResult.STOP) + + def _asdict(self) -> Dict[Text, Any]: + return { + 'name': self.name, + 'action': self.action, + } + + def with_args(self, **kwargs: Any) -> 'Checkpoint': + return data.attr_copy(self, name=util.format_string(self.name, kwargs)) + + def with_plugs(self, **subplugs: Any) -> 'Checkpoint': + return data.attr_copy(self, name=util.format_string(self.name, subplugs)) + + def load_code_info(self) -> 'Checkpoint': + return self + + def apply_to_all_phases( + self, func: Callable[[phase_descriptor.PhaseDescriptor], + phase_descriptor.PhaseDescriptor] + ) -> 'Checkpoint': + return self + + def get_result( + self, running_test_state: 'test_state.TestState' + ) -> phase_descriptor.PhaseReturnT: + if self._check_for_action(running_test_state): + return self.action + return phase_descriptor.PhaseResult.CONTINUE + + @abc.abstractmethod + def _check_for_action(self, + running_test_state: 'test_state.TestState') -> bool: + """Returns True when the action should be taken.""" + + @abc.abstractmethod + def record_conditional(self) -> Union[PreviousPhases, DiagnosisCondition]: + """Returns the conditional record data.""" + + +@attr.s(slots=True, frozen=True) +class PhaseFailureCheckpoint(Checkpoint): + """Node that checks if a previous phase or all previous phases failed. + + If the phases fail, this will be resolved as `action`. + + When using `all_previous`, this will take in to account all phases; it will + *NOT* limit itself to the subtest when using the FAIL_SUBTEST action. + """ + + previous_phases_to_check = attr.ib( + type=PreviousPhases, default=PreviousPhases.ALL) + + @classmethod + def last(cls, *args, **kwargs) -> 'PhaseFailureCheckpoint': + """Checking that takes action when the last phase fails.""" + kwargs['previous_phases_to_check'] = PreviousPhases.LAST + return cls(*args, **kwargs) + + @classmethod + def all_previous(cls, *args, **kwargs) -> 'PhaseFailureCheckpoint': + kwargs['previous_phases_to_check'] = PreviousPhases.ALL + return cls(*args, **kwargs) + + def _asdict(self) -> Dict[Text, Any]: + ret = super(PhaseFailureCheckpoint, self)._asdict() + ret.update(previous_phases_to_check=self.previous_phases_to_check) + return ret + + def _phase_failed(self, phase_rec: test_record.PhaseRecord) -> bool: + """Returns True if the phase_rec failed; ignores ERRORs.""" + return phase_rec.outcome == test_record.PhaseOutcome.FAIL + + def _check_for_action(self, + running_test_state: 'test_state.TestState') -> bool: + """Returns True when the specific set of phases fail.""" + phase_records = running_test_state.test_record.phases + if not phase_records: + raise NoPhasesFoundError('No phases found in the test record.') + if self.previous_phases_to_check == PreviousPhases.LAST: + return self._phase_failed(phase_records[-1]) + else: + for phase_rec in phase_records: + if self._phase_failed(phase_rec): + return True + return False + + def record_conditional(self) -> PreviousPhases: + return self.previous_phases_to_check + + +@attr.s(slots=True, frozen=True, init=False) +class DiagnosisCheckpoint(Checkpoint): + """Checkpoint node that activates when a diagnosis condition is true.""" + + diag_condition = attr.ib(type=DiagnosisCondition, default=None) + + def __init__(self, name, diag_condition, *args, **kwargs): + super(DiagnosisCheckpoint, self).__init__(name, *args, **kwargs) + object.__setattr__(self, 'diag_condition', diag_condition) + + def _asdict(self) -> Dict[Text, Any]: + ret = super(DiagnosisCheckpoint, self)._asdict() + ret.update(diag_condition=self.diag_condition._asdict()) + return ret + + def _check_for_action(self, + running_test_state: 'test_state.TestState') -> bool: + """Returns True if the condition is true.""" + return self.diag_condition.check(running_test_state.diagnoses_manager.store) + + def record_conditional(self) -> DiagnosisCondition: + return self.diag_condition diff --git a/openhtf/core/phase_collections.py b/openhtf/core/phase_collections.py new file mode 100644 index 000000000..d58d1d78c --- /dev/null +++ b/openhtf/core/phase_collections.py @@ -0,0 +1,236 @@ +# Copyright 2020 Google Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Implements the basic PhaseNode collections. + +Phase Sequence are basic collections where each node is sequentially run. These +instances can be nested inside of each other or with any other phase node. A +terminal error during a phase sequence will cause the rest of the nodes to be +skipped. +""" + +import abc +import collections +from typing import Any, Callable, DefaultDict, Dict, Iterable, Iterator, List, Optional, Text, Tuple, Type, TypeVar, Union + +import attr +from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import phase_descriptor +from openhtf.core import phase_nodes +import six +from six.moves import collections_abc + +NodeType = TypeVar('NodeType', bound=phase_nodes.PhaseNode) +SequenceClassT = TypeVar('SequenceClassT', bound='PhaseSequence') +PhasesOrNodesT = Iterable[phase_descriptor.PhaseCallableOrNodeT] +SequenceInitializerT = Union[phase_descriptor.PhaseCallableOrNodeT, + PhasesOrNodesT] + + +class DuplicateSubtestNamesError(Exception): + """Multiple subtests share the same name.""" + + +def _recursive_flatten(n: Any) -> Iterator[phase_nodes.PhaseNode]: + """Yields flattened phase nodes.""" + if isinstance(n, collections_abc.Iterable): + for it in n: + for node in _recursive_flatten(it): + yield node + elif isinstance(n, phase_nodes.PhaseNode): + yield n.copy() + elif isinstance(n, phase_descriptor.PhaseDescriptor) or callable(n): + yield phase_descriptor.PhaseDescriptor.wrap_or_copy(n) + else: + raise ValueError('Cannot flatten {}'.format(n)) + + +def flatten(n: Any) -> List[phase_nodes.PhaseNode]: + """Recursively flatten the argument and return a list of phase nodes.""" + return list(_recursive_flatten(n)) + + +class PhaseCollectionNode( + six.with_metaclass(abc.ABCMeta, phase_nodes.PhaseNode)): + """Base class for a node that contains multiple other phase nodes.""" + + __slots__ = () + + def all_phases(self) -> Iterator[phase_descriptor.PhaseDescriptor]: + """Returns an iterator of all the Phase Descriptors for the collection.""" + return self.filter_by_type(phase_descriptor.PhaseDescriptor) + + @abc.abstractmethod + def filter_by_type(self, node_cls: Type[NodeType]) -> Iterator[NodeType]: + """Returns recursively all the nodes of the given type. + + This can return collection nodes that include each other. + + Args: + node_cls: The phase node subtype to iterate over. + """ + + +@attr.s(slots=True, frozen=True, init=False) +class PhaseSequence(PhaseCollectionNode): + """A node that collects a sequence of phase nodes. + + This object is immutable. + """ + + # The sequence of phase nodes. + nodes = attr.ib(type=Tuple[phase_nodes.PhaseNode, ...]) + name = attr.ib(type=Optional[Text], default=None) + + # TODO(arsharma): When fully PY3, replace kwargs with nodes and name keywords. + def __init__(self, *args: SequenceInitializerT, **kwargs: Any): + """Initializer. + + Args: + *args: Sequence of phase nodes, phase callables, or recursive iterables of + either. + **kwargs: Keyword arguments. Allows two: nodes - A tuple of PhaseNode + instances. name - The name of the sequence. + """ + super(PhaseSequence, self).__init__() + name = kwargs.pop('name', None) # type: Optional[Text] + object.__setattr__(self, 'name', name) + nodes = kwargs.pop( + 'nodes', None) # type: Optional[Tuple[phase_nodes.PhaseNode, ...]] + if nodes is None: + nodes = tuple(_recursive_flatten(args)) + elif args: + raise ValueError('args and nodes cannot both be specified') + object.__setattr__(self, 'nodes', nodes) + if kwargs: + raise ValueError('Only allowed keywords are `nodes` and `name`.') + + # TODO(arsharma): When fully PY3, replace kwargs with name keyword. + @classmethod + def combine(cls: Type[SequenceClassT], *sequences: Optional['PhaseSequence'], + **kwargs: Any) -> Optional[SequenceClassT]: + """Combine multiple phase node sequences that could be None. + + Args: + *sequences: The Phase Sequences, which can be None. + **kwargs: Keyword arguments. Allows only name. + + Returns: + The combined phase node sequence if at least one sequence is defined; + otherwise, None. + """ + name = kwargs.pop('name', None) + if kwargs: + raise ValueError('Only allowed keyword is `name`.') + + nodes = [] + + for seq in sequences: + if seq: + nodes.extend(seq.nodes) + + if nodes: + return cls(nodes=tuple(nodes), name=name) + return None + + def _asdict(self) -> Dict[Text, Any]: + """Constructs a base type dictionary for JSON serialization.""" + return { + 'name': self.name, + 'nodes': [node._asdict() for node in self.nodes], + } + + def with_args(self: SequenceClassT, **kwargs: Any) -> SequenceClassT: + """Send these keyword-arguments when phases are called.""" + return type(self)( + nodes=tuple(n.with_args(**kwargs) for n in self.nodes), + name=util.format_string(self.name, kwargs)) + + def with_plugs(self: SequenceClassT, + **subplugs: Type[base_plugs.BasePlug]) -> SequenceClassT: + """Substitute plugs for placeholders for this phase, error on unknowns.""" + return type(self)( + nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes), + name=util.format_string(self.name, subplugs)) + + def load_code_info(self: SequenceClassT) -> SequenceClassT: + """Load coded info for all contained phases.""" + return type(self)( + nodes=tuple(n.load_code_info() for n in self.nodes), name=self.name) + + def apply_to_all_phases( + self: SequenceClassT, func: Callable[[phase_descriptor.PhaseDescriptor], + phase_descriptor.PhaseDescriptor] + ) -> SequenceClassT: + """Apply func to all contained phases.""" + return type(self)( + nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes), + name=self.name) + + def filter_by_type(self, node_cls: Type[NodeType]) -> Iterator[NodeType]: + """Yields recursively all the nodes of the given type. + + This can yield collection nodes that include each other. + + Args: + node_cls: The phase node subtype to iterate over. + """ + for node in self.nodes: + if isinstance(node, node_cls): + yield node + if isinstance(node, PhaseCollectionNode): + for n in node.filter_by_type(node_cls): + yield n + + +@attr.s(slots=True, frozen=True, init=False) +class Subtest(PhaseSequence): + """A node for a subtest. + + A subtest must have a unique name for all subtest nodes in the overarching + test. + """ + + # TODO(arsharma): When fully PY3, replace kwargs with nodes keyword. + def __init__(self, name: Text, *args: SequenceInitializerT, **kwargs: Any): + kwargs['name'] = name + super(Subtest, self).__init__(*args, **kwargs) + + +def check_for_duplicate_subtest_names(sequence: PhaseSequence): + """Check for subtests with duplicate names. + + Args: + sequence: Sequence of phase nodes to check over. + + Raises: + DuplicateSubtestNamesError: when duplicate subtest names are found. + """ + names_to_subtests = collections.defaultdict( + list) # type: DefaultDict[Text, List[Subtest]] + for subtest in sequence.filter_by_type(Subtest): + names_to_subtests[subtest.name].append(subtest) + + duplicates = [] # type: List[Text] + for name, subtests in names_to_subtests.items(): + if len(subtests) > 1: + duplicates.append('Name "{}" used by multiple subtests: {}'.format( + name, subtests)) + if not duplicates: + return + duplicates.sort() + raise DuplicateSubtestNamesError('Duplicate Subtest names: {}'.format( + '\n'.join(duplicates))) diff --git a/openhtf/core/phase_descriptor.py b/openhtf/core/phase_descriptor.py index e499fc2d3..cad849bd4 100644 --- a/openhtf/core/phase_descriptor.py +++ b/openhtf/core/phase_descriptor.py @@ -11,121 +11,135 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Phases in OpenHTF. Phases in OpenHTF are distinct steps in a test. Each phase is an instance of PhaseDescriptor class. """ + +import enum import inspect import pdb -import sys +from typing import Any, Callable, Dict, List, Optional, Text, TYPE_CHECKING, Type, Union -import enum -import mutablerecords +import attr import openhtf from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import phase_nodes from openhtf.core import test_record import openhtf.plugs from openhtf.util import data import six +if TYPE_CHECKING: + from openhtf.core import diagnoses_lib # pylint: disable=g-import-not-at-top + from openhtf.core import measurements as core_measurements # pylint: disable=g-import-not-at-top + from openhtf.core import test_state # pylint: disable=g-import-not-at-top + class PhaseWrapError(Exception): """Error with phase wrapping.""" -# Result of a phase. -# -# These values can be returned by a test phase to control what the framework -# does after the phase. -PhaseResult = enum.Enum('PhaseResult', [ # pylint: disable=invalid-name - # Causes the framework to process the phase measurement outcomes and execute - # the next phase. - 'CONTINUE', - # Causes the framework to mark the phase with a fail outcome and execute the - # next phase. - 'FAIL_AND_CONTINUE', - # Causes the framework to execute the same phase again, ignoring the - # measurement outcomes for this instance. If returned more than the phase's - # repeat_limit option, this will be treated as a STOP. - 'REPEAT', - # Causes the framework to ignore the measurement outcomes and execute the - # next phase. The phase is still logged, unlike with run_if. - 'SKIP', - # Causes the framework to stop executing, indicating a failure. - 'STOP' -]) - - -class PhaseOptions(mutablerecords.Record('PhaseOptions', [], { - 'name': None, 'timeout_s': None, 'run_if': None, 'requires_state': None, - 'repeat_limit': None, 'run_under_pdb': False})): +class PhaseResult(enum.Enum): + """Result of a phase. + + These values can be returned by a test phase to control what the framework + does after the phase. + """ + + # Causes the framework to process the phase measurement outcomes and execute + # the next phase. + CONTINUE = 'CONTINUE' + # Causes the framework to mark the phase with a fail outcome and execute the + # next phase. + FAIL_AND_CONTINUE = 'FAIL_AND_CONTINUE' + # Causes the framework to execute the same phase again, ignoring the + # measurement outcomes for this instance. If returned more than the phase's + # repeat_limit option, this will be treated as a STOP. + REPEAT = 'REPEAT' + # Causes the framework to ignore the measurement outcomes and execute the + # next phase. The phase is still logged, unlike with run_if. + SKIP = 'SKIP' + # Causes the framework to stop executing, indicating a failure. + STOP = 'STOP' + # Causes the framework to stop the current subtest and is otherwise treated as + # a FAIL_AND_CONTINUE. If not in a subtest, this is treated as an ERROR. + FAIL_SUBTEST = 'FAIL_SUBTEST' + + +PhaseReturnT = Optional[PhaseResult] +PhaseCallableT = Callable[..., PhaseReturnT] +PhaseCallableOrNodeT = Union[PhaseCallableT, phase_nodes.PhaseNode] +PhaseT = Union['PhaseDescriptor', PhaseCallableT] +TimeoutT = Union[float, int] + + +@attr.s(slots=True) +class PhaseOptions(object): """Options used to override default test phase behaviors. Attributes: name: Override for the name of the phase. Can be formatted in several - different ways as defined in util.format_string. + different ways as defined in util.format_string. timeout_s: Timeout to use for the phase, in seconds. run_if: Callback that decides whether to run the phase or not; if not run, - the phase will also not be logged. + the phase will also not be logged. requires_state: If True, pass the whole TestState into the first argument, - otherwise only the TestApi will be passed in. This is useful if a - phase needs to wrap another phase for some reason, as - PhaseDescriptors can only be invoked with a TestState instance. - repeat_limit: Maximum number of repeats. None indicates a phase will - be repeated infinitely as long as PhaseResult.REPEAT is returned. + otherwise only the TestApi will be passed in. This is useful if a phase + needs to wrap another phase for some reason, as PhaseDescriptors can only + be invoked with a TestState instance. + repeat_limit: Maximum number of repeats. None indicates a phase will be + repeated infinitely as long as PhaseResult.REPEAT is returned. run_under_pdb: If True, run the phase under the Python Debugger (pdb). When - setting this option, increase the phase timeout as well because the - timeout will still apply when under the debugger. - - Example Usages: - @PhaseOptions(timeout_s=1) - def PhaseFunc(test): - pass - - @PhaseOptions(name='Phase({port})') - def PhaseFunc(test, port, other_info): - pass + setting this option, increase the phase timeout as well because the + timeout will still apply when under the debugger. + Example Usages: @PhaseOptions(timeout_s=1) + def PhaseFunc(test): pass @PhaseOptions(name='Phase({port})') + def PhaseFunc(test, port, other_info): pass """ - def format_strings(self, **kwargs): + name = attr.ib(type=Optional[Union[Text, Callable[..., Text]]], default=None) + timeout_s = attr.ib(type=Optional[TimeoutT], default=None) + run_if = attr.ib(type=Optional[Callable[[], bool]], default=None) + requires_state = attr.ib(type=bool, default=False) + repeat_limit = attr.ib(type=Optional[int], default=None) + run_under_pdb = attr.ib(type=bool, default=False) + + def format_strings(self, **kwargs: Any) -> 'PhaseOptions': """String substitution of name.""" - return mutablerecords.CopyRecord( - self, name=util.format_string(self.name, kwargs)) + return data.attr_copy(self, name=util.format_string(self.name, kwargs)) - def update(self, **kwargs): + def update(self, **kwargs: Any) -> None: for key, value in six.iteritems(kwargs): - if key not in self.__slots__: - raise AttributeError('Type %s does not have attribute %s' % ( - type(self).__name__, key)) setattr(self, key, value) - def __call__(self, phase_func): + def __call__(self, phase_func: PhaseT) -> 'PhaseDescriptor': phase = PhaseDescriptor.wrap_or_copy(phase_func) - for attr in self.__slots__: - value = getattr(self, attr) - if value is not None: - setattr(phase.options, attr, value) + if self.name: + phase.options.name = self.name + if self.timeout_s is not None: + phase.options.timeout_s = self.timeout_s + if self.run_if: + phase.options.run_if = self.run_if + if self.requires_state: + phase.options.requires_state = self.requires_state + if self.repeat_limit is not None: + phase.options.repeat_limit = self.repeat_limit + if self.run_under_pdb: + phase.options.run_under_pdb = self.run_under_pdb return phase + TestPhase = PhaseOptions -class PhaseDescriptor(mutablerecords.Record( - 'PhaseDescriptor', ['func'], - { - 'options': PhaseOptions, - 'plugs': list, - 'measurements': list, - 'diagnosers': list, - 'extra_kwargs': dict, - 'code_info': test_record.CodeInfo.uncaptured(), - })): +@attr.s(slots=True) +class PhaseDescriptor(phase_nodes.PhaseNode): """Phase function and related information. Attributes: @@ -136,10 +150,23 @@ class PhaseDescriptor(mutablerecords.Record( diagnosers: List of PhaseDiagnoser objects. extra_kwargs: Keyword arguments that will be passed to the function. code_info: Info about the source code of func. + name: Phase name. + doc: Phase documentation. """ + func = attr.ib(type=PhaseCallableT) + options = attr.ib(type=PhaseOptions, factory=PhaseOptions) + plugs = attr.ib(type=List[base_plugs.PhasePlug], factory=list) + measurements = attr.ib( + type=List['core_measurements.Measurement'], factory=list) + diagnosers = attr.ib( + type=List['diagnoses_lib.BasePhaseDiagnoser'], factory=list) + extra_kwargs = attr.ib(type=Dict[Text, Any], factory=dict) + code_info = attr.ib( + type=test_record.CodeInfo, factory=test_record.CodeInfo.uncaptured) + @classmethod - def wrap_or_copy(cls, func, **options): + def wrap_or_copy(cls, func: PhaseT, **options: Any) -> 'PhaseDescriptor': """Return a new PhaseDescriptor from the given function or instance. We want to return a new copy so that you can reuse a phase with different @@ -155,80 +182,73 @@ def wrap_or_copy(cls, func, **options): Returns: A new PhaseDescriptor object. """ + # TODO(arsharma): Remove when type annotations are more enforced. if isinstance(func, openhtf.PhaseGroup): - raise PhaseWrapError('Cannot wrap PhaseGroup <%s> as a phase.' % ( - func.name or 'Unnamed')) + raise PhaseWrapError('Cannot wrap PhaseGroup <%s> as a phase.' % + (func.name or 'Unnamed')) # pytype: disable=attribute-error if isinstance(func, cls): # We want to copy so that a phase can be reused with different options # or kwargs. See with_args() below for more details. - retval = mutablerecords.CopyRecord(func) + retval = data.attr_copy(func) else: retval = cls(func) retval.options.update(**options) return retval - def _asdict(self): - asdict = { - k: data.convert_to_base_types(getattr(self, k), ignore_keys=('cls',)) - for k in self.optional_attributes - } - asdict.update(name=self.name, doc=self.doc) - return asdict + def _asdict(self) -> Dict[Text, Any]: + ret = attr.asdict(self, filter=attr.filters.exclude('func')) + ret.update(name=self.name, doc=self.doc) + return ret @property - def name(self): - return self.options.name or self.func.__name__ + def name(self) -> Text: + if self.options.name and isinstance(self.options.name, str): + return self.options.name + return self.func.__name__ @property - def doc(self): + def doc(self) -> Optional[Text]: return self.func.__doc__ - def with_known_args(self, **kwargs): - """Send only known keyword-arguments to the phase when called.""" + def with_args(self, **kwargs: Any) -> 'PhaseDescriptor': + """Send keyword-arguments to the phase when called. + + Args: + **kwargs: mapping of argument name to value to be passed to the phase + function when called. Unknown arguments are ignored. + + Returns: + Updated PhaseDescriptor. + """ if six.PY3: argspec = inspect.getfullargspec(self.func) argspec_keywords = argspec.varkw else: - argspec = inspect.getargspec(self.func) + argspec = inspect.getargspec(self.func) # pylint: disable=deprecated-method argspec_keywords = argspec.keywords - stored = {} + known_arguments = {} for key, arg in six.iteritems(kwargs): if key in argspec.args or argspec_keywords: - stored[key] = arg - if stored: - return self.with_args(**stored) - return self - - def with_args(self, **kwargs): - """Send these keyword-arguments to the phase when called.""" - # Make a copy so we can have multiple of the same phase with different args - # in the same test. - new_info = mutablerecords.CopyRecord(self) + known_arguments[key] = arg + + new_info = data.attr_copy(self) new_info.options = new_info.options.format_strings(**kwargs) - new_info.extra_kwargs.update(kwargs) + new_info.extra_kwargs.update(known_arguments) new_info.measurements = [m.with_args(**kwargs) for m in self.measurements] return new_info - def with_known_plugs(self, **subplugs): - """Substitute only known plugs for placeholders for this phase.""" - return self._apply_with_plugs(subplugs, error_on_unknown=False) - - def with_plugs(self, **subplugs): - """Substitute plugs for placeholders for this phase, error on unknowns.""" - return self._apply_with_plugs(subplugs, error_on_unknown=True) - - def _apply_with_plugs(self, subplugs, error_on_unknown): + def with_plugs(self, + **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseDescriptor': """Substitute plugs for placeholders for this phase. Args: - subplugs: dict of plug name to plug class, plug classes to replace. - error_on_unknown: bool, if True, then error when an unknown plug name is - provided. + **subplugs: dict of plug name to plug class, plug classes to replace; + unknown plug names are ignored. A base_plugs.InvalidPlugError is raised + when a test includes a phase that still has a placeholder plug. Raises: - openhtf.plugs.InvalidPlugError if for one of the plug names one of the - following is true: - - error_on_unknown is True and the plug name is not registered. + base_plugs.InvalidPlugError: if for one of the plug names one of the + following is true: - The new plug subclass is not a subclass of the original. - The original plug class is not a placeholder or automatic placeholder. @@ -236,37 +256,51 @@ def _apply_with_plugs(self, subplugs, error_on_unknown): PhaseDescriptor with updated plugs. """ plugs_by_name = {plug.name: plug for plug in self.plugs} - new_plugs = dict(plugs_by_name) + new_plugs = {} for name, sub_class in six.iteritems(subplugs): original_plug = plugs_by_name.get(name) accept_substitute = True if original_plug is None: - if not error_on_unknown: - continue - accept_substitute = False - elif isinstance(original_plug.cls, openhtf.plugs.PlugPlaceholder): + continue + elif isinstance(original_plug.cls, base_plugs.PlugPlaceholder): accept_substitute = issubclass(sub_class, original_plug.cls.base_class) else: # Check __dict__ to see if the attribute is explicitly defined in the # class, rather than being defined in a parent class. accept_substitute = ('auto_placeholder' in original_plug.cls.__dict__ - and original_plug.cls.auto_placeholder - and issubclass(sub_class, original_plug.cls)) + and original_plug.cls.auto_placeholder and + issubclass(sub_class, original_plug.cls)) if not accept_substitute: - raise openhtf.plugs.InvalidPlugError( + raise base_plugs.InvalidPlugError( 'Could not find valid placeholder for substitute plug %s ' 'required for phase %s' % (name, self.name)) - new_plugs[name] = mutablerecords.CopyRecord(original_plug, cls=sub_class) + new_plugs[name] = data.attr_copy(original_plug, cls=sub_class) + + if not new_plugs: + return self + + plugs_by_name.update(new_plugs) - return mutablerecords.CopyRecord( + return data.attr_copy( self, - plugs=list(new_plugs.values()), + plugs=list(plugs_by_name.values()), options=self.options.format_strings(**subplugs), measurements=[m.with_args(**subplugs) for m in self.measurements]) - def __call__(self, test_state): + def load_code_info(self) -> 'PhaseDescriptor': + """Load code info for this phase.""" + return data.attr_copy( + self, code_info=test_record.CodeInfo.for_function(self.func)) + + def apply_to_all_phases( + self, func: Callable[['PhaseDescriptor'], + 'PhaseDescriptor']) -> 'PhaseDescriptor': + return func(self) + + def __call__(self, + running_test_state: 'test_state.TestState') -> PhaseReturnT: """Invoke this Phase, passing in the appropriate args. By default, an openhtf.TestApi is passed as the first positional arg, but if @@ -276,30 +310,31 @@ def __call__(self, test_state): with_args(), combined with plugs (plugs override extra_kwargs). Args: - test_state: test_state.TestState for the currently executing Test. + running_test_state: test_state.TestState for the currently executing Test. Returns: The return value from calling the underlying function. """ kwargs = dict(self.extra_kwargs) - kwargs.update(test_state.plug_manager.provide_plugs( - (plug.name, plug.cls) for plug in self.plugs if plug.update_kwargs)) + kwargs.update( + running_test_state.plug_manager.provide_plugs( + (plug.name, plug.cls) for plug in self.plugs if plug.update_kwargs)) - if sys.version_info[0] < 3: - arg_info = inspect.getargspec(self.func) - keywords = arg_info.keywords - else: + if six.PY3: arg_info = inspect.getfullargspec(self.func) keywords = arg_info.varkw + else: + arg_info = inspect.getargspec(self.func) # pylint: disable=deprecated-method + keywords = arg_info.keywords # Pass in test_api if the phase takes *args, or **kwargs with at least 1 # positional, or more positional args than we have keyword args. - if arg_info.varargs or (keywords and len(arg_info.args) >= 1) or ( - len(arg_info.args) > len(kwargs)): + if arg_info.varargs or (keywords and len(arg_info.args) >= 1) or (len( + arg_info.args) > len(kwargs)): args = [] if self.options.requires_state: - args.append(test_state) + args.append(running_test_state) else: - args.append(test_state.test_api) + args.append(running_test_state.test_api) if self.options.run_under_pdb: return pdb.runcall(self.func, *args, **kwargs) diff --git a/openhtf/core/phase_executor.py b/openhtf/core/phase_executor.py index e921b959f..de3f401cc 100644 --- a/openhtf/core/phase_executor.py +++ b/openhtf/core/phase_executor.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """PhaseExecutor module for handling the phases of a test. -Each phase is an instance of openhtf.PhaseDescriptor and therefore has +Each phase is an instance of phase_descriptor.PhaseDescriptor and therefore has relevant options. Each option is taken into account when executing a phase, such as checking options.run_if as soon as possible and timing out at the appropriate time. -A phase must return an openhtf.PhaseResult, one of CONTINUE, REPEAT, or STOP. -A phase may also return None, or have no return statement, which is the same as -returning openhtf.PhaseResult.CONTINUE. These results are then acted upon -accordingly and a new test run status is returned. +A phase must return an phase_descriptor.PhaseResult, one of CONTINUE, REPEAT, or +STOP. A phase may also return None, or have no return statement, which is the +same as returning openhtf.PhaseResult.CONTINUE. These results are then acted +upon accordingly and a new test run status is returned. Phases are always run in order and not allowed to loop back, though a phase may choose to repeat itself by returning REPEAT. Returning STOP will cause a test @@ -32,44 +30,61 @@ framework. """ -import collections import logging +import pstats import sys import threading import time import traceback - -import openhtf +import types +from typing import Any, Dict, Optional, Text, Tuple, Type, TYPE_CHECKING, Union + +import attr +from openhtf import util +from openhtf.core import phase_branches +from openhtf.core import phase_descriptor +from openhtf.core import test_record from openhtf.util import argv from openhtf.util import threads from openhtf.util import timeouts +if TYPE_CHECKING: + from openhtf.core import test_state as htf_test_state # pylint: disable=g-import-not-at-top + DEFAULT_PHASE_TIMEOUT_S = 3 * 60 -ARG_PARSER = argv.ModuleParser() +ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( - '--phase_default_timeout_s', default=DEFAULT_PHASE_TIMEOUT_S, - action=argv.StoreInModule, target='%s.DEFAULT_PHASE_TIMEOUT_S' % __name__, + '--phase_default_timeout_s', + default=DEFAULT_PHASE_TIMEOUT_S, + action=argv.StoreInModule, + target='%s.DEFAULT_PHASE_TIMEOUT_S' % __name__, help='Test phase timeout in seconds') +# TODO(arsharma): Use the test state logger. _LOG = logging.getLogger(__name__) -class ExceptionInfo(collections.namedtuple( - 'ExceptionInfo', ['exc_type', 'exc_val', 'exc_tb'])): +@attr.s(slots=True, frozen=True) +class ExceptionInfo(object): """Wrap the description of a raised exception and its traceback.""" - def _asdict(self): + exc_type = attr.ib(type=Type[Exception]) + exc_val = attr.ib(type=Exception) + exc_tb = attr.ib(type=types.TracebackType) + + def as_base_types(self) -> Dict[Text, Text]: return { 'exc_type': str(self.exc_type), - 'exc_val': self.exc_val, + 'exc_val': str(self.exc_val), 'exc_tb': self.get_traceback_string(), } - def get_traceback_string(self): - return ''.join(traceback.format_exception(*self)) + def get_traceback_string(self) -> Text: + return ''.join( + traceback.format_exception(self.exc_type, self.exc_val, self.exc_tb)) - def __str__(self): + def __str__(self) -> Text: return self.exc_type.__name__ @@ -77,8 +92,8 @@ class InvalidPhaseResultError(Exception): """Raised when PhaseExecutionOutcome is created with invalid phase result.""" -class PhaseExecutionOutcome(collections.namedtuple( - 'PhaseExecutionOutcome', 'phase_result')): +@attr.s(slots=True, frozen=True) +class PhaseExecutionOutcome(object): """Provide some utility and sanity around phase return values. This should not be confused with openhtf.PhaseResult. PhaseResult is an @@ -94,35 +109,38 @@ class PhaseExecutionOutcome(collections.namedtuple( similarly be used to check for the timeout case. The only accepted values for phase_result are None (timeout), an instance - of Exception (phase raised), or an instance of openhtf.PhaseResult. Any - other value will raise an InvalidPhaseResultError. + of Exception (phase raised), or an instance of openhtf.PhaseResult. """ - def __new__(cls, phase_result): - if (phase_result is not None and - not isinstance(phase_result, (openhtf.PhaseResult, ExceptionInfo)) and - not isinstance(phase_result, threads.ThreadTerminationError)): - raise InvalidPhaseResultError('Invalid phase result', phase_result) - self = super(PhaseExecutionOutcome, cls).__new__(cls, phase_result) - return self + phase_result = attr.ib(type=Union[None, phase_descriptor.PhaseResult, + ExceptionInfo, + threads.ThreadTerminationError]) + + @property + def is_aborted(self): + return isinstance(self.phase_result, threads.ThreadTerminationError) @property def is_fail_and_continue(self): - return self.phase_result is openhtf.PhaseResult.FAIL_AND_CONTINUE + return self.phase_result is phase_descriptor.PhaseResult.FAIL_AND_CONTINUE + + @property + def is_fail_subtest(self): + return self.phase_result is phase_descriptor.PhaseResult.FAIL_SUBTEST @property def is_repeat(self): - return self.phase_result is openhtf.PhaseResult.REPEAT + return self.phase_result is phase_descriptor.PhaseResult.REPEAT @property def is_skip(self): - return self.phase_result is openhtf.PhaseResult.SKIP + return self.phase_result is phase_descriptor.PhaseResult.SKIP @property def is_terminal(self): """True if this result will stop the test.""" return (self.raised_exception or self.is_timeout or - self.phase_result == openhtf.PhaseResult.STOP) + self.phase_result is phase_descriptor.PhaseResult.STOP) @property def is_timeout(self): @@ -132,12 +150,8 @@ def is_timeout(self): @property def raised_exception(self): """True if the phase in question raised an exception.""" - return isinstance(self.phase_result, ( - ExceptionInfo, threads.ThreadTerminationError)) - - @property - def is_aborted(self): - return isinstance(self.phase_result, threads.ThreadTerminationError) + return isinstance(self.phase_result, + (ExceptionInfo, threads.ThreadTerminationError)) class PhaseExecutorThread(threads.KillableThread): @@ -149,35 +163,42 @@ class PhaseExecutorThread(threads.KillableThread): """ daemon = True - def __init__(self, phase_desc, test_state, run_with_profiling): + def __init__(self, phase_desc: phase_descriptor.PhaseDescriptor, + test_state: 'htf_test_state.TestState', run_with_profiling: bool, + subtest_rec: Optional[test_record.SubtestRecord]): super(PhaseExecutorThread, self).__init__( name='', run_with_profiling=run_with_profiling) self._phase_desc = phase_desc self._test_state = test_state - self._phase_execution_outcome = None + self._subtest_rec = subtest_rec + self._phase_execution_outcome = None # type: Optional[PhaseExecutionOutcome] - def _thread_proc(self): + def _thread_proc(self) -> None: """Execute the encompassed phase and save the result.""" # Call the phase, save the return value, or default it to CONTINUE. phase_return = self._phase_desc(self._test_state) if phase_return is None: - phase_return = openhtf.PhaseResult.CONTINUE - - # If phase_return is invalid, this will raise, and _phase_execution_outcome - # will get set to the InvalidPhaseResultError in _thread_exception instead. + phase_return = phase_descriptor.PhaseResult.CONTINUE + + if not isinstance(phase_return, phase_descriptor.PhaseResult): + raise InvalidPhaseResultError('Invalid phase result', phase_return) + if (phase_return is phase_descriptor.PhaseResult.FAIL_SUBTEST and + not self._subtest_rec): + raise InvalidPhaseResultError( + 'Phase returned FAIL_SUBTEST but a subtest is not running.') self._phase_execution_outcome = PhaseExecutionOutcome(phase_return) - def _log_exception(self, *args): + def _log_exception(self, *args: Any) -> Any: """Log exception, while allowing unit testing to override.""" self._test_state.state_logger.critical(*args) - def _thread_exception(self, *args): + def _thread_exception(self, *args) -> bool: self._phase_execution_outcome = PhaseExecutionOutcome(ExceptionInfo(*args)) self._log_exception('Phase %s raised an exception', self._phase_desc.name) return True # Never propagate exceptions upward. - def join_or_die(self): + def join_or_die(self) -> PhaseExecutionOutcome: """Wait for thread to finish, returning a PhaseExecutionOutcome instance.""" if self._phase_desc.options.timeout_s is not None: self.join(self._phase_desc.options.timeout_s) @@ -185,7 +206,7 @@ def join_or_die(self): self.join(DEFAULT_PHASE_TIMEOUT_S) # We got a return value or an exception and handled it. - if isinstance(self._phase_execution_outcome, PhaseExecutionOutcome): + if self._phase_execution_outcome: return self._phase_execution_outcome # Check for timeout, indicated by None for @@ -198,31 +219,37 @@ def join_or_die(self): return PhaseExecutionOutcome(threads.ThreadTerminationError()) @property - def name(self): + def name(self) -> Text: return str(self) - def __str__(self): - return '<%s: (%s)>' % (type(self).__name__, self._phase_desc.name) + def __str__(self) -> Text: + return '<{}: ({})>'.format(type(self).__name__, self._phase_desc.name) class PhaseExecutor(object): """Encompasses the execution of the phases of a test.""" - def __init__(self, test_state): + def __init__(self, test_state: 'htf_test_state.TestState'): self.test_state = test_state # This lock exists to prevent stop() calls from being ignored if called when # _execute_phase_once is setting up the next phase thread. self._current_phase_thread_lock = threading.Lock() - self._current_phase_thread = None + self._current_phase_thread = None # type: Optional[PhaseExecutorThread] self._stopping = threading.Event() - def execute_phase(self, phase, run_with_profiling=False): + def execute_phase( + self, + phase: phase_descriptor.PhaseDescriptor, + run_with_profiling: bool = False, + subtest_rec: Optional[test_record.SubtestRecord] = None + ) -> Tuple[PhaseExecutionOutcome, Optional[pstats.Stats]]: """Executes a phase or skips it, yielding PhaseExecutionOutcome instances. Args: phase: Phase to execute. run_with_profiling: Whether to run with cProfile stat collection for the phase code run inside a thread. + subtest_rec: Optional subtest record. Returns: A two-tuple; the first item is the final PhaseExecutionOutcome that wraps @@ -237,7 +264,7 @@ def execute_phase(self, phase, run_with_profiling=False): while not self._stopping.is_set(): is_last_repeat = repeat_count >= repeat_limit phase_execution_outcome, profile_stats = self._execute_phase_once( - phase, is_last_repeat, run_with_profiling) + phase, is_last_repeat, run_with_profiling, subtest_rec) if phase_execution_outcome.is_repeat and not is_last_repeat: repeat_count += 1 @@ -247,17 +274,28 @@ def execute_phase(self, phase, run_with_profiling=False): # We've been cancelled, so just 'timeout' the phase. return PhaseExecutionOutcome(None), None - def _execute_phase_once(self, phase_desc, is_last_repeat, run_with_profiling): + def _execute_phase_once( + self, + phase_desc: phase_descriptor.PhaseDescriptor, + is_last_repeat: bool, + run_with_profiling: bool, + subtest_rec: Optional[test_record.SubtestRecord], + ) -> Tuple[PhaseExecutionOutcome, Optional[pstats.Stats]]: """Executes the given phase, returning a PhaseExecutionOutcome.""" # Check this before we create a PhaseState and PhaseRecord. if phase_desc.options.run_if and not phase_desc.options.run_if(): _LOG.debug('Phase %s skipped due to run_if returning falsey.', phase_desc.name) - return PhaseExecutionOutcome(openhtf.PhaseResult.SKIP), None + return PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), None override_result = None with self.test_state.running_phase_context(phase_desc) as phase_state: - _LOG.debug('Executing phase %s', phase_desc.name) + if subtest_rec: + _LOG.debug('Executing phase %s under subtest %s', phase_desc.name, + subtest_rec.name) + phase_state.set_subtest_name(subtest_rec.name) + else: + _LOG.debug('Executing phase %s', phase_desc.name) with self._current_phase_thread_lock: # Checking _stopping must be in the lock context, otherwise there is a # race condition: this thread checks _stopping and then switches to @@ -271,7 +309,7 @@ def _execute_phase_once(self, phase_desc, is_last_repeat, run_with_profiling): phase_state.result = result return result, None phase_thread = PhaseExecutorThread(phase_desc, self.test_state, - run_with_profiling) + run_with_profiling, subtest_rec) phase_thread.start() self._current_phase_thread = phase_thread @@ -279,7 +317,8 @@ def _execute_phase_once(self, phase_desc, is_last_repeat, run_with_profiling): if phase_state.result.is_repeat and is_last_repeat: _LOG.error('Phase returned REPEAT, exceeding repeat_limit.') phase_state.hit_repeat_limit = True - override_result = PhaseExecutionOutcome(openhtf.PhaseResult.STOP) + override_result = PhaseExecutionOutcome( + phase_descriptor.PhaseResult.STOP) self._current_phase_thread = None # Refresh the result in case a validation for a partially set measurement @@ -290,10 +329,64 @@ def _execute_phase_once(self, phase_desc, is_last_repeat, run_with_profiling): return (result, phase_thread.get_profile_stats() if run_with_profiling else None) - def reset_stop(self): + def skip_phase(self, phase_desc: phase_descriptor.PhaseDescriptor, + subtest_rec: Optional[test_record.SubtestRecord]) -> None: + """Skip a phase, but log a record of it.""" + _LOG.debug('Automatically skipping phase %s', phase_desc.name) + with self.test_state.running_phase_context(phase_desc) as phase_state: + if subtest_rec: + phase_state.set_subtest_name(subtest_rec.name) + phase_state.result = PhaseExecutionOutcome( + phase_descriptor.PhaseResult.SKIP) + + def evaluate_checkpoint( + self, checkpoint: phase_branches.Checkpoint, + subtest_rec: Optional[test_record.SubtestRecord] + ) -> PhaseExecutionOutcome: + """Evaluate a checkpoint, returning a PhaseExecutionOutcome.""" + if subtest_rec: + subtest_name = subtest_rec.name + _LOG.debug('Evaluating checkpoint %s under subtest %s', checkpoint.name, + subtest_name) + else: + _LOG.debug('Evaluating checkpoint %s', checkpoint.name) + subtest_name = None + evaluated_millis = util.time_millis() + try: + outcome = PhaseExecutionOutcome(checkpoint.get_result(self.test_state)) + _LOG.debug('Checkpoint %s result: %s', checkpoint.name, + outcome.phase_result) + if outcome.is_fail_subtest and not subtest_rec: + raise InvalidPhaseResultError( + 'Checkpoint returned FAIL_SUBTEST, but subtest not running.') + except Exception: # pylint: disable=broad-except + outcome = PhaseExecutionOutcome(ExceptionInfo(*sys.exc_info())) + + checkpoint_rec = test_record.CheckpointRecord.from_checkpoint( + checkpoint, subtest_name, outcome, evaluated_millis) + + self.test_state.test_record.add_checkpoint_record(checkpoint_rec) + + return outcome + + def skip_checkpoint(self, checkpoint: phase_branches.Checkpoint, + subtest_rec: Optional[test_record.SubtestRecord]) -> None: + """Skip a checkpoint, but log a record of it.""" + _LOG.debug('Automatically skipping checkpoint %s', checkpoint.name) + subtest_name = subtest_rec.name if subtest_rec else None + checkpoint_rec = test_record.CheckpointRecord.from_checkpoint( + checkpoint, subtest_name, + PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), + util.time_millis()) + self.test_state.test_record.add_checkpoint_record(checkpoint_rec) + + def reset_stop(self) -> None: self._stopping.clear() - def stop(self, timeout_s=None): + def stop( + self, + timeout_s: Union[None, int, float, + timeouts.PolledTimeout] = None) -> None: """Stops execution of the current phase, if any. It will raise a ThreadTerminationError, which will cause the test to stop diff --git a/openhtf/core/phase_group.py b/openhtf/core/phase_group.py index b44aced11..1694ab1fb 100644 --- a/openhtf/core/phase_group.py +++ b/openhtf/core/phase_group.py @@ -11,23 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Phase Groups in OpenHTF. Phase Groups are collections of Phases that are used to control phase -shortcutting due to terminal errors to better guarentee when teardown phases +shortcutting due to terminal errors to better guarantee when teardown phases run. PhaseGroup instances have three primary member fields: - `setup`: a list of phases, run first. If these phases are all non-terminal, - the PhaseGroup is entered. - `main`: a list of phases, run after the setup phases as long as those are - non-terminal. If any of these phases are terminal, then the rest of the - main phases will be skipped. - `teardown`: a list of phases, guarenteed to run after the main phases as long - as the PhaseGroup was entered. If any are terminal, other teardown phases - will continue to be run. One exception is that a second CTRL-C sent to - the main thread will abort all teardown phases. + `setup`: a sequence of phase nodes, run first. If these phases are all + non-terminal, the PhaseGroup is entered. + `main`: a sequence of phase nodes, run after the setup phases as long as those + are non-terminal. If any of these phases are terminal, then the rest of + the main phases will be skipped. + `teardown`: a sequence of phase nodes, guaranteed to run after the main phases + as long as the PhaseGroup was entered. If any are terminal, other + teardown phases will continue to be run. One exception is that a second + CTRL-C sent to the main thread will abort all teardown phases. + Nested phase collections in a teardown do not have the terminal error + prevention, so further errors will cause those nodes to not be run. There is one optional field: `name`: str, an arbitrary description used for logging. @@ -35,22 +36,29 @@ terminal if any of its Phases or further nested PhaseGroups are also terminal. """ -import functools +from typing import Any, Callable, Dict, Iterator, Optional, Text, Type -import mutablerecords +import attr +from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import phase_collections from openhtf.core import phase_descriptor -from openhtf.core import test_record -from six.moves import collections_abc +from openhtf.util import data + +def _initialize_group_sequence( + seq: Optional[phase_collections.SequenceInitializerT] +) -> Optional[phase_collections.PhaseSequence]: + if not seq: + return None + if isinstance(seq, phase_collections.PhaseSequence): + return seq + return phase_collections.PhaseSequence(seq) -class PhaseGroup(mutablerecords.Record( - 'PhaseGroup', [], { - 'setup': tuple, - 'main': tuple, - 'teardown': tuple, - 'name': None, - })): + +@attr.s(slots=True, frozen=True, init=False) +class PhaseGroup(phase_collections.PhaseCollectionNode): """Phase group with guaranteed end phase running. If the setup phases all continue, then the main phases and teardown phases are @@ -58,218 +66,143 @@ class PhaseGroup(mutablerecords.Record( teardown phases are guaranteed to be run. """ - def __init__(self, setup=None, main=None, teardown=None, name=None): - if not setup: - setup = () - elif isinstance(setup, PhaseGroup): - setup = (setup,) - if not main: - main = () - elif isinstance(main, PhaseGroup): - main = (main,) - if not teardown: - teardown = () - elif isinstance(teardown, PhaseGroup): - teardown = (teardown,) - super(PhaseGroup, self).__init__( - setup=tuple(setup), main=tuple(main), teardown=tuple(teardown), - name=name) - - @classmethod - def convert_if_not(cls, phases_or_groups): - """Convert list of phases or groups into a new PhaseGroup if not already.""" - if isinstance(phases_or_groups, PhaseGroup): - return mutablerecords.CopyRecord(phases_or_groups) - - flattened = flatten_phases_and_groups(phases_or_groups) - return cls(main=flattened) + setup = attr.ib(type=Optional[phase_collections.PhaseSequence], default=None) + main = attr.ib(type=Optional[phase_collections.PhaseSequence], default=None) + teardown = attr.ib( + type=Optional[phase_collections.PhaseSequence], default=None) + name = attr.ib(type=Optional[Text], default=None) + + def __init__( + self, + setup: Optional[phase_collections.SequenceInitializerT] = None, + main: Optional[phase_collections.SequenceInitializerT] = None, + teardown: Optional[phase_collections.SequenceInitializerT] = None, + name: Optional[Text] = None): + object.__setattr__(self, 'setup', _initialize_group_sequence(setup)) + object.__setattr__(self, 'main', _initialize_group_sequence(main)) + object.__setattr__(self, 'teardown', _initialize_group_sequence(teardown)) + object.__setattr__(self, 'name', name) @classmethod - def with_context(cls, setup_phases, teardown_phases): + def with_context( + cls, setup_nodes: Optional[phase_collections.SequenceInitializerT], + teardown_nodes: Optional[phase_collections.SequenceInitializerT] + ) -> Callable[..., 'PhaseGroup']: """Create PhaseGroup creator function with setup and teardown phases. Args: - setup_phases: list of phase_descriptor.PhaseDescriptors/PhaseGroups/ - callables/iterables, phases to run during the setup for the PhaseGroup - returned from the created function. - teardown_phases: list of phase_descriptor.PhaseDescriptors/PhaseGroups/ - callables/iterables, phases to run during the teardown for the - PhaseGroup returned from the created function. + setup_nodes: phases to run during the setup for the PhaseGroup returned + from the created function. + teardown_nodes: phases to run during the teardown for the PhaseGroup + returned from the created function. Returns: Function that takes *phases and returns a PhaseGroup with the predefined setup and teardown phases, with *phases as the main phases. """ - setup = flatten_phases_and_groups(setup_phases) - teardown = flatten_phases_and_groups(teardown_phases) + setup = phase_collections.PhaseSequence( + setup_nodes) if setup_nodes else None + teardown = phase_collections.PhaseSequence( + teardown_nodes) if teardown_nodes else None + + def _context_wrapper( + *phases: phase_descriptor.PhaseCallableOrNodeT) -> 'PhaseGroup': + return cls( + setup=data.attr_copy(setup) if setup else None, + main=phase_collections.PhaseSequence(phases), + teardown=data.attr_copy(teardown) if teardown else None) - def _context_wrapper(*phases): - return cls(setup=setup, - main=flatten_phases_and_groups(phases), - teardown=teardown) return _context_wrapper @classmethod - def with_setup(cls, *setup_phases): + def with_setup( + cls, *setup_phases: phase_descriptor.PhaseCallableOrNodeT + ) -> Callable[..., 'PhaseGroup']: """Create PhaseGroup creator function with predefined setup phases.""" - return cls.with_context(setup_phases, []) + return cls.with_context(setup_phases, None) @classmethod - def with_teardown(cls, *teardown_phases): + def with_teardown( + cls, *teardown_phases: phase_descriptor.PhaseCallableOrNodeT + ) -> Callable[..., 'PhaseGroup']: """Create PhaseGroup creator function with predefined teardown phases.""" - return cls.with_context([], teardown_phases) + return cls.with_context(None, teardown_phases) - def combine(self, other, name=None): + def combine(self, + other: 'PhaseGroup', + name: Optional[Text] = None) -> 'PhaseGroup': """Combine with another PhaseGroup and return the result.""" return PhaseGroup( - setup=self.setup + other.setup, - main=self.main + other.main, - teardown=self.teardown + other.teardown, + setup=phase_collections.PhaseSequence.combine(self.setup, other.setup), + main=phase_collections.PhaseSequence.combine(self.main, other.main), + teardown=phase_collections.PhaseSequence.combine( + self.teardown, other.teardown), name=name) - def wrap(self, main_phases, name=None): + def wrap(self, + main_phases: phase_collections.SequenceInitializerT, + name: Text = None) -> 'PhaseGroup': """Returns PhaseGroup with additional main phases.""" - new_main = list(self.main) - if isinstance(main_phases, collections_abc.Iterable): - new_main.extend(main_phases) - else: - new_main.append(main_phases) - return PhaseGroup( - setup=self.setup, - main=new_main, - teardown=self.teardown, - name=name) - - def transform(self, transform_fn): - return PhaseGroup( - setup=[transform_fn(p) for p in self.setup], - main=[transform_fn(p) for p in self.main], - teardown=[transform_fn(p) for p in self.teardown], - name=self.name) - - def with_args(self, **kwargs): + other = PhaseGroup(main=main_phases) + return self.combine(other, name=name) + + def _asdict(self) -> Dict[Text, Any]: + return { + 'setup': self.setup._asdict() if self.setup else None, + 'main': self.main._asdict() if self.main else None, + 'teardown': self.teardown._asdict() if self.teardown else None, + 'name': self.name, + } + + def with_args(self, **kwargs: Any) -> 'PhaseGroup': """Send known keyword-arguments to each contained phase the when called.""" - return self.transform(functools.partial(optionally_with_args, **kwargs)) - - def with_plugs(self, **subplugs): + return attr.evolve( + self, + setup=self.setup.with_args(**kwargs) if self.setup else None, + main=self.main.with_args(**kwargs) if self.main else None, + teardown=self.teardown.with_args(**kwargs) if self.teardown else None, + name=util.format_string(self.name, kwargs)) + + def with_plugs(self, **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseGroup': """Substitute only known plugs for placeholders for each contained phase.""" - return self.transform(functools.partial(optionally_with_plugs, **subplugs)) - - def _iterate(self, phases): - for phase in phases: - if isinstance(phase, PhaseGroup): - for p in phase: - yield p - else: - yield phase - - def __iter__(self): - """Iterate directly over the phases.""" - for phase in self._iterate(self.setup): - yield phase - for phase in self._iterate(self.main): - yield phase - for phase in self._iterate(self.teardown): - yield phase - - def flatten(self): - """Internally flatten out nested iterables.""" return PhaseGroup( - setup=flatten_phases_and_groups(self.setup), - main=flatten_phases_and_groups(self.main), - teardown=flatten_phases_and_groups(self.teardown), - name=self.name) + setup=self.setup.with_plugs(**subplugs) if self.setup else None, + main=self.main.with_plugs(**subplugs) if self.main else None, + teardown=self.teardown.with_plugs( + **subplugs) if self.teardown else None, + name=util.format_string(self.name, subplugs)) - def load_code_info(self): + def load_code_info(self) -> 'PhaseGroup': """Load coded info for all contained phases.""" return PhaseGroup( - setup=load_code_info(self.setup), - main=load_code_info(self.main), - teardown=load_code_info(self.teardown), + setup=self.setup.load_code_info() if self.setup else None, + main=self.main.load_code_info() if self.main else None, + teardown=self.teardown.load_code_info() if self.teardown else None, name=self.name) + def apply_to_all_phases( + self, func: Callable[[phase_descriptor.PhaseDescriptor], + phase_descriptor.PhaseDescriptor] + ) -> 'PhaseGroup': + """Apply func to all contained phases.""" + return PhaseGroup( + setup=self.setup.apply_to_all_phases(func) if self.setup else None, + main=self.main.apply_to_all_phases(func) if self.main else None, + teardown=(self.teardown.apply_to_all_phases(func) + if self.teardown else None), + name=self.name) -def load_code_info(phases_or_groups): - """Recursively load code info for a PhaseGroup or list of phases or groups.""" - if isinstance(phases_or_groups, PhaseGroup): - return phases_or_groups.load_code_info() - ret = [] - for phase in phases_or_groups: - if isinstance(phase, PhaseGroup): - ret.append(phase.load_code_info()) - else: - ret.append( - mutablerecords.CopyRecord( - phase, code_info=test_record.CodeInfo.for_function(phase.func))) - return ret - - -def flatten_phases_and_groups(phases_or_groups): - """Recursively flatten nested lists for the list of phases or groups.""" - if isinstance(phases_or_groups, PhaseGroup): - phases_or_groups = [phases_or_groups] - ret = [] - for phase in phases_or_groups: - if isinstance(phase, PhaseGroup): - ret.append(phase.flatten()) - elif isinstance(phase, collections_abc.Iterable): - ret.extend(flatten_phases_and_groups(phase)) - else: - ret.append(phase_descriptor.PhaseDescriptor.wrap_or_copy(phase)) - return ret - - -def optionally_with_args(phase, **kwargs): - """Apply only the args that the phase knows. - - If the phase has a **kwargs-style argument, it counts as knowing all args. - - Args: - phase: phase_descriptor.PhaseDescriptor or PhaseGroup or callable, or - iterable of those, the phase or phase group (or iterable) to apply - with_args to. - **kwargs: arguments to apply to the phase. - - Returns: - phase_descriptor.PhaseDescriptor or PhaseGroup or iterable with the updated - args. - """ - if isinstance(phase, PhaseGroup): - return phase.with_args(**kwargs) - if isinstance(phase, collections_abc.Iterable): - return [optionally_with_args(p, **kwargs) for p in phase] - - if not isinstance(phase, phase_descriptor.PhaseDescriptor): - phase = phase_descriptor.PhaseDescriptor.wrap_or_copy(phase) - return phase.with_known_args(**kwargs) - - -def optionally_with_plugs(phase, **subplugs): - """Apply only the with_plugs that the phase knows. - - This will determine the subset of plug overrides for only plugs the phase - actually has. - - Args: - phase: phase_descriptor.PhaseDescriptor or PhaseGroup or callable, or - iterable of those, the phase or phase group (or iterable) to apply the - plug changes to. - **subplugs: mapping from plug name to derived plug class, the subplugs to - apply. + def filter_by_type( + self, node_cls: Type[phase_collections.NodeType] + ) -> Iterator[phase_collections.NodeType]: + """Yields recursively all the nodes of the given type. - Raises: - openhtf.plugs.InvalidPlugError: if a specified subplug class is not a valid - replacement for the specified plug name. + This can yield collection nodes that include each other. - Returns: - phase_descriptor.PhaseDescriptor or PhaseGroup or iterable with the updated - plugs. - """ - if isinstance(phase, PhaseGroup): - return phase.with_plugs(**subplugs) - if isinstance(phase, collections_abc.Iterable): - return [optionally_with_plugs(p, **subplugs) for p in phase] - - if not isinstance(phase, phase_descriptor.PhaseDescriptor): - phase = phase_descriptor.PhaseDescriptor.wrap_or_copy(phase) - return phase.with_known_plugs(**subplugs) + Args: + node_cls: The phase node subtype to iterate over. + """ + for seq in (self.setup, self.main, self.teardown): + if seq: + for phase in seq.filter_by_type(node_cls): + yield phase diff --git a/openhtf/core/phase_nodes.py b/openhtf/core/phase_nodes.py new file mode 100644 index 000000000..b27a982e8 --- /dev/null +++ b/openhtf/core/phase_nodes.py @@ -0,0 +1,67 @@ +# Copyright 2020 Google Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Contains the abstract interfaces for phase nodes.""" + +import abc +from typing import Any, Callable, Dict, Optional, Text, Type, TypeVar, TYPE_CHECKING + +from openhtf.core import base_plugs +from openhtf.util import data +import six + +if TYPE_CHECKING: + from openhtf.core import phase_descriptor # pylint: disable=g-import-not-at-top + +WithModifierT = TypeVar('WithModifierT', bound='PhaseNode') +ApplyAllNodesT = TypeVar('ApplyAllNodesT', bound='PhaseNode') + + +class PhaseNode(six.with_metaclass(abc.ABCMeta, object)): + """Base class for all executable nodes in OpenHTF.""" + + __slots__ = () + + @abc.abstractproperty + def name(self) -> Optional[Text]: + """Returns the name of this node.""" + + @abc.abstractmethod + def _asdict(self) -> Dict[Text, Any]: + """Returns a base type dictionary for serialization.""" + + def copy(self: WithModifierT) -> WithModifierT: + """Create a copy of the PhaseNode.""" + return data.attr_copy(self) + + @abc.abstractmethod + def with_args(self: WithModifierT, **kwargs: Any) -> WithModifierT: + """Send these keyword-arguments when phases are called.""" + + @abc.abstractmethod + def with_plugs(self: WithModifierT, + **subplugs: Type[base_plugs.BasePlug]) -> WithModifierT: + """Substitute plugs for placeholders for this phase, error on unknowns.""" + + @abc.abstractmethod + def load_code_info(self: WithModifierT) -> WithModifierT: + """Load coded info for all contained phases.""" + + @abc.abstractmethod + def apply_to_all_phases( + self: WithModifierT, func: Callable[['phase_descriptor.PhaseDescriptor'], + 'phase_descriptor.PhaseDescriptor'] + ) -> WithModifierT: + """Apply func to all contained phases.""" diff --git a/openhtf/core/test_descriptor.py b/openhtf/core/test_descriptor.py index 7d0edb123..50183b6e5 100644 --- a/openhtf/core/test_descriptor.py +++ b/openhtf/core/test_descriptor.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Tests in OpenHTF. Tests are main entry point for OpenHTF tests. In its simplest form a test is a series of Phases that are executed by the OpenHTF framework. """ + import argparse import collections import logging @@ -27,20 +26,25 @@ import textwrap import threading import traceback -from types import LambdaType +import types +import typing +from typing import Any, Callable, Dict, List, Optional, Set, Text, Type, Union import uuid import weakref +import attr import colorama -import mutablerecords from openhtf import util +from openhtf.core import base_plugs from openhtf.core import diagnoses_lib +from openhtf.core import measurements +from openhtf.core import phase_collections from openhtf.core import phase_descriptor from openhtf.core import phase_executor -from openhtf.core import phase_group from openhtf.core import test_executor -from openhtf.core import test_record +from openhtf.core import test_record as htf_test_record +from openhtf.core import test_state from openhtf.util import conf from openhtf.util import console_output @@ -50,19 +54,15 @@ _LOG = logging.getLogger(__name__) -conf.declare('capture_source', description=textwrap.dedent( - '''Whether to capture the source of phases and the test module. This +conf.declare( + 'capture_source', + description=textwrap.dedent( + """Whether to capture the source of phases and the test module. This defaults to False since this potentially reads many files and makes large string copies. - Set to 'true' if you want to capture your test's source.'''), - default_value=False) -# TODO(arsharma): Deprecate this configuration after removing the old teardown -# specification. -conf.declare('teardown_timeout_s', default_value=30, description= - 'Default timeout (in seconds) for test teardown functions; ' - 'this option is deprecated and only applies to the deprecated ' - 'Test level teardown function.') + Set to 'true' if you want to capture your test's source."""), + default_value=False) class UnrecognizedTestUidError(Exception): @@ -77,7 +77,7 @@ class InvalidTestStateError(Exception): """Raised when an operation is attempted in an invalid state.""" -def create_arg_parser(add_help=False): +def create_arg_parser(add_help: bool = False) -> argparse.ArgumentParser: """Creates an argparse.ArgumentParser for parsing command line flags. If you want to add arguments, create your own with this as a parent: @@ -103,7 +103,8 @@ def create_arg_parser(add_help=False): ], add_help=add_help) parser.add_argument( - '--config-help', action='store_true', + '--config-help', + action='store_true', help='Instead of executing the test, simply print all available config ' 'keys and their description strings.') return parser @@ -130,7 +131,8 @@ def PhaseTwo(test): HANDLED_SIGINT_ONCE = False DEFAULT_SIGINT_HANDLER = None - def __init__(self, *phases, **metadata): + def __init__(self, *nodes: phase_descriptor.PhaseCallableOrNodeT, + **metadata: Any): # Some sanity checks on special metadata keys we automatically fill in. if 'config' in metadata: raise KeyError( @@ -141,18 +143,18 @@ def __init__(self, *phases, **metadata): self._test_options = TestOptions() self._lock = threading.Lock() self._executor = None - self._test_desc = TestDescriptor( - phases, test_record.CodeInfo.uncaptured(), metadata) + # TODO(arsharma): Drop _flatten at some point. + sequence = phase_collections.PhaseSequence(nodes) + self._test_desc = TestDescriptor(sequence, + htf_test_record.CodeInfo.uncaptured(), + metadata) if conf.capture_source: - # First, we copy the phases with the real CodeInfo for them. - group = self._test_desc.phase_group.load_code_info() - - # Then we replace the TestDescriptor with one that stores the test - # module's CodeInfo as well as our newly copied phases. - code_info = test_record.CodeInfo.for_module_from_stack(levels_up=2) - self._test_desc = self._test_desc._replace( - code_info=code_info, phase_group=group) + # Copy the phases with the real CodeInfo for them. + self._test_desc.phase_sequence = ( + self._test_desc.phase_sequence.load_code_info()) + self._test_desc.code_info = ( + htf_test_record.CodeInfo.for_module_from_stack(levels_up=2)) # Make sure configure() gets called at least once before Execute(). The # user might call configure() again to override options, but we don't want @@ -165,7 +167,7 @@ def __init__(self, *phases, **metadata): self.configure() @classmethod - def from_uid(cls, test_uid): + def from_uid(cls, test_uid: Text) -> 'Test': """Get Test by UID. Args: @@ -183,11 +185,12 @@ def from_uid(cls, test_uid): return test @property - def uid(self): + def uid(self) -> Optional[Text]: if self._executor is not None: return self._executor.uid + return None - def make_uid(self): + def make_uid(self) -> Text: """Returns the next test execution's UID. This identifier must be unique but trackable across invocations of @@ -200,29 +203,32 @@ def make_uid(self): uuid.uuid4().hex[:16], util.time_millis()) @property - def descriptor(self): + def descriptor(self) -> 'TestDescriptor': """Static data about this test, does not change across Execute() calls.""" return self._test_desc @property - def state(self): + def state(self) -> Optional[test_state.TestState]: """Transient state info about the currently executing test, or None.""" with self._lock: if self._executor: return self._executor.test_state + return None - def get_option(self, option): + def get_option(self, option: Text) -> Any: return getattr(self._test_options, option) - def add_output_callbacks(self, *callbacks): + def add_output_callbacks( + self, *callbacks: Callable[[htf_test_record.TestRecord], None]) -> None: """Add the given function as an output module to this test.""" self._test_options.output_callbacks.extend(callbacks) - def add_test_diagnosers(self, *diagnosers): + def add_test_diagnosers(self, + *diagnosers: diagnoses_lib.BaseTestDiagnoser) -> None: diagnoses_lib.check_diagnosers(diagnosers, diagnoses_lib.BaseTestDiagnoser) self._test_options.diagnosers.extend(diagnosers) - def configure(self, **kwargs): + def configure(self, **kwargs: Any) -> None: """Update test-wide configuration options. See TestOptions for docs.""" # These internally ensure they are safe to call multiple times with no weird # side effects. @@ -235,9 +241,10 @@ def configure(self, **kwargs): setattr(self._test_options, key, value) @classmethod - def handle_sig_int(cls, signalnum, handler): + def handle_sig_int(cls, signalnum: Optional[int], handler: Any) -> None: + """Handle the SIGINT callback.""" if not cls.TEST_INSTANCES: - cls.DEFAULT_SIGINT_HANDLER(signalnum, handler) + cls.DEFAULT_SIGINT_HANDLER(signalnum, handler) # pylint: disable=not-callable return _LOG.error('Received SIGINT, stopping all tests.') @@ -249,7 +256,7 @@ def handle_sig_int(cls, signalnum, handler): # Otherwise, does not raise KeyboardInterrupt to ensure that the tests are # cleaned up. - def abort_from_sig_int(self): + def abort_from_sig_int(self) -> None: """Abort test execution abruptly, only in response to SIGINT.""" with self._lock: _LOG.error('Aborting %s due to SIGINT', self) @@ -259,30 +266,17 @@ def abort_from_sig_int(self): _LOG.error('Test state: %s', self._executor.test_state) self._executor.abort() - # TODO(arsharma): teardown_function test option is deprecated; remove this. - def _get_running_test_descriptor(self): - """If there is a teardown_function, wrap current descriptor with it.""" - if not self._test_options.teardown_function: - return self._test_desc - - teardown_phase = phase_descriptor.PhaseDescriptor.wrap_or_copy( - self._test_options.teardown_function) - if not teardown_phase.options.timeout_s: - teardown_phase.options.timeout_s = conf.teardown_timeout_s - return TestDescriptor( - phase_group.PhaseGroup(main=[self._test_desc.phase_group], - teardown=[teardown_phase]), - self._test_desc.code_info, self._test_desc.metadata) - - def execute(self, test_start=None, profile_filename=None): + def execute(self, + test_start: Optional[phase_descriptor.PhaseT] = None, + profile_filename: Optional[Text] = None) -> bool: """Starts the framework and executes the given test. Args: test_start: Either a trigger phase for starting the test, or a function - that returns a DUT ID. If neither is provided, defaults to not - setting the DUT ID. + that returns a DUT ID. If neither is provided, defaults to not setting + the DUT ID. profile_filename: Name of file to put profiling stats into. This also - enables profiling data collection. + enables profiling data collection. Returns: Boolean indicating whether the test failed (False) or passed (True). @@ -290,8 +284,11 @@ def execute(self, test_start=None, profile_filename=None): Raises: InvalidTestStateError: if this test is already being executed. """ - diagnoses_lib.check_for_duplicate_results(self._test_desc.phase_group, - self._test_options.diagnosers) + diagnoses_lib.check_for_duplicate_results( + self._test_desc.phase_sequence.all_phases(), + self._test_options.diagnosers) + phase_collections.check_for_duplicate_subtest_names( + self._test_desc.phase_sequence) # Lock this section so we don't .stop() the executor between instantiating # it and .Start()'ing it, doing so does weird things to the executor state. with self._lock: @@ -307,20 +304,21 @@ def execute(self, test_start=None, profile_filename=None): self._test_desc.metadata['config'] = conf._asdict() self.last_run_time_millis = util.time_millis() - if isinstance(test_start, LambdaType): + if isinstance(test_start, types.LambdaType): + @phase_descriptor.PhaseOptions() def trigger_phase(test): - test.test_record.dut_id = test_start() + test.test_record.dut_id = typing.cast(types.LambdaType, test_start)() + trigger = trigger_phase else: trigger = test_start if conf.capture_source: - trigger.code_info = test_record.CodeInfo.for_function(trigger.func) + trigger.code_info = htf_test_record.CodeInfo.for_function(trigger.func) - test_desc = self._get_running_test_descriptor() self._executor = test_executor.TestExecutor( - test_desc, + self._test_desc, self.make_uid(), trigger, self._test_options, @@ -343,59 +341,49 @@ def trigger_phase(test): _LOG.debug('Test completed for %s, outputting now.', final_state.test_record.metadata['test_name']) - test_executor.CombineProfileStats(self._executor.phase_profile_stats, - profile_filename) + test_executor.combine_profile_stats(self._executor.phase_profile_stats, + profile_filename) for output_cb in self._test_options.output_callbacks: try: output_cb(final_state.test_record) except Exception: # pylint: disable=broad-except stacktrace = traceback.format_exc() - _LOG.error( - 'Output callback %s raised:\n%s\nContinuing anyway...', - output_cb, stacktrace) + _LOG.error('Output callback %s raised:\n%s\nContinuing anyway...', + output_cb, stacktrace) # Make sure the final outcome of the test is printed last and in a # noticeable color so it doesn't get scrolled off the screen or missed. - if final_state.test_record.outcome == test_record.Outcome.ERROR: + if final_state.test_record.outcome == htf_test_record.Outcome.ERROR: for detail in final_state.test_record.outcome_details: console_output.error_print(detail.description) else: colors = collections.defaultdict(lambda: colorama.Style.BRIGHT) - colors[test_record.Outcome.PASS] = ''.join((colorama.Style.BRIGHT, - colorama.Fore.GREEN)) - colors[test_record.Outcome.FAIL] = ''.join((colorama.Style.BRIGHT, - colorama.Fore.RED)) + colors[htf_test_record.Outcome.PASS] = ''.join( + (colorama.Style.BRIGHT, colorama.Fore.GREEN)) # pytype: disable=wrong-arg-types + colors[htf_test_record.Outcome.FAIL] = ''.join( + (colorama.Style.BRIGHT, colorama.Fore.RED)) # pytype: disable=wrong-arg-types msg_template = 'test: {name} outcome: {color}{outcome}{rst}' - console_output.banner_print(msg_template.format( - name=final_state.test_record.metadata['test_name'], - color=colors[final_state.test_record.outcome], - outcome=final_state.test_record.outcome.name, - rst=colorama.Style.RESET_ALL)) + console_output.banner_print( + msg_template.format( + name=final_state.test_record.metadata['test_name'], + color=colors[final_state.test_record.outcome], + outcome=final_state.test_record.outcome.name, + rst=colorama.Style.RESET_ALL)) finally: del self.TEST_INSTANCES[self.uid] self._executor.close() self._executor = None - return final_state.test_record.outcome == test_record.Outcome.PASS + return final_state.test_record.outcome == htf_test_record.Outcome.PASS -# TODO(arsharma): Deprecate the teardown_function in favor of PhaseGroups. -class TestOptions(mutablerecords.Record('TestOptions', [], { - 'name': 'openhtf_test', - 'output_callbacks': list, - 'teardown_function': None, - 'failure_exceptions': list, - 'default_dut_id': 'UNKNOWN_DUT', - 'stop_on_first_failure': False, - 'diagnosers': list, -})): +@attr.s(slots=True) +class TestOptions(object): """Class encapsulating various tunable knobs for Tests and their defaults. name: The name of the test to be put into the metadata. output_callbacks: List of output callbacks to run, typically it's better to use add_output_callbacks(), but you can pass [] here to reset them. - teardown_function: Function to run at teardown. We pass the same arguments to - it as a phase. failure_exceptions: Exceptions to cause a test FAIL instead of ERROR. When a test run exits early due to an exception, the run will be marked as a FAIL if the raised exception matches one of the types in this list. Otherwise, @@ -407,108 +395,196 @@ class TestOptions(mutablerecords.Record('TestOptions', [], { phases. """ + name = attr.ib(type=Text, default='openhtf_test') + output_callbacks = attr.ib( + type=List[Callable[[htf_test_record.TestRecord], None]], factory=list) + failure_exceptions = attr.ib(type=List[Type[Exception]], factory=list) + default_dut_id = attr.ib(type=Text, default='UNKNOWN_DUT') + stop_on_first_failure = attr.ib(type=bool, default=False) + diagnosers = attr.ib(type=List[diagnoses_lib.BaseTestDiagnoser], factory=list) + -class TestDescriptor(collections.namedtuple( - 'TestDescriptor', ['phase_group', 'code_info', 'metadata', 'uid'])): +@attr.s(slots=True) +class TestDescriptor(object): """An object that represents the reusable portions of an OpenHTF test. This object encapsulates the static test information that is set once and used by the framework along the way. Attributes: - phase_group: The top level phase group to execute for this Test. + phase_sequence: The top level phase collection for this test. metadata: Any metadata that should be associated with test records. code_info: Information about the module that created the Test. uid: UID for this test. """ - def __new__(cls, phases, code_info, metadata): - group = phase_group.PhaseGroup.convert_if_not(phases) - return super(TestDescriptor, cls).__new__( - cls, group, code_info, metadata, uid=uuid.uuid4().hex[:16]) + phase_sequence = attr.ib(type=phase_collections.PhaseSequence) + code_info = attr.ib(type=htf_test_record.CodeInfo) + metadata = attr.ib(type=Dict[Text, Any]) + uid = attr.ib(type=Text, factory=lambda: uuid.uuid4().hex[:16]) @property - def plug_types(self): + def plug_types(self) -> Set[Type[base_plugs.BasePlug]]: """Returns set of plug types required by this test.""" - return {plug.cls - for phase in self.phase_group - for plug in phase.plugs} + ret = set() + for phase in self.phase_sequence.all_phases(): + for plug in phase.plugs: + ret.add(plug.cls) + return ret -class TestApi(collections.namedtuple('TestApi', [ - 'logger', 'state', 'test_record', 'measurements', 'attachments', - 'attach', 'attach_from_file', 'get_measurement', 'get_attachment', - 'notify_update'])): +@attr.s(slots=True) +class TestApi(object): """Class passed to test phases as the first argument. Attributes: - dut_id: This attribute provides getter and setter access to the DUT ID - of the device under test by the currently running openhtf.Test. A - non-empty DUT ID *must* be set by the end of a test, or no output - will be produced. It may be set via return value from a callable - test_start argument to openhtf.Test.Execute(), or may be set in a - test phase via this attribute. - + dut_id: This attribute provides getter and setter access to the DUT ID of + the device under test by the currently running openhtf.Test. A non-empty + DUT ID *must* be set by the end of a test, or no output will be produced. + It may be set via return value from a callable test_start argument to + openhtf.Test.Execute(), or may be set in a test phase via this attribute. logger: A Python Logger instance that can be used to log to the resulting - TestRecord. This object supports all the usual log levels, and - outputs to stdout (configurable) and the frontend via the Station - API, if it's enabled, in addition to the 'log_records' attribute - of the final TestRecord output by the running test. - - measurements: A measurements.Collection object used to get/set - measurement values. See util/measurements.py for more implementation - details, but in the simple case, set measurements directly as - attributes on this object (see examples/measurements.py for examples). - + TestRecord. This object supports all the usual log levels, and outputs to + stdout (configurable) and the frontend via the Station API, if it's + enabled, in addition to the 'log_records' attribute of the final + TestRecord output by the running test. + measurements: A measurements.Collection object used to get/set measurement + values. See util/measurements.py for more implementation details, but in + the simple case, set measurements directly as attributes on this object + (see examples/measurements.py for examples). + attachments: Dict mapping attachment name to test_record.Attachment instance + containing the data that was attached (and the MIME type that was assumed + based on extension, if any). Only attachments that have been attached in + the current phase show up here, and this attribute should not be modified + directly; use TestApi.attach() or TestApi.attach_from_file() instead; read + only. state: A dict (initially empty) that is persisted across test phases (but - resets for every invocation of Execute() on an openhtf.Test). This - can be used for any test-wide state you need to persist across phases. - Use this with caution, however, as it is not persisted in the output - TestRecord or displayed on the web frontend in any way. - - test_record: A reference to the output TestRecord for the currently - running openhtf.Test. Direct access to this attribute is *strongly* - discouraged, but provided as a catch-all for interfaces not otherwise - provided by TestApi. If you find yourself using this, please file a + resets for every invocation of Execute() on an openhtf.Test). This can be + used for any test-wide state you need to persist across phases. Use this + with caution, however, as it is not persisted in the output TestRecord or + displayed on the web frontend in any way. + diagnoses_store: The diagnoses storage and lookup instance for this test. + test_record: A reference to the output TestRecord for the currently running + openhtf.Test. Direct access to this attribute is *strongly* discouraged, + but provided as a catch-all for interfaces not otherwise provided by + TestApi. If you find yourself using this, please file a feature request for an alternative at: https://github.com/google/openhtf/issues/new + """ - Callable Attributes: - attach: Attach binary data to the test, see TestState.attach(). - - attach_from_file: Attach binary data from a file, see - TestState.attach_from_file(). - - get_attachment: Get copy of attachment contents from current or previous - phase, see TestState.get_attachement. - - get_measurement: Get copy of a measurement from a current or previous phase, - see TestState.get_measurement(). - - notify_update: Notify any frontends of an interesting update. Typically - this is automatically called internally when interesting things happen, - but it can be called by the user (takes no args), for instance if - modifying test_record directly. - - + measurements = attr.ib(type=measurements.Collection) - Read-only Attributes: - attachments: Dict mapping attachment name to test_record.Attachment - instance containing the data that was attached (and the MIME type - that was assumed based on extension, if any). Only attachments - that have been attached in the current phase show up here, and this - attribute should not be modified directly; use TestApi.attach() or - TestApi.attach_from_file() instead. - """ + # Internal state objects. If you find yourself needing to use these, please + # use required_state=True for the phase to use the test_state object instead. + _running_phase_state = attr.ib(type=test_state.PhaseState) + _running_test_state = attr.ib(type=test_state.TestState) @property - def dut_id(self): + def dut_id(self) -> Text: return self.test_record.dut_id @dut_id.setter - def dut_id(self, dut_id): + def dut_id(self, dut_id: Text) -> None: if self.test_record.dut_id: self.logger.warning('Overriding previous DUT ID "%s" with "%s".', self.test_record.dut_id, dut_id) self.test_record.dut_id = dut_id self.notify_update() + + @property + def logger(self) -> logging.Logger: + return self._running_phase_state.logger + + # TODO(arsharma): Change to Dict[Any, Any] when pytype handles it correctly. + @property + def state(self) -> Any: + return self._running_test_state.user_defined_state + + @property + def test_record(self) -> htf_test_record.TestRecord: + return self._running_test_state.test_record + + @property + def attachments(self) -> Dict[Text, htf_test_record.Attachment]: + return self._running_phase_state.attachments + + def attach( + self, + name: Text, + binary_data: Union[Text, bytes], + mimetype: test_state.MimetypeT = test_state.INFER_MIMETYPE) -> None: + """Store the given binary_data as an attachment with the given name. + + Args: + name: Attachment name under which to store this binary_data. + binary_data: Data to attach. + mimetype: One of the following: INFER_MIMETYPE - The type will be guessed + from the attachment name. None - The type will be left unspecified. A + string - The type will be set to the specified value. + + Raises: + DuplicateAttachmentError: Raised if there is already an attachment with + the given name. + """ + self._running_phase_state.attach(name, binary_data, mimetype=mimetype) + + def attach_from_file( + self, + filename: Text, + name: Optional[Text] = None, + mimetype: test_state.MimetypeT = test_state.INFER_MIMETYPE) -> None: + """Store the contents of the given filename as an attachment. + + Args: + filename: The file to read data from to attach. + name: If provided, override the attachment name, otherwise it will default + to the filename. + mimetype: One of the following: + * INFER_MIMETYPE: The type will be guessed first, from the file name, + and second (i.e. as a fallback), from the attachment name. + * None: The type will be left unspecified. + * A string: The type will be set to the specified value. + + Raises: + DuplicateAttachmentError: Raised if there is already an attachment with + the given name. + IOError: Raised if the given filename couldn't be opened. + """ + self._running_phase_state.attach_from_file( + filename, name=name, mimetype=mimetype) + + def get_measurement( + self, + measurement_name: Text) -> Optional[test_state.ImmutableMeasurement]: + """Get a copy of a measurement value from current or previous phase. + + Measurement and phase name uniqueness is not enforced, so this method will + return an immutable copy of the most recent measurement recorded. + + Args: + measurement_name: str of the measurement name + + Returns: + an ImmutableMeasurement or None if the measurement cannot be found. + """ + return self._running_test_state.get_measurement(measurement_name) + + def get_attachment( + self, attachment_name: Text) -> Optional[htf_test_record.Attachment]: + """Get a copy of an attachment from current or previous phases. + + Args: + attachment_name: str of the attachment name + + Returns: + A copy of the attachment or None if the attachment cannot be found. + """ + return self._running_test_state.get_attachment(attachment_name) + + def notify_update(self) -> None: + """Notify any update events that there was an update.""" + self._running_test_state.notify_update() + + @property + def diagnoses_store(self) -> diagnoses_lib.DiagnosesStore: + return self._running_test_state.diagnoses_manager.store diff --git a/openhtf/core/test_executor.py b/openhtf/core/test_executor.py index 91fe53d75..2a4ca4b6e 100644 --- a/openhtf/core/test_executor.py +++ b/openhtf/core/test_executor.py @@ -11,34 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """TestExecutor executes tests.""" +import contextlib +import enum import logging import pstats import sys import tempfile import threading import traceback +from typing import Iterator, List, Optional, Text, Type, TYPE_CHECKING +from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import diagnoses_lib +from openhtf.core import phase_branches +from openhtf.core import phase_collections from openhtf.core import phase_descriptor from openhtf.core import phase_executor from openhtf.core import phase_group +from openhtf.core import phase_nodes from openhtf.core import test_record from openhtf.core import test_state from openhtf.util import conf from openhtf.util import threads +if TYPE_CHECKING: + from openhtf.core import test_descriptor # pylint: disable=g-import-not-at-top _LOG = logging.getLogger(__name__) -conf.declare('cancel_timeout_s', default_value=2, - description='Timeout (in seconds) when the test has been cancelled' - 'to wait for the running phase to exit.') +conf.declare( + 'cancel_timeout_s', + default_value=2, + description='Timeout (in seconds) when the test has been cancelled' + 'to wait for the running phase to exit.') -conf.declare('stop_on_first_failure', default_value=False, - description='Stop current test execution and return Outcome FAIL' - 'on first phase with failed measurement.') +conf.declare( + 'stop_on_first_failure', + default_value=False, + description='Stop current test execution and return Outcome FAIL' + 'on first phase with failed measurement.') class TestExecutionError(Exception): @@ -49,7 +63,17 @@ class TestStopError(Exception): """Test is being stopped.""" -def CombineProfileStats(profile_stats_iter, output_filename): +class _ExecutorReturn(enum.Enum): + CONTINUE = 0 + TERMINAL = 1 + + +def _more_critical(e1: _ExecutorReturn, e2: _ExecutorReturn) -> _ExecutorReturn: + return _ExecutorReturn(max(e1.value, e2.value)) + + +def combine_profile_stats(profile_stats_iter: List[pstats.Stats], + output_filename: Text) -> None: """Given an iterable of pstats.Stats, combine them into a single Stats.""" profile_stats_filenames = [] for profile_stats in profile_stats_iter: @@ -66,34 +90,40 @@ class TestExecutor(threads.KillableThread): """Encompasses the execution of a single test.""" daemon = True - def __init__(self, - test_descriptor, - execution_uid, - test_start, - test_options, - run_with_profiling): + def __init__(self, test_descriptor: 'test_descriptor.TestDescriptor', + execution_uid: Text, + test_start: Optional[phase_descriptor.PhaseDescriptor], + test_options: 'test_descriptor.TestOptions', + run_with_profiling: bool): super(TestExecutor, self).__init__( name='TestExecutorThread', run_with_profiling=run_with_profiling) - self.test_state = None + self.test_state = None # type: Optional[test_state.TestState] self._test_descriptor = test_descriptor self._test_start = test_start self._test_options = test_options self._lock = threading.Lock() - self._phase_exec = None + self._phase_exec = None # type: Optional[phase_executor.PhaseExecutor] self.uid = execution_uid - self._last_outcome = None + self._last_outcome = None # type: Optional[phase_executor.PhaseExecutionOutcome] self._abort = threading.Event() self._full_abort = threading.Event() - self._teardown_phases_lock = threading.Lock() - self._phase_profile_stats = [] # Populated if profiling is enabled. + # This is a reentrant lock so that the teardown logic that prevents aborts + # affects nested sequences. + self._teardown_phases_lock = threading.RLock() + # Populated if profiling is enabled. + self._phase_profile_stats = [] # type: List[pstats.Stats] @property - def phase_profile_stats(self): + def logger(self) -> logging.Logger: + return self.test_state.state_logger + + @property + def phase_profile_stats(self) -> List[pstats.Stats]: """Returns iterable of profiling Stats objects, per phase.""" return self._phase_profile_stats - def close(self): + def close(self) -> None: """Close and remove any global registrations. Always call this function when finished with this instance. @@ -104,7 +134,7 @@ def close(self): self.wait() self.test_state.close() - def abort(self): + def abort(self) -> None: """Abort this test.""" if self._abort.is_set(): _LOG.error('Abort already set; forcibly stopping the process.') @@ -118,14 +148,14 @@ def abort(self): # No need to kill this thread because the abort state has been set, it will # end as soon as all queued teardown phases are run. - def finalize(self): + def finalize(self) -> test_state.TestState: """Finalize test execution and output resulting record to callbacks. Should only be called once at the conclusion of a test run, and will raise an exception if end_time_millis is already set. Returns: - Finalized TestState. It should not be modified after this call. + Finalized TestState. It must not be modified after this call. Raises: TestStopError: test @@ -139,7 +169,7 @@ def finalize(self): return self.test_state - def wait(self): + def wait(self) -> None: """Waits until death.""" # Must use a timeout here in case this is called from the main thread. # Otherwise, the SIGINT abort logic in test_descriptor will not get called. @@ -147,17 +177,15 @@ def wait(self): if sys.version_info >= (3, 2): # TIMEOUT_MAX can be too large and cause overflows on 32-bit OSes, so take # whichever timeout is shorter. - timeout = min(threading.TIMEOUT_MAX, timeout) + timeout = min(threading.TIMEOUT_MAX, timeout) # pytype: disable=module-attr self.join(timeout) - def _thread_proc(self): + def _thread_proc(self) -> None: """Handles one whole test from start to finish.""" try: # Top level steps required to run a single iteration of the Test. - self.test_state = test_state.TestState( - self._test_descriptor, - self.uid, - self._test_options) + self.test_state = test_state.TestState(self._test_descriptor, self.uid, + self._test_options) phase_exec = phase_executor.PhaseExecutor(self.test_state) # Any access to self._exit_stacks must be done while holding this lock. @@ -177,7 +205,7 @@ def _thread_proc(self): # Everything is set, set status and begin test execution. self.test_state.set_status_running() - self._execute_phase_group(self._test_descriptor.phase_group) + self._execute_node(self._test_descriptor.phase_sequence, None, False) self._execute_test_diagnosers() except: # pylint: disable=bare-except stacktrace = traceback.format_exc() @@ -186,7 +214,9 @@ def _thread_proc(self): finally: self._execute_test_teardown() - def _initialize_plugs(self, plug_types=None): + def _initialize_plugs( + self, + plug_types: Optional[List[Type[base_plugs.BasePlug]]] = None) -> bool: """Initialize plugs. Args: @@ -204,7 +234,7 @@ def _initialize_plugs(self, plug_types=None): phase_executor.ExceptionInfo(*sys.exc_info())) return True - def _execute_test_start(self): + def _execute_test_start(self) -> bool: """Run the start trigger phase, and check that the DUT ID is set after. Initializes any plugs used in the trigger. @@ -219,8 +249,8 @@ def _execute_test_start(self): """ # Have the phase executor run the start trigger phase. Do partial plug # initialization for just the plugs needed by the start trigger phase. - if self._initialize_plugs(plug_types=[ - phase_plug.cls for phase_plug in self._test_start.plugs]): + if self._initialize_plugs( + plug_types=[phase_plug.cls for phase_plug in self._test_start.plugs]): return True outcome, profile_stats = self._phase_exec.execute_phase( @@ -237,7 +267,7 @@ def _execute_test_start(self): _LOG.warning('Start trigger did not set a DUT ID.') return False - def _stop_phase_executor(self, force=False): + def _stop_phase_executor(self, force: bool = False) -> None: with self._lock: phase_exec = self._phase_exec if not phase_exec: @@ -254,28 +284,36 @@ def _stop_phase_executor(self, force=False): if not force: self._teardown_phases_lock.release() - # TODO(kschiller): Cleanup the naming here and possibly merge with finalize. - def _execute_test_teardown(self): + def _execute_test_teardown(self) -> None: # Plug teardown does not affect the test outcome. self.test_state.plug_manager.tear_down_plugs() # Now finalize the test state. if self._abort.is_set(): - self.test_state.state_logger.debug('Finishing test with outcome ABORTED.') + self.logger.debug('Finishing test with outcome ABORTED.') self.test_state.abort() elif self._last_outcome and self._last_outcome.is_terminal: self.test_state.finalize_from_phase_outcome(self._last_outcome) else: self.test_state.finalize_normally() - def _handle_phase(self, phase): - if isinstance(phase, phase_group.PhaseGroup): - return self._execute_phase_group(phase) + def _execute_phase(self, phase: phase_descriptor.PhaseDescriptor, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: + if subtest_rec: + self.logger.debug('Executing phase %s under subtest %s', phase.name, + subtest_rec.name) + else: + self.logger.debug('Executing phase %s', phase.name) - self.test_state.state_logger.debug('Handling phase %s', phase.name) - outcome, profile_stats = self._phase_exec.execute_phase( - phase, self._run_with_profiling) + if not in_teardown and subtest_rec and subtest_rec.is_fail: + self._phase_exec.skip_phase(phase, subtest_rec) + return _ExecutorReturn.CONTINUE + outcome, profile_stats = self._phase_exec.execute_phase( + phase, + run_with_profiling=self._run_with_profiling, + subtest_rec=subtest_rec) if profile_stats is not None: self._phase_profile_stats.append(profile_stats) @@ -287,59 +325,189 @@ def _handle_phase(self, phase): if current_phase_result.outcome == test_record.PhaseOutcome.FAIL: outcome = phase_executor.PhaseExecutionOutcome( phase_descriptor.PhaseResult.STOP) - self.test_state.state_logger.error( - 'Stopping test because stop_on_first_failure is True') + self.logger.error('Stopping test because stop_on_first_failure is True') - if outcome.is_terminal and not self._last_outcome: - self._last_outcome = outcome + if outcome.is_terminal: + if not self._last_outcome: + self._last_outcome = outcome + return _ExecutorReturn.TERMINAL + + if outcome.is_fail_subtest: + if not subtest_rec: + raise TestExecutionError( + 'INVALID STATE: Phase returned outcome FAIL_SUBTEST when not ' + 'in subtest.') + subtest_rec.outcome = test_record.SubtestOutcome.FAIL + return _ExecutorReturn.CONTINUE + + def _execute_checkpoint(self, checkpoint: phase_branches.Checkpoint, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: + if not in_teardown and subtest_rec and subtest_rec.is_fail: + self._phase_exec.skip_checkpoint(checkpoint, subtest_rec) + return _ExecutorReturn.CONTINUE + + outcome = self._phase_exec.evaluate_checkpoint(checkpoint, subtest_rec) + if outcome.is_terminal: + if not self._last_outcome: + self._last_outcome = outcome + return _ExecutorReturn.TERMINAL + + if outcome.is_fail_subtest: + if not subtest_rec: + raise TestExecutionError( + 'INVALID STATE: Phase returned outcome FAIL_SUBTEST when not ' + 'in subtest.') + subtest_rec.outcome = test_record.SubtestOutcome.FAIL + return _ExecutorReturn.CONTINUE + + def _log_sequence(self, phase_sequence, override_message): + message = phase_sequence.name + if override_message: + message = override_message + if message: + self.logger.debug('Executing phase nodes for %s', message) + + def _execute_sequence( + self, + phase_sequence: phase_collections.PhaseSequence, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool, + override_message: Optional[Text] = None) -> _ExecutorReturn: + """Execute phase sequence. - return outcome.is_terminal + Args: + phase_sequence: Sequence of phase nodes to run. + subtest_rec: Current subtest record, if any. + in_teardown: Indicates if currently processing a teardown sequence. + override_message: Optional message to override when logging. - def _execute_abortable_phases(self, type_name, phases, group_name): - """Execute phases, returning immediately if any error or abort is triggered. + Returns: + _ExecutorReturn for how to proceed. + """ + self._log_sequence(phase_sequence, override_message) + + if in_teardown: + return self._execute_teardown_sequence(phase_sequence, subtest_rec) + else: + return self._execute_abortable_sequence(phase_sequence, subtest_rec) + + def _execute_abortable_sequence( + self, phase_sequence: phase_collections.PhaseSequence, + subtest_rec: Optional[test_record.SubtestRecord]) -> _ExecutorReturn: + """Execute phase sequence, returning immediately on error or test abort. Args: - type_name: str, type of phases running, usually 'Setup' or 'Main'. - phases: iterable of phase_descriptor.Phase or phase_group.PhaseGroup - instances, the phases to execute. - group_name: str or None, name of the executing group. + phase_sequence: Sequence of phase nodes to run. + subtest_rec: Current subtest record, if any. Returns: - True if there is a terminal error or the test is aborted, False otherwise. + _ExecutorReturn for how to proceed. """ - if group_name and phases: - self.test_state.state_logger.debug( - 'Executing %s phases for %s', type_name, group_name) - for phase in phases: - if self._abort.is_set() or self._handle_phase(phase): - return True - return False - - def _execute_teardown_phases(self, teardown_phases, group_name): + for node in phase_sequence.nodes: + if self._abort.is_set(): + return _ExecutorReturn.TERMINAL + exe_ret = self._execute_node(node, subtest_rec, False) + if exe_ret != _ExecutorReturn.CONTINUE: + return exe_ret + return _ExecutorReturn.CONTINUE + + def _execute_teardown_sequence( + self, phase_sequence: phase_collections.PhaseSequence, + subtest_rec: Optional[test_record.SubtestRecord]) -> _ExecutorReturn: """Execute all the teardown phases, regardless of errors. Args: - teardown_phases: iterable of phase_descriptor.Phase or - phase_group.PhaseGroup instances, the phases to execute. - group_name: str or None, name of the executing group. + phase_sequence: Sequence of phase nodes to run. + subtest_rec: Current subtest record, if any. Returns: - True if there is at least one terminal error, False otherwise. + _ExecutorReturn for how to proceed. """ - if group_name and teardown_phases: - self.test_state.state_logger.debug('Executing teardown phases for %s', - group_name) - ret = False + ret = _ExecutorReturn.CONTINUE with self._teardown_phases_lock: - for teardown_phase in teardown_phases: + for node in phase_sequence.nodes: if self._full_abort.is_set(): - ret = True - break - if self._handle_phase(teardown_phase): - ret = True + return _ExecutorReturn.TERMINAL + ret = _more_critical(ret, self._execute_node(node, subtest_rec, True)) + return ret - def _execute_phase_group(self, group): + @contextlib.contextmanager + def _subtest_context( + self, subtest: phase_collections.Subtest + ) -> Iterator[test_record.SubtestRecord]: + """Enter a subtest context. + + This context tracks the subname and sets up the subtest record to track the + timing. + + Args: + subtest: The subtest running during the context. + + Yields: + The subtest record for updating the outcome. + """ + self.logger.debug('%s: Starting subtest.', subtest.name) + subtest_rec = test_record.SubtestRecord( + name=subtest.name, + start_time_millis=util.time_millis(), + outcome=test_record.SubtestOutcome.PASS) + yield subtest_rec + subtest_rec.end_time_millis = util.time_millis() + self.test_state.test_record.add_subtest_record(subtest_rec) + + def _execute_subtest(self, subtest: phase_collections.Subtest, + outer_subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: + """Run a subtest node.""" + with self._subtest_context(subtest) as subtest_rec: + if outer_subtest_rec and outer_subtest_rec.is_fail: + subtest_rec.outcome = test_record.SubtestOutcome.FAIL + + ret = self._execute_sequence(subtest, subtest_rec, in_teardown) + + if ret == _ExecutorReturn.TERMINAL: + subtest_rec.outcome = test_record.SubtestOutcome.STOP + self.logger.debug('%s: Subtest stopping the test.', subtest.name) + else: + if subtest_rec.outcome is test_record.SubtestOutcome.FAIL: + self.logger.debug('%s: Subtest failed;', subtest.name) + else: + self.logger.debug('%s: Subtest passed.', subtest.name) + return ret + + def _execute_phase_branch(self, branch: phase_branches.BranchSequence, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: + branch_message = branch.diag_condition.message + if branch.name: + branch_message = '{}:{}'.format(branch.name, branch_message) + if not in_teardown and subtest_rec and subtest_rec.is_fail: + self.logger.debug('%s: Branch not being run due to failed subtest.', + branch_message) + return _ExecutorReturn.CONTINUE + + evaluated_millis = util.time_millis() + if branch.should_run(self.test_state.diagnoses_manager.store): + self.logger.debug('%s: Branch condition met; running phases.', + branch_message) + branch_taken = True + ret = self._execute_sequence(branch, subtest_rec, in_teardown) + else: + self.logger.debug('%s: Branch condition NOT met; not running sequence.', + branch_message) + branch_taken = False + ret = _ExecutorReturn.CONTINUE + + branch_rec = test_record.BranchRecord.from_branch(branch, branch_taken, + evaluated_millis) + self.test_state.test_record.add_branch_record(branch_rec) + return ret + + def _execute_phase_group(self, group: phase_group.PhaseGroup, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: """Executes the phases in a phase group. This will run the phases in the phase group, ensuring if the setup @@ -351,28 +519,77 @@ def _execute_phase_group(self, group): Args: group: phase_group.PhaseGroup, the phase group to execute. + subtest_rec: Current subtest record, if any. + in_teardown: Indicates if currently processing a teardown sequence. Returns: True if the phases are terminal; otherwise returns False. """ + message_prefix = '' if group.name: - self.test_state.state_logger.debug('Entering PhaseGroup %s', group.name) - if self._execute_abortable_phases( - 'setup', group.setup, group.name): - return True - main_ret = self._execute_abortable_phases( - 'main', group.main, group.name) - teardown_ret = self._execute_teardown_phases( - group.teardown, group.name) - return main_ret or teardown_ret - - def _execute_test_diagnoser(self, diagnoser): + self.logger.debug('Entering PhaseGroup %s', group.name) + message_prefix = group.name + ':' + # If in a subtest and it is already failing, the group will not be entered, + # so the teardown phases will need to be skipped. + skip_teardown = subtest_rec is not None and subtest_rec.is_fail + if group.setup: + setup_ret = self._execute_sequence( + group.setup, + subtest_rec, + in_teardown, + override_message=message_prefix + 'setup') + if setup_ret != _ExecutorReturn.CONTINUE: + return setup_ret + if not skip_teardown: + # If the subtest fails during the setup, the group is still not entered, + # so skip the teardown phases here as well. + skip_teardown = (subtest_rec is not None and subtest_rec.is_fail) + if group.main: + main_ret = self._execute_sequence( + group.main, + subtest_rec, + in_teardown, + override_message=message_prefix + 'main') + else: + main_ret = _ExecutorReturn.CONTINUE + if group.teardown: + teardown_ret = self._execute_sequence( + group.teardown, + subtest_rec, + # If the subtest is already failing, record skips during the teardown + # sequence. + not skip_teardown, + override_message=message_prefix + 'teardown') + else: + teardown_ret = _ExecutorReturn.CONTINUE + return _more_critical(main_ret, teardown_ret) + + def _execute_node(self, node: phase_nodes.PhaseNode, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool) -> _ExecutorReturn: + if isinstance(node, phase_collections.Subtest): + return self._execute_subtest(node, subtest_rec, in_teardown) + if isinstance(node, phase_branches.BranchSequence): + return self._execute_phase_branch(node, subtest_rec, in_teardown) + if isinstance(node, phase_collections.PhaseSequence): + return self._execute_sequence(node, subtest_rec, in_teardown) + if isinstance(node, phase_group.PhaseGroup): + return self._execute_phase_group(node, subtest_rec, in_teardown) + if isinstance(node, phase_descriptor.PhaseDescriptor): + return self._execute_phase(node, subtest_rec, in_teardown) + if isinstance(node, phase_branches.Checkpoint): + return self._execute_checkpoint(node, subtest_rec, in_teardown) + self.logger.error('Unhandled node type: %s', node) + return _ExecutorReturn.TERMINAL + + def _execute_test_diagnoser( + self, diagnoser: diagnoses_lib.BaseTestDiagnoser) -> None: try: self.test_state.diagnoses_manager.execute_test_diagnoser( diagnoser, self.test_state.test_record) except Exception: # pylint: disable=broad-except if self._last_outcome and self._last_outcome.is_terminal: - self.test_state.state_logger.exception( + self.logger.exception( 'Test Diagnoser %s raised an exception, but the test outcome is ' 'already terminal; logging additional exception here.', diagnoser.name) @@ -381,6 +598,6 @@ def _execute_test_diagnoser(self, diagnoser): self._last_outcome = phase_executor.PhaseExecutionOutcome( phase_executor.ExceptionInfo(*sys.exc_info())) - def _execute_test_diagnosers(self): + def _execute_test_diagnosers(self) -> None: for diagnoser in self._test_options.diagnosers: self._execute_test_diagnoser(diagnoser) diff --git a/openhtf/core/test_record.py b/openhtf/core/test_record.py index dfcb13558..8bb768f21 100644 --- a/openhtf/core/test_record.py +++ b/openhtf/core/test_record.py @@ -11,20 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """OpenHTF module responsible for managing records of tests.""" -import collections import hashlib import inspect import logging import os import tempfile +from typing import Any, Dict, List, Optional, Text, TYPE_CHECKING, Union -from enum import Enum - -import mutablerecords +import attr +import enum # pylint: disable=g-bad-import-order from openhtf import util from openhtf.util import conf @@ -33,26 +30,36 @@ import six +if TYPE_CHECKING: + from openhtf.core import diagnoses_lib # pylint: disable=g-import-not-at-top + from openhtf.core import measurements as htf_measurements # pylint: disable=g-import-not-at-top + from openhtf.core import phase_descriptor # pylint: disable=g-import-not-at-top + from openhtf.core import phase_executor # pylint: disable=g-import-not-at-top + from openhtf.core import phase_branches # pylint: disable=g-import-not-at-top conf.declare( 'attachments_directory', default_value=None, description='Directory where temprorary files can be safely stored.') - _LOG = logging.getLogger(__name__) -class InvalidMeasurementDimensions(Exception): - """Raised when a measurement is taken with the wrong number of dimensions.""" +@attr.s(slots=True, frozen=True) +class OutcomeDetails(object): + code = attr.ib(type=Union[Text, int]) + description = attr.ib(type=Text) -OutcomeDetails = collections.namedtuple( - 'OutcomeDetails', 'code description') -Outcome = Enum('Outcome', ['PASS', 'FAIL', 'ERROR', 'TIMEOUT', 'ABORTED']) # pylint: disable=invalid-name -# LogRecord is in openhtf.util.logs.LogRecord. +class Outcome(enum.Enum): + PASS = 'PASS' + FAIL = 'FAIL' + ERROR = 'ERROR' + TIMEOUT = 'TIMEOUT' + ABORTED = 'ABORTED' +@attr.s(slots=True, init=False) class Attachment(object): """Encapsulate attachment data and guessed MIME type. @@ -63,67 +70,127 @@ class Attachment(object): Attributes: mimetype: str, MIME type of the data. sha1: str, SHA-1 hash of the data. - _file: pointer temporary File containing the data. + _file: Temporary File containing the data. + data: property that reads the data from the temporary file. """ - __slots__ = ['mimetype', 'sha1', '_file'] + mimetype = attr.ib(type=Text) + sha1 = attr.ib(type=Text) + _filename = attr.ib(type=Text) - def __init__(self, data, mimetype): - data = six.ensure_binary(data) + def __init__(self, contents: Union[Text, bytes], mimetype: Text): + contents = six.ensure_binary(contents) self.mimetype = mimetype - self.sha1 = hashlib.sha1(data).hexdigest() - self._file = self._create_temp_file(data) + self.sha1 = hashlib.sha1(contents).hexdigest() + self._filename = self._create_temp_file(contents) + + def __del__(self): + self.close() - def _create_temp_file(self, data): - tf = tempfile.NamedTemporaryFile('wb+', dir=conf.attachments_directory) - tf.write(data) - tf.flush() - return tf + def _create_temp_file(self, contents: bytes) -> Text: + with tempfile.NamedTemporaryFile( + 'w+b', dir=conf.attachments_directory, delete=False) as tf: + tf.write(contents) + return tf.name @property - def data(self): - self._file.seek(0) - return self._file.read() + def data(self) -> bytes: + with open(self._filename, 'rb') as contents: + return contents.read() - def _asdict(self): + def close(self): + if not self._filename: + return + os.remove(self._filename) + self._filename = None + + def _asdict(self) -> Dict[Text, Any]: # Don't include the attachment data when converting to dict. return { 'mimetype': self.mimetype, 'sha1': self.sha1, } - def __copy__(self): + def __copy__(self) -> 'Attachment': return Attachment(self.data, self.mimetype) - def __deepcopy__(self, memo): - return Attachment(self.data, self.mimetype) + def __deepcopy__(self, memo) -> 'Attachment': + del memo # Unused. + return self.__copy__() + + +def _get_source_safely(obj: Any) -> Text: + try: + return inspect.getsource(obj) + except Exception: # pylint: disable=broad-except + logs.log_once(_LOG.warning, + 'Unable to load source code for %s. Only logging this once.', + obj) + return '' + + +@attr.s(slots=True, frozen=True, hash=True) +class CodeInfo(object): + """Information regarding the running tester code.""" + + name = attr.ib(type=Text) + docstring = attr.ib(type=Optional[Text]) + sourcecode = attr.ib(type=Text) + + @classmethod + def for_module_from_stack(cls, levels_up: int = 1) -> 'CodeInfo': + # levels_up is how many frames up to go: + # 0: This function (useless). + # 1: The function calling this (likely). + # 2+: The function calling 'you' (likely in the framework). + frame, filename = inspect.stack(context=0)[levels_up][:2] + module = inspect.getmodule(frame) + source = _get_source_safely(frame) + return cls(os.path.basename(filename), inspect.getdoc(module), source) + + @classmethod + def for_function(cls, func: Any) -> 'CodeInfo': + source = _get_source_safely(func) + return cls(func.__name__, inspect.getdoc(func), source) + + @classmethod + def uncaptured(cls) -> 'CodeInfo': + return cls('', None, '') -class TestRecord( # pylint: disable=no-init - mutablerecords.Record( - 'TestRecord', ['dut_id', 'station_id'], - { - 'start_time_millis': int, - 'end_time_millis': None, - 'outcome': None, - 'outcome_details': list, - 'code_info': None, - 'metadata': dict, - 'phases': list, - 'diagnosers': list, - 'diagnoses': list, - 'log_records': list, - '_cached_record': dict, - '_cached_phases': list, - '_cached_diagnosers': list, - '_cached_diagnoses': list, - '_cached_log_records': list, - '_cached_config_from_metadata': dict, - })): +@attr.s(slots=True) +class TestRecord(object): """The record of a single run of a test.""" - def __init__(self, *args, **kwargs): - super(TestRecord, self).__init__(*args, **kwargs) + dut_id = attr.ib(type=Optional[Text]) + station_id = attr.ib(type=Text) + start_time_millis = attr.ib(type=int, default=0) + end_time_millis = attr.ib(type=Optional[int], default=None) + outcome = attr.ib(type=Optional[Outcome], default=None) + outcome_details = attr.ib(type=List[OutcomeDetails], factory=list) + code_info = attr.ib(type=CodeInfo, factory=CodeInfo.uncaptured) + metadata = attr.ib(type=Dict[Text, Any], factory=dict) + phases = attr.ib(type=List['PhaseRecord'], factory=list) + subtests = attr.ib(type=List['SubtestRecord'], factory=list) + branches = attr.ib(type=List['BranchRecord'], factory=list) + checkpoints = attr.ib(type=List['CheckpointRecord'], factory=list) + diagnosers = attr.ib( + type=List['diagnoses_lib.BaseTestDiagnoser'], factory=list) + diagnoses = attr.ib(type=List['diagnoses_lib.Diagnosis'], factory=list) + log_records = attr.ib(type=List[logs.LogRecord], factory=list) + + # Cache fields to reduce repeated base type conversions. + _cached_record = attr.ib(type=Dict[Text, Any], factory=dict) + _cached_phases = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_subtests = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_branches = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_checkpoints = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_diagnosers = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_diagnoses = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_log_records = attr.ib(type=List[Dict[Text, Any]], factory=list) + _cached_config_from_metadata = attr.ib(type=Dict[Text, Any], factory=dict) + + def __attrs_post_init__(self) -> None: # Cache data that does not change during execution. # Cache the metadata config so it does not recursively copied over and over # again. @@ -134,7 +201,9 @@ def __init__(self, *args, **kwargs): } self._cached_diagnosers = data.convert_to_base_types(self.diagnosers) - def add_outcome_details(self, code, description=''): + def add_outcome_details(self, + code: Union[int, Text], + description: Text = '') -> None: """Adds a code with optional description to this record's outcome_details. Args: @@ -143,22 +212,36 @@ def add_outcome_details(self, code, description=''): """ self.outcome_details.append(OutcomeDetails(code, description)) - def add_phase_record(self, phase_record): + def add_phase_record(self, phase_record: 'PhaseRecord') -> None: self.phases.append(phase_record) self._cached_phases.append(phase_record.as_base_types()) - def add_diagnosis(self, diagnosis): + def add_subtest_record(self, subtest_record: 'SubtestRecord') -> None: + self.subtests.append(subtest_record) + self._cached_subtests.append(data.convert_to_base_types(subtest_record)) + + def add_branch_record(self, branch_record: 'BranchRecord') -> None: + self.branches.append(branch_record) + self._cached_branches.append(data.convert_to_base_types(branch_record)) + + def add_checkpoint_record(self, + checkpoint_record: 'CheckpointRecord') -> None: + self.checkpoints.append(checkpoint_record) + self._cached_checkpoints.append( + data.convert_to_base_types(checkpoint_record)) + + def add_diagnosis(self, diagnosis: 'diagnoses_lib.Diagnosis') -> None: self.diagnoses.append(diagnosis) self._cached_diagnoses.append(data.convert_to_base_types(diagnosis)) - def add_log_record(self, log_record): + def add_log_record(self, log_record: logs.LogRecord) -> None: self.log_records.append(log_record) self._cached_log_records.append(log_record._asdict()) - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" - metadata = data.convert_to_base_types(self.metadata, - ignore_keys=('config',)) + metadata = data.convert_to_base_types( + self.metadata, ignore_keys=('config',)) metadata['config'] = self._cached_config_from_metadata ret = { 'dut_id': data.convert_to_base_types(self.dut_id), @@ -168,6 +251,8 @@ def as_base_types(self): 'outcome_details': data.convert_to_base_types(self.outcome_details), 'metadata': metadata, 'phases': self._cached_phases, + 'subtests': self._cached_subtests, + 'branches': self._cached_branches, 'diagnosers': self._cached_diagnosers, 'diagnoses': self._cached_diagnoses, 'log_records': self._cached_log_records, @@ -176,31 +261,72 @@ def as_base_types(self): return ret -# PhaseResult enumerations are converted to these outcomes by the PhaseState. -PhaseOutcome = Enum( # pylint: disable=invalid-name - 'PhaseOutcome', [ - 'PASS', # CONTINUE with allowed measurement outcomes. - 'FAIL', # CONTINUE with failed measurements or FAIL_AND_CONTINUE. - 'SKIP', # SKIP or REPEAT when under the phase's repeat limit. - 'ERROR', # Any terminal result. - ]) - - -class PhaseRecord( # pylint: disable=no-init - mutablerecords.Record( - 'PhaseRecord', ['descriptor_id', 'name', 'codeinfo'], - { - 'measurements': None, - 'options': None, - 'diagnosers': list, - 'start_time_millis': int, - 'end_time_millis': None, - 'attachments': dict, - 'diagnosis_results': list, - 'failure_diagnosis_results': list, - 'result': None, - 'outcome': None, - })): +@attr.s(slots=True, frozen=True) +class BranchRecord(object): + """The record of a branch.""" + + name = attr.ib(type=Optional[Text]) + diag_condition = attr.ib(type='phase_branches.DiagnosisCondition') + branch_taken = attr.ib(type=bool) + evaluated_millis = attr.ib(type=int) + + @classmethod + def from_branch(cls, branch: 'phase_branches.BranchSequence', + branch_taken: bool, evaluated_millis: int) -> 'BranchRecord': + return cls( + name=branch.name, + diag_condition=branch.diag_condition, + branch_taken=branch_taken, + evaluated_millis=evaluated_millis) + + +@attr.s(slots=True, frozen=True) +class CheckpointRecord(object): + """The record of a checkpoint.""" + + name = attr.ib(type=Text) + action = attr.ib(type='phase_descriptor.PhaseResult') + conditional = attr.ib(type=Union['phase_branches.PreviousPhases', + 'phase_branches.DiagnosisCondition']) + subtest_name = attr.ib(type=Optional[Text]) + result = attr.ib(type='phase_executor.PhaseExecutionOutcome') + evaluated_millis = attr.ib(type=int) + + @classmethod + def from_checkpoint(cls, checkpoint: 'phase_branches.Checkpoint', + subtest_name: Optional[Text], + result: 'phase_executor.PhaseExecutionOutcome', + evaluated_millis: int) -> 'CheckpointRecord': + return cls( + name=checkpoint.name, + action=checkpoint.action, + conditional=checkpoint.record_conditional(), + subtest_name=subtest_name, + result=result, + evaluated_millis=evaluated_millis) + + +class PhaseOutcome(enum.Enum): + """Phase outcomes, converted to from the PhaseState.""" + + # CONTINUE with allowed measurement outcomes. + PASS = 'PASS' + # CONTINUE with failed measurements or FAIL_AND_CONTINUE. + FAIL = 'FAIL' + # SKIP or REPEAT when under the phase's repeat limit. + SKIP = 'SKIP' + # Any terminal result. + ERROR = 'ERROR' + + +def _phase_record_base_type_filter(attribute: attr.Attribute, + value: Any) -> bool: + del value # Unused. + return attribute.name not in ('descriptor_id', 'name', 'codeinfo') # pytype: disable=attribute-error + + +@attr.s(slots=True) +class PhaseRecord(object): """The record of a single run of a phase. Measurement metadata (declarations) and values are stored in separate @@ -220,17 +346,40 @@ class PhaseRecord( # pylint: disable=no-init of the phase's measurements or indicates that the verification was skipped. """ - @classmethod - def from_descriptor(cls, phase_desc): - return cls(id(phase_desc), phase_desc.name, phase_desc.code_info, - diagnosers=list(phase_desc.diagnosers)) + descriptor_id = attr.ib(type=int) + name = attr.ib(type=Text) + codeinfo = attr.ib(type=CodeInfo) + + measurements = attr.ib( + type=Dict[Text, 'htf_measurements.Measurement'], default=None) + options = attr.ib(type='phase_descriptor.PhaseOptions', default=None) + diagnosers = attr.ib( + type=List['diagnoses_lib.BasePhaseDiagnoser'], factory=list) + subtest_name = attr.ib(type=Optional[Text], default=None) + start_time_millis = attr.ib(type=int, default=0) + end_time_millis = attr.ib(type=Optional[int], default=None) + attachments = attr.ib(type=Dict[Text, Attachment], factory=dict) + diagnosis_results = attr.ib( + type=List['diagnoses_lib.DiagResultEnum'], factory=list) + failure_diagnosis_results = attr.ib( + type=List['diagnoses_lib.DiagResultEnum'], factory=list) + result = attr.ib( + type=Optional['phase_executor.PhaseExecutionOutcome'], default=None) + outcome = attr.ib(type=Optional[PhaseOutcome], default=None) - def as_base_types(self): + @classmethod + def from_descriptor( + cls, phase_desc: 'phase_descriptor.PhaseDescriptor') -> 'PhaseRecord': + return cls( + id(phase_desc), + phase_desc.name, + phase_desc.code_info, + diagnosers=list(phase_desc.diagnosers)) + + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" - base_types_dict = { - k: data.convert_to_base_types(getattr(self, k)) - for k in self.optional_attributes - } + base_types_dict = data.convert_to_base_types( + attr.asdict(self, recurse=False, filter=_phase_record_base_type_filter)) base_types_dict.update( descriptor_id=self.descriptor_id, name=self.name, @@ -238,46 +387,31 @@ def as_base_types(self): ) return base_types_dict - def record_start_time(self): + def record_start_time(self) -> int: """Record the phase start time and return it.""" self.start_time_millis = util.time_millis() return self.start_time_millis - def finalize_phase(self, options): + def finalize_phase(self, options: 'phase_descriptor.PhaseOptions') -> None: self.end_time_millis = util.time_millis() self.options = options -def _get_source_safely(obj): - try: - return inspect.getsource(obj) - except Exception: # pylint: disable=broad-except - logs.log_once( - _LOG.warning, - 'Unable to load source code for %s. Only logging this once.', obj) - return '' +class SubtestOutcome(enum.Enum): + PASS = 'PASS' + FAIL = 'FAIL' + STOP = 'STOP' -class CodeInfo(mutablerecords.HashableRecord( - 'CodeInfo', ['name', 'docstring', 'sourcecode'])): - """Information regarding the running tester code.""" +@attr.s(slots=True) +class SubtestRecord(object): + """The record of a subtest.""" - @classmethod - def for_module_from_stack(cls, levels_up=1): - # levels_up is how many frames up to go: - # 0: This function (useless). - # 1: The function calling this (likely). - # 2+: The function calling 'you' (likely in the framework). - frame, filename = inspect.stack(context=0)[levels_up][:2] - module = inspect.getmodule(frame) - source = _get_source_safely(frame) - return cls(os.path.basename(filename), inspect.getdoc(module), source) - - @classmethod - def for_function(cls, func): - source = _get_source_safely(func) - return cls(func.__name__, inspect.getdoc(func), source) + name = attr.ib(type=Text) + start_time_millis = attr.ib(type=int, default=0) + end_time_millis = attr.ib(type=Optional[int], default=None) + outcome = attr.ib(type=Optional[SubtestOutcome], default=None) - @classmethod - def uncaptured(cls): - return cls('', None, '') + @property + def is_fail(self) -> bool: + return self.outcome is SubtestOutcome.FAIL diff --git a/openhtf/core/test_state.py b/openhtf/core/test_state.py index dce80734b..780a347b0 100644 --- a/openhtf/core/test_state.py +++ b/openhtf/core/test_state.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Module for handling transient state of a running test. Classes implemented in this module encapsulate state information about a @@ -27,41 +25,59 @@ import collections import contextlib import copy +import enum import functools +import logging import mimetypes import os import socket import sys +from typing import Any, Dict, Iterator, List, Optional, Set, Text, TYPE_CHECKING, Union -from enum import Enum -import mutablerecords +import attr import openhtf from openhtf import plugs from openhtf import util from openhtf.core import diagnoses_lib from openhtf.core import measurements +from openhtf.core import phase_descriptor from openhtf.core import phase_executor from openhtf.core import test_record from openhtf.util import conf from openhtf.util import data from openhtf.util import logs +from openhtf.util import units from past.builtins import long import six +from typing_extensions import Literal + +if TYPE_CHECKING: + from openhtf.core import test_descriptor # pylint: disable=g-import-not-at-top -conf.declare('allow_unset_measurements', default_value=False, - description='If True, unset measurements do not cause Tests to ' - 'FAIL.') +conf.declare( + 'allow_unset_measurements', + default_value=False, + description='If True, unset measurements do not cause Tests to ' + 'FAIL.') # All tests require a station_id. This can be via the --config-file # automatically loaded by OpenHTF, provided explicitly to the config with # conf.load(station_id='My_OpenHTF_Station'), or alongside other configs loaded # with conf.load_from_dict({..., 'station_id': 'My_Station'}). If none of those # are provided then we'll fall back to the machine's hostname. -conf.declare('station_id', 'The name of this test station', - default_value=socket.gethostname()) +conf.declare( + 'station_id', + 'The name of this test station', + default_value=socket.gethostname()) + + +class _Infer(enum.Enum): + INFER = 0 + # Sentinel value indicating that the mimetype should be inferred. -INFER_MIMETYPE = object() +INFER_MIMETYPE = _Infer.INFER +MimetypeT = Union[None, Literal[INFER_MIMETYPE], Text] class BlankDutIdError(Exception): @@ -72,30 +88,39 @@ class DuplicateAttachmentError(Exception): """Raised when two attachments are attached with the same name.""" -class ImmutableMeasurement(collections.namedtuple( - 'ImmutableMeasurement', - ['name', 'value', 'units', 'dimensions', 'outcome'])): +class InternalError(Exception): + """An internal error.""" + + +@attr.s(slots=True, frozen=True) +class ImmutableMeasurement(object): """Immutable copy of a measurement.""" + name = attr.ib(type=Text) + value = attr.ib(type=Any) + units = attr.ib(type=Optional[units.UnitDescriptor]) + dimensions = attr.ib(type=Optional[List[measurements.Dimension]]) + outcome = attr.ib(type=Optional[measurements.Outcome]) + @classmethod - def FromMeasurement(cls, measurement): + def from_measurement( + cls, measurement: measurements.Measurement) -> 'ImmutableMeasurement': """Convert a Measurement into an ImmutableMeasurement.""" measured_value = measurement.measured_value if isinstance(measured_value, measurements.DimensionedMeasuredValue): - value = mutablerecords.CopyRecord( - measured_value, - value_dict=copy.deepcopy(measured_value.value_dict) - ) + value = data.attr_copy( + measured_value, value_dict=copy.deepcopy(measured_value.value_dict)) else: - value = (copy.deepcopy(measured_value.value) - if measured_value.is_value_set else None) + value = ( + copy.deepcopy(measured_value.value) + if measured_value.is_value_set else None) return cls( - measurement.name, - value, - measurement.units, - measurement.dimensions, - measurement.outcome) + name=measurement.name, + value=value, + units=measurement.units, + dimensions=measurement.dimensions, + outcome=measurement.outcome) class TestState(util.SubscribableStateMixin): @@ -106,55 +131,66 @@ class TestState(util.SubscribableStateMixin): data associated with a Test (that is, it remains the same across invocations of Test.Execute()). - Init Args: - test_desc: openhtf.TestDescriptor instance describing the test to run, - used to initialize some values here, but it is not modified. - execution_uid: a unique uuid use to identify a test being run. - test_options: test_options passed through from Test. - Attributes: test_record: TestRecord instance for the currently running test. state_logger: Logger that logs to test_record's log_records attribute. plug_manager: PlugManager instance for managing supplying plugs for the - currently running test. + currently running test. diagnoses_manager: DiagnosesManager instance for tracking diagnoses for the - currently running test. - running_phase_state: PhaseState object for the currently running phase, - if any, otherwise None. + currently running test. + running_phase_state: PhaseState object for the currently running phase, if + any, otherwise None. user_defined_state: Dictionary for users to persist state across phase - invocations. It's passed to the user via test_api. - test_api: An openhtf.TestApi instance for passing to test phases, - providing test authors access to necessary state information, while - protecting internal-only structures from being accidentally modified. - Note that if there is no running phase, test_api is also None. + invocations. It's passed to the user via test_api. + test_api: An openhtf.TestApi instance for passing to test phases, providing + test authors access to necessary state information, while protecting + internal-only structures from being accidentally modified. Note that if + there is no running phase, test_api is also None. execution_uid: A UUID that is specific to this execution. """ - Status = Enum('Status', ['WAITING_FOR_TEST_START', 'RUNNING', 'COMPLETED']) # pylint: disable=invalid-name - def __init__(self, test_desc, execution_uid, test_options): + class Status(enum.Enum): + WAITING_FOR_TEST_START = 'WAITING_FOR_TEST_START' + RUNNING = 'RUNNING' + COMPLETED = 'COMPLETED' + + def __init__(self, test_desc: 'test_descriptor.TestDescriptor', + execution_uid: Text, + test_options: 'test_descriptor.TestOptions'): + """Initializer. + + Args: + test_desc: openhtf.TestDescriptor instance describing the test to run, + used to initialize some values here, but it is not modified. + execution_uid: a unique uuid use to identify a test being run. + test_options: test_options passed through from Test. + """ super(TestState, self).__init__() - self._status = self.Status.WAITING_FOR_TEST_START + self._status = self.Status.WAITING_FOR_TEST_START # type: TestState.Status self.test_record = test_record.TestRecord( - dut_id=None, station_id=conf.station_id, code_info=test_desc.code_info, + dut_id=None, + station_id=conf.station_id, + code_info=test_desc.code_info, start_time_millis=0, # Copy metadata so we don't modify test_desc. metadata=copy.deepcopy(test_desc.metadata), diagnosers=test_options.diagnosers) - logs.initialize_record_handler( - execution_uid, self.test_record, self.notify_update) + logs.initialize_record_handler(execution_uid, self.test_record, + self.notify_update) self.state_logger = logs.get_record_logger_for(execution_uid) - self.plug_manager = plugs.PlugManager( - test_desc.plug_types, self.state_logger) + self.plug_manager = plugs.PlugManager(test_desc.plug_types, + self.state_logger) self.diagnoses_manager = diagnoses_lib.DiagnosesManager( self.state_logger.getChild('diagnoses')) - self.running_phase_state = None - self._running_test_api = None - self.user_defined_state = {} + self.running_phase_state = None # type: Optional['PhaseState'] + self._running_test_api = None # type: Optional['test_descriptor.TestApi'] + # TODO(arsharma): Change to Dict[Any, Any] when pytype handles it correctly. + self.user_defined_state = {} # type: Any self.execution_uid = execution_uid self.test_options = test_options - def close(self): + def close(self) -> None: """Close and remove any global registrations. Always call this function when finished with this instance as it ensures @@ -166,7 +202,7 @@ def close(self): logs.remove_record_handler(self.execution_uid) @property - def logger(self): + def logger(self) -> logging.Logger: if self.running_phase_state: return self.running_phase_state.logger raise RuntimeError( @@ -174,35 +210,32 @@ def logger(self): 'instead.') @property - def test_api(self): + def test_api(self) -> 'test_descriptor.TestApi': """Create a TestApi for access to this TestState. The returned TestApi should be passed as the first argument to test - phases. Note that the return value is None if there is no - self.running_phase_state set. As such, this attribute should only - be accessed within a RunningPhaseContext(). + phases. This attribute must only be accessed within a + running_phase_context(). + + Raises: + ValueError: when not in a running_phase_context. Returns: openhtf.TestApi """ if not self.running_phase_state: - raise ValueError( - 'test_api only available when phase is running.') + raise ValueError('test_api only available when phase is running.') if not self._running_test_api: - ps = self.running_phase_state self._running_test_api = openhtf.TestApi( - self.logger, self.user_defined_state, self.test_record, - measurements.Collection(ps.measurements), - ps.attachments, - ps.attach, - ps.attach_from_file, - self.get_measurement, - self.get_attachment, - self.notify_update, + measurements=measurements.Collection( + self.running_phase_state.measurements), + running_phase_state=self.running_phase_state, + running_test_state=self, ) return self._running_test_api - def get_attachment(self, attachment_name): + def get_attachment(self, + attachment_name: Text) -> Optional[test_record.Attachment]: """Get a copy of an attachment from current or previous phases. Args: @@ -210,9 +243,8 @@ def get_attachment(self, attachment_name): Returns: A copy of the attachment or None if the attachment cannot be found. - """ - # Check current running phase state + # Check current running phase state for the attachment name first. if self.running_phase_state: if attachment_name in self.running_phase_state.phase_record.attachments: attachment = self.running_phase_state.phase_record.attachments.get( @@ -227,7 +259,8 @@ def get_attachment(self, attachment_name): self.state_logger.warning('Could not find attachment: %s', attachment_name) return None - def get_measurement(self, measurement_name): + def get_measurement(self, + measurement_name: Text) -> Optional[ImmutableMeasurement]: """Get a copy of a measurement value from current or previous phase. Measurement and phase name uniqueness is not enforced, so this method will @@ -235,6 +268,7 @@ def get_measurement(self, measurement_name): Args: measurement_name: str of the measurement name + Returns: an ImmutableMeasurement or None if the measurement cannot be found. """ @@ -244,7 +278,7 @@ def get_measurement(self, measurement_name): # Check current running phase state if self.running_phase_state: if measurement_name in self.running_phase_state.measurements: - return ImmutableMeasurement.FromMeasurement( + return ImmutableMeasurement.from_measurement( self.running_phase_state.measurements[measurement_name]) # Iterate through phases in reversed order to return most recent (necessary @@ -253,14 +287,16 @@ def get_measurement(self, measurement_name): if (phase_record.result not in ignore_outcomes and measurement_name in phase_record.measurements): measurement = phase_record.measurements[measurement_name] - return ImmutableMeasurement.FromMeasurement(measurement) + return ImmutableMeasurement.from_measurement(measurement) - self.state_logger.warning( - 'Could not find measurement: %s', measurement_name) + self.state_logger.warning('Could not find measurement: %s', + measurement_name) return None @contextlib.contextmanager - def running_phase_context(self, phase_desc): + def running_phase_context( + self, + phase_desc: phase_descriptor.PhaseDescriptor) -> Iterator['PhaseState']: """Create a context within which a single phase is running. Yields a PhaseState object for tracking transient state during the @@ -290,7 +326,7 @@ def running_phase_context(self, phase_desc): self._running_test_api = None self.notify_update() # Phase finished. - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" running_phase_state = None if self.running_phase_state: @@ -302,20 +338,20 @@ def as_base_types(self): 'running_phase_state': running_phase_state, } - def _asdict(self): + def _asdict(self) -> Dict[Text, Any]: """Return a dict representation of the test's state.""" return self.as_base_types() @property - def is_finalized(self): + def is_finalized(self) -> bool: return self._status == self.Status.COMPLETED - def stop_running_phase(self): + def stop_running_phase(self) -> None: """Stops the currently running phase, allowing another phase to run.""" self.running_phase_state = None @property - def last_run_phase_name(self): + def last_run_phase_name(self) -> Optional[Text]: """Get the name of the currently running phase, or None. Note that this name is not guaranteed to still be accurate by the time this @@ -325,16 +361,18 @@ def last_run_phase_name(self): Returns: str name of currently running phase or None. """ - return self.running_phase_state and self.running_phase_state.name + if self.running_phase_state: + return self.running_phase_state.name + return None - def mark_test_started(self): + def mark_test_started(self) -> None: """Set the TestRecord's start_time_millis field.""" # Blow up instead of blowing away a previously set start_time_millis. - assert self.test_record.start_time_millis is 0 + assert self.test_record.start_time_millis == 0 self.test_record.start_time_millis = util.time_millis() self.notify_update() - def set_status_running(self): + def set_status_running(self) -> None: """Mark the test as actually running, can't be done once finalized.""" if self._is_aborted(): return @@ -342,13 +380,10 @@ def set_status_running(self): self._status = self.Status.RUNNING self.notify_update() - def finalize_from_phase_outcome(self, phase_execution_outcome): - """Finalize due to the given phase outcome. - - Args: - phase_execution_outcome: An instance of - phase_executor.PhaseExecutionOutcome. - """ + def finalize_from_phase_outcome( + self, + phase_execution_outcome: phase_executor.PhaseExecutionOutcome) -> None: + """Finalize due to the given phase outcome.""" if self._is_aborted(): return @@ -394,7 +429,7 @@ def finalize_from_phase_outcome(self, phase_execution_outcome): 'A phase stopped the test run.') self._finalize(test_record.Outcome.FAIL) - def finalize_normally(self): + def finalize_normally(self) -> None: """Mark the state as finished. This method is called on normal test completion. The outcome will be either @@ -421,6 +456,9 @@ def finalize_normally(self): self._finalize(test_record.Outcome.ERROR) elif any(d.is_failure for d in self.test_record.diagnoses): self._finalize(test_record.Outcome.FAIL) + elif any(s.outcome == test_record.SubtestOutcome.FAIL + for s in self.test_record.subtests): + self._finalize(test_record.Outcome.FAIL) else: # Otherwise, the test run was successful. self._finalize(test_record.Outcome.PASS) @@ -429,7 +467,7 @@ def finalize_normally(self): 'Finishing test execution normally with outcome %s.', self.test_record.outcome.name) - def abort(self): + def abort(self) -> None: if self._is_aborted(): return @@ -438,7 +476,7 @@ def abort(self): self.test_record.add_outcome_details('ABORTED', 'Test aborted by operator.') self._finalize(test_record.Outcome.ABORTED) - def _finalize(self, test_outcome): + def _finalize(self, test_outcome: test_record.Outcome) -> None: aborting = test_outcome == test_record.Outcome.ABORTED assert not self.is_finalized or aborting, ( 'Test already completed with status %s!' % self._status.name) @@ -454,20 +492,21 @@ def _finalize(self, test_outcome): self._status = self.Status.COMPLETED self.notify_update() - def _is_aborted(self): + def _is_aborted(self) -> bool: if (self.is_finalized and self.test_record.outcome == test_record.Outcome.ABORTED): self.state_logger.debug('Test already aborted.') return True return False - def _outcome_is_failure_exception(self, outcome): + def _outcome_is_failure_exception( + self, outcome: phase_executor.PhaseExecutionOutcome) -> bool: for failure_exception in self.test_options.failure_exceptions: if isinstance(outcome.phase_result.exc_val, failure_exception): return True return False - def __str__(self): + def __str__(self) -> Text: return '<%s: %s@%s Running Phase: %s>' % ( type(self).__name__, self.test_record.dut_id, @@ -476,67 +515,73 @@ def __str__(self): ) -class PhaseState(mutablerecords.Record( - 'PhaseState', - [ - 'name', - 'phase_record', - 'measurements', - 'options', - 'logger', - 'test_state', - 'diagnosers', - ], - { - 'hit_repeat_limit': False, - '_cached': dict, - '_update_measurements': set, - })): +@attr.s +class PhaseState(object): """Data type encapsulating interesting information about a running phase. Attributes: + name: Name of the phase. phase_record: A test_record.PhaseRecord for the running phase. - measurements: A dict mapping measurement name to it's declaration; this - dict can be passed to measurements.Collection to initialize a user- - facing Collection for setting measurements. + measurements: A dict mapping measurement name to it's declaration; this dict + can be passed to measurements.Collection to initialize a user-facing + Collection for setting measurements. options: the PhaseOptions from the phase descriptor. logger: logging.Logger for this phase execution run. test_state: TestState, parent test state. diagnosers: list of PhaseDiagnoser instances to run after the phase - finishes. + finishes. hit_repeat_limit: bool, True when the phase repeat limit was hit. _cached: A cached representation of the running test state that; updated in - place to save allocation time. - - Properties: + place to save allocation time. attachments: Convenience accessor for phase_record.attachments. result: Convenience getter/setter for phase_record.result. """ - def __init__(self, *args, **kwargs): - super(PhaseState, self).__init__(*args, **kwargs) + name = attr.ib(type=Text) + phase_record = attr.ib(type=test_record.PhaseRecord) + measurements = attr.ib(type=Dict[Text, measurements.Measurement]) + options = attr.ib(type=phase_descriptor.PhaseOptions) + logger = attr.ib(type=logging.Logger) + test_state = attr.ib(type=TestState) + diagnosers = attr.ib(type=List[diagnoses_lib.BasePhaseDiagnoser]) + hit_repeat_limit = attr.ib(type=bool, default=False) + _cached = attr.ib(type=Dict[Text, Any], factory=dict) + _update_measurements = attr.ib(type=Set[Text], factory=set) + + def __attrs_post_init__(self): for m in six.itervalues(self.measurements): # Using functools.partial to capture the value of the loop variable. m.set_notification_callback(functools.partial(self._notify, m.name)) self._cached = { - 'name': self.name, - 'codeinfo': data.convert_to_base_types(self.phase_record.codeinfo), - 'descriptor_id': data.convert_to_base_types( - self.phase_record.descriptor_id), + 'name': + self.name, + 'codeinfo': + data.convert_to_base_types(self.phase_record.codeinfo), + 'descriptor_id': + data.convert_to_base_types(self.phase_record.descriptor_id), # Options are not set until the phase is finished. - 'options': None, + 'options': + None, 'measurements': { - k: m.as_base_types() for k, m in six.iteritems(self.measurements)}, + k: m.as_base_types() for k, m in six.iteritems(self.measurements) + }, 'attachments': {}, - 'start_time_millis': long(self.phase_record.record_start_time()), + 'start_time_millis': + long(self.phase_record.record_start_time()), + 'subtest_name': + None, } @classmethod - def from_descriptor(cls, phase_desc, test_state, logger): + def from_descriptor(cls, phase_desc: phase_descriptor.PhaseDescriptor, + test_state: TestState, + logger: logging.Logger) -> 'PhaseState': + """Create a PhaseState from a phase descriptor.""" # Measurements are copied because their state is modified during the phase # execution. - measurements_copy = [copy.deepcopy(measurement) - for measurement in phase_desc.measurements] + measurements_copy = [ + copy.deepcopy(measurement) for measurement in phase_desc.measurements + ] diag_store = test_state.diagnoses_manager.store for m in measurements_copy: # Check the conditional validators to see if their results have been @@ -556,11 +601,11 @@ def from_descriptor(cls, phase_desc, test_state, logger): diagnosers=phase_desc.diagnosers, ) - def _notify(self, measurement_name): + def _notify(self, measurement_name: Text) -> None: self._update_measurements.add(measurement_name) self.test_state.notify_update() - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" cur_update_measurements = self._update_measurements self._update_measurements = set() @@ -571,27 +616,33 @@ def as_base_types(self): return self._cached @property - def result(self): + def result(self) -> Optional[phase_executor.PhaseExecutionOutcome]: return self.phase_record.result @result.setter - def result(self, result): + def result(self, result: phase_executor.PhaseExecutionOutcome): self.phase_record.result = result + def set_subtest_name(self, subtest_name: Text) -> None: + self.phase_record.subtest_name = subtest_name + self._cached['subtest_name'] = subtest_name + @property - def attachments(self): + def attachments(self) -> Dict[Text, test_record.Attachment]: return self.phase_record.attachments - def attach(self, name, binary_data, mimetype=INFER_MIMETYPE): + def attach(self, + name: Text, + binary_data: Union[Text, bytes], + mimetype: MimetypeT = INFER_MIMETYPE) -> None: """Store the given binary_data as an attachment with the given name. Args: name: Attachment name under which to store this binary_data. binary_data: Data to attach. - mimetype: One of the following: - INFER_MIMETYPE - The type will be guessed from the attachment name. - None - The type will be left unspecified. - A string - The type will be set to the specified value. + mimetype: One of the following: INFER_MIMETYPE - The type will be guessed + from the attachment name. None - The type will be left unspecified. A + string - The type will be set to the specified value. Raises: DuplicateAttachmentError: Raised if there is already an attachment with @@ -610,16 +661,19 @@ def attach(self, name, binary_data, mimetype=INFER_MIMETYPE): self.phase_record.attachments[name] = attach_record self._cached['attachments'][name] = attach_record._asdict() - def attach_from_file(self, filename, name=None, mimetype=INFER_MIMETYPE): + def attach_from_file(self, + filename: Text, + name: Optional[Text] = None, + mimetype: MimetypeT = INFER_MIMETYPE) -> None: """Store the contents of the given filename as an attachment. Args: filename: The file to read data from to attach. - name: If provided, override the attachment name, otherwise it will - default to the filename. + name: If provided, override the attachment name, otherwise it will default + to the filename. mimetype: One of the following: * INFER_MIMETYPE: The type will be guessed first, from the file name, - and second (i.e. as a fallback), from the attachment name. + and second (i.e. as a fallback), from the attachment name. * None: The type will be left unspecified. * A string: The type will be set to the specified value. @@ -630,18 +684,19 @@ def attach_from_file(self, filename, name=None, mimetype=INFER_MIMETYPE): """ if mimetype is INFER_MIMETYPE: mimetype = mimetypes.guess_type(filename)[0] or mimetype - with open(filename, 'rb') as f: # pylint: disable=invalid-name + with open(filename, 'rb') as f: self.attach( - name if name is not None else os.path.basename(filename), f.read(), + name if name is not None else os.path.basename(filename), + f.read(), mimetype=mimetype) - def add_diagnosis(self, diagnosis): + def add_diagnosis(self, diagnosis: diagnoses_lib.Diagnosis) -> None: if diagnosis.is_failure: self.phase_record.failure_diagnosis_results.append(diagnosis.result) else: self.phase_record.diagnosis_results.append(diagnosis.result) - def _finalize_measurements(self): + def _finalize_measurements(self) -> None: """Perform end-of-phase finalization steps for measurements. Any UNSET measurements will cause the Phase to FAIL unless @@ -667,7 +722,7 @@ def _finalize_measurements(self): # Set final values on the PhaseRecord. self.phase_record.measurements = self.measurements - def _measurements_pass(self): + def _measurements_pass(self) -> bool: allowed_outcomes = {measurements.Outcome.PASS} if conf.allow_unset_measurements: allowed_outcomes.add(measurements.Outcome.UNSET) @@ -675,14 +730,19 @@ def _measurements_pass(self): return all(meas.outcome in allowed_outcomes for meas in self.phase_record.measurements.values()) - def _set_prediagnosis_phase_outcome(self): - if self.result is None or self.result.is_terminal or self.hit_repeat_limit: + def _set_prediagnosis_phase_outcome(self) -> None: + """Set the phase outcome before running diagnosers.""" + result = self.result + if result is None or result.is_terminal or self.hit_repeat_limit: self.logger.debug('Phase outcome is ERROR.') outcome = test_record.PhaseOutcome.ERROR - elif self.result.is_repeat or self.result.is_skip: + elif result.is_repeat or result.is_skip: self.logger.debug('Phase outcome is SKIP.') outcome = test_record.PhaseOutcome.SKIP - elif self.result.is_fail_and_continue: + elif result.is_fail_subtest: + self.logger.debug('Phase outcome is FAIL due to subtest failure.') + outcome = test_record.PhaseOutcome.FAIL + elif result.is_fail_and_continue: self.logger.debug('Phase outcome is FAIL due to phase result.') outcome = test_record.PhaseOutcome.FAIL elif not self._measurements_pass(): @@ -693,7 +753,8 @@ def _set_prediagnosis_phase_outcome(self): outcome = test_record.PhaseOutcome.PASS self.phase_record.outcome = outcome - def _set_postdiagnosis_phase_outcome(self): + def _set_postdiagnosis_phase_outcome(self) -> None: + """Set the phase outcome after diagnosers run.""" if self.phase_record.outcome == test_record.PhaseOutcome.ERROR: return # Check for errors during diagnoser execution. @@ -707,7 +768,9 @@ def _set_postdiagnosis_phase_outcome(self): self.logger.debug('Phase outcome is FAIL due to diagnoses.') self.phase_record.outcome = test_record.PhaseOutcome.FAIL - def _execute_phase_diagnoser(self, diagnoser): + def _execute_phase_diagnoser( + self, diagnoser: diagnoses_lib.BasePhaseDiagnoser) -> None: + """Execute a single phase diagnoser.""" try: self.test_state.diagnoses_manager.execute_phase_diagnoser( diagnoser, self, self.test_state.test_record) @@ -721,16 +784,21 @@ def _execute_phase_diagnoser(self, diagnoser): self.phase_record.result = phase_executor.PhaseExecutionOutcome( phase_executor.ExceptionInfo(*sys.exc_info())) - def _execute_phase_diagnosers(self): - if self.result.is_aborted: + def _execute_phase_diagnosers(self) -> None: + """Execute all the diagnosers for this phase.""" + result = self.result + if result is None: + self.logger.error('Internal error occurred; skipping diagnosers.') + return + if result.is_aborted: self.logger.warning('Skipping diagnosers when phase was aborted.') return - if self.result.is_repeat or self.result.is_skip: + if result.is_repeat or result.is_skip: return for diagnoser in self.diagnosers: self._execute_phase_diagnoser(diagnoser) - def finalize(self): + def finalize(self) -> None: self._finalize_measurements() self._set_prediagnosis_phase_outcome() self._execute_phase_diagnosers() diff --git a/openhtf/output/callbacks/__init__.py b/openhtf/output/callbacks/__init__.py index b30860bf5..d83ea1e02 100644 --- a/openhtf/output/callbacks/__init__.py +++ b/openhtf/output/callbacks/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """This module contains support for various built-in output mechanisms. Here, a base OutputToFile class is implemented to provide simple output to @@ -23,31 +22,44 @@ import contextlib import shutil import tempfile +import typing +from typing import BinaryIO, Callable, Iterator, Optional, Text, Union from openhtf import util +from openhtf.core import test_record from openhtf.util import data import six from six.moves import collections_abc +from six.moves import cPickle as pickle + +SerializedTestRecord = Union[Text, bytes, Iterator[Union[Text, bytes]]] # TODO(wallacbe): Switch to util class Atomic(object): """Class that does atomic write in a contextual manner.""" - def __init__(self, filename): + def __init__(self, filename: Text): self.filename = filename self.temp = tempfile.NamedTemporaryFile(delete=False) - def write(self, write_data): - if hasattr(write_data, 'decode'): - return self.temp.write(write_data) - return self.temp.write(write_data.encode()) + def write(self, write_data: Union[Text, bytes]) -> int: + return self.temp.write(six.ensure_binary(write_data)) - def close(self): + def close(self) -> None: self.temp.close() shutil.move(self.temp.name, self.filename) +class CloseAttachments(object): + """Close the attachment files associated with a test record.""" + + def __call__(self, test_rec: test_record.TestRecord) -> None: + for phase_rec in test_rec.phases: + for attachment in six.itervalues(phase_rec.attachments): + attachment.close() + + class OutputToFile(object): """Output the given TestRecord to a file. @@ -58,67 +70,71 @@ class OutputToFile(object): serialize_test_record() method. Additionally, subclasses may implement more complex file naming mechanisms by overriding the open_file() method. - Args: - test_record: The TestRecord to write out to a file. - Attributes: - filename_pattern: A string that defines filename pattern with placeholders - to be replaced by test run metadata values - filename: A string that defines the final file name with all the - placeholders replaced + filename_pattern: A string or callable that returns a string that defines + filename pattern with placeholders to be replaced by test run metadata + values. Exclusive with output_file. + output_file: A file object. Exclusive with filename_pattern. """ - def __init__(self, filename_pattern): - self.filename_pattern = filename_pattern - self._pattern_formattable = ( - isinstance(filename_pattern, six.string_types) or - callable(filename_pattern)) + def __init__(self, filename_pattern_or_file: Union[Text, Callable[..., Text], + BinaryIO]): + self.filename_pattern = None # type: Optional[Union[Text, Callable[..., Text]]] + self.output_file = None # type: Optional[BinaryIO] + if (isinstance(filename_pattern_or_file, six.string_types) or + callable(filename_pattern_or_file)): + self.filename_pattern = filename_pattern_or_file + else: + self.output_file = filename_pattern_or_file @staticmethod - def serialize_test_record(test_record): + def serialize_test_record( + test_rec: test_record.TestRecord) -> SerializedTestRecord: """Override method to alter how test records are serialized to file data.""" - return six.moves.pickle.dumps(test_record, -1) + return pickle.dumps(test_rec, -1) @staticmethod - def open_file(filename): + def open_file(filename: Text) -> Atomic: """Override method to alter file open behavior or file types.""" return Atomic(filename) - def create_file_name(self, test_record): - """Use filename_pattern and test_record to create filename.""" + def create_file_name(self, test_rec: test_record.TestRecord) -> Text: + """Use filename_pattern and test_rec to create filename.""" + if self.filename_pattern is None: + raise ValueError( + 'filename_pattern must be string or callable to create file name.') # Ignore keys for the log filename to not convert larger data structures. record_dict = data.convert_to_base_types( - test_record, ignore_keys=('code_info', 'phases', 'log_records')) - if self._pattern_formattable: - return util.format_string(self.filename_pattern, record_dict) - else: - raise ValueError( - 'filename_pattern must be string or callable to create file name') + test_rec, ignore_keys=('code_info', 'phases', 'log_records')) + return typing.cast(Text, + util.format_string(self.filename_pattern, record_dict)) @contextlib.contextmanager - def open_output_file(self, test_record): + def open_output_file( + self, + test_rec: test_record.TestRecord) -> Iterator[Union[Atomic, BinaryIO]]: """Open file based on pattern.""" - if self._pattern_formattable: - filename = self.create_file_name(test_record) + if self.filename_pattern: + filename = self.create_file_name(test_rec) output_file = self.open_file(filename) try: yield output_file finally: output_file.close() - elif hasattr(self.filename_pattern, 'write'): - yield self.filename_pattern + elif self.output_file: + yield self.output_file else: - raise ValueError( + raise TypeError( 'filename_pattern must be string, callable, or File-like object') - def __call__(self, test_record): - with self.open_output_file(test_record) as outfile: - serialized_record = self.serialize_test_record(test_record) + def __call__(self, test_rec: test_record.TestRecord) -> None: + with self.open_output_file(test_rec) as outfile: + serialized_record = self.serialize_test_record(test_rec) if isinstance(serialized_record, six.string_types): - outfile.write(serialized_record) + outfile.write(six.ensure_binary(serialized_record)) elif isinstance(serialized_record, collections_abc.Iterable): for chunk in serialized_record: - outfile.write(chunk) + outfile.write(six.ensure_binary(chunk)) else: raise TypeError('Expected string or iterable but got {}.'.format( type(serialized_record))) diff --git a/openhtf/output/callbacks/console_summary.py b/openhtf/output/callbacks/console_summary.py index 8ab5589e3..e351dcb74 100644 --- a/openhtf/output/callbacks/console_summary.py +++ b/openhtf/output/callbacks/console_summary.py @@ -3,49 +3,51 @@ import os import sys -from openhtf.core import test_record from openhtf.core import measurements +from openhtf.core import test_record import six class ConsoleSummary(): - """Print test results with failure info on console. """ + """Print test results with failure info on console.""" # pylint: disable=invalid-name def __init__(self, indent=2, output_stream=sys.stdout): self.indent = ' ' * indent - if os.name == 'posix': #Linux and Mac + if os.name == 'posix': # Linux and Mac. self.RED = '\033[91m' self.GREEN = '\033[92m' self.ORANGE = '\033[93m' self.RESET = '\033[0m' self.BOLD = '\033[1m' else: - self.RED = "" - self.GREEN = "" - self.ORANGE = "" - self.RESET = "" - self.BOLD = "" + self.RED = '' + self.GREEN = '' + self.ORANGE = '' + self.RESET = '' + self.BOLD = '' self.color_table = { - test_record.Outcome.PASS:self.GREEN, - test_record.Outcome.FAIL:self.RED, - test_record.Outcome.ERROR:self.ORANGE, - test_record.Outcome.TIMEOUT:self.ORANGE, - test_record.Outcome.ABORTED:self.RED, + test_record.Outcome.PASS: self.GREEN, + test_record.Outcome.FAIL: self.RED, + test_record.Outcome.ERROR: self.ORANGE, + test_record.Outcome.TIMEOUT: self.ORANGE, + test_record.Outcome.ABORTED: self.RED, } self.output_stream = output_stream + # pylint: enable=invalid-name def __call__(self, record): - output_lines = [''.join((self.color_table[record.outcome], - self.BOLD, record.code_info.name, ':', - record.outcome.name, self.RESET))] + output_lines = [ + ''.join((self.color_table[record.outcome], self.BOLD, + record.code_info.name, ':', record.outcome.name, self.RESET)) + ] if record.outcome != test_record.Outcome.PASS: for phase in record.phases: new_phase = True - phase_time_sec = (float(phase.end_time_millis) - - float(phase.start_time_millis)) / 1000.0 + phase_time_sec = (float(phase.end_time_millis) - + float(phase.start_time_millis)) / 1000.0 for name, measurement in six.iteritems(phase.measurements): if measurement.outcome != measurements.Outcome.PASS: if new_phase: @@ -56,17 +58,17 @@ def __call__(self, record): output_lines.append('%sfailed_item: %s (%s)' % (self.indent, name, measurement.outcome)) output_lines.append('%smeasured_value: %s' % - (self.indent*2, measurement.measured_value)) - output_lines.append('%svalidators:' % (self.indent*2)) + (self.indent * 2, measurement.measured_value)) + output_lines.append('%svalidators:' % (self.indent * 2)) for validator in measurement.validators: output_lines.append('%svalidator: %s' % - (self.indent*3, str(validator))) + (self.indent * 3, str(validator))) phase_result = phase.result.phase_result - if not phase_result: #Timeout + if not phase_result: # Timeout. output_lines.append('timeout phase: %s [ran for %.2f sec]' % (phase.name, phase_time_sec)) - elif 'CONTINUE' not in str(phase_result): #Exception + elif 'CONTINUE' not in str(phase_result): # Exception. output_lines.append('%sexception type: %s' % (self.indent, record.outcome_details[0].code)) diff --git a/openhtf/output/callbacks/json_factory.py b/openhtf/output/callbacks/json_factory.py index b6b6fdf5d..e410f24ab 100644 --- a/openhtf/output/callbacks/json_factory.py +++ b/openhtf/output/callbacks/json_factory.py @@ -2,6 +2,7 @@ import base64 import json +from typing import Any, BinaryIO, Callable, Dict, Iterator, Text, Union from openhtf.core import test_record from openhtf.output import callbacks @@ -11,7 +12,7 @@ class TestRecordEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, test_record.Attachment): dct = obj._asdict() dct['data'] = base64.standard_b64encode(obj.data).decode('utf-8') @@ -19,6 +20,51 @@ def default(self, obj): return super(TestRecordEncoder, self).default(obj) +def convert_test_record_to_json( + test_rec: test_record.TestRecord, + inline_attachments: bool = True, allow_nan: bool = False + ) -> Dict[Text, Any]: + """Convert the test record to a JSON object. + + Args: + test_rec: The test record to convert. + inline_attachments: Whether attachments should be included inline in the + output. Set to False if you expect to have large binary attachments. If + True (the default), then attachments are base64 encoded to allow for + binary data that's not supported by JSON directly. + allow_nan: If False, out of range float values will raise ValueError. + + Returns: + The test record encoded as JSON objects. + """ + as_dict = data.convert_to_base_types(test_rec, json_safe=(not allow_nan)) + if inline_attachments: + for phase, original_phase in zip(as_dict['phases'], test_rec.phases): + for name, attachment in six.iteritems(original_phase.attachments): + phase['attachments'][name] = attachment + return as_dict + + +def stream_json( + encoded_test_rec: Dict[Text, Any], allow_nan: bool = False, **kwargs +) -> Iterator[Text]: + """Convert the JSON object encoded test record into a stream of strings. + + Args: + encoded_test_rec: The JSON converted test record. + allow_nan: If False, out of range float values will raise ValueError. + **kwargs: Additional arguments to be passed to the JSON encoder. + + Returns: + Iterable of JSON strings. + """ + json_encoder = TestRecordEncoder(allow_nan=allow_nan, **kwargs) + + # The iterencode return type in typeshed for PY2 is wrong; not worried about + # fixing it as we are droping PY2 support soon. + return json_encoder.iterencode(encoded_test_rec) # pytype: disable=bad-return-type + + class OutputToJSON(callbacks.OutputToFile): """Return an output callback that writes JSON Test Records. @@ -29,34 +75,34 @@ class OutputToJSON(callbacks.OutputToFile): test = openhtf.Test(PhaseOne, PhaseTwo) test.add_output_callback(openhtf.output.callbacks.OutputToJSON( '/data/test_records/{dut_id}.{metadata[test_name]}.json')) - - Args: - filename_pattern: A format string specifying the filename to write to, - will be formatted with the Test Record as a dictionary. May also be a - file-like object to write to directly. - inline_attachments: Whether attachments should be included inline in the - output. Set to False if you expect to have large binary attachments. If - True (the default), then attachments are base64 encoded to allow for - binary data that's not supported by JSON directly. """ - def __init__(self, filename_pattern=None, inline_attachments=True, **kwargs): - super(OutputToJSON, self).__init__(filename_pattern) + def __init__(self, filename_pattern_or_file: Union[Text, Callable[..., Text], + BinaryIO], + inline_attachments: bool = True, + allow_nan: bool = False, **json_kwargs: Any): + """Constructor. + + Args: + filename_pattern_or_file: A format string specifying the filename to write + to, will be formatted with the Test Record as a dictionary. May also be + a file-like object to write to directly. + inline_attachments: Whether attachments should be included inline in the + output. Set to False if you expect to have large binary attachments. If + True (the default), then attachments are base64 encoded to allow for + binary data that's not supported by JSON directly. + allow_nan: If False, out of range float values will raise ValueError. + **json_kwargs: Additional arguments to be passed to the JSON encoder. + """ + super(OutputToJSON, self).__init__(filename_pattern_or_file) self.inline_attachments = inline_attachments + self.allow_nan = allow_nan + self._json_kwargs = json_kwargs - # Conform strictly to the JSON spec by default. - kwargs.setdefault('allow_nan', False) - self.allow_nan = kwargs['allow_nan'] - self.json_encoder = TestRecordEncoder(**kwargs) - - def serialize_test_record(self, test_record): - return self.json_encoder.iterencode(self.convert_to_dict(test_record)) - - def convert_to_dict(self, test_record): - as_dict = data.convert_to_base_types(test_record, - json_safe=(not self.allow_nan)) - if self.inline_attachments: - for phase, original_phase in zip(as_dict['phases'], test_record.phases): - for name, attachment in six.iteritems(original_phase.attachments): - phase['attachments'][name] = attachment - return as_dict + def serialize_test_record(self, test_rec: test_record.TestRecord + ) -> Iterator[Text]: + encoded = convert_test_record_to_json( + test_rec, inline_attachments=self.inline_attachments, + allow_nan=self.allow_nan) + return stream_json(encoded, allow_nan=self.allow_nan, + **self._json_kwargs) diff --git a/openhtf/output/callbacks/mfg_inspector.py b/openhtf/output/callbacks/mfg_inspector.py index 6a26ad255..28ad52803 100644 --- a/openhtf/output/callbacks/mfg_inspector.py +++ b/openhtf/output/callbacks/mfg_inspector.py @@ -54,17 +54,18 @@ def _send_mfg_inspector_request(envelope_data, credentials, destination_url): raise UploadFailedError(message) -def send_mfg_inspector_data(inspector_proto, credentials, destination_url): +def send_mfg_inspector_data(inspector_proto, credentials, destination_url, + payload_type): """Upload MfgEvent to steam_engine.""" envelope = guzzle_pb2.TestRunEnvelope() envelope.payload = zlib.compress(inspector_proto.SerializeToString()) - envelope.payload_type = guzzle_pb2.COMPRESSED_TEST_RUN + envelope.payload_type = payload_type envelope_data = envelope.SerializeToString() for _ in range(5): try: - result = _send_mfg_inspector_request( - envelope_data, credentials, destination_url) + result = _send_mfg_inspector_request(envelope_data, credentials, + destination_url) return result except UploadFailedError: time.sleep(1) @@ -76,7 +77,6 @@ def send_mfg_inspector_data(inspector_proto, credentials, destination_url): class _MemStorage(oauth2client.client.Storage): - # pylint: disable=invalid-name """Helper Storage class that keeps credentials in memory.""" def __init__(self): @@ -125,7 +125,7 @@ class MfgInspector(object): SCOPE_CODE_URI = 'https://www.googleapis.com/auth/glass.infra.quantum_upload' DESTINATION_URL = ('https://clients2.google.com/factoryfactory/' 'uploads/quantum_upload/?json') - PARAMS = ['dut_id', 'end_time_millis', 'start_time_millis', 'station_id'] + PARAMS = ['dut_id', 'end_time_millis', 'start_time_millis', 'station_id'] # These attributes control format of callback and what actions are undertaken # when called. These should either be set by a subclass or via configure. @@ -138,8 +138,11 @@ class MfgInspector(object): # saving to disk via save_to_disk. _default_filename_pattern = None - def __init__(self, user=None, keydata=None, - token_uri=TOKEN_URI, destination_url=DESTINATION_URL): + def __init__(self, + user=None, + keydata=None, + token_uri=TOKEN_URI, + destination_url=DESTINATION_URL): self.user = user self.keydata = keydata self.token_uri = token_uri @@ -175,9 +178,10 @@ def from_json(cls, json_data): Returns: a MfgInspectorCallback with credentials. """ - return cls(user=json_data['client_email'], - keydata=json_data['private_key'], - token_uri=json_data['token_uri']) + return cls( + user=json_data['client_email'], + keydata=json_data['private_key'], + token_uri=json_data['token_uri']) def _check_cached_params(self, test_record_obj): """Check if all cached params equal the values in test record.""" @@ -188,7 +192,8 @@ def _check_cached_params(self, test_record_obj): def _convert(self, test_record_obj): """Convert and cache a test record to a mfg-inspector proto.""" - if self._cached_proto is None or not self._check_cached_params(test_record_obj): + if (self._cached_proto is None or + not self._check_cached_params(test_record_obj)): self._cached_proto = self._converter(test_record_obj) for param in self.PARAMS: self._cached_params[param] = getattr(test_record_obj, param) @@ -203,9 +208,8 @@ def save_to_disk(self, filename_pattern=None): pattern = filename_pattern or self._default_filename_pattern if not pattern: - raise RuntimeError( - 'Must specify provide a filename_pattern or set a ' - '_default_filename_pattern on subclass.') + raise RuntimeError('Must specify provide a filename_pattern or set a ' + '_default_filename_pattern on subclass.') def save_to_disk_callback(test_record_obj): proto = self._convert(test_record_obj) @@ -216,7 +220,7 @@ def save_to_disk_callback(test_record_obj): return save_to_disk_callback - def upload(self): + def upload(self, payload_type=guzzle_pb2.COMPRESSED_TEST_RUN): """Returns a callback to convert a test record to a proto and upload.""" if not self._converter: raise RuntimeError( @@ -228,8 +232,9 @@ def upload(self): def upload_callback(test_record_obj): proto = self._convert(test_record_obj) - self.upload_result = send_mfg_inspector_data( - proto, self.credentials, self.destination_url) + self.upload_result = send_mfg_inspector_data(proto, self.credentials, + self.destination_url, + payload_type) return upload_callback @@ -268,6 +273,6 @@ class UploadToMfgInspector(MfgInspector): def _converter(test_record_obj): return test_runs_converter.test_run_from_test_record(test_record_obj) - def __call__(self, test_record_obj): # pylint: disable=invalid-name + def __call__(self, test_record_obj): upload_callback = self.upload() upload_callback(test_record_obj) diff --git a/openhtf/output/proto/mfg_event_converter.py b/openhtf/output/proto/mfg_event_converter.py index 1c1acb989..ae4613a9e 100644 --- a/openhtf/output/proto/mfg_event_converter.py +++ b/openhtf/output/proto/mfg_event_converter.py @@ -26,6 +26,7 @@ from past.builtins import unicode +import six TEST_RECORD_ATTACHMENT_NAME = 'OpenHTF_record.json' @@ -104,7 +105,7 @@ def mfg_event_from_test_record(record): def _populate_basic_data(mfg_event, record): """Copies data from the OpenHTF TestRecord to the MfgEvent proto.""" - # TODO: + # TODO(openhtf-team): # * Missing in proto: set run name from metadata. # * `part_tags` field on proto is unused # * `timings` field on proto is unused. @@ -168,19 +169,21 @@ def _attach_record_as_json(mfg_event, record): attachment.type = test_runs_pb2.TEXT_UTF8 -def _convert_object_to_json(obj): +def _convert_object_to_json(obj): # pylint: disable=missing-function-docstring # Since there will be parts of this that may have unicode, either as # measurement or in the logs, we have to be careful and convert everything # to unicode, merge, then encode to UTF-8 to put it into the proto. - json_encoder = json.JSONEncoder(sort_keys=True, indent=2, ensure_ascii=False) - pieces = [] - for piece in json_encoder.iterencode(obj): - if isinstance(piece, bytes): - pieces.append(unicode(piece, errors='replace')) + + def bytes_handler(o): + # For bytes, JSONEncoder will fallback to this function to convert to str. + if six.PY3 and isinstance(o, six.binary_type): + return six.ensure_str(o, encoding='utf-8', errors='replace') else: - pieces.append(piece) + raise TypeError(repr(o) + ' is not JSON serializable') - return (u''.join(pieces)).encode('utf8', errors='replace') + json_encoder = json.JSONEncoder( + sort_keys=True, indent=2, ensure_ascii=False, default=bytes_handler) + return json_encoder.encode(obj).encode('utf-8', errors='replace') def _attach_config(mfg_event, record): @@ -208,7 +211,7 @@ def __init__(self, all_names): self._counts = collections.Counter(all_names) self._seen = collections.Counter() - def make_unique(self, name): + def make_unique(self, name): # pylint: disable=missing-function-docstring count = self._counts[name] assert count >= 1, 'Seeing a new name that was not given to the constructor' if count == 1: @@ -272,12 +275,8 @@ def multidim_measurement_to_attachment(name, measurement): for d in dimensions: if d.suffix is None: suffix = u'' - # Ensure that the suffix is unicode. It's typically str/bytes because - # units.py looks them up against str/bytes. - elif isinstance(d.suffix, unicode): - suffix = d.suffix else: - suffix = d.suffix.decode('utf8') + suffix = six.ensure_text(d.suffix) dims.append({ 'uom_suffix': suffix, 'uom_code': d.code, @@ -374,7 +373,7 @@ def _copy_unidimensional_measurement( elif isinstance(value, bytes): # text_value expects unicode or ascii-compatible strings, so we must # 'decode' it, even if it's actually just garbage bytestring data. - mfg_measurement.text_value = unicode(value, errors='replace') + mfg_measurement.text_value = unicode(value, errors='replace') # pytype: disable=wrong-keyword-args elif isinstance(value, unicode): # Don't waste time and potential errors decoding unicode. mfg_measurement.text_value = value diff --git a/openhtf/output/proto/test_runs_converter.py b/openhtf/output/proto/test_runs_converter.py index 0a3d25426..f12a18aed 100644 --- a/openhtf/output/proto/test_runs_converter.py +++ b/openhtf/output/proto/test_runs_converter.py @@ -34,17 +34,11 @@ from openhtf.core import measurements from openhtf.core import test_record from openhtf.output.callbacks import json_factory -from openhtf.util import validators - from openhtf.output.proto import test_runs_pb2 +from openhtf.util import validators +import six -try: - from past.types import unicode # pylint: disable=redefined-builtin,g-import-not-at-top -except ImportError: - pass - -import six # pylint: disable=g-import-not-at-top,g-bad-import-order - +# pylint: disable=g-complex-comprehension # pylint: disable=no-member MIMETYPE_MAP = { @@ -138,9 +132,9 @@ def _attach_json(record, testrun): record: the OpenHTF TestRecord being converted testrun: a TestRun proto """ - record_json = json_factory.OutputToJSON( - inline_attachments=False, - sort_keys=True, indent=2).serialize_test_record(record) + encoded = json_factory.convert_test_record_to_json( + record, inline_attachments=False) + record_json = json_factory.stream_json(encoded, sort_keys=True, indent=2) testrun_param = testrun.info_parameters.add() testrun_param.name = 'OpenHTF_record.json' testrun_param.value_binary = b''.join(r.encode('utf-8') for r in record_json) diff --git a/openhtf/output/servers/dashboard_server.py b/openhtf/output/servers/dashboard_server.py index fcb887a22..76db03c97 100644 --- a/openhtf/output/servers/dashboard_server.py +++ b/openhtf/output/servers/dashboard_server.py @@ -8,28 +8,36 @@ import collections import json import logging -import six import socket import threading import time -import sockjs.tornado -import tornado.web - -from openhtf.output.servers import station_server from openhtf.output.servers import pub_sub +from openhtf.output.servers import station_server from openhtf.output.servers import web_gui_server from openhtf.output.web_gui import web_launcher from openhtf.util import data from openhtf.util import multicast +import six +import sockjs.tornado +import tornado.web _LOG = logging.getLogger(__name__) DASHBOARD_SERVER_TYPE = 'dashboard' -StationInfo = collections.namedtuple( - 'StationInfo', - 'cell host port station_id status test_description test_name') + +class StationInfo( # pylint: disable=missing-class-docstring + collections.namedtuple('StationInfo', [ + 'cell', + 'host', + 'port', + 'station_id', + 'status', + 'test_description', + 'test_name', + ])): + pass def _discover(**kwargs): @@ -39,17 +47,18 @@ def _discover(**kwargs): try: result = json.loads(response) except ValueError: - _LOG.warn('Received bad JSON over multicast from %s: %s', host, response) + _LOG.warning('Received bad JSON over multicast from %s: %s', host, + response) try: yield StationInfo(result['cell'], host, result['port'], result['station_id'], 'ONLINE', - result.get('test_description'), - result['test_name']) + result.get('test_description'), result['test_name']) except KeyError: if 'last_activity_time_millis' in result: _LOG.debug('Received old station API response on multicast. Ignoring.') else: - _LOG.warn('Received bad multicast response from %s: %s', host, response) + _LOG.warning('Received bad multicast response from %s: %s', host, + response) class StationListHandler(tornado.web.RequestHandler): @@ -118,11 +127,10 @@ def _get_config(self): } def run(self): - _LOG.info( - 'Starting dashboard server at:\n' - ' Local: http://localhost:{port}\n' - ' Remote: http://{host}:{port}' - .format(host=socket.gethostname(), port=self.port)) + _LOG.info('Starting dashboard server at:\n' # pylint: disable=logging-format-interpolation + ' Local: http://localhost:{port}\n' + ' Remote: http://{host}:{port}'.format( + host=socket.gethostname(), port=self.port)) super(DashboardServer, self).run() def stop(self): @@ -135,16 +143,27 @@ def main(): parser = argparse.ArgumentParser( description='Serves web GUI for interacting with multiple OpenHTF ' - 'stations.') - parser.add_argument('--discovery-interval-s', type=int, default=1, - help='Seconds between station discovery attempts.') - parser.add_argument('--launch-web-gui', default=True, action="store_true", - help='Whether to automatically open web GUI.') - parser.add_argument('--no-launch-web-gui', dest="launch_web_gui", - action="store_false", - help='Whether to automatically open web GUI.') - parser.add_argument('--dashboard-server-port', type=int, default=12000, - help='Port on which to serve the dashboard server.') + 'stations.') + parser.add_argument( + '--discovery-interval-s', + type=int, + default=1, + help='Seconds between station discovery attempts.') + parser.add_argument( + '--launch-web-gui', + default=True, + action='store_true', + help='Whether to automatically open web GUI.') + parser.add_argument( + '--no-launch-web-gui', + dest='launch_web_gui', + action='store_false', + help='Whether to automatically open web GUI.') + parser.add_argument( + '--dashboard-server-port', + type=int, + default=12000, + help='Port on which to serve the dashboard server.') # These have default values in openhtf.util.multicast.py. parser.add_argument('--station-discovery-address', type=str) diff --git a/openhtf/output/servers/pub_sub.py b/openhtf/output/servers/pub_sub.py index f29e62e54..dc2acfdcd 100644 --- a/openhtf/output/servers/pub_sub.py +++ b/openhtf/output/servers/pub_sub.py @@ -11,28 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Generic pub/sub implementation using SockJS connections.""" import logging +from openhtf import util as htf_util import sockjs.tornado -from openhtf.util import classproperty - _LOG = logging.getLogger(__name__) class PubSub(sockjs.tornado.SockJSConnection): """Generic pub/sub based on SockJS connections.""" - @classproperty + @htf_util.classproperty def _lock(cls): # pylint: disable=no-self-argument """Ensure subclasses don't share subscriber locks by forcing override.""" raise AttributeError( 'The PubSub class should not be instantiated directly. ' 'Instead, subclass it and override the _lock attribute.') - @classproperty + @htf_util.classproperty def subscribers(cls): # pylint: disable=no-self-argument """Ensure subclasses don't share subscribers by forcing override.""" raise AttributeError( @@ -46,23 +45,23 @@ def publish(cls, message, client_filter=None): Args: message: The message to publish. client_filter: A filter function to call passing in each client. Only - clients for whom the function returns True will have the - message sent to them. + clients for whom the function returns True will have the message sent to + them. """ - with cls._lock: - for client in cls.subscribers: + with cls._lock: # pylint: disable=not-context-manager + for client in cls.subscribers: # pylint: disable=not-an-iterable if (not client_filter) or client_filter(client): client.send(message) def on_open(self, info): _LOG.debug('New subscriber from %s.', info.ip) - with self._lock: + with self._lock: # pylint: disable=not-context-manager self.subscribers.add(self) self.on_subscribe(info) def on_close(self): _LOG.debug('A client unsubscribed.') - with self._lock: + with self._lock: # pylint: disable=not-context-manager self.subscribers.remove(self) self.on_unsubscribe() diff --git a/openhtf/output/servers/station_server.py b/openhtf/output/servers/station_server.py index ba47f7cdb..117b3ba6a 100644 --- a/openhtf/output/servers/station_server.py +++ b/openhtf/output/servers/station_server.py @@ -11,24 +11,21 @@ import logging import os import re -import six import socket import threading import time import types -import sockjs.tornado - import openhtf -from openhtf.output.callbacks import mfg_inspector from openhtf.output.servers import pub_sub from openhtf.output.servers import web_gui_server from openhtf.util import conf from openhtf.util import data from openhtf.util import functions -from openhtf.util import logs from openhtf.util import multicast from openhtf.util import timeouts +import six +import sockjs.tornado STATION_SERVER_TYPE = 'station' @@ -43,12 +40,16 @@ _WAIT_FOR_ANY_EVENT_POLL_S = 0.05 _WAIT_FOR_EXECUTING_TEST_POLL_S = 0.1 -conf.declare('frontend_throttle_s', default_value=_DEFAULT_FRONTEND_THROTTLE_S, - description=('Min wait time between successive updates to the ' - 'frontend.')) -conf.declare('station_server_port', default_value=0, - description=('Port on which to serve the app. If set to zero (the ' - 'default) then an arbitrary port will be chosen.')) +conf.declare( + 'frontend_throttle_s', + default_value=_DEFAULT_FRONTEND_THROTTLE_S, + description=('Min wait time between successive updates to the ' + 'frontend.')) +conf.declare( + 'station_server_port', + default_value=0, + description=('Port on which to serve the app. If set to zero (the ' + 'default) then an arbitrary port will be chosen.')) # These have default values in openhtf.util.multicast.py. conf.declare('station_discovery_address') @@ -75,7 +76,7 @@ def _get_executing_test(): return None, None if len(tests) > 1: - _LOG.warn('Station server does not support multiple executing tests.') + _LOG.warning('Station server does not support multiple executing tests.') test = tests[0] test_state = test.state @@ -123,6 +124,7 @@ def _wait_for_any_event(events, timeout_s): Returns: True if at least one event was set before the timeout expired, else False. """ + def any_event_set(): return any(event.is_set() for event in events) @@ -155,7 +157,7 @@ def run(self): # Note that because logging triggers a call to notify_update(), by # logging a message, we automatically retry publishing the update # after an error occurs. - if error.message == 'dictionary changed size during iteration': + if error.args[0] == 'dictionary changed size during iteration': # These errors occur occasionally and it is infeasible to get rid of # them entirely unless data.convert_to_base_types() is made # thread-safe. Ignore the error and retry quickly. @@ -280,9 +282,12 @@ def _publish_test_state(cls, test_state_dict, message_type): cls._last_message = message def on_subscribe(self, info): - """ - Send the more recent test state to new subscribers when they connect, - unless the test has already completed. + """Send the more recent test state to new subscribers when they connect. + + This is skipped if the test has already completed. + + Args: + info: Subscription info. """ test, _ = _get_executing_test() @@ -354,7 +359,8 @@ def get(self, test_uid): phase_descriptors = [ dict(id=id(phase), **data.convert_to_base_types(phase)) - for phase in test.descriptor.phase_group] + for phase in test.descriptor.phase_group + ] # Wrap value in a dict because writing a list directly is prohibited. self.write({'data': phase_descriptors}) @@ -387,12 +393,11 @@ def post(self, test_uid, plug_name): method = getattr(plug, method_name, None) - if not (plug.enable_remote and - isinstance(method, types.MethodType) and + if not (plug.enable_remote and isinstance(method, types.MethodType) and not method_name.startswith('_') and method_name not in plug.disable_remote_attrs): - self.write('Cannot access method %s of plug %s.' % (method_name, - plug_name)) + self.write('Cannot access method %s of plug %s.' % + (method_name, plug_name)) self.set_status(400) return @@ -407,6 +412,8 @@ def post(self, test_uid, plug_name): class BaseHistoryHandler(web_gui_server.CorsRequestHandler): + history_path = None + def initialize(self, history_path): self.history_path = history_path @@ -466,7 +473,7 @@ class HistoryItemHandler(BaseHistoryHandler): """GET endpoint for a test record from the history.""" def get(self, file_name): - # TODO(Kenadia): Implement the history item handler. The implementation + # TODO(kenadia): Implement the history item handler. The implementation # depends on the format used to store test records on disk. self.write('Not implemented.') self.set_status(500) @@ -482,7 +489,7 @@ class HistoryAttachmentsHandler(BaseHistoryHandler): """ def get(self, file_name, attachment_name): - # TODO(Kenadia): Implement the history item handler. The implementation + # TODO(kenadia): Implement the history item handler. The implementation # depends on the format used to store test records on disk. self.write('Not implemented.') self.set_status(500) @@ -552,7 +559,7 @@ class StationServer(web_gui_server.WebGuiServer): def __init__(self, history_path=None): # Disable tornado's logging. - # TODO(Kenadia): Enable these logs if verbosity flag is at least -vvv. + # TODO(kenadia): Enable these logs if verbosity flag is at least -vvv. # I think this will require changing how StoreRepsInModule works. # Currently, if we call logs.ARG_PARSER.parse_known_args() multiple # times, we multiply the number of v's that we get. @@ -586,11 +593,16 @@ def __init__(self, history_path=None): # Optionally enable history from disk. if history_path is not None: routes.extend(( - (r'/history', HistoryListHandler, {'history_path': history_path}), - (r'/history/(?P[^/]+)', HistoryItemHandler, - {'history_path': history_path}), + (r'/history', HistoryListHandler, { + 'history_path': history_path + }), + (r'/history/(?P[^/]+)', HistoryItemHandler, { + 'history_path': history_path + }), (r'/history/(?P[^/]+)/attachments/(?P.+)', - HistoryAttachmentsHandler, {'history_path': history_path}), + HistoryAttachmentsHandler, { + 'history_path': history_path + }), )) super(StationServer, self).__init__(routes, port, sockets=sockets) @@ -605,11 +617,10 @@ def run(self): _LOG.info('Announcing station server via multicast on %s:%s', self.station_multicast.address, self.station_multicast.port) self.station_multicast.start() - _LOG.info( - 'Starting station server at:\n' - ' Local: http://localhost:{port}\n' - ' Remote: http://{host}:{port}' - .format(host=socket.gethostname(), port=self.port)) + _LOG.info('Starting station server at:\n' # pylint: disable=logging-format-interpolation + ' Local: http://localhost:{port}\n' + ' Remote: http://{host}:{port}'.format( + host=socket.gethostname(), port=self.port)) super(StationServer, self).run() def stop(self): diff --git a/openhtf/output/servers/web_gui_server.py b/openhtf/output/servers/web_gui_server.py index 2c992aa99..a066cdc5e 100644 --- a/openhtf/output/servers/web_gui_server.py +++ b/openhtf/output/servers/web_gui_server.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Extensible HTTP server serving the OpenHTF Angular frontend.""" import os @@ -105,9 +103,6 @@ def validate_absolute_path(self, root, abspath): class TemplateLoader(tornado.template.Loader): - def __init__(self, root_directory, **kwargs): - super(TemplateLoader, self).__init__(root_directory, **kwargs) - def resolve_path(self, name, parent_path=None): return name @@ -140,7 +135,8 @@ def __init__(self, additional_routes, port, sockets=None): routes, default_handler_class=DefaultHandler, template_loader=TemplateLoader(STATIC_FILES_ROOT), - static_path=STATIC_FILES_ROOT,) + static_path=STATIC_FILES_ROOT, + ) self.server = tornado.httpserver.HTTPServer(application) self.server.add_sockets(sockets) diff --git a/openhtf/output/web_gui/src/app/plugs/user-input-plug.component.html b/openhtf/output/web_gui/src/app/plugs/user-input-plug.component.html index a478a729c..bb84bc562 100644 --- a/openhtf/output/web_gui/src/app/plugs/user-input-plug.component.html +++ b/openhtf/output/web_gui/src/app/plugs/user-input-plug.component.html @@ -10,8 +10,8 @@
- - + + user-input-image
diff --git a/openhtf/plugs/__init__.py b/openhtf/plugs/__init__.py index cdd0286d8..7ba7dc320 100644 --- a/openhtf/plugs/__init__.py +++ b/openhtf/plugs/__init__.py @@ -11,96 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -"""The plugs module provides boilerplate for accessing hardware. - -Most tests require interaction with external hardware. This module provides -framework support for such interfaces, allowing for automatic setup and -teardown of the objects. +"""The plugs module provides managing plugs. Test phases can be decorated as using Plug objects, which then get passed into the test via parameters. Plugs are all instantiated at the beginning of a test, and all plugs' tearDown() methods are called at the end of a test. It's up to the Plug implementation to do any sort of is-ready check. - -A plug may be made "frontend-aware", allowing it, in conjunction with the -Station API, to update any frontends each time the plug's state changes. See -FrontendAwareBasePlug for more info. - -Example implementation of a plug: - - from openhtf import plugs - - class ExamplePlug(plugs.BasePlug): - '''A Plug that does nothing.''' - - def __init__(self): - print 'Instantiating %s!' % type(self).__name__ - - def DoSomething(self): - print '%s doing something!' % type(self).__name__ - - def tearDown(self): - # This method is optional. If implemented, it will be called at the end - # of the test. - print 'Tearing down %s!' % type(self).__name__ - -Example usage of the above plug: - - from openhtf import plugs - from my_custom_plugs_package import example - - @plugs.plug(example=example.ExamplePlug) - def TestPhase(test, example): - print 'Test phase started!' - example.DoSomething() - print 'Test phase done!' - -Putting all this together, when the test is run (with just that phase), you -would see the output (with other framework logs before and after): - - Instantiating ExamplePlug! - Test phase started! - ExamplePlug doing something! - Test phase done! - Tearing down ExamplePlug! - -Plugs will often need to use configuration values. The recommended way -of doing this is with the conf.inject_positional_args decorator: - - from openhtf import plugs - from openhtf.util import conf - - conf.declare('my_config_key', default_value='my_config_value') - - class ExamplePlug(plugs.BasePlug): - '''A plug that requires some configuration.''' - - @conf.inject_positional_args - def __init__(self, my_config_key) - self._my_config = my_config_key - -Note that Plug constructors shouldn't take any other arguments; the -framework won't pass any, so you'll get a TypeError. Any values that are only -known at run time must be either passed into other methods or set via explicit -setter methods. See openhtf/conf.py for details, but with the above -example, you would also need a configuration .yaml file with something like: - - my_config_key: my_config_value - -This will result in the ExamplePlug being constructed with -self._my_config having a value of 'my_config_value'. """ import collections import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Text, Tuple, Type, TypeVar, Union -import mutablerecords +import attr from openhtf import util -import openhtf.core.phase_descriptor +from openhtf.core import base_plugs +from openhtf.core import phase_descriptor from openhtf.util import classproperty from openhtf.util import conf from openhtf.util import data @@ -108,28 +36,24 @@ def __init__(self, my_config_key) from openhtf.util import threads import six - _LOG = logging.getLogger(__name__) +_BASE_PLUGS_LOG = base_plugs._LOG # pylint: disable=protected-access +conf.declare( + 'plug_teardown_timeout_s', + default_value=0, + description='Timeout (in seconds) for each plug tearDown function if > 0; ' + 'otherwise, will wait an unlimited time.') -conf.declare('plug_teardown_timeout_s', default_value=0, description= - 'Timeout (in seconds) for each plug tearDown function if > 0; ' - 'otherwise, will wait an unlimited time.') - - -PlugDescriptor = collections.namedtuple('PlugDescriptor', ['mro']) # pylint: disable=invalid-name - -# Placeholder for a specific plug to be provided before test execution. -# -# Use the with_plugs() method to provide the plug before test execution. The -# with_plugs() method checks to make sure the substitute plug is a subclass of -# the PlugPlaceholder's base_class. -PlugPlaceholder = collections.namedtuple('PlugPlaceholder', ['base_class']) # pylint: disable=invalid-name +# TODO(arsharma): Remove this aliases when users have moved to using the core +# library. +BasePlug = base_plugs.BasePlug +FrontendAwareBasePlug = base_plugs.FrontendAwareBasePlug -class PhasePlug(mutablerecords.Record( - 'PhasePlug', ['name', 'cls'], {'update_kwargs': True})): - """Information about the use of a plug in a phase.""" +@attr.s(slots=True, frozen=True) +class PlugDescriptor(object): + mro = attr.ib(type=List[Text]) class PlugOverrideError(Exception): @@ -140,81 +64,10 @@ class DuplicatePlugError(Exception): """Raised when the same plug is required multiple times on a phase.""" -class InvalidPlugError(Exception): - """Raised when a plug declaration or requested name is invalid.""" - - -class BasePlug(object): - """All plug types must subclass this type. - - Attributes: - logger: This attribute will be set by the PlugManager (and as such it - doesn't appear here), and is the same logger as passed into test - phases via TestApi. - """ - # Override this to True in subclasses to support remote Plug access. - enable_remote = False - # Allow explicitly disabling remote access to specific attributes. - disable_remote_attrs = set() - # Override this to True in subclasses to support using with_plugs with this - # plug without needing to use placeholder. This will only affect the classes - # that explicitly define this; subclasses do not share the declaration. - auto_placeholder = False - # Default logger to be used only in __init__ of subclasses. - # This is overwritten both on the class and the instance so don't store - # a copy of it anywhere. - logger = _LOG - - @classproperty - def placeholder(cls): - """Returns a PlugPlaceholder for the calling class.""" - return PlugPlaceholder(cls) - - def _asdict(self): - """Return a dictionary representation of this plug's state. - - This is called repeatedly during phase execution on any plugs that are in - use by that phase. The result is reported via the Station API by the - PlugManager (if the Station API is enabled, which is the default). - - Note that this method is called in a tight loop, it is recommended that you - decorate it with functions.call_at_most_every() to limit the frequency at - which updates happen (pass a number of seconds to it to limit samples to - once per that number of seconds). - - You can also implement an `as_base_types` function that can return a dict - where the values must be base types at all levels. This can help prevent - recursive copying, which is time intensive. - """ - return {} - - def tearDown(self): - """This method is called automatically at the end of each Test execution.""" - pass - - @classmethod - def uses_base_tear_down(cls): - """Checks whether the tearDown method is the BasePlug implementation.""" - this_tear_down = getattr(cls, 'tearDown') - base_tear_down = getattr(BasePlug, 'tearDown') - return this_tear_down.__code__ is base_tear_down.__code__ - - -class FrontendAwareBasePlug(BasePlug, util.SubscribableStateMixin): - """A plug that notifies of any state updates. - - Plugs inheriting from this class may be used in conjunction with the Station - API to update any frontends each time the plug's state changes. The plug - should call notify_update() when and only when the state returned by _asdict() - changes. - - Since the Station API runs in a separate thread, the _asdict() method of - frontend-aware plugs should be written with thread safety in mind. - """ - enable_remote = True - - -def plug(update_kwargs=True, **plugs_map): +def plug( + update_kwargs: bool = True, + **plugs_map: Union[Type[base_plugs.BasePlug], base_plugs.PlugPlaceholder] +) -> Callable[['phase_descriptor.PhaseT'], 'phase_descriptor.PhaseDescriptor']: """Creates a decorator that passes in plugs when invoked. This function returns a decorator for a function that will replace positional @@ -232,16 +85,18 @@ def plug(update_kwargs=True, **plugs_map): A PhaseDescriptor that will pass plug instances in as kwargs when invoked. Raises: - InvalidPlugError: If a type is provided that is not a subclass of BasePlug. + base_plugs.InvalidPlugError: If a type is provided that is not a subclass of + BasePlug. """ for a_plug in plugs_map.values(): - if not (isinstance(a_plug, PlugPlaceholder) - or issubclass(a_plug, BasePlug)): - raise InvalidPlugError( - 'Plug %s is not a subclass of plugs.BasePlug nor a placeholder ' + if not (isinstance(a_plug, base_plugs.PlugPlaceholder) or + issubclass(a_plug, base_plugs.BasePlug)): + raise base_plugs.InvalidPlugError( + 'Plug %s is not a subclass of base_plugs.BasePlug nor a placeholder ' 'for one' % a_plug) - def result(func): + def result( + func: 'phase_descriptor.PhaseT') -> 'phase_descriptor.PhaseDescriptor': """Wrap the given function and return the wrapper. Args: @@ -255,35 +110,39 @@ def result(func): DuplicatePlugError: If a plug name is declared twice for the same function. """ - phase = openhtf.core.phase_descriptor.PhaseDescriptor.wrap_or_copy(func) - duplicates = (frozenset(p.name for p in phase.plugs) & - frozenset(plugs_map)) + phase = phase_descriptor.PhaseDescriptor.wrap_or_copy(func) + duplicates = (frozenset(p.name for p in phase.plugs) & frozenset(plugs_map)) if duplicates: - raise DuplicatePlugError( - 'Plugs %s required multiple times on phase %s' % (duplicates, func)) + raise DuplicatePlugError('Plugs %s required multiple times on phase %s' % + (duplicates, func)) phase.plugs.extend([ - PhasePlug(name, a_plug, update_kwargs=update_kwargs) - for name, a_plug in six.iteritems(plugs_map)]) + base_plugs.PhasePlug(name, a_plug, update_kwargs=update_kwargs) + for name, a_plug in six.iteritems(plugs_map) + ]) return phase + return result class _PlugTearDownThread(threads.KillableThread): """Killable thread that runs a plug's tearDown function.""" - def __init__(self, a_plug, *args, **kwargs): + def __init__(self, a_plug: base_plugs.BasePlug, *args: Any, **kwargs: Any): super(_PlugTearDownThread, self).__init__(*args, **kwargs) self._plug = a_plug - def _thread_proc(self): + def _thread_proc(self) -> None: try: self._plug.tearDown() except Exception: # pylint: disable=broad-except # Including the stack trace from ThreadTerminationErrors received when # killed. - _LOG.warning('Exception calling tearDown on %s:', - self._plug, exc_info=True) + _LOG.warning( + 'Exception calling tearDown on %s:', self._plug, exc_info=True) + + +PlugT = TypeVar('PlugT', bound=base_plugs.BasePlug) class PlugManager(object): @@ -297,21 +156,24 @@ class PlugManager(object): main framework thread anyway. Attributes: - _plug_types: Initial set of plug types, additional plug types may be - passed into calls to initialize_plugs(). + _plug_types: Initial set of plug types, additional plug types may be passed + into calls to initialize_plugs(). _plugs_by_type: Dict mapping plug type to plug instance. _plugs_by_name: Dict mapping plug name to plug instance. _plug_descriptors: Dict mapping plug type to plug descriptor. logger: logging.Logger instance that can save logs to the running test - record. + record. """ - def __init__(self, plug_types=None, record_logger=None): + def __init__(self, + plug_types: Set[Type[base_plugs.BasePlug]] = None, + record_logger: Optional[logging.Logger] = None): self._plug_types = plug_types or set() - for plug in self._plug_types: - if isinstance(plug, PlugPlaceholder): - raise InvalidPlugError('Plug %s is a placeholder, replace it using ' - 'with_plugs().' % plug) + for plug_type in self._plug_types: + if isinstance(plug_type, base_plugs.PlugPlaceholder): + raise base_plugs.InvalidPlugError( + 'Plug {} is a placeholder, replace it using with_plugs().'.format( + plug_type)) self._plugs_by_type = {} self._plugs_by_name = {} self._plug_descriptors = {} @@ -319,10 +181,10 @@ def __init__(self, plug_types=None, record_logger=None): record_logger = _LOG self.logger = record_logger.getChild('plug') - def as_base_types(self): + def as_base_types(self) -> Dict[Text, Any]: return { 'plug_descriptors': { - name: dict(descriptor._asdict()) # Convert OrderedDict to dict. + name: attr.asdict(descriptor) for name, descriptor in six.iteritems(self._plug_descriptors) }, 'plug_states': { @@ -331,11 +193,12 @@ def as_base_types(self): }, } - def _make_plug_descriptor(self, plug_type): + def _make_plug_descriptor( + self, plug_type: Type[base_plugs.BasePlug]) -> PlugDescriptor: """Returns the plug descriptor, containing info about this plug type.""" return PlugDescriptor(self.get_plug_mro(plug_type)) - def get_plug_mro(self, plug_type): + def get_plug_mro(self, plug_type: Type[base_plugs.BasePlug]) -> List[Text]: """Returns a list of names identifying the plug classes in the plug's MRO. For example: @@ -343,32 +206,40 @@ def get_plug_mro(self, plug_type): Or: ['openhtf.plugs.user_input.UserInput', 'my_module.advanced_user_input.AdvancedUserInput'] + + Args: + plug_type: The plug class to get the MRO for. """ - ignored_classes = (BasePlug, FrontendAwareBasePlug) + ignored_classes = (base_plugs.BasePlug, base_plugs.FrontendAwareBasePlug) return [ - self.get_plug_name(base_class) for base_class in plug_type.mro() - if (issubclass(base_class, BasePlug) and + self.get_plug_name(base_class) # pylint: disable=g-complex-comprehension + for base_class in plug_type.mro() + if (issubclass(base_class, base_plugs.BasePlug) and base_class not in ignored_classes) ] - def get_plug_name(self, plug_type): + def get_plug_name(self, plug_type: Type[base_plugs.BasePlug]) -> Text: """Returns the plug's name, which is the class name and module. For example: 'openhtf.plugs.user_input.UserInput' + + Args: + plug_type: The plug class to get the name of. """ return '%s.%s' % (plug_type.__module__, plug_type.__name__) - def initialize_plugs(self, plug_types=None): + def initialize_plugs( + self, + plug_types: Optional[Set[Type[base_plugs.BasePlug]]] = None) -> None: """Instantiate required plugs. Instantiates plug types and saves the instances in self._plugs_by_type for use in provide_plugs(). Args: - plug_types: Plug types may be specified here rather than passed - into the constructor (this is used primarily for unit testing - phases). + plug_types: Plug types may be specified here rather than passed into the + constructor (this is used primarily for unit testing phases). """ types = plug_types if plug_types is not None else self._plug_types for plug_type in types: @@ -378,12 +249,13 @@ def initialize_plugs(self, plug_types=None): if plug_type in self._plugs_by_type: continue try: - if not issubclass(plug_type, BasePlug): - raise InvalidPlugError( - 'Plug type "%s" is not an instance of BasePlug' % plug_type) - if plug_type.logger != _LOG: + if not issubclass(plug_type, base_plugs.BasePlug): + raise base_plugs.InvalidPlugError( + 'Plug type "{}" is not an instance of base_plugs.BasePlug'.format( + plug_type)) + if plug_type.logger != _BASE_PLUGS_LOG: # They put a logger attribute on the class itself, overriding ours. - raise InvalidPlugError( + raise base_plugs.InvalidPlugError( 'Do not override "logger" in your plugs.', plug_type) # Override the logger so that __init__'s logging goes into the record. @@ -392,12 +264,12 @@ def initialize_plugs(self, plug_types=None): plug_instance = plug_type() finally: # Now set it back since we'll give the instance a logger in a moment. - plug_type.logger = _LOG - # Set the logger attribute directly (rather than in BasePlug) so we - # don't depend on subclasses' implementation of __init__ to have it - # set. - if plug_instance.logger != _LOG: - raise InvalidPlugError( + plug_type.logger = _BASE_PLUGS_LOG + # Set the logger attribute directly (rather than in base_plugs.BasePlug) + # so we don't depend on subclasses' implementation of __init__ to have + # it set. + if plug_instance.logger != _BASE_PLUGS_LOG: + raise base_plugs.InvalidPlugError( 'Do not set "self.logger" in __init__ in your plugs', plug_type) else: # Now the instance has its own copy of the test logger. @@ -408,7 +280,8 @@ def initialize_plugs(self, plug_types=None): raise self.update_plug(plug_type, plug_instance) - def get_plug_by_class_path(self, plug_name): + def get_plug_by_class_path(self, + plug_name: Text) -> Optional[base_plugs.BasePlug]: """Get a plug instance by name (class path). This provides a way for extensions to OpenHTF to access plug instances for @@ -422,7 +295,7 @@ def get_plug_by_class_path(self, plug_name): """ return self._plugs_by_name.get(plug_name) - def update_plug(self, plug_type, plug_value): + def update_plug(self, plug_type: Type[PlugT], plug_value: PlugT) -> None: """Update internal data stores with the given plug value for plug type. Safely tears down the old instance if one was already created, but that's @@ -432,6 +305,10 @@ def update_plug(self, plug_type, plug_value): Note this should only be used inside unittests, as this mechanism is not compatible with RemotePlug support. + + Args: + plug_type: The plug class to update. + plug_value: The plug class instance to store. """ self._plug_types.add(plug_type) if plug_type in self._plugs_by_type: @@ -441,11 +318,13 @@ def update_plug(self, plug_type, plug_value): self._plugs_by_name[plug_name] = plug_value self._plug_descriptors[plug_name] = self._make_plug_descriptor(plug_type) - def provide_plugs(self, plug_name_map): + def provide_plugs( + self, plug_name_map: Iterable[Tuple[Text, Type[base_plugs.BasePlug]]] + ) -> Dict[Text, base_plugs.BasePlug]: """Provide the requested plugs [(name, type),] as {name: plug instance}.""" return {name: self._plugs_by_type[cls] for name, cls in plug_name_map} - def tear_down_plugs(self): + def tear_down_plugs(self) -> None: """Call tearDown() on all instantiated plugs. Note that initialize_plugs must have been called before calling @@ -463,9 +342,9 @@ def tear_down_plugs(self): name = '' % plug_type thread = _PlugTearDownThread(plug_instance, name=name) thread.start() - timeout_s = (conf.plug_teardown_timeout_s - if conf.plug_teardown_timeout_s - else None) + timeout_s = ( + conf.plug_teardown_timeout_s + if conf.plug_teardown_timeout_s else None) thread.join(timeout_s) if thread.is_alive(): thread.kill() @@ -474,7 +353,9 @@ def tear_down_plugs(self): self._plugs_by_type.clear() self._plugs_by_name.clear() - def wait_for_plug_update(self, plug_name, remote_state, timeout_s): + def wait_for_plug_update( + self, plug_name: Text, remote_state: Dict[Text, Any], + timeout_s: Union[int, float]) -> Optional[Dict[Text, Any]]: """Wait for a change in the state of a frontend-aware plug. Args: @@ -486,26 +367,30 @@ def wait_for_plug_update(self, plug_name, remote_state, timeout_s): An updated state, or None if the timeout runs out. Raises: - InvalidPlugError: The plug can't be waited on either because it's not in - use or it's not a frontend-aware plug. + base_plugs.InvalidPlugError: The plug can't be waited on either because + it's not in use or it's not a frontend-aware plug. """ - plug = self._plugs_by_name.get(plug_name) + plug_instance = self._plugs_by_name.get(plug_name) - if plug is None: - raise InvalidPlugError('Cannot wait on unknown plug "%s".' % plug_name) + if plug_instance is None: + raise base_plugs.InvalidPlugError( + 'Cannot wait on unknown plug "{}".'.format(plug_name)) - if not isinstance(plug, FrontendAwareBasePlug): - raise InvalidPlugError('Cannot wait on a plug %s that is not an subclass ' - 'of FrontendAwareBasePlug.' % plug_name) + if not isinstance(plug_instance, base_plugs.FrontendAwareBasePlug): + raise base_plugs.InvalidPlugError( + 'Cannot wait on a plug {} that is not an subclass ' + 'of FrontendAwareBasePlug.'.format(plug_name)) - state, update_event = plug.asdict_with_event() + state, update_event = plug_instance.asdict_with_event() if state != remote_state: return state if update_event.wait(timeout_s): - return plug._asdict() + return plug_instance._asdict() - def get_frontend_aware_plug_names(self): + def get_frontend_aware_plug_names(self) -> List[Text]: """Returns the names of frontend-aware plugs.""" - return [name for name, plug in six.iteritems(self._plugs_by_name) - if isinstance(plug, FrontendAwareBasePlug)] + return [ + name for name, plug in six.iteritems(self._plugs_by_name) + if isinstance(plug, base_plugs.FrontendAwareBasePlug) + ] diff --git a/openhtf/plugs/cambrionix/__init__.py b/openhtf/plugs/cambrionix/__init__.py index 447f9c676..4049117c2 100644 --- a/openhtf/plugs/cambrionix/__init__.py +++ b/openhtf/plugs/cambrionix/__init__.py @@ -11,31 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Plug for a Cambrionix device.""" import subprocess import time from openhtf.plugs.usb import local_usb + class EtherSync(object): - """EtherSync object for the access of usb device connected to - Cambrionix unit.""" + """EtherSync object for the access of usb device connected to Cambrionix unit.""" port_map = { - '1':'112', - '2':'111', - '3':'114', - '4':'113', - '5':'212', - '6':'211', - '7':'214', - '8':'213', + '1': '112', + '2': '111', + '3': '114', + '4': '113', + '5': '212', + '6': '211', + '7': '214', + '8': '213', } def __init__(self, mac_addr): """Construct a EtherSync object. - Args: - mac_addr: mac address of the Cambrionix unit for EtherSync. + Args: + mac_addr: mac address of the Cambrionix unit for EtherSync. """ addr_info = mac_addr.lower().split(':') if len(addr_info) < 6: @@ -45,40 +46,41 @@ def __init__(self, mac_addr): self._addr = ''.join(addr_info[2:]) def get_usb_serial(self, port_num): - """Get the device serial number + """Get the device serial number. Args: - port_num: port number on the Cambrionix unit + port_num: port number on the Cambrionix unit. - Return: - usb device serial number + Returns: + USB device serial number. """ port = self.port_map[str(port_num)] - arg = ''.join(['DEVICE INFO,', self._addr, '.', port]) + arg = 'DEVICE INFO,{}.{}'.format(self._addr, port) cmd = (['esuit64', '-t', arg]) - info = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + info = subprocess.check_output( + cmd, stderr=subprocess.STDOUT).decode('utf-8') serial = None - if "SERIAL" in info: + if 'SERIAL:' in info: serial_info = info.split('SERIAL:')[1] serial = serial_info.split('\n')[0].strip() use_info = info.split('BY')[1].split(' ')[1] if use_info == 'NO': cmd = (['esuit64', '-t', 'AUTO USE ALL']) subprocess.check_output(cmd, stderr=subprocess.STDOUT) - time.sleep(50.0/1000.0) + time.sleep(50.0 / 1000.0) else: raise ValueError('No USB device detected') return serial def open_usb_handle(self, port_num): - """open usb port + """Open USB port. Args: - port_num: port number on the Cambrionix unit + port_num: port number on the Cambrionix unit. - Return: - usb handle + Returns: + USB handle. """ serial = self.get_usb_serial(port_num) return local_usb.LibUsbHandle.open(serial_number=serial) diff --git a/openhtf/plugs/device_wrapping.py b/openhtf/plugs/device_wrapping.py index fbba9d1ac..01472e536 100644 --- a/openhtf/plugs/device_wrapping.py +++ b/openhtf/plugs/device_wrapping.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """OpenHTF base plugs for thinly wrapping existing device abstractions. Sometimes you already have a Python interface to a device or instrument; you @@ -23,7 +21,7 @@ import functools import types -import openhtf +from openhtf.core import base_plugs import six @@ -33,7 +31,7 @@ def short_repr(obj, max_len=40): Args: obj: An object for which to return a string representation. max_len: Maximum length of the returned string. Longer reprs will be turned - into a brief descriptive string giving the type and length of obj. + into a brief descriptive string giving the type and length of obj. """ obj_repr = repr(obj) if len(obj_repr) <= max_len: @@ -41,7 +39,7 @@ def short_repr(obj, max_len=40): return '<{} of length {}>'.format(type(obj).__name__, len(obj_repr)) -class DeviceWrappingPlug(openhtf.plugs.BasePlug): +class DeviceWrappingPlug(base_plugs.BasePlug): """A base plug for wrapping existing device abstractions. Attribute access is delegated to the _device attribute, which is normally set @@ -63,23 +61,25 @@ def __init__(self, ble_sniffer_host, ble_sniffer_port): counted on to do sufficient logging, some debug logging is provided here in the plug layer to show which attributes were called and with what arguments. - Args: - device: The device to wrap; must not be None. - Raises: - openhtf.plugs.InvalidPlugError: The _device attribute has the value None + base_plugs.InvalidPlugError: The _device attribute has the value None when attribute access is attempted. """ verbose = True # overwrite on subclass to disable logging_wrapper. def __init__(self, device): + """Constructor. + + Args: + device: The device to wrap; must not be None. + """ super(DeviceWrappingPlug, self).__init__() self._device = device if hasattr(self._device, 'tearDown') and self.uses_base_tear_down(): - self.logger.warning('Wrapped device %s implements a tearDown method, ' - 'but using the no-op BasePlug tearDown method.', - type(self._device)) + self.logger.warning( + 'Wrapped device %s implements a tearDown method, ' + 'but using the no-op BasePlug tearDown method.', type(self._device)) def __setattr__(self, name, value): if (name == '_device' or '_device' not in self.__dict__ or @@ -90,11 +90,8 @@ def __setattr__(self, name, value): def __getattr__(self, attr): if self._device is None: - raise openhtf.plugs.InvalidPlugError( + raise base_plugs.InvalidPlugError( 'DeviceWrappingPlug instances must set the _device attribute.') - if attr == 'as_base_types': - return super(DeviceWrappingPlug, self).__getattr__(attr) - attribute = getattr(self._device, attr) if not self.verbose or not isinstance(attribute, types.MethodType): @@ -102,13 +99,12 @@ def __getattr__(self, attr): # Attribute callable; return a wrapper that logs calls with args and kwargs. functools.wraps(attribute, assigned=('__name__', '__doc__')) + def logging_wrapper(*args, **kwargs): """Wraps a callable with a logging statement.""" args_strings = tuple(short_repr(arg) for arg in args) - kwargs_strings = tuple( - ('%s=%s' % (key, short_repr(val)) - for key, val in six.iteritems(kwargs)) - ) + kwargs_strings = tuple(('%s=%s' % (key, short_repr(val)) + for key, val in six.iteritems(kwargs))) log_line = '%s calling "%s" on device.' % (type(self).__name__, attr) if args_strings or kwargs_strings: log_line += ' Args: \n %s' % (', '.join(args_strings + kwargs_strings)) @@ -116,4 +112,3 @@ def logging_wrapper(*args, **kwargs): return attribute(*args, **kwargs) return logging_wrapper - diff --git a/openhtf/plugs/generic/serial_collection.py b/openhtf/plugs/generic/serial_collection.py index fb2bb0e65..f640a45e0 100644 --- a/openhtf/plugs/generic/serial_collection.py +++ b/openhtf/plugs/generic/serial_collection.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """OpenHTF plug for serial port. Allows for writing out to a serial port. @@ -20,18 +18,21 @@ import logging import threading +from typing import Optional + +from openhtf.core import base_plugs +from openhtf.util import conf try: - import serial + # pylint: disable=g-import-not-at-top + import serial # pytype: disable=import-error + # pylint: enable=g-import-not-at-top except ImportError: - logging.error('Failed to import pyserial. Please install the `serial_collection_plug` extra, ' - 'e.g. via `pip install openhtf[serial_collection_plug]`.') + logging.error( + 'Failed to import pyserial. Please install the `serial_collection_plug` extra, ' + 'e.g. via `pip install openhtf[serial_collection_plug]`.') raise -import openhtf -from openhtf.util import conf - - conf.declare( 'serial_collection_port', description='Port on which to collect serial data.', @@ -42,7 +43,7 @@ default_value=115200) -class SerialCollectionPlug(openhtf.plugs.BasePlug): +class SerialCollectionPlug(base_plugs.BasePlug): """Plug that collects data from a serial port. Spawns a thread that will open the configured serial port, continuously @@ -55,19 +56,24 @@ class SerialCollectionPlug(openhtf.plugs.BasePlug): # Serial library can raise these exceptions SERIAL_EXCEPTIONS = (serial.SerialException, ValueError) + _serial = None # type: serial.Serial + _serial_port = None # type: int + _collect = None # type: bool + _collection_thread = None # type: Optional[threading.Thread] + @conf.inject_positional_args def __init__(self, serial_collection_port, serial_collection_baud): super(SerialCollectionPlug, self).__init__() # Instantiate the port with no name, then add the name, so it won't be # opened until the collection context is entered. - self._serial = serial.Serial(port=None, - baudrate=serial_collection_baud, - timeout=1) + self._serial = serial.Serial( + port=None, baudrate=serial_collection_baud, timeout=1) self._serial.port = serial_collection_port self._collect = False self._collection_thread = None def start_collection(self, dest): + def _poll(): try: with open(dest, 'w+') as outfile: @@ -75,14 +81,14 @@ def _poll(): data = self._serial.readline().decode() outfile.write(data) except self.SERIAL_EXCEPTIONS: - self.logger.error('Serial port error. Stopping data collection.', - exc_info=True) + self.logger.error( + 'Serial port error. Stopping data collection.', exc_info=True) self._collect = True self._collection_thread = threading.Thread(target=_poll) self._collection_thread.daemon = True - self.logger.debug( - 'Starting serial data collection on port %s.' % self._serial.port) + self.logger.debug('Starting serial data collection on port %s.' % + self._serial.port) self._serial.open() self._collection_thread.start() @@ -94,8 +100,7 @@ def is_collecting(self): def stop_collection(self): if not self.is_collecting: - self.logger.warning( - 'Data collection was not running, cannot be stopped.') + self.logger.warning('Data collection was not running, cannot be stopped.') return self._collect = False self._collection_thread.join() diff --git a/openhtf/plugs/usb/__init__.py b/openhtf/plugs/usb/__init__.py index 2956967a6..a2fad00bc 100644 --- a/openhtf/plugs/usb/__init__.py +++ b/openhtf/plugs/usb/__init__.py @@ -1,4 +1,3 @@ - # Copyright 2014 Google Inc. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Plugs that provide access to USB devices via ADB/Fastboot. For details of what these interfaces look like, see adb_device.py and @@ -31,7 +28,7 @@ def MyPhase(test, adb): import logging import time -import openhtf.plugs as plugs +from openhtf.core import base_plugs from openhtf.plugs import cambrionix from openhtf.plugs.usb import adb_device from openhtf.plugs.usb import adb_protocol @@ -53,8 +50,8 @@ def MyPhase(test, adb): @functions.call_once def init_dependent_flags(): parser = argparse.ArgumentParser( - 'USB Plug flags', parents=[ - adb_protocol.ARG_PARSER, fastboot_protocol.ARG_PARSER], + 'USB Plug flags', + parents=[adb_protocol.ARG_PARSER, fastboot_protocol.ARG_PARSER], add_help=False) parser.parse_known_args() @@ -96,7 +93,7 @@ def _open_usb_handle(serial_number=None, **kwargs): return local_usb.LibUsbHandle.open(serial_number=serial_number, **kwargs) -class FastbootPlug(plugs.BasePlug): +class FastbootPlug(base_plugs.BasePlug): """Plug that provides fastboot.""" def __init__(self): @@ -114,7 +111,7 @@ def __getattr__(self, attr): return getattr(self._device, attr) -class AdbPlug(plugs.BasePlug): +class AdbPlug(base_plugs.BasePlug): """Plug that provides ADB.""" serial_number = None @@ -149,8 +146,7 @@ def connect(self): interface_class=adb_device.CLASS, interface_subclass=adb_device.SUBCLASS, interface_protocol=adb_device.PROTOCOL, - serial_number=self.serial_number), - **kwargs) + serial_number=self.serial_number), **kwargs) def __getattr__(self, attr): """Forward other attributes to the device.""" @@ -160,16 +156,17 @@ def __getattr__(self, attr): class AndroidTriggers(object): # pylint: disable=invalid-name """Test start and stop triggers for Android devices.""" + serial_number = None + @classmethod def _try_open(cls): """Try to open a USB handle.""" handle = None - for usb_cls, subcls, protocol in [(adb_device.CLASS, - adb_device.SUBCLASS, - adb_device.PROTOCOL), - (fastboot_device.CLASS, - fastboot_device.SUBCLASS, - fastboot_device.PROTOCOL)]: + for usb_cls, subcls, protocol in [ + (adb_device.CLASS, adb_device.SUBCLASS, adb_device.PROTOCOL), + (fastboot_device.CLASS, fastboot_device.SUBCLASS, + fastboot_device.PROTOCOL) + ]: try: handle = local_usb.LibUsbHandle.open( serial_number=cls.serial_number, @@ -190,8 +187,10 @@ def _try_open(cls): @classmethod def test_start_frontend(cls): """Start when frontend event comes, but get serial from USB.""" - prompt_for_test_start('Connect Android device and press ENTER.', - text_input=False)() + # TODO(arsharma): Reenable after reworking this; one cannot just directly + # call a phase. + # prompt_for_test_start( + # message='Connect Android device and press ENTER.', text_input=False)() return cls.test_start() @classmethod diff --git a/openhtf/plugs/usb/adb_device.py b/openhtf/plugs/usb/adb_device.py index bdfaba001..ede2c104b 100644 --- a/openhtf/plugs/usb/adb_device.py +++ b/openhtf/plugs/usb/adb_device.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """User-facing interface to an ADB device. This module provides the user-facing interface to controlling ADB devices with @@ -30,24 +28,25 @@ class, subclass, and protocol. """ +# pytype: skip-file + import logging import os.path -try: - from M2Crypto import RSA -except ImportError: - logging.error('Failed to import M2Crypto, did you pip install ' - 'openhtf[usb_plugs]?') - raise - from openhtf.plugs.usb import adb_protocol from openhtf.plugs.usb import filesync_service from openhtf.plugs.usb import shell_service from openhtf.plugs.usb import usb_exceptions - from openhtf.util import timeouts import six +try: + from M2Crypto import RSA # pylint: disable=g-import-not-at-top +except ImportError: + logging.error('Failed to import M2Crypto, did you pip install ' + 'openhtf[usb_plugs]?') + raise + # USB interface class, subclass, and protocol for matching against. CLASS = 0xFF SUBCLASS = 0x42 @@ -89,10 +88,10 @@ def __init__(self, adb_connection): adb_connection) def __str__(self): - return '<%s: %s(%s) @%s>' % (type(self).__name__, - self._adb_connection.serial, - self._adb_connection.systemtype, - self._adb_connection.transport) + return '<%s: %s(%s) @%s>' % ( + type(self).__name__, self._adb_connection.serial, + self._adb_connection.systemtype, self._adb_connection.transport) + __repr__ = __str__ def get_system_type(self): @@ -128,8 +127,8 @@ def install(self, apk_path, destination_dir=None, timeout_ms=None): basename = os.path.basename(apk_path) destination_path = destination_dir + basename self.push(apk_path, destination_path, timeout_ms=timeout_ms) - return self.Shell('pm install -r "%s"' % destination_path, - timeout_ms=timeout_ms) + return self.Shell( + 'pm install -r "%s"' % destination_path, timeout_ms=timeout_ms) def push(self, source_file, device_filename, timeout_ms=None): """Push source_file to file on device. @@ -147,7 +146,9 @@ def push(self, source_file, device_filename, timeout_ms=None): source_file = open(source_file) self.filesync_service.send( - source_file, device_filename, mtime=mtime, + source_file, + device_filename, + mtime=mtime, timeout=timeouts.PolledTimeout.from_millis(timeout_ms)) def pull(self, device_filename, dest_file=None, timeout_ms=None): @@ -181,7 +182,7 @@ def command(self, command, raw=False, timeout_ms=None): return self.shell_service.command( str(command), raw=raw, timeout_ms=timeout_ms) - Shell = command #pylint: disable=invalid-name + Shell = command # pylint: disable=invalid-name def async_command(self, command, raw=False, timeout_ms=None): """See shell_service.ShellService.async_command().""" @@ -210,7 +211,7 @@ def _check_remote_command(self, destination, timeout_ms, success_msgs=None): stream = self._adb_connection.open_stream(destination, timeout) if not stream: raise usb_exceptions.AdbStreamUnavailableError( - 'Service %s not supported', destination) + 'Service %s not supported' % destination) try: message = stream.read(timeout_ms=timeout) # Some commands report success messages, ignore them. @@ -221,7 +222,7 @@ def _check_remote_command(self, destination, timeout_ms, success_msgs=None): # We expect this if the device is rebooting. return raise - raise usb_exceptions.AdbRemoteError('Device message: %s', message) + raise usb_exceptions.AdbRemoteError('Device message: %s' % message) def reboot(self, destination='', timeout_ms=None): """Reboot device, specify 'bootloader' for fastboot.""" @@ -233,9 +234,9 @@ def remount(self, timeout_ms=None): def root(self, timeout_ms=None): """Restart adbd as root on device.""" - self._check_remote_command('root:', timeout_ms, - ['already running as root', - 'restarting adbd as root']) + self._check_remote_command( + 'root:', timeout_ms, + ['already running as root', 'restarting adbd as root']) @classmethod def connect(cls, usb_handle, **kwargs): diff --git a/openhtf/plugs/usb/adb_message.py b/openhtf/plugs/usb/adb_message.py index c17c75ba6..117e4b5a5 100644 --- a/openhtf/plugs/usb/adb_message.py +++ b/openhtf/plugs/usb/adb_message.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """This module contains a class to encapsulate ADB messages. See the following in the Android source for more details: @@ -40,6 +38,8 @@ message entity. See adb_protocol.py for the stateful components. """ +# pytype: skip-file + import collections import logging import string @@ -52,6 +52,7 @@ _LOG = logging.getLogger(__name__) + def make_wire_commands(*ids): """Assemble the commands.""" cmd_to_wire = { @@ -61,20 +62,20 @@ def make_wire_commands(*ids): return cmd_to_wire, wire_to_cmd -class RawAdbMessage(collections.namedtuple('RawAdbMessage', - ['cmd', 'arg0', 'arg1', - 'data_length', 'data_checksum', - 'magic'])): +class RawAdbMessage( + collections.namedtuple( + 'RawAdbMessage', + ['cmd', 'arg0', 'arg1', 'data_length', 'data_checksum', 'magic'])): """Helper class for handling the struct -> AdbMessage mapping.""" def to_adb_message(self, data): """Turn the data into an ADB message.""" - message = AdbMessage(AdbMessage.WIRE_TO_CMD.get(self.cmd), - self.arg0, self.arg1, data) + message = AdbMessage( + AdbMessage.WIRE_TO_CMD.get(self.cmd), self.arg0, self.arg1, data) if (len(data) != self.data_length or message.data_crc32 != self.data_checksum): raise usb_exceptions.AdbDataIntegrityError( - '%s (%s) received invalid data: %s', message, self, repr(data)) + '%s (%s) received invalid data: %s' % (message, self, repr(data))) return message @@ -152,12 +153,12 @@ def read_message(self, timeout): raise usb_exceptions.AdbProtocolError('Adb connection lost') try: - raw_message = RawAdbMessage(*struct.unpack( - AdbMessage.HEADER_STRUCT_FORMAT, raw_header)) + raw_message = RawAdbMessage( + *struct.unpack(AdbMessage.HEADER_STRUCT_FORMAT, raw_header)) except struct.error as exception: raise usb_exceptions.AdbProtocolError( - 'Unable to unpack ADB command (%s): %s (%s)', - AdbMessage.HEADER_STRUCT_FORMAT, raw_header, exception) + 'Unable to unpack ADB command (%s): %s (%s)' % + (AdbMessage.HEADER_STRUCT_FORMAT, raw_header, exception)) if raw_message.data_length > 0: if timeout.has_expired(): @@ -181,8 +182,8 @@ def read_until(self, expected_commands, timeout): exceptions that may be raised. Args: - expected_commands: Iterable of expected command responses, like - ('CNXN', 'AUTH'). + expected_commands: Iterable of expected command responses, like ('CNXN', + 'AUTH'). timeout: timeouts.PolledTimeout object to use for timeout. Returns: @@ -197,7 +198,7 @@ def read_until(self, expected_commands, timeout): lambda m: m.command in expected_commands, 0) if msg.command not in expected_commands: raise usb_exceptions.AdbTimeoutError( - 'Timed out establishing connection, waiting for: %s', + 'Timed out establishing connection, waiting for: %s' % expected_commands) return msg @@ -247,13 +248,7 @@ class AdbMessage(object): this message. To send a message over the header, send its header, followed by its data if it has any. - Attributes: - header - command - arg0 - arg1 - data - magic + Attributes: header command arg0 arg1 data magic """ PRINTABLE_DATA = set(string.printable) - set(string.whitespace) @@ -264,7 +259,7 @@ class AdbMessage(object): def __init__(self, command, arg0=0, arg1=0, data=''): if command not in self.CMD_TO_WIRE: - raise usb_exceptions.AdbProtocolError('Unrecognized ADB command: %s', + raise usb_exceptions.AdbProtocolError('Unrecognized ADB command: %s' % command) self._command = self.CMD_TO_WIRE[command] self.arg0 = arg0 @@ -275,9 +270,8 @@ def __init__(self, command, arg0=0, arg1=0, data=''): @property def header(self): """The message header.""" - return struct.pack( - self.HEADER_STRUCT_FORMAT, self._command, self.arg0, self.arg1, - len(self.data), self.data_crc32, self.magic) + return struct.pack(self.HEADER_STRUCT_FORMAT, self._command, self.arg0, + self.arg1, len(self.data), self.data_crc32, self.magic) @property def command(self): @@ -286,18 +280,15 @@ def command(self): def __str__(self): return '<%s: %s(%s, %s): %s (%s bytes)>' % ( - type(self).__name__, - self.command, - self.arg0, - self.arg1, - ''.join(char if char in self.PRINTABLE_DATA - else '.' for char in self.data[:64]), - len(self.data)) + type(self).__name__, self.command, self.arg0, self.arg1, ''.join( + char if char in self.PRINTABLE_DATA else '.' + for char in self.data[:64]), len(self.data)) + __repr__ = __str__ @property def data_crc32(self): - """Return the sum of all the data bytes. + """Returns the sum of all the data bytes. The "crc32" used by ADB is actually just a sum of all the bytes, but we name this data_crc32 to be consistent with ADB. diff --git a/openhtf/plugs/usb/adb_protocol.py b/openhtf/plugs/usb/adb_protocol.py index 42dbde03a..f2359070c 100644 --- a/openhtf/plugs/usb/adb_protocol.py +++ b/openhtf/plugs/usb/adb_protocol.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """ADB protocol implementation. Implements the ADB protocol as seen in Android's adb/adbd binaries, but only the @@ -77,28 +75,28 @@ """ import collections +import enum import itertools import logging +import sys import threading -from enum import Enum - from openhtf.plugs.usb import adb_message from openhtf.plugs.usb import usb_exceptions from openhtf.util import argv -from openhtf.util import exceptions from openhtf.util import timeouts +import six from six.moves import queue - ADB_MESSAGE_LOG = False -ARG_PARSER = argv.ModuleParser() -ARG_PARSER.add_argument('--adb_messsage_log', - action=argv.StoreTrueInModule, - target='%s.ADB_MESSAGE_LOG' % __name__, - help='Set to True to save all incoming and outgoing ' - 'AdbMessages and print them on Close().') +ARG_PARSER = argv.module_parser() +ARG_PARSER.add_argument( + '--adb_messsage_log', + action=argv.StoreTrueInModule, + target='%s.ADB_MESSAGE_LOG' % __name__, + help='Set to True to save all incoming and outgoing ' + 'AdbMessages and print them on Close().') _LOG = logging.getLogger(__name__) @@ -158,10 +156,10 @@ def is_closed(self): return self._transport.is_closed() def __str__(self): - return '<%s: (%s, %s->%s)>' % (type(self).__name__, - self._destination, + return '<%s: (%s, %s->%s)>' % (type(self).__name__, self._destination, self._transport.local_id, self._transport.remote_id) + __repr__ = __str__ def write(self, data, timeout_ms=None): @@ -169,8 +167,8 @@ def write(self, data, timeout_ms=None): Args: data: Data to write. - timeout_ms: Timeout to use for the write/Ack transaction, in - milliseconds (or as a PolledTimeout object). + timeout_ms: Timeout to use for the write/Ack transaction, in milliseconds + (or as a PolledTimeout object). Raises: AdbProtocolError: If an ACK is not received. @@ -180,8 +178,8 @@ def write(self, data, timeout_ms=None): timeout = timeouts.PolledTimeout.from_millis(timeout_ms) # Break the data up into our transport's maxdata sized WRTE messages. while data: - self._transport.write( - data[:self._transport.adb_connection.maxdata], timeout) + self._transport.write(data[:self._transport.adb_connection.maxdata], + timeout) data = data[self._transport.adb_connection.maxdata:] def read(self, length=0, timeout_ms=None): @@ -205,8 +203,8 @@ def read(self, length=0, timeout_ms=None): AdbStreamClosedError: The stream is already closed. AdbTimeoutError: Timed out waiting for a message. """ - return self._transport.read( - length, timeouts.PolledTimeout.from_millis(timeout_ms)) + return self._transport.read(length, + timeouts.PolledTimeout.from_millis(timeout_ms)) def read_until_close(self, timeout_ms=None): """Yield data until this stream is closed. @@ -233,7 +231,7 @@ def close(self, timeout_ms=100): self._transport.close(timeout_ms) -class AdbStreamTransport(object): # pylint: disable=too-many-instance-attributes +class AdbStreamTransport(object): # pylint: disable=too-many-instance-attributes """This class encapsulates the transport aspect of an ADB stream. This class handles the interface between AdbStreams and an AdbConnection, @@ -243,12 +241,17 @@ class AdbStreamTransport(object): # pylint: disable=too-many-instance-attributes doesn't have to maintain it for many AdbStreams. Attributes: + adb_connection: The connection to the Android device. local_id: The local stream id for the stream using this transport. remote_id: The remote stream id for the stream using this transport. message_queue: The Queue of AdbMessages intended for this stream. - closed: True if this transport has been closed, from either end. + closed_state: ClosedState. """ - ClosedState = Enum('ClosedState', ['CLOSED', 'PENDING', 'OPEN']) + + class ClosedState(enum.Enum): + CLOSED = 'CLOSED' + PENDING = 'PENDING' + OPEN = 'OPEN' def __init__(self, adb_connection, local_id, message_queue): self.adb_connection = adb_connection @@ -269,9 +272,9 @@ def __init__(self, adb_connection, local_id, message_queue): self._reader_lock = threading.Lock() def __str__(self): - return '<%s: (%s->%s)>' % (type(self).__name__, - self.local_id, + return '<%s: (%s->%s)>' % (type(self).__name__, self.local_id, self.remote_id) + __repr__ = __str__ def _set_or_check_remote_id(self, remote_id): @@ -281,8 +284,8 @@ def _set_or_check_remote_id(self, remote_id): self.remote_id = remote_id self.closed_state = self.ClosedState.OPEN elif self.remote_id != remote_id: - raise usb_exceptions.AdbProtocolError( - '%s remote-id change to %s', self, remote_id) + raise usb_exceptions.AdbProtocolError('%s remote-id change to %s' % + (self, remote_id)) def _send_command(self, command, timeout, data=''): """Send the given command/data over this transport. @@ -298,15 +301,15 @@ def _send_command(self, command, timeout, data=''): data: If provided, data to send with the AdbMessage. """ if len(data) > self.adb_connection.maxdata: - raise usb_exceptions.AdbProtocolError('Message data too long (%s>%s): %s', - len(data), - self.adb_connection.maxdata, data) + raise usb_exceptions.AdbProtocolError( + 'Message data too long (%s>%s): %s' % + (len(data), self.adb_connection.maxdata, data)) if not self.remote_id: # If we get here, we probably missed the OKAY response to our OPEN. We # should have failed earlier, but in case someone does something tricky # with multiple threads, we sanity check this here. - raise usb_exceptions.AdbProtocolError('%s send before OKAY: %s', - self, data) + raise usb_exceptions.AdbProtocolError('%s send before OKAY: %s' % + (self, data)) self.adb_connection.transport.write_message( adb_message.AdbMessage(command, self.local_id, self.remote_id, data), timeout) @@ -332,13 +335,13 @@ def _handle_message(self, message, handle_wrte=True): self._set_or_check_remote_id(message.arg0) if not self._expecting_okay: raise usb_exceptions.AdbProtocolError( - '%s received unexpected OKAY: %s', self, message) + '%s received unexpected OKAY: %s' % (self, message)) self._expecting_okay = False elif message.command == 'CLSE': self.closed_state = self.ClosedState.CLOSED elif not handle_wrte: raise usb_exceptions.AdbProtocolError( - '%s received WRTE before OKAY/CLSE: %s', self, message) + '%s received WRTE before OKAY/CLSE: %s' % (self, message)) else: with self._read_buffer_lock: self._read_buffer.append(message.data) @@ -402,7 +405,7 @@ def _read_messages_until_true(self, predicate, timeout): self._message_received.wait(timeout.remaining) if timeout.has_expired(): raise usb_exceptions.AdbTimeoutError( - '%s timed out reading messages.', self) + '%s timed out reading messages.' % self) finally: # Make sure we release this even if an exception occurred. self._message_received.release() @@ -425,8 +428,8 @@ def ensure_opened(self, timeout): Raises: AdbProtocolError: If we receive a WRTE message instead of OKAY/CLSE. """ - self._handle_message(self.adb_connection.read_for_stream(self, timeout), - handle_wrte=False) + self._handle_message( + self.adb_connection.read_for_stream(self, timeout), handle_wrte=False) return self.is_open() def is_open(self): @@ -458,13 +461,13 @@ def write(self, data, timeout): """Write data to this stream, using the given timeouts.PolledTimeout.""" if not self.remote_id: raise usb_exceptions.AdbStreamClosedError( - 'Cannot write() to half-opened %s', self) + 'Cannot write() to half-opened %s' % self) if self.closed_state != self.ClosedState.OPEN: - raise usb_exceptions.AdbStreamClosedError( - 'Cannot write() to closed %s', self) + raise usb_exceptions.AdbStreamClosedError('Cannot write() to closed %s' % + self) elif self._expecting_okay: raise usb_exceptions.AdbProtocolError( - 'Previous WRTE failed, %s in unknown state', self) + 'Previous WRTE failed, %s in unknown state' % self) # Make sure we only have one WRTE in flight at a time, because ADB doesn't # identify which WRTE it is ACK'ing when it sends the OKAY message back. @@ -554,7 +557,7 @@ def __init__(self, transport, maxdata, remote_banner): try: self.systemtype, self.serial, self.banner = remote_banner.split(':', 2) except ValueError: - raise usb_exceptions.AdbProtocolError('Received malformed banner %s', + raise usb_exceptions.AdbProtocolError('Received malformed banner %s' % remote_banner) self.transport = transport self.maxdata = maxdata @@ -613,7 +616,7 @@ def _handle_message_for_stream(self, stream_transport, message, timeout): """ if message.command not in ('OKAY', 'CLSE', 'WRTE'): raise usb_exceptions.AdbProtocolError( - '%s received unexpected message: %s', self, message) + '%s received unexpected message: %s' % self, message) if message.arg1 == stream_transport.local_id: # Ack writes immediately. @@ -621,11 +624,11 @@ def _handle_message_for_stream(self, stream_transport, message, timeout): # Make sure we don't get a WRTE before an OKAY/CLSE message. if not stream_transport.remote_id: raise usb_exceptions.AdbProtocolError( - '%s received WRTE before OKAY/CLSE: %s', - stream_transport, message) - self.transport.write_message(adb_message.AdbMessage( - 'OKAY', stream_transport.local_id, stream_transport.remote_id), - timeout) + '%s received WRTE before OKAY/CLSE: %s' % + (stream_transport, message)) + self.transport.write_message( + adb_message.AdbMessage('OKAY', stream_transport.local_id, + stream_transport.remote_id), timeout) elif message.command == 'CLSE': self.close_stream_transport(stream_transport, timeout) return message @@ -672,9 +675,9 @@ def open_stream(self, destination, timeout_ms=None): self.transport.write_message( adb_message.AdbMessage( command='OPEN', - arg0=stream_transport.local_id, arg1=0, - data=destination + '\0'), - timeout) + arg0=stream_transport.local_id, + arg1=0, + data=destination + '\0'), timeout) if not stream_transport.ensure_opened(timeout): return None return AdbStream(destination, stream_transport) @@ -700,9 +703,9 @@ def close_stream_transport(self, stream_transport, timeout): del self._stream_transport_map[stream_transport.local_id] # If we never got a remote_id, there's no CLSE message to send. if stream_transport.remote_id: - self.transport.write_message(adb_message.AdbMessage( - 'CLSE', stream_transport.local_id, stream_transport.remote_id), - timeout) + self.transport.write_message( + adb_message.AdbMessage('CLSE', stream_transport.local_id, + stream_transport.remote_id), timeout) return True return False @@ -727,7 +730,7 @@ def streaming_command(self, service, command='', timeout_ms=None): stream = self.open_stream('%s:%s' % (service, command), timeout) if not stream: raise usb_exceptions.AdbStreamUnavailableError( - '%s does not support service: %s', self, service) + '%s does not support service: %s' % (self, service)) for data in stream.read_until_close(timeout): yield data @@ -760,12 +763,12 @@ def read_for_stream(self, stream_transport, timeout_ms=None): corresponding CLSE message, and this AdbStream will be marked as closed. Args: - stream_transport: The AdbStreamTransport for the stream that is reading - an AdbMessage from this AdbConnection. + stream_transport: The AdbStreamTransport for the stream that is reading an + AdbMessage from this AdbConnection. timeout_ms: If provided, timeout, in milliseconds, to use. Note this timeout applies to this entire call, not for each individual Read, since - there may be multiple reads if messages for other streams are read. - This argument may be a timeouts.PolledTimeout. + there may be multiple reads if messages for other streams are read. This + argument may be a timeouts.PolledTimeout. Returns: AdbMessage that was read, guaranteed to be one of 'OKAY', 'CLSE', or @@ -811,8 +814,8 @@ def read_for_stream(self, stream_transport, timeout_ms=None): self._reader_lock.release() if timeout.has_expired(): - raise usb_exceptions.AdbTimeoutError( - 'Read timed out for %s', stream_transport) + raise usb_exceptions.AdbTimeoutError('Read timed out for %s' % + stream_transport) # The stream is no longer in the map, so it's closed, but check for any # queued messages. @@ -820,33 +823,36 @@ def read_for_stream(self, stream_transport, timeout_ms=None): return stream_transport.message_queue.get_nowait() except queue.Empty: raise usb_exceptions.AdbStreamClosedError( - 'Attempt to read from closed or unknown %s', stream_transport) + 'Attempt to read from closed or unknown %s' % stream_transport) @classmethod - def connect(cls, transport, rsa_keys=None, timeout_ms=1000, + def connect(cls, + transport, + rsa_keys=None, + timeout_ms=1000, auth_timeout_ms=100): """Establish a new connection to a device, connected via transport. Args: - transport: A transport to use for reads/writes from/to the device, - usually an instance of UsbHandle, but really it can be anything with - read() and write() methods. + transport: A transport to use for reads/writes from/to the device, usually + an instance of UsbHandle, but really it can be anything with read() and + write() methods. rsa_keys: List of AuthSigner subclass instances to be used for authentication. The device can either accept one of these via the sign method, or we will send the result of get_public_key from the first one if the device doesn't accept any of them. - timeout_ms: Timeout to wait for the device to respond to our CNXN - request. Actual timeout may take longer if the transport object passed - has a longer default timeout than timeout_ms, or if auth_timeout_ms is - longer than timeout_ms and public key auth is used. This argument may - be a PolledTimeout object. + timeout_ms: Timeout to wait for the device to respond to our CNXN request. + Actual timeout may take longer if the transport object passed has a + longer default timeout than timeout_ms, or if auth_timeout_ms is longer + than timeout_ms and public key auth is used. This argument may be a + PolledTimeout object. auth_timeout_ms: Timeout to wait for when sending a new public key. This is only relevant when we send a new public key. The device shows a - dialog and this timeout is how long to wait for that dialog. If used - in automation, this should be low to catch such a case as a failure - quickly; while in interactive settings it should be high to allow - users to accept the dialog. We default to automation here, so it's low - by default. This argument may be a PolledTimeout object. + dialog and this timeout is how long to wait for that dialog. If used in + automation, this should be low to catch such a case as a failure + quickly; while in interactive settings it should be high to allow users + to accept the dialog. We default to automation here, so it's low by + default. This argument may be a PolledTimeout object. Returns: An instance of AdbConnection that is connected to the device. @@ -864,9 +870,10 @@ def connect(cls, transport, rsa_keys=None, timeout_ms=1000, adb_transport = adb_message.AdbTransportAdapter(transport) adb_transport.write_message( adb_message.AdbMessage( - command='CNXN', arg0=ADB_VERSION, arg1=MAX_ADB_DATA, - data='host::%s\0' % ADB_BANNER), - timeout) + command='CNXN', + arg0=ADB_VERSION, + arg1=MAX_ADB_DATA, + data='host::%s\0' % ADB_BANNER), timeout) msg = adb_transport.read_until(('AUTH', 'CNXN'), timeout) if msg.command == 'CNXN': @@ -880,14 +887,15 @@ def connect(cls, transport, rsa_keys=None, timeout_ms=1000, # Loop through our keys, signing the last 'banner' or token. for rsa_key in rsa_keys: if msg.arg0 != cls.AUTH_TOKEN: - raise usb_exceptions.AdbProtocolError('Bad AUTH response: %s', msg) + raise usb_exceptions.AdbProtocolError('Bad AUTH response: %s' % msg) signed_token = rsa_key.sign(msg.data) adb_transport.write_message( adb_message.AdbMessage( - command='AUTH', arg0=cls.AUTH_SIGNATURE, arg1=0, - data=signed_token), - timeout) + command='AUTH', + arg0=cls.AUTH_SIGNATURE, + arg1=0, + data=signed_token), timeout) msg = adb_transport.read_until(('AUTH', 'CNXN'), timeout) if msg.command == 'CNXN': @@ -896,16 +904,20 @@ def connect(cls, transport, rsa_keys=None, timeout_ms=1000, # None of the keys worked, so send a public key. adb_transport.write_message( adb_message.AdbMessage( - command='AUTH', arg0=cls.AUTH_RSAPUBLICKEY, arg1=0, - data=rsa_keys[0].get_public_key() + '\0'), - timeout) + command='AUTH', + arg0=cls.AUTH_RSAPUBLICKEY, + arg1=0, + data=rsa_keys[0].get_public_key() + '\0'), timeout) try: msg = adb_transport.read_until( ('CNXN',), timeouts.PolledTimeout.from_millis(auth_timeout_ms)) except usb_exceptions.UsbReadFailedError as exception: if exception.is_timeout(): - exceptions.reraise(usb_exceptions.DeviceAuthError, - 'Accept auth key on device, then retry.') + six.reraise( + usb_exceptions.DeviceAuthError, + usb_exceptions.DeviceAuthError( + message='Accept auth key on device, then retry.'), + sys.exc_info()[2]) raise # The read didn't time-out, so we got a CNXN response. diff --git a/openhtf/plugs/usb/fastboot_device.py b/openhtf/plugs/usb/fastboot_device.py index c9653be76..47251bf60 100644 --- a/openhtf/plugs/usb/fastboot_device.py +++ b/openhtf/plugs/usb/fastboot_device.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Fastboot device.""" import logging @@ -23,13 +21,26 @@ from openhtf.util import timeouts # From fastboot.c -VENDORS = {0x18D1, 0x0451, 0x0502, 0x0FCE, 0x05C6, 0x22B8, 0x0955, - 0x413C, 0x2314, 0x0BB4, 0x8087} +VENDORS = frozenset([ + 0x18D1, + 0x0451, + 0x0502, + 0x0FCE, + 0x05C6, + 0x22B8, + 0x0955, + 0x413C, + 0x2314, + 0x0BB4, + 0x8087, +]) CLASS = 0xFF SUBCLASS = 0x42 PROTOCOL = 0x03 _LOG = logging.getLogger(__name__) + + class FastbootDevice(object): """Libusb fastboot wrapper with retries.""" @@ -49,12 +60,14 @@ def set_boot_config(self, name, value): def get_boot_config(self, name, info_cb=None): """Get bootconfig, either as full dict or specific value for key.""" result = {} + def default_info_cb(msg): """Default Info CB.""" if not msg.message: return key, value = msg.message.split(':', 1) result[key.strip()] = value.strip() + info_cb = info_cb or default_info_cb final_result = self.oem('bootconfig %s' % name, info_cb=info_cb) # Return INFO messages before the final OKAY message. @@ -77,6 +90,7 @@ def __getattr__(self, attr): # pylint: disable=invalid-name Args: attr: Attribute to get. + Returns: Either the attribute from the device or a retrying function-wrapper if attr is a method on the device. @@ -86,12 +100,14 @@ def __getattr__(self, attr): # pylint: disable=invalid-name val = getattr(self._protocol, attr) if callable(val): + def _retry_wrapper(*args, **kwargs): """Wrap the retry function.""" result = _retry_usb_function(self._num_retries, val, *args, **kwargs) - _LOG.debug('LIBUSB FASTBOOT: %s(*%s, **%s) -> %s', - attr, args, kwargs, result) + _LOG.debug('LIBUSB FASTBOOT: %s(*%s, **%s) -> %s', attr, args, kwargs, + result) return result + return _retry_wrapper return val @@ -101,14 +117,15 @@ def connect(cls, usb_handle, **kwargs): Args: usb_handle: UsbHandle instance to use for communication to the device. - **kwargs: Additional args to pass to the class constructor (currently - only num_retries). + **kwargs: Additional args to pass to the class constructor (currently only + num_retries). Returns: An instance of this class if the device connected successfully. """ return cls(fastboot_protocol.FastbootCommands(usb_handle), **kwargs) + def _retry_usb_function(count, func, *args, **kwargs): """Helper function to retry USB.""" helper = timeouts.RetryHelper(count) diff --git a/openhtf/plugs/usb/fastboot_protocol.py b/openhtf/plugs/usb/fastboot_protocol.py index b2d958678..547456914 100644 --- a/openhtf/plugs/usb/fastboot_protocol.py +++ b/openhtf/plugs/usb/fastboot_protocol.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """A libusb1-based fastboot implementation.""" import binascii @@ -25,10 +23,9 @@ from openhtf.util import argv import six - FASTBOOT_DOWNLOAD_CHUNK_SIZE_KB = 1024 -ARG_PARSER = argv.ModuleParser() +ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( '--fastboot_download_chunk_size_kb', default=FASTBOOT_DOWNLOAD_CHUNK_SIZE_KB, @@ -40,8 +37,14 @@ _LOG = logging.getLogger(__name__) DEFAULT_MESSAGE_CALLBACK = lambda m: _LOG.info('Got %s from device', m) -FastbootMessage = collections.namedtuple( # pylint: disable=invalid-name - 'FastbootMessage', ['message', 'header']) + + +class FastbootMessage( + collections.namedtuple('FastbootMessage', [ + 'message', + 'header', + ])): + pass class FastbootProtocol(object): @@ -72,8 +75,9 @@ def send_command(self, command, arg=None): command = '%s:%s' % (command, arg) self._write(six.StringIO(command), len(command)) - def handle_simple_responses( - self, timeout_ms=None, info_cb=DEFAULT_MESSAGE_CALLBACK): + def handle_simple_responses(self, + timeout_ms=None, + info_cb=DEFAULT_MESSAGE_CALLBACK): """Accepts normal responses from the device. Args: @@ -86,9 +90,12 @@ def handle_simple_responses( return self._accept_responses('OKAY', info_cb, timeout_ms=timeout_ms) # pylint: disable=too-many-arguments - def handle_data_sending(self, source_file, source_len, + def handle_data_sending(self, + source_file, + source_len, info_cb=DEFAULT_MESSAGE_CALLBACK, - progress_callback=None, timeout_ms=None): + progress_callback=None, + timeout_ms=None): """Handles the protocol for sending data to the device. Arguments: @@ -101,9 +108,10 @@ def handle_data_sending(self, source_file, source_len, Raises: FastbootTransferError: When fastboot can't handle this amount of data. - FastbootStateMismatch: Fastboot responded with the wrong packet type. - FastbootRemoteFailure: Fastboot reported failure. - FastbootInvalidResponse: Fastboot responded with an unknown packet type. + FastbootStateMismatchError: Fastboot responded with the wrong packet type. + FastbootRemoteFailureError: Fastboot reported failure. + FastbootInvalidResponseError: Fastboot responded with an unknown packet + type. Returns: OKAY packet's message. @@ -115,8 +123,8 @@ def handle_data_sending(self, source_file, source_len, accepted_size, = struct.unpack('>I', accepted_size) if accepted_size != source_len: raise usb_exceptions.FastbootTransferError( - 'Device refused to download %s bytes of data (accepts %s bytes)', - source_len, accepted_size) + 'Device refused to download %s bytes of data (accepts %s bytes)' % + (source_len, accepted_size)) self._write(source_file, accepted_size, progress_callback) return self._accept_responses('OKAY', info_cb, timeout_ms=timeout_ms) @@ -131,9 +139,10 @@ def _accept_responses(self, expected_header, info_cb, timeout_ms=None): timeout_ms: Timeout in milliseconds to wait for each response. Raises: - FastbootStateMismatch: Fastboot responded with the wrong packet type. - FastbootRemoteFailure: Fastboot reported failure. - FastbootInvalidResponse: Fastboot responded with an unknown packet type. + FastbootStateMismatchError: Fastboot responded with the wrong packet type. + FastbootRemoteFailureError: Fastboot reported failure. + FastbootInvalidResponseError: Fastboot responded with an unknown packet + type. Returns: OKAY packet's message. @@ -147,17 +156,17 @@ def _accept_responses(self, expected_header, info_cb, timeout_ms=None): info_cb(FastbootMessage(remaining, header)) elif header in self.FINAL_HEADERS: if header != expected_header: - raise usb_exceptions.FastbootStateMismatch( - 'Expected %s, got %s', expected_header, header) + raise usb_exceptions.FastbootStateMismatchError( + 'Expected %s, got %s' % (expected_header, header)) if header == 'OKAY': info_cb(FastbootMessage(remaining, header)) return remaining elif header == 'FAIL': info_cb(FastbootMessage(remaining, header)) - raise usb_exceptions.FastbootRemoteFailure('FAIL: %s', remaining) + raise usb_exceptions.FastbootRemoteFailureError('FAIL: %s' % remaining) else: - raise usb_exceptions.FastbootInvalidResponse( - 'Got unknown header %s and response %s', header, remaining) + raise usb_exceptions.FastbootInvalidResponseError( + 'Got unknown header %s and response %s' % (header, remaining)) def _handle_progress(self, total, progress_callback): # pylint: disable=no-self-use """Calls the callback with the current progress and total .""" @@ -174,7 +183,7 @@ def _handle_progress(self, total, progress_callback): # pylint: disable=no-self def _write(self, data, length, progress_callback=None): """Sends the data to the device, tracking progress with the callback.""" if progress_callback: - progress = self._handle_progress(length, progress_callback) + progress = self._handle_progress(length, progress_callback) # pylint: disable=assignment-from-no-return six.next(progress) while length: tmp = data.read(FASTBOOT_DOWNLOAD_CHUNK_SIZE_KB * 1024) @@ -213,8 +222,12 @@ def _simple_command(self, command, arg=None, **kwargs): return self._protocol.handle_simple_responses(**kwargs) # pylint: disable=too-many-arguments - def flash_from_file(self, partition, source_file, source_len=0, - info_cb=DEFAULT_MESSAGE_CALLBACK, progress_callback=None, + def flash_from_file(self, + partition, + source_file, + source_len=0, + info_cb=DEFAULT_MESSAGE_CALLBACK, + progress_callback=None, timeout_ms=None): """Flashes a partition from the file on disk. @@ -233,28 +246,32 @@ def flash_from_file(self, partition, source_file, source_len=0, # Fall back to stat. source_len = os.stat(source_file).st_size download_response = self.download( - source_file, source_len=source_len, info_cb=info_cb, + source_file, + source_len=source_len, + info_cb=info_cb, progress_callback=progress_callback) - flash_response = self.flash(partition, info_cb=info_cb, - timeout_ms=timeout_ms) + flash_response = self.flash( + partition, info_cb=info_cb, timeout_ms=timeout_ms) return download_response + flash_response # pylint: enable=too-many-arguments - def download(self, source_file, source_len=0, - info_cb=DEFAULT_MESSAGE_CALLBACK, progress_callback=None): + def download(self, + source_file, + source_len=0, + info_cb=DEFAULT_MESSAGE_CALLBACK, + progress_callback=None): """Downloads a file to the device. Args: source_file: A filename or file-like object to download to the device. source_len: Optional length of source_file. If source_file is a file-like - object and source_len is not provided, source_file is read into - memory. + object and source_len is not provided, source_file is read into memory. info_cb: Optional callback accepting FastbootMessage for text sent from - the bootloader. + the bootloader. progress_callback: Optional callback called with the percent of the - source_file downloaded. Note, this doesn't include progress of the - actual flashing. + source_file downloaded. Note, this doesn't include progress of the + actual flashing. Returns: Response to a download request, normally nothing. @@ -284,8 +301,8 @@ def flash(self, partition, timeout_ms=None, info_cb=DEFAULT_MESSAGE_CALLBACK): Returns: Response to a download request, normally nothing. """ - return self._simple_command('flash', arg=partition, info_cb=info_cb, - timeout_ms=timeout_ms) + return self._simple_command( + 'flash', arg=partition, info_cb=info_cb, timeout_ms=timeout_ms) def erase(self, partition, timeout_ms=None): """Erases the given partition.""" @@ -297,6 +314,7 @@ def get_var(self, var, info_cb=DEFAULT_MESSAGE_CALLBACK): Args: var: A variable the bootloader tracks, such as version. info_cb: See Download. Usually no messages. + Returns: Value of var according to the current bootloader. """ @@ -309,6 +327,7 @@ def oem(self, command, timeout_ms=None, info_cb=DEFAULT_MESSAGE_CALLBACK): command: The command to execute, such as 'poweroff' or 'bootconfig read'. timeout_ms: Optional timeout in milliseconds to wait for a response. info_cb: See Download. Messages vary based on command. + Returns: The final response from the device. """ @@ -323,14 +342,15 @@ def reboot(self, target_mode=None, timeout_ms=None): """Reboots the device. Args: - target_mode: Normal reboot when unspecified (or None). Can specify - other target modes, such as 'recovery' or 'bootloader'. + target_mode: Normal reboot when unspecified (or None). Can specify other + target modes, such as 'recovery' or 'bootloader'. timeout_ms: Optional timeout in milliseconds to wait for a response. + Returns: Usually the empty string. Depends on the bootloader and the target_mode. """ - return self._simple_command('reboot', arg=target_mode, - timeout_ms=timeout_ms) + return self._simple_command( + 'reboot', arg=target_mode, timeout_ms=timeout_ms) def reboot_bootloader(self, timeout_ms=None): """Reboots into the bootloader, usually equiv to Reboot('bootloader').""" diff --git a/openhtf/plugs/usb/filesync_service.py b/openhtf/plugs/usb/filesync_service.py index b3fd3ca96..5d0b6a97e 100644 --- a/openhtf/plugs/usb/filesync_service.py +++ b/openhtf/plugs/usb/filesync_service.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Implementation of the ADB SYNC protocol, used for push/pull commands. This protocol is build on top of ADB streams, see adb_protocol.py for a high @@ -105,6 +103,7 @@ u32 command = 'DONE' == 0x444F4E45 u32 size = 0 """ +# pytype: skip-file import collections import stat @@ -123,26 +122,34 @@ MAX_PUSH_DATA_BYTES = 64 * 1024 -DeviceFileStat = collections.namedtuple('DeviceFileStat', [ - 'filename', 'mode', 'size', 'mtime']) +class DeviceFileStat( + collections.namedtuple('DeviceFileStat', [ + 'filename', + 'mode', + 'size', + 'mtime', + ])): + pass def _make_message_type(name, attributes, has_data=True): """Make a message type for the AdbTransport subclasses.""" - def assert_command_is(self, command): # pylint: disable=invalid-name + def assert_command_is(self, command): """Assert that a message's command matches the given command.""" if self.command != command: - raise usb_exceptions.AdbProtocolError( - 'Expected %s command, received %s', command, self) + raise usb_exceptions.AdbProtocolError('Expected %s command, received %s' % + (command, self)) - return type(name, (collections.namedtuple(name, attributes),), - { - 'assert_command_is': assert_command_is, - 'has_data': has_data, - # Struct format on the wire has an unsigned int for each attr. - 'struct_format': '<%sI' % len(attributes.split()), - }) + return type( + name, + (collections.namedtuple(name, attributes),), + { + 'assert_command_is': assert_command_is, + 'has_data': has_data, + # Struct format on the wire has an unsigned int for each attr. + 'struct_format': '<%sI' % len(attributes.split()), + }) class FilesyncService(object): @@ -163,7 +170,7 @@ class FilesyncService(object): def __init__(self, stream): self.stream = stream - def __del__(self): # pylint: disable=invalid-name + def __del__(self): self.close() def close(self): @@ -191,9 +198,9 @@ def list(self, path, timeout=None): """ transport = DentFilesyncTransport(self.stream) transport.write_data('LIST', path, timeout) - return (DeviceFileStat(dent_msg.name, dent_msg.mode, - dent_msg.size, dent_msg.time) for dent_msg in - transport.read_until_done('DENT', timeout)) + return (DeviceFileStat(dent_msg.name, dent_msg.mode, dent_msg.size, + dent_msg.time) + for dent_msg in transport.read_until_done('DENT', timeout)) def recv(self, filename, dest_file, timeout=None): """Retrieve a file from the device into the file-like dest_file.""" @@ -227,7 +234,11 @@ def _check_for_fail_message(self, transport, exc_info, timeout): # pylint: disa raise_with_traceback(exc_info[0](exc_info[1]), traceback=exc_info[2]) # pylint: disable=too-many-arguments - def send(self, src_file, filename, st_mode=DEFAULT_PUSH_MODE, mtime=None, + def send(self, + src_file, + filename, + st_mode=DEFAULT_PUSH_MODE, + mtime=None, timeout=None): """Push a file-like object to the device. @@ -328,7 +339,15 @@ class AbstractFilesyncTransport(object): receive over this transport. """ CMD_TO_WIRE, WIRE_TO_CMD = adb_message.make_wire_commands( - 'STAT', 'LIST', 'SEND', 'RECV', 'DENT', 'DONE', 'DATA', 'OKAY', 'FAIL', + 'STAT', + 'LIST', + 'SEND', + 'RECV', + 'DENT', + 'DONE', + 'DATA', + 'OKAY', + 'FAIL', ) def __init__(self, stream): @@ -354,9 +373,9 @@ def __init__(self, stream): # pylint: disable=no-member def __str__(self): - return '<%s(%s) id(%x), Receives: %s>' % (type(self).__name__, self.stream, - id(self), - self.RECV_MSG_TYPE.__name__) + return '<%s(%s) id(%x), Receives: %s>' % ( + type(self).__name__, self.stream, id(self), self.RECV_MSG_TYPE.__name__) + __repr__ = __str__ def write_data(self, command, data, timeout=None): @@ -381,8 +400,8 @@ def write_message(self, msg, timeout=None): data = msg[-1] replace_dict[msg._fields[-1]] = len(data) - self.stream.write(struct.pack(msg.struct_format, - *msg._replace(**replace_dict)), timeout) + self.stream.write( + struct.pack(msg.struct_format, *msg._replace(**replace_dict)), timeout) if msg.has_data: self.stream.write(data, timeout) @@ -435,12 +454,12 @@ def read_message(self, timeout=None): raw_message = struct.unpack(self.RECV_MSG_TYPE.struct_format, raw_data) except struct.error: raise usb_exceptions.AdbProtocolError( - '%s expected format "%s", got data %s', self, - self.RECV_MSG_TYPE.struct_format, raw_data) + '%s expected format "%s", got data %s' % + (self, self.RECV_MSG_TYPE.struct_format, raw_data)) if raw_message[0] not in self.WIRE_TO_CMD: - raise usb_exceptions.AdbProtocolError( - 'Unrecognized command id: %s', raw_message) + raise usb_exceptions.AdbProtocolError('Unrecognized command id: %s' % + raw_message) # Swap out the wire command with the string equivalent. raw_message = (self.WIRE_TO_CMD[raw_message[0]],) + raw_message[1:] @@ -453,11 +472,11 @@ def read_message(self, timeout=None): raw_message = raw_message[:-1] + (self.stream.read(data_len, timeout),) if raw_message[0] not in self.VALID_RESPONSES: - raise usb_exceptions.AdbProtocolError( - '%s not a valid response for %s', raw_message[0], self) + raise usb_exceptions.AdbProtocolError('%s not a valid response for %s' % + (raw_message[0], self)) if raw_message[0] == 'FAIL': - raise usb_exceptions.AdbRemoteError( - 'Remote ADB failure: %s', raw_message) + raise usb_exceptions.AdbRemoteError('Remote ADB failure: %s' % + raw_message) return self.RECV_MSG_TYPE(*raw_message) @@ -474,11 +493,10 @@ class FilesyncMessageTypes(object): we do that. """ # pylint: disable=invalid-name - DoneMessage = _make_message_type('DoneMessage', - 'command mtime', - has_data=False) - StatMessage = _make_message_type('StatMessage', 'command mode size time', - has_data=False) + DoneMessage = _make_message_type( + 'DoneMessage', 'command mtime', has_data=False) + StatMessage = _make_message_type( + 'StatMessage', 'command mode size time', has_data=False) DentMessage = _make_message_type('DentMessage', 'command mode size time name') DataMessage = _make_message_type('DataMessage', 'command data') # pylint: enable=invalid-name diff --git a/openhtf/plugs/usb/local_usb.py b/openhtf/plugs/usb/local_usb.py index 9abcfb701..49cfa89b0 100644 --- a/openhtf/plugs/usb/local_usb.py +++ b/openhtf/plugs/usb/local_usb.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """UsbHandle implementation using libusb to communicate with local devices. This implementation of UsbHandle uses python libusb1 bindings to communicate @@ -24,17 +22,20 @@ import logging +from openhtf.plugs.usb import usb_exceptions +from openhtf.plugs.usb import usb_handle +import six + try: - import libusb1 - import usb1 + # pylint: disable=g-import-not-at-top + import libusb1 # pytype: disable=import-error + import usb1 # pytype: disable=import-error + # pylint: enable=g-import-not-at-top except ImportError: logging.error('Failed to import libusb, did you pip install ' 'openhtf[usb_plugs]?') raise -from openhtf.plugs.usb import usb_exceptions -from openhtf.plugs.usb import usb_handle -import six _LOG = logging.getLogger(__name__) @@ -57,8 +58,10 @@ def __init__(self, device, setting, name=None, default_timeout_ms=None): interface supported. IOError: If the device has been disconnected. """ - super(LibUsbHandle, self).__init__(device.getSerialNumber(), name=name, - default_timeout_ms=default_timeout_ms) + super(LibUsbHandle, self).__init__( + device.getSerialNumber(), + name=name, + default_timeout_ms=default_timeout_ms) self._setting = setting self._device = device @@ -111,9 +114,8 @@ def is_closed(self): @staticmethod def _device_to_sysfs_path(device): """Convert device to corresponding sysfs path.""" - return '%s-%s' % ( - device.getBusNumber(), - '.'.join([str(item) for item in device.GetPortNumberList()])) + return '%s-%s' % (device.getBusNumber(), '.'.join( + [str(item) for item in device.GetPortNumberList()])) @property def port_path(self): @@ -124,7 +126,8 @@ def port_path(self): def read(self, length, timeout_ms=None): try: return self._handle.bulkRead( - self._read_endpoint, length, + self._read_endpoint, + length, timeout=self._timeout_or_default(timeout_ms)) except libusb1.USBError as exception: raise usb_exceptions.UsbReadFailedError( @@ -135,7 +138,8 @@ def read(self, length, timeout_ms=None): def write(self, data, timeout_ms=None): try: return self._handle.bulkWrite( - self._write_endpoint, data, + self._write_endpoint, + data, timeout=self._timeout_or_default(timeout_ms)) except libusb1.USBError as exception: raise usb_exceptions.UsbWriteFailedError( @@ -163,8 +167,8 @@ def open(cls, **kwargs): handle = six.next(handle_iter) except StopIteration: # No matching interface, raise. - raise usb_exceptions.DeviceNotFoundError( - 'Open failed with args: %s', kwargs) + raise usb_exceptions.DeviceNotFoundError('Open failed with args: %s' % + kwargs) try: multiple_handle = six.next(handle_iter) @@ -179,8 +183,13 @@ def open(cls, **kwargs): # pylint: disable=too-many-arguments @classmethod - def iter_open(cls, name=None, interface_class=None, interface_subclass=None, - interface_protocol=None, serial_number=None, port_path=None, + def iter_open(cls, + name=None, + interface_class=None, + interface_subclass=None, + interface_protocol=None, + serial_number=None, + port_path=None, default_timeout_ms=None): """Find and yield locally connected devices that match. @@ -194,8 +203,8 @@ def iter_open(cls, name=None, interface_class=None, interface_subclass=None, interface_protocol: USB interface_protocol to match. serial_number: USB serial_number to match. port_path: USB Port path to match, like X-X.X.X - default_timeout_ms: Default timeout in milliseconds of reads/writes on - the handles yielded. + default_timeout_ms: Default timeout in milliseconds of reads/writes on the + handles yielded. Yields: UsbHandle instances that match any non-None args given. @@ -233,10 +242,11 @@ def iter_open(cls, name=None, interface_class=None, interface_subclass=None, setting.getProtocol() != interface_protocol): continue - yield cls(device, setting, name=name, - default_timeout_ms=default_timeout_ms) + yield cls( + device, setting, name=name, default_timeout_ms=default_timeout_ms) except libusb1.USBError as exception: if (exception.value != libusb1.libusb_error.forward_dict['LIBUSB_ERROR_ACCESS']): raise + # pylint: disable=too-many-arguments diff --git a/openhtf/plugs/usb/shell_service.py b/openhtf/plugs/usb/shell_service.py index 7806a0258..809262f5e 100644 --- a/openhtf/plugs/usb/shell_service.py +++ b/openhtf/plugs/usb/shell_service.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -"""Some handy interfaces to the ADB :shell service. +r"""Some handy interfaces to the ADB :shell service. The :shell service is pretty straightforward, you send 'shell:command' and the device runs /bin/sh -c 'command'. The ADB daemon on the device sets up a @@ -94,7 +92,7 @@ class AsyncCommandHandle(object): be False, otherwise it will be True. """ - def __init__(self, stream, stdin, stdout, timeout, is_raw): #pylint: disable=too-many-arguments + def __init__(self, stream, stdin, stdout, timeout, is_raw): # pylint: disable=too-many-arguments """Create a handle to use for interfacing with an async_command. Args: @@ -106,23 +104,23 @@ def __init__(self, stream, stdin, stdout, timeout, is_raw): #pylint: disable=to timeout: timeouts.PolledTimeout to use for the command. is_raw: If True, we'll do reads from stdin, otherwise we do readlines instead to play nicer with potential interactive uses (read doesn't - return until EOF, but interactively you want to send each line and - then see the response). stdout is treated the same in either case, - read is used - AdbStreams don't support readline. + return until EOF, but interactively you want to send each line and then + see the response). stdout is treated the same in either case, read is + used - AdbStreams don't support readline. """ self.stream = stream self.stdin = stdin self.stdout = stdout or six.StringIO() self.force_closed_or_timeout = False - self.reader_thread = threading.Thread(target=self._reader_thread_proc, - args=(timeout,)) + self.reader_thread = threading.Thread( + target=self._reader_thread_proc, args=(timeout,)) self.reader_thread.daemon = True self.reader_thread.start() if stdin: - self.writer_thread = threading.Thread(target=self._writer_thread_proc, - args=(is_raw,)) + self.writer_thread = threading.Thread( + target=self._writer_thread_proc, args=(is_raw,)) self.writer_thread.daemon = True self.writer_thread.start() @@ -148,10 +146,10 @@ def _reader_thread_proc(self, timeout): if self.stdout is not None: self.stdout.write(data) - def __enter__(self): # pylint: disable=invalid-name + def __enter__(self): return self - def __exit__(self, exc_type, exc_value, exc_tb): # pylint: disable=invalid-name + def __exit__(self, exc_type, exc_value, exc_tb): if exc_type: return False self.wait() @@ -180,8 +178,8 @@ def wait(self, timeout_ms=None): return value explicitly for None, as the output may be ''. """ closed = timeouts.loop_until_timeout_or_true( - timeouts.PolledTimeout.from_millis(timeout_ms), - self.stream.is_closed, .1) + timeouts.PolledTimeout.from_millis(timeout_ms), self.stream.is_closed, + .1) if closed: if hasattr(self.stdout, 'getvalue'): return self.stdout.getvalue() @@ -231,7 +229,11 @@ def streaming_command(self, command, raw=False, timeout_ms=None): return self.adb_connection.streaming_command('shell', command, timeout_ms) # pylint: disable=too-many-arguments - def async_command(self, command, stdin=None, stdout=None, raw=False, + def async_command(self, + command, + stdin=None, + stdout=None, + raw=False, timeout_ms=None): """Run the given command on the device asynchronously. @@ -241,8 +243,8 @@ def async_command(self, command, stdin=None, stdout=None, raw=False, could use sys.stdin and sys.stdout to emulate the 'adb shell' commandline. Args: - command: The command to run, will be run with /bin/sh -c 'command' on - the device. + command: The command to run, will be run with /bin/sh -c 'command' on the + device. stdin: File-like object to read from to pipe to the command's stdin. Can be None, in which case nothing will be written to the command's stdin. stdout: File-like object to write the command's output to. Can be None, @@ -265,12 +267,13 @@ def async_command(self, command, stdin=None, stdout=None, raw=False, stream = self.adb_connection.open_stream('shell:%s' % command, timeout) if not stream: raise usb_exceptions.AdbStreamUnavailableError( - '%s does not support service: shell', self) + '%s does not support service: shell' % self) if raw and stdin is not None: # Short delay to make sure the ioctl to set raw mode happens before we do # any writes to the stream, if we don't do this bad things happen... time.sleep(.1) return AsyncCommandHandle(stream, stdin, stdout, timeout, raw) + # pylint: enable=too-many-arguments @classmethod diff --git a/openhtf/plugs/usb/usb_exceptions.py b/openhtf/plugs/usb/usb_exceptions.py index 5619a04c5..b16c93811 100644 --- a/openhtf/plugs/usb/usb_exceptions.py +++ b/openhtf/plugs/usb/usb_exceptions.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Common exceptions for USB, ADB and Fastboot.""" import logging try: - import libusb1 + # pylint: disable=g-import-not-at-top + import libusb1 # pytype: disable=import-error + # pylint: enable=g-import-not-at-top except ImportError: logging.error('Failed to import libusb, did you pip install ' 'openhtf[usb_plugs]?') @@ -36,22 +36,6 @@ class CommonUsbError(Exception): human-readable, but keeps the arguments in case other code try-excepts it. """ - def __init__(self, message=None, *args): - if message is not None: - if '%' in message: - try: - message %= args - except TypeError: - # This is a fairly obscure failure, so we intercept it and emit a - # more useful error message. - _LOG.error('USB Exceptions expect a format-string, do not include ' - 'percent symbols to disable this functionality: %s', - message) - raise - super(CommonUsbError, self).__init__(message, *args) - else: - super(CommonUsbError, self).__init__(*args) - # USB exceptions, these are not specific to any particular protocol. class LibusbWrappingError(CommonUsbError): @@ -114,7 +98,7 @@ class DeviceAuthError(CommonUsbError): """Device authentication failed.""" -class AdbOperationException(Exception): +class AdbOperationExceptionError(Exception): """Failed to communicate over adb with device after multiple retries.""" @@ -151,13 +135,13 @@ class FastbootTransferError(CommonUsbError): """Transfer error.""" -class FastbootRemoteFailure(CommonUsbError): +class FastbootRemoteFailureError(CommonUsbError): """Remote error.""" -class FastbootStateMismatch(CommonUsbError): +class FastbootStateMismatchError(CommonUsbError): """Fastboot and uboot's state machines are arguing. You Lose.""" -class FastbootInvalidResponse(CommonUsbError): +class FastbootInvalidResponseError(CommonUsbError): """Fastboot responded with a header we didn't expect.""" diff --git a/openhtf/plugs/usb/usb_handle.py b/openhtf/plugs/usb/usb_handle.py index 064ccdf91..c4d80f115 100644 --- a/openhtf/plugs/usb/usb_handle.py +++ b/openhtf/plugs/usb/usb_handle.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Base interface for communicating with USB devices. This module provides the base classes required to support interfacing with USB @@ -21,11 +19,11 @@ A UsbHandle object represents a single USB Interface, *not* an entire device. """ -from future.utils import with_metaclass import abc import functools import logging +from future.utils import with_metaclass from openhtf.plugs.usb import usb_exceptions DEFAULT_TIMEOUT_MS = 5000 @@ -33,7 +31,7 @@ _LOG = logging.getLogger(__name__) -def requires_open_handle(method): # pylint: disable=invalid-name +def requires_open_handle(method): """Decorator to ensure a handle is open for certain methods. Subclasses should decorate their Read() and Write() with this rather than @@ -47,15 +45,18 @@ def requires_open_handle(method): # pylint: disable=invalid-name HandleClosedError: If this handle has been closed. Returns: - A wrapper around method that ensures the handle is open before calling through + A wrapper around method that ensures the handle is open before calling + through to the wrapped method. """ + @functools.wraps(method) def wrapper_requiring_open_handle(self, *args, **kwargs): """The wrapper to be returned.""" if self.is_closed(): raise usb_exceptions.HandleClosedError() return method(self, *args, **kwargs) + return wrapper_requiring_open_handle @@ -117,19 +118,20 @@ def __init__(self, serial_number, name=None, default_timeout_ms=None): self.name = name or '' self._default_timeout_ms = default_timeout_ms or DEFAULT_TIMEOUT_MS - def __del__(self): # pylint: disable=invalid-name + def __del__(self): if not self.is_closed(): _LOG.error('!!!!!USB!!!!! %s not closed!', type(self).__name__) def __str__(self): return '<%s: (%s %s)>' % (type(self).__name__, self.name, self.serial_number) + __repr__ = __str__ def _timeout_or_default(self, timeout_ms): """Specify a timeout or take the default.""" - return int(timeout_ms if timeout_ms is not None - else self._default_timeout_ms) + return int( + timeout_ms if timeout_ms is not None else self._default_timeout_ms) def flush_buffers(self): """Default implementation, calls Read() until it blocks.""" diff --git a/openhtf/plugs/usb/usb_handle_stub.py b/openhtf/plugs/usb/usb_handle_stub.py index 5e5b0b6cf..76c13bd4e 100644 --- a/openhtf/plugs/usb/usb_handle_stub.py +++ b/openhtf/plugs/usb/usb_handle_stub.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Stub USB handle implementation for testing.""" import binascii @@ -26,8 +24,8 @@ class StubUsbHandle(usb_handle.UsbHandle): PRINTABLE_DATA = set(string.printable) - set(string.whitespace) def __init__(self, ignore_writes=False): - super(StubUsbHandle, self).__init__('StubSerial', 'StubHandle', - default_timeout_ms=0) + super(StubUsbHandle, self).__init__( + 'StubSerial', 'StubHandle', default_timeout_ms=0) self.expected_write_data = None if ignore_writes else [] self.expected_read_data = [] self.closed = False @@ -45,9 +43,9 @@ def write(self, data, dummy=None): expected_data = self.expected_write_data.pop(0) if expected_data != data: - raise ValueError('Expected %s, got %s (%s)' % ( - self._dotify(expected_data), binascii.hexlify(data), - self._dotify(data))) + raise ValueError('Expected %s, got %s (%s)' % + (self._dotify(expected_data), binascii.hexlify(data), + self._dotify(data))) def read(self, length, dummy=None): """Stub Read method.""" @@ -55,8 +53,8 @@ def read(self, length, dummy=None): data = self.expected_read_data.pop(0) if length < len(data): raise ValueError( - 'Overflow packet length. Read %d bytes, got %d bytes: %s', - length, len(data), self._dotify(data)) + 'Overflow packet length. Read %d bytes, got %d bytes: %s' % + (length, len(data), self._dotify(data))) return data def close(self): diff --git a/openhtf/plugs/user_input.py b/openhtf/plugs/user_input.py index a2f8de1cd..905d8b9a2 100644 --- a/openhtf/plugs/user_input.py +++ b/openhtf/plugs/user_input.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """User input module for OpenHTF. Provides a plug which can be used to prompt the user for input. The prompt can be displayed in the console, the OpenHTF web GUI, and custom frontends. """ -from __future__ import print_function - -import collections import functools import logging import os @@ -29,10 +24,13 @@ import select import sys import threading +from typing import Any, Callable, Dict, Optional, Text, Tuple, Union import uuid -from openhtf import PhaseOptions +import attr +import openhtf from openhtf import plugs +from openhtf.core import base_plugs from openhtf.util import console_output from six.moves import input @@ -56,7 +54,12 @@ class PromptUnansweredError(Exception): """Raised when a prompt times out or otherwise comes back unanswered.""" -Prompt = collections.namedtuple('Prompt', 'id message text_input image_url') +@attr.s(slots=True, frozen=True) +class Prompt(object): + id = attr.ib(type=Text) + message = attr.ib(type=Text) + text_input = attr.ib(type=bool) + image_url = attr.ib(type=Optional[Text], default=None) class ConsolePrompt(threading.Thread): @@ -65,7 +68,10 @@ class ConsolePrompt(threading.Thread): This should not be used for processes that run in the background. """ - def __init__(self, message, callback, color=''): + def __init__(self, + message: Text, + callback: Callable[[Text], None], + color: Text = ''): """Initializes a ConsolePrompt. Args: @@ -81,29 +87,27 @@ def __init__(self, message, callback, color=''): self._stop_event = threading.Event() self._answered = False - def Stop(self): + def stop(self) -> None: """Mark this ConsolePrompt as stopped.""" self._stop_event.set() if not self._answered: - console_output.cli_print(os.linesep, color=self._color, - end='', logger=None) + console_output.cli_print( + os.linesep, color=self._color, end='', logger=None) _LOG.debug('Stopping ConsolePrompt--prompt was answered from elsewhere.') - def run(self): + def run(self) -> None: """Main logic for this thread to execute.""" if platform.system() == 'Windows': # Windows doesn't support file-like objects for select(), so fall back # to raw_input(). - response = input(''.join((self._message, - os.linesep, - PROMPT))) + response = input(''.join((self._message, os.linesep, PROMPT))) self._answered = True self._callback(response) return # First, display the prompt to the console. - console_output.cli_print(self._message, color=self._color, - end=os.linesep, logger=None) + console_output.cli_print( + self._message, color=self._color, end=os.linesep, logger=None) console_output.cli_print(PROMPT, color=self._color, end='', logger=None) sys.stdout.flush() @@ -129,45 +133,51 @@ def run(self): return -class UserInput(plugs.FrontendAwareBasePlug): +class UserInput(base_plugs.FrontendAwareBasePlug): """Get user input from inside test phases. Attributes: last_response: None, or a pair of (prompt_id, response) indicating the last - user response that was received by the plug. + user response that was received by the plug. """ def __init__(self): super(UserInput, self).__init__() - self.last_response = None - self._prompt = None - self._console_prompt = None - self._response = None + self.last_response = None # type: Optional[Tuple[Text, Text]] + self._prompt = None # type: Optional[Prompt] + self._console_prompt = None # type: Optional[ConsolePrompt] + self._response = None # type: Optional[Text] self._cond = threading.Condition(threading.RLock()) - def _asdict(self): + def _asdict(self) -> Optional[Dict[Text, Any]]: """Return a dictionary representation of the current prompt.""" with self._cond: if self._prompt is None: - return - return {'id': self._prompt.id, - 'message': self._prompt.message, - 'text-input': self._prompt.text_input, - 'image-url': self._prompt.image_url} - - def tearDown(self): + return None + return { + 'id': self._prompt.id, + 'message': self._prompt.message, + 'text-input': self._prompt.text_input + } + + def tearDown(self) -> None: self.remove_prompt() - def remove_prompt(self): + def remove_prompt(self) -> None: """Remove the prompt.""" with self._cond: self._prompt = None if self._console_prompt: - self._console_prompt.Stop() + self._console_prompt.stop() self._console_prompt = None self.notify_update() - def prompt(self, message, text_input=False, timeout_s=None, cli_color='', image_url = None): + def prompt(self, + message: Text, + text_input: bool = False, + timeout_s: Union[int, float, None] = None, + cli_color: Text = '', + image_url: Optional[Text] = None) -> Text: """Display a prompt and wait for a response. Args: @@ -175,6 +185,7 @@ def prompt(self, message, text_input=False, timeout_s=None, cli_color='', image_ text_input: A boolean indicating whether the user must respond with text. timeout_s: Seconds to wait before raising a PromptUnansweredError. cli_color: An ANSI color code, or the empty string. + image_url: Optional image URL to display or None. Returns: A string response, or the empty string if text_input was False. @@ -186,13 +197,18 @@ def prompt(self, message, text_input=False, timeout_s=None, cli_color='', image_ self.start_prompt(message, text_input, cli_color, image_url) return self.wait_for_prompt(timeout_s) - def start_prompt(self, message, text_input=False, cli_color='', image_url = None): + def start_prompt(self, + message: Text, + text_input: bool = False, + cli_color: Text = '', + image_url: Optional[Text] = None) -> Text: """Display a prompt. Args: message: A string to be presented to the user. text_input: A boolean indicating whether the user must respond with text. cli_color: An ANSI color code, or the empty string. + image_url: Optional image URL to display or None. Raises: MultiplePromptsError: There was already an existing prompt. @@ -202,14 +218,18 @@ def start_prompt(self, message, text_input=False, cli_color='', image_url = None """ with self._cond: if self._prompt: - raise MultiplePromptsError + raise MultiplePromptsError( + 'Multiple concurrent prompts are not supported.') prompt_id = uuid.uuid4().hex _LOG.debug('Displaying prompt (%s): "%s"%s', prompt_id, message, ', Expects text input.' if text_input else '') self._response = None self._prompt = Prompt( - id=prompt_id, message=message, text_input=text_input, image_url=image_url) + id=prompt_id, + message=message, + text_input=text_input, + image_url=image_url) if sys.stdin.isatty(): self._console_prompt = ConsolePrompt( message, functools.partial(self.respond, prompt_id), cli_color) @@ -218,7 +238,7 @@ def start_prompt(self, message, text_input=False, cli_color='', image_url = None self.notify_update() return prompt_id - def wait_for_prompt(self, timeout_s=None): + def wait_for_prompt(self, timeout_s: Union[int, float, None] = None) -> Text: """Wait for the user to respond to the current prompt. Args: @@ -240,7 +260,7 @@ def wait_for_prompt(self, timeout_s=None): raise PromptUnansweredError return self._response - def respond(self, prompt_id, response): + def respond(self, prompt_id: Text, response: Text) -> None: """Respond to the prompt with the given ID. If there is no active prompt or the given ID doesn't match the active @@ -249,24 +269,22 @@ def respond(self, prompt_id, response): Args: prompt_id: A string uniquely identifying the prompt. response: A string response to the given prompt. - - Returns: - True if the prompt with the given ID was active, otherwise False. """ _LOG.debug('Responding to prompt (%s): "%s"', prompt_id, response) with self._cond: if not (self._prompt and self._prompt.id == prompt_id): - return False + return self._response = response self.last_response = (prompt_id, response) self.remove_prompt() self._cond.notifyAll() - return True def prompt_for_test_start( - message='Enter a DUT ID in order to start the test.', timeout_s=60*60*24, - validator=lambda sn: sn, cli_color=''): + message: Text = 'Enter a DUT ID in order to start the test.', + timeout_s: Union[int, float, None] = 60 * 60 * 24, + validator: Callable[[Text], Text] = lambda sn: sn, + cli_color: Text = '') -> openhtf.PhaseDescriptor: """Returns an OpenHTF phase for use as a prompt-based start trigger. Args: @@ -276,9 +294,9 @@ def prompt_for_test_start( cli_color: An ANSI color code, or the empty string. """ - @PhaseOptions(timeout_s=timeout_s) + @openhtf.PhaseOptions(timeout_s=timeout_s) @plugs.plug(prompts=UserInput) - def trigger_phase(test, prompts): + def trigger_phase(test: openhtf.TestApi, prompts: UserInput) -> None: """Test start trigger that prompts the user for a DUT ID.""" dut_id = prompts.prompt( message, text_input=True, timeout_s=timeout_s, cli_color=cli_color) diff --git a/openhtf/util/__init__.py b/openhtf/util/__init__.py index d911cd0e9..c8ae629c5 100644 --- a/openhtf/util/__init__.py +++ b/openhtf/util/__init__.py @@ -11,22 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """One-off utilities.""" import logging import re import threading import time +import typing +from typing import Any, Callable, Dict, Iterator, Optional, Text, Tuple, Type, TypeVar, Union import weakref -import mutablerecords +import attr import six -def _log_every_n_to_logger(n, logger, level, message, *args): # pylint: disable=invalid-name +def _log_every_n_to_logger(n: int, logger: Optional[logging.Logger], level: int, + message: Text, *args: Any) -> Callable[[], bool]: """Logs the given message every n calls to a logger. Args: @@ -35,31 +36,36 @@ def _log_every_n_to_logger(n, logger, level, message, *args): # pylint: disable level: The logging level (e.g. logging.INFO). message: A message to log *args: Any format args for the message. + Returns: A method that logs and returns True every n calls. """ - logger = logger or logging.getLogger() - def _gen(): # pylint: disable=missing-docstring + logger = logger if logger else logging.getLogger() + + def _gen() -> Iterator[bool]: # pylint: disable=missing-docstring while True: for _ in range(n): yield False logger.log(level, message, *args) yield True + gen = _gen() return lambda: six.next(gen) -def log_every_n(n, level, message, *args): # pylint: disable=invalid-name +def log_every_n(n: int, level: int, message: Text, + *args: Any) -> Callable[[], bool]: """Logs a message every n calls. See _log_every_n_to_logger.""" return _log_every_n_to_logger(n, None, level, message, *args) -def time_millis(): # pylint: disable=invalid-name +def time_millis() -> int: """The time in milliseconds.""" return int(time.time() * 1000) -class NonLocalResult(mutablerecords.Record('NonLocal', [], {'result': None})): +@attr.s(slots=True) +class NonLocalResult(object): """Holds a single result as a nonlocal variable. Comparable to using Python 3's nonlocal keyword, it allows an inner function @@ -77,27 +83,37 @@ def InnerFunction(): return x.result """ + result = attr.ib(type=Any, default=None) + # TODO(jethier): Add a pylint plugin to avoid the no-self-argument for this. -class classproperty(object): +class classproperty(object): # pylint: disable=invalid-name """Exactly what it sounds like. Note that classproperties don't have setters, so setting them will replace the classproperty with the new value. In most use cases (forcing subclasses to override the classproperty, for example) this is desired. """ - def __init__(self, func): + + def __init__(self, func: Callable[..., Any]): self._func = func - def __get__(self, instance, owner): + def __get__(self, instance, owner) -> Any: return self._func(owner) -def partial_format(target, **kwargs): +def partial_format(target: Text, **kwargs: Any) -> Text: """Formats a string without requiring all values to be present. This function allows substitutions to be gradually made in several steps rather than all at once. Similar to string.Template.safe_substitute. + + Args: + target: format string. + **kwargs: format replacements. + + Returns: + Formatted string. """ output = target[:] @@ -109,22 +125,48 @@ def partial_format(target, **kwargs): return output + +FormatT = TypeVar('FormatT') + + +@typing.overload +def format_string(target: Text, kwargs: Dict[Text, Any]) -> Text: + pass + + +@typing.overload +def format_string(target: Callable[..., Text], kwargs: Dict[Text, Any]) -> Text: + pass + + +@typing.overload +def format_string(target: None, kwargs: Dict[Text, Any]) -> None: + pass + + +@typing.overload +def format_string(target: FormatT, kwargs: Dict[Text, Any]) -> FormatT: + pass + + def format_string(target, kwargs): """Formats a string in any of three ways (or not at all). Args: target: The target string to format. This can be a function that takes a - dict as its only argument, a string with {}- or %-based formatting, or - a basic string with none of those. In the latter case, the string is - returned as-is, but in all other cases the string is formatted (or the - callback called) with the given kwargs. - If this is None (or otherwise falsey), it is returned immediately. - kwargs: The arguments to use for formatting. - Passed to safe_format, %, or target if it's - callable. + dict as its only argument, a string with {}- or %-based formatting, or a + basic string with none of those. In the latter case, the string is + returned as-is, but in all other cases the string is formatted (or the + callback called) with the given kwargs. If this is None (or otherwise + falsey), it is returned immediately. + kwargs: The arguments to use for formatting. Passed to safe_format, %, or + target if it's callable. + + Returns: + Formatted string. """ - if not target: - return target + if target is None: + return None if callable(target): return target(**kwargs) if not isinstance(target, six.string_types): @@ -151,11 +193,11 @@ def __init__(self): self._lock = threading.Lock() self._update_events = weakref.WeakSet() - def _asdict(self): + def _asdict(self) -> Dict[Text, Any]: raise NotImplementedError( 'Subclasses of SubscribableStateMixin must implement _asdict.') - def asdict_with_event(self): + def asdict_with_event(self) -> Tuple[Dict[Text, Any], threading.Event]: """Get a dict representation of this object and an update event. Returns: @@ -168,7 +210,7 @@ def asdict_with_event(self): self._update_events.add(event) return self._asdict(), event - def notify_update(self): + def notify_update(self) -> None: """Notify any update events that there was an update.""" with self._lock: for event in self._update_events: diff --git a/openhtf/util/argv.py b/openhtf/util/argv.py index 8eb49ebf4..0acd01f2b 100644 --- a/openhtf/util/argv.py +++ b/openhtf/util/argv.py @@ -1,12 +1,23 @@ -"""Utilities for handling command line arguments. +"""Utilities for handling command line arguments.""" -StoreInModule: - Enables emulating a gflags-esque API (flag affects global value), but one - doesn't necessarily need to use flags to set values. +import argparse +import sys +import typing +from typing import Any, Optional, Text + + +def module_parser() -> argparse.ArgumentParser: + return argparse.ArgumentParser(add_help=False) + + +class StoreInModule(argparse.Action): + """ArgParse action emulating a gflags-esque API (flag affects global value). + + This doesn't necessarily need to use flags to set values. Example usage: DEFAULT_VALUE = 0 - ARG_PARSER = argv.ModuleParser() + ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( '--override-value', action=argv.StoreInModule, default=DEFAULT_VALUE, target='%s.DEFAULT_VALUE' % __name__) @@ -14,32 +25,28 @@ Then in an entry point (main() function), use that parser as a parent: parser = argparse.ArgumentParser(parents=[other_module.ARG_PARSER]) parser.parse_args() -""" - -import argparse -import sys - - -def ModuleParser(): - return argparse.ArgumentParser(add_help=False) - - -class StoreInModule(argparse.Action): + """ - def __init__(self, *args, **kwargs): - self._tgt_mod, self._tgt_attr = kwargs.pop('target').rsplit('.', 1) + def __init__(self, *args: Any, **kwargs: Any): + self._tgt_mod, self._tgt_attr = typing.cast(Text, + kwargs.pop('target')).rsplit( + '.', 1) proxy_cls = kwargs.pop('proxy', None) if proxy_cls is not None: self._proxy = proxy_cls(*args, **kwargs) super(StoreInModule, self).__init__(*args, **kwargs) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: Optional[Text] = None) -> None: if hasattr(self, '_proxy'): values = self._proxy(parser, namespace, values) setattr(self._resolve_module(), self._tgt_attr, values) # self.val = values - def _resolve_module(self): + def _resolve_module(self) -> Any: if '.' in self._tgt_mod: base, mod = self._tgt_mod.rsplit('.', 1) __import__(base, fromlist=[mod]) @@ -51,11 +58,15 @@ def _resolve_module(self): class _StoreValueInModule(StoreInModule): """Stores a value in a module level variable when set.""" - def __init__(self, const, *args, **kwargs): + def __init__(self, const: Any, *args: Any, **kwargs: Any): kwargs.update(nargs=0, const=const) super(_StoreValueInModule, self).__init__(*args, **kwargs) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: Optional[Text] = None) -> None: del values # Unused. super(_StoreValueInModule, self).__call__( parser, namespace, self.const, option_string=option_string) @@ -64,30 +75,34 @@ def __call__(self, parser, namespace, values, option_string=None): class StoreTrueInModule(_StoreValueInModule): """Stores True in a module level variable when set.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super(StoreTrueInModule, self).__init__(True, *args, **kwargs) class StoreFalseInModule(_StoreValueInModule): """Stores False in module level variable when set.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super(StoreFalseInModule, self).__init__(False, *args, **kwargs) class StoreRepsInModule(StoreInModule): """Store a count of number of times the flag was repeated in a module.""" - def __init__(self, *args, **kwargs): - kwargs.update(nargs=0, const=None) + def __init__(self, *args: Any, **kwargs: Any): + kwargs.update(nargs=0, const=None) # pytype: disable=wrong-arg-types super(StoreRepsInModule, self).__init__(*args, **kwargs) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: Optional[Text] = None) -> None: del values # Unused. old_count = getattr(self._resolve_module(), self._tgt_attr) if old_count is None: super(StoreRepsInModule, self).__call__( - parser, namespace, 0, option_string=option_string) + parser, namespace, 0, option_string=option_string) else: super(StoreRepsInModule, self).__call__( - parser, namespace, old_count + 1, option_string=option_string) + parser, namespace, old_count + 1, option_string=option_string) diff --git a/openhtf/util/atomic_write.py b/openhtf/util/atomic_write.py index fb2b3be50..8fb909cdf 100644 --- a/openhtf/util/atomic_write.py +++ b/openhtf/util/atomic_write.py @@ -11,19 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for automic_write a new file.""" +import contextlib import os import tempfile -from contextlib import contextmanager -@contextmanager + +@contextlib.contextmanager def atomic_write(filename, filesync=False): - """ atomically write a file (using a temporary file). + """Atomically write a file (using a temporary file). + + Args: + filename: the file to be written + filesync: flush the file to disk - filename: the file to be written - filesync: flush the file to disk + Yields: + File object to write to. """ tmpf = tempfile.NamedTemporaryFile(delete=False) diff --git a/openhtf/util/checkpoints.py b/openhtf/util/checkpoints.py index 88ba76fc8..09eb67dee 100644 --- a/openhtf/util/checkpoints.py +++ b/openhtf/util/checkpoints.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """A simple utility to check whether all previous phases have passed. In general test execution stops on a raised Exception but will continue if a @@ -28,22 +26,24 @@ previous phase has failed. """ -from openhtf.core import phase_descriptor -from openhtf.core import test_record +from typing import Optional, Text + +from openhtf.core import phase_branches + -def checkpoint(checkpoint_name=None): - name = checkpoint_name if checkpoint_name else 'Checkpoint' +def checkpoint( + checkpoint_name: Optional[Text] = None +) -> phase_branches.PhaseFailureCheckpoint: + """Creates a checkpoint phase that checks if all the previous phases passed. - @phase_descriptor.PhaseOptions(name=name) - def _checkpoint(test_run): - failed_phases = [] - for phase_record in test_run.test_record.phases: - if phase_record.outcome == test_record.PhaseOutcome.FAIL: - failed_phases.append(phase_record.name) + Args: + checkpoint_name: Optional name for the checkpoint phase; if not specified, + this defaults to 'checkpoint'. - if failed_phases: - test_run.logger.error('Stopping execution because phases failed: %s', - failed_phases) - return phase_descriptor.PhaseResult.STOP + Returns: + The checkpoint phase. + """ + if not checkpoint_name: + checkpoint_name = 'checkpoint' - return _checkpoint + return phase_branches.PhaseFailureCheckpoint.all_previous(checkpoint_name) diff --git a/openhtf/util/conf.py b/openhtf/util/conf.py index d897a466d..61d825064 100644 --- a/openhtf/util/conf.py +++ b/openhtf/util/conf.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Interface to OpenHTF configuration files. As a matter of convention, OpenHTF configuration files should contain values @@ -155,31 +154,35 @@ def do_stuff(): upon execution of the decorated callable, regardless of which keys are updated in the decorator or in the decorated callable. """ +# pytype: skip-file import argparse +import enum import functools import inspect import logging import sys import threading -import yaml +from typing import Any, Optional, Text -import mutablerecords +import attr +from openhtf.util import argv +from openhtf.util import threads import six - - -from . import argv -from . import threads +import yaml # If provided, --config-file will cause the given file to be load()ed when the # conf module is initially imported. -ARG_PARSER = argv.ModuleParser() +ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( - '--config-file', type=argparse.FileType('r'), + '--config-file', + type=argparse.FileType('r'), help='File from which to load configuration values.') ARG_PARSER.add_argument( - '--config-value', action='append', default=[], + '--config-value', + action='append', + default=[], help='Allows specifying a configuration key=value on the command line. ' 'The format should be --config-value=key=value. This value will override ' 'any loaded value, and will be a string.') @@ -208,21 +211,23 @@ class InvalidKeyError(Exception): class UnsetKeyError(Exception): """Raised when a key value is requested but we have no value for it.""" - # pylint: disable=invalid-name,bad-super-call - class Declaration(mutablerecords.Record( - 'Declaration', ['name'], { - 'description': None, 'default_value': None, 'has_default': False})): + @attr.s(slots=True) + class Declaration(object): """Record type encapsulating information about a config declaration.""" - def __init__(self, *args, **kwargs): - super(type(self), self).__init__(*args, **kwargs) - # Track this separately to allow for None as a default value, override - # any value that was passed in explicitly - don't do that. - self.has_default = 'default_value' in kwargs - # pylint: enable=invalid-name,bad-super-call - __slots__ = ('_logger', '_lock', '_modules', '_declarations', - '_flag_values', '_flags', '_loaded_values', 'ARG_PARSER', - '__name__') + class _DefaultSetting(enum.Enum): + NOT_SET = 0 + + name = attr.ib(type=Text) + description = attr.ib(type=Optional[Text], default=None) + default_value = attr.ib(type=Any, default=_DefaultSetting.NOT_SET) + + @property + def has_default(self) -> bool: + return self.default_value is not self._DefaultSetting.NOT_SET + + __slots__ = ('_logger', '_lock', '_modules', '_declarations', '_flag_values', + '_flags', '_loaded_values', 'ARG_PARSER', '__name__') def __init__(self, logger, lock, parser, **kwargs): """Initializes the configuration state. @@ -234,13 +239,14 @@ def __init__(self, logger, lock, parser, **kwargs): Args: logger: Logger to use for logging messages within this class. lock: Threading.lock to use for locking access to config values. + parser: the arg parser. **kwargs: Modules we need to access within this class. """ self._logger = logger self._lock = lock self._modules = kwargs self._declarations = {} - self.ARG_PARSER = parser + self.ARG_PARSER = parser # pylint: disable=invalid-name # Parse just the flags we care about, since this happens at import time. self._flags, _ = parser.parse_known_args() @@ -275,9 +281,9 @@ def _is_valid_key(key): """Return True if key is a valid configuration key.""" return key and key[0].islower() - def __setattr__(self, attr, value): + def __setattr__(self, field, value): """Provide a useful error when attempting to set a value via setattr().""" - if self._is_valid_key(attr): + if self._is_valid_key(field): raise AttributeError("Can't set conf values by attribute, use load()") # __slots__ is defined above, so this will raise an AttributeError if the # attribute isn't one we expect; this limits the number of ways to abuse the @@ -285,19 +291,19 @@ def __setattr__(self, attr, value): # normally here because of the sys.modules swap (Configuration is no longer # defined, and evaluates to None if used here). # pylint: disable=bad-super-call - super(type(self), self).__setattr__(attr, value) + super(type(self), self).__setattr__(field, value) # Don't use synchronized on this one, because __getitem__ handles it. - def __getattr__(self, attr): # pylint: disable=invalid-name + def __getattr__(self, field): """Get a config value via attribute access.""" - if self._is_valid_key(attr): - return self[attr] + if self._is_valid_key(field): + return self[field] # Config keys all begin with a lowercase letter, so treat this normally. raise AttributeError("'%s' object has no attribute '%s'" % - (type(self).__name__, attr)) + (type(self).__name__, field)) @threads.synchronized - def __getitem__(self, item): # pylint: disable=invalid-name + def __getitem__(self, item): """Get a config value via item access. Order of precedence is: @@ -307,6 +313,13 @@ def __getitem__(self, item): # pylint: disable=invalid-name Args: item: Config key name to get. + + Raises: + UndeclaredKeyError: If the item was not declared. + UnsetKeyError: When hte config value was not set and has no default. + + Returns: + The config value. """ if item not in self._declarations: raise self.UndeclaredKeyError('Configuration key not declared', item) @@ -314,24 +327,23 @@ def __getitem__(self, item): # pylint: disable=invalid-name if item in self._flag_values: if item in self._loaded_values: self._logger.warning( - 'Overriding loaded value for %s (%s) with flag value: %s', - item, self._loaded_values[item], self._flag_values[item]) + 'Overriding loaded value for %s (%s) with flag value: %s', item, + self._loaded_values[item], self._flag_values[item]) return self._flag_values[item] if item in self._loaded_values: return self._loaded_values[item] if self._declarations[item].has_default: return self._declarations[item].default_value - raise self.UnsetKeyError( - 'Configuration value not set and has no default', item) + raise self.UnsetKeyError('Configuration value not set and has no default', + item) @threads.synchronized - def __contains__(self, name): # pylint: disable=invalid-name + def __contains__(self, name): """True if we have a value for name.""" return (name in self._declarations and (self._declarations[name].has_default or - name in self._loaded_values or - name in self._flag_values)) + name in self._loaded_values or name in self._flag_values)) @threads.synchronized def declare(self, name, description=None, **kwargs): @@ -340,15 +352,19 @@ def declare(self, name, description=None, **kwargs): Args: name: Configuration key to declare, must not have been already declared. description: If provided, use this as the description for this key. - **kwargs: Other kwargs to pass to the Declaration, only default_value - is currently supported. + **kwargs: Other kwargs to pass to the Declaration, only default_value is + currently supported. + + Raises: + InvalidKeyError: When name is not constructed correctly. + KeyAlreadyDeclaredError: When name has already been defined. """ if not self._is_valid_key(name): raise self.InvalidKeyError( 'Invalid key name, must begin with a lowercase letter', name) if name in self._declarations: - raise self.KeyAlreadyDeclaredError( - 'Configuration key already declared', name) + raise self.KeyAlreadyDeclaredError('Configuration key already declared', + name) self._declarations[name] = self.Declaration( name, description=description, **kwargs) @@ -365,14 +381,18 @@ def reset(self): if self._flags.config_file is not None: self.load_from_file(self._flags.config_file, _allow_undeclared=True) - def load_from_file(self, yamlfile, _override=True, _allow_undeclared=False): + def load_from_file(self, yamlfile, _override=True, _allow_undeclared=False): # pylint: disable=invalid-name """Loads the configuration from a file. Parsed contents must be a single dict mapping config key to value. Args: - yamlfile: The opened file object to load configuration from. - See load_from_dict() for other args' descriptions. + yamlfile: The opened file object to load configuration from. See + load_from_dict() for other args' descriptions. + _override: If True, new values will override previous values. + _allow_undeclared: If True, silently load undeclared keys, otherwise warn + and ignore the value. Typically used for loading config files before + declarations have been evaluated. Raises: ConfigurationInvalidError: If configuration file can't be read, or can't @@ -384,8 +404,8 @@ def load_from_file(self, yamlfile, _override=True, _allow_undeclared=False): parsed_yaml = self._modules['yaml'].safe_load(yamlfile.read()) except self._modules['yaml'].YAMLError: self._logger.exception('Problem parsing YAML') - raise self.ConfigurationInvalidError( - 'Failed to load from %s as YAML' % yamlfile) + raise self.ConfigurationInvalidError('Failed to load from %s as YAML' % + yamlfile) if not isinstance(parsed_yaml, dict): # Parsed YAML, but it's not a dict. @@ -396,13 +416,13 @@ def load_from_file(self, yamlfile, _override=True, _allow_undeclared=False): self.load_from_dict( parsed_yaml, _override=_override, _allow_undeclared=_allow_undeclared) - def load(self, _override=True, _allow_undeclared=False, **kwargs): + def load(self, _override=True, _allow_undeclared=False, **kwargs): # pylint: disable=invalid-name """load configuration values from kwargs, see load_from_dict().""" self.load_from_dict( kwargs, _override=_override, _allow_undeclared=_allow_undeclared) @threads.synchronized - def load_from_dict(self, dictionary, _override=True, _allow_undeclared=False): + def load_from_dict(self, dictionary, _override=True, _allow_undeclared=False): # pylint: disable=invalid-name """Loads the config with values from a dictionary instead of a file. This is meant for testing and bin purposes and shouldn't be used in most @@ -411,9 +431,9 @@ def load_from_dict(self, dictionary, _override=True, _allow_undeclared=False): Args: dictionary: The dictionary containing config keys/values to update. _override: If True, new values will override previous values. - _allow_undeclared: If True, silently load undeclared keys, otherwise - warn and ignore the value. Typically used for loading config - files before declarations have been evaluated. + _allow_undeclared: If True, silently load undeclared keys, otherwise warn + and ignore the value. Typically used for loading config files before + declarations have been evaluated. """ undeclared_keys = [] for key, value in self._modules['six'].iteritems(dictionary): @@ -449,8 +469,11 @@ def _asdict(self): """Create a dictionary snapshot of the current config values.""" # Start with any default values we have, and override with loaded values, # and then override with flag values. - retval = {key: self._declarations[key].default_value for - key in self._declarations if self._declarations[key].has_default} + retval = { + key: self._declarations[key].default_value + for key in self._declarations + if self._declarations[key].has_default + } retval.update(self._loaded_values) # Only update keys that are declared so we don't allow injecting # un-declared keys via commandline flags. @@ -466,7 +489,7 @@ def __dict__(self): def help_text(self): """Return a string with all config keys and their descriptions.""" result = [] - for name in sorted(self._declarations.keys()): + for name in sorted(self._declarations): result.append(name) result.append('-' * len(name)) decl = self._declarations[name] @@ -476,14 +499,14 @@ def help_text(self): result.append('(no description found)') if decl.has_default: result.append('') - quotes = '"' if type(decl.default_value) is str else '' + quotes = '"' if isinstance(decl.default_value, str) else '' result.append(' default_value={quotes}{val}{quotes}'.format( quotes=quotes, val=decl.default_value)) result.append('') result.append('') return '\n'.join(result) - def save_and_restore(self, _func=None, **config_values): + def save_and_restore(self, _func=None, **config_values): # pylint: disable=invalid-name """Decorator for saving conf state and restoring it after a function. This decorator is primarily for use in tests, where conf keys may be updated @@ -521,10 +544,10 @@ def MyOtherTestFunc(): Args: _func: The function to wrap. The returned wrapper will invoke the - function and restore the config to the state it was in at invocation. + function and restore the config to the state it was in at invocation. **config_values: Config keys can be set inline at decoration time, see - examples. Note that config keys can't begin with underscore, so - there can be no name collision with _func. + examples. Note that config keys can't begin with underscore, so there + can be no name collision with _func. Returns: Wrapper to replace _func, as per Python decorator semantics. @@ -541,7 +564,8 @@ def _saving_wrapper(*args, **kwargs): self.load_from_dict(config_values) return _func(*args, **kwargs) finally: - self._loaded_values = saved_config # pylint: disable=attribute-defined-outside-init + self._loaded_values = saved_config # pylint: disable=attribute-defined-outside-init + return _saving_wrapper def inject_positional_args(self, method): @@ -567,8 +591,12 @@ def inject_positional_args(self, method): A wrapper that, when invoked, will call the wrapped method, passing in configuration values for positional arguments. """ - inspect = self._modules['inspect'] - argspec = inspect.getargspec(method) + inspect = self._modules['inspect'] # pylint: disable=redefined-outer-name + six = self._modules['six'] # pylint: disable=redefined-outer-name + if six.PY3: + argspec = inspect.getfullargspec(method) + else: + argspec = inspect.getargspec(method) # pylint: disable=deprecated-method # Index in argspec.args of the first keyword argument. This index is a # negative number if there are any kwargs, or 0 if there are no kwargs. @@ -588,16 +616,18 @@ def method_wrapper(**kwargs): # Check for keyword args with names that are in the config so we can warn. for kwarg in kwarg_names: if kwarg in self: - self._logger.warning('Keyword arg %s not set from configuration, but ' - 'is a configuration key', kwarg) + self._logger.warning( + 'Keyword arg %s not set from configuration, but ' + 'is a configuration key', kwarg) # Set positional args from configuration values. final_kwargs = {name: self[name] for name in arg_names if name in self} for overridden in set(kwargs) & set(final_kwargs): - self._logger.warning('Overriding configuration value for kwarg %s (%s) ' - 'with provided kwarg value: %s', overridden, - self[overridden], kwargs[overridden]) + self._logger.warning( + 'Overriding configuration value for kwarg %s (%s) ' + 'with provided kwarg value: %s', overridden, self[overridden], + kwargs[overridden]) final_kwargs.update(kwargs) if inspect.ismethod(method): @@ -610,16 +640,24 @@ def method_wrapper(**kwargs): # We have to check for a 'self' parameter explicitly because Python doesn't # pass it as a keyword arg, it passes it as the first positional arg. if argspec.args[0] == 'self': + @functools.wraps(method) - def self_wrapper(self, **kwargs): # pylint: disable=invalid-name + def self_wrapper(self, **kwargs): """Wrapper that pulls values from openhtf.util.conf.""" kwargs['self'] = self return method_wrapper(**kwargs) + return self_wrapper return method_wrapper + # Swap out the module for a singleton instance of Configuration so we can # provide __getattr__ and __getitem__ functionality at the module level. sys.modules[__name__] = Configuration( - logging.getLogger(__name__), threading.RLock(), ARG_PARSER, - functools=functools, inspect=inspect, yaml=yaml, six=six) + logging.getLogger(__name__), + threading.RLock(), + ARG_PARSER, + functools=functools, + inspect=inspect, + yaml=yaml, + six=six) diff --git a/openhtf/util/console_output.py b/openhtf/util/console_output.py index 51beabf69..dd1d7d781 100644 --- a/openhtf/util/console_output.py +++ b/openhtf/util/console_output.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Console output utilities for OpenHTF. This module provides convenience methods to format output for the CLI, along @@ -32,7 +30,6 @@ import string import sys import textwrap -import time import colorama import contextlib2 as contextlib @@ -48,12 +45,14 @@ # logging that uses a CliQuietFilter. CLI_QUIET = False -ARG_PARSER = argv.ModuleParser() +ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( - '--quiet', action=argv.StoreTrueInModule, target='%s.CLI_QUIET' % __name__, - help=textwrap.dedent('''\ + '--quiet', + action=argv.StoreTrueInModule, + target='%s.CLI_QUIET' % __name__, + help=textwrap.dedent("""\ Suppress all CLI output from OpenHTF's printing functions and logging. - This flag will override any verbosity levels set with -v.''')) + This flag will override any verbosity levels set with -v.""")) ANSI_ESC_RE = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') @@ -64,8 +63,8 @@ class ActionFailedError(Exception): def _printed_len(some_string): """Compute the visible length of the string when printed.""" - return len([x for x in ANSI_ESC_RE.sub('', some_string) - if x in string.printable]) + return len( + [x for x in ANSI_ESC_RE.sub('', some_string) if x in string.printable]) def _linesep_for_file(file): @@ -83,19 +82,14 @@ def banner_print(msg, color='', width=60, file=sys.stdout, logger=_LOG): Args: msg: The message to print. color: Optional colorama color string to be applied to the message. You can - concatenate colorama color strings together in order to get any set of - effects you want. + concatenate colorama color strings together in order to get any set of + effects you want. width: Total width for the resulting banner. file: A file object to which the banner text will be written. Intended for - use with CLI output file objects like sys.stdout. + use with CLI output file objects like sys.stdout. logger: A logger to use, or None to disable logging. - - Example: - - >>> banner_print('Foo Bar Baz') - - ======================== Foo Bar Baz ======================= - + Example: >>> banner_print('Foo Bar Baz') ======================== Foo Bar + Baz ======================= """ if logger: logger.debug(ANSI_ESC_RE.sub('', msg)) @@ -104,7 +98,11 @@ def banner_print(msg, color='', width=60, file=sys.stdout, logger=_LOG): lpad = int(math.ceil((width - _printed_len(msg) - 2) / 2.0)) * '=' rpad = int(math.floor((width - _printed_len(msg) - 2) / 2.0)) * '=' file.write('{sep}{color}{lpad} {msg} {rpad}{reset}{sep}{sep}'.format( - sep=_linesep_for_file(file), color=color, lpad=lpad, msg=msg, rpad=rpad, + sep=_linesep_for_file(file), + color=color, + lpad=lpad, + msg=msg, + rpad=rpad, reset=colorama.Style.RESET_ALL)) file.flush() @@ -115,21 +113,25 @@ def bracket_print(msg, color='', width=8, file=sys.stdout, end_line=True): Args: msg: The message to put inside the brackets (a brief status message). color: Optional colorama color string to be applied to the message. You can - concatenate colorama color strings together in order to get any set of - effects you want. + concatenate colorama color strings together in order to get any set of + effects you want. width: Total desired width of the bracketed message. file: A file object to which the bracketed text will be written. Intended - for use with CLI output file objects like sys.stdout. + for use with CLI output file objects like sys.stdout. end_line: If True, end the line and flush the file object after outputting - the bracketed text. - """ + the bracketed text. + """ if CLI_QUIET: return lpad = int(math.ceil((width - 2 - _printed_len(msg)) / 2.0)) * ' ' rpad = int(math.floor((width - 2 - _printed_len(msg)) / 2.0)) * ' ' file.write('[{lpad}{bright}{color}{msg}{reset}{rpad}]'.format( - lpad=lpad, bright=colorama.Style.BRIGHT, color=color, msg=msg, - reset=colorama.Style.RESET_ALL, rpad=rpad)) + lpad=lpad, + bright=colorama.Style.BRIGHT, + color=color, + msg=msg, + reset=colorama.Style.RESET_ALL, + rpad=rpad)) file.write(colorama.Style.RESET_ALL) if end_line: file.write(_linesep_for_file(file)) @@ -146,15 +148,15 @@ def cli_print(msg, color='', end=None, file=sys.stdout, logger=_LOG): Args: msg: The message to print/log. color: Optional colorama color string to be applied to the message. You can - concatenate colorama color strings together in order to get any set of - effects you want. + concatenate colorama color strings together in order to get any set of + effects you want. end: A custom line-ending string to print instead of newline. file: A file object to which the baracketed text will be written. Intended - for use with CLI output file objects like sys.stdout. + for use with CLI output file objects like sys.stdout. logger: A logger to use, or None to disable logging. """ if logger: - logger.debug('-> {}'.format(msg)) + logger.debug('-> %s', msg) if CLI_QUIET: return if end is None: @@ -169,16 +171,20 @@ def error_print(msg, color=colorama.Fore.RED, file=sys.stderr): Args: msg: The error message to be printed. color: Optional colorama color string to be applied to the message. You can - concatenate colorama color strings together here, but note that style - strings will not be applied. + concatenate colorama color strings together here, but note that style + strings will not be applied. file: A file object to which the baracketed text will be written. Intended - for use with CLI output file objects, specifically sys.stderr. + for use with CLI output file objects, specifically sys.stderr. """ if CLI_QUIET: return file.write('{sep}{bright}{color}Error: {normal}{msg}{sep}{reset}'.format( - sep=_linesep_for_file(file), bright=colorama.Style.BRIGHT, color=color, - normal=colorama.Style.NORMAL, msg=msg, reset=colorama.Style.RESET_ALL)) + sep=_linesep_for_file(file), + bright=colorama.Style.BRIGHT, + color=color, + normal=colorama.Style.NORMAL, + msg=msg, + reset=colorama.Style.RESET_ALL)) file.flush() @@ -230,37 +236,8 @@ def action_result_context(action_text, file: Specific file object to write to write CLI output to. logger: A logger to use, or None to disable logging. - Example usage: - with action_result_context('Doing an action that will succeed...') as act: - time.sleep(2) - act.succeed() - - with action_result_context('Doing an action with unset result...') as act: - time.sleep(2) - - with action_result_context('Doing an action that will fail...') as act: - time.sleep(2) - act.fail() - - with action_result_context('Doing an action that will raise...') as act: - time.sleep(2) - import textwrap - raise RuntimeError(textwrap.dedent('''\ - Uh oh, looks like there was a raise in the mix. - - If you see this message, it means you are running the console_output - module directly rather than using it as a library. Things to try: - - * Not running it as a module. - * Running it as a module and enjoying the preview text. - * Getting another coffee.''')) - - Example output: - Doing an action that will succeed... [ OK ] - Doing an action with unset result... [ ???? ] - Doing an action that will fail... [ FAIL ] - Doing an action that will raise... [ FAIL ] - ... + Yields: + ActionResult to declare the result of the action. """ if logger: logger.debug('Action - %s', action_text) @@ -272,13 +249,13 @@ def action_result_context(action_text, result = ActionResult() try: yield result - except Exception as err: + except Exception as err: # pylint: disable=broad-except if logger: logger.debug('Result - %s [ %s ]', action_text, fail_text) if not CLI_QUIET: file.write(''.join((action_text, spacing))) - bracket_print(fail_text, width=status_width, color=colorama.Fore.RED, - file=file) + bracket_print( + fail_text, width=status_width, color=colorama.Fore.RED, file=file) if not isinstance(err, ActionFailedError): raise return @@ -289,37 +266,8 @@ def action_result_context(action_text, logger.debug('Result - %s [ %s ]', action_text, result_text) if not CLI_QUIET: file.write(''.join((action_text, spacing))) - bracket_print(result_text, width=status_width, color=result_color, - file=file) - - -# If invoked as a runnable module, this module will invoke its action result -# context in order to print colorized example output. -if __name__ == '__main__': - banner_print('Running pre-flight checks.') - - with action_result_context('Doing an action that will succeed...') as act: - time.sleep(2) - act.succeed() - - with action_result_context('Doing an action with unset result...') as act: - time.sleep(2) - - with action_result_context('Doing an action that will fail...') as act: - time.sleep(2) - act.fail() - - with action_result_context('Doing an action that will raise...') as act: - time.sleep(2) - raise RuntimeError(textwrap.dedent('''\ - Uh oh, looks like there was a raise in the mix. - - If you see this message, it means you are running the console_output - module directly rather than using it as a library. Things to try: - - * Not running it as a module. - * Running it as a module and enjoying the preview text. - * Getting another coffee.''')) + bracket_print( + result_text, width=status_width, color=result_color, file=file) class CliQuietFilter(logging.Filter): @@ -329,5 +277,6 @@ class CliQuietFilter(logging.Filter): module, and can thus be overridden in test scripts. This filter should only be used with loggers that print to the CLI. """ + def filter(self, record): return not CLI_QUIET diff --git a/openhtf/util/data.py b/openhtf/util/data.py index cbe848a8a..ce929b0aa 100644 --- a/openhtf/util/data.py +++ b/openhtf/util/data.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Module for utility functions that manipulate or compare data. We use a few special data formats internally, these utility functions make it a little easier to work with them. """ +import copy import difflib +import enum import itertools import logging import math @@ -26,16 +27,16 @@ import pprint import struct import sys +from typing import Any, TypeVar import attr from mutablerecords import records from past.builtins import long from past.builtins import unicode -from enum import Enum import six from six.moves import collections_abc - +from six.moves import zip # Used by convert_to_base_types(). PASSTHROUGH_TYPES = {bool, bytes, int, long, type(None), unicode} @@ -46,7 +47,9 @@ def pprint_diff(first, second, first_name='first', second_name='second'): return difflib.unified_diff( pprint.pformat(first).splitlines(), pprint.pformat(second).splitlines(), - fromfile=first_name, tofile=second_name, lineterm='') + fromfile=first_name, + tofile=second_name, + lineterm='') def equals_log_diff(expected, actual, level=logging.ERROR): @@ -57,8 +60,11 @@ def equals_log_diff(expected, actual, level=logging.ERROR): # Output the diff first. logging.log(level, '***** Data mismatch: *****') for line in difflib.unified_diff( - expected.splitlines(), actual.splitlines(), - fromfile='expected', tofile='actual', lineterm=''): + expected.splitlines(), + actual.splitlines(), + fromfile='expected', + tofile='actual', + lineterm=''): logging.log(level, line) logging.log(level, '^^^^^ Data diff ^^^^^') @@ -66,9 +72,16 @@ def equals_log_diff(expected, actual, level=logging.ERROR): def assert_records_equal_nonvolatile(first, second, volatile_fields, indent=0): """Compare two test_record tuples, ignoring any volatile fields. - 'Volatile' fields include any fields that are expected to differ between - successive runs of the same test, mainly timestamps. All other fields - are recursively compared. + Args: + first: test_record.TestRecord to compare. + second: test_record.TestRecord to compare. + volatile_fields: list of str, any fields that are expected to differ between + successive runs of the same test, mainly timestamps. All other fields are + recursively compared. + indent: int, indent level. + + Raises: + AssertionError: when the records are different. """ if isinstance(first, dict) and isinstance(second, dict): if set(first) != set(second): @@ -90,7 +103,7 @@ def assert_records_equal_nonvolatile(first, second, volatile_fields, indent=0): assert_records_equal_nonvolatile(first._asdict(), second._asdict(), volatile_fields, indent) elif hasattr(first, '__iter__') and hasattr(second, '__iter__'): - for idx, (fir, sec) in enumerate(itertools.izip(first, second)): + for idx, (fir, sec) in enumerate(zip(first, second)): try: assert_records_equal_nonvolatile(fir, sec, volatile_fields, indent + 2) except AssertionError: @@ -107,7 +120,9 @@ def assert_records_equal_nonvolatile(first, second, volatile_fields, indent=0): assert first == second -def convert_to_base_types(obj, ignore_keys=tuple(), tuple_type=tuple, +def convert_to_base_types(obj, + ignore_keys=tuple(), + tuple_type=tuple, json_safe=True): """Recursively convert objects into base types. @@ -135,10 +150,18 @@ def convert_to_base_types(obj, ignore_keys=tuple(), tuple_type=tuple, encoding that does not differentiate between the two), pass 'tuple_type=list' as an argument. - If `json_safe` is True, then the float 'inf', '-inf', and 'nan' values will be - converted to strings. This ensures that the returned dictionary can be passed - to json.dumps to create valid JSON. Otherwise, json.dumps may return values - such as NaN which are not valid JSON. + Args: + obj: object to recursively convert to base types. + ignore_keys: Iterable of str, keys that should be ignored when recursing on + dict types. + tuple_type: Type used for tuple objects. + json_safe: If True, then the float 'inf', '-inf', and 'nan' values will be + converted to strings. This ensures that the returned dictionary can be + passed to json.dumps to create valid JSON. Otherwise, json.dumps may + return values such as NaN which are not valid JSON. + + Returns: + Version of the object composed of base types. """ # Because it's *really* annoying to pass a single string accidentally. assert not isinstance(ignore_keys, six.string_types), 'Pass a real iterable!' @@ -155,21 +178,26 @@ def convert_to_base_types(obj, ignore_keys=tuple(), tuple_type=tuple, new_obj[a] = val obj = new_obj elif attr.has(type(obj)): - obj = attr.asdict(obj) - elif isinstance(obj, Enum): + obj = attr.asdict(obj, recurse=False) + elif isinstance(obj, enum.Enum): obj = obj.name - if type(obj) in PASSTHROUGH_TYPES: + if type(obj) in PASSTHROUGH_TYPES: # pylint: disable=unidiomatic-typecheck return obj # Recursively convert values in dicts, lists, and tuples. if isinstance(obj, dict): - return {convert_to_base_types(k, ignore_keys, tuple_type): - convert_to_base_types(v, ignore_keys, tuple_type) - for k, v in six.iteritems(obj) if k not in ignore_keys} + return { # pylint: disable=g-complex-comprehension + convert_to_base_types(k, ignore_keys, tuple_type): + convert_to_base_types(v, ignore_keys, tuple_type) + for k, v in six.iteritems(obj) + if k not in ignore_keys + } elif isinstance(obj, list): - return [convert_to_base_types(val, ignore_keys, tuple_type, json_safe) - for val in obj] + return [ + convert_to_base_types(val, ignore_keys, tuple_type, json_safe) + for val in obj + ] elif isinstance(obj, tuple): return tuple_type( convert_to_base_types(value, ignore_keys, tuple_type, json_safe) @@ -195,6 +223,7 @@ def convert_to_base_types(obj, ignore_keys=tuple(), tuple_type=tuple, def total_size(obj): """Returns the approximate total memory footprint an object.""" seen = set() + def sizeof(current_obj): try: return _sizeof(current_obj) @@ -211,14 +240,36 @@ def _sizeof(current_obj): size = sys.getsizeof(current_obj) if isinstance(current_obj, dict): - size += sum(map(sizeof, itertools.chain.from_iterable( - six.iteritems(current_obj)))) + size += sum( + map(sizeof, + itertools.chain.from_iterable(six.iteritems(current_obj)))) elif (isinstance(current_obj, collections_abc.Iterable) and not isinstance(current_obj, six.string_types)): size += sum(sizeof(item) for item in current_obj) elif isinstance(current_obj, records.RecordClass): - size += sum(sizeof(getattr(current_obj, a)) - for a in current_obj.__slots__) + size += sum( + sizeof(getattr(current_obj, a)) for a in current_obj.__slots__) return size return sizeof(obj) + + +_AttrCopyT = TypeVar('_AttrCopyT') + + +def attr_copy(obj: _AttrCopyT, **overrides: Any) -> _AttrCopyT: + """Recursively copy an attr-defined object.""" + kwargs = dict(overrides) + for field in attr.fields(type(obj)): + name = field.name + init_name = name if name[0] != '_' else name[1:] + # Skip fields being set in the override. + if init_name in overrides: + continue + value = getattr(obj, name) + if attr.has(value): + new_value = attr_copy(value) + else: + new_value = copy.copy(value) + kwargs[init_name] = new_value + return type(obj)(**kwargs) diff --git a/openhtf/util/exceptions.py b/openhtf/util/exceptions.py deleted file mode 100644 index 6aca117bc..000000000 --- a/openhtf/util/exceptions.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2014 Google Inc. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Utils for dealing with exceptions.""" - -import inspect -import sys - - -def reraise(exc_type, message=None, *args, **kwargs): # pylint: disable=invalid-name - """reraises an exception for exception translation. - - This is primarily used for when you immediately reraise an exception that is - thrown in a library, so that your client will not have to depend on various - exceptions defined in the library implementation that is being abstracted. The - advantage of this helper function is somewhat preserve traceback information - although it is polluted by the reraise frame. - - Example Code: - def A(): - raise Exception('Whoops') - def main(): - try: - A() - except Exception as e: - exceptions.reraise(ValueError) - main() - - Traceback (most recent call last): - File "exception.py", line 53, in - main() - File "exception.py", line 49, in main - reraise(ValueError) - File "exception.py", line 47, in main - A() - File "exception.py", line 42, in A - raise Exception('Whoops') - ValueError: line 49 - - When this code is run, the additional stack frames for calling A() and raising - within A() are printed out in exception, whereas a bare exception translation - would lose this information. As long as you ignore the reraise stack frame, - the stack trace is okay looking. - - Generally this can be fixed by hacking on CPython to allow modification of - traceback objects ala - https://github.com/mitsuhiko/jinja2/blob/master/jinja2/debug.py, but this is - fixed in Python 3 anyways and that method is the definition of hackery. - - Args: - exc_type: (Exception) Exception class to create. - message: (str) Optional message to place in exception instance. Usually not - needed as the original exception probably has a message that will be - printed out in the modified stacktrace. - *args: Args to pass to exception constructor. - **kwargs: Kwargs to pass to exception constructor. - """ - last_lineno = inspect.currentframe().f_back.f_lineno - line_msg = 'line %s: ' % last_lineno - if message: - line_msg += str(message) - raise exc_type(line_msg, *args, **kwargs).raise_with_traceback(sys.exc_info()[2]) diff --git a/openhtf/util/functions.py b/openhtf/util/functions.py index 4fa4e2816..603357c39 100644 --- a/openhtf/util/functions.py +++ b/openhtf/util/functions.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities for functions.""" import collections @@ -19,6 +18,8 @@ import inspect import time +import six + def call_once(func): """Decorate a function to only allow it to be called once. @@ -26,30 +27,55 @@ def call_once(func): Note that it doesn't make sense to only call a function once if it takes arguments (use @functools.lru_cache for that sort of thing), so this only works on callables that take no args. + + Args: + func: function to decorate to only be called once. + + Returns: + The decorated function. """ - argspec = inspect.getargspec(func) - if argspec.args or argspec.varargs or argspec.keywords: + if six.PY3: + argspec = inspect.getfullargspec(func) + argspec_args = argspec.args + argspec_varargs = argspec.varargs + argspec_keywords = argspec.varkw + else: + argspec = inspect.getargspec(func) # pylint: disable=deprecated-method + argspec_args = argspec.args + argspec_varargs = argspec.varargs + argspec_keywords = argspec.keywords + if argspec_args or argspec_varargs or argspec_keywords: raise ValueError('Can only decorate functions with no args', func, argspec) @functools.wraps(func) def _wrapper(): # If we haven't been called yet, actually invoke func and save the result. - if not _wrapper.HasRun(): - _wrapper.MarkAsRun() + if not _wrapper.has_run(): + _wrapper.mark_as_run() _wrapper.return_value = func() return _wrapper.return_value - _wrapper.has_run = False - _wrapper.HasRun = lambda: _wrapper.has_run - _wrapper.MarkAsRun = lambda: setattr(_wrapper, 'has_run', True) + _wrapper._has_run = False # pylint: disable=protected-access + _wrapper.has_run = lambda: _wrapper._has_run # pylint: disable=protected-access + _wrapper.mark_as_run = lambda: setattr(_wrapper, '_has_run', True) return _wrapper + def call_at_most_every(seconds, count=1): """Call the decorated function at most count times every seconds seconds. The decorated function will sleep to ensure that at most count invocations occur within any 'seconds' second window. + + Args: + seconds: time in seconds that this function will get called at most count + times over. + count: int, number of times it can be called in seconds duration. + + Returns: + Decorated function. """ + def decorator(func): try: call_history = getattr(func, '_call_history') @@ -69,5 +95,7 @@ def _wrapper(*args, **kwargs): # Append this call, deque will automatically trim old calls using maxlen. call_history.append(time.time()) return func(*args, **kwargs) + return _wrapper + return decorator diff --git a/openhtf/util/logs.py b/openhtf/util/logs.py index ef6374173..0ac4b37be 100644 --- a/openhtf/util/logs.py +++ b/openhtf/util/logs.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Logging mechanisms for use in OpenHTF. Below is an illustration of the tree of OpenHTF loggers: @@ -119,25 +117,35 @@ def MyPhase(test, helper): # Will be overridden if the ARG_PARSER below parses the -v argument. CLI_LOGGING_VERBOSITY = 0 -ARG_PARSER = argv.ModuleParser() +ARG_PARSER = argv.module_parser() ARG_PARSER.add_argument( - '-v', action=argv.StoreRepsInModule, + '-v', + action=argv.StoreRepsInModule, target='%s.CLI_LOGGING_VERBOSITY' % __name__, - help=textwrap.dedent('''\ + help=textwrap.dedent("""\ CLI logging verbosity. Can be repeated to increase verbosity (i.e. -v, - -vv, -vvv).''')) + -vv, -vvv).""")) LOGGER_PREFIX = 'openhtf' RECORD_LOGGER_PREFIX = '.'.join((LOGGER_PREFIX, 'test_record')) -RECORD_LOGGER_RE = re.compile( - r'%s\.(?P[^.]*)\.?' % RECORD_LOGGER_PREFIX) +RECORD_LOGGER_RE = re.compile(r'%s\.(?P[^.]*)\.?' % + RECORD_LOGGER_PREFIX) SUBSYSTEM_LOGGER_RE = re.compile( r'%s\.[^.]*\.(?Pplug|phase)\.(?P[^.]*)' % RECORD_LOGGER_PREFIX) _LOG_ONCE_SEEN = set() -LogRecord = collections.namedtuple( - 'LogRecord', 'level logger_name source lineno timestamp_millis message') + +class LogRecord( + collections.namedtuple('LogRecord', [ + 'level', + 'logger_name', + 'source', + 'lineno', + 'timestamp_millis', + 'message', + ])): + pass class HtfTestLogger(logging.Logger): @@ -172,6 +180,11 @@ def initialize_record_handler(test_uid, test_record, notify_update): For each running test, we attach a record handler to the top-level OpenHTF logger. The handler will append OpenHTF logs to the test record, while filtering out logs that are specific to any other test run. + + Args: + test_uid: UID for the test run. + test_record: The test record for the current test run. + notify_update: Function that gets called when the test record is updated. """ htf_logger = logging.getLogger(LOGGER_PREFIX) htf_logger.addHandler(RecordHandler(test_uid, test_record, notify_update)) @@ -196,7 +209,8 @@ def log_once(log_func, msg, *args, **kwargs): class MacAddressLogFilter(logging.Filter): """A filter which redacts MAC addresses.""" - MAC_REPLACE_RE = re.compile(r""" + MAC_REPLACE_RE = re.compile( + r""" ((?:[\dA-F]{2}:){3}) # 3-part prefix, f8:8f:ca means google (?:[\dA-F]{2}(:|\b)){3} # the remaining octets """, re.IGNORECASE | re.VERBOSE) @@ -209,13 +223,14 @@ def filter(self, record): record.msg = self.MAC_REPLACE_RE.sub(self.MAC_REPLACEMENT, record.msg) record.args = tuple([ self.MAC_REPLACE_RE.sub(self.MAC_REPLACEMENT, str(arg)) - if isinstance(arg, six.string_types) - else arg for arg in record.args]) + if isinstance(arg, six.string_types) else arg for arg in record.args + ]) else: - record.msg = self.MAC_REPLACE_RE.sub( - self.MAC_REPLACEMENT, record.getMessage()) + record.msg = self.MAC_REPLACE_RE.sub(self.MAC_REPLACEMENT, + record.getMessage()) return True + # We use one shared instance of this, it has no internal state. MAC_FILTER = MacAddressLogFilter() @@ -279,8 +294,12 @@ def emit(self, record): try: message = self.format(record) log_record = LogRecord( - record.levelno, record.name, os.path.basename(record.pathname), - record.lineno, int(record.created * 1000), message, + record.levelno, + record.name, + os.path.basename(record.pathname), + record.lineno, + int(record.created * 1000), + message, ) self._test_record.add_log_record(log_record) self._notify_update() @@ -304,15 +323,12 @@ def format(self, record): subsys_match = SUBSYSTEM_LOGGER_RE.match(record.name) if subsys_match: terse_name = '<{subsys}: {id}>'.format( - subsys=subsys_match.group('subsys'), - id=subsys_match.group('id')) + subsys=subsys_match.group('subsys'), id=subsys_match.group('id')) else: # Fall back to using the last five characters of the test UUID. terse_name = '' % match.group('test_uid')[-5:] - return '{lvl} {time} {logger} - {msg}'.format(lvl=terse_level, - time=terse_time, - logger=terse_name, - msg=record.message) + return '{lvl} {time} {logger} - {msg}'.format( + lvl=terse_level, time=terse_time, logger=terse_name, msg=record.message) @functions.call_once diff --git a/openhtf/util/multicast.py b/openhtf/util/multicast.py index 9f464b246..c66c0d542 100644 --- a/openhtf/util/multicast.py +++ b/openhtf/util/multicast.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Multicast facilities for sending and receiving messages. This module includes both a MulticastListener that listens on a multicast @@ -45,15 +43,6 @@ class MulticastListener(threading.Thread): The listener will force-bind to the multicast port via the SO_REUSEADDR option, so it's possible for multiple listeners to bind to the same port. - - Args: - callback: A callable to invoke upon receipt of a multicast message. Will be - called with one argument -- the text of the message received. - callback can optionally return a string response, which will be - transmitted back to the sender. - address: Multicast IP address component of the socket to listen on. - port: Multicast UDP port component of the socket to listen on. - ttl: TTL for multicast messages. 1 to keep traffic in-network. """ LISTEN_TIMEOUT_S = 60 # Seconds to listen before retrying. daemon = True @@ -63,6 +52,17 @@ def __init__(self, address=DEFAULT_ADDRESS, port=DEFAULT_PORT, ttl=DEFAULT_TTL): + """Constructor. + + Args: + callback: A callable to invoke upon receipt of a multicast message. Will + be called with one argument -- the text of the message received. + callback can optionally return a string response, which will be + transmitted back to the sender. + address: Multicast IP address component of the socket to listen on. + port: Multicast UDP port component of the socket to listen on. + ttl: TTL for multicast messages. 1 to keep traffic in-network. + """ super(MulticastListener, self).__init__() self.address = address self.port = port @@ -70,9 +70,7 @@ def __init__(self, self._callback = callback self._live = False self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self._sock.setsockopt(socket.IPPROTO_IP, - socket.IP_MULTICAST_TTL, - self.ttl) + self._sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, self.ttl) def stop(self, timeout_s=None): """Stop listening for messages.""" @@ -98,15 +96,13 @@ def run(self): socket.IP_ADD_MEMBERSHIP, # IP_ADD_MEMBERSHIP takes the 8-byte group address followed by the IP # assigned to the interface on which to listen. - struct.pack('!4sL', socket.inet_aton(self.address), interface_ip)) + struct.pack('!4sL', socket.inet_aton(self.address), interface_ip)) # pylint: disable=g-socket-inet-aton if sys.platform == 'darwin': - self._sock.setsockopt(socket.SOL_SOCKET, - socket.SO_REUSEPORT, + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # Allow multiple listeners to bind. else: - self._sock.setsockopt(socket.SOL_SOCKET, - socket.SO_REUSEADDR, + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow multiple listeners to bind. self._sock.bind((self.address, self.port)) @@ -123,8 +119,8 @@ def run(self): # requests and reply (if they all try to use the multicast socket # to reply, they conflict and this sendto fails). response = response.encode('utf-8') - socket.socket(socket.AF_INET, socket.SOCK_DGRAM).sendto( - response, address) + socket.socket(socket.AF_INET, + socket.SOCK_DGRAM).sendto(response, address) _LOG.debug(log_line) except socket.timeout: pass @@ -145,25 +141,26 @@ def send(query, address: Multicast IP address component of the socket to send to. port: Multicast UDP port component of the socket to send to. ttl: TTL for multicast messages. 1 to keep traffic in-network. + local_only: if True, no packets will leave this host. timeout_s: Seconds to wait for responses. - Returns: A set of all responses that arrived before the timeout expired. - Responses are tuples of (sender_address, message). + Yields: + A set of all responses that arrived before the timeout expired. + Responses are tuples of (sender_address, message). """ # Set up the socket as a UDP Multicast socket with the given timeout. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) if local_only: # Set outgoing interface to localhost to ensure no packets leave this host. - sock.setsockopt( - socket.IPPROTO_IP, - socket.IP_MULTICAST_IF, - struct.pack('!L', LOCALHOST_ADDRESS)) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, + struct.pack('!L', LOCALHOST_ADDRESS)) sock.settimeout(timeout_s) sock.sendto(query.encode('utf-8'), (address, port)) # Set up our thread-safe Queue for handling responses. recv_queue = queue.Queue() + def _handle_responses(): while True: try: @@ -173,8 +170,8 @@ def _handle_responses(): recv_queue.put(None) break else: - _LOG.debug('Multicast response to query "%s": %s:%s', - query, address[0], data) + _LOG.debug('Multicast response to query "%s": %s:%s', query, address[0], + data) recv_queue.put((address[0], str(data))) # Yield responses as they come in, giving up once timeout expires. diff --git a/openhtf/util/test.py b/openhtf/util/test.py index 57a6c47ca..355b7d0f9 100644 --- a/openhtf/util/test.py +++ b/openhtf/util/test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Unit test helpers for OpenHTF tests and phases. This module provides some utility for unit testing OpenHTF test phases and @@ -117,16 +115,22 @@ def test_multiple(self, mock_my_plug): import logging import sys import types +from typing import Any, Callable, Dict, Iterable, List, Text, Tuple, Type import unittest +import attr import mock import openhtf from openhtf import plugs from openhtf import util +from openhtf.core import base_plugs from openhtf.core import diagnoses_lib from openhtf.core import measurements +from openhtf.core import phase_collections +from openhtf.core import phase_descriptor from openhtf.core import phase_executor +from openhtf.core import phase_nodes from openhtf.core import test_descriptor from openhtf.core import test_record from openhtf.core import test_state @@ -135,7 +139,6 @@ def test_multiple(self, mock_my_plug): import six from six.moves import collections_abc - logs.CLI_LOGGING_VERBOSITY = 2 @@ -143,10 +146,139 @@ class InvalidTestError(Exception): """Raised when there's something invalid about a test.""" +class _ValidTimestamp(int): + + def __eq__(self, other): + return other is not None and other > 0 + + +VALID_TIMESTAMP = _ValidTimestamp() + + +@attr.s(slots=True, frozen=True) +class TestNode(phase_nodes.PhaseNode): + """General base class for comparison nodes. + + This is used to test functions that create phase nodes; it cannot be run as + part of an actual test run. + """ + + def copy(self: phase_nodes.WithModifierT) -> phase_nodes.WithModifierT: + """Create a copy of the PhaseNode.""" + return self + + def with_args(self: phase_nodes.WithModifierT, + **kwargs: Any) -> phase_nodes.WithModifierT: + """Send these keyword-arguments when phases are called.""" + del kwargs # Unused. + return self + + def with_plugs( + self: phase_nodes.WithModifierT, + **subplugs: Type[base_plugs.BasePlug]) -> phase_nodes.WithModifierT: + """Substitute plugs for placeholders for this phase, error on unknowns.""" + del subplugs # Unused. + return self + + def load_code_info( + self: phase_nodes.WithModifierT) -> phase_nodes.WithModifierT: + """Load coded info for all contained phases.""" + return self + + def apply_to_all_phases(self, func: Any) -> 'TestNode': + return self + + +@attr.s(slots=True, frozen=True, eq=False) +class PhaseNodeNameComparable(TestNode): + """Compares truthfully against any phase node with the same name. + + This is used to test functions that create phase nodes; it cannot be run as + part of an actual test run. + """ + + name = attr.ib(type=Text) + + def _asdict(self) -> Dict[Text, Any]: + """Returns a base type dictionary for serialization.""" + return {'name': self.name} + + def __eq__(self, other: phase_nodes.PhaseNode) -> bool: + return self.name == other.name + + +@attr.s(slots=True, frozen=True, eq=False, init=False) +class PhaseNodeComparable(TestNode): + """Compares truthfully only against another with same data. + + This is used to test functions that create phase nodes; it cannot be run as + part of an actual test run. + """ + + name = attr.ib(type=Text) + args = attr.ib(type=Tuple[Any, ...], factory=tuple) + kwargs = attr.ib(type=Dict[Text, Any], factory=dict) + + def __init__(self, name, *args, **kwargs): + super(PhaseNodeComparable, self).__init__() + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'args', tuple(args)) + object.__setattr__(self, 'kwargs', kwargs) + + @classmethod + def create_constructor(cls, name) -> Callable[..., 'PhaseNodeComparable']: + + def constructor(*args, **kwargs): + return cls(name, *args, **kwargs) + + return constructor + + def _asdict(self) -> Dict[Text, Any]: + return {'name': self.name, 'args': self.args, 'kwargs': self.kwargs} + + def __eq__(self, other: phase_nodes.PhaseNode) -> bool: + return (isinstance(other, PhaseNodeComparable) and + self.name == other.name and self.args == other.args and + self.kwargs == other.kwargs) + + +class FakeTestApi(openhtf.TestApi): + """A fake TestApi used to test non-phase helper functions.""" + + def __init__(self): + self.mock_logger = mock.create_autospec(logging.Logger) + self.mock_phase_state = mock.create_autospec( + test_state.PhaseState, logger=self.mock_logger) + self.mock_test_state = mock.create_autospec( + test_state.TestState, + test_record=test_record.TestRecord('DUT', 'STATION'), + user_defined_state={}) + super(FakeTestApi, self).__init__( + measurements={}, + running_phase_state=self.mock_phase_state, + running_test_state=self.mock_test_state) + + +def filter_phases_by_names(phase_recs: Iterable[test_record.PhaseRecord], + *names: Text) -> Iterable[test_record.PhaseRecord]: + all_names = set(names) + for phase_rec in phase_recs: + if phase_rec.name in all_names: + yield phase_rec + + +def filter_phases_by_outcome( + phase_recs: Iterable[test_record.PhaseRecord], + outcome: test_record.PhaseOutcome) -> Iterable[test_record.PhaseRecord]: + for phase_rec in phase_recs: + if phase_rec.outcome == outcome: + yield phase_rec + + class PhaseOrTestIterator(collections_abc.Iterator): - def __init__(self, test_case, iterator, mock_plugs, - phase_user_defined_state, phase_diagnoses): + def __init__(self, test_case, iterator, mock_plugs, phase_user_defined_state, + phase_diagnoses): """Create an iterator for iterating over Tests or phases to run. This should only be instantiated internally. @@ -154,13 +286,13 @@ def __init__(self, test_case, iterator, mock_plugs, Args: test_case: TestCase subclass where the test case function is defined. iterator: Child iterator to use for obtaining Tests or test phases, must - be a generator. + be a generator. mock_plugs: Dict mapping plug types to mock objects to use instead of - actually instantiating that type. + actually instantiating that type. phase_user_defined_state: If not None, a dictionary that will be added to - the test_state.user_defined_state when handling phases. + the test_state.user_defined_state when handling phases. phase_diagnoses: If not None, must be a list of Diagnosis instances; these - are added to the DiagnosesManager when handling phases. + are added to the DiagnosesManager when handling phases. Raises: InvalidTestError: when iterator is not a generator. @@ -189,14 +321,14 @@ def _initialize_plugs(self, plug_types): # Make sure we initialize any plugs, this will ignore any that have already # been initialized. plug_types = list(plug_types) - self.plug_manager.initialize_plugs(plug_cls for plug_cls in plug_types if - plug_cls not in self.mock_plugs) + self.plug_manager.initialize_plugs( + plug_cls for plug_cls in plug_types if plug_cls not in self.mock_plugs) for plug_type, plug_value in six.iteritems(self.mock_plugs): self.plug_manager.update_plug(plug_type, plug_value) def _handle_phase(self, phase_desc): """Handle execution of a single test phase.""" - diagnoses_lib.check_for_duplicate_results([phase_desc], []) + diagnoses_lib.check_for_duplicate_results(iter([phase_desc]), []) logs.configure_logging() self._initialize_plugs(phase_plug.cls for phase_plug in phase_desc.plugs) @@ -205,10 +337,9 @@ def _handle_phase(self, phase_desc): with mock.patch( 'openhtf.plugs.PlugManager', new=lambda _, __: self.plug_manager): test_state_ = test_state.TestState( - openhtf.TestDescriptor((phase_desc,), phase_desc.code_info, {}), - 'Unittest:StubTest:UID', - test_options - ) + openhtf.TestDescriptor( + phase_collections.PhaseSequence((phase_desc,)), + phase_desc.code_info, {}), 'Unittest:StubTest:UID', test_options) test_state_.mark_test_started() test_state_.user_defined_state.update(self.phase_user_defined_state) @@ -224,11 +355,15 @@ def _handle_phase(self, phase_desc): executor = phase_executor.PhaseExecutor(test_state_) # Log an exception stack when a Phase errors out. with mock.patch.object( - phase_executor.PhaseExecutorThread, '_log_exception', + phase_executor.PhaseExecutorThread, + '_log_exception', side_effect=logging.exception): # Use _execute_phase_once because we want to expose all possible outcomes. phase_result, _ = executor._execute_phase_once( - phase_desc, is_last_repeat=False, run_with_profiling=False) + phase_desc, + is_last_repeat=False, + run_with_profiling=False, + subtest_rec=None) if phase_result.raised_exception: failure_message = phase_result.phase_result.get_traceback_string() @@ -250,14 +385,14 @@ def _handle_test(self, test): # Mock the PlugManager to use ours instead, and execute the test. with mock.patch( 'openhtf.plugs.PlugManager', new=lambda _, __: self.plug_manager): - test.execute(test_start=lambda: 'TestDutId') + test.execute(test_start=self.test_case.test_start_function) test_record_ = record_saver.result if test_record_.outcome_details: msgs = [] for detail in test_record_.outcome_details: - msgs.append( - 'code: {}\ndescription: {}'.format(detail.code, detail.description)) + msgs.append('code: {}\ndescription: {}'.format(detail.code, + detail.description)) failure_message = '\n'.join(msgs) else: failure_message = None @@ -273,7 +408,7 @@ def __next__(self): 'individual test phases', phase_or_test) else: self.last_result, failure_message = self._handle_phase( - openhtf.PhaseDescriptor.wrap_or_copy(phase_or_test)) + phase_descriptor.PhaseDescriptor.wrap_or_copy(phase_or_test)) return phase_or_test, self.last_result, failure_message def next(self): @@ -286,7 +421,7 @@ def next(self): 'individual test phases', phase_or_test) else: self.last_result, failure_message = self._handle_phase( - openhtf.PhaseDescriptor.wrap_or_copy(phase_or_test)) + phase_descriptor.PhaseDescriptor.wrap_or_copy(phase_or_test)) return phase_or_test, self.last_result, failure_message @@ -295,11 +430,11 @@ def yields_phases(func): return patch_plugs()(func) -def yields_phases_with(phase_user_defined_state=None, - phase_diagnoses=None): +def yields_phases_with(phase_user_defined_state=None, phase_diagnoses=None): """Apply patch_plugs with no plugs, but add test state modifications.""" - return patch_plugs(phase_user_defined_state=phase_user_defined_state, - phase_diagnoses=phase_diagnoses) + return patch_plugs( + phase_user_defined_state=phase_user_defined_state, + phase_diagnoses=phase_diagnoses) def patch_plugs(phase_user_defined_state=None, @@ -333,12 +468,12 @@ def test_my_phase_again(self, my_plug_mock): Args: phase_user_defined_state: If specified, a dictionary that will be added to - the test_state.user_defined_state when handling phases. + the test_state.user_defined_state when handling phases. phase_diagnoses: If specified, must be a list of Diagnosis instances; these - are added to the DiagnosesManager when handling phases. + are added to the DiagnosesManager when handling phases. **mock_plugs: kwargs mapping argument name to be passed to the test case to - a string describing the plug type to mock. The corresponding mock will - be passed to the decorated test case as a keyword argument. + a string describing the plug type to mock. The corresponding mock will be + passed to the decorated test case as a keyword argument. Returns: Function decorator that mocks plugs. @@ -348,7 +483,10 @@ def test_my_phase_again(self, my_plug_mock): assert isinstance(diag, diagnoses_lib.Diagnosis) def test_wrapper(test_func): - plug_argspec = inspect.getargspec(test_func) + if six.PY3: + plug_argspec = inspect.getfullargspec(test_func) + else: + plug_argspec = inspect.getargspec(test_func) # pylint: disable=deprecated-method num_defaults = len(plug_argspec.defaults or ()) plug_args = set(plug_argspec.args[1:-num_defaults or None]) @@ -376,18 +514,18 @@ def test_wrapper(test_func): logging.error("Invalid plug type specification %s='%s'", plug_arg_name, plug_fullname) raise - elif issubclass(plug_fullname, plugs.BasePlug): + elif issubclass(plug_fullname, base_plugs.BasePlug): plug_type = plug_fullname else: - raise ValueError('Invalid plug type specification %s="%s"' % ( - plug_arg_name, plug_fullname)) + raise ValueError('Invalid plug type specification %s="%s"' % + (plug_arg_name, plug_fullname)) if issubclass(plug_type, device_wrapping.DeviceWrappingPlug): # We can't strictly spec because calls to attributes are proxied to the # underlying device. plug_mock = mock.MagicMock() else: - plug_mock = mock.create_autospec(plug_type, spec_set=True, - instance=True) + plug_mock = mock.create_autospec( + plug_type, spec_set=True, instance=True) plug_typemap[plug_type] = plug_mock plug_kwargs[plug_arg_name] = plug_mock @@ -395,27 +533,71 @@ def test_wrapper(test_func): # name to match so we don't mess with unittest's TestLoader mechanism. @functools.wraps(test_func) def wrapped_test(self): - self.assertIsInstance(self, TestCase, - msg='Must derive from openhtf.util.test.TestCase ' - 'to use yields_phases/patch_plugs.') + self.assertIsInstance( + self, + TestCase, + msg='Must derive from openhtf.util.test.TestCase ' + 'to use yields_phases/patch_plugs.') for phase_or_test, result, failure_message in PhaseOrTestIterator( self, test_func(self, **plug_kwargs), plug_typemap, phase_user_defined_state, phase_diagnoses): logging.info('Ran %s, result: %s', phase_or_test, result) if failure_message: logging.error('Reported error:\n%s', failure_message) + return wrapped_test + return test_wrapper +def _assert_phase_or_test_record(func): + """Decorator for automatically invoking self.assertTestPhases when needed. + + This allows assertions to apply to a single phase or "any phase in the test" + without having to handle the type check themselves. Note that the record, + either PhaseRecord or TestRecord, must be the first argument to the + wrapped assertion method. + + In the case of a TestRecord, the assertion will pass if *any* PhaseRecord in + the TestRecord passes, otherwise the *last* exception raised will be + re-raised. + + Args: + func: the function to wrap. + + Returns: + Function decorator. + """ + + @functools.wraps(func) + def assertion_wrapper(self, phase_or_test_record, *args, **kwargs): + if isinstance(phase_or_test_record, test_record.TestRecord): + exc_info = None + for phase_record in phase_or_test_record.phases: + try: + func(self, phase_record, *args, **kwargs) + break + except Exception: # pylint: disable=broad-except + exc_info = sys.exc_info() + else: + if exc_info: + six.reraise(*exc_info) + elif isinstance(phase_or_test_record, test_record.PhaseRecord): + func(self, phase_or_test_record, *args, **kwargs) + else: + raise InvalidTestError('Expected either a PhaseRecord or TestRecord') + + return assertion_wrapper + + class TestCase(unittest.TestCase): def __init__(self, methodName=None): super(TestCase, self).__init__(methodName=methodName) test_method = getattr(self, methodName) if inspect.isgeneratorfunction(test_method): - raise ValueError( - '%s yields without @openhtf.util.test.yields_phases' % methodName) + raise ValueError('%s yields without @openhtf.util.test.yields_phases' % + methodName) def setUp(self): super(TestCase, self).setUp() @@ -423,40 +605,9 @@ def setUp(self): # attribute will be set to the openhtf.core.test_state.TestState used in the # phase execution. self.last_test_state = None - - def _AssertPhaseOrTestRecord(func): # pylint: disable=no-self-argument,invalid-name - """Decorator for automatically invoking self.assertTestPhases when needed. - - This allows assertions to apply to a single phase or "any phase in the test" - without having to handle the type check themselves. Note that the record, - either PhaseRecord or TestRecord, must be the first argument to the - wrapped assertion method. - - In the case of a TestRecord, the assertion will pass if *any* PhaseRecord in - the TestRecord passes, otherwise the *last* exception raised will be - re-raised. - - Returns: - Function decorator. - """ - @functools.wraps(func) - def assertion_wrapper(self, phase_or_test_record, *args, **kwargs): - if isinstance(phase_or_test_record, test_record.TestRecord): - exc_info = None - for phase_record in phase_or_test_record.phases: - try: - func(self, phase_record, *args, **kwargs) - break - except Exception: # pylint: disable=broad-except - exc_info = sys.exc_info() - else: - if exc_info: - raise exc_info[0](exc_info[1]).raise_with_traceback(exc_info[2]) - elif isinstance(phase_or_test_record, test_record.PhaseRecord): - func(self, phase_or_test_record, *args, **kwargs) - else: - raise InvalidTestError('Expected either a PhaseRecord or TestRecord') - return assertion_wrapper + # When a test is yielded, this function is provided to as the test_start + # argument to test.execute. + self.test_start_function = lambda: 'TestDutId' ##### TestRecord Assertions ##### @@ -483,12 +634,16 @@ def assertTestOutcomeCode(self, test_rec, code): ##### PhaseRecord Assertions ##### def assertPhaseContinue(self, phase_record): - self.assertIs( - openhtf.PhaseResult.CONTINUE, phase_record.result.phase_result) + self.assertIs(openhtf.PhaseResult.CONTINUE, + phase_record.result.phase_result) def assertPhaseFailAndContinue(self, phase_record): - self.assertIs( - openhtf.PhaseResult.FAIL_AND_CONTINUE, phase_record.result.phase_result) + self.assertIs(openhtf.PhaseResult.FAIL_AND_CONTINUE, + phase_record.result.phase_result) + + def assertPhaseFailSubtest(self, phase_record): + self.assertIs(openhtf.PhaseResult.FAIL_SUBTEST, + phase_record.result.phase_result) def assertPhaseRepeat(self, phase_record): self.assertIs(openhtf.PhaseResult.REPEAT, phase_record.result.phase_result) @@ -503,9 +658,10 @@ def assertPhaseError(self, phase_record, exc_type=None): self.assertTrue(phase_record.result.raised_exception, 'Phase did not raise an exception') if exc_type: - self.assertIsInstance(phase_record.result.phase_result.exc_val, exc_type, - 'Raised exception %r is not a subclass of %r' % - (phase_record.result.phase_result, exc_type)) + self.assertIsInstance( + phase_record.result.phase_result.exc_val, exc_type, + 'Raised exception %r is not a subclass of %r' % + (phase_record.result.phase_result, exc_type)) def assertPhaseTimeout(self, phase_record): self.assertTrue(phase_record.result.is_timeout) @@ -522,9 +678,28 @@ def assertPhaseOutcomeSkip(self, phase_record): def assertPhaseOutcomeError(self, phase_record): self.assertIs(test_record.PhaseOutcome.ERROR, phase_record.outcome) + def assertPhasesOutcomeByName(self, + expected_outcome: test_record.PhaseOutcome, + test_rec: test_record.TestRecord, + *phase_names: Text): + errors = [] # type: List[Text] + for phase_rec in filter_phases_by_names(test_rec.phases, *phase_names): + if phase_rec.outcome is not expected_outcome: + errors.append('Phase "{}" outcome: {}'.format(phase_rec.name, + phase_rec.outcome)) + self.assertFalse( + errors, + msg='Expected phases don\'t all have outcome {}: {}'.format( + expected_outcome.name, errors)) + + def assertPhasesNotRun(self, test_rec, *phase_names): + phases = list(filter_phases_by_names(test_rec.phases, *phase_names)) + self.assertFalse(phases) + ##### Measurement Assertions ##### def assertNotMeasured(self, phase_or_test_record, measurement): + def _check_phase(phase_record, strict=False): if strict: self.assertIn(measurement, phase_record.measurements) @@ -538,11 +713,11 @@ def _check_phase(phase_record, strict=False): if isinstance(phase_or_test_record, test_record.PhaseRecord): _check_phase(phase_or_test_record, True) else: - # Check *all* phases (not *any* like _AssertPhaseOrTestRecord). + # Check *all* phases (not *any* like _assert_phase_or_test_record). for phase_record in phase_or_test_record.phases: _check_phase(phase_record) - @_AssertPhaseOrTestRecord + @_assert_phase_or_test_record def assertMeasured(self, phase_record, measurement, value=mock.ANY): self.assertTrue( phase_record.measurements[measurement].measured_value.is_value_set, @@ -554,18 +729,32 @@ def assertMeasured(self, phase_record, measurement, value=mock.ANY): (measurement, value, phase_record.measurements[measurement].measured_value.value)) - @_AssertPhaseOrTestRecord + @_assert_phase_or_test_record def assertMeasurementPass(self, phase_record, measurement): self.assertMeasured(phase_record, measurement) self.assertIs(measurements.Outcome.PASS, phase_record.measurements[measurement].outcome) - @_AssertPhaseOrTestRecord + @_assert_phase_or_test_record def assertMeasurementFail(self, phase_record, measurement): self.assertMeasured(phase_record, measurement) self.assertIs(measurements.Outcome.FAIL, phase_record.measurements[measurement].outcome) + @_assert_phase_or_test_record + def assertAttachment(self, + phase_record, + attachment_name, + expected_contents=mock.ANY): + self.assertIn(attachment_name, phase_record.attachments, + 'Attachment {} not attached.'.format(attachment_name)) + if expected_contents is not mock.ANY: + data = phase_record.attachments[attachment_name].data + self.assertEqual( + expected_contents, data, + 'Attachment {} has wrong value: expected {}, got {}'.format( + attachment_name, expected_contents, data)) + def get_diagnoses_store(self): self.assertIsNotNone(self.last_test_state) return self.last_test_state.diagnoses_manager.store diff --git a/openhtf/util/threads.py b/openhtf/util/threads.py index cd756bef4..38feddb72 100644 --- a/openhtf/util/threads.py +++ b/openhtf/util/threads.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Thread library defining a few helpers.""" import contextlib @@ -26,10 +24,9 @@ import six try: - from six.moves import _thread + from six.moves import _thread # pylint: disable=g-import-not-at-top except ImportError: - from six.moves import _dummy_thread as _thread - + from six.moves import _dummy_thread as _thread # pylint: disable=g-import-not-at-top _LOG = logging.getLogger(__name__) @@ -101,10 +98,12 @@ def _safe_lock_release_py2(rlock): rlock._RLock__count = 0 rlock._RLock__block.release() raise + + # pylint: enable=protected-access -def loop(_=None, force=False): # pylint: disable=invalid-name +def loop(_=None, force=False): """Causes a function to loop indefinitely.""" if not force: raise AttributeError( @@ -113,14 +112,17 @@ def loop(_=None, force=False): # pylint: disable=invalid-name 'and use it as @loop(force=True) for now.') def real_loop(fn): + @functools.wraps(fn) def _proc(*args, **kwargs): """Wrapper to return.""" while True: fn(*args, **kwargs) + _proc.once = fn # way for tests to invoke the function once - # you may need to pass in "self" since this may be unbound. + # you may need to pass in "self" since this may be unbound. return _proc + return real_loop @@ -149,11 +151,15 @@ class is meant to be subclassed. If you were to invoke this with def __init__(self, *args, **kwargs): """Initializer for KillableThread. + The keyword argument `run_with_profiling` is extracted from kwargs. If + True, run this thread with profiling data collection. + Args: - run_with_profiling: Whether to run this thread with profiling data - collection. Must be passed by keyword. + *args: Passed to the base class. + **kwargs: Passed to the base class. """ - self._run_with_profiling = kwargs.pop('run_with_profiling', None) + self._run_with_profiling = kwargs.pop('run_with_profiling', + False) # type: bool super(KillableThread, self).__init__(*args, **kwargs) self._running_lock = threading.Lock() self._killed = threading.Event() @@ -180,17 +186,17 @@ def run(self): if self._profiler is not None: self._profiler.disable() - def get_profile_stats(self): + def get_profile_stats(self) -> pstats.Stats: """Returns profile_stats from profiler. Raises if profiling not enabled.""" if self._profiler is not None: return pstats.Stats(self._profiler) raise InvalidUsageError( 'Profiling not enabled via __init__, or thread has not run yet.') - def _is_thread_proc_running(self): + def _is_thread_proc_running(self) -> bool: # Acquire the lock without blocking, though this object is fully implemented # in C, so we cannot specify keyword arguments. - could_acquire = self._running_lock.acquire(0) + could_acquire = self._running_lock.acquire(False) if could_acquire: self._running_lock.release() return False @@ -216,6 +222,8 @@ def _thread_exception(self, exc_type, exc_val, exc_tb): True if the exception should be ignored. The default case ignores the exception raised by the kill functionality. """ + del exc_val # Unused. + del exc_tb # Unused. return exc_type is ThreadTerminationError def kill(self): @@ -237,8 +245,8 @@ def async_raise(self, exc_type): # If the thread has died we don't want to raise an exception so log. if not self.is_alive(): - _LOG.debug('Not raising %s because thread %s (%s) is not alive', - exc_type, self.name, self.ident) + _LOG.debug('Not raising %s because thread %s (%s) is not alive', exc_type, + self.name, self.ident) return result = ctypes.pythonapi.PyThreadState_SetAsyncExc( @@ -249,8 +257,8 @@ def async_raise(self, exc_type): elif result > 1: # Something bad happened, call with a NULL exception to undo. ctypes.pythonapi.PyThreadState_SetAsyncExc(self.ident, None) - raise RuntimeError('Error: PyThreadState_SetAsyncExc %s %s (%s) %s' % ( - exc_type, self.name, self.ident, result)) + raise RuntimeError('Error: PyThreadState_SetAsyncExc %s %s (%s) %s' % + (exc_type, self.name, self.ident, result)) class NoneByDefaultThreadLocal(threading.local): @@ -262,12 +270,13 @@ class NoneByDefaultThreadLocal(threading.local): check. """ - def __getattr__(self, _): # pylint: disable=invalid-name + def __getattr__(self, _): return None -def synchronized(func): # pylint: disable=invalid-name +def synchronized(func): """Hold self._lock while executing func.""" + @functools.wraps(func) def synchronized_method(self, *args, **kwargs): """Wrapper to return.""" @@ -282,4 +291,5 @@ def synchronized_method(self, *args, **kwargs): (func.__name__, type(self).__name__, hint)) with self._lock: # pylint: disable=protected-access return func(self, *args, **kwargs) + return synchronized_method diff --git a/openhtf/util/timeouts.py b/openhtf/util/timeouts.py index 5cfce3254..930d56c42 100644 --- a/openhtf/util/timeouts.py +++ b/openhtf/util/timeouts.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """A simple utility to do timeout checking.""" import contextlib @@ -23,6 +21,7 @@ _LOG = logging.getLogger(__name__) + class PolledTimeout(object): """An object which tracks if a timeout has expired.""" @@ -59,7 +58,7 @@ def from_millis(cls, timeout_ms): return cls(timeout_ms / 1000.0) @classmethod - def from_seconds(cls, timeout_s): + def from_seconds(cls, timeout_s) -> 'PolledTimeout': """Create a new PolledTimeout if needed. If timeout_s is already a PolledTimeout, just return it, otherwise create a @@ -67,13 +66,13 @@ def from_seconds(cls, timeout_s): Args: timeout_s: PolledTimeout object, or number of seconds to use for creating - a new one. + a new one. Returns: A PolledTimeout object that will expire in timeout_s seconds, which may be timeout_s itself, or a newly allocated PolledTimeout. """ - if hasattr(timeout_s, 'has_expired'): + if isinstance(timeout_s, cls): return timeout_s return cls(timeout_s) @@ -119,7 +118,7 @@ def remaining_ms(self): # There's now no way to tell if a timeout occurred generically # which sort of sucks (for generic validation fn) -def loop_until_timeout_or_valid(timeout_s, function, validation_fn, sleep_s=1): # pylint: disable=invalid-name +def loop_until_timeout_or_valid(timeout_s, function, validation_fn, sleep_s=1): """Loops until the specified function returns valid or a timeout is reached. Note: The function may return anything which, when passed to validation_fn, @@ -130,11 +129,11 @@ def loop_until_timeout_or_valid(timeout_s, function, validation_fn, sleep_s=1): Args: timeout_s: The number of seconds to wait until a timeout condition is - reached. As a convenience, this accepts None to mean never timeout. Can - also be passed a PolledTimeout object instead of an integer. + reached. As a convenience, this accepts None to mean never timeout. Can + also be passed a PolledTimeout object instead of an integer. function: The function to call each iteration. validation_fn: The validation function called on the function result to - determine whether to keep looping. + determine whether to keep looping. sleep_s: The number of seconds to wait after calling the function. Returns: @@ -151,7 +150,7 @@ def loop_until_timeout_or_valid(timeout_s, function, validation_fn, sleep_s=1): time.sleep(sleep_s) -def loop_until_timeout_or_true(timeout_s, function, sleep_s=1): # pylint: disable=invalid-name +def loop_until_timeout_or_true(timeout_s, function, sleep_s=1): """Loops until the specified function returns True or a timeout is reached. Note: The function may return anything which evaluates to implicit True. This @@ -161,8 +160,8 @@ def loop_until_timeout_or_true(timeout_s, function, sleep_s=1): # pylint: disab Args: timeout_s: The number of seconds to wait until a timeout condition is - reached. As a convenience, this accepts None to mean never timeout. Can - also be passed a PolledTimeout object instead of an integer. + reached. As a convenience, this accepts None to mean never timeout. Can + also be passed a PolledTimeout object instead of an integer. function: The function to call each iteration. sleep_s: The number of seconds to wait after calling the function. @@ -172,21 +171,21 @@ def loop_until_timeout_or_true(timeout_s, function, sleep_s=1): # pylint: disab return loop_until_timeout_or_valid(timeout_s, function, lambda x: x, sleep_s) -def loop_until_timeout_or_not_none(timeout_s, function, sleep_s=1): # pylint: disable=invalid-name +def loop_until_timeout_or_not_none(timeout_s, function, sleep_s=1): """Loops until the specified function returns non-None or until a timeout. Args: timeout_s: The number of seconds to wait until a timeout condition is - reached. As a convenience, this accepts None to mean never timeout. Can - also be passed a PolledTimeout object instead of an integer. + reached. As a convenience, this accepts None to mean never timeout. Can + also be passed a PolledTimeout object instead of an integer. function: The function to call each iteration. sleep_s: The number of seconds to wait after calling the function. Returns: Whatever the function returned last. """ - return loop_until_timeout_or_valid( - timeout_s, function, lambda x: x is not None, sleep_s) + return loop_until_timeout_or_valid(timeout_s, function, + lambda x: x is not None, sleep_s) def loop_until_true_else_raise(timeout_s, @@ -198,23 +197,25 @@ def loop_until_true_else_raise(timeout_s, Args: timeout_s: The number of seconds to wait until a timeout condition is - reached. As a convenience, this accepts None to mean never timeout. Can - also be passed a PolledTimeout object instead of an integer. + reached. As a convenience, this accepts None to mean never timeout. Can + also be passed a PolledTimeout object instead of an integer. function: The function to call each iteration. invert: If True, wait for the callable to return falsey instead of truthy. message: Optional custom error message to use on a timeout. sleep_s: Seconds to sleep between call attempts. + Raises: + RuntimeError: if the timeout is reached before the function returns truthy. + Returns: The final return value of the function. - - Raises: - RuntimeError if the timeout is reached before the function returns truthy. """ + def validate(x): return bool(x) != invert - result = loop_until_timeout_or_valid(timeout_s, function, validate, sleep_s=1) + result = loop_until_timeout_or_valid( + timeout_s, function, validate, sleep_s=sleep_s) if validate(result): return result @@ -224,12 +225,11 @@ def validate(x): name = '(unknown)' if hasattr(function, '__name__'): name = function.__name__ - elif (isinstance(function, functools.partial) - and hasattr(function.func, '__name__')): + elif (isinstance(function, functools.partial) and + hasattr(function.func, '__name__')): name = function.func.__name__ - raise RuntimeError( - 'Function %s failed to return %s within %d seconds.' - % (name, 'falsey' if invert else 'truthy', timeout_s)) + raise RuntimeError('Function %s failed to return %s within %d seconds.' % + (name, 'falsey' if invert else 'truthy', timeout_s)) class Interval(object): @@ -241,7 +241,7 @@ def __init__(self, method, stop_if_false=False): Args: method: A callable to execute, it should take no arguments. stop_if_false: If True, the interval will exit if the method returns - False. + False. """ self.method = method self.stopped = threading.Event() @@ -259,6 +259,7 @@ def start(self, interval_s): Args: interval_s: The amount of time between executions of the method. + Returns: False if the interval was already running. """ @@ -288,7 +289,8 @@ def stop(self, timeout_s=None): Args: timeout_s: The time in seconds to wait on the thread to finish. By - default it's forever. + default it's forever. + Returns: False if a timeout was provided and we timed out. """ @@ -304,6 +306,7 @@ def join(self, timeout_s=None): Args: timeout_s: The time in seconds to wait, defaults to forever. + Returns: True if the interval is still running and we reached the timeout. """ @@ -313,13 +316,14 @@ def join(self, timeout_s=None): return self.running -def execute_forever(method, interval_s): # pylint: disable=invalid-name +def execute_forever(method, interval_s): """Executes a method forever at the specified interval. Args: method: The callable to execute. interval_s: The number of seconds to start the execution after each method - finishes. + finishes. + Returns: An Interval object. """ @@ -328,13 +332,14 @@ def execute_forever(method, interval_s): # pylint: disable=invalid-name return interval -def execute_until_false(method, interval_s): # pylint: disable=invalid-name +def execute_until_false(method, interval_s): """Executes a method forever until the method returns a false value. Args: method: The callable to execute. interval_s: The number of seconds to start the execution after each method - finishes. + finishes. + Returns: An Interval object. """ @@ -343,22 +348,29 @@ def execute_until_false(method, interval_s): # pylint: disable=invalid-name return interval -# pylint: disable=invalid-name -def retry_until_true_or_limit_reached(method, limit, sleep_s=1, +def retry_until_true_or_limit_reached(method, + limit, + sleep_s=1, catch_exceptions=()): """Executes a method until the retry limit is hit or True is returned.""" - return retry_until_valid_or_limit_reached( - method, limit, lambda x: x, sleep_s, catch_exceptions) + return retry_until_valid_or_limit_reached(method, limit, lambda x: x, sleep_s, + catch_exceptions) -def retry_until_not_none_or_limit_reached(method, limit, sleep_s=1, +def retry_until_not_none_or_limit_reached(method, + limit, + sleep_s=1, catch_exceptions=()): """Executes a method until the retry limit is hit or not None is returned.""" - return retry_until_valid_or_limit_reached( - method, limit, lambda x: x is not None, sleep_s, catch_exceptions) + return retry_until_valid_or_limit_reached(method, limit, + lambda x: x is not None, sleep_s, + catch_exceptions) -def retry_until_valid_or_limit_reached(method, limit, validation_fn, sleep_s=1, +def retry_until_valid_or_limit_reached(method, + limit, + validation_fn, + sleep_s=1, catch_exceptions=()): """Executes a method until the retry limit or validation_fn returns True. @@ -370,9 +382,10 @@ def retry_until_valid_or_limit_reached(method, limit, validation_fn, sleep_s=1, method: The method to execute should take no arguments. limit: The number of times to try this method. Must be >0. validation_fn: The validation function called on the function result to - determine whether to keep looping. + determine whether to keep looping. sleep_s: The time to sleep in between invocations. catch_exceptions: Tuple of exception types to catch and count as failures. + Returns: Whatever the method last returned, implicit False would indicate the method never succeeded. @@ -394,8 +407,6 @@ def _execute_method(helper): result = _execute_method(helper) return result -# pylint: disable=invalid-name - @contextlib.contextmanager def take_at_least_n_seconds(time_s): @@ -410,6 +421,7 @@ def take_at_least_n_seconds(time_s): Args: time_s: The number of seconds this block should take. If it doesn't take at least this time, then this method blocks during __exit__. + Yields: To do some actions then on completion waits the remaining time. """ @@ -429,6 +441,7 @@ def take_at_most_n_seconds(time_s, func, *args, **kwargs): func: Function to call. *args: Arguments to call the function with. **kwargs: Keyword arguments to call the function with. + Returns: True if the function finished in less than time_s seconds. """ @@ -463,6 +476,7 @@ def target(): func(*args, **kwargs) except Exception: # pylint: disable=broad-except _LOG.exception('Error executing %s after %s expires.', func, timeout) + if timeout.remaining is not None: thread = threading.Thread(target=target) thread.start() diff --git a/openhtf/util/units.py b/openhtf/util/units.py index b49e191fa..c4bf4216a 100644 --- a/openhtf/util/units.py +++ b/openhtf/util/units.py @@ -41,7 +41,14 @@ import collections -UnitDescriptor = collections.namedtuple('UnitDescriptor', 'name code suffix') +class UnitDescriptor( + collections.namedtuple('UnitDescriptor', [ + 'name', + 'code', + 'suffix', + ])): + pass + ALL_UNITS = [] @@ -4245,8 +4252,10 @@ # pylint: enable=line-too-long + class UnitLookup(object): """Facilitates user-friendly access to units.""" + def __init__(self, lookup): self._lookup = lookup diff --git a/openhtf/util/validators.py b/openhtf/util/validators.py index c94dd0816..7b43fe187 100644 --- a/openhtf/util/validators.py +++ b/openhtf/util/validators.py @@ -5,8 +5,9 @@ module, and will typically be a type, instances of which are callable: from openhtf.util import validators + from openhtf.util import measurements - class MyLessThanValidator(validators.ValidatorBase): + class MyLessThanValidator(ValidatorBase): def __init__(self, limit): self.limit = limit @@ -34,7 +35,7 @@ def LessThan4(value): return value < 4 @measurements.measures( - measurements.Measurement('my_measurement').with_validator(LessThan4)) + measurements.Measurement('my_measurement).with_validator(LessThan4)) def MyPhase(test): test.measurements.my_measurement = 5 # Will also 'FAIL' @@ -79,19 +80,20 @@ def has_validator(name): def create_validator(name, *args, **kwargs): return _VALIDATORS[name](*args, **kwargs) + _identity = lambda x: x class ValidatorBase(with_metaclass(abc.ABCMeta, object)): - @abc.abstractmethod + @abc.abstractmethod def __call__(self, value): """Should validate value, returning a boolean result.""" class RangeValidatorBase(with_metaclass(abc.ABCMeta, ValidatorBase)): - @abc.abstractproperty + @abc.abstractproperty def minimum(self): """Should return the minimum, inclusive value of the range.""" @@ -104,6 +106,7 @@ def maximum(self): class AllInRangeValidator(ValidatorBase): def __init__(self, min_value, max_value): + super(AllInRangeValidator, self).__init__() self.min_value = min_value self.max_value = max_value @@ -114,11 +117,13 @@ def __call__(self, values): class AllEqualsValidator(ValidatorBase): def __init__(self, spec): + super(AllEqualsValidator, self).__init__() self.spec = spec def __call__(self, values): return all([value == self.spec for value in values]) + register(AllInRangeValidator, name='all_in_range') register(AllEqualsValidator, name='all_equals') @@ -126,13 +131,13 @@ def __call__(self, values): class InRange(RangeValidatorBase): """Validator to verify a numeric value is within a range.""" - def __init__(self, minimum=None, maximum=None, type=None): + def __init__(self, minimum=None, maximum=None, type=None): # pylint: disable=redefined-builtin + super(InRange, self).__init__() if minimum is None and maximum is None: raise ValueError('Must specify minimum, maximum, or both') - if (minimum is not None and maximum is not None - and isinstance(minimum, numbers.Number) - and isinstance(maximum, numbers.Number) - and minimum > maximum): + if (minimum is not None and maximum is not None and + isinstance(minimum, numbers.Number) and + isinstance(maximum, numbers.Number) and minimum > maximum): raise ValueError('Minimum cannot be greater than maximum') self._minimum = minimum self._maximum = maximum @@ -178,18 +183,19 @@ def __str__(self): return 'x <= %s' % self._maximum def __eq__(self, other): - return (isinstance(other, type(self)) and - self.minimum == other.minimum and self.maximum == other.maximum) + return (isinstance(other, type(self)) and self.minimum == other.minimum and + self.maximum == other.maximum) def __ne__(self, other): return not self == other + in_range = InRange # pylint: disable=invalid-name register(in_range, name='in_range') @register -def equals(value, type=None): +def equals(value, type=None): # pylint: disable=redefined-builtin if isinstance(value, numbers.Number): return InRange(minimum=value, maximum=value, type=type) elif isinstance(value, six.string_types): @@ -203,7 +209,7 @@ def equals(value, type=None): class Equals(object): """Validator to verify an object is equal to the expected value.""" - def __init__(self, expected, type=None): + def __init__(self, expected, type=None): # pylint: disable=redefined-builtin self._expected = expected self._type = type @@ -254,6 +260,7 @@ class WithinPercent(RangeValidatorBase): """Validates that a number is within percent of a value.""" def __init__(self, expected, percent): + super(WithinPercent, self).__init__() if percent < 0: raise ValueError('percent argument is {}, must be >0'.format(percent)) self.expected = expected @@ -279,8 +286,7 @@ def __str__(self): def __eq__(self, other): return (isinstance(other, type(self)) and - self.expected == other.expected and - self.percent == other.percent) + self.expected == other.expected and self.percent == other.percent) def __ne__(self, other): return not self == other @@ -291,39 +297,45 @@ def within_percent(expected, percent): return WithinPercent(expected, percent) -class WithinTolerance(RangeValidatorBase): - """Validates that a number is within a given tolerance of a value.""" +class DimensionPivot(ValidatorBase): + """Runs a validator on each actual value of a dimensioned measurement.""" - def __init__(self, expected, tolerance): - if tolerance < 0: - raise ValueError( - 'tolerance argument is {}, must be >0'.format(tolerance)) - self.expected = expected - self.tolerance = tolerance + def __init__(self, sub_validator): + super(DimensionPivot, self).__init__() + self._sub_validator = sub_validator - @property - def minimum(self): - return self.expected - self.tolerance + def __call__(self, dimensioned_value): + return all(self._sub_validator(row[-1]) for row in dimensioned_value) - @property - def maximum(self): - return self.expected + self.tolerance + def __str__(self): + return 'All values pass: {}'.format(str(self._sub_validator)) - def __call__(self, value): - return self.minimum <= value <= self.maximum - def __str__(self): - return "'x' is within {} of {}".format(self.tolerance, self.expected) +@register +def dimension_pivot_validate(sub_validator): + return DimensionPivot(sub_validator) - def __eq__(self, other): - return (isinstance(other, type(self)) and - self.expected == other.expected and - self.tolerance == other.tolerance) - def __ne__(self, other): - return not self == other +class ConsistentEndDimensionPivot(ValidatorBase): + """If any rows validate, all following rows must also validate.""" + + def __init__(self, sub_validator): + super(ConsistentEndDimensionPivot, self).__init__() + self._sub_validator = sub_validator + + def __call__(self, dimensioned_value): + for index, row in enumerate(dimensioned_value): + if self._sub_validator(row[-1]): + i = index + break + else: + return False + return all(self._sub_validator(rest[-1]) for rest in dimensioned_value[i:]) + + def __str__(self): + return 'Once pass, rest must also pass: {}'.format(str(self._sub_validator)) @register -def within_tolerance(expected, tolerance): - return WithinTolerance(expected, tolerance) +def consistent_end_dimension_pivot_validate(sub_validator): + return ConsistentEndDimensionPivot(sub_validator) diff --git a/openhtf/util/xmlrpcutil.py b/openhtf/util/xmlrpcutil.py index 75bd04aa4..7e9f81bb9 100644 --- a/openhtf/util/xmlrpcutil.py +++ b/openhtf/util/xmlrpcutil.py @@ -11,29 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utility helpers for xmlrpclib.""" import http.client -import xmlrpc.server import socketserver import sys import threading import xmlrpc.client -import collections -from six.moves import collections_abc +import xmlrpc.server +import six +from six.moves import collections_abc DEFAULT_PROXY_TIMEOUT_S = 3 # https://github.com/PythonCharmers/python-future/issues/280 +# pylint: disable=g-import-not-at-top,g-importing-member if sys.version_info[0] < 3: - from SimpleXMLRPCServer import SimpleXMLRPCServer + from SimpleXMLRPCServer import SimpleXMLRPCServer # pytype: disable=import-error else: - from xmlrpc.server import SimpleXMLRPCServer as SimpleXMLRPCServer + from xmlrpc.server import SimpleXMLRPCServer # pytype: disable=import-error +# pylint: enable=g-import-not-at-top,g-importing-member -class TimeoutHTTPConnection(http.client.HTTPConnection): +class TimeoutHTTPConnection(http.client.HTTPConnection): # pylint: disable=missing-class-docstring + def __init__(self, timeout_s, *args, **kwargs): http.client.HTTPConnection.__init__(self, *args, **kwargs) self.timeout_s = timeout_s @@ -47,7 +49,8 @@ def connect(self): self.sock.settimeout(self.timeout_s) -class TimeoutTransport(xmlrpc.client.Transport): +class TimeoutTransport(xmlrpc.client.Transport): # pylint: disable=missing-class-docstring + def __init__(self, timeout_s, *args, **kwargs): xmlrpc.client.Transport.__init__(self, *args, **kwargs) self._connection = None @@ -70,14 +73,18 @@ class BaseServerProxy(xmlrpc.client.ServerProxy, object): class TimeoutProxyMixin(object): """Timeouts for ServerProxy objects.""" + def __init__(self, *args, **kwargs): - super(TimeoutProxyMixin, self).__init__( + kwargs.update( transport=TimeoutTransport( - kwargs.pop('timeout_s', DEFAULT_PROXY_TIMEOUT_S)), - *args, **kwargs) + kwargs.pop('timeout_s', DEFAULT_PROXY_TIMEOUT_S))) + super(TimeoutProxyMixin, self).__init__(*args, **kwargs) def __settimeout(self, timeout_s): - self.__transport.settimeout(timeout_s) + if six.PY3: + self._transport.settimeout(timeout_s) # pytype: disable=attribute-error + else: + self.__transport.settimeout(timeout_s) # pytype: disable=attribute-error class TimeoutProxyServer(TimeoutProxyMixin, BaseServerProxy): @@ -86,17 +93,19 @@ class TimeoutProxyServer(TimeoutProxyMixin, BaseServerProxy): class LockedProxyMixin(object): """A ServerProxy that locks calls to methods.""" + def __init__(self, *args, **kwargs): super(LockedProxyMixin, self).__init__(*args, **kwargs) self._lock = threading.Lock() def __getattr__(self, attr): - method = super(LockedProxyMixin, self).__getattr__(attr) + method = super(LockedProxyMixin, self).__getattr__(attr) # pytype: disable=attribute-error if isinstance(method, collections_abc.Callable): # xmlrpc doesn't support **kwargs, so only accept *args. def _wrapper(*args): with self._lock: return method(*args) + # functools.wraps() doesn't work with _Method internal type within # xmlrpclib. We only care about the name anyway, so manually set it. _wrapper.__name__ = attr @@ -108,7 +117,7 @@ class LockedTimeoutProxy(TimeoutProxyMixin, LockedProxyMixin, BaseServerProxy): """ServerProxy with additional features we use.""" -class SimpleThreadedXmlRpcServer( - socketserver.ThreadingMixIn, SimpleXMLRPCServer): +class SimpleThreadedXmlRpcServer(socketserver.ThreadingMixIn, + SimpleXMLRPCServer): """Helper for handling multiple simultaneous RPCs in threads.""" daemon_threads = True diff --git a/pylint_plugins/mutablerecords_plugin.py b/pylint_plugins/mutablerecords_plugin.py deleted file mode 100644 index 13a5df950..000000000 --- a/pylint_plugins/mutablerecords_plugin.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import astroid - -from astroid import MANAGER - - -def __init__(self): - pass - - -def mutable_record_transform(cls): - """Transform mutable records usage by updating locals.""" - if not (len(cls.bases) > 0 - and isinstance(cls.bases[0], astroid.Call) - and cls.bases[0].func.as_string() == 'mutablerecords.Record'): - return - - try: - # Add required attributes. - if len(cls.bases[0].args) >= 2: - for a in cls.bases[0].args[1].elts: - cls.locals[a] = [None] - - # Add optional attributes. - if len(cls.bases[0].args) >= 3: - for a,b in cls.bases[0].args[2].items: - cls.locals[a.value] = [None] - - except: - raise SyntaxError('Invalid mutablerecords syntax') - - -def register(linter): - """Register transform with the linter.""" - MANAGER.register_transform(astroid.ClassDef, mutable_record_transform) diff --git a/setup.py b/setup.py index 04c8c9fe9..69429901f 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ import subprocess import sys +# pylint: disable=g-importing-member,g-bad-import-order from distutils.command.build import build from distutils.command.clean import clean from distutils.cmd import Command @@ -54,7 +55,8 @@ def initialize_options(self): self.skip_proto = False try: prefix = subprocess.check_output( - 'pkg-config --variable prefix protobuf'.split()).strip().decode('utf-8') + 'pkg-config --variable prefix protobuf'.split()).strip().decode( + 'utf-8') except (subprocess.CalledProcessError, OSError): if platform.system() == 'Linux': # Default to /usr? @@ -69,11 +71,11 @@ def initialize_options(self): maybe_protoc = os.path.join(prefix, 'bin', 'protoc') if os.path.isfile(maybe_protoc) and os.access(maybe_protoc, os.X_OK): - self.protoc = maybe_protoc + self.protoc = maybe_protoc else: - print('Warning: protoc not found at %s' % maybe_protoc) - print('setup will attempt to run protoc with no prefix.') - self.protoc = 'protoc' + print('Warning: protoc not found at %s' % maybe_protoc) + print('setup will attempt to run protoc with no prefix.') + self.protoc = 'protoc' self.protodir = os.path.join(prefix, 'include') self.indir = os.getcwd() @@ -95,9 +97,12 @@ def run(self): print('Attempting to build proto files:\n%s' % '\n'.join(protos)) cmd = [ self.protoc, - '--proto_path', self.indir, - '--proto_path', self.protodir, - '--python_out', self.outdir, + '--proto_path', + self.indir, + '--proto_path', + self.protodir, + '--python_out', + self.outdir, ] + protos try: subprocess.check_call(cmd) @@ -123,7 +128,6 @@ def run(self): # Make building protos part of building overall. build.sub_commands.insert(0, ('build_proto', None)) - INSTALL_REQUIRES = [ 'attrs>=19.3.0', 'colorama>=0.3.9,<1.0', @@ -138,10 +142,11 @@ def run(self): 'sockjs-tornado>=1.0.3,<2.0', 'tornado>=4.3,<5.0', 'six>=1.13.0', + 'typing-extensions', ] -class PyTestCommand(test): +class PyTestCommand(test): # pylint: disable=missing-class-docstring # Derived from # https://github.com/chainreactionmfg/cara/blob/master/setup.py user_options = [ @@ -163,7 +168,7 @@ def finalize_options(self): def run_tests(self): self.run_command('build_proto') - import pytest + import pytest # pylint: disable=g-import-not-at-top cov = [] if self.pytest_cov is not None: outputs = [] @@ -185,13 +190,14 @@ def run_tests(self): maintainer='Joe Ethier', maintainer_email='jethier@google.com', packages=find_packages(), - package_data={'openhtf': ['output/proto/*.proto', - 'output/web_gui/dist/*.*', - 'output/web_gui/dist/css/*', - 'output/web_gui/dist/js/*', - 'output/web_gui/dist/img/*', - 'output/web_gui/*.*']}, - python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*', + package_data={ + 'openhtf': [ + 'output/proto/*.proto', 'output/web_gui/dist/*.*', + 'output/web_gui/dist/css/*', 'output/web_gui/dist/js/*', + 'output/web_gui/dist/img/*', 'output/web_gui/*.*' + ] + }, + python_requires='>=3.6', cmdclass={ 'build_proto': BuildProtoCommand, 'clean': CleanCommand, @@ -203,17 +209,12 @@ def run_tests(self): 'libusb1>=1.3.0,<2.0', 'M2Crypto>=0.22.3,<1.0', ], - 'update_units': [ - 'xlrd>=1.0.0,<2.0', - ], - 'serial_collection_plug': [ - 'pyserial>=3.3.0,<4.0', - ], - 'examples': [ - 'pandas>=0.22.0', - ], + 'update_units': ['xlrd>=1.0.0,<2.0',], + 'serial_collection_plug': ['pyserial>=3.3.0,<4.0',], + 'examples': ['pandas>=0.22.0',], }, tests_require=[ + 'absl-py>=0.10.0', 'mock>=2.0.0', # Remove max version here after we drop Python 2 support. 'pandas>=0.22.0,<0.25.0', diff --git a/test/capture_source_test.py b/test/capture_source_test.py index b1e1b0728..3f25cc9e1 100644 --- a/test/capture_source_test.py +++ b/test/capture_source_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import typing import unittest import openhtf as htf @@ -29,12 +29,14 @@ class BasicCodeCaptureTest(unittest.TestCase): def testCaptured(self): htf.conf.load(capture_source=True) test = htf.Test(phase) - phase_descriptor = list(test.descriptor.phase_group)[0] + phase_descriptor = typing.cast(htf.PhaseDescriptor, + test.descriptor.phase_sequence.nodes[0]) self.assertEqual(phase_descriptor.code_info.name, phase.__name__) @htf.conf.save_and_restore def testNotCaptured(self): htf.conf.load(capture_source=False) test = htf.Test(phase) - phase_descriptor = list(test.descriptor.phase_group)[0] + phase_descriptor = typing.cast(htf.PhaseDescriptor, + test.descriptor.phase_sequence.nodes[0]) self.assertEqual(phase_descriptor.code_info.name, '') diff --git a/test/core/diagnoses_test.py b/test/core/diagnoses_test.py index f5027badf..76ac585de 100644 --- a/test/core/diagnoses_test.py +++ b/test/core/diagnoses_test.py @@ -1,10 +1,6 @@ -# Lint as: python2, python3 +# Lint as: python3 """Tests for Diagnoses in OpenHTF.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import time import unittest @@ -31,12 +27,12 @@ class DiagTestError(Exception): pass -@htf.TestPhase() +@htf.PhaseOptions() def basic_phase(): pass -@htf.TestPhase() +@htf.PhaseOptions() def exception_phase(): raise PhaseError('it broke') @@ -120,7 +116,7 @@ def get_mock_diag(**kwargs): kwargs['return_value'] = None mock_diag = mock.MagicMock(**kwargs) return diagnoses_lib.PhaseDiagnoser( - OkayResult, name='mock_diag', run_func=mock_diag) + OkayResult, name='mock_diag', run_func=mock_diag), mock_diag class DupeResultA(htf.DiagResultEnum): @@ -183,7 +179,7 @@ def b2(): pass with self.assertRaises(diagnoses_lib.DuplicateResultError): - diagnoses_lib.check_for_duplicate_results([a1, b2], []) + diagnoses_lib.check_for_duplicate_results(iter([a1, b2]), []) def test_phase_phase_same_result(self): @@ -195,7 +191,7 @@ def a1(): def a2(): pass - diagnoses_lib.check_for_duplicate_results([a1, a2], []) + diagnoses_lib.check_for_duplicate_results(iter([a1, a2]), []) def test_phase_phase_same_diagnoser(self): @@ -207,7 +203,7 @@ def a1(): def a2(): pass - diagnoses_lib.check_for_duplicate_results([a1, a2], []) + diagnoses_lib.check_for_duplicate_results(iter([a1, a2]), []) def test_phase_test_dupe(self): @@ -216,7 +212,7 @@ def a1(): pass with self.assertRaises(diagnoses_lib.DuplicateResultError): - diagnoses_lib.check_for_duplicate_results([a1], [dupe_b_test_diag]) + diagnoses_lib.check_for_duplicate_results(iter([a1]), [dupe_b_test_diag]) def test_phase_test_same_result(self): @@ -224,20 +220,20 @@ def test_phase_test_same_result(self): def a1(): pass - diagnoses_lib.check_for_duplicate_results([a1], [dupe_a2_test_diag]) + diagnoses_lib.check_for_duplicate_results(iter([a1]), [dupe_a2_test_diag]) def test_test_test_dupe(self): with self.assertRaises(diagnoses_lib.DuplicateResultError): diagnoses_lib.check_for_duplicate_results( - [], [dupe_a_test_diag, dupe_b_test_diag]) + iter([]), [dupe_a_test_diag, dupe_b_test_diag]) def test_test_test_same_result(self): diagnoses_lib.check_for_duplicate_results( - [], [dupe_a_test_diag, dupe_a2_test_diag]) + iter([]), [dupe_a_test_diag, dupe_a2_test_diag]) def test_test_test_same_diagnoser(self): diagnoses_lib.check_for_duplicate_results( - [], [dupe_a_test_diag, dupe_a_test_diag]) + iter([]), [dupe_a_test_diag, dupe_a_test_diag]) class CheckDiagnosersTest(unittest.TestCase): @@ -249,18 +245,18 @@ class NotDiagnoser(object): with self.assertRaises(diagnoses_lib.DiagnoserError) as cm: diagnoses_lib._check_diagnoser(NotDiagnoser(), - diagnoses_lib.PhaseDiagnoser) - self.assertEqual('Diagnoser "NotDiagnoser" is not a PhaseDiagnoser.', + diagnoses_lib.BasePhaseDiagnoser) # pytype: disable=wrong-arg-types + self.assertEqual('Diagnoser "NotDiagnoser" is not a BasePhaseDiagnoser.', cm.exception.args[0]) def test_result_type_not_set(self): - @htf.PhaseDiagnoser(None) + @htf.PhaseDiagnoser(None) # pytype: disable=wrong-arg-types def bad_diag(phase_rec): del phase_rec # Unused. with self.assertRaises(diagnoses_lib.DiagnoserError) as cm: - diagnoses_lib._check_diagnoser(bad_diag, diagnoses_lib.PhaseDiagnoser) + diagnoses_lib._check_diagnoser(bad_diag, diagnoses_lib.BasePhaseDiagnoser) self.assertEqual('Diagnoser "bad_diag" does not have a result_type set.', cm.exception.args[0]) @@ -269,32 +265,34 @@ def test_result_type_not_result_enum(self): class BadEnum(str, enum.Enum): BAD = 'bad' - @htf.PhaseDiagnoser(BadEnum) + @htf.PhaseDiagnoser(BadEnum) # pytype: disable=wrong-arg-types def bad_enum_diag(phase_rec): del phase_rec # Unused. with self.assertRaises(diagnoses_lib.DiagnoserError) as cm: diagnoses_lib._check_diagnoser(bad_enum_diag, - diagnoses_lib.PhaseDiagnoser) + diagnoses_lib.BasePhaseDiagnoser) self.assertEqual( 'Diagnoser "bad_enum_diag" result_type "BadEnum" does not inherit ' 'from DiagResultEnum.', cm.exception.args[0]) def test_pass(self): diagnoses_lib._check_diagnoser(basic_wrapper_phase_diagnoser, - diagnoses_lib.PhaseDiagnoser) + diagnoses_lib.BasePhaseDiagnoser) def test_inomplete_phase_diagnoser(self): incomplete = htf.PhaseDiagnoser(BadResult, 'NotFinished') with self.assertRaises(diagnoses_lib.DiagnoserError): - diagnoses_lib._check_diagnoser(incomplete, diagnoses_lib.PhaseDiagnoser) + diagnoses_lib._check_diagnoser(incomplete, + diagnoses_lib.BasePhaseDiagnoser) def test_inomplete_test_diagnoser(self): incomplete = htf.TestDiagnoser(BadResult, 'NotFinished') with self.assertRaises(diagnoses_lib.DiagnoserError): - diagnoses_lib._check_diagnoser(incomplete, diagnoses_lib.TestDiagnoser) + diagnoses_lib._check_diagnoser(incomplete, + diagnoses_lib.BaseTestDiagnoser) class DiagnoserTest(unittest.TestCase): @@ -362,8 +360,9 @@ def reuse(test_record_, store): with self.assertRaises(diagnoses_lib.DiagnoserError): @reuse - def unused_diag(test_record_): + def unused_diag(test_record_, store): del test_record_ # Unused. + del store # Unused. return None @@ -427,7 +426,7 @@ def totally_not_a_diagnoser(): pass with self.assertRaises(diagnoses_lib.DiagnoserError): - _ = htf.diagnose(totally_not_a_diagnoser)(basic_phase) + _ = htf.diagnose(totally_not_a_diagnoser)(basic_phase) # pytype: disable=wrong-arg-types def test_test_diagnoses__check_diagnosers_fail(self): @@ -436,7 +435,7 @@ def totally_not_a_diagnoser(): test = htf.Test(basic_phase) with self.assertRaises(diagnoses_lib.DiagnoserError): - test.add_test_diagnosers(totally_not_a_diagnoser) + test.add_test_diagnosers(totally_not_a_diagnoser) # pytype: disable=wrong-arg-types @htf_test.yields_phases def test_phase_no_diagnoses(self): @@ -1013,7 +1012,7 @@ def test_phase_diagnoser__phase_error__diag_fail(self): @htf_test.yields_phases def test_phase_diagnoser__phase_skip__no_diagnosers_run(self): - fake_diag = get_mock_diag() + fake_diag, mock_func = get_mock_diag() @htf.diagnose(fake_diag) def skip_phase(): @@ -1025,11 +1024,11 @@ def skip_phase(): self.assertPhaseOutcomeSkip(phase_rec) self.assertEqual([], phase_rec.diagnosis_results) self.assertEqual([], phase_rec.failure_diagnosis_results) - fake_diag._run_func.assert_not_called() + mock_func.assert_not_called() @htf_test.yields_phases def test_phase_diagnoser__phase_repeat__no_diagnosers_run(self): - fake_diag = get_mock_diag() + fake_diag, mock_func = get_mock_diag() @htf.diagnose(fake_diag) def repeat_phase(): @@ -1041,11 +1040,11 @@ def repeat_phase(): self.assertPhaseOutcomeSkip(phase_rec) self.assertEqual([], phase_rec.diagnosis_results) self.assertEqual([], phase_rec.failure_diagnosis_results) - fake_diag._run_func.assert_not_called() + mock_func.assert_not_called() @htf_test.yields_phases def test_phase_diagnoser__timeout__diagnoser_run(self): - fake_diag = get_mock_diag() + fake_diag, mock_func = get_mock_diag() @htf.diagnose(fake_diag) @htf.PhaseOptions(timeout_s=0) @@ -1056,7 +1055,7 @@ def phase(): phase_rec = yield phase self.assertPhaseTimeout(phase_rec) - fake_diag._run_func.assert_called_once() + mock_func.assert_called_once() @htf_test.yields_phases def test_test_diagnoser__exception(self): @@ -1221,31 +1220,31 @@ def test_test_record_diagnosis_serialization(self): converted = data.convert_to_base_types(test_rec) self.assertEqual([ { - 'result': 'OKAY', + 'result': 'okay', 'description': 'Everything is okay.', 'component': None, 'priority': 'NORMAL', }, { - 'result': 'ONE', + 'result': 'bad_one', 'description': 'Oh no!', 'component': None, 'priority': 'NORMAL', 'is_failure': True, }, { - 'result': 'TEST_OK', + 'result': 'test_ok', 'description': 'Okay', 'component': None, 'priority': 'NORMAL', }, ], converted['diagnoses']) - self.assertEqual(['OKAY'], converted['phases'][1]['diagnosis_results']) + self.assertEqual(['okay'], converted['phases'][1]['diagnosis_results']) self.assertEqual([], converted['phases'][1]['failure_diagnosis_results']) self.assertEqual([], converted['phases'][2]['diagnosis_results']) - self.assertEqual(['ONE'], + self.assertEqual(['bad_one'], converted['phases'][2]['failure_diagnosis_results']) @htf_test.yields_phases @@ -1265,8 +1264,8 @@ def check_record_diagnoser(phase_record): 'pass_measure', is_value_set=True, stored_value=True, - _cached_value=True), - _cached=mock.ANY), phase_record.measurements['pass_measure']) + cached_value=True), + cached=mock.ANY), phase_record.measurements['pass_measure']) self.assertEqual( htf.Measurement( 'fail_measure', @@ -1275,9 +1274,9 @@ def check_record_diagnoser(phase_record): 'fail_measure', is_value_set=True, stored_value=False, - _cached_value=False), + cached_value=False), validators=[is_true], - _cached=mock.ANY), phase_record.measurements['fail_measure']) + cached=mock.ANY), phase_record.measurements['fail_measure']) return None @htf.diagnose(check_record_diagnoser) diff --git a/test/core/exe_test.py b/test/core/exe_test.py index 2bfe81ade..13be16cdd 100644 --- a/test/core/exe_test.py +++ b/test/core/exe_test.py @@ -13,33 +13,38 @@ # limitations under the License. """Unit tests for the openhtf.exe module.""" +import logging import threading import time import unittest +from absl.testing import parameterized import mock import openhtf from openhtf import plugs from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import diagnoses_lib +from openhtf.core import phase_branches +from openhtf.core import phase_collections from openhtf.core import phase_descriptor from openhtf.core import phase_executor from openhtf.core import phase_group from openhtf.core import test_descriptor from openhtf.core import test_executor +from openhtf.core import test_record from openhtf.core import test_state -from openhtf.core.test_record import Outcome from openhtf.util import conf from openhtf.util import logs from openhtf.util import timeouts - # Default logging to debug level. logs.CLI_LOGGING_VERBOSITY = 2 -class UnittestPlug(plugs.BasePlug): +class UnittestPlug(base_plugs.BasePlug): return_continue_count = 4 @@ -71,7 +76,7 @@ class FailedPlugError(Exception): FAIL_PLUG_MESSAGE = 'Failed' -class FailPlug(plugs.BasePlug): +class FailPlug(base_plugs.BasePlug): def __init__(self): raise FailedPlugError(FAIL_PLUG_MESSAGE) @@ -130,10 +135,23 @@ def fail_plug_phase(fail): del fail +@openhtf.PhaseOptions() +def bad_return_phase(): + return 42 + + def blank_phase(): pass +def _rename(phase, new_name): + return phase_descriptor.PhaseOptions(name=new_name)(phase) + + +def _fake_phases(*new_names): + return [_rename(blank_phase, name) for name in new_names] + + class TeardownError(Exception): pass @@ -152,10 +170,12 @@ def _abort_executor_in_thread(executor_abort): # the wait() call gets the error raised in it. ready_to_stop_ev = threading.Event() inner_ev = threading.Event() + def abort_executor(): ready_to_stop_ev.wait(1) executor_abort() inner_ev.set() + threading.Thread(target=abort_executor).start() ready_to_stop_ev.set() inner_ev.wait(2) @@ -187,37 +207,40 @@ def failure_phase(test): # Configure test to throw exception midrun, and check that this causes # Outcome = ERROR. ev = threading.Event() - group = phase_group.PhaseGroup( - main=[failure_phase], - teardown=[lambda: ev.set()], # pylint: disable=unnecessary-lambda - ) + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup(main=[failure_phase], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.ERROR) + self.assertEqual(record.outcome, test_record.Outcome.ERROR) # Same as above, but now specify that the TestDummyExceptionError should # instead be a FAIL outcome. - test.configure( - failure_exceptions=[self.TestDummyExceptionError] - ) + test.configure(failure_exceptions=[self.TestDummyExceptionError]) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.FAIL) + self.assertEqual(record.outcome, test_record.Outcome.FAIL) def test_plug_map(self): test = openhtf.Test(phase_one, phase_two) @@ -232,8 +255,9 @@ def test_test_executor(self): def test_class_string(self): check_list = ['PhaseExecutorThread', 'phase_one'] + mock_test_state = mock.create_autospec(test_state.TestState) phase_thread = phase_executor.PhaseExecutorThread( - phase_one, ' ', run_with_profiling=False) + phase_one, mock_test_state, run_with_profiling=False, subtest_rec=None) name = str(phase_thread) found = True for item in check_list: @@ -255,14 +279,13 @@ def cancel_phase(test): ev = threading.Event() - group = phase_group.PhaseGroup( - teardown=[lambda: ev.set()], # pylint: disable=unnecessary-lambda - ) + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup(teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) # Cancel during test start phase. executor = test_executor.TestExecutor( @@ -270,8 +293,7 @@ def cancel_phase(test): 'uid', cancel_phase, test._test_options, - run_with_profiling=False - ) + run_with_profiling=False) executor.start() executor.wait() @@ -294,14 +316,18 @@ def cancel_phase(): _abort_executor_in_thread(executor.abort) ev = threading.Event() - group = phase_group.PhaseGroup(main=[cancel_phase], - teardown=[lambda: ev.set()]) # pylint: disable=unnecessary-lambda + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup(main=[cancel_phase], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() @@ -316,6 +342,7 @@ def cancel_phase(): executor.close() def test_cancel_phase_with_diagnoser(self): + class DiagResult(openhtf.DiagResultEnum): RESULT = 'result' @@ -331,14 +358,18 @@ def cancel_phase(): _abort_executor_in_thread(executor.abort) ev = threading.Event() - group = phase_group.PhaseGroup(main=[cancel_phase], - teardown=[lambda: ev.set()]) # pylint: disable=unnecessary-lambda + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup(main=[cancel_phase], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() @@ -381,14 +412,15 @@ def teardown2_phase(): teardown_running = threading.Event() ev = threading.Event() ev2 = threading.Event() - group = phase_group.PhaseGroup(main=[cancel_twice_phase], - teardown=[teardown_phase, teardown2_phase]) + group = phase_group.PhaseGroup( + main=[cancel_twice_phase], teardown=[teardown_phase, teardown2_phase]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() @@ -405,22 +437,24 @@ def teardown2_phase(): def test_failure_during_plug_init(self): ev = threading.Event() - group = phase_group.PhaseGroup( - main=[fail_plug_phase], - teardown=[lambda: ev.set()], # pylint: disable=unnecessary-lambda - ) + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup(main=[fail_plug_phase], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', None, test._test_options, + test.descriptor, + 'uid', + None, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.ERROR) + self.assertEqual(record.outcome, test_record.Outcome.ERROR) self.assertEqual(record.outcome_details[0].code, FailedPlugError.__name__) self.assertEqual(record.outcome_details[0].description, FAIL_PLUG_MESSAGE) # Teardown function should *NOT* be executed. @@ -428,21 +462,22 @@ def test_failure_during_plug_init(self): executor.close() def test_failure_during_start_phase_plug_init(self): + def never_gonna_run_phase(): ev2.set() ev = threading.Event() + + def set_ev(): + ev.set() + ev2 = threading.Event() group = phase_group.PhaseGroup( - main=[never_gonna_run_phase], - teardown=[lambda: ev.set()], # pylint: disable=unnecessary-lambda - ) + main=[never_gonna_run_phase], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( test.descriptor, @@ -453,7 +488,7 @@ def never_gonna_run_phase(): executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.ERROR) + self.assertEqual(record.outcome, test_record.Outcome.ERROR) self.assertEqual(record.outcome_details[0].code, FailedPlugError.__name__) self.assertEqual(record.outcome_details[0].description, FAIL_PLUG_MESSAGE) # Teardown function should *NOT* be executed. @@ -461,21 +496,21 @@ def never_gonna_run_phase(): self.assertFalse(ev2.is_set()) def test_error_during_teardown(self): - group = phase_group.PhaseGroup( - main=[blank_phase], teardown=[teardown_fail]) + group = phase_group.PhaseGroup(main=[blank_phase], teardown=[teardown_fail]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.ERROR) + self.assertEqual(record.outcome, test_record.Outcome.ERROR) self.assertEqual(record.outcome_details[0].code, TeardownError.__name__) executor.close() @@ -485,45 +520,51 @@ def test_log_during_teardown(self): def teardown_log(test): test.logger.info(message) - group = phase_group.PhaseGroup( - main=[blank_phase], teardown=[teardown_log]) + group = phase_group.PhaseGroup(main=[blank_phase], teardown=[teardown_log]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record - self.assertEqual(record.outcome, Outcome.PASS) - log_records = [log_record for log_record in record.log_records - if log_record.message == message] + self.assertEqual(record.outcome, test_record.Outcome.PASS) + log_records = [ + log_record for log_record in record.log_records + if log_record.message == message + ] self.assertTrue(log_records) executor.close() def test_stop_on_first_failure_phase(self): ev = threading.Event() - group = phase_group.PhaseGroup(main=[phase_return_fail_and_continue, - phase_one], - teardown=[lambda: ev.set()]) # pylint: disable=unnecessary-lambda + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup( + main=[phase_return_fail_and_continue, phase_one], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) test.configure(stop_on_first_failure=True) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record self.assertEqual(record.phases[0].name, start_phase.name) - self.assertTrue(record.outcome, Outcome.FAIL) + self.assertTrue(record.outcome, test_record.Outcome.FAIL) # Verify phase_one was not run ran_phase = [phase.name for phase in record.phases] self.assertNotIn('phase_one', ran_phase) @@ -535,23 +576,27 @@ def test_stop_on_first_failure_phase(self): def test_conf_stop_on_first_failure_phase(self): ev = threading.Event() - group = phase_group.PhaseGroup(main=[phase_return_fail_and_continue, - phase_one], - teardown=[lambda: ev.set()]) # pylint: disable=unnecessary-lambda + + def set_ev(): + ev.set() + + group = phase_group.PhaseGroup( + main=[phase_return_fail_and_continue, phase_one], teardown=[set_ev]) test = openhtf.Test(group) - test.configure( - default_dut_id='dut', - ) + test.configure(default_dut_id='dut',) conf.load(stop_on_first_failure=True) executor = test_executor.TestExecutor( - test.descriptor, 'uid', start_phase, test._test_options, + test.descriptor, + 'uid', + start_phase, + test._test_options, run_with_profiling=False) executor.start() executor.wait() record = executor.test_state.test_record self.assertEqual(record.phases[0].name, start_phase.name) - self.assertTrue(record.outcome, Outcome.FAIL) + self.assertTrue(record.outcome, test_record.Outcome.FAIL) # Verify phase_one was not run ran_phase = [phase.name for phase in record.phases] self.assertNotIn('phase_one', ran_phase) @@ -560,10 +605,10 @@ def test_conf_stop_on_first_failure_phase(self): executor.close() -class TestExecutorHandlePhaseTest(unittest.TestCase): +class TestExecutorExecutePhaseTest(unittest.TestCase): def setUp(self): - super(TestExecutorHandlePhaseTest, self).setUp() + super(TestExecutorExecutePhaseTest, self).setUp() self.test_state = mock.MagicMock( spec=test_state.TestState, plug_manager=plugs.PlugManager(), @@ -571,38 +616,31 @@ def setUp(self): state_logger=mock.MagicMock(), test_options=test_descriptor.TestOptions(), test_record=mock.MagicMock()) - self.phase_exec = mock.MagicMock( - spec=phase_executor.PhaseExecutor) - self.test_exec = test_executor.TestExecutor(None, 'uid', None, - test_descriptor.TestOptions(), - run_with_profiling=False) + self.phase_exec = mock.MagicMock(spec=phase_executor.PhaseExecutor) + td = test_descriptor.TestDescriptor( + phase_sequence=phase_collections.PhaseSequence( + phase_group.PhaseGroup()), + code_info=test_record.CodeInfo.uncaptured(), + metadata={}) + self.test_exec = test_executor.TestExecutor( + td, + td.uid, + None, + test_descriptor.TestOptions(), + run_with_profiling=False) self.test_exec.test_state = self.test_state self.test_exec._phase_exec = self.phase_exec - patcher = mock.patch.object(self.test_exec, '_execute_phase_group') - self.mock_execute_phase_group = patcher.start() - - def testPhaseGroup_NotTerminal(self): - self.mock_execute_phase_group.return_value = False - group = phase_group.PhaseGroup(name='test') - self.assertFalse(self.test_exec._handle_phase(group)) - self.mock_execute_phase_group.assert_called_once_with(group) - - def testPhaseGroup_Terminal(self): - self.mock_execute_phase_group.return_value = True - group = phase_group.PhaseGroup(name='test') - self.assertTrue(self.test_exec._handle_phase(group)) - self.mock_execute_phase_group.assert_called_once_with(group) - def testPhase_NotTerminal(self): phase = phase_descriptor.PhaseDescriptor(blank_phase) self.phase_exec.execute_phase.return_value = ( phase_executor.PhaseExecutionOutcome( phase_descriptor.PhaseResult.CONTINUE), None) - self.assertFalse(self.test_exec._handle_phase(phase)) + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase(phase, None, False)) - self.mock_execute_phase_group.assert_not_called() - self.phase_exec.execute_phase.assert_called_once_with(phase, False) + self.phase_exec.execute_phase.assert_called_once_with( + phase, run_with_profiling=False, subtest_rec=None) self.assertIsNone(self.test_exec._last_outcome) def testPhase_NotTerminal_PreviousLastOutcome(self): @@ -613,10 +651,11 @@ def testPhase_NotTerminal_PreviousLastOutcome(self): self.phase_exec.execute_phase.return_value = ( phase_executor.PhaseExecutionOutcome( phase_descriptor.PhaseResult.CONTINUE), None) - self.assertFalse(self.test_exec._handle_phase(phase)) + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase(phase, None, False)) - self.mock_execute_phase_group.assert_not_called() - self.phase_exec.execute_phase.assert_called_once_with(phase, False) + self.phase_exec.execute_phase.assert_called_once_with( + phase, run_with_profiling=False, subtest_rec=None) self.assertIs(set_outcome, self.test_exec._last_outcome) def testPhase_Terminal_SetLastOutcome(self): @@ -624,10 +663,11 @@ def testPhase_Terminal_SetLastOutcome(self): outcome = phase_executor.PhaseExecutionOutcome( phase_descriptor.PhaseResult.STOP) self.phase_exec.execute_phase.return_value = outcome, None - self.assertTrue(self.test_exec._handle_phase(phase)) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_phase(phase, None, False)) - self.mock_execute_phase_group.assert_not_called() - self.phase_exec.execute_phase.assert_called_once_with(phase, False) + self.phase_exec.execute_phase.assert_called_once_with( + phase, run_with_profiling=False, subtest_rec=None) self.assertIs(outcome, self.test_exec._last_outcome) def testPhase_Terminal_PreviousLastOutcome(self): @@ -637,97 +677,167 @@ def testPhase_Terminal_PreviousLastOutcome(self): outcome = phase_executor.PhaseExecutionOutcome( phase_descriptor.PhaseResult.STOP) self.phase_exec.execute_phase.return_value = outcome, None - self.assertTrue(self.test_exec._handle_phase(phase)) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_phase(phase, None, False)) - self.mock_execute_phase_group.assert_not_called() - self.phase_exec.execute_phase.assert_called_once_with(phase, False) + self.phase_exec.execute_phase.assert_called_once_with( + phase, run_with_profiling=False, subtest_rec=None) self.assertIs(set_outcome, self.test_exec._last_outcome) -class TestExecutorExecutePhasesTest(unittest.TestCase): +class TestExecutorExecuteSequencesTest(unittest.TestCase): def setUp(self): - super(TestExecutorExecutePhasesTest, self).setUp() + super(TestExecutorExecuteSequencesTest, self).setUp() self.test_state = mock.MagicMock( spec=test_state.TestState, plug_manager=plugs.PlugManager(), execution_uid='01234567890', state_logger=mock.MagicMock()) - self.test_exec = test_executor.TestExecutor(None, 'uid', None, - test_descriptor.TestOptions(), - run_with_profiling=False) + td = test_descriptor.TestDescriptor( + phase_sequence=phase_collections.PhaseSequence( + phase_group.PhaseGroup()), + code_info=test_record.CodeInfo.uncaptured(), + metadata={}) + self.test_exec = test_executor.TestExecutor( + td, + td.uid, + None, + test_descriptor.TestOptions(), + run_with_profiling=False) self.test_exec.test_state = self.test_state - patcher = mock.patch.object(self.test_exec, '_handle_phase') - self.mock_handle_phase = patcher.start() + patcher = mock.patch.object(self.test_exec, '_execute_node') + self.mock_execute_node = patcher.start() def testExecuteAbortable_NoPhases(self): - self.assertFalse(self.test_exec._execute_abortable_phases( - 'main', (), 'group')) - self.mock_handle_phase.assert_not_called() + self.assertEqual( + test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_sequence( + phase_collections.PhaseSequence(tuple()), + None, + False, + override_message='main group')) + self.mock_execute_node.assert_not_called() def testExecuteAbortable_Normal(self): - self.mock_handle_phase.side_effect = [False] - self.assertFalse(self.test_exec._execute_abortable_phases( - 'main', ('normal',), 'group')) - self.mock_handle_phase.assert_called_once_with('normal') + self.mock_execute_node.side_effect = [ + test_executor._ExecutorReturn.CONTINUE + ] + sequence = phase_collections.PhaseSequence(_fake_phases('normal')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_sequence(sequence, None, False)) + self.mock_execute_node.assert_called_once_with(all_phases[0], None, False) def testExecuteAbortable_AbortedPrior(self): self.test_exec.abort() - self.assertTrue(self.test_exec._execute_abortable_phases( - 'main', ('not-run',), 'group')) - self.mock_handle_phase.assert_not_called() + sequence = phase_collections.PhaseSequence(_fake_phases('not-run')) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_sequence(sequence, None, False)) + self.mock_execute_node.assert_not_called() def testExecuteAbortable_AbortedDuring(self): - self.mock_handle_phase.side_effect = lambda x: self.test_exec.abort() - self.assertTrue(self.test_exec._execute_abortable_phases( - 'main', ('abort', 'not-run'), 'group')) - self.mock_handle_phase.assert_called_once_with('abort') + + def execute_node(node, subtest_rec, in_teardown): + del node # Unused. + del subtest_rec # Unused. + del in_teardown # Unused. + self.test_exec.abort() + return test_executor._ExecutorReturn.TERMINAL + + self.mock_execute_node.side_effect = execute_node + sequence = phase_collections.PhaseSequence(_fake_phases('abort', 'not-run')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_sequence(sequence, None, False)) + self.mock_execute_node.assert_called_once_with(all_phases[0], None, False) def testExecuteAbortable_Terminal(self): - self.mock_handle_phase.side_effect = [False, True] - self.assertTrue(self.test_exec._execute_abortable_phases( - 'main', ('normal', 'abort', 'not_run'), 'group')) - self.assertEqual([mock.call('normal'), mock.call('abort')], - self.mock_handle_phase.call_args_list) + self.mock_execute_node.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.TERMINAL + ] + sequence = phase_collections.PhaseSequence( + _fake_phases('normal', 'abort', 'not_run')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_sequence(sequence, None, False)) + self.assertEqual([ + mock.call(all_phases[0], None, False), + mock.call(all_phases[1], None, False) + ], self.mock_execute_node.call_args_list) def testExecuteTeardown_Empty(self): - self.assertFalse(self.test_exec._execute_teardown_phases((), 'group')) - self.mock_handle_phase.assert_not_called() + self.assertEqual( + test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_sequence( + phase_collections.PhaseSequence(tuple()), + None, + True, + override_message='group')) + self.mock_execute_node.assert_not_called() def testExecuteTeardown_Normal(self): - self.mock_handle_phase.side_effect = [False] - self.assertFalse(self.test_exec._execute_teardown_phases( - ('normal',), 'group')) - self.mock_handle_phase.assert_called_once_with('normal') + self.mock_execute_node.side_effect = [ + test_executor._ExecutorReturn.CONTINUE + ] + sequence = phase_collections.PhaseSequence(_fake_phases('normal')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_sequence(sequence, None, True)) + self.mock_execute_node.assert_called_once_with(all_phases[0], None, True) def testExecuteTeardown_AbortPrior(self): self.test_exec.abort() - self.mock_handle_phase.side_effect = [False] - self.assertFalse(self.test_exec._execute_teardown_phases( - ('normal',), 'group')) - self.mock_handle_phase.assert_called_once_with('normal') + self.mock_execute_node.side_effect = [ + test_executor._ExecutorReturn.CONTINUE + ] + sequence = phase_collections.PhaseSequence(_fake_phases('normal')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_sequence(sequence, None, True)) + self.mock_execute_node.assert_called_once_with(all_phases[0], None, True) def testExecuteTeardown_AbortedDuring(self): - def handle_phase(fake_phase): - if fake_phase == 'abort': + + def execute_node(node, subtest_rec, in_teardown): + del subtest_rec # Unused. + del in_teardown # Unused. + if node.name == 'abort': self.test_exec.abort() - return False - self.mock_handle_phase.side_effect = handle_phase - self.assertFalse(self.test_exec._execute_teardown_phases( - ('abort', 'still-run'), 'group')) - self.mock_handle_phase.assert_has_calls( - [mock.call('abort'), mock.call('still-run')]) + return test_executor._ExecutorReturn.TERMINAL + return test_executor._ExecutorReturn.CONTINUE + + self.mock_execute_node.side_effect = execute_node + sequence = phase_collections.PhaseSequence( + _fake_phases('abort', 'still-run')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_sequence(sequence, None, True)) + self.assertEqual([ + mock.call(all_phases[0], None, True), + mock.call(all_phases[1], None, True) + ], self.mock_execute_node.call_args_list) def testExecuteTeardown_Terminal(self): - def handle_phase(fake_phase): - if fake_phase == 'error': - return True - return False - self.mock_handle_phase.side_effect = handle_phase - self.assertTrue(self.test_exec._execute_teardown_phases( - ('error', 'still-run'), 'group')) - self.mock_handle_phase.assert_has_calls( - [mock.call('error'), mock.call('still-run')]) + + def execute_node(node, subtest_rec, in_teardown): + del subtest_rec # Unused. + del in_teardown # Unused. + if node.name == 'error': + return test_executor._ExecutorReturn.TERMINAL + return test_executor._ExecutorReturn.CONTINUE + + self.mock_execute_node.side_effect = execute_node + sequence = phase_collections.PhaseSequence( + _fake_phases('error', 'still-run')) + all_phases = list(sequence.all_phases()) + self.assertEqual(test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_sequence(sequence, None, True)) + self.assertEqual([ + mock.call(all_phases[0], None, True), + mock.call(all_phases[1], None, True) + ], self.mock_execute_node.call_args_list) class TestExecutorExecutePhaseGroupTest(unittest.TestCase): @@ -739,71 +849,281 @@ def setUp(self): plug_manager=plugs.PlugManager(), execution_uid='01234567890', state_logger=mock.MagicMock()) - self.test_exec = test_executor.TestExecutor(None, 'uid', None, - test_descriptor.TestOptions(), - run_with_profiling=False) + td = test_descriptor.TestDescriptor( + phase_sequence=phase_collections.PhaseSequence( + phase_group.PhaseGroup()), + code_info=test_record.CodeInfo.uncaptured(), + metadata={}) + self.test_exec = test_executor.TestExecutor( + td, + td.uid, + None, + test_descriptor.TestOptions(), + run_with_profiling=False) self.test_exec.test_state = self.test_state - patcher = mock.patch.object(self.test_exec, '_execute_abortable_phases') - self.mock_execute_abortable = patcher.start() - - patcher = mock.patch.object(self.test_exec, '_execute_teardown_phases') - self.mock_execute_teardown = patcher.start() + patcher = mock.patch.object(self.test_exec, '_execute_sequence') + self.mock_execute_sequence = patcher.start() + @phase_descriptor.PhaseOptions() def setup(): pass - self._setup = setup + self._setup = phase_collections.PhaseSequence((setup,)) + + @phase_descriptor.PhaseOptions() def main(): pass - self._main = main + + self._main = phase_collections.PhaseSequence((main,)) @openhtf.PhaseOptions(timeout_s=30) def teardown(): pass - self._teardown = teardown + + self._teardown = phase_collections.PhaseSequence((teardown,)) self.group = phase_group.PhaseGroup( - setup=[setup], main=[main], teardown=[teardown], name='group') + setup=self._setup, + main=self._main, + teardown=self._teardown, + name='group') def testStopDuringSetup(self): - self.mock_execute_abortable.return_value = True - self.assertTrue(self.test_exec._execute_phase_group(self.group)) - self.mock_execute_abortable.assert_called_once_with( - 'setup', (self._setup,), 'group') - self.mock_execute_teardown.assert_not_called() + self.mock_execute_sequence.return_value = ( + test_executor._ExecutorReturn.TERMINAL) + + self.assertEqual( + test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_phase_group(self.group, None, False)) + self.mock_execute_sequence.assert_called_once_with( + self._setup, None, False, override_message='group:setup') def testStopDuringMain(self): - self.mock_execute_abortable.side_effect = [False, True] - self.mock_execute_teardown.return_value = False - self.assertTrue(self.test_exec._execute_phase_group(self.group)) - self.mock_execute_abortable.assert_has_calls([ - mock.call('setup', (self._setup,), 'group'), - mock.call('main', (self._main,), 'group'), - ]) - self.mock_execute_teardown.assert_called_once_with( - (self._teardown,), 'group') + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.TERMINAL, + test_executor._ExecutorReturn.CONTINUE, + ] + + self.assertEqual( + test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_phase_group(self.group, None, False)) + self.assertEqual([ + mock.call(self._setup, None, False, override_message='group:setup'), + mock.call(self._main, None, False, override_message='group:main'), + mock.call( + self._teardown, None, True, override_message='group:teardown'), + ], self.mock_execute_sequence.call_args_list) def testStopDuringTeardown(self): - self.mock_execute_abortable.return_value = False - self.mock_execute_teardown.return_value = True - self.assertTrue(self.test_exec._execute_phase_group(self.group)) - self.mock_execute_abortable.assert_has_calls([ - mock.call('setup', (self._setup,), 'group'), - mock.call('main', (self._main,), 'group'), - ]) - self.mock_execute_teardown.assert_called_once_with( - (self._teardown,), 'group') + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.TERMINAL, + ] + + self.assertEqual( + test_executor._ExecutorReturn.TERMINAL, + self.test_exec._execute_phase_group(self.group, None, False)) + self.assertEqual([ + mock.call(self._setup, None, False, override_message='group:setup'), + mock.call(self._main, None, False, override_message='group:main'), + mock.call( + self._teardown, None, True, override_message='group:teardown'), + ], self.mock_execute_sequence.call_args_list) def testNoStop(self): - self.mock_execute_abortable.return_value = False - self.mock_execute_teardown.return_value = False - self.assertFalse(self.test_exec._execute_phase_group(self.group)) - self.mock_execute_abortable.assert_has_calls([ - mock.call('setup', (self._setup,), 'group'), - mock.call('main', (self._main,), 'group'), - ]) - self.mock_execute_teardown.assert_called_once_with( - (self._teardown,), 'group') + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + ] + + self.assertEqual( + test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase_group(self.group, None, False)) + self.assertEqual([ + mock.call(self._setup, None, False, override_message='group:setup'), + mock.call(self._main, None, False, override_message='group:main'), + mock.call( + self._teardown, None, True, override_message='group:teardown'), + ], self.mock_execute_sequence.call_args_list) + + def testEmptyGroup(self): + group = phase_group.PhaseGroup() + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase_group(group, None, False)) + self.mock_execute_sequence.assert_not_called() + + def testNoSetup(self): + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + ] + + group = phase_group.PhaseGroup( + main=self._main, teardown=self._teardown, name='group') + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase_group(group, None, False)) + self.assertEqual([ + mock.call(self._main, None, False, override_message='group:main'), + mock.call( + self._teardown, None, True, override_message='group:teardown'), + ], self.mock_execute_sequence.call_args_list) + + def testNoMain(self): + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + ] + + group = phase_group.PhaseGroup( + setup=self._setup, teardown=self._teardown, name='group') + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase_group(group, None, False)) + self.assertEqual([ + mock.call(self._setup, None, False, override_message='group:setup'), + mock.call( + self._teardown, None, True, override_message='group:teardown'), + ], self.mock_execute_sequence.call_args_list) + + def testNoTeardown(self): + self.mock_execute_sequence.side_effect = [ + test_executor._ExecutorReturn.CONTINUE, + test_executor._ExecutorReturn.CONTINUE, + ] + + group = phase_group.PhaseGroup( + setup=self._setup, main=self._main, name='group') + self.assertEqual(test_executor._ExecutorReturn.CONTINUE, + self.test_exec._execute_phase_group(group, None, False)) + self.assertEqual([ + mock.call(self._setup, None, False, override_message='group:setup'), + mock.call(self._main, None, False, override_message='group:main'), + ], self.mock_execute_sequence.call_args_list) + + +class BranchDiag(diagnoses_lib.DiagResultEnum): + ONE = 'one' + TWO = 'two' + THREE = 'three' + + +_NO_RESULTS = tuple() +_ONE_RESULT = (BranchDiag.ONE,) +_ALL_RESULTS = tuple(BranchDiag) + + +class StringInComparer(object): + + def __init__(self, expected_content): + self._expected_content = expected_content + + def __eq__(self, other): + return self._expected_content in other + + +class TestExecutorExecuteBranchTest(parameterized.TestCase): + + def setUp(self): + super(TestExecutorExecuteBranchTest, self).setUp() + self.diag_store = diagnoses_lib.DiagnosesStore() + self.mock_test_record = mock.create_autospec(test_record.TestRecord) + self.mock_logger = mock.create_autospec(logging.Logger) + self.test_state = mock.MagicMock( + spec=test_state.TestState, + plug_manager=plugs.PlugManager(), + diagnoses_manager=mock.MagicMock( + spec=diagnoses_lib.DiagnosesManager, store=self.diag_store), + execution_uid='01234567890', + test_record=self.mock_test_record, + state_logger=self.mock_logger) + td = test_descriptor.TestDescriptor( + phase_sequence=phase_collections.PhaseSequence( + phase_group.PhaseGroup()), + code_info=test_record.CodeInfo.uncaptured(), + metadata={}) + self.test_exec = test_executor.TestExecutor( + td, + td.uid, + None, + test_descriptor.TestOptions(), + run_with_profiling=False) + self.test_exec.test_state = self.test_state + patcher = mock.patch.object(self.test_exec, '_execute_sequence') + self.mock_execute_sequence = patcher.start() + + @parameterized.named_parameters( + # on_all + ('on_all__one__not_triggered', 'on_all', BranchDiag.ONE, _NO_RESULTS, + False), + ('on_all__one__triggered', 'on_all', _ONE_RESULT, _ONE_RESULT, True), + ('on_all__multiple__none', 'on_all', _ALL_RESULTS, _NO_RESULTS, False), + ('on_all__multiple__one', 'on_all', _ALL_RESULTS, _ONE_RESULT, False), + ('on_all__multiple__all', 'on_all', _ALL_RESULTS, _ALL_RESULTS, True), + # on_any + ('on_any__one__not_triggered', 'on_any', BranchDiag.ONE, _NO_RESULTS, + False), + ('on_any__one__triggered', 'on_any', _ONE_RESULT, _ONE_RESULT, True), + ('on_any__multiple__none', 'on_any', _ALL_RESULTS, _NO_RESULTS, False), + ('on_any__multiple__one', 'on_any', _ALL_RESULTS, _ONE_RESULT, True), + ('on_any__multiple__all', 'on_any', _ALL_RESULTS, _ALL_RESULTS, True), + # on_not_any + ('on_not_any__one__not_triggered', 'on_not_any', BranchDiag.ONE, + _NO_RESULTS, True), + ('on_not_any__one__triggered', 'on_not_any', _ONE_RESULT, _ONE_RESULT, + False), + ('on_not_any__multiple__none', 'on_not_any', _ALL_RESULTS, _NO_RESULTS, + True), + ('on_not_any__multiple__one', 'on_not_any', _ALL_RESULTS, _ONE_RESULT, + False), + ('on_not_any__multiple__all', 'on_not_any', _ALL_RESULTS, _ALL_RESULTS, + False), + # not_all + ('on_not_all__one__not_triggered', 'on_not_all', _ONE_RESULT, _NO_RESULTS, + True), + ('on_not_all__one__triggered', 'on_not_all', _ONE_RESULT, _ONE_RESULT, + False), + ('on_not_all__multiple__none', 'on_not_all', _ALL_RESULTS, _NO_RESULTS, + True), + ('on_not_all__multiple__one', 'on_not_all', _ALL_RESULTS, _ONE_RESULT, + True), + ('on_not_all__multiple__all', 'on_not_all', _ALL_RESULTS, _ALL_RESULTS, + False), + ) + def test_branch(self, constructor_name, constructor_diags, results, called): + diag_cond = getattr(phase_branches.DiagnosisCondition, + constructor_name)(*constructor_diags) + branch = phase_branches.BranchSequence(diag_cond) + for result in results: + self.diag_store._add_diagnosis(diagnoses_lib.Diagnosis(result=result)) + + self.test_exec._execute_phase_branch(branch, None, False) + if called: + self.mock_execute_sequence.assert_called_once_with(branch, None, False) + self.mock_logger.debug.assert_called_once_with( + '%s: Branch condition met; running phases.', diag_cond.message) + else: + self.mock_execute_sequence.assert_not_called() + self.mock_logger.debug.assert_called_once_with( + '%s: Branch condition NOT met; not running sequence.', + diag_cond.message) + self.mock_test_record.add_branch_record.assert_called_once_with( + test_record.BranchRecord.from_branch(branch, called, mock.ANY)) + + def test_branch_with_log(self): + diag_cond = phase_branches.DiagnosisCondition.on_all(BranchDiag.ONE) + branch = phase_branches.BranchSequence(diag_cond, name='branch') + self.diag_store._add_diagnosis( + diagnoses_lib.Diagnosis(result=BranchDiag.ONE)) + + self.test_exec._execute_phase_branch(branch, None, False) + self.mock_execute_sequence.assert_called_once_with(branch, None, False) + self.mock_test_record.add_branch_record.assert_called_once_with( + test_record.BranchRecord.from_branch(branch, True, mock.ANY)) + self.mock_logger.debug.assert_called_once_with( + '%s: Branch condition met; running phases.', + 'branch:{}'.format(diag_cond.message)) class PhaseExecutorTest(unittest.TestCase): @@ -815,8 +1135,8 @@ def setUp(self): plug_manager=plugs.PlugManager(), execution_uid='01234567890', state_logger=mock.MagicMock()) - self.test_state.plug_manager.initialize_plugs([ - UnittestPlug, MoreRepeatsUnittestPlug]) + self.test_state.plug_manager.initialize_plugs( + [UnittestPlug, MoreRepeatsUnittestPlug]) self.phase_executor = phase_executor.PhaseExecutor(self.test_state) def test_execute_continue_phase(self): @@ -845,3 +1165,9 @@ def test_execute_phase_return_fail_and_continue(self): result, _ = self.phase_executor.execute_phase( phase_return_fail_and_continue) self.assertEqual(openhtf.PhaseResult.FAIL_AND_CONTINUE, result.phase_result) + + def test_execute_phase_bad_phase_return(self): + result, _ = self.phase_executor.execute_phase(bad_return_phase) + self.assertEqual( + phase_executor.ExceptionInfo(phase_executor.InvalidPhaseResultError, + mock.ANY, mock.ANY), result.phase_result) diff --git a/test/core/measurements_test.py b/test/core/measurements_test.py index b9b51d479..b9b0a7588 100644 --- a/test/core/measurements_test.py +++ b/test/core/measurements_test.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Test various measurements use cases. The test cases here need improvement - they should check for things that we @@ -29,8 +28,10 @@ from openhtf.util import test as htf_test # Fields that are considered 'volatile' for record comparison. -_VOLATILE_FIELDS = {'start_time_millis', 'end_time_millis', 'timestamp_millis', - 'lineno', 'codeinfo', 'code_info', 'descriptor_id'} +_VOLATILE_FIELDS = { + 'start_time_millis', 'end_time_millis', 'timestamp_millis', 'lineno', + 'codeinfo', 'code_info', 'descriptor_id' +} class BadValidatorError(Exception): @@ -46,15 +47,15 @@ def bad_validator(value): raise BadValidatorError('This is a bad validator.') -@htf.measures(htf.Measurement('bad').with_dimensions('a').with_validator( - bad_validator)) +@htf.measures( + htf.Measurement('bad').with_dimensions('a').with_validator(bad_validator)) def bad_validator_phase(test): test.measurements.bad[1] = 1 test.measurements.bad[2] = 2 -@htf.measures(htf.Measurement('bad').with_dimensions('a').with_validator( - bad_validator)) +@htf.measures( + htf.Measurement('bad').with_dimensions('a').with_validator(bad_validator)) def bad_validator_with_error(test): test.measurements.bad[2] = 2 raise BadPhaseError('Bad phase.') @@ -122,11 +123,12 @@ def test_cache_same_object(self): m.measured_value.set(1) m.notify_value_set() basetypes2 = m.as_base_types() - self.assertEqual({ - 'name': 'measurement', - 'outcome': 'PASS', - 'measured_value': 1, - }, basetypes2) + self.assertEqual( + { + 'name': 'measurement', + 'outcome': 'PASS', + 'measured_value': 1, + }, basetypes2) self.assertIs(basetypes0, basetypes2) @htf_test.patch_plugs(user_mock='openhtf.plugs.user_input.UserInput') @@ -141,23 +143,34 @@ def test_chaining_in_measurement_declarations(self, user_mock): @htf_test.yields_phases def test_measurements_with_dimensions(self): record = yield all_the_things.dimensions - self.assertMeasured(record, 'dimensions', - [(0, 1), (1, 2), (2, 4), (3, 8), (4, 16)]) - self.assertMeasured(record, 'lots_of_dims', - [(1, 21, 101, 123), (2, 22, 102, 126), - (3, 23, 103, 129), (4, 24, 104, 132)]) + self.assertMeasured(record, 'dimensions', [ + (0, 1), + (1, 2), + (2, 4), + (3, 8), + (4, 16), + ]) + self.assertMeasured(record, 'lots_of_dims', [ + (1, 21, 101, 123), + (2, 22, 102, 126), + (3, 23, 103, 129), + (4, 24, 104, 132), + ]) @htf_test.yields_phases def test_validator_replacement(self): - record = yield all_the_things.measures_with_args.with_args(min=2, max=4) + record = yield all_the_things.measures_with_args.with_args( + minimum=2, maximum=4) self.assertMeasurementFail(record, 'replaced_min_only') self.assertMeasurementPass(record, 'replaced_max_only') self.assertMeasurementFail(record, 'replaced_min_max') - record = yield all_the_things.measures_with_args.with_args(min=0, max=5) + record = yield all_the_things.measures_with_args.with_args( + minimum=0, maximum=5) self.assertMeasurementPass(record, 'replaced_min_only') self.assertMeasurementPass(record, 'replaced_max_only') self.assertMeasurementPass(record, 'replaced_min_max') - record = yield all_the_things.measures_with_args.with_args(min=-1, max=0) + record = yield all_the_things.measures_with_args.with_args( + minimum=-1, maximum=0) self.assertMeasurementPass(record, 'replaced_min_only') self.assertMeasurementFail(record, 'replaced_max_only') self.assertMeasurementFail(record, 'replaced_min_max') @@ -165,12 +178,13 @@ def test_validator_replacement(self): @htf_test.yields_phases def test_measurement_order(self): record = yield all_the_things.dimensions - self.assertEqual(list(record.measurements.keys()), - ['dimensions', 'lots_of_dims']) - record = yield all_the_things.measures_with_args.with_args(min=2, max=4) - self.assertEqual(list(record.measurements.keys()), - ['replaced_min_only', 'replaced_max_only', - 'replaced_min_max']) + self.assertEqual( + list(record.measurements.keys()), ['dimensions', 'lots_of_dims']) + record = yield all_the_things.measures_with_args.with_args( + minimum=2, maximum=4) + self.assertEqual( + list(record.measurements.keys()), + ['replaced_min_only', 'replaced_max_only', 'replaced_min_max']) @htf_test.yields_phases def test_bad_validation(self): @@ -202,8 +216,7 @@ def test_to_dataframe__no_pandas(self): def test_to_dataframe(self, units=True): measurement = htf.Measurement('test_multidim') - measurement.with_dimensions('ms', 'assembly', - htf.Dimension('my_zone')) + measurement.with_dimensions('ms', 'assembly', htf.Dimension('my_zone')) if units: measurement.with_units('°C') @@ -222,12 +235,10 @@ def test_to_dataframe(self, units=True): df = measurement.to_dataframe() coordinates = (1, 'A', 2) - query = '(ms == %s) & (assembly == "%s") & (my_zone == %s)' % ( - coordinates) + query = '(ms == %s) & (assembly == "%s") & (my_zone == %s)' % (coordinates) - self.assertEqual( - measurement.measured_value[coordinates], - df.query(query)[measure_column_name].values[0]) + self.assertEqual(measurement.measured_value[coordinates], + df.query(query)[measure_column_name].values[0]) def test_to_dataframe__no_units(self): self.test_to_dataframe(units=False) @@ -256,7 +267,10 @@ def test_cache_dict(self): def test_cached_complex(self): measured_value = measurements.MeasuredValue('complex') - NamedComplex = collections.namedtuple('NamedComplex', ['a']) # pylint: disable=invalid-name + + class NamedComplex(collections.namedtuple('NamedComplex', ['a'])): + pass + named_complex = NamedComplex(10) measured_value.set(named_complex) self.assertEqual({'a': 10}, measured_value._cached_value) @@ -273,7 +287,10 @@ def test_coordinates_len_integer(self): self.assertEqual(length, 1) def test_coordinates_len_tuple(self): - coordinates = ('string', 42,) + coordinates = ( + 'string', + 42, + ) length = measurements._coordinates_len(coordinates) self.assertEqual(length, 2) @@ -305,7 +322,10 @@ def test_single_dimension_mutable_obj_error(self): def test_multi_dimension_correct(self): measurement = htf.Measurement('measure') measurement.with_dimensions('dimension1', 'dimension2') - dimension_vals = ('dim val 1', 1234,) + dimension_vals = ( + 'dim val 1', + 1234, + ) try: measurement.measured_value[dimension_vals] = 42 except measurements.InvalidDimensionsError: diff --git a/test/core/monitors_test.py b/test/core/monitors_test.py index 8e88b79c5..ecc67f13d 100644 --- a/test/core/monitors_test.py +++ b/test/core/monitors_test.py @@ -12,26 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import time -import mock +import unittest +import mock from openhtf import plugs +from openhtf.core import base_plugs from openhtf.core import monitors from six.moves import queue -class EmptyPlug(plugs.BasePlug): +class EmptyPlug(base_plugs.BasePlug): pass class TestMonitors(unittest.TestCase): def setUp(self): + super(TestMonitors, self).setUp() self.test_state = mock.MagicMock(execution_uid='01234567890') - def provide_plugs(plugs): - return {name: cls() for name, cls in plugs} + def provide_plugs(plug_map): + return {name: cls() for name, cls in plug_map} + self.test_state.plug_manager.provide_plugs = provide_plugs def test_basics(self): @@ -40,11 +43,13 @@ def test_basics(self): q = queue.Queue() def monitor_func(test): + del test # Unused. q.put(1) return 1 @monitors.monitors('meas', monitor_func, poll_interval_ms=100) def phase(test): + del test # Unused. while q.qsize() < 2: time.sleep(0.1) @@ -59,21 +64,24 @@ def phase(test): # Measurement time is at the end of the monitor func, which can take # upwards of 100 milliseconds depending on how busy the infrastructure is, # so we only check that it's less than a second. - self.assertLessEqual(first_meas[0], 100, - msg='At time 0, there should be a call made.') - self.assertEqual(1, first_meas[1], - msg="And it should be the monitor func's return val") + self.assertLessEqual( + first_meas[0], 100, msg='At time 0, there should be a call made.') + self.assertEqual( + 1, first_meas[1], msg="And it should be the monitor func's return val") def testPlugs(self): q = queue.Queue() @plugs.plug(empty=EmptyPlug) def monitor(test, empty): + del test # Unused. + del empty # Unused. q.put(2) return 2 @monitors.monitors('meas', monitor, poll_interval_ms=100) def phase(test): + del test # Unused. while q.qsize() < 2: time.sleep(0.1) @@ -84,9 +92,7 @@ def phase(test): # Measurement time is at the end of the monitor func, which can take # upwards of 100 milliseconds depending on how busy the infrastructure is, # so we only check that it's less than a second. - self.assertLessEqual(first_meas[0], 100, - msg='At time 0, there should be a call made.') - self.assertEqual(2, first_meas[1], - msg="And it should be the monitor func's return val") - - + self.assertLessEqual( + first_meas[0], 100, msg='At time 0, there should be a call made.') + self.assertEqual( + 2, first_meas[1], msg="And it should be the monitor func's return val") diff --git a/test/core/phase_branches_test.py b/test/core/phase_branches_test.py new file mode 100644 index 000000000..c660de16f --- /dev/null +++ b/test/core/phase_branches_test.py @@ -0,0 +1,733 @@ +"""Tests for google3.third_party.py.openhtf.test.core.phase_branches.""" + +import unittest + +import mock + +import openhtf as htf +from openhtf.core import phase_branches +from openhtf.core import phase_executor +from openhtf.core import test_record +from openhtf.util import test as htf_test + + +class BranchDiagResult(htf.DiagResultEnum): + SET = 'set' + NOT_SET = 'not_set' + + +@htf.PhaseDiagnoser(BranchDiagResult) +def branch_diagnoser(phase_rec): + del phase_rec # Unused. + return htf.Diagnosis(BranchDiagResult.SET) + + +@htf.diagnose(branch_diagnoser) +def add_set_diag(): + pass + + +@htf.PhaseOptions() +def run_phase(): + pass + + +@htf.PhaseOptions() +def fail_phase(): + return htf.PhaseResult.FAIL_AND_CONTINUE + + +@htf.PhaseOptions() +def error_phase(): + raise Exception('broken') + + +def _rename(phase, new_name): + assert isinstance(new_name, str) + return htf.PhaseOptions(name=new_name)(phase) + + +def _fake_phases(*new_names): + return [_rename(run_phase, name) for name in new_names] + + +phase0, phase1, phase2, phase3 = _fake_phases('phase0', 'phase1', 'phase2', + 'phase3') +skip0 = _rename(run_phase, 'skip0') + + +class BranchSequenceTest(unittest.TestCase): + + def test_as_dict(self): + branch = phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + nodes=(run_phase,)) + expected = { + 'name': None, + 'nodes': [run_phase._asdict()], + 'diag_condition': { + 'condition': phase_branches.ConditionOn.ALL, + 'diagnosis_results': [BranchDiagResult.SET], + }, + } + self.assertEqual(expected, branch._asdict()) + + +class BranchSequenceIntegrationTest(htf_test.TestCase): + + def _assert_phase_names(self, expected_names, test_rec): + run_phase_names = [p.name for p in test_rec.phases[1:]] + self.assertEqual(expected_names, run_phase_names) + + @htf_test.yields_phases + def test_branch_taken(self): + nodes = [ + add_set_diag, + phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + run_phase), + ] + + test_rec = yield htf.Test(nodes) + self.assertTestPass(test_rec) + self._assert_phase_names(['add_set_diag', 'run_phase'], test_rec) + self.assertEqual([ + test_record.BranchRecord( + name=None, + diag_condition=phase_branches.DiagnosisCondition( + condition=phase_branches.ConditionOn.ALL, + diagnosis_results=(BranchDiagResult.SET,)), + branch_taken=True, + evaluated_millis=mock.ANY) + ], test_rec.branches) + + @htf_test.yields_phases + def test_branch_not_taken(self): + nodes = [ + phase_branches.BranchSequence( + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.NOT_SET), + run_phase), + ] + + test_rec = yield htf.Test(nodes) + self.assertTestPass(test_rec) + self._assert_phase_names([], test_rec) + self.assertEqual([ + test_record.BranchRecord( + name=None, + diag_condition=phase_branches.DiagnosisCondition( + condition=phase_branches.ConditionOn.ALL, + diagnosis_results=(BranchDiagResult.NOT_SET,)), + branch_taken=False, + evaluated_millis=mock.ANY) + ], test_rec.branches) + + +class PhaseFailureCheckpointIntegrationTest(htf_test.TestCase): + + def test_invalid_action(self): + with self.assertRaises(ValueError): + phase_branches.PhaseFailureCheckpoint.last( + 'bad_action', action=htf.PhaseResult.CONTINUE) + + def test_asdict(self): + checkpoint = phase_branches.PhaseFailureCheckpoint.last('checkpoint') + self.assertEqual( + { + 'name': 'checkpoint', + 'action': htf.PhaseResult.STOP, + 'previous_phases_to_check': phase_branches.PreviousPhases.LAST, + }, checkpoint._asdict()) + + @htf_test.yields_phases + def test_last__no_previous_phases(self): + self.test_start_function = None + test_rec = yield htf.Test( + phase_branches.PhaseFailureCheckpoint.last('last_prev')) + + self.assertTestError(test_rec) + self.assertTestOutcomeCode(test_rec, 'NoPhasesFoundError') + self.assertEqual(0, len(test_rec.phases)) + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_prev', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + phase_executor.ExceptionInfo(phase_branches.NoPhasesFoundError, + mock.ANY, mock.ANY)), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last__no_failures(self): + test_rec = yield htf.Test( + phase0, phase_branches.PhaseFailureCheckpoint.last('last_pass'), phase1) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_pass', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last__failure_too_early(self): + test_rec = yield htf.Test( + fail_phase, phase0, + phase_branches.PhaseFailureCheckpoint.last('last_early_fail'), phase1) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_early_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last__failure_too_late(self): + test_rec = yield htf.Test( + phase0, phase_branches.PhaseFailureCheckpoint.last('last_late_fail'), + fail_phase, phase1) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_late_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last__failure(self): + test_rec = yield htf.Test( + phase0, fail_phase, + phase_branches.PhaseFailureCheckpoint.last('last_fail'), error_phase) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome(htf.PhaseResult.STOP), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last_fail_subtest__not_in_subtest(self): + test_rec = yield htf.Test( + fail_phase, + phase_branches.PhaseFailureCheckpoint.last( + 'last_subtest', action=htf.PhaseResult.FAIL_SUBTEST), error_phase) + + self.assertTestError(test_rec) + self.assertTestOutcomeCode(test_rec, 'InvalidPhaseResultError') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + phase_executor.ExceptionInfo( + phase_executor.InvalidPhaseResultError, mock.ANY, + mock.ANY)), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last_fail_subtest__pass_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', phase1, + phase_branches.PhaseFailureCheckpoint.last( + 'last_pass_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + phase2), phase3) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2', 'phase3') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_pass_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last_fail_subtest__early_fail_out_of_subtest(self): + test_rec = yield htf.Test( + fail_phase, phase0, + htf.Subtest( + 'sub', phase1, + phase_branches.PhaseFailureCheckpoint.last( + 'last_pass_early_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + phase2), phase3) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2', 'phase3') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_pass_early_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last_fail_subtest__early_fail_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', fail_phase, phase1, + phase_branches.PhaseFailureCheckpoint.last( + 'last_fail_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + phase2), phase3) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2', 'phase3') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_fail_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_last_fail_subtest__fail_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', phase1, fail_phase, + phase_branches.PhaseFailureCheckpoint.last( + 'last_fail_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + skip0), phase2) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='last_fail_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.LAST, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.FAIL_SUBTEST), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all__no_previous_phases(self): + self.test_start_function = None + test_rec = yield htf.Test( + phase_branches.PhaseFailureCheckpoint.all_previous('all_prev')) + + self.assertTestError(test_rec) + self.assertTestOutcomeCode(test_rec, 'NoPhasesFoundError') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_prev', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + phase_executor.ExceptionInfo(phase_branches.NoPhasesFoundError, + mock.ANY, mock.ANY)), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all__pass(self): + test_rec = yield htf.Test( + phase0, phase_branches.PhaseFailureCheckpoint.all_previous('all_pass'), + phase1) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_pass', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all__fail(self): + test_rec = yield htf.Test( + fail_phase, + phase_branches.PhaseFailureCheckpoint.all_previous('all_fail'), + error_phase) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome(htf.PhaseResult.STOP), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all__earlier_fail(self): + test_rec = yield htf.Test( + fail_phase, phase0, + phase_branches.PhaseFailureCheckpoint.all_previous('all_earlier_fail'), + error_phase) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_earlier_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome(htf.PhaseResult.STOP), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all_fail_subtest__not_in_subtest(self): + test_rec = yield htf.Test( + fail_phase, + phase_branches.PhaseFailureCheckpoint.all_previous( + 'all_subtest', action=htf.PhaseResult.FAIL_SUBTEST), error_phase) + + self.assertTestError(test_rec) + self.assertTestOutcomeCode(test_rec, 'InvalidPhaseResultError') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + phase_executor.ExceptionInfo( + phase_executor.InvalidPhaseResultError, mock.ANY, + mock.ANY)), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all_fail_subtest__pass_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', phase1, + phase_branches.PhaseFailureCheckpoint.all_previous( + 'all_pass_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + phase2), phase3) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2', 'phase3') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_pass_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all_fail_subtest__early_fail_out_of_subtest(self): + test_rec = yield htf.Test( + fail_phase, phase0, + htf.Subtest( + 'sub', phase1, + phase_branches.PhaseFailureCheckpoint.all_previous( + 'all_fail_early_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + skip0), phase2) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_fail_early_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.FAIL_SUBTEST), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all_fail_subtest__early_fail_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', fail_phase, phase1, + phase_branches.PhaseFailureCheckpoint.all_previous( + 'all_fail_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + skip0), phase2) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_fail_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.FAIL_SUBTEST), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_all_fail_subtest__fail_in_subtest(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'sub', phase1, fail_phase, + phase_branches.PhaseFailureCheckpoint.all_previous( + 'all_fail_subtest', action=htf.PhaseResult.FAIL_SUBTEST), + skip0), phase2) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip0') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_phase') + + self.assertEqual([ + test_record.CheckpointRecord( + name='all_fail_subtest', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.PreviousPhases.ALL, + subtest_name='sub', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.FAIL_SUBTEST), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + +class DiagnosisCheckpointIntegrationTest(htf_test.TestCase): + + def test_asdict(self): + checkpoint = phase_branches.DiagnosisCheckpoint( + 'checkpoint', + phase_branches.DiagnosisCondition.on_any(BranchDiagResult.SET), + action=htf.PhaseResult.FAIL_SUBTEST) + self.assertEqual( + { + 'name': 'checkpoint', + 'action': htf.PhaseResult.FAIL_SUBTEST, + 'diag_condition': { + 'condition': phase_branches.ConditionOn.ANY, + 'diagnosis_results': [BranchDiagResult.SET], + }, + }, checkpoint._asdict()) + + @htf_test.yields_phases + def test_pass(self): + test_rec = yield htf.Test( + phase0, + phase_branches.DiagnosisCheckpoint( + 'diag_pass', + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.NOT_SET)), + phase1) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1') + + self.assertEqual([ + test_record.CheckpointRecord( + name='diag_pass', + action=htf.PhaseResult.STOP, + conditional=phase_branches.DiagnosisCondition( + phase_branches.ConditionOn.ALL, (BranchDiagResult.NOT_SET,)), + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_fail(self): + test_rec = yield htf.Test( + add_set_diag, + phase_branches.DiagnosisCheckpoint( + 'diag_fail', + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET)), + error_phase) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName( + test_record.PhaseOutcome.PASS, + test_rec, + 'add_set_diag', + ) + + self.assertEqual([ + test_record.CheckpointRecord( + name='diag_fail', + action=htf.PhaseResult.STOP, + conditional=phase_branches.DiagnosisCondition( + phase_branches.ConditionOn.ALL, (BranchDiagResult.SET,)), + subtest_name=None, + result=phase_executor.PhaseExecutionOutcome(htf.PhaseResult.STOP), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_subtest_pass(self): + test_rec = yield htf.Test( + phase0, + htf.Subtest( + 'subtest', phase1, + phase_branches.DiagnosisCheckpoint( + 'diag_subtest_pass', + phase_branches.DiagnosisCondition.on_all( + BranchDiagResult.NOT_SET), + action=htf.PhaseResult.FAIL_SUBTEST), phase2), phase3) + + self.assertTestPass(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase0', 'phase1', 'phase2', 'phase1') + + self.assertEqual([ + test_record.CheckpointRecord( + name='diag_subtest_pass', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.DiagnosisCondition( + phase_branches.ConditionOn.ALL, (BranchDiagResult.NOT_SET,)), + subtest_name='subtest', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.CONTINUE), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) + + @htf_test.yields_phases + def test_subtest_fail(self): + test_rec = yield htf.Test( + add_set_diag, + htf.Subtest( + 'subtest', phase0, + phase_branches.DiagnosisCheckpoint( + 'diag_subtest_pass', + phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET), + action=htf.PhaseResult.FAIL_SUBTEST), skip0), phase1) + + self.assertTestFail(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'add_set_diag', 'phase0', 'phase1') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip0') + + self.assertEqual([ + test_record.CheckpointRecord( + name='diag_subtest_pass', + action=htf.PhaseResult.FAIL_SUBTEST, + conditional=phase_branches.DiagnosisCondition( + phase_branches.ConditionOn.ALL, (BranchDiagResult.SET,)), + subtest_name='subtest', + result=phase_executor.PhaseExecutionOutcome( + htf.PhaseResult.FAIL_SUBTEST), + evaluated_millis=htf_test.VALID_TIMESTAMP), + ], test_rec.checkpoints) diff --git a/test/core/phase_collections_test.py b/test/core/phase_collections_test.py new file mode 100644 index 000000000..b7f6f7d91 --- /dev/null +++ b/test/core/phase_collections_test.py @@ -0,0 +1,763 @@ +"""Unit tests for the phase collections library.""" + +import unittest + +import mock +import openhtf as htf +from openhtf import plugs +from openhtf.core import base_plugs +from openhtf.core import phase_collections +from openhtf.core import phase_descriptor +from openhtf.core import phase_executor +from openhtf.core import phase_group +from openhtf.core import phase_nodes +from openhtf.core import test_record +from openhtf.util import test as htf_test + + +def _create_node(name): + return htf_test.PhaseNodeNameComparable(name) + + +def _create_nodes(*names): + return [_create_node(n) for n in names] + + +def _prefix_name(p): + return phase_descriptor.PhaseOptions(name='prefix:' + p.name)(p) + + +def phase(): + pass + + +def fail_subtest_phase(): + return phase_descriptor.PhaseResult.FAIL_SUBTEST + + +class BrokenError(Exception): + pass + + +def error_phase(): + raise BrokenError('broken') + + +def teardown_phase(): + pass + + +teardown_group = phase_group.PhaseGroup(teardown=teardown_phase) + + +@phase_descriptor.PhaseOptions() +def empty_phase(): + pass + + +@phase_descriptor.PhaseOptions() +def skip_phase(): + pass + + +@phase_descriptor.PhaseOptions() +def skip_phase0(): + pass + + +@phase_descriptor.PhaseOptions() +def skip_phase1(): + pass + + +@phase_descriptor.PhaseOptions() +def phase_with_args(arg1=None): + del arg1 + + +class ParentPlug(base_plugs.BasePlug): + pass + + +class ChildPlug(ParentPlug): + pass + + +@plugs.plug(my_plug=ParentPlug.placeholder) +def plug_phase(my_plug): + del my_plug # Unused. + + +class FlattenTest(unittest.TestCase): + + def test_single_node(self): + node = _create_node('a') + expected = _create_nodes('a') + + self.assertEqual(expected, phase_collections.flatten(node)) + + def test_iterable_flat(self): + node1 = _create_node('1') + node2 = _create_node('2') + node3 = _create_node('3') + expected = _create_nodes('1', '2', '3') + + self.assertEqual(expected, phase_collections.flatten([node1, node2, node3])) + + def test_single_phase(self): + expected = _create_nodes('phase') + + self.assertEqual(expected, phase_collections.flatten(phase)) + + def test_iterable_of_iterable(self): + nodes = [[_create_node('1')], + [[_create_node('2'), _create_node('3')], [_create_node('4')], + _create_node('5')], + _create_node('6'), phase] + expected = _create_nodes('1', '2', '3', '4', '5', '6', 'phase') + + self.assertEqual(expected, phase_collections.flatten(nodes)) + + def test_invalid_entry(self): + nodes = 42 + + with self.assertRaises(ValueError): + phase_collections.flatten(nodes) + + def test_flatten_single_list(self): + seq = htf.PhaseSequence(_create_nodes('1', '2')) + expected = [htf.PhaseSequence(_create_nodes('1', '2'))] + + self.assertEqual(expected, phase_collections.flatten([seq])) + + +class PhaseSequenceTest(unittest.TestCase): + + def test_init__nodes_and_args(self): + with self.assertRaises(ValueError): + phase_collections.PhaseSequence(phase, nodes=tuple(_create_nodes('1'))) + + def test_init__extra_kwargs(self): + with self.assertRaises(ValueError): + phase_collections.PhaseSequence(other=1) + + def test_init__single_callable(self): + expected = phase_collections.PhaseSequence( + nodes=tuple((phase_descriptor.PhaseDescriptor.wrap_or_copy(phase),))) + + self.assertEqual(expected, phase_collections.PhaseSequence(phase)) + + def test_asdict(self): + expected = { + 'name': 'sequence_name', + 'nodes': [{ + 'name': '1' + }, { + 'name': '2' + }], + } + seq = phase_collections.PhaseSequence( + _create_nodes('1', '2'), name='sequence_name') + + self.assertEqual(expected, seq._asdict()) + + def test_with_args(self): + mock_node = mock.create_autospec(phase_nodes.PhaseNode) + seq = phase_collections.PhaseSequence( + nodes=(empty_phase, phase_with_args, mock_node), name='seq') + + updated = seq.with_args(arg1=1, ignored_arg=2) + self.assertEqual(seq.name, updated.name) + self.assertEqual(empty_phase, updated.nodes[0]) + self.assertEqual(phase_with_args.with_args(arg1=1), updated.nodes[1]) + self.assertEqual(mock_node.with_args.return_value, updated.nodes[2]) + mock_node.with_args.assert_called_once_with(arg1=1, ignored_arg=2) + + def test_with_plugs(self): + mock_node = mock.create_autospec(phase_nodes.PhaseNode) + seq = phase_collections.PhaseSequence( + nodes=(empty_phase, plug_phase, mock_node), name='seq') + + updated = seq.with_plugs(my_plug=ChildPlug, ignored_plug=ParentPlug) + self.assertEqual(seq.name, updated.name) + self.assertEqual(empty_phase, updated.nodes[0]) + self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), updated.nodes[1]) + self.assertEqual(mock_node.with_plugs.return_value, updated.nodes[2]) + mock_node.with_plugs.assert_called_once_with( + my_plug=ChildPlug, ignored_plug=ParentPlug) + + def test_load_code_info(self): + mock_node = mock.create_autospec(phase_nodes.PhaseNode) + seq = phase_collections.PhaseSequence( + nodes=(empty_phase, plug_phase, mock_node), name='seq') + + updated = seq.load_code_info() + self.assertEqual(seq.name, updated.name) + phases = list(updated.all_phases()) + self.assertEqual( + test_record.CodeInfo.for_function(empty_phase.func), + phases[0].code_info) + self.assertEqual( + test_record.CodeInfo.for_function(plug_phase.func), phases[1].code_info) + self.assertEqual(mock_node.load_code_info.return_value, updated.nodes[2]) + mock_node.load_code_info.assert_called_once_with() + + def test_apply_to_all_phases(self): + mock_node = mock.create_autospec(phase_nodes.PhaseNode) + seq = phase_collections.PhaseSequence( + nodes=(empty_phase, plug_phase, mock_node), name='seq') + + updated = seq.apply_to_all_phases(_prefix_name) + self.assertEqual(seq.name, updated.name) + self.assertEqual(_prefix_name(empty_phase), updated.nodes[0]) + self.assertEqual(_prefix_name(plug_phase), updated.nodes[1]) + self.assertEqual(mock_node.apply_to_all_phases.return_value, + updated.nodes[2]) + mock_node.apply_to_all_phases.assert_called_once_with(_prefix_name) + + def test_all_phases(self): + mock_node = mock.create_autospec(phase_nodes.PhaseNode) + seq = phase_collections.PhaseSequence( + nodes=(empty_phase, plug_phase, mock_node), name='seq') + self.assertEqual([empty_phase, plug_phase], list(seq.all_phases())) + + +class PhaseSequenceIntegrationTest(htf_test.TestCase): + + @htf_test.yields_phases + def test_nested(self): + seq = phase_collections.PhaseSequence( + phase_collections.PhaseSequence(phase, empty_phase)) + + test_rec = yield htf.Test(seq) + + self.assertTestPass(test_rec) + + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase', 'empty_phase') + + +class SubtestTest(unittest.TestCase): + + def test_init__name(self): + subtest = phase_collections.Subtest('subtest', phase) + self.assertEqual('subtest', subtest.name) + + def test_check_duplicates__dupes(self): + seq = phase_collections.PhaseSequence( + nodes=(phase_collections.Subtest('dupe'), + phase_collections.Subtest('dupe'))) + with self.assertRaises(phase_collections.DuplicateSubtestNamesError): + phase_collections.check_for_duplicate_subtest_names(seq) + + def test_check_duplicates__nested_dupes(self): + seq = phase_collections.PhaseSequence( + nodes=(phase_collections.Subtest( + 'dupe', nodes=(phase_collections.Subtest('dupe'),)),)) + with self.assertRaises(phase_collections.DuplicateSubtestNamesError): + phase_collections.check_for_duplicate_subtest_names(seq) + + +class SubtestIntegrationTest(htf_test.TestCase): + + @htf_test.yields_phases + def test_pass(self): + subtest = phase_collections.Subtest('subtest', phase) + + test_rec = yield htf.Test(subtest) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase') + + self.assertTestPass(test_rec) + self.assertEqual([ + test_record.SubtestRecord( + name='subtest', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.PASS), + ], test_rec.subtests) + self.assertEqual('subtest', test_rec.phases[-1].subtest_name) + + @htf_test.yields_phases + def test_fail_but_still_continues(self): + subtest = phase_collections.Subtest('failure', fail_subtest_phase, + skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + fail_phase_rec = test_rec.phases[1] + self.assertPhaseOutcomeFail(fail_phase_rec) + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertEqual('failure', fail_phase_rec.subtest_name) + + skip_phase_rec = test_rec.phases[2] + self.assertPhaseOutcomeSkip(skip_phase_rec) + self.assertPhaseSkip(skip_phase_rec) + self.assertEqual('failure', skip_phase_rec.subtest_name) + + continue_phase_rec = test_rec.phases[3] + self.assertPhaseOutcomePass(continue_phase_rec) + self.assertPhaseContinue(continue_phase_rec) + self.assertIsNone(continue_phase_rec.subtest_name) + + self.assertEqual([ + test_record.SubtestRecord( + name='failure', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_error(self): + subtest = phase_collections.Subtest('subtest', error_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestError(test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.ERROR, test_rec, + 'error_phase') + self.assertPhasesNotRun(test_rec, 'phase') + + error_phase_rec = test_rec.phases[1] + self.assertPhaseError(error_phase_rec, exc_type=BrokenError) + + self.assertEqual([ + test_record.SubtestRecord( + name='subtest', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.STOP), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_pass__with_group(self): + subtest = phase_collections.Subtest('subtest', teardown_group.wrap(phase)) + + test_rec = yield htf.Test(subtest) + + self.assertTestPass(test_rec) + + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase', 'teardown_phase') + + self.assertEqual([ + test_record.SubtestRecord( + name='subtest', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.PASS), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail__with_group(self): + subtest = phase_collections.Subtest('it_fails', + teardown_group.wrap(fail_subtest_phase), + skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + fail_phase_rec = test_rec.phases[1] + self.assertEqual('fail_subtest_phase', fail_phase_rec.name) + self.assertPhaseOutcomeFail(fail_phase_rec) + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertEqual('it_fails', fail_phase_rec.subtest_name) + + teardown_phase_rec = test_rec.phases[2] + self.assertEqual('teardown_phase', teardown_phase_rec.name) + self.assertPhaseContinue(teardown_phase_rec) + self.assertPhaseOutcomePass(teardown_phase_rec) + self.assertEqual('it_fails', teardown_phase_rec.subtest_name) + + skip_phase_rec = test_rec.phases[3] + self.assertEqual('skip_phase', skip_phase_rec.name) + self.assertPhaseSkip(skip_phase_rec) + self.assertPhaseOutcomeSkip(skip_phase_rec) + self.assertEqual('it_fails', skip_phase_rec.subtest_name) + + continue_phase_rec = test_rec.phases[4] + self.assertEqual('phase', continue_phase_rec.name) + self.assertPhaseOutcomePass(continue_phase_rec) + self.assertPhaseContinue(continue_phase_rec) + self.assertIsNone((continue_phase_rec.subtest_name)) + + self.assertEqual([ + test_record.SubtestRecord( + name='it_fails', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail__with_nested_group_skipped(self): + subtest = phase_collections.Subtest( + 'it_fails', fail_subtest_phase, + htf.PhaseGroup(main=[skip_phase0], teardown=[skip_phase1]), skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + fail_phase_rec = test_rec.phases[1] + self.assertEqual('fail_subtest_phase', fail_phase_rec.name) + self.assertPhaseOutcomeFail(fail_phase_rec) + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertEqual('it_fails', fail_phase_rec.subtest_name) + + skip_phase0_rec = test_rec.phases[2] + self.assertEqual('skip_phase0', skip_phase0_rec.name) + self.assertPhaseSkip(skip_phase0_rec) + self.assertPhaseOutcomeSkip(skip_phase0_rec) + self.assertEqual('it_fails', skip_phase0_rec.subtest_name) + + skip_phase1_rec = test_rec.phases[3] + self.assertEqual('skip_phase1', skip_phase1_rec.name) + self.assertPhaseSkip(skip_phase1_rec) + self.assertPhaseOutcomeSkip(skip_phase1_rec) + self.assertEqual('it_fails', skip_phase1_rec.subtest_name) + + skip_phase_rec = test_rec.phases[4] + self.assertEqual('skip_phase', skip_phase_rec.name) + self.assertPhaseSkip(skip_phase_rec) + self.assertPhaseOutcomeSkip(skip_phase_rec) + self.assertEqual('it_fails', skip_phase_rec.subtest_name) + + continue_phase_rec = test_rec.phases[5] + self.assertEqual('phase', continue_phase_rec.name) + self.assertPhaseOutcomePass(continue_phase_rec) + self.assertPhaseContinue(continue_phase_rec) + self.assertIsNone((continue_phase_rec.subtest_name)) + + self.assertEqual([ + test_record.SubtestRecord( + name='it_fails', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail__with_nested_group_fail_in_setup(self): + subtest = phase_collections.Subtest( + 'it_fails', + htf.PhaseGroup( + setup=[fail_subtest_phase], + main=[skip_phase0], + teardown=[skip_phase1]), skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + fail_phase_rec = test_rec.phases[1] + self.assertEqual('fail_subtest_phase', fail_phase_rec.name) + self.assertPhaseOutcomeFail(fail_phase_rec) + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertEqual('it_fails', fail_phase_rec.subtest_name) + + skip_phase0_rec = test_rec.phases[2] + self.assertEqual('skip_phase0', skip_phase0_rec.name) + self.assertPhaseSkip(skip_phase0_rec) + self.assertPhaseOutcomeSkip(skip_phase0_rec) + self.assertEqual('it_fails', skip_phase0_rec.subtest_name) + + skip_phase1_rec = test_rec.phases[3] + self.assertEqual('skip_phase1', skip_phase1_rec.name) + self.assertPhaseSkip(skip_phase1_rec) + self.assertPhaseOutcomeSkip(skip_phase1_rec) + self.assertEqual('it_fails', skip_phase1_rec.subtest_name) + + skip_phase_rec = test_rec.phases[4] + self.assertEqual('skip_phase', skip_phase_rec.name) + self.assertPhaseSkip(skip_phase_rec) + self.assertPhaseOutcomeSkip(skip_phase_rec) + self.assertEqual('it_fails', skip_phase_rec.subtest_name) + + continue_phase_rec = test_rec.phases[5] + self.assertEqual('phase', continue_phase_rec.name) + self.assertPhaseOutcomePass(continue_phase_rec) + self.assertPhaseContinue(continue_phase_rec) + self.assertIsNone((continue_phase_rec.subtest_name)) + + self.assertEqual([ + test_record.SubtestRecord( + name='it_fails', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail__with_nested_group_fail_in_teardown(self): + subtest = phase_collections.Subtest( + 'it_fails', + htf.PhaseGroup( + main=[empty_phase], teardown=[fail_subtest_phase, teardown_phase]), + skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + empty_phase_rec = test_rec.phases[1] + self.assertEqual('empty_phase', empty_phase_rec.name) + self.assertPhaseOutcomePass(empty_phase_rec) + self.assertPhaseContinue(empty_phase_rec) + self.assertEqual('it_fails', empty_phase_rec.subtest_name) + + fail_phase_rec = test_rec.phases[2] + self.assertEqual('fail_subtest_phase', fail_phase_rec.name) + self.assertPhaseOutcomeFail(fail_phase_rec) + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertEqual('it_fails', fail_phase_rec.subtest_name) + + teardown_phase_rec = test_rec.phases[3] + self.assertEqual('teardown_phase', teardown_phase_rec.name) + self.assertPhaseContinue(teardown_phase_rec) + self.assertPhaseOutcomePass(teardown_phase_rec) + self.assertEqual('it_fails', teardown_phase_rec.subtest_name) + + skip_phase_rec = test_rec.phases[4] + self.assertEqual('skip_phase', skip_phase_rec.name) + self.assertPhaseSkip(skip_phase_rec) + self.assertPhaseOutcomeSkip(skip_phase_rec) + self.assertEqual('it_fails', skip_phase_rec.subtest_name) + + continue_phase_rec = test_rec.phases[5] + self.assertEqual('phase', continue_phase_rec.name) + self.assertPhaseOutcomePass(continue_phase_rec) + self.assertPhaseContinue(continue_phase_rec) + self.assertIsNone((continue_phase_rec.subtest_name)) + + self.assertEqual([ + test_record.SubtestRecord( + name='it_fails', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_error__with_group(self): + subtest = phase_collections.Subtest('it_errors', + teardown_group.wrap(error_phase)) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestError(test_rec) + + error_phase_rec = test_rec.phases[1] + self.assertEqual('error_phase', error_phase_rec.name) + self.assertPhaseOutcomeError(error_phase_rec) + self.assertPhaseError(error_phase_rec, exc_type=BrokenError) + self.assertEqual('it_errors', error_phase_rec.subtest_name) + + teardown_phase_rec = test_rec.phases[2] + self.assertEqual('teardown_phase', teardown_phase_rec.name) + self.assertPhaseContinue(teardown_phase_rec) + self.assertPhaseOutcomePass(teardown_phase_rec) + self.assertEqual('it_errors', teardown_phase_rec.subtest_name) + + self.assertEqual([ + test_record.SubtestRecord( + name='it_errors', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.STOP), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_nested__pass(self): + subtest = phase_collections.Subtest( + 'outer', phase, phase_collections.Subtest('inner', phase)) + + test_rec = yield htf.Test(subtest) + + self.assertTestPass(test_rec) + + outer_phase_rec = test_rec.phases[1] + self.assertEqual('outer', outer_phase_rec.subtest_name) + + inner_phase_rec = test_rec.phases[2] + self.assertEqual('inner', inner_phase_rec.subtest_name) + + self.assertEqual([ + test_record.SubtestRecord( + name='inner', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.PASS), + test_record.SubtestRecord( + name='outer', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.PASS), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_nested__fail(self): + subtest = phase_collections.Subtest( + 'outer', phase, + phase_collections.Subtest('inner', fail_subtest_phase, skip_phase), + empty_phase) + + test_rec = yield htf.Test(subtest) + + self.assertTestFail(test_rec) + + outer_phase_rec = test_rec.phases[1] + self.assertEqual('phase', outer_phase_rec.name) + self.assertEqual('outer', outer_phase_rec.subtest_name) + self.assertPhaseOutcomePass(outer_phase_rec) + + inner_phase_rec = test_rec.phases[2] + self.assertEqual('fail_subtest_phase', inner_phase_rec.name) + self.assertEqual('inner', inner_phase_rec.subtest_name) + self.assertPhaseOutcomeFail(inner_phase_rec) + + skip_phase_rec = test_rec.phases[3] + self.assertEqual('skip_phase', skip_phase_rec.name) + self.assertEqual('inner', skip_phase_rec.subtest_name) + self.assertPhaseOutcomeSkip(skip_phase_rec) + + outer_phase2_rec = test_rec.phases[4] + self.assertEqual('empty_phase', outer_phase2_rec.name) + self.assertEqual('outer', outer_phase2_rec.subtest_name) + self.assertPhaseOutcomePass(outer_phase2_rec) + + self.assertEqual([ + test_record.SubtestRecord( + name='inner', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + test_record.SubtestRecord( + name='outer', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.PASS), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail_subtest__not_in_subtest(self): + test_rec = yield htf.Test(fail_subtest_phase, phase) + + self.assertTestError( + test_rec, exc_type=phase_executor.InvalidPhaseResultError) + + fail_phase_rec = test_rec.phases[1] + self.assertPhaseError( + fail_phase_rec, exc_type=phase_executor.InvalidPhaseResultError) + self.assertPhaseOutcomeError(fail_phase_rec) + self.assertIsNone(fail_phase_rec.subtest_name) + + @htf_test.yields_phases + def test_fail_subtest__nested_subtest_also_skipped(self): + subtest = phase_collections.Subtest( + 'outer', fail_subtest_phase, skip_phase0, + phase_collections.Subtest('inner', skip_phase), skip_phase1) + + test_rec = yield htf.Test(subtest, phase) + + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_subtest_phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip_phase0', 'skip_phase', 'skip_phase1') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase') + + self.assertEqual([ + test_record.SubtestRecord( + name='inner', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + test_record.SubtestRecord( + name='outer', + start_time_millis=htf_test.VALID_TIMESTAMP, + end_time_millis=htf_test.VALID_TIMESTAMP, + outcome=test_record.SubtestOutcome.FAIL), + ], test_rec.subtests) + + @htf_test.yields_phases + def test_fail_subtest__skip_checkpoint(self): + subtest = phase_collections.Subtest( + 'skip_checkpoint', fail_subtest_phase, + htf.PhaseFailureCheckpoint('must_be_skipped'), skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + fail_phase_rec = test_rec.phases[1] + self.assertPhaseFailSubtest(fail_phase_rec) + self.assertPhaseOutcomeFail(fail_phase_rec) + + skip_phase_rec = test_rec.phases[2] + self.assertPhaseOutcomeSkip(skip_phase_rec) + + continue_phase_rec = test_rec.phases[3] + self.assertPhaseOutcomePass(continue_phase_rec) + + self.assertTrue(test_rec.checkpoints[0].result.is_skip) + + @htf_test.yields_phases + def test_fail_subtest__skip_branch_that_would_not_run(self): + + class _Diag(htf.DiagResultEnum): + NOT_SET = 'not_set' + + subtest = phase_collections.Subtest( + 'skip_branch', fail_subtest_phase, + htf.BranchSequence(_Diag.NOT_SET, error_phase), skip_phase) + + test_rec = yield htf.Test(subtest, phase) + + self.assertTestFail(test_rec) + + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_subtest_phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip_phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'phase') + self.assertPhasesNotRun(test_rec, 'error_phase') + + @htf_test.yields_phases + def test_fail_subtest__skip_branch_that_would_run(self): + + class _Diag(htf.DiagResultEnum): + SET = 'set' + + @htf.PhaseDiagnoser(_Diag) + def diagnoser(phase_rec): + del phase_rec # Unused. + return htf.Diagnosis(_Diag.SET) + + @htf.diagnose(diagnoser) + def diag_phase(): + pass + + subtest = phase_collections.Subtest( + 'skip_branch', fail_subtest_phase, + htf.BranchSequence( + htf.DiagnosisCondition.on_all(_Diag.SET), error_phase), skip_phase) + + test_rec = yield htf.Test(diag_phase, subtest, phase) + + self.assertTestFail(test_rec) + + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.FAIL, test_rec, + 'fail_subtest_phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.SKIP, test_rec, + 'skip_phase') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'diag_phase', 'phase') + self.assertPhasesNotRun(test_rec, 'error_phase') diff --git a/test/core/phase_group_test.py b/test/core/phase_group_test.py index 8ccf8c1b1..7ceb28874 100644 --- a/test/core/phase_group_test.py +++ b/test/core/phase_group_test.py @@ -3,8 +3,12 @@ import threading import unittest +import mock import openhtf as htf from openhtf import plugs +from openhtf.core import base_plugs +from openhtf.core import phase_collections +from openhtf.core import test_record from openhtf.util import test as htf_test @@ -25,6 +29,10 @@ def _rename(phase, new_name): return htf.PhaseOptions(name=new_name)(phase) +def _prefix_name(phase): + return htf.PhaseOptions(name='prefix:' + phase.name)(phase) + + def _fake_phases(*new_names): return [_rename(blank_phase, name) for name in new_names] @@ -36,7 +44,7 @@ def arg_phase(test, arg1=None, arg2=None): del arg2 # Unused. -class ParentPlug(plugs.BasePlug): +class ParentPlug(base_plugs.BasePlug): pass @@ -52,35 +60,29 @@ def plug_phase(my_plug): def _abort_test_in_thread(test): # See note in test/core/exe_test.py for _abort_executor_in_thread. inner_ev = threading.Event() + def stop_executor(): test.abort_from_sig_int() inner_ev.set() + threading.Thread(target=stop_executor).start() inner_ev.wait(1) class PhaseGroupTest(unittest.TestCase): - def testInit(self): - setup = [1] - main = [2] - teardown = [3] + def testConstruct(self): + setup = _fake_phases('1') + main = _fake_phases('2') + teardown = _fake_phases('3') name = 'name' pg = htf.PhaseGroup(setup=setup, main=main, teardown=teardown, name=name) - self.assertEqual(tuple(setup), pg.setup) - self.assertEqual(tuple(main), pg.main) - self.assertEqual(tuple(teardown), pg.teardown) + self.assertEqual(phase_collections.PhaseSequence(tuple(setup)), pg.setup) + self.assertEqual(phase_collections.PhaseSequence(tuple(main)), pg.main) + self.assertEqual( + phase_collections.PhaseSequence(tuple(teardown)), pg.teardown) self.assertEqual(name, pg.name) - def testConvertIfNot_Not(self): - phases = _fake_phases('a', 'b', 'c') - expected = htf.PhaseGroup(main=_fake_phases('a', 'b', 'c')) - self.assertEqual(expected, htf.PhaseGroup.convert_if_not(phases)) - - def testConvertIfNot_Group(self): - expected = htf.PhaseGroup() - self.assertEqual(expected, htf.PhaseGroup.convert_if_not(expected)) - def testWithContext(self): setup = _fake_phases('setup') main = _fake_phases('main') @@ -124,6 +126,13 @@ def testCombine(self): teardown=_fake_phases('t1', 't2')) self.assertEqual(expected, group1.combine(group2)) + def testCombine_Empty(self): + group1 = htf.PhaseGroup(main=_fake_phases('m1')) + group2 = htf.PhaseGroup(teardown=_fake_phases('t1')) + expected = htf.PhaseGroup( + main=_fake_phases('m1'), teardown=_fake_phases('t1')) + self.assertEqual(expected, group1.combine(group2)) + def testWrap(self): group = htf.PhaseGroup( setup=_fake_phases('s1'), @@ -151,63 +160,63 @@ def testWrap_SinglePhase(self): def testWithArgs_Setup(self): group = htf.PhaseGroup(setup=[blank_phase, arg_phase]) arg_group = group.with_args(arg1=1) - self.assertEqual(blank_phase, arg_group.setup[0]) - self.assertEqual(arg_phase.with_args(arg1=1), arg_group.setup[1]) + self.assertEqual(blank_phase, arg_group.setup.nodes[0]) + self.assertEqual(arg_phase.with_args(arg1=1), arg_group.setup.nodes[1]) def testWithArgs_Main(self): group = htf.PhaseGroup(main=[blank_phase, arg_phase]) arg_group = group.with_args(arg1=1) - self.assertEqual(blank_phase, arg_group.main[0]) - self.assertEqual(arg_phase.with_args(arg1=1), arg_group.main[1]) + self.assertEqual(blank_phase, arg_group.main.nodes[0]) + self.assertEqual(arg_phase.with_args(arg1=1), arg_group.main.nodes[1]) def testWithArgs_Teardown(self): group = htf.PhaseGroup(teardown=[blank_phase, arg_phase]) arg_group = group.with_args(arg1=1) - self.assertEqual(blank_phase, arg_group.teardown[0]) - self.assertEqual(arg_phase.with_args(arg1=1), arg_group.teardown[1]) + self.assertEqual(blank_phase, arg_group.teardown.nodes[0]) + self.assertEqual(arg_phase.with_args(arg1=1), arg_group.teardown.nodes[1]) def testWithArgs_Recursive(self): inner_group = htf.PhaseGroup(main=[blank_phase, arg_phase]) outer_group = htf.PhaseGroup(main=[inner_group, arg_phase]) arg_group = outer_group.with_args(arg2=2) - self.assertEqual(blank_phase, arg_group.main[0].main[0]) - self.assertEqual(arg_phase.with_args(arg2=2), arg_group.main[0].main[1]) - self.assertEqual(arg_phase.with_args(arg2=2), arg_group.main[1]) + all_phases = list(arg_group.all_phases()) + self.assertEqual(blank_phase, all_phases[0]) + self.assertEqual(arg_phase.with_args(arg2=2), all_phases[1]) + self.assertEqual(arg_phase.with_args(arg2=2), all_phases[2]) def testWithPlugs_Setup(self): group = htf.PhaseGroup(setup=[blank_phase, plug_phase]) plug_group = group.with_plugs(my_plug=ChildPlug) - self.assertEqual(blank_phase, plug_group.setup[0]) - self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), - plug_group.setup[1]) + self.assertEqual(blank_phase, plug_group.setup.nodes[0]) + self.assertEqual( + plug_phase.with_plugs(my_plug=ChildPlug), plug_group.setup.nodes[1]) def testWithPlugs_Main(self): group = htf.PhaseGroup(main=[blank_phase, plug_phase]) plug_group = group.with_plugs(my_plug=ChildPlug) - self.assertEqual(blank_phase, plug_group.main[0]) - self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), - plug_group.main[1]) + self.assertEqual(blank_phase, plug_group.main.nodes[0]) + self.assertEqual( + plug_phase.with_plugs(my_plug=ChildPlug), plug_group.main.nodes[1]) def testWithPlugs_Teardown(self): group = htf.PhaseGroup(teardown=[blank_phase, plug_phase]) plug_group = group.with_plugs(my_plug=ChildPlug) - self.assertEqual(blank_phase, plug_group.teardown[0]) - self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), - plug_group.teardown[1]) + self.assertEqual(blank_phase, plug_group.teardown.nodes[0]) + self.assertEqual( + plug_phase.with_plugs(my_plug=ChildPlug), plug_group.teardown.nodes[1]) def testWithPlugs_Recursive(self): inner_group = htf.PhaseGroup(main=[blank_phase, plug_phase]) outer_group = htf.PhaseGroup(main=[inner_group, plug_phase]) plug_group = outer_group.with_plugs(my_plug=ChildPlug) - self.assertEqual(blank_phase, plug_group.main[0].main[0]) - self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), - plug_group.main[0].main[1]) - self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), - plug_group.main[1]) + all_phases = list(plug_group.all_phases()) + self.assertEqual(blank_phase, all_phases[0]) + self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), all_phases[1]) + self.assertEqual(plug_phase.with_plugs(my_plug=ChildPlug), all_phases[2]) - def testIterate(self): + def testAllPhases(self): inner_group = htf.PhaseGroup( setup=_fake_phases('a', 'b'), main=_fake_phases('c', 'd'), @@ -225,29 +234,7 @@ def testIterate(self): 'e', 'f', # Inner teardown. '7', # Rest of outer main. '8', '9', # Outer teardown. - ), list(outer_group)) - - def testFlatten(self): - inner = htf.PhaseGroup( - setup=_fake_phases('a', 'b') + [_fake_phases('c')], - main=[_fake_phases('d')], - teardown=[_fake_phases('e'), _fake_phases('f')] + _fake_phases('g')) - outer = htf.PhaseGroup( - setup=_fake_phases('1', '2'), - main=[_fake_phases('3')] + [inner, _fake_phases('4')] + - _fake_phases('5'), - teardown=_fake_phases('6') + [_fake_phases('7', '8')] + - _fake_phases('9')) - - expected_inner = htf.PhaseGroup( - setup=_fake_phases('a', 'b', 'c'), - main=_fake_phases('d'), - teardown=_fake_phases('e', 'f', 'g')) - expected_outer = htf.PhaseGroup( - setup=_fake_phases('1', '2'), - main=_fake_phases('3') + [expected_inner] + _fake_phases('4', '5'), - teardown=_fake_phases('6', '7', '8', '9')) - self.assertEqual(expected_outer, outer.flatten()) + ), list(outer_group.all_phases())) # pyformat: disable def testLoadCodeInfo(self): group = htf.PhaseGroup( @@ -255,10 +242,75 @@ def testLoadCodeInfo(self): main=_fake_phases('main'), teardown=_fake_phases('teardown')) code_group = group.load_code_info() - self.assertEqual(blank.__name__, code_group.setup[0].code_info.name) - self.assertEqual(blank.__name__, code_group.main[0].code_info.name) + all_phases = list(code_group.all_phases()) + code_info = test_record.CodeInfo.for_function(blank) + self.assertEqual(code_info, all_phases[0].code_info) + self.assertEqual(code_info, all_phases[1].code_info) + self.assertEqual(code_info, all_phases[2].code_info) + + @mock.patch.object(htf.PhaseDescriptor, '_asdict', autospec=True) + def testAsDict_Full(self, mock_phase_asdict): + + def phase_asdict(self_phase): + return self_phase.name + + mock_phase_asdict.side_effect = phase_asdict + setup = _fake_phases('setup') + main = _fake_phases('main') + teardown = _fake_phases('teardown') + group = htf.PhaseGroup( + setup=setup, main=main, teardown=teardown, name='group') self.assertEqual( - blank.__name__, code_group.teardown[0].code_info.name) + { + 'setup': { + 'nodes': ['setup'], + 'name': None + }, + 'main': { + 'nodes': ['main'], + 'name': None + }, + 'teardown': { + 'nodes': ['teardown'], + 'name': None + }, + 'name': 'group', + }, group._asdict()) + mock_phase_asdict.assert_has_calls( + [mock.call(setup[0]), + mock.call(main[0]), + mock.call(teardown[0])]) + + @mock.patch.object(htf.PhaseDescriptor, '_asdict', autospec=True) + def testAsDict_Empty(self, mock_phase_asdict): + group = htf.PhaseGroup(name='group') + self.assertEqual( + { + 'setup': None, + 'main': None, + 'teardown': None, + 'name': 'group', + }, group._asdict()) + mock_phase_asdict.assert_not_called() + + def testApplyToAllPhases_Empty(self): + group = htf.PhaseGroup(name='group') + + expected = htf.PhaseGroup(name='group') + self.assertEqual(expected, group.apply_to_all_phases(_prefix_name)) + + def testApplyToAllPhases_Full(self): + group = htf.PhaseGroup( + setup=_fake_phases('setup'), + main=_fake_phases('main'), + teardown=_fake_phases('teardown'), + name='group') + expected = htf.PhaseGroup( + setup=_fake_phases('prefix:setup'), + main=_fake_phases('prefix:main'), + teardown=_fake_phases('prefix:teardown'), + name='group') + self.assertEqual(expected, group.apply_to_all_phases(_prefix_name)) class PhaseGroupIntegrationTest(htf_test.TestCase): @@ -289,19 +341,19 @@ def testRecursive(self): name='inner') recursive = htf.PhaseGroup( setup=_fake_phases('setup'), - main=(_fake_phases('main-pre') + [inner] + - _fake_phases('main-post')), + main=(_fake_phases('main-pre') + [inner] + _fake_phases('main-post')), teardown=_fake_phases('teardown'), name='recursive') test_rec = yield htf.Test(recursive) self.assertTestPass(test_rec) - self._assert_phase_names( - ['setup', 'main-pre', 'inner-setup', 'inner-main', 'inner-teardown', - 'main-post', 'teardown'], - test_rec) + self._assert_phase_names([ + 'setup', 'main-pre', 'inner-setup', 'inner-main', 'inner-teardown', + 'main-post', 'teardown' + ], test_rec) @htf_test.yields_phases def testAbort_Setup(self): + @htf.PhaseOptions() def abort_phase(): _abort_test_in_thread(test) @@ -319,6 +371,7 @@ def abort_phase(): @htf_test.yields_phases def testAbort_Main(self): + @htf.PhaseOptions() def abort_phase(): _abort_test_in_thread(test) @@ -336,6 +389,7 @@ def abort_phase(): @htf_test.yields_phases def testAbort_Teardown(self): + @htf.PhaseOptions() def abort_phase(): _abort_test_in_thread(test) @@ -348,8 +402,8 @@ def abort_phase(): test = htf.Test(abort_teardown) test_rec = yield test self.assertTestAborted(test_rec) - self._assert_phase_names( - ['setup0', 'main0', 'td0', 'abort_phase', 'td1'], test_rec) + self._assert_phase_names(['setup0', 'main0', 'td0', 'abort_phase', 'td1'], + test_rec) @htf_test.yields_phases def testFailure_Before(self): @@ -382,8 +436,8 @@ def testFailure_Main(self): name='fail_main') test_rec = yield htf.Test(fail_main) self.assertTestFail(test_rec) - self._assert_phase_names( - ['setup', 'main0', 'stop_phase', 'teardown0'], test_rec) + self._assert_phase_names(['setup', 'main0', 'stop_phase', 'teardown0'], + test_rec) @htf_test.yields_phases def testFailure_Teardown(self): @@ -394,15 +448,14 @@ def testFailure_Teardown(self): name='fail_teardown') test_rec = yield htf.Test(fail_teardown) self.assertTestFail(test_rec) - self._assert_phase_names( - ['setup', 'main', 'td0', 'stop_phase', 'td1'], test_rec) + self._assert_phase_names(['setup', 'main', 'td0', 'stop_phase', 'td1'], + test_rec) @htf_test.yields_phases def testRecursive_FailureSetup(self): inner_fail = htf.PhaseGroup( - setup=( - _fake_phases('inner-setup0') + [stop_phase] + - _fake_phases('not-run')), + setup=(_fake_phases('inner-setup0') + [stop_phase] + + _fake_phases('not-run')), main=_fake_phases('not-run-inner'), teardown=_fake_phases('not-run-inner-teardown'), name='inner_fail') @@ -421,8 +474,7 @@ def testRecursive_FailureSetup(self): def testRecursive_FailureMain(self): inner_fail = htf.PhaseGroup( setup=_fake_phases('inner-setup0'), - main=(_fake_phases('inner0') + [stop_phase] + - _fake_phases('not-run')), + main=(_fake_phases('inner0') + [stop_phase] + _fake_phases('not-run')), teardown=_fake_phases('inner-teardown'), name='inner_fail') outer = htf.PhaseGroup( @@ -432,19 +484,25 @@ def testRecursive_FailureMain(self): name='outer') test_rec = yield htf.Test(outer) self.assertTestFail(test_rec) - self._assert_phase_names( - ['setup0', 'outer0', 'inner-setup0', 'inner0', 'stop_phase', - 'inner-teardown', 'teardown0'], - test_rec) + self._assert_phase_names([ + 'setup0', 'outer0', 'inner-setup0', 'inner0', 'stop_phase', + 'inner-teardown', 'teardown0' + ], test_rec) @htf_test.yields_phases - def testOldTeardown(self): - phases = _fake_phases('p0', 'p1', 'p2') - teardown_phase = _rename(blank_phase, 'teardown') + def testRecursive_InTeardownAllRun(self): + group = htf.PhaseGroup( + main=_fake_phases('main0'), + teardown=[ + htf.PhaseSequence([stop_phase] + + _fake_phases('teardown0', 'teardown1')) + ]) - test = htf.Test(phases) - test.configure(teardown_function=teardown_phase) - run = test._get_running_test_descriptor() - test_rec = yield test - self.assertTestPass(test_rec) - self._assert_phase_names(['p0', 'p1', 'p2', 'teardown'], test_rec) + test_rec = yield htf.Test(group) + self.assertTestFail(test_rec) + self._assert_phase_names(['main0', 'stop_phase', 'teardown0', 'teardown1'], + test_rec) + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.PASS, test_rec, + 'main0', 'teardown0', 'teardown1') + self.assertPhasesOutcomeByName(test_record.PhaseOutcome.ERROR, test_rec, + 'stop_phase') diff --git a/test/core/test_descriptor_test.py b/test/core/test_descriptor_test.py index 436441bcd..aadd209df 100644 --- a/test/core/test_descriptor_test.py +++ b/test/core/test_descriptor_test.py @@ -1,15 +1,15 @@ -# Lint as: python2, python3 +# Lint as: python3 """Unit tests for test_descriptor module.""" -import unittest -import mock import re +import unittest +import mock from openhtf.core import test_descriptor -from openhtf.util import console_output class RegexMatcher(object): + def __init__(self, pattern): self.pattern = pattern @@ -21,11 +21,14 @@ class TestTest(unittest.TestCase): @mock.patch.object(test_descriptor, '_LOG') def test_output_cb_error_stacktrace_log(self, mock_log): + def phase(): return + def callback(test_record): del test_record raise Exception('test123') + test = test_descriptor.Test(phase) test.add_output_callbacks(callback) test.execute() diff --git a/test/core/test_record_test.py b/test/core/test_record_test.py index 7fcfeb907..7b10072f7 100644 --- a/test/core/test_record_test.py +++ b/test/core/test_record_test.py @@ -1,4 +1,4 @@ -# Lint as: python2, python3 +# Lint as: python3 """Unit tests for test_record module.""" import sys @@ -9,7 +9,7 @@ def _get_obj_size(obj): size = 0 - for attr in obj.__slots__: + for attr in obj.__slots__: # pytype: disable=attribute-error size += sys.getsizeof(attr) size += sys.getsizeof(getattr(obj, attr)) return size @@ -24,7 +24,7 @@ def test_attachment_data(self): self.assertEqual(data, expected_data) def test_attachment_memory_safety(self): - empty_attachment = test_record.Attachment('', 'text') + empty_attachment = test_record.Attachment(b'', 'text') expected_obj_size = _get_obj_size(empty_attachment) large_data = b'test attachment data' * 1000 attachment = test_record.Attachment(large_data, 'text') diff --git a/test/output/callbacks/callbacks_test.py b/test/output/callbacks/callbacks_test.py index 4c8134a2b..e5c19d51a 100644 --- a/test/output/callbacks/callbacks_test.py +++ b/test/output/callbacks/callbacks_test.py @@ -20,7 +20,6 @@ import io import json -import sys import unittest import openhtf as htf @@ -38,10 +37,13 @@ class TestOutput(test.TestCase): @classmethod def setUpClass(cls): + super(TestOutput, cls).setUpClass() # Create input record. result = util.NonLocalResult() + def _save_result(test_record): result.result = test_record + cls._test = htf.Test( all_the_things.hello_world, all_the_things.dimensions, @@ -54,12 +56,8 @@ def _save_result(test_record): def test_json(self, user_mock): user_mock.prompt.return_value = 'SomeWidget' record = yield self._test - if sys.version_info[0] < 3: - json_output = io.BytesIO() - else: - json_output = io.StringIO() - json_factory.OutputToJSON( - json_output, sort_keys=True, indent=2)(record) + json_output = io.BytesIO() + json_factory.OutputToJSON(json_output, sort_keys=True, indent=2)(record) json_output.seek(0) json.loads(json_output.read()) @@ -106,10 +104,13 @@ class TestMfgEventOutput(test.TestCase): @classmethod def setUpClass(cls): + super(TestMfgEventOutput, cls).setUpClass() # Create input record. result = util.NonLocalResult() + def _save_result(test_record): result.result = test_record + cls._test = htf.Test( all_the_things.hello_world, all_the_things.dimensions, @@ -154,8 +155,9 @@ def test_mfg_event_from_test_record(self, user_mock): raise AssertionError('No measurement named %s' % measurement_name) # Spot check an attachment (example_attachment.txt) - for attachment_name in ['example_attachment_0.txt', - 'example_attachment_1.txt']: + for attachment_name in [ + 'example_attachment_0.txt', 'example_attachment_1.txt' + ]: for attachment in mfg_event.attachment: if attachment.name == attachment_name: self.assertEqual( diff --git a/test/output/callbacks/mfg_event_converter_test.py b/test/output/callbacks/mfg_event_converter_test.py index 152a0474d..d16224fa0 100644 --- a/test/output/callbacks/mfg_event_converter_test.py +++ b/test/output/callbacks/mfg_event_converter_test.py @@ -16,11 +16,9 @@ from openhtf.util import logs as test_logs from openhtf.util import units - TEST_MULTIDIM_JSON_FILE = os.path.join( - os.path.dirname(__file__), - 'multidim_testdata.json') -with io.open(TEST_MULTIDIM_JSON_FILE, 'r', encoding='utf-8') as f: + os.path.dirname(__file__), 'multidim_testdata.json') +with io.open(TEST_MULTIDIM_JSON_FILE, 'rb') as f: TEST_MULTIDIM_JSON = f.read() @@ -45,20 +43,20 @@ def test_mfg_event_from_test_record(self): record.outcome = test_record.Outcome.PASS record.metadata = { 'assembly_events': [assembly_event_pb2.AssemblyEvent()] * 2, - 'config': {'mock-config-key': 'mock-config-value'}, + 'config': { + 'mock-config-key': 'mock-config-value' + }, 'operator_name': 'mock-operator-name', } record.phases = [ - test_record.PhaseRecord( + test_record.PhaseRecord( # pylint: disable=g-complex-comprehension name='phase-%d' % idx, descriptor_id=idx, codeinfo=test_record.CodeInfo.uncaptured(), result=None, attachments={}, start_time_millis=1, - end_time_millis=1 - ) - for idx in range(1, 5) + end_time_millis=1) for idx in range(1, 5) ] for phase in record.phases: phase.measurements = { @@ -67,8 +65,8 @@ def test_mfg_event_from_test_record(self): 'meas-3': measurements.Measurement('meas-3').with_dimensions('V'), } phase.attachments = { - 'attach-1': test_record.Attachment(data='data-1', mimetype=''), - 'attach-2': test_record.Attachment(data='data-2', mimetype=''), + 'attach-1': test_record.Attachment(b'data-1', ''), + 'attach-2': test_record.Attachment(b'data-2', ''), } mfg_event = mfg_event_converter.mfg_event_from_test_record(record) @@ -76,17 +74,36 @@ def test_mfg_event_from_test_record(self): self.assertEqual(mfg_event.dut_serial, record.dut_id) self.assertEqual(len(mfg_event.assembly_events), 2) self.assertEqual(len(mfg_event.measurement), 8) - self.assertEqual(sorted(m.name for m in mfg_event.measurement), - ['meas-1_0', 'meas-1_1', 'meas-1_2', 'meas-1_3', - 'meas-2_0', 'meas-2_1', 'meas-2_2', 'meas-2_3']) + self.assertEqual( # pylint: disable=g-generic-assert + sorted(m.name for m in mfg_event.measurement), [ + 'meas-1_0', + 'meas-1_1', + 'meas-1_2', + 'meas-1_3', + 'meas-2_0', + 'meas-2_1', + 'meas-2_2', + 'meas-2_3', + ]) self.assertEqual(len(mfg_event.attachment), 15) - self.assertEqual(sorted(str(m.name) for m in mfg_event.attachment), - ['OpenHTF_record.json', 'argv', - 'attach-1_0', 'attach-1_1', 'attach-1_2', 'attach-1_3', - 'attach-2_0', 'attach-2_1', 'attach-2_2', 'attach-2_3', - 'config', - 'multidim_meas-3_0', 'multidim_meas-3_1', - 'multidim_meas-3_2', 'multidim_meas-3_3']) + self.assertEqual( # pylint: disable=g-generic-assert + sorted(str(m.name) for m in mfg_event.attachment), [ + 'OpenHTF_record.json', + 'argv', + 'attach-1_0', + 'attach-1_1', + 'attach-1_2', + 'attach-1_3', + 'attach-2_0', + 'attach-2_1', + 'attach-2_2', + 'attach-2_3', + 'config', + 'multidim_meas-3_0', + 'multidim_meas-3_1', + 'multidim_meas-3_2', + 'multidim_meas-3_3', + ]) def test_populate_basic_data(self): outcome_details = test_record.OutcomeDetails( @@ -166,9 +183,19 @@ def test_attach_record_as_json(self): self.assertTrue(mfg_event.attachment[0].value_binary) # Assert truthy. self.assertEqual(mfg_event.attachment[0].type, test_runs_pb2.TEXT_UTF8) + def test_convert_object_to_json_with_bytes(self): + input_object = {'foo': b'bar'} + output_json = mfg_event_converter._convert_object_to_json(input_object) + expected_json = (b'{\n' + b' "foo": "bar"\n' + b'}') + self.assertEqual(output_json, expected_json) + def test_attach_config(self): - record = test_record.TestRecord('mock-dut-id', 'mock-station-id', - metadata={'config': {'key': 'value'}}) + record = test_record.TestRecord( + 'mock-dut-id', 'mock-station-id', metadata={'config': { + 'key': 'value' + }}) mfg_event = mfg_event_pb2.MfgEvent() mfg_event_converter._attach_config(mfg_event, record) @@ -181,9 +208,9 @@ def _create_and_set_measurement(self, name, value): measured_value.set(value) measurement = measurements.Measurement( - name=name, outcome=measurements.Outcome.PASS) - # Cannot be set in initialization. - measurement.measured_value = measured_value + name=name, + outcome=measurements.Outcome.PASS, + measured_value=measured_value) return measurement def test_copy_measurements_from_phase(self): @@ -262,7 +289,7 @@ def test_copy_measurements_from_phase(self): self.assertEqual(mock_measurement_within_percent.numeric_maximum, 12.0) def testCopyAttachmentsFromPhase(self): - attachment = test_record.Attachment('mock-data', 'text/plain') + attachment = test_record.Attachment(b'mock-data', 'text/plain') phase = test_record.PhaseRecord( name='mock-phase-name', descriptor_id=1, @@ -346,8 +373,8 @@ def test_reversibleish_leagcy_status_int(self): # Re-parse the data, edit the outcome field to a int, then reserialize. data_dict = json.loads(attachment.data) data_dict['outcome'] = test_runs_pb2.Status.Value(data_dict['outcome']) - attachment = test_record.Attachment(json.dumps(data_dict), - test_runs_pb2.MULTIDIM_JSON) + attachment = test_record.Attachment( + json.dumps(data_dict).encode('utf-8'), test_runs_pb2.MULTIDIM_JSON) reversed_mdim = mfg_event_converter.attachment_to_multidim_measurement( attachment) @@ -367,5 +394,6 @@ def assert_same_mdim(self, expected, other): assert k in other.measured_value.value_dict, ( 'expected key %s is not present in other multidim' % k) other_v = other.measured_value.value_dict[k] - self.assertEqual(v, other_v, 'Different values for key: %s (%s != %s)' % ( - k, v, other_v)) + self.assertEqual( + v, other_v, + 'Different values for key: %s (%s != %s)' % (k, v, other_v)) diff --git a/test/output/callbacks/mfg_inspector_test.py b/test/output/callbacks/mfg_inspector_test.py index 4b8bf5977..8d33b229d 100644 --- a/test/output/callbacks/mfg_inspector_test.py +++ b/test/output/callbacks/mfg_inspector_test.py @@ -28,8 +28,7 @@ from openhtf import util from examples import all_the_things from openhtf.output.callbacks import mfg_inspector -from openhtf.output.proto import mfg_event_converter -from openhtf.output.proto import mfg_event_pb2 +from openhtf.output.proto import guzzle_pb2 from openhtf.output.proto import test_runs_converter from openhtf.output.proto import test_runs_pb2 from openhtf.util import test @@ -38,32 +37,39 @@ tester_name='mock_test_run', dut_serial='UNITTEST1234', test_status=test_runs_pb2.PASS, - test_info=test_runs_pb2.TestInfo(name='unit_test') -) + test_info=test_runs_pb2.TestInfo(name='unit_test')) -MOCK_TEST_RUN = collections.namedtuple( - 'Testrun', mfg_inspector.MfgInspector.PARAMS)(None, None, None, None) +MOCK_TEST_RUN = collections.namedtuple('Testrun', + mfg_inspector.MfgInspector.PARAMS)(None, + None, + None, + None) class TestMfgInspector(test.TestCase): def setUp(self): + super(TestMfgInspector, self).setUp() self.mock_credentials = mock.patch( - 'oauth2client.client.SignedJwtAssertionCredentials' - ).start().return_value + 'oauth2client.client.SignedJwtAssertionCredentials').start( + ).return_value self.mock_send_mfg_inspector_data = mock.patch.object( mfg_inspector, 'send_mfg_inspector_data').start() def tearDown(self): mock.patch.stopall() + super(TestMfgInspector, self).tearDown() @classmethod def setUpClass(cls): + super(TestMfgInspector, cls).setUpClass() # Create input record. result = util.NonLocalResult() + def _save_result(test_record): result.result = test_record + cls._test = htf.Test( all_the_things.hello_world, all_the_things.dimensions, @@ -81,8 +87,7 @@ def test_save_only(self, user_mock): callback = mfg_inspector.MfgInspector() callback.set_converter( - converter=test_runs_converter.test_run_from_test_record, - ) + converter=test_runs_converter.test_run_from_test_record,) save_to_disk_callback = callback.save_to_disk( filename_pattern=testrun_output) save_to_disk_callback(record) @@ -101,13 +106,14 @@ def test_save_only(self, user_mock): def test_upload_only(self): mock_converter = mock.MagicMock(return_value=MOCK_TEST_RUN_PROTO) callback = mfg_inspector.MfgInspector( - user='user', keydata='keydata', token_uri='').set_converter( - mock_converter) + user='user', keydata='keydata', + token_uri='').set_converter(mock_converter) callback.upload()(MOCK_TEST_RUN) self.mock_send_mfg_inspector_data.assert_called_with( - MOCK_TEST_RUN_PROTO, self.mock_credentials, callback.destination_url) + MOCK_TEST_RUN_PROTO, self.mock_credentials, callback.destination_url, + guzzle_pb2.COMPRESSED_TEST_RUN) def test_save_and_upload(self): testrun_output = io.BytesIO() @@ -128,7 +134,8 @@ def test_save_and_upload(self): self.assertEqual(MOCK_TEST_RUN_PROTO, testrun) self.mock_send_mfg_inspector_data.assert_called_with( - MOCK_TEST_RUN_PROTO, self.mock_credentials, callback.destination_url) + MOCK_TEST_RUN_PROTO, self.mock_credentials, callback.destination_url, + guzzle_pb2.COMPRESSED_TEST_RUN) # Make sure mock converter only called once i.e. the test record was # was converted to a proto only once. This important because some custom diff --git a/test/phase_descriptor_test.py b/test/phase_descriptor_test.py index b1435c99f..ec545f5cf 100644 --- a/test/phase_descriptor_test.py +++ b/test/phase_descriptor_test.py @@ -14,28 +14,29 @@ import unittest +import attr import mock import openhtf from openhtf import plugs -from openhtf.core import phase_descriptor +from openhtf.core import base_plugs def plain_func(): - """Plain Docstring""" + """Plain Docstring.""" pass -def normal_test_phase(test): +def normal_test_phase(): return 'return value' -@openhtf.PhaseOptions(name='func-name({input[0]})') -def extra_arg_func(input=None): - return input +@openhtf.PhaseOptions(name='func-name({input_value[0]})') +def extra_arg_func(input_value=None): + return input_value -class ExtraPlug(plugs.BasePlug): +class ExtraPlug(base_plugs.BasePlug): name = 'extra_plug_0' def echo(self, phrase): @@ -48,7 +49,7 @@ def extra_plug_func(plug, phrase): return plug.echo(phrase) -class PlaceholderCapablePlug(plugs.BasePlug): +class PlaceholderCapablePlug(base_plugs.BasePlug): auto_placeholder = True @@ -66,86 +67,123 @@ def sub_placeholder_using_plug(subplaced): del subplaced # Unused. +class NonPlugBase(object): + """A base class that is not a BasePlug.""" + + +class PlugVersionOfNonPlug(NonPlugBase, base_plugs.BasePlug): + """Plug implementation of a non-plug base.""" + + +custom_placeholder = base_plugs.PlugPlaceholder(NonPlugBase) + + +@plugs.plug(custom=custom_placeholder) +def custom_placeholder_phase(custom): + del custom # Unused. + + class TestPhaseDescriptor(unittest.TestCase): def setUp(self): - super(TestPhaseDescriptor, self).setUp() - self._phase_data = mock.Mock( - plug_manager=plugs.PlugManager(), - execution_uid='01234567890') + super(TestPhaseDescriptor, self).setUp() + self._phase_data = mock.Mock( + plug_manager=plugs.PlugManager(), execution_uid='01234567890') def test_basics(self): - phase = openhtf.PhaseDescriptor.wrap_or_copy(plain_func) - self.assertIs(phase.func, plain_func) - self.assertEqual(0, len(phase.plugs)) - self.assertEqual('plain_func', phase.name) - self.assertEqual('Plain Docstring', phase.doc) - phase(self._phase_data) + phase = openhtf.PhaseDescriptor.wrap_or_copy(plain_func) + self.assertIs(phase.func, plain_func) # pytype: disable=wrong-arg-types + self.assertEqual(0, len(phase.plugs)) + self.assertEqual('plain_func', phase.name) + self.assertEqual('Plain Docstring.', phase.doc) + phase(self._phase_data) - test_phase = openhtf.PhaseDescriptor.wrap_or_copy(normal_test_phase) - self.assertEqual('normal_test_phase', test_phase.name) - self.assertEqual('return value', test_phase(self._phase_data)) + test_phase = openhtf.PhaseDescriptor.wrap_or_copy(normal_test_phase) + self.assertEqual('normal_test_phase', test_phase.name) + self.assertEqual('return value', test_phase(self._phase_data)) def test_multiple_phases(self): - phase = openhtf.PhaseDescriptor.wrap_or_copy(plain_func) - second_phase = openhtf.PhaseDescriptor.wrap_or_copy(phase) - for attr in type(phase).all_attribute_names: - if attr == 'func': continue - self.assertIsNot(getattr(phase, attr), getattr(second_phase, attr)) - - @mock.patch.object(phase_descriptor.PhaseDescriptor, "with_args") - def test_with_known_args(self, mock_with_args): - phase = openhtf.PhaseDescriptor.wrap_or_copy(extra_arg_func) - kwargs = {"input": True} - phase.with_known_args(**kwargs) - mock_with_args.assert_called_once_with(**kwargs) - - @mock.patch.object(phase_descriptor.PhaseDescriptor, "with_args") - def test_with_known_args_no_args(self, mock_with_args): - phase = openhtf.PhaseDescriptor.wrap_or_copy(normal_test_phase) - kwargs = {"input": True} - result = phase.with_known_args(**kwargs) - self.assertEqual(result, phase) - self.assertEqual(mock_with_args.call_count, 0) + phase = openhtf.PhaseDescriptor.wrap_or_copy(plain_func) + second_phase = openhtf.PhaseDescriptor.wrap_or_copy(phase) + for field in attr.fields(type(phase)): + if field.name == 'func': + continue + self.assertIsNot( + getattr(phase, field.name), getattr(second_phase, field.name)) + + def test_callable_name_with_args(self): + + def namer(**kwargs): + return 'renamed_{one}_{two}'.format(**kwargs) + + @openhtf.PhaseOptions(name=namer) + def custom_phase(one=None, two=None): + del one # Unused. + del two # Unused. + + self.assertEqual('custom_phase', custom_phase.name) + arged = custom_phase.with_args(one=1, two=2) + self.assertEqual('renamed_1_2', arged.name) def test_with_args(self): - phase = openhtf.PhaseDescriptor.wrap_or_copy(extra_arg_func) - phase = phase.with_args(input='input arg') - result = phase(self._phase_data) - first_result = phase(self._phase_data) - self.assertEqual('input arg', result) - self.assertEqual('func-name(i)', phase.name) - self.assertEqual('input arg', first_result) - - # Must do with_args() on the original phase, otherwise it has already been - # formatted and the format-arg information is lost. - second_phase = extra_arg_func.with_args(input='second input') - second_result = second_phase(self._phase_data) - self.assertEqual('second input', second_result) - self.assertEqual('func-name(s)', second_phase.name) + phase = extra_arg_func.with_args(input_value='input arg') + result = phase(self._phase_data) + first_result = phase(self._phase_data) + self.assertIs(phase.func, extra_arg_func.func) + self.assertEqual('input arg', result) + self.assertEqual('func-name(i)', phase.name) + self.assertEqual('input arg', first_result) + + # Must do with_args() on the original phase, otherwise it has already been + # formatted and the format-arg information is lost. + second_phase = extra_arg_func.with_args(input_value='second input') + second_result = second_phase(self._phase_data) + self.assertEqual('second input', second_result) + self.assertEqual('func-name(s)', second_phase.name) + + def test_with_args_argument_not_specified(self): + phase = extra_arg_func.with_args(arg_does_not_exist=1) + self.assertNotIn('arg_does_not_exist', phase.extra_kwargs) + + def test_with_args_kwargs(self): + @openhtf.PhaseOptions() + def phase(test_api, **kwargs): + del test_api # Unused. + del kwargs # Unused. + + updated = phase.with_args(arg_does_not_exist=1) + self.assertEqual({'arg_does_not_exist': 1}, updated.extra_kwargs) def test_with_plugs(self): - self._phase_data.plug_manager.initialize_plugs([ExtraPlug]) - phase = extra_plug_func.with_plugs(plug=ExtraPlug).with_args(phrase='hello') - self.assertIs(phase.func, extra_plug_func.func) - self.assertEqual(1, len(phase.plugs)) - self.assertEqual('extra_plug_func[extra_plug_0][hello]', phase.options.name) - self.assertEqual('extra_plug_func[extra_plug_0][hello]', phase.name) + self._phase_data.plug_manager.initialize_plugs([ExtraPlug]) + phase = extra_plug_func.with_plugs(plug=ExtraPlug).with_args(phrase='hello') + self.assertIs(phase.func, extra_plug_func.func) + self.assertEqual(1, len(phase.plugs)) + self.assertEqual('extra_plug_func[extra_plug_0][hello]', phase.options.name) + self.assertEqual('extra_plug_func[extra_plug_0][hello]', phase.name) + + result = phase(self._phase_data) + self.assertEqual('extra_plug_0 says hello', result) - result = phase(self._phase_data) - self.assertEqual('extra_plug_0 says hello', result) + def test_with_plugs_unknown_plug_name_ignored(self): + phase = placeholder_using_plug.with_plugs(undefined_plug=ExtraPlug) + self.assertIs(phase, placeholder_using_plug) def test_with_plugs_auto_placeholder(self): - phase = placeholder_using_plug.with_plugs( - placed=SubPlaceholderCapablePlug) - self.assertIs(phase.func, placeholder_using_plug.func) - self.assertEqual(1, len(phase.plugs)) + phase = placeholder_using_plug.with_plugs(placed=SubPlaceholderCapablePlug) + self.assertIs(phase.func, placeholder_using_plug.func) + self.assertEqual(1, len(phase.plugs)) def test_with_plugs_subclass_auto_placeholder_error(self): - with self.assertRaises(plugs.InvalidPlugError): - sub_placeholder_using_plug.with_plugs( - subplaced=SubPlaceholderCapablePlug) + with self.assertRaises(base_plugs.InvalidPlugError): + sub_placeholder_using_plug.with_plugs(subplaced=SubPlaceholderCapablePlug) def test_with_plugs_auto_placeholder_non_subclass_error(self): - with self.assertRaises(plugs.InvalidPlugError): - placeholder_using_plug.with_plugs(placed=ExtraPlug) + with self.assertRaises(base_plugs.InvalidPlugError): + placeholder_using_plug.with_plugs(placed=ExtraPlug) + + def test_with_plugs_custom_placeholder_is_base_plug(self): + phase = custom_placeholder_phase.with_plugs(custom=PlugVersionOfNonPlug) + self.assertIs(phase.func, custom_placeholder_phase.func) + self.assertEqual([base_plugs.PhasePlug('custom', PlugVersionOfNonPlug)], + phase.plugs) diff --git a/test/plugs/plugs_test.py b/test/plugs/plugs_test.py index ab68f99b0..6325c7084 100644 --- a/test/plugs/plugs_test.py +++ b/test/plugs/plugs_test.py @@ -14,13 +14,14 @@ import threading import time +import unittest -import openhtf from openhtf import plugs +from openhtf.core import base_plugs from openhtf.util import test -class AdderPlug(plugs.FrontendAwareBasePlug): +class AdderPlug(base_plugs.FrontendAwareBasePlug): INSTANCE_COUNT = 0 LAST_INSTANCE = None @@ -40,7 +41,7 @@ def increment(self): self.notify_update() return self.number - def tearDown(self): + def tearDown(self): # pylint: disable=g-missing-super-call self.state = 'TORN DOWN' @@ -48,20 +49,22 @@ class AdderSubclassPlug(AdderPlug): pass -class DummyPlug(plugs.BasePlug): +class DummyPlug(base_plugs.BasePlug): pass -class TearDownRaisesPlug1(plugs.BasePlug): +class TearDownRaisesPlug1(base_plugs.BasePlug): TORN_DOWN = False - def tearDown(self): + + def tearDown(self): # pylint: disable=g-missing-super-call type(self).TORN_DOWN = True raise Exception() -class TearDownRaisesPlug2(plugs.BasePlug): +class TearDownRaisesPlug2(base_plugs.BasePlug): TORN_DOWN = False - def tearDown(self): + + def tearDown(self): # pylint: disable=g-missing-super-call type(self).TORN_DOWN = True raise Exception() @@ -78,7 +81,7 @@ def tearDown(self): super(PlugsTest, self).tearDown() def test_base_plug(self): - plug = plugs.BasePlug() + plug = base_plugs.BasePlug() self.assertEqual({}, plug._asdict()) plug.tearDown() @@ -95,41 +98,42 @@ def test_initialize(self): self.plug_manager.provide_plugs( (('adder_plug', AdderPlug),))['adder_plug']) adder_plug_name = AdderPlug.__module__ + '.AdderPlug' - self.assertEqual( - { - adder_plug_name: {'mro': [adder_plug_name]} - }, - self.plug_manager.as_base_types()['plug_descriptors'] - ) - self.assertEqual( - { - adder_plug_name: {'number': 0} - }, - self.plug_manager.as_base_types()['plug_states'] - ) + self.assertEqual({adder_plug_name: { + 'mro': [adder_plug_name] + }}, + self.plug_manager.as_base_types()['plug_descriptors']) + self.assertEqual({adder_plug_name: { + 'number': 0 + }}, + self.plug_manager.as_base_types()['plug_states']) self.assertEqual('CREATED', AdderPlug.LAST_INSTANCE.state) @test.yields_phases def test_multiple_plugs(self): + @plugs.plug(adder_plug=AdderPlug) @plugs.plug(other_plug=AdderPlug) - def dummy_phase(test_api, adder_plug, other_plug): + def dummy_phase(adder_plug, other_plug): self.assertEqual(1, AdderPlug.INSTANCE_COUNT) self.assertIs(AdderPlug.LAST_INSTANCE, adder_plug) self.assertIs(AdderPlug.LAST_INSTANCE, other_plug) + yield dummy_phase - @plugs.plug(adder_plug=AdderPlug, - other_plug=plugs.BasePlug) - def dummy_phase(test_api, adder_plug, other_plug): + @plugs.plug(adder_plug=AdderPlug, other_plug=base_plugs.BasePlug) + def dummy_phase(adder_plug, other_plug): + del other_plug # Unused. self.assertEqual(1, AdderPlug.INSTANCE_COUNT) self.assertIs(AdderPlug.LAST_INSTANCE, adder_plug) + yield dummy_phase @test.yields_phases def test_plug_logging(self): """Test that both __init__ and other functions get the good logger.""" - class LoggingPlug(plugs.BasePlug): + + class LoggingPlug(base_plugs.BasePlug): + def __init__(self): self.logger_seen_init = self.logger @@ -137,7 +141,7 @@ def action(self): self.logger_seen_action = self.logger @plugs.plug(logger=LoggingPlug) - def dummy_phase(test_api, logger): + def dummy_phase(logger): logger.action() self.assertIs(logger.logger_seen_init, logger.logger_seen_action) self.assertIs(logger.logger_seen_init, self.logger) @@ -146,8 +150,8 @@ def dummy_phase(test_api, logger): def test_tear_down_raises(self): """Test that all plugs get torn down even if some raise.""" - self.plug_manager.initialize_plugs({ - TearDownRaisesPlug1, TearDownRaisesPlug2}) + self.plug_manager.initialize_plugs( + {TearDownRaisesPlug1, TearDownRaisesPlug2}) self.plug_manager.tear_down_plugs() self.assertTrue(TearDownRaisesPlug1.TORN_DOWN) self.assertTrue(TearDownRaisesPlug2.TORN_DOWN) @@ -155,47 +159,52 @@ def test_tear_down_raises(self): def test_plug_updates(self): self.plug_manager.initialize_plugs({AdderPlug}) adder_plug_name = AdderPlug.__module__ + '.AdderPlug' - update = self.plug_manager.wait_for_plug_update( - adder_plug_name, {}, .001) + update = self.plug_manager.wait_for_plug_update(adder_plug_name, {}, .001) self.assertEqual({'number': 0}, update) # No update since last time, this should time out (return None). - self.assertIsNone(self.plug_manager.wait_for_plug_update( - adder_plug_name, update, .001)) + self.assertIsNone( + self.plug_manager.wait_for_plug_update(adder_plug_name, update, .001)) def _delay_then_update(): time.sleep(.5) self.assertEqual(1, AdderPlug.LAST_INSTANCE.increment()) + threading.Thread(target=_delay_then_update).start() start_time = time.time() - self.assertEqual({'number': 1}, self.plug_manager.wait_for_plug_update( - adder_plug_name, update, 5)) + self.assertEqual({'number': 1}, + self.plug_manager.wait_for_plug_update( + adder_plug_name, update, 5)) self.assertGreater(time.time() - start_time, .2) def test_invalid_plug(self): - with self.assertRaises(plugs.InvalidPlugError): - self.plug_manager.initialize_plugs({object}) - with self.assertRaises(plugs.InvalidPlugError): - plugs.plug(adder_plug=object) - with self.assertRaises(plugs.InvalidPlugError): - self.plug_manager.initialize_plugs({ - type('BadPlug', (plugs.BasePlug,), {'logger': None})}) - with self.assertRaises(plugs.InvalidPlugError): - class BadPlugInit(plugs.BasePlug): + with self.assertRaises(base_plugs.InvalidPlugError): + self.plug_manager.initialize_plugs({object}) # pytype: disable=wrong-arg-types + with self.assertRaises(base_plugs.InvalidPlugError): + plugs.plug(adder_plug=object) # pytype: disable=wrong-arg-types + with self.assertRaises(base_plugs.InvalidPlugError): + self.plug_manager.initialize_plugs( + {type('BadPlug', (base_plugs.BasePlug,), {'logger': None})}) + with self.assertRaises(base_plugs.InvalidPlugError): + + class BadPlugInit(base_plugs.BasePlug): + def __init__(self): self.logger = None + self.plug_manager.initialize_plugs({BadPlugInit}) - with self.assertRaises(plugs.InvalidPlugError): + with self.assertRaises(base_plugs.InvalidPlugError): self.plug_manager.wait_for_plug_update('invalid', {}, 0) def test_duplicate_plug(self): with self.assertRaises(plugs.DuplicatePlugError): + @plugs.plug(adder_plug=AdderPlug) @plugs.plug(adder_plug=AdderPlug) - def dummy_phase(test, adder_plug): - pass + def dummy_phase(adder_plug): + del adder_plug # Unused. def test_uses_base_tear_down(self): - self.assertTrue(plugs.BasePlug().uses_base_tear_down()) + self.assertTrue(base_plugs.BasePlug().uses_base_tear_down()) self.assertTrue(DummyPlug().uses_base_tear_down()) self.assertFalse(AdderPlug().uses_base_tear_down()) self.assertFalse(AdderSubclassPlug().uses_base_tear_down()) diff --git a/test/plugs/user_input_test.py b/test/plugs/user_input_test.py index e9b81e5a9..5bdf51103 100644 --- a/test/plugs/user_input_test.py +++ b/test/plugs/user_input_test.py @@ -29,6 +29,7 @@ def tearDown(self): self.plug.tearDown() def test_respond_to_blocking_prompt(self): + def _respond_to_prompt(): as_dict = None while not as_dict: @@ -48,9 +49,8 @@ def test_respond_to_non_blocking_prompt(self): self.assertIsNotNone(self.plug._asdict()) - response_used = self.plug.respond(prompt_id, 'Mock response.') + self.plug.respond(prompt_id, 'Mock response.') - self.assertTrue(response_used) self.assertIsNone(self.plug._asdict()) self.assertEqual(self.plug.last_response, (prompt_id, 'Mock response.')) diff --git a/test/test_state_test.py b/test/test_state_test.py index bdf0c65f0..f03a21f38 100644 --- a/test/test_state_test.py +++ b/test/test_state_test.py @@ -20,7 +20,7 @@ import mock import openhtf -from openhtf.core import phase_group +from openhtf.core import phase_collections from openhtf.core import test_descriptor from openhtf.core import test_record from openhtf.core import test_state @@ -52,6 +52,7 @@ def test_phase(): }, 'attachments': {}, 'start_time_millis': 11235, + 'subtest_name': None, } PHASE_RECORD_BASE_TYPE = copy.deepcopy(PHASE_STATE_BASE_TYPE_INITIAL) @@ -69,7 +70,11 @@ def test_phase(): 'status': 'WAITING_FOR_TEST_START', 'test_record': { 'station_id': conf.station_id, - 'code_info': None, + 'code_info': { + 'docstring': None, + 'name': '', + 'sourcecode': '', + }, 'dut_id': None, 'start_time_millis': 0, 'end_time_millis': None, @@ -79,6 +84,8 @@ def test_phase(): 'config': {} }, 'phases': [], + 'subtests': [], + 'branches': [], 'diagnosers': [], 'diagnoses': [], 'log_records': [], @@ -95,12 +102,13 @@ class TestTestApi(unittest.TestCase): def setUp(self): super(TestTestApi, self).setUp() - patcher = mock.patch.object(test_record.PhaseRecord, 'record_start_time', - return_value=11235) + patcher = mock.patch.object( + test_record.PhaseRecord, 'record_start_time', return_value=11235) self.mock_record_start_time = patcher.start() self.addCleanup(patcher.stop) self.test_descriptor = test_descriptor.TestDescriptor( - phase_group.PhaseGroup(main=[test_phase]), None, {'config': {}}) + phase_collections.PhaseSequence((test_phase,)), + test_record.CodeInfo.uncaptured(), {'config': {}}) self.test_state = test_state.TestState(self.test_descriptor, 'testing-123', test_descriptor.TestOptions()) self.test_record = self.test_state.test_record @@ -116,6 +124,10 @@ def test_get_attachment(self): self.test_api.attach(attachment_name, input_contents, mimetype) output_attachment = self.test_api.get_attachment(attachment_name) + if not output_attachment: + # Need branch to appease pytype. + self.fail('output_attachment not found') + self.assertEqual(input_contents, output_attachment.data) self.assertEqual(mimetype, output_attachment.mimetype) @@ -123,6 +135,9 @@ def test_get_measurement(self): measurement_val = [1, 2, 3] self.test_api.measurements['test_measurement'] = measurement_val measurement = self.test_api.get_measurement('test_measurement') + if not measurement: + # Need branch to appease pytype. + self.fail('measurement not found.') self.assertEqual(measurement_val, measurement.value) self.assertEqual('test_measurement', measurement.name) @@ -131,6 +146,9 @@ def test_get_measurement_immutable(self): measurement_val = [1, 2, 3] self.test_api.measurements['test_measurement'] = measurement_val measurement = self.test_api.get_measurement('test_measurement') + if not measurement: + # Need branch to appease pytype. + self.fail('measurement not found.') self.assertEqual(measurement_val, measurement.value) self.assertEqual('test_measurement', measurement.name) @@ -145,6 +163,9 @@ def test_infer_mime_type_from_file_name(self): file_name = f.name self.test_api.attach_from_file(file_name, 'attachment') attachment = self.test_api.get_attachment('attachment') + if not attachment: + # Need branch to appease pytype. + self.fail('attachment not found.') self.assertEqual(attachment.mimetype, 'text/plain') def test_infer_mime_type_from_attachment_name(self): @@ -154,6 +175,9 @@ def test_infer_mime_type_from_attachment_name(self): file_name = f.name self.test_api.attach_from_file(file_name, 'attachment.png') attachment = self.test_api.get_attachment('attachment.png') + if not attachment: + # Need branch to appease pytype. + self.fail('attachment not found.') self.assertEqual(attachment.mimetype, 'image/png') def test_phase_state_cache(self): diff --git a/test/util/conf_test.py b/test/util/conf_test.py index a2a4b2c10..fa6076a8c 100644 --- a/test/util/conf_test.py +++ b/test/util/conf_test.py @@ -19,15 +19,16 @@ from openhtf.util import conf import six - args = [ '--config-value=flag_key=flag_value', - '--config-value', 'other_flag=other_value', + '--config-value', + 'other_flag=other_value', # You can specify arbitrary keys, but they'll get ignored if they aren't # actually declared anywhere (included here to make sure of that). '--config_value=undeclared_flag=who_cares', '--config-value=true_value=true', - '--config-value', 'num_value=100', + '--config-value', + 'num_value=100', ] conf.declare('flag_key') @@ -49,12 +50,14 @@ class TestConf(unittest.TestCase): NOT_A_DICT = os.path.join(os.path.dirname(__file__), 'bad_config.yaml') def setUp(self): + super(TestConf, self).setUp() flags, _ = conf.ARG_PARSER.parse_known_args(args) conf.load_flag_values(flags) def tearDown(self): conf._flags.config_file = None conf.reset() + super(TestConf, self).tearDown() def test_yaml_config(self): with io.open(self.YAML_FILENAME, encoding='utf-8') as yamlfile: @@ -80,7 +83,7 @@ def test_defaults(self): self.assertEqual('default', conf.string_default) self.assertIsNone(conf.none_default) with self.assertRaises(conf.UnsetKeyError): - conf.no_default + conf.no_default # pylint: disable=pointless-statement def test_flag_values(self): self.assertEqual('flag_value', conf.flag_key) @@ -107,16 +110,15 @@ def test_as_dict(self): 'string_default': 'default', } # assert first dict is a subset of second dict - self.assertLessEqual(six.viewitems(expected_dict), - six.viewitems(conf_dict)) + self.assertLessEqual(six.viewitems(expected_dict), six.viewitems(conf_dict)) def test_undeclared(self): with self.assertRaises(conf.UndeclaredKeyError): - conf.undeclared + conf.undeclared # pylint: disable=pointless-statement def test_weird_attribute(self): with self.assertRaises(AttributeError): - conf._dont_do_this + conf._dont_do_this # pylint: disable=pointless-statement with self.assertRaises(AttributeError): conf._dont_do_this_either = None @@ -147,6 +149,7 @@ def test_bad_config_file(self): conf.reset() def test_save_and_restore(self): + @conf.save_and_restore def modifies_conf(): conf.load(string_default='modified') @@ -157,6 +160,7 @@ def modifies_conf(): self.assertEqual('default', conf.string_default) def test_save_and_restore_kwargs(self): + @conf.save_and_restore(string_default='modified') def modifies_conf(): self.assertEqual('modified', conf.string_default) @@ -166,15 +170,17 @@ def modifies_conf(): self.assertEqual('default', conf.string_default) def test_inject_positional_args(self): + @conf.inject_positional_args def test_function(string_default, no_default, not_declared): self.assertEqual('default', string_default) self.assertEqual('passed_value', no_default) self.assertEqual('not_declared', not_declared) - test_function(no_default='passed_value', not_declared='not_declared') + test_function(no_default='passed_value', not_declared='not_declared') # pylint: disable=no-value-for-parameter def test_inject_positional_args_overrides(self): + @conf.inject_positional_args def test_function(string_default, none_default='new_default'): # Make sure when we pass a kwarg, it overrides the config value. @@ -185,10 +191,12 @@ def test_function(string_default, none_default='new_default'): test_function(string_default='overridden') def test_inject_positional_args_class(self): - class test_class(object): + + class TestClass(object): + @conf.inject_positional_args def __init__(self, string_default): self.string_default = string_default - instance = test_class() - self.assertEqual('default', instance.string_default) + instance = TestClass() # pylint: disable=no-value-for-parameter + self.assertEqual('default', instance.string_default) # pytype: disable=attribute-error diff --git a/test/util/data_test.py b/test/util/data_test.py index 7a2a6c1c3..02938cc73 100644 --- a/test/util/data_test.py +++ b/test/util/data_test.py @@ -16,14 +16,14 @@ import unittest import attr -from builtins import int from openhtf.util import data -from past.builtins import long +import past.builtins class TestData(unittest.TestCase): def test_convert_to_base_types(self): + class FloatSubclass(float): pass @@ -61,9 +61,9 @@ class AnotherAttr(object): 'tuple': (10,), 'str': '10', 'unicode': '10', - 'int': 2 ** 40, + 'int': 2**40, 'float': 10.0, - 'long': 2 ** 80, + 'long': 2**80, 'bool': True, 'none': None, 'complex': 10j, @@ -74,10 +74,8 @@ class AnotherAttr(object): # Some plugs such as UserInputPlug will return None as a response to # AsDict(). 'none_dict': AsDict(), - 'frozen1': FrozenAttr(value=42), 'another_attr': AnotherAttr(frozen=FrozenAttr(value=19)), - } converted = data.convert_to_base_types(example_data) @@ -87,7 +85,7 @@ class AnotherAttr(object): self.assertIsInstance(converted['unicode'], str) self.assertIsInstance(converted['int'], int) self.assertIsInstance(converted['float'], float) - self.assertIsInstance(converted['long'], long) + self.assertIsInstance(converted['long'], past.builtins.long) self.assertIsInstance(converted['bool'], bool) self.assertIsNone(converted['none']) self.assertIsInstance(converted['complex'], str) @@ -96,10 +94,7 @@ class AnotherAttr(object): self.assertEqual(converted['special'], {'safe_value': True}) self.assertIs(converted['not_copied'], not_copied.value) - self.assertEqual(converted['none_dict'], None) + self.assertIsNone(converted['none_dict']) self.assertEqual(converted['frozen1'], {'value': 42}) self.assertEqual(converted['another_attr'], {'frozen': {'value': 19}}) - -if __name__ == '__main__': - unittest.main() diff --git a/test/util/functions_test.py b/test/util/functions_test.py index a27075fde..44f222408 100644 --- a/test/util/functions_test.py +++ b/test/util/functions_test.py @@ -15,15 +15,17 @@ import unittest import mock - from openhtf.util import functions class MockTime(object): + def __init__(self): self._time = 0 + def sleep(self, seconds): self._time += seconds + def time(self): self._time += 1 return self._time - 1 @@ -31,24 +33,30 @@ def time(self): class TestFunctions(unittest.TestCase): - def test_call_once_fails_with_args(self): + def test_call_once_fails_single_arg(self): with self.assertRaises(ValueError): + @functions.call_once - def has_args(x): - pass + def has_args(x): # pylint: disable=unused-variable + del x # Unused. + def test_call_once_fails_star_args(self): with self.assertRaises(ValueError): + @functions.call_once - def has_args(*args): - pass + def has_args(*args): # pylint: disable=unused-variable + del args # Unused. + def test_call_once_fails_kwargs(self): with self.assertRaises(ValueError): + @functions.call_once - def has_args(**kwargs): - pass + def has_args(**kwargs): # pylint: disable=unused-variable + del kwargs # Unused. def test_call_once(self): calls = [] + @functions.call_once def can_only_call_once(): calls.append(None) @@ -61,10 +69,12 @@ def can_only_call_once(): @mock.patch('openhtf.util.functions.time', new_callable=MockTime) def testCallAtMostEvery(self, mock_time): call_times = [] + @functions.call_at_most_every(5) - def CallOnceEveryFiveSeconds(): + def _call_once_every_five_seconds(): call_times.append(mock_time.time()) + for _ in range(100): - CallOnceEveryFiveSeconds() + _call_once_every_five_seconds() # Each call takes "6 seconds", so we get call times up to 600. - self.assertEqual(list(range(2, 600, 6)), call_times) + self.assertEqual(list(range(2, 600, 6)), call_times) diff --git a/test/util/logs_test.py b/test/util/logs_test.py index 17433bcb7..e7baa253e 100644 --- a/test/util/logs_test.py +++ b/test/util/logs_test.py @@ -16,9 +16,9 @@ import unittest import mock - from openhtf.util import logs + class TestLogs(unittest.TestCase): def test_log_once(self): @@ -34,4 +34,3 @@ def test_log_once_utf8(self): logs.log_once(mock_log, u'状态是', 'arg1') assert mock_log.call_count == 1 - diff --git a/test/util/test_test.py b/test/util/test_test.py index 5638056b5..a49a54459 100644 --- a/test/util/test_test.py +++ b/test/util/test_test.py @@ -17,6 +17,7 @@ import openhtf from openhtf import plugs +from openhtf.core import base_plugs from openhtf.core import measurements from openhtf.util import test from openhtf.util import validators @@ -26,7 +27,7 @@ class DummyError(Exception): """Raised for testing phases that raise.""" -class MyPlug(plugs.BasePlug): +class MyPlug(base_plugs.BasePlug): """Stub plug for ensuring plugs get mocked correctly.""" def __init__(self): @@ -57,8 +58,10 @@ def raising_phase(): def phase_retval(retval): """Helper function to generate a phase that returns the given retval.""" + def phase(): return retval + return phase @@ -148,8 +151,10 @@ def test_bad_assert(self): self.assertMeasured(None) def test_doesnt_yield(self): + def doesnt_yield(cls_self): # pylint: disable=unused-argument pass + with self.assertRaises(test.InvalidTestError): test.yields_phases(doesnt_yield)(self) @@ -164,18 +169,22 @@ def stub_test_method(cls_self, plug_one, plug_two): # pylint: disable=unused-ar # Test that we catch mocks that aren't expected. with self.assertRaises(test.InvalidTestError): - test.patch_plugs(plug_one='unused', plug_two='unused', - plug_three='unused')(stub_test_method) + test.patch_plugs( + plug_one='unused', plug_two='unused', plug_three='unused')( + stub_test_method) # Test that we catch weird plug specifications. with self.assertRaises(ValueError): - test.patch_plugs(plug_one='bad_spec_no_dots', - plug_two='unused')(stub_test_method) + test.patch_plugs( + plug_one='bad_spec_no_dots', plug_two='unused')( + stub_test_method) with self.assertRaises(KeyError): - test.patch_plugs(plug_one='bad.spec.invalid.module', - plug_two='also.bad')(stub_test_method) + test.patch_plugs( + plug_one='bad.spec.invalid.module', plug_two='also.bad')( + stub_test_method) def test_bad_yield(self): + def bad_test(cls_self): # pylint: disable=unused-argument yield None diff --git a/test/util/util_test.py b/test/util/util_test.py index 38989cdc1..5e0ffc83d 100644 --- a/test/util/util_test.py +++ b/test/util/util_test.py @@ -12,26 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import copy -import mock import time import unittest +import mock from openhtf import util from openhtf.util import timeouts class TestUtil(unittest.TestCase): - @classmethod - def setUp(cls): - cls.timeout = 60 - cls.polledtimeout = timeouts.PolledTimeout(cls.timeout) - - @classmethod - def tearDown(cls): - pass + def setUp(self): + super(TestUtil, self).setUp() + self.timeout = 60 + self.polledtimeout = timeouts.PolledTimeout(self.timeout) @mock.patch.object(time, 'time') def test_time_expired_false(self, mock_time): @@ -49,8 +44,8 @@ def test_time_expired_true(self): def test_partial_format(self): original = ('Apples are {apple[color]} and {apple[taste]}. ' - 'Pears are {pear.color} and {pear.taste}. ' - 'Oranges are {orange_color} and {orange_taste}.') + 'Pears are {pear.color} and {pear.taste}. ' + 'Oranges are {orange_color} and {orange_taste}.') text = copy.copy(original) apple = { @@ -61,18 +56,21 @@ def test_partial_format(self): class Pear(object): color = 'green' taste = 'tart' + pear = Pear() # Partial formatting res = util.partial_format(text, apple=apple) res = util.partial_format(res, pear=pear) - self.assertEqual('Apples are red and sweet. Pears are green and tart. ' - 'Oranges are {orange_color} and {orange_taste}.', res) + self.assertEqual( + 'Apples are red and sweet. Pears are green and tart. ' + 'Oranges are {orange_color} and {orange_taste}.', res) # Format rest of string res = util.partial_format(res, orange_color='orange', orange_taste='sour') - self.assertEqual('Apples are red and sweet. Pears are green and tart. ' - 'Oranges are orange and sour.', res) + self.assertEqual( + 'Apples are red and sweet. Pears are green and tart. ' + 'Oranges are orange and sour.', res) # The original text has not changed self.assertEqual(original, text) diff --git a/test/util/validators_test.py b/test/util/validators_test.py index b778ada1f..ccba2a59f 100644 --- a/test/util/validators_test.py +++ b/test/util/validators_test.py @@ -1,12 +1,13 @@ -"""Unit tests for util/validators.py""" +"""Unit tests for util/validators.py.""" import copy import decimal -import six import unittest -from builtins import int +import openhtf as htf +from openhtf.util import test as htf_test from openhtf.util import validators +import six class TestInRange(unittest.TestCase): @@ -68,12 +69,14 @@ def test_with_custom_type(self): class TestEqualsValidator(unittest.TestCase): def test_with_built_in_pods(self): - for val in [1, '1', 1.0, False, (1,), [1], {1:1}]: + for val in [1, '1', 1.0, False, (1,), [1], {1: 1}]: self.assertTrue(validators.Equals(val)(val)) def test_with_custom_class(self): + class MyType(object): A = 10 + my_type = MyType() self.assertTrue(validators.Equals(my_type)(my_type)) @@ -111,8 +114,10 @@ def test_with_string(self): self.assertFalse(string_validator('aardvarka')) def test_with_object(self): + class MyType(object): val = 'A' + my_type_a = MyType() object_validator = validators.equals(my_type_a) self.assertTrue(object_validator(my_type_a)) @@ -151,8 +156,10 @@ def test_within_percent_greater_than_one_hundred(self): def test_equals_equivalent_within_percent_validator(self): validator_a = validators.WithinPercent(expected=100, percent=10) validator_b = validators.WithinPercent(expected=100, percent=10) - self.assertEqual(validator_a, validator_b, - msg='Validators should compare equal, but did not.') + self.assertEqual( + validator_a, + validator_b, + msg='Validators should compare equal, but did not.') def test_not_equals_when_not_equivalent(self): validator_a = validators.WithinPercent(expected=100, percent=10) @@ -174,77 +181,72 @@ def test_is_deep_copyable(self): # state in a non-deepcopyable manner. validator_a(1) str(validator_a) - validator_a == 'a' + validator_a == 'a' # pylint: disable=pointless-statement validator_b = copy.deepcopy(validator_a) self.assertEqual(validator_a.expected, validator_b.expected) self.assertEqual(validator_a.percent, validator_b.percent) -class TestWithinTolerance(unittest.TestCase): +class DimensionPivotTest(htf_test.TestCase): + """Tests validators.DimensionPivot. Used with dimensioned measurements.""" - def test_raises_for_negative_tolerance(self): - with six.assertRaisesRegex(self, ValueError, 'tolerance argument is'): - validators.WithinTolerance(expected=5.0, tolerance=-0.1) + _test_value = 10 + _sub_validator = validators.in_range(0, _test_value) + _test_measurement = htf.Measurement('pivot').with_dimensions( + 'height', 'width').dimension_pivot_validate(_sub_validator) - def test_within_tolerance_small(self): - validator = validators.WithinTolerance(expected=5.0, tolerance=0.1) - for valid_value in [5.0, 5.01, 5.09, 5.0999, 4.9, 4.91]: - self.assertTrue( - validator(valid_value), - msg='{} should validate, but did not'.format(valid_value)) - for invalid_value in [0, 0.01, -10.0, 10.0, 5.2, 5.11, 4.89]: - self.assertFalse( - validator(invalid_value), - msg='{} should not validate, but did'.format(invalid_value)) + @htf_test.yields_phases + def testPasses(self): - def test_within_tolerance_large(self): - validator = validators.WithinTolerance(expected=0.0, tolerance=100.0) - for valid_value in [0.0, -90.5, 100.0, -100.0, -1.3, -99.9]: - self.assertTrue( - validator(valid_value), - msg='{} should validate, but did not'.format(valid_value)) - for invalid_value in [100.001, 1000.0, -200.0, -100.1, 1e6]: - self.assertFalse( - validator(invalid_value), - msg='{} should not validate, but did'.format(invalid_value)) + @htf.measures(self._test_measurement) + def phase(test): + test.measurements.pivot[10, 10] = self._test_value - 2 + test.measurements.pivot[11, 10] = self._test_value - 1 - def test_within_tolerance_negative(self): - validator = validators.WithinTolerance(expected=5.0, tolerance=0.1) - for valid_value in [5.0, 5.01, 5.09, 5.0999, 4.9, 4.91]: - self.assertTrue( - validator(valid_value), - msg='{} should validate, but did not'.format(valid_value)) - for invalid_value in [0, 0.01, -10.0, 10.0, 5.2, 5.11, 4.89]: - self.assertFalse( - validator(invalid_value), - msg='{} should not validate, but did'.format(invalid_value)) + phase_record = yield phase + self.assertMeasurementPass(phase_record, 'pivot') - def test_equals_equivalent_within_tolerance_validator(self): - validator_a = validators.WithinTolerance(expected=5.0, tolerance=0.1) - validator_b = validators.WithinTolerance(expected=5.0, tolerance=0.1) - self.assertEqual(validator_a, validator_b, - msg='Validators should compare equal, but did not.') + @htf_test.yields_phases + def testFails(self): - def test_not_equals_when_not_equivalent(self): - validator_a = validators.WithinTolerance(expected=5.0, tolerance=0.1) - validator_b = validators.WithinTolerance(expected=5.0, tolerance=0.2) - validator_c = validators.WithinTolerance(expected=4.0, tolerance=0.1) - for validator in [validator_b, validator_c]: - self.assertNotEqual(validator_a, validator) + @htf.measures(self._test_measurement) + def phase(test): + test.measurements.pivot[11, 12] = self._test_value - 1 + test.measurements.pivot[14, 12] = self._test_value + 1 - def test_string_representation_does_not_raise(self): - validator_a = validators.WithinTolerance(expected=5.0, tolerance=0.1) - str(validator_a) - # Check that we constructed a usable validator. - self.assertTrue(validator_a(5.0)) + phase_record = yield phase + self.assertMeasurementFail(phase_record, 'pivot') - def test_is_deep_copyable(self): - validator_a = validators.WithinTolerance(expected=5.0, tolerance=0.1) - # Call implemented functions, try catch the cases where they might change - # state in a non-deepcopyable manner. - validator_a(1) - str(validator_a) - validator_a == 'a' - validator_b = copy.deepcopy(validator_a) - self.assertEqual(validator_a.expected, validator_b.expected) - self.assertEqual(validator_a.tolerance, validator_b.tolerance) + +class ConsistentEndDimensionPivotTest(htf_test.TestCase): + """Tests validators.ConsistentEndRange. Similar to DimensionPivot.""" + + _sub_validator = validators.in_range(minimum=5) + _test_measurement = htf.Measurement('pivot').with_dimensions( + 'time').consistent_end_dimension_pivot_validate(_sub_validator) + + @htf_test.yields_phases + def testPasses(self): + + @htf.measures(self._test_measurement) + def phase(test): + test.measurements.pivot[0] = 0 + test.measurements.pivot[1] = 2 + test.measurements.pivot[2] = 6 + test.measurements.pivot[3] = 8 + + phase_record = yield phase + self.assertMeasurementPass(phase_record, 'pivot') + + @htf_test.yields_phases + def testFails(self): + + @htf.measures(self._test_measurement) + def phase(test): + test.measurements.pivot[0] = 3 + test.measurements.pivot[1] = 4 + test.measurements.pivot[2] = 6 + test.measurements.pivot[3] = 4 + + phase_record = yield phase + self.assertMeasurementFail(phase_record, 'pivot') diff --git a/tox.ini b/tox.ini index 61369a961..2f7f667ca 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27 +envlist = py36,py37 [testenv] deps = -r{toxinidir}/test_reqs.txt