Skip to content

Commit

Permalink
Do login entirely within AesTransport (python-kasa#580)
Browse files Browse the repository at this point in the history
* Do login entirely within AesTransport

* Remove login and handshake attributes from BaseTransport

* Add AesTransport tests

* Synchronise transport and protocol __init__ signatures and rename internal variables

* Update after review
  • Loading branch information
sdb9696 authored Dec 19, 2023
1 parent 209391c commit 20ea670
Show file tree
Hide file tree
Showing 13 changed files with 468 additions and 237 deletions.
81 changes: 33 additions & 48 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hashlib
import logging
import time
from typing import Optional, Union
from typing import Optional

import httpx
from cryptography.hazmat.primitives import padding, serialization
Expand Down Expand Up @@ -47,6 +47,7 @@ class AesTransport(BaseTransport):
protocol, sometimes used by newer firmware versions on kasa devices.
"""

DEFAULT_PORT = 80
DEFAULT_TIMEOUT = 5
SESSION_COOKIE_NAME = "TP_SESSIONID"
COMMON_HEADERS = {
Expand All @@ -59,12 +60,16 @@ def __init__(
self,
host: str,
*,
port: Optional[int] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
super().__init__(host=host)

self._credentials = credentials or Credentials(username="", password="")
super().__init__(
host,
port=port or self.DEFAULT_PORT,
credentials=credentials,
timeout=timeout,
)

self._handshake_done = False

Expand All @@ -77,7 +82,7 @@ def __init__(
self._http_client: httpx.AsyncClient = httpx.AsyncClient()
self._login_token = None

_LOGGER.debug("Created AES object for %s", self.host)
_LOGGER.debug("Created AES transport for %s", self._host)

def hash_credentials(self, login_v2):
"""Hash the credentials."""
Expand Down Expand Up @@ -123,7 +128,7 @@ def _handle_response_error_code(self, resp_dict: dict, msg: str):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
Expand All @@ -136,7 +141,7 @@ def _handle_response_error_code(self, resp_dict: dict, msg: str):

async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
if self._login_token:
url += f"?token={self._login_token}"

Expand All @@ -150,7 +155,7 @@ async def send_secure_passthrough(self, request: str):

if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)

Expand All @@ -164,49 +169,31 @@ async def send_secure_passthrough(self, request: str):
resp_dict = json_loads(response)
return resp_dict

async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
async def _perform_login_for_version(self, *, login_version: int = 1):
"""Login to the device."""
self._login_token = None

if isinstance(login_request, str):
login_request_dict: dict = json_loads(login_request)
else:
login_request_dict = login_request

un, pw = self.hash_credentials(login_v2)
login_request_dict["params"] = {"password": pw, "username": un}
request = json_dumps(login_request_dict)
un, pw = self.hash_credentials(login_version == 2)
password_field_name = "password2" if login_version == 2 else "password"
login_request = {
"method": "login_device",
"params": {password_field_name: pw, "username": un},
"request_time_milis": round(time.time() * 1000),
}
request = json_dumps(login_request)
try:
resp_dict = await self.send_secure_passthrough(request)
except SmartDeviceException as ex:
raise AuthenticationException(ex) from ex
self._login_token = resp_dict["result"]["token"]

@property
def needs_login(self) -> bool:
"""Return true if the transport needs to do a login."""
return self._login_token is None

async def login(self, request: str) -> None:
async def perform_login(self) -> None:
"""Login to the device."""
try:
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to login"
)
await self.perform_login(request, login_v2=False)
await self._perform_login_for_version(login_version=2)
except AuthenticationException:
_LOGGER.warning("Login version 2 failed, trying version 1")
await self.perform_handshake()
await self.perform_login(request, login_v2=True)

@property
def needs_handshake(self) -> bool:
"""Return true if the transport needs to do a handshake."""
return not self._handshake_done or self._handshake_session_expired()

async def handshake(self) -> None:
"""Perform the encryption handshake."""
await self.perform_handshake()
await self._perform_login_for_version(login_version=1)

async def perform_handshake(self):
"""Perform the handshake."""
Expand All @@ -217,7 +204,7 @@ async def perform_handshake(self):
self._session_expire_at = None
self._session_cookie = None

url = f"http://{self.host}/app"
url = f"http://{self._host}/app"
key_pair = KeyPair.create_key_pair()

pub_key = (
Expand All @@ -238,7 +225,7 @@ async def perform_handshake(self):

if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
f"{self._host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)

Expand All @@ -261,7 +248,7 @@ async def perform_handshake(self):

self._handshake_done = True

_LOGGER.debug("Handshake with %s complete", self.host)
_LOGGER.debug("Handshake with %s complete", self._host)

def _handshake_session_expired(self):
"""Return true if session has expired."""
Expand All @@ -272,12 +259,10 @@ def _handshake_session_expired(self):

async def send(self, request: str):
"""Send the request."""
if self.needs_handshake:
raise SmartDeviceException(
"Handshake must be complete before trying to send"
)
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")
if not self._handshake_done or self._handshake_session_expired():
await self.perform_handshake()
if not self._login_token:
await self.perform_login()

return await self.send_secure_passthrough(request)

Expand Down
19 changes: 14 additions & 5 deletions kasa/device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
dev.protocol = protocol_class(host, credentials=credentials)
dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await dev.update()
if debug_enabled:
end_time = time.perf_counter()
Expand All @@ -90,7 +95,13 @@ async def connect(
host=host, port=port, credentials=credentials, timeout=timeout
)
if protocol_class is not None:
unknown_dev.protocol = protocol_class(host, credentials=credentials)
# TODO this will be replaced with connection params
unknown_dev.protocol = protocol_class(
host,
transport=AesTransport(
host, port=port, credentials=credentials, timeout=timeout
),
)
await unknown_dev.update()
device_class = get_device_class_from_sys_info(unknown_dev.internal_state)
dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout)
Expand Down Expand Up @@ -163,7 +174,5 @@ def get_protocol_from_connection_name(

protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore
transport: BaseTransport = transport_class(host, credentials=credentials)
protocol: TPLinkProtocol = protocol_class(
host, credentials=credentials, transport=transport
)
protocol: TPLinkProtocol = protocol_class(host, transport=transport)
return protocol
44 changes: 13 additions & 31 deletions kasa/iotprotocol.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Module for the IOT legacy IOT KASA protocol."""
import asyncio
import logging
from typing import Dict, Optional, Union
from typing import Dict, Union

import httpx

from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .json import dumps as json_dumps
from .klaptransport import KlapTransport
from .protocol import BaseTransport, TPLinkProtocol

_LOGGER = logging.getLogger(__name__)
Expand All @@ -17,24 +15,14 @@
class IotProtocol(TPLinkProtocol):
"""Class for the legacy TPLink IOT KASA Protocol."""

DEFAULT_PORT = 80

def __init__(
self,
host: str,
*,
transport: Optional[BaseTransport] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
transport: BaseTransport,
) -> None:
super().__init__(host=host, port=self.DEFAULT_PORT)

self._credentials: Credentials = credentials or Credentials(
username="", password=""
)
self._transport: BaseTransport = transport or KlapTransport(
host, credentials=self._credentials, timeout=timeout
)
"""Create a protocol object."""
super().__init__(host, transport=transport)

self._query_lock = asyncio.Lock()

Expand All @@ -54,45 +42,39 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict:
except httpx.CloseError as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {sdex}"
f"Unable to connect to the device: {self._host}: {sdex}"
) from sdex
continue
except httpx.ConnectError as cex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {cex}"
f"Unable to connect to the device: {self._host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
f"Unable to connect to the device, timed out: {self._host}: {tex}"
) from tex
except AuthenticationException as auex:
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to connect to the device: {self._host}: {ex}"
) from ex
continue

# make mypy happy, this should never be reached..
raise SmartDeviceException("Query reached somehow to unreachable")

async def _execute_query(self, request: str, retry_count: int) -> Dict:
if self._transport.needs_handshake:
await self._transport.handshake()

if self._transport.needs_login: # This shouln't happen
raise SmartDeviceException(
"IOT Protocol needs to login to transport but is not login aware"
)

return await self._transport.send(request)

async def close(self) -> None:
Expand Down
Loading

0 comments on commit 20ea670

Please sign in to comment.