diff --git a/src/powerapi/database/socket_db.py b/src/powerapi/database/socket_db.py index 461dfd4d..825fd3a1 100644 --- a/src/powerapi/database/socket_db.py +++ b/src/powerapi/database/socket_db.py @@ -27,99 +27,151 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import asyncio -from typing import Type, List -import json - -from powerapi.utils import JsonStream +import logging +from json import JSONDecoder, JSONDecodeError +from queue import SimpleQueue, Empty +from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler +from threading import Thread +from typing import Type, Iterator + +from powerapi.database.base_db import IterDB, BaseDB from powerapi.report import Report -from .base_db import IterDB, BaseDB, DBError -class IterSocketDB(IterDB): +class ThreadedTCPServer(ThreadingMixIn, TCPServer): """ - iterator connected to a socket that receive report from a sensor + TCP Server implementation. + Each client connected will be served by a separate thread. """ + daemon_threads = True + allow_reuse_address = True - def __init__(self, report_type, stream_mode, queue): + def __init__(self, server_address, request_handler_class, received_data_queue: SimpleQueue): """ + :param server_address: The address to listen on + :param request_handler_class: The request handler class to use when receiving requests + :param received_data_queue: The data queue to store the received data """ - IterDB.__init__(self, None, report_type, stream_mode) + super().__init__(server_address, request_handler_class) + self.received_data_queue = received_data_queue + - self.queue = queue +class JsonRequestHandler(StreamRequestHandler): + """ + Request handler that handles JSON documents received from the client. + """ + server: ThreadedTCPServer + + @staticmethod + def parse_json_documents(data: str) -> Iterator[dict]: + """ + Try to parse json document(s) from the given string. + This function tolerates truncated and malformed json documents. + :param data: The raw data to decode + :return: Iterator over parsed json documents + """ + decoder = JSONDecoder() + idx = 0 + while idx < len(data): + try: + obj, end_idx = decoder.raw_decode(data, idx) + yield obj + idx += end_idx + + # Search and try to parse the remaining document(s) + except JSONDecodeError as e: + idx = data.find('{', e.pos) + if idx == -1: + break + + def handle(self): + """ + Handle incoming connections. + The received data is parsed and the result(s) stored in the data queue for further processing. + It is expected for the data to be in json format (utf-8 charset) and newline terminated. + """ + caddr = '{}:{}'.format(*self.client_address) + logging.info('New incoming connection from %s', caddr) + + while True: + try: + data = self.rfile.readline() + if not data: + break + + for obj in self.parse_json_documents(data.decode('utf-8')): + self.server.received_data_queue.put(obj) + + except ValueError as e: + logging.warning('[%s] Received malformed data: %s', caddr, e) + continue + except OSError as e: + logging.error('[%s] Caught OSError while handling request: %s', caddr, e) + break + except KeyboardInterrupt: + break + + logging.info('Connection from %s closed', caddr) + + +class IterSocketDB(IterDB): + """ + SocketDB iterator that returns the received data. + """ - def __aiter__(self): + def __iter__(self): return self - async def __anext__(self): + def __next__(self): try: - json_str = await asyncio.wait_for(self.queue.get(), 2) - # json = self.queue.get_nowait() - # self.queue.get() - report = self.report_type.from_json(json.loads(json_str)) - return report - # except Empty: - except asyncio.TimeoutError: - return None + document = self.db.received_data_queue.get(block=False) + return self.report_type.from_json(document) + except Empty as e: + raise StopIteration from e class SocketDB(BaseDB): """ - Database that act as a server that expose a socket where data source will push data + Database implementation that exposes a TCP socket the clients can connect to. """ def __init__(self, report_type: Type[Report], host: str, port: int): - BaseDB.__init__(self, report_type, is_async=True) - self.queue = None - self.host = host - self.port = port - self.server = None + """ + :param report_type: The type of report to create + :param host: The host address to listen on + :param port: The port number to listen on + """ + super().__init__(report_type) - async def connect(self): + self.server_address = (host, port) + + self.received_data_queue = None + self.background_thread = None + + def _tcpserver_background_thread_target(self): """ - Connect to the socket database. + Target function of the thread that will run the TCP server in background. """ - self.queue = asyncio.Queue() - self.server = await asyncio.start_server(self._gen_server_callback(), host=self.host, port=self.port) + with ThreadedTCPServer(self.server_address, JsonRequestHandler, self.received_data_queue) as server: + logging.info('TCP socket is listening on %s:%s', *self.server_address) + server.serve_forever() - async def disconnect(self): + def connect(self): """ - Disconnect from the socket database. + Connect to the socket database. """ + self.received_data_queue = SimpleQueue() + self.background_thread = Thread(target=self._tcpserver_background_thread_target, daemon=True) + self.background_thread.start() - async def stop(self): + def disconnect(self): """ - stop server connection + Disconnect from the socket database. """ - self.server.close() - await self.server.wait_closed() def iter(self, stream_mode: bool = False) -> IterSocketDB: - return IterSocketDB(self.report_type, stream_mode, self.queue) - - def _gen_server_callback(self): - async def callback(stream_reader, _): - stream = JsonStream(stream_reader) - count = 0 # If 10 times in a row we don't have a full message we stop - while True: - json_str = await stream.read_json_object() - if json_str is None: - if count > 10: - break - count += 1 - continue - count = 0 - await self.queue.put(json_str) - - # self.queue.put(json_str) - - return callback - - def __iter__(self): - raise DBError('Socket db don\'t support __iter__ method') - - def save(self, report: Report): - raise DBError('Socket db don\'t support save method') - - def save_many(self, reports: List[Report]): - raise DBError('Socket db don\'t support save_many method') + """ + Create the data iterator for the socket database. + :param stream_mode: Whether the data should be pulled continuously or not. + """ + return IterSocketDB(self, self.report_type, stream_mode) diff --git a/src/powerapi/puller/handlers.py b/src/powerapi/puller/handlers.py index c470d8db..c46fde59 100644 --- a/src/powerapi/puller/handlers.py +++ b/src/powerapi/puller/handlers.py @@ -27,19 +27,18 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import time -import asyncio import logging +import time from threading import Thread from powerapi.actor import State -from powerapi.message import Message -from powerapi.exception import PowerAPIException, BadInputData +from powerapi.database import DBError +from powerapi.exception import PowerAPIException from powerapi.filter import FilterUselessError from powerapi.handler import StartHandler, PoisonPillMessageHandler -from powerapi.database import DBError from powerapi.message import ErrorMessage, PoisonPillMessage -from powerapi.report.report import DeserializationFail +from powerapi.message import Message +from powerapi.report import BadInputData class NoReportExtractedException(PowerAPIException): @@ -55,37 +54,11 @@ class DBPullerThread(Thread): """ def __init__(self, state, timeout, handler): - Thread.__init__(self, daemon=True) + super().__init__(daemon=True) self.timeout = timeout self.state = state - self.loop = None self.handler = handler - def _connect(self): - try: - self.state.database.connect() - self.loop.run_until_complete(self.state.database.connect()) - self.state.database_it = self.state.database.iter(self.state.stream_mode) - except DBError as error: - self.state.actor.send_control(ErrorMessage(sender_name='system', error_message=error.msg)) - self.state.alive = False - - def _pull_database(self): - try: - if self.state.database.is_async: - report = self.loop.run_until_complete(anext(self.state.database_it)) - if report is None: - raise StopIteration() - return report - - return next(self.state.database_it) - - except (StopIteration, BadInputData, DeserializationFail) as database_problem: - raise NoReportExtractedException() from database_problem - - def _get_dispatchers(self, report): - return self.state.report_filter.route(report) - def run(self): """ Read data from Database and send it to the dispatchers. @@ -95,37 +68,26 @@ def run(self): :param None msg: None. """ - if self.state.database.is_async: - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - self.state.loop = self.loop - self.loop.set_debug(enabled=True) - logging.basicConfig(level=logging.DEBUG) - - self._connect() - while self.state.alive: try: - raw_report = self._pull_database() - - dispatchers = self._get_dispatchers(raw_report) + raw_report = next(self.state.database_it) + dispatchers = self.state.report_filter.route(raw_report) for dispatcher in dispatchers: dispatcher.send_data(raw_report) - except NoReportExtractedException: - time.sleep(self.state.timeout_puller / 1000) - self.state.actor.logger.debug('NoReportExtractedException with stream mode ' + - str(self.state.stream_mode)) - if not self.state.stream_mode: - self.handler.handle_internal_msg(PoisonPillMessage(soft=False, sender_name='system')) - return - except FilterUselessError: - self.handler.handle_internal_msg(PoisonPillMessage(soft=False, sender_name='system')) + self.handler.handle_internal_msg(PoisonPillMessage(False, self.name)) return + except BadInputData as exn: + logging.error('Received malformed report from database: %s', exn.msg) + logging.debug('Raw report value: %s', exn.input_data) + except StopIteration: - continue + time.sleep(self.state.timeout_puller / 1000) + if not self.state.stream_mode: + self.handler.handle_internal_msg(PoisonPillMessage(False, self.name)) + return class PullerPoisonPillMessageHandler(PoisonPillMessageHandler): @@ -166,15 +128,15 @@ def handle_internal_msg(self, msg): StartHandler.delegate_message_handling(self, msg) def initialization(self): - - self._database_connection() - if not self.state.report_filter.filters: raise PullerInitializationException('No filters') + # Connect to all dispatcher for _, dispatcher in self.state.report_filter.filters: dispatcher.connect_data() + self._connect_database() + def handle(self, msg: Message): try: StartHandler.handle(self, msg) @@ -183,7 +145,7 @@ def handle(self, msg: Message): self.pull_db() - self.handle_internal_msg(PoisonPillMessage(soft=False, sender_name='system')) + self.handle_internal_msg(PoisonPillMessage(False, self.state.actor.name)) def pull_db(self): """ @@ -198,12 +160,10 @@ def pull_db(self): if msg is not None: self.handle_internal_msg(msg) - def _database_connection(self): + def _connect_database(self): try: - if not self.state.database.is_async: - self.state.database.connect() - self.state.database_it = self.state.database.iter(stream_mode=self.state.stream_mode) - + self.state.database.connect() + self.state.database_it = self.state.database.iter(self.state.stream_mode) except DBError as error: - self.state.actor.send_control(ErrorMessage(self.state.actor.name, error.msg)) self.state.alive = False + self.state.actor.send_control(ErrorMessage(self.state.actor.name, error.msg)) diff --git a/src/powerapi/puller/puller_actor.py b/src/powerapi/puller/puller_actor.py index 71f78766..238fd655 100644 --- a/src/powerapi/puller/puller_actor.py +++ b/src/powerapi/puller/puller_actor.py @@ -28,19 +28,12 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import logging + from powerapi.actor import Actor, State -from powerapi.exception import PowerAPIException from powerapi.message import PoisonPillMessage, StartMessage from powerapi.puller.handlers import PullerPoisonPillMessageHandler, PullerStartHandler -class NoReportExtractedException(PowerAPIException): - """ - Exception raised when we can't extract a report from the given - database - """ - - class PullerState(State): """ Puller Actor State diff --git a/src/powerapi/report/hwpc_report.py b/src/powerapi/report/hwpc_report.py index a7229686..2b728502 100644 --- a/src/powerapi/report/hwpc_report.py +++ b/src/powerapi/report/hwpc_report.py @@ -104,10 +104,12 @@ def from_json(data: Dict) -> HWPCReport: ts = Report._extract_timestamp(data[TIMESTAMP_KEY]) metadata = {} if METADATA_KEY not in data else data[METADATA_KEY] return HWPCReport(ts, data[SENSOR_KEY], data[TARGET_KEY], data[GROUPS_KEY], metadata) + except TypeError as exn: + raise BadInputData(f'Invalid input document: {exn.args[0]}', data) from exn except KeyError as exn: - raise BadInputData('HWPC report require field ' + str(exn.args[0]) + ' in json document', data) from exn + raise BadInputData(f'Missing required field "{exn.args[0]}" from input document', data) from exn except ValueError as exn: - raise BadInputData(exn.args[0], data) from exn + raise BadInputData(f'Unexpected field value in input document: {exn.args}', data) from exn @staticmethod def to_json(report: HWPCReport) -> Dict: diff --git a/src/powerapi/utils/__init__.py b/src/powerapi/utils/__init__.py index 6d947d09..5541640a 100644 --- a/src/powerapi/utils/__init__.py +++ b/src/powerapi/utils/__init__.py @@ -27,5 +27,4 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .json_stream import JsonStream from .utils import timestamp_to_datetime diff --git a/src/powerapi/utils/json_stream.py b/src/powerapi/utils/json_stream.py deleted file mode 100644 index 27cf5d33..00000000 --- a/src/powerapi/utils/json_stream.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2021, INRIA -# Copyright (c) 2021, University of Lille -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -DEFAULT_BUFFER_SIZE = 4096 - - -class JsonStream: - """read data received from a input utf-8 byte stream socket as a json stream - - :param stream_reader: - :param buffer_size: size of the buffer used to receive data from the socket, - it must match the average size of received json string - (default 4096 bytes) - """ - - def __init__(self, stream_reader, buffer_size=4096): - self.stream_reader = stream_reader - self.json_buffer = b'' - self.buffer_size = buffer_size - self.open_brackets = 0 - - async def _get_bytes(self): - data = await self.stream_reader.read(n=self.buffer_size) - return b'' if data is None else data - - def _extract_json_end_position(self, first_new_byte): - """ - Find the first valable report in the stream - """ - i = first_new_byte - # print("buffer is ", self.json_buffer[first_new_byte:]) - if len(self.json_buffer) == 0: - return -1 - if self.json_buffer[0] != 123: # ASCII code opening bracket - return -1 - - while i < len(self.json_buffer): - if self.json_buffer[i] == 125: # ASCII code closing bracket - self.open_brackets -= 1 - # print("opening : ", self.open_brackets) - elif self.json_buffer[i] == 123: # ASCII code opening bracket - self.open_brackets += 1 - # print("closing :",self.open_brackets) - if self.open_brackets == 0: - return i - i += 1 - - return -1 - - async def read_json_object(self): - """ - return all the json object received from the connection as a iteration of string - """ - if len(self.json_buffer) != 0 and self.open_brackets == 0: - # Last iteration _extract_json_end_position returned a json_object - # and breaked. If the buffer isn't empty wasn't treated, so we have - # to treat it - first_new_byte = 0 - else: - first_new_byte = len(self.json_buffer) - self.json_buffer += await self._get_bytes() - i = self._extract_json_end_position(first_new_byte) - - if i == -1: - return None - if i == len(self.json_buffer) - 1: - json_str = self.json_buffer[:] - self.json_buffer = b'' - # print("buffer empty") - # print(self.open_brackets) - else: - json_str = self.json_buffer[:i + 1] - self.json_buffer = self.json_buffer[i + 1:] - # print("buffer non empty ") - # print(self.open_brackets) - return json_str.decode('utf-8') diff --git a/tests/unit/cli/test_generator.py b/tests/unit/cli/test_generator.py index 6e9de2c1..187ec3c5 100644 --- a/tests/unit/cli/test_generator.py +++ b/tests/unit/cli/test_generator.py @@ -108,7 +108,7 @@ def test_generate_several_pullers_from_config(several_inputs_outputs_stream_conf elif current_puller_infos['type'] == 'socket': assert isinstance(db, SocketDB) - assert db.port == current_puller_infos['port'] + assert db.server_address == (current_puller_infos['host'], current_puller_infos['port']) else: assert False diff --git a/tests/unit/database/__init__.py b/tests/unit/database/__init__.py new file mode 100644 index 00000000..c1a0f439 --- /dev/null +++ b/tests/unit/database/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, Inria +# Copyright (c) 2024, University of Lille +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tests/unit/database/conftest.py b/tests/unit/database/conftest.py new file mode 100644 index 00000000..c1a0f439 --- /dev/null +++ b/tests/unit/database/conftest.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, Inria +# Copyright (c) 2024, University of Lille +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tests/unit/database/test_socket.py b/tests/unit/database/test_socket.py new file mode 100644 index 00000000..1413d083 --- /dev/null +++ b/tests/unit/database/test_socket.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Inria +# Copyright (c) 2024, University of Lille +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from powerapi.database.socket_db import JsonRequestHandler + + +def test_parse_json_empty_document(): + """ + Test parsing an empty JSON document. + """ + document = '' + results = JsonRequestHandler.parse_json_documents(document) + + with pytest.raises(StopIteration): + next(results) + + +def test_parse_single_json_document(): + """ + Test parsing a single JSON document. + """ + document = '''{"a": 1, "b": 2, "c": 3}''' + results = JsonRequestHandler.parse_json_documents(document) + + first_result = next(results) + + assert first_result == {"a": 1, "b": 2, "c": 3} + + +def test_parse_single_invalid_json_document(): + """ + Test parsing a single invalid JSON document. + """ + document = '''{"a": 1, "b": 2, "c":''' + results = JsonRequestHandler.parse_json_documents(document) + + with pytest.raises(StopIteration): + next(results) + + +def test_parse_multiple_json_documents(): + """ + Test parsing multiple JSON documents. + """ + document = '''{"a": 1, "b": 2, "c": 3}{"d": 4, "e": 5, "f": 6}''' + results = JsonRequestHandler.parse_json_documents(document) + + print(results) + + first_result = next(results) + second_result = next(results) + + assert first_result == {"a": 1, "b": 2, "c": 3} + assert second_result == {"d": 4, "e": 5, "f": 6} + + +def test_parse_multiple_documents_first_valid_second_invalid(): + """ + Test parsing multiple JSON documents where the first document is valid and second is invalid. + """ + document = '''{"a": 1, "b": 2, "c": 3}{"d": 4, "e":''' + results = JsonRequestHandler.parse_json_documents(document) + + first_result = next(results) + assert first_result == {"a": 1, "b": 2, "c": 3} + + with pytest.raises(StopIteration): + next(results) diff --git a/tests/unit/utils/test_JsonStream.py b/tests/unit/utils/test_JsonStream.py deleted file mode 100644 index 10594fde..00000000 --- a/tests/unit/utils/test_JsonStream.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2021, INRIA -# Copyright (c) 2021, University of Lille -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import time - -import pytest -from mock import Mock - -from powerapi.utils import JsonStream - -SOCKET_TIMEOUT = 0.2 - - -class MockedStreamReader(Mock): - def __init__(self, message): - Mock.__init__(self) - self.message = message - - async def read(self, n=-1): - if self.message == '': - time.sleep(SOCKET_TIMEOUT) - return None - else: - byte_to_read = min(len(self.message), n) - data = self.message[:byte_to_read] - self.message = self.message[byte_to_read:] - return bytes(data, 'utf-8') - - -@pytest.mark.asyncio -async def test_read_json_object_from_a_socket_without_data_return_None(): - socket = MockedStreamReader('') - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result is None - - -@pytest.mark.asyncio -async def test_read_json_object_from_a_socket_with_one_json_object_must_return_one_json_string(): - json_string = '{"a":1}' - socket = MockedStreamReader(json_string) - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result == json_string - - -@pytest.mark.asyncio -async def test_read_json_object_twice_from_a_socket_with_one_json_object_must_return_only_one_json_string(): - json_string = '{"a":1}' - socket = MockedStreamReader(json_string) - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result == json_string - - result = await stream.read_json_object() - assert result is None - - -@pytest.mark.asyncio -async def test_read_json_object_from_a_socket_with_an_incomplete_json_object_must_return_None(): - json_string = '{"a":1' - socket = MockedStreamReader(json_string) - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result is None - - -@pytest.mark.asyncio -async def test_read_json_object_from_a_socket_with_an_complete_json_object_and_incomplete_json_object_must_return_only_one_json_string(): - json_string = '{"a":1}{"a":1' - socket = MockedStreamReader(json_string) - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result == '{"a":1}' - - result = await stream.read_json_object() - assert result is None - - -@pytest.mark.asyncio -async def test_read_json_object_twice_from_a_socket_with_two_json_object_must_return_two_json_string(): - json1 = '{"a":1}' - json2 = '{"b":2}' - socket = MockedStreamReader(json1 + json2) - stream = JsonStream(socket) - - result = await stream.read_json_object() - assert result == json1 - - result = await stream.read_json_object() - assert result == json2 - - result = await stream.read_json_object() - assert result is None