From e9443e7655265d7ac6974fd0b072810aea553135 Mon Sep 17 00:00:00 2001 From: Duco Sebel <74970928+DCSBL@users.noreply.github.com> Date: Sun, 17 Nov 2024 13:59:19 +0100 Subject: [PATCH 1/4] Implement mimimum WebSocket API --- homewizard_energy/v2/__init__.py | 81 +++++++++---- homewizard_energy/v2/const.py | 12 +- homewizard_energy/v2/websocket.py | 185 ++++++++++++++++++++++++++++++ 3 files changed, 252 insertions(+), 26 deletions(-) create mode 100644 homewizard_energy/v2/websocket.py diff --git a/homewizard_energy/v2/__init__.py b/homewizard_energy/v2/__init__.py index 8aa4698..2868ccf 100644 --- a/homewizard_energy/v2/__init__.py +++ b/homewizard_energy/v2/__init__.py @@ -5,20 +5,13 @@ import asyncio import logging import ssl -from collections.abc import Callable, Coroutine +from collections.abc import Coroutine from http import HTTPStatus -from typing import Any, TypeVar +from typing import Any, Callable, TypeVar +import aiohttp import async_timeout import backoff -from aiohttp.client import ( - ClientError, - ClientResponseError, - ClientSession, - ClientTimeout, - TCPConnector, -) -from aiohttp.hdrs import METH_DELETE, METH_GET, METH_POST, METH_PUT from homewizard_energy.errors import ( DisabledError, @@ -30,6 +23,7 @@ from .cacert import CACERT from .models import Device, Measurement, System, SystemUpdate +from .websocket import Websocket _LOGGER = logging.getLogger(__name__) @@ -54,9 +48,10 @@ async def wrapper(self, *args, **kwargs) -> T: class HomeWizardEnergyV2: """Communicate with a HomeWizard Energy device.""" - _clientsession: ClientSession | None = None + _clientsession: aiohttp.ClientSession | None = None _close_clientsession: bool = False _request_timeout: int = 10 + _websocket: Websocket | None = None def __init__( self, @@ -89,6 +84,16 @@ def host(self) -> str: """ return self._host + @property + def websocket(self) -> Websocket: + """Return the websocket object. + + Create a new websocket object if it does not exist. + """ + if self._websocket is None: + self._websocket = Websocket(self) + return self._websocket + @authorized_method async def device(self) -> Device: """Return the device object.""" @@ -115,7 +120,7 @@ async def system( if update is not None: data = update.as_dict() status, response = await self._request( - "/api/system", method=METH_PUT, data=data + "/api/system", method=aiohttp.hdrs.METH_PUT, data=data ) else: @@ -133,7 +138,7 @@ async def identify( self, ) -> bool: """Send identify request.""" - await self._request("/api/system/identify", method=METH_PUT) + await self._request("/api/system/identify", method=aiohttp.hdrs.METH_PUT) return True async def get_token( @@ -142,7 +147,7 @@ async def get_token( ) -> str: """Get authorization token from device.""" status, response = await self._request( - "/api/user", method=METH_POST, data={"name": f"local/{name}"} + "/api/user", method=aiohttp.hdrs.METH_POST, data={"name": f"local/{name}"} ) if status == HTTPStatus.FORBIDDEN: @@ -168,7 +173,7 @@ async def delete_token( """Delete authorization token from device.""" status, response = await self._request( "/api/user", - method=METH_DELETE, + method=aiohttp.hdrs.METH_DELETE, data={"name": name} if name is not None else None, ) @@ -180,11 +185,34 @@ async def delete_token( if name is None: self._token = None - async def _get_clientsession(self) -> ClientSession: + @property + def token(self) -> str | None: + """Return the token of the device. + + Returns: + token: The used token + + """ + return self._token + + @property + def request_timeout(self) -> int: + """Return the request timeout of the device. + + Returns: + request_timeout: The used request timeout + + """ + return self._request_timeout + + async def get_clientsession(self) -> aiohttp.ClientSession: """ Get a clientsession that is tuned for communication with the HomeWizard Energy Device """ + if self._clientsession is not None: + return self._clientsession + def _build_ssl_context() -> ssl.SSLContext: context = ssl.create_default_context(cadata=CACERT) if self._identifier is not None: @@ -199,26 +227,28 @@ def _build_ssl_context() -> ssl.SSLContext: loop = asyncio.get_running_loop() context = await loop.run_in_executor(None, _build_ssl_context) - connector = TCPConnector( + connector = aiohttp.TCPConnector( enable_cleanup_closed=True, ssl=context, limit_per_host=1, ) - return ClientSession( - connector=connector, timeout=ClientTimeout(total=self._request_timeout) + self._clientsession = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self._request_timeout), ) + return self._clientsession + @backoff.on_exception(backoff.expo, RequestError, max_tries=5, logger=None) async def _request( - self, path: str, method: str = METH_GET, data: object = None + self, path: str, method: str = aiohttp.hdrs.METH_GET, data: object = None ) -> Any: """Make a request to the API.""" - if self._clientsession is None: - self._clientsession = await self._get_clientsession() + _clientsession = await self.get_clientsession() - if self._clientsession.closed: + if _clientsession.closed: # Avoid runtime errors when connection is closed. # This solves an issue when updates were scheduled and clientsession was closed. return None @@ -235,7 +265,7 @@ async def _request( try: async with async_timeout.timeout(self._request_timeout): - resp = await self._clientsession.request( + resp = await _clientsession.request( method, url, json=data, @@ -249,7 +279,7 @@ async def _request( raise RequestError( f"Timeout occurred while connecting to the HomeWizard Energy device at {self.host}" ) from exception - except (ClientError, ClientResponseError) as exception: + except (aiohttp.ClientError, aiohttp.ClientResponseError) as exception: raise RequestError( f"Error occurred while communicating with the HomeWizard Energy device at {self.host}" ) from exception @@ -276,6 +306,7 @@ async def close(self) -> None: _LOGGER.debug("Closing clientsession") if self._clientsession is not None: await self._clientsession.close() + self._clientsession = None async def __aenter__(self) -> HomeWizardEnergyV2: """Async enter. diff --git a/homewizard_energy/v2/const.py b/homewizard_energy/v2/const.py index 520289d..d1884f4 100644 --- a/homewizard_energy/v2/const.py +++ b/homewizard_energy/v2/const.py @@ -1,6 +1,16 @@ """Constants for HomeWizard Energy.""" -SUPPORTED_API_VERSION = "v1" +from enum import StrEnum + +SUPPORTED_API_VERSION = "2.0.0" SUPPORTS_STATE = ["HWE-SKT"] SUPPORTS_IDENTIFY = ["HWE-SKT", "HWE-P1", "HWE-WTR"] + + +class WebsocketTopic(StrEnum): + """Websocket topics.""" + + DEVICE = "device" + MEASUREMENT = "measurement" + SYSTEM = "system" diff --git a/homewizard_energy/v2/websocket.py b/homewizard_energy/v2/websocket.py new file mode 100644 index 0000000..fdb816f --- /dev/null +++ b/homewizard_energy/v2/websocket.py @@ -0,0 +1,185 @@ +"""Websocket client for HomeWizard Energy API.""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Callable + +import aiohttp + +from homewizard_energy.errors import UnauthorizedError + +from .const import WebsocketTopic +from .models import Device, Measurement, System + +if TYPE_CHECKING: + from . import HomeWizardEnergyV2 + +_LOGGER = logging.getLogger(__name__) + + +class Websocket: + """Websocket client for HomeWizard Energy API.""" + + _connect_lock: asyncio.Lock = asyncio.Lock() + _ws_connection: aiohttp.ClientWebSocketResponse | None = None + _ws_subscriptions: list[tuple[str, Callable[[aiohttp.WSMessage], None]]] = [] + _ws_authenticated: bool = False + + def __init__(self, parent: "HomeWizardEnergyV2"): + self._parent = parent + + async def connect(self) -> bool: + """Connect the websocket.""" + + if self._connect_lock.locked(): + _LOGGER.debug("Another connect is already happening") + return False + try: + await asyncio.wait_for(self._connect_lock.acquire(), timeout=0.1) + except asyncio.TimeoutError: + _LOGGER.debug("Failed to get connection lock") + + start_event = asyncio.Event() + _LOGGER.debug("Scheduling WS connect...") + asyncio.create_task(self._websocket_loop(start_event)) + + try: + await asyncio.wait_for( + start_event.wait(), timeout=self._parent.request_timeout + ) + except asyncio.TimeoutError: + _LOGGER.warning("Timed out while waiting for Websocket to connect") + await self.disconnect() + + self._connect_lock.release() + if self._ws_connection is None: + _LOGGER.debug("Failed to connect to Websocket") + return False + _LOGGER.debug("Connected to Websocket successfully") + return True + + async def disconnect(self) -> None: + """Disconnect the websocket.""" + if self._ws_connection is not None and not self._ws_connection.closed: + await self._ws_connection.close() + self._ws_connection = None + + def subscribe( + self, topic: WebsocketTopic, ws_callback: Callable[[aiohttp.WSMessage], None] + ) -> Callable[[], None]: + """ + Subscribe to raw websocket messages. + + Returns a callback that will unsubscribe. + """ + + def _unsub_ws_callback() -> None: + self._ws_subscriptions.remove({topic, ws_callback}) + + _LOGGER.debug("Adding subscription: %s, %s", topic, ws_callback) + self._ws_subscriptions.append((topic, ws_callback)) + + if self._ws_connection is not None and self._ws_authenticated: + asyncio.create_task( + self._ws_connection.send_json({"type": "subscribe", "data": topic}) + ) + + return _unsub_ws_callback + + async def _websocket_loop(self, start_event: asyncio.Event) -> None: + _LOGGER.debug("Connecting WS...") + + _clientsession = await self._parent.get_clientsession() + + # catch any and all errors for Websocket so we can clean up correctly + try: + self._ws_connection = await _clientsession.ws_connect( + f"wss://{self._parent.host}/api/ws", ssl=False + ) + start_event.set() + + async for msg in self._ws_connection: + _LOGGER.info("Received message: %s", msg) + if not await self._process_message(msg): + break + except aiohttp.ClientError as e: + _LOGGER.exception("Websocket disconnect error: %s", e) + finally: + _LOGGER.debug("Websocket disconnected") + if self._ws_connection is not None and not self._ws_connection.closed: + await self._ws_connection.close() + self._ws_connection = None + # make sure event does not timeout + start_event.set() + + async def _on_authorization_requested(self, msg_type: str, msg_data: str) -> None: + del msg_type, msg_data + + _LOGGER.info("Authorization requested") + if self._ws_authenticated: + raise UnauthorizedError("Already authenticated") + + await self._ws_connection.send_json( + {"type": "authorization", "data": self._parent.token} + ) + + async def _on_authorized(self, msg_type: str, msg_data: str) -> None: + del msg_type, msg_data + + _LOGGER.info("Authorized") + self._ws_authenticated = True + + # Send subscription requests + print(self._ws_subscriptions) + for topic, _ in self._ws_subscriptions: + _LOGGER.info("Sending subscription request for %s", topic) + await self._ws_connection.send_json({"type": "subscribe", "data": topic}) + + async def _process_message(self, msg: aiohttp.WSMessage) -> bool: + if msg.type == aiohttp.WSMsgType.ERROR: + raise ValueError(f"Error from Websocket: {msg.data}") + + _LOGGER.debug("Received message: %s", msg.data) + + if msg.type == aiohttp.WSMsgType.TEXT: + try: + msg = msg.json() + except ValueError as ex: + raise ValueError(f"Invalid JSON received: {msg.data}") from ex + + if "type" not in msg: + raise ValueError(f"Missing 'type' in message: {msg}") + + msg_type = msg.get("type") + msg_data = msg.get("data") + parsed_data = None + + match msg_type: + case "authorization_requested": + await self._on_authorization_requested(msg_type, msg_data) + return True + + case "authorized": + await self._on_authorized(msg_type, msg_data) + return True + + case WebsocketTopic.MEASUREMENT: + parsed_data = Measurement.from_dict(msg_data) + + case WebsocketTopic.SYSTEM: + parsed_data = System.from_dict(msg_data) + + case WebsocketTopic.DEVICE: + parsed_data = Device.from_dict(msg_data) + + if parsed_data is None: + raise ValueError(f"Unknown message type: {msg_type}") + + for topic, callback in self._ws_subscriptions: + if topic == msg_type: + try: + await callback(topic, parsed_data) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error processing websocket message") + + return True From 3c2a743db3644c2b52063fbd5629ac76aeca9d53 Mon Sep 17 00:00:00 2001 From: Duco Sebel <74970928+DCSBL@users.noreply.github.com> Date: Sun, 17 Nov 2024 14:06:45 +0100 Subject: [PATCH 2/4] Adjust callback type --- homewizard_energy/v2/websocket.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/homewizard_energy/v2/websocket.py b/homewizard_energy/v2/websocket.py index fdb816f..19688eb 100644 --- a/homewizard_energy/v2/websocket.py +++ b/homewizard_energy/v2/websocket.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from . import HomeWizardEnergyV2 +OnMessageCallbackType = Callable[[str, Device | Measurement | System], None] + _LOGGER = logging.getLogger(__name__) @@ -22,7 +24,8 @@ class Websocket: _connect_lock: asyncio.Lock = asyncio.Lock() _ws_connection: aiohttp.ClientWebSocketResponse | None = None - _ws_subscriptions: list[tuple[str, Callable[[aiohttp.WSMessage], None]]] = [] + _ws_subscriptions: list[tuple[str, OnMessageCallbackType]] = [] + _ws_authenticated: bool = False def __init__(self, parent: "HomeWizardEnergyV2"): @@ -65,7 +68,7 @@ async def disconnect(self) -> None: self._ws_connection = None def subscribe( - self, topic: WebsocketTopic, ws_callback: Callable[[aiohttp.WSMessage], None] + self, topic: WebsocketTopic, ws_callback: OnMessageCallbackType ) -> Callable[[], None]: """ Subscribe to raw websocket messages. From 843d889003c7e9fa12ac8a43b0fec3b5c09aa470 Mon Sep 17 00:00:00 2001 From: Duco Sebel <74970928+DCSBL@users.noreply.github.com> Date: Sun, 17 Nov 2024 15:39:22 +0100 Subject: [PATCH 3/4] add strenum dep --- poetry.lock | 18 +++++++++++++++++- pyproject.toml | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 28a5c57..af51874 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2119,6 +2119,22 @@ files = [ [package.dependencies] pbr = ">=2.0.0" +[[package]] +name = "strenum" +version = "0.4.15" +description = "An Enum that inherits from str." +optional = false +python-versions = "*" +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[package.extras] +docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] +release = ["twine"] +test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] + [[package]] name = "syrupy" version = "4.7.2" @@ -2362,4 +2378,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "2ce12f9ead867d0626dc6bbe4a477154215bd2a287a3fcb0dc8de072d176fa1d" +content-hash = "036bcff7c45c1dbfd2f344cc0808fe465912c704d3dbab9b338e73e286bb7f43" diff --git a/pyproject.toml b/pyproject.toml index a8b5efc..c4e1cfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ aiohttp = ">=3.0.0" async-timeout = "^4.0.3" multidict = "^6.0.5" ## To fix aiohttp dependency at python 3.12 backoff = "^2.2.1" +strenum = "^0.4.15" [tool.poetry.dev-dependencies] aresponses = "^3.0.0" From 2a7c01f228ca65346deafd2e0807a2a093910693 Mon Sep 17 00:00:00 2001 From: Duco Sebel <74970928+DCSBL@users.noreply.github.com> Date: Sun, 17 Nov 2024 15:42:50 +0100 Subject: [PATCH 4/4] And remove strenum again --- poetry.lock | 18 +----------------- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/poetry.lock b/poetry.lock index af51874..28a5c57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2119,22 +2119,6 @@ files = [ [package.dependencies] pbr = ">=2.0.0" -[[package]] -name = "strenum" -version = "0.4.15" -description = "An Enum that inherits from str." -optional = false -python-versions = "*" -files = [ - {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, - {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, -] - -[package.extras] -docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] -release = ["twine"] -test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] - [[package]] name = "syrupy" version = "4.7.2" @@ -2378,4 +2362,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "036bcff7c45c1dbfd2f344cc0808fe465912c704d3dbab9b338e73e286bb7f43" +content-hash = "2ce12f9ead867d0626dc6bbe4a477154215bd2a287a3fcb0dc8de072d176fa1d" diff --git a/pyproject.toml b/pyproject.toml index c4e1cfe..a8b5efc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ aiohttp = ">=3.0.0" async-timeout = "^4.0.3" multidict = "^6.0.5" ## To fix aiohttp dependency at python 3.12 backoff = "^2.2.1" -strenum = "^0.4.15" [tool.poetry.dev-dependencies] aresponses = "^3.0.0"