Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 99 additions & 12 deletions switchbot/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from collections.abc import Callable
from dataclasses import replace
from enum import IntEnum
from typing import Any, TypeVar, cast
from uuid import UUID

Expand Down Expand Up @@ -142,6 +143,21 @@ class SwitchbotOperationError(Exception):
"""Raised when an operation fails."""


class AESMode(IntEnum):
"""Supported AES modes for encrypted devices."""

CTR = 0
GCM = 1


def _normalize_encryption_mode(mode: int) -> AESMode:
"""Normalize encryption mode to AESMode (only 0/1 allowed)."""
try:
return AESMode(mode)
except (TypeError, ValueError) as exc:
raise ValueError(f"Unsupported encryption mode: {mode}") from exc


def _sb_uuid(comms_type: str = "service") -> UUID | str:
"""Return Switchbot UUID."""
_uuid = {"tx": "002", "rx": "003", "service": "d00"}
Expand Down Expand Up @@ -982,7 +998,8 @@ def __init__(
self._key_id = key_id
self._encryption_key = bytearray.fromhex(encryption_key)
self._iv: bytes | None = None
self._cipher: bytes | None = None
self._cipher: Cipher | None = None
self._encryption_mode: AESMode | None = None
super().__init__(device, None, interface, **kwargs)
self._model = model

Expand Down Expand Up @@ -1081,9 +1098,8 @@ async def _send_command(
_LOGGER.error("Failed to initialize encryption")
return None

encrypted = (
key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
)
ciphertext_hex, header_hex = self._encrypt(key[2:])
encrypted = key[:2] + self._key_id + header_hex + ciphertext_hex
command = bytearray.fromhex(self._commandkey(encrypted))
_LOGGER.debug("%s: Scheduling command %s", self.name, command.hex())
max_attempts = retry + 1
Expand All @@ -1093,7 +1109,10 @@ async def _send_command(
)
if result is None:
return None
return result[:1] + self._decrypt(result[4:])
decrypted = self._decrypt(result[4:])
if self._encryption_mode == AESMode.GCM:
self._increment_gcm_iv()
return result[:1] + decrypted

async def _ensure_encryption_initialized(self) -> bool:
"""Ensure encryption is initialized, must be called with operation lock held."""
Expand All @@ -1117,34 +1136,71 @@ async def _ensure_encryption_initialized(self) -> bool:
return False

if ok := self._check_command_result(result, 0, {1}):
self._iv = result[4:]
_LOGGER.debug("%s: Encryption init response: %s", self.name, result.hex())
mode_byte = result[2] if len(result) > 2 else None
self._resolve_encryption_mode(mode_byte)
if self._encryption_mode == AESMode.GCM:
iv = result[4:-4]
expected_iv_len = 12
else:
iv = result[4:]
expected_iv_len = 16
if len(iv) != expected_iv_len:
_LOGGER.error(
"%s: Invalid IV length %d for mode %s (expected %d)",
self.name,
len(iv),
self._encryption_mode.name,
expected_iv_len,
)
return False
self._iv = iv
self._cipher = None # Reset cipher when IV changes
_LOGGER.debug("%s: Encryption initialized successfully", self.name)

return ok

async def _execute_disconnect(self) -> None:
"""
Reset encryption state and disconnect.

Clears IV, cipher, and encryption mode so they can be
re-detected on the next connection (e.g., after firmware update).
"""
async with self._connect_lock:
self._iv = None
self._cipher = None
self._encryption_mode = None
await self._execute_disconnect_with_lock()

def _get_cipher(self) -> Cipher:
if self._cipher is None:
if self._iv is None:
raise RuntimeError("Cannot create cipher: IV is None")
self._cipher = Cipher(
algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
)
if self._encryption_mode == AESMode.GCM:
self._cipher = Cipher(
algorithms.AES128(self._encryption_key), modes.GCM(self._iv)
)
else:
self._cipher = Cipher(
algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
)
return self._cipher

def _encrypt(self, data: str) -> str:
def _encrypt(self, data: str) -> tuple[str, str]:
if len(data) == 0:
return ""
return "", ""
if self._iv is None:
raise RuntimeError("Cannot encrypt: IV is None")
encryptor = self._get_cipher().encryptor()
return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
ciphertext = encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()
if self._encryption_mode == AESMode.GCM:
header_hex = encryptor.tag[:2].hex()
# GCM cipher is single-use; clear it so _get_cipher() creates a fresh one
self._cipher = None
else:
header_hex = self._iv[0:2].hex()
return ciphertext.hex(), header_hex

def _decrypt(self, data: bytearray) -> bytes:
if len(data) == 0:
Expand All @@ -1157,9 +1213,40 @@ def _decrypt(self, data: bytearray) -> bytes:
)
return b""
raise RuntimeError("Cannot decrypt: IV is None")
if self._encryption_mode == AESMode.GCM:
# Firmware only returns a 2-byte partial tag which can't be used for
# verification. Use a dummy 16-byte tag and skip finalize() since
# authentication is handled by the firmware.
decryptor = Cipher(
algorithms.AES128(self._encryption_key),
modes.GCM(self._iv, b"\x00" * 16),
).decryptor()
return decryptor.update(data)
decryptor = self._get_cipher().decryptor()
return decryptor.update(data) + decryptor.finalize()

def _increment_gcm_iv(self) -> None:
"""Increment GCM IV by 1 (big-endian). Called after each encrypted command."""
if self._iv is None:
raise RuntimeError("Cannot increment GCM IV: IV is None")
if len(self._iv) != 12:
raise RuntimeError("Cannot increment GCM IV: IV length is not 12 bytes")
iv_int = int.from_bytes(self._iv, "big") + 1
self._iv = iv_int.to_bytes(12, "big")
self._cipher = None

def _resolve_encryption_mode(self, mode_byte: int | None) -> None:
"""Resolve encryption mode from device response when available."""
if mode_byte is None:
raise ValueError("Encryption mode byte is missing")
detected_mode = _normalize_encryption_mode(mode_byte)
if self._encryption_mode is not None and self._encryption_mode != detected_mode:
raise ValueError(
f"Conflicting encryption modes detected: {self._encryption_mode.name} vs {detected_mode.name}"
)
self._encryption_mode = detected_mode
_LOGGER.debug("%s: Detected encryption mode: %s", self.name, detected_mode.name)


class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice):
"""
Expand Down
Loading