Skip to content

Commit

Permalink
Add stricter checking when decoding with .from_payload()
Browse files Browse the repository at this point in the history
Earlier, it was possible that a protobuf payload not including a certain
message was decoded using .from_payload(). For example when using
GetConfigsRequest.from_payload() with a message missing
GenericMessage::wirepas::get_configs_req, the library was not reporting
an error and setting the req_id to 0 even though it was not initialized
in the protobuf message.

If that happens now, and InvalidMessageType exception is raised. This
exception has the GatewayAPIParsingException as its parent to help with
backwards compatibility.

Also refactoring the tests and having a new test file for decoding
related error testing.
  • Loading branch information
sgunes-wirepas committed Nov 22, 2024
1 parent c66bfe9 commit afa7ad0
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 96 deletions.
2 changes: 0 additions & 2 deletions tests/default_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

IMPLEMENTED_API_VERSION = 0

INVALID_PROTOBUF_MESSAGE = bytes([0])

# Todo add more fields in config
NODE_CONFIG_1 = dict([("sink_id", SINK_ID), ("node_address", 123)])

Expand Down
64 changes: 64 additions & 0 deletions tests/test_decoding_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# flake8: noqa

import pytest

import wirepas_mesh_messaging
from wirepas_mesh_messaging.proto import GenericMessage
from wirepas_mesh_messaging.wirepas_exceptions import (
GatewayAPIParsingException,
InvalidMessageType,
)

MESSAGE_CLASSES = [
wirepas_mesh_messaging.GetConfigsRequest,
wirepas_mesh_messaging.GetConfigsResponse,
wirepas_mesh_messaging.GetGatewayInfoRequest,
wirepas_mesh_messaging.GetGatewayInfoResponse,
wirepas_mesh_messaging.GetScratchpadStatusRequest,
wirepas_mesh_messaging.GetScratchpadStatusResponse,
wirepas_mesh_messaging.ProcessScratchpadRequest,
wirepas_mesh_messaging.ProcessScratchpadResponse,
wirepas_mesh_messaging.ReceivedDataEvent,
wirepas_mesh_messaging.SendDataRequest,
wirepas_mesh_messaging.SendDataResponse,
wirepas_mesh_messaging.SetConfigRequest,
wirepas_mesh_messaging.SetConfigResponse,
wirepas_mesh_messaging.SetScratchpadTargetAndActionRequest,
wirepas_mesh_messaging.SetScratchpadTargetAndActionResponse,
wirepas_mesh_messaging.StatusEvent,
wirepas_mesh_messaging.UploadScratchpadRequest,
wirepas_mesh_messaging.UploadScratchpadResponse,
]


def _get_payload_excluding_message_type(message_class):
if message_class == wirepas_mesh_messaging.GetConfigsRequest:
return wirepas_mesh_messaging.GetGatewayInfoRequest().payload

return wirepas_mesh_messaging.GetConfigsRequest().payload


@pytest.mark.parametrize("message_class", MESSAGE_CLASSES)
def test_decoding_errors(message_class):
invalid_protobuf_message = bytes([0])

with pytest.raises(GatewayAPIParsingException, match="Cannot decode"):
message_class.from_payload(invalid_protobuf_message)


@pytest.mark.parametrize("message_class", MESSAGE_CLASSES)
def test_decoding_wrong_message_type(message_class):
payload = _get_payload_excluding_message_type(message_class)

with pytest.raises(InvalidMessageType, match=message_class.__name__):
message_class.from_payload(payload)


@pytest.mark.parametrize("message_class", MESSAGE_CLASSES)
def test_decoding_missing_message_type(message_class):
message = GenericMessage()
message.wirepas.SetInParent()
payload = message.SerializeToString()

with pytest.raises(InvalidMessageType, match=message_class.__name__):
message_class.from_payload(payload)
10 changes: 0 additions & 10 deletions tests/test_get_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException

DUMMY_CONFIGS = [NODE_CONFIG_2]

Expand All @@ -30,11 +28,3 @@ def test_generate_parse_response():

for k, v in request.__dict__.items():
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetConfigsRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetConfigsResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
11 changes: 0 additions & 11 deletions tests/test_get_gw_info.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
import time

from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request():
request = wirepas_mesh_messaging.GetGatewayInfoRequest(REQUEST_ID)
Expand Down Expand Up @@ -57,11 +54,3 @@ def test_generate_parse_response_not_all_optional():
for k, v in request.__dict__.items():
print(k)
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetGatewayInfoRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetGatewayInfoResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_get_scratchpad_status.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
import enum
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request():
Expand Down Expand Up @@ -44,11 +42,3 @@ def test_generate_parse_response():
assert v.value == request2.__dict__[k].value
else:
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetScratchpadStatusRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.GetScratchpadStatusResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_process_scratchpad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request():
Expand Down Expand Up @@ -31,11 +29,3 @@ def test_generate_parse_response():

for k, v in request.__dict__.items():
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.ProcessScratchpadRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.ProcessScratchpadResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
6 changes: 0 additions & 6 deletions tests/test_received_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_event():
Expand Down Expand Up @@ -75,7 +73,3 @@ def test_generate_parse_event_with_network_address():

for k, v in status.__dict__.items():
assert v == status2.__dict__[k]

def test_event_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.ReceivedDataEvent.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_send_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request():
Expand Down Expand Up @@ -39,11 +37,3 @@ def test_generate_parse_response():

for k, v in response.__dict__.items():
assert v == response2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SendDataRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SendDataResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_set_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request():
Expand All @@ -30,11 +28,3 @@ def test_generate_parse_response():

for k, v in request.__dict__.items():
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SetConfigRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SetConfigResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_set_scratchpad_target.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
import enum
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request_with_raw():
Expand Down Expand Up @@ -64,11 +62,3 @@ def test_generate_parse_response():
assert v.value == request2.__dict__[k].value
else:
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SetScratchpadTargetAndActionRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.SetScratchpadTargetAndActionResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
6 changes: 0 additions & 6 deletions tests/test_status.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException

DUMMY_CONFIGS = [NODE_CONFIG_1, NODE_CONFIG_2]

Expand All @@ -30,7 +28,3 @@ def test_generate_parse_event_with_max_size():

for k, v in status.__dict__.items():
assert v == status2.__dict__[k]

def test_event_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.StatusEvent.from_payload(INVALID_PROTOBUF_MESSAGE)
10 changes: 0 additions & 10 deletions tests/test_upload_scratchpad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# flake8: noqa

import pytest
import wirepas_mesh_messaging
from default_value import *
from wirepas_mesh_messaging.wirepas_exceptions import GatewayAPIParsingException


def test_generate_parse_request_clear():
Expand Down Expand Up @@ -54,11 +52,3 @@ def test_generate_parse_response():

for k, v in request.__dict__.items():
assert v == request2.__dict__[k]

def test_request_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.UploadScratchpadRequest.from_payload(INVALID_PROTOBUF_MESSAGE)

def test_response_decoding_error():
with pytest.raises(GatewayAPIParsingException):
wirepas_mesh_messaging.UploadScratchpadResponse.from_payload(INVALID_PROTOBUF_MESSAGE)
5 changes: 5 additions & 0 deletions wirepas_mesh_messaging/wirepas_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ class GatewayAPIParsingException(Exception):
"""
Wirepas Gateway API generic Exception
"""

class InvalidMessageType(GatewayAPIParsingException):
"""
Exception indicating wrong message type during deserialization
"""
9 changes: 8 additions & 1 deletion wirepas_mesh_messaging/wirepas_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from abc import ABC, abstractmethod

from .proto import GenericMessage
from .wirepas_exceptions import GatewayAPIParsingException
from .wirepas_exceptions import GatewayAPIParsingException, InvalidMessageType


class WirepasMessage(ABC):
Expand Down Expand Up @@ -67,4 +67,11 @@ def _decode_and_get_related_message(cls, payload):

contained_message = cls._get_related_message(generic_message)

# Works by checking if all required fields of contained_message are
# set. In our case, every message holds a required header field.
if not contained_message.IsInitialized():
raise InvalidMessageType(
f"Could not find relevant Wirepas message for {cls.__name__}"
)

return contained_message

0 comments on commit afa7ad0

Please sign in to comment.