Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing overhaul + pre-commit updates #200

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ repos:
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
- repo: https://github.com/fsfe/reuse-tool
rev: v1.1.2
hooks:
Expand All @@ -18,7 +23,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/pylint
rev: v2.17.4
rev: v3.0.4
hooks:
- id: pylint
name: pylint (library code)
Expand All @@ -40,3 +45,7 @@ repos:
files: "^tests/"
args:
- --disable=missing-docstring,consider-using-f-string,duplicate-code
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
hooks:
- id: mypy
64 changes: 41 additions & 23 deletions adafruit_ble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

from __future__ import annotations

import sys

# pylint: disable=wrong-import-position

import sys

if sys.implementation.name == "circuitpython" and sys.implementation.version[0] <= 4:
raise ImportError(
Expand All @@ -24,17 +25,31 @@

import _bleio

from .services import Service
from .advertising import Advertisement
from .services import Service
from .uuid import UUID

try:
from typing import Iterator, NoReturn, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
NoReturn,
Optional,
Tuple,
Type,
Union,
)

from typing_extensions import Literal

if TYPE_CHECKING:
from circuitpython_typing import ReadableBuffer
from adafruit_ble.uuid import UUID
from adafruit_ble.characteristics import Characteristic

from adafruit_ble.uuid import StandardUUID, VendorUUID

Uuid = Union[StandardUUID, VendorUUID]

except ImportError:
pass
Expand All @@ -55,11 +70,11 @@ class BLEConnection:
def __init__(self, bleio_connection: _bleio.Connection) -> None:
self._bleio_connection = bleio_connection
# _bleio.Service objects representing services found during discovery.
self._discovered_bleio_services = {}
self._discovered_bleio_services: Dict[Uuid, _bleio.Service] = {}
# Service objects that wrap remote services.
self._constructed_services = {}
self._constructed_services: Dict[Uuid, Service] = {}

def _discover_remote(self, uuid: UUID) -> Optional[_bleio.Service]:
def _discover_remote(self, uuid: Uuid) -> Optional[_bleio.Service]:
remote_service = None
if uuid in self._discovered_bleio_services:
remote_service = self._discovered_bleio_services[uuid]
Expand All @@ -72,7 +87,7 @@ def _discover_remote(self, uuid: UUID) -> Optional[_bleio.Service]:
self._discovered_bleio_services[uuid] = remote_service
return remote_service

def __contains__(self, key: Union[UUID, Type[Service]]) -> bool:
def __contains__(self, key: Union[Uuid, Type[Service]]) -> bool:
"""
Allows easy testing for a particular Service class or a particular UUID
associated with this connection.
Expand All @@ -85,16 +100,15 @@ def __contains__(self, key: Union[UUID, Type[Service]]) -> bool:
if StandardUUID(0x1234) in connection:
# do something
"""
uuid = key
if hasattr(key, "uuid"):
uuid = key.uuid
uuid = key if isinstance(key, UUID) else key.uuid
return self._discover_remote(uuid) is not None

def __getitem__(self, key: Union[UUID, Type[Service]]) -> Optional[Service]:
def __getitem__(self, key: Union[Uuid, Type[Service]]) -> Optional[Service]:
"""Return the Service for the given Service class or uuid, if any."""
uuid = key
maybe_service = False
if hasattr(key, "uuid"):
if isinstance(key, UUID):
uuid = key
maybe_service = False
else:
uuid = key.uuid
maybe_service = True

Expand All @@ -104,7 +118,7 @@ def __getitem__(self, key: Union[UUID, Type[Service]]) -> Optional[Service]:
remote_service = self._discover_remote(uuid)
if remote_service:
constructed_service = None
if maybe_service:
if maybe_service and not isinstance(key, UUID):
constructed_service = key(service=remote_service)
self._constructed_services[uuid] = constructed_service
return constructed_service
Expand Down Expand Up @@ -166,7 +180,7 @@ def __init__(self, adapter: Optional[_bleio.Adapter] = None) -> None:
raise RuntimeError("No adapter available")
self._adapter = adapter or _bleio.adapter
self._current_advertisement = None
self._connection_cache = {}
self._connection_cache: Dict[_bleio.Connection, BLEConnection] = {}

def start_advertising(
self,
Expand Down Expand Up @@ -223,7 +237,7 @@ def stop_advertising(self) -> None:
"""Stops advertising."""
self._adapter.stop_advertising()

def start_scan(
def start_scan( # pylint: disable=too-many-arguments
self,
*advertisement_types: Type[Advertisement],
buffer_size: int = 512,
Expand Down Expand Up @@ -311,9 +325,13 @@ def connect(
:return: the connection to the peer
:rtype: BLEConnection
"""
if not isinstance(peer, _bleio.Address):
peer = peer.address
connection = self._adapter.connect(peer, timeout=timeout)
if isinstance(peer, _bleio.Address):
peer_ = peer
else:
assert peer.address is not None
peer_ = peer.address

connection = self._adapter.connect(peer_, timeout=timeout)
self._clean_connection_cache()
self._connection_cache[connection] = BLEConnection(connection)
return self._connection_cache[connection]
Expand All @@ -328,7 +346,7 @@ def connections(self) -> Tuple[Optional[BLEConnection], ...]:
"""A tuple of active `BLEConnection` objects."""
self._clean_connection_cache()
connections = self._adapter.connections
wrapped_connections = [None] * len(connections)
wrapped_connections: List[Optional[BLEConnection]] = [None] * len(connections)
for i, connection in enumerate(connections):
if connection not in self._connection_cache:
self._connection_cache[connection] = BLEConnection(connection)
Expand Down
68 changes: 45 additions & 23 deletions adafruit_ble/advertising/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,29 @@
import struct

try:
from typing import Dict, Any, Union, List, Optional, Type, TypeVar, TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from typing_extensions import Literal

if TYPE_CHECKING:
from _bleio import ScanEntry
from _bleio import Address, ScanEntry

LazyObjectField_GivenClass = TypeVar( # pylint: disable=invalid-name
"LazyObjectField_GivenClass"
)

DataDict = Dict[int, Union[bytes, List[bytes]]]

except ImportError:
pass

Expand All @@ -35,13 +48,11 @@ def to_bytes_literal(seq: bytes) -> str:
return 'b"' + "".join("\\x{:02x}".format(v) for v in seq) + '"'


def decode_data(
data: bytes, *, key_encoding: str = "B"
) -> Dict[Any, Union[bytes, List[bytes]]]:
def decode_data(data: bytes, *, key_encoding: str = "B") -> DataDict:
"""Helper which decodes length encoded structures into a dictionary with the given key
encoding."""
i = 0
data_dict = {}
data_dict: DataDict = {}
key_size = struct.calcsize(key_encoding)
while i < len(data):
item_length = data[i]
Expand All @@ -51,18 +62,18 @@ def decode_data(
key = struct.unpack_from(key_encoding, data, i)[0]
value = data[i + key_size : i + item_length]
if key in data_dict:
if not isinstance(data_dict[key], list):
data_dict[key] = [data_dict[key]]
data_dict[key].append(value)
cur_value = data_dict[key]
if isinstance(cur_value, list):
cur_value.append(value)
else:
data_dict[key] = [cur_value, value]
else:
data_dict[key] = value
i += item_length
return data_dict


def compute_length(
data_dict: Dict[Any, Union[bytes, List[bytes]]], *, key_encoding: str = "B"
) -> int:
def compute_length(data_dict: DataDict, *, key_encoding: str = "B") -> int:
"""Computes the length of the encoded data dictionary."""
value_size = 0
for value in data_dict.values():
Expand All @@ -74,9 +85,7 @@ def compute_length(
return len(data_dict) + len(data_dict) * struct.calcsize(key_encoding) + value_size


def encode_data(
data_dict: Dict[Any, Union[bytes, List[bytes]]], *, key_encoding: str = "B"
) -> bytes:
def encode_data(data_dict: DataDict, *, key_encoding: str = "B") -> bytes:
"""Helper which encodes dictionaries into length encoded structures with the given key
encoding."""
length = compute_length(data_dict, key_encoding=key_encoding)
Expand Down Expand Up @@ -137,7 +146,9 @@ def __init__(
self._adt = advertising_data_type
self.flags = 0
if self._adt in self._advertisement.data_dict:
self.flags = self._advertisement.data_dict[self._adt][0]
value = self._advertisement.data_dict[self._adt]
assert not isinstance(value, list)
self.flags = value[0]

def __len__(self) -> Literal[1]:
return 1
Expand Down Expand Up @@ -170,7 +181,9 @@ def __get__(
return self
if self._adt not in obj.data_dict:
return None
return str(obj.data_dict[self._adt], "utf-8")
value = obj.data_dict[self._adt]
assert not isinstance(value, list)
return str(value, "utf-8")

def __set__(self, obj: "Advertisement", value: str) -> None:
obj.data_dict[self._adt] = value.encode("utf-8")
Expand All @@ -190,7 +203,9 @@ def __get__(
return self
if self._adt not in obj.data_dict:
return None
return struct.unpack(self._format, obj.data_dict[self._adt])[0]
value = obj.data_dict[self._adt]
assert not isinstance(value, list)
return struct.unpack(self._format, value)[0]

def __set__(self, obj: "Advertisement", value: Any) -> None:
obj.data_dict[self._adt] = struct.pack(self._format, value)
Expand Down Expand Up @@ -237,7 +252,10 @@ class Advertisement:
bytestring prefixes to match against the multiple data structures in the advertisement.
"""

match_prefixes = ()
address: Optional[Address]
_rssi: Optional[int]

match_prefixes: Optional[Tuple[bytes, ...]] = ()
"""For Advertisement, :py:attr:`~adafruit_ble.advertising.Advertisement.match_prefixes`
will always return ``True``. Subclasses may override this value."""
# cached bytes of merged prefixes.
Expand Down Expand Up @@ -293,22 +311,26 @@ def rssi(self) -> Optional[int]:
return self._rssi

@classmethod
def get_prefix_bytes(cls) -> Optional[bytes]:
def get_prefix_bytes(cls) -> bytes:
"""Return a merged version of match_prefixes as a single bytes object,
with length headers.
"""
# Check for deprecated `prefix` class attribute.
cls._prefix_bytes = getattr(cls, "prefix", None)
prefix_bytes: Optional[bytes] = getattr(cls, "prefix", None)

# Do merge once and memoize it.
if cls._prefix_bytes is None:
cls._prefix_bytes = (
cls._prefix_bytes = (
(
b""
if cls.match_prefixes is None
else b"".join(
len(prefix).to_bytes(1, "little") + prefix
for prefix in cls.match_prefixes
)
)
if prefix_bytes is None
else prefix_bytes
)

return cls._prefix_bytes

Expand Down
1 change: 1 addition & 0 deletions adafruit_ble/advertising/adafruit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import struct

from micropython import const

from . import Advertisement, LazyObjectField
Expand Down
Loading