diff --git a/openhtf/output/callbacks/json_factory.py b/openhtf/output/callbacks/json_factory.py index b6b6fdf5d..796135ecf 100644 --- a/openhtf/output/callbacks/json_factory.py +++ b/openhtf/output/callbacks/json_factory.py @@ -6,19 +6,10 @@ from openhtf.core import test_record from openhtf.output import callbacks from openhtf.util import data +from openhtf.util import json_encoder import six -class TestRecordEncoder(json.JSONEncoder): - - def default(self, obj): - if isinstance(obj, test_record.Attachment): - dct = obj._asdict() - dct['data'] = base64.standard_b64encode(obj.data).decode('utf-8') - return dct - return super(TestRecordEncoder, self).default(obj) - - class OutputToJSON(callbacks.OutputToFile): """Return an output callback that writes JSON Test Records. @@ -47,7 +38,7 @@ def __init__(self, filename_pattern=None, inline_attachments=True, **kwargs): # Conform strictly to the JSON spec by default. kwargs.setdefault('allow_nan', False) self.allow_nan = kwargs['allow_nan'] - self.json_encoder = TestRecordEncoder(**kwargs) + self.json_encoder = json_encoder.TestRecordEncoder(**kwargs) def serialize_test_record(self, test_record): return self.json_encoder.iterencode(self.convert_to_dict(test_record)) diff --git a/openhtf/output/proto/mfg_event_converter.py b/openhtf/output/proto/mfg_event_converter.py index 1c1acb989..5e6525a59 100644 --- a/openhtf/output/proto/mfg_event_converter.py +++ b/openhtf/output/proto/mfg_event_converter.py @@ -23,6 +23,7 @@ from openhtf.util import data as htf_data from openhtf.util import units from openhtf.util import validators +from openhtf.util import json_encoder from past.builtins import unicode @@ -172,9 +173,9 @@ def _convert_object_to_json(obj): # Since there will be parts of this that may have unicode, either as # measurement or in the logs, we have to be careful and convert everything # to unicode, merge, then encode to UTF-8 to put it into the proto. - json_encoder = json.JSONEncoder(sort_keys=True, indent=2, ensure_ascii=False) + encoder = json_encoder.TestRecordEncoder(sort_keys=True, indent=2, ensure_ascii=False) pieces = [] - for piece in json_encoder.iterencode(obj): + for piece in encoder.iterencode(obj): if isinstance(piece, bytes): pieces.append(unicode(piece, errors='replace')) else: diff --git a/openhtf/output/servers/station_server.py b/openhtf/output/servers/station_server.py index fe20bb2bb..d6bc9d6da 100644 --- a/openhtf/output/servers/station_server.py +++ b/openhtf/output/servers/station_server.py @@ -29,6 +29,7 @@ from openhtf.util import logs from openhtf.util import multicast from openhtf.util import timeouts +from openhtf.util import json_encoder STATION_SERVER_TYPE = 'station' @@ -556,6 +557,11 @@ def __init__(self, history_path=None): if not tornado_logger.handlers: tornado_logger.addHandler(logging.NullHandler()) + # Override tornado's json encoding to handle our Attachments. + def _json_encode(value): + return json_encoder.TestRecordEncoder().encode(value) + sockjs.tornado.proto.json_encode = _json_encode + # Bind port early so that the correct port number can be used in the routes. sockets, port = web_gui_server.bind_port(int(conf.station_server_port)) diff --git a/openhtf/util/json_encoder.py b/openhtf/util/json_encoder.py new file mode 100644 index 000000000..bd3ede332 --- /dev/null +++ b/openhtf/util/json_encoder.py @@ -0,0 +1,29 @@ +# Copyright 2016 Google Inc. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json + +from openhtf.core import test_record + + +class TestRecordEncoder(json.JSONEncoder): + """JSON encoder that supports Attachments.""" + + def default(self, obj): + if isinstance(obj, test_record.Attachment): + dct = obj._asdict() + dct['data'] = base64.standard_b64encode(obj.data).decode('utf-8') + return dct + return super(TestRecordEncoder, self).default(obj)