diff --git a/switchbot/devices/device.py b/switchbot/devices/device.py index 475657c1..ce02deca 100644 --- a/switchbot/devices/device.py +++ b/switchbot/devices/device.py @@ -8,6 +8,7 @@ import time from collections.abc import Callable from dataclasses import replace +from enum import IntEnum from typing import Any, TypeVar, cast from uuid import UUID @@ -142,6 +143,21 @@ class SwitchbotOperationError(Exception): """Raised when an operation fails.""" +class AESMode(IntEnum): + """Supported AES modes for encrypted devices.""" + + CTR = 0 + GCM = 1 + + +def _normalize_encryption_mode(mode: int) -> AESMode: + """Normalize encryption mode to AESMode (only 0/1 allowed).""" + try: + return AESMode(mode) + except (TypeError, ValueError) as exc: + raise ValueError(f"Unsupported encryption mode: {mode}") from exc + + def _sb_uuid(comms_type: str = "service") -> UUID | str: """Return Switchbot UUID.""" _uuid = {"tx": "002", "rx": "003", "service": "d00"} @@ -982,7 +998,8 @@ def __init__( self._key_id = key_id self._encryption_key = bytearray.fromhex(encryption_key) self._iv: bytes | None = None - self._cipher: bytes | None = None + self._cipher: Cipher | None = None + self._encryption_mode: AESMode | None = None super().__init__(device, None, interface, **kwargs) self._model = model @@ -1081,9 +1098,8 @@ async def _send_command( _LOGGER.error("Failed to initialize encryption") return None - encrypted = ( - key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:]) - ) + ciphertext_hex, header_hex = self._encrypt(key[2:]) + encrypted = key[:2] + self._key_id + header_hex + ciphertext_hex command = bytearray.fromhex(self._commandkey(encrypted)) _LOGGER.debug("%s: Scheduling command %s", self.name, command.hex()) max_attempts = retry + 1 @@ -1093,7 +1109,10 @@ async def _send_command( ) if result is None: return None - return result[:1] + self._decrypt(result[4:]) + decrypted = self._decrypt(result[4:]) + if self._encryption_mode == AESMode.GCM: + self._increment_gcm_iv() + return result[:1] + decrypted async def _ensure_encryption_initialized(self) -> bool: """Ensure encryption is initialized, must be called with operation lock held.""" @@ -1117,34 +1136,71 @@ async def _ensure_encryption_initialized(self) -> bool: return False if ok := self._check_command_result(result, 0, {1}): - self._iv = result[4:] + _LOGGER.debug("%s: Encryption init response: %s", self.name, result.hex()) + mode_byte = result[2] if len(result) > 2 else None + self._resolve_encryption_mode(mode_byte) + if self._encryption_mode == AESMode.GCM: + iv = result[4:-4] + expected_iv_len = 12 + else: + iv = result[4:] + expected_iv_len = 16 + if len(iv) != expected_iv_len: + _LOGGER.error( + "%s: Invalid IV length %d for mode %s (expected %d)", + self.name, + len(iv), + self._encryption_mode.name, + expected_iv_len, + ) + return False + self._iv = iv self._cipher = None # Reset cipher when IV changes _LOGGER.debug("%s: Encryption initialized successfully", self.name) return ok async def _execute_disconnect(self) -> None: + """ + Reset encryption state and disconnect. + + Clears IV, cipher, and encryption mode so they can be + re-detected on the next connection (e.g., after firmware update). + """ async with self._connect_lock: self._iv = None self._cipher = None + self._encryption_mode = None await self._execute_disconnect_with_lock() def _get_cipher(self) -> Cipher: if self._cipher is None: if self._iv is None: raise RuntimeError("Cannot create cipher: IV is None") - self._cipher = Cipher( - algorithms.AES128(self._encryption_key), modes.CTR(self._iv) - ) + if self._encryption_mode == AESMode.GCM: + self._cipher = Cipher( + algorithms.AES128(self._encryption_key), modes.GCM(self._iv) + ) + else: + self._cipher = Cipher( + algorithms.AES128(self._encryption_key), modes.CTR(self._iv) + ) return self._cipher - def _encrypt(self, data: str) -> str: + def _encrypt(self, data: str) -> tuple[str, str]: if len(data) == 0: - return "" + return "", "" if self._iv is None: raise RuntimeError("Cannot encrypt: IV is None") encryptor = self._get_cipher().encryptor() - return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex() + ciphertext = encryptor.update(bytearray.fromhex(data)) + encryptor.finalize() + if self._encryption_mode == AESMode.GCM: + header_hex = encryptor.tag[:2].hex() + # GCM cipher is single-use; clear it so _get_cipher() creates a fresh one + self._cipher = None + else: + header_hex = self._iv[0:2].hex() + return ciphertext.hex(), header_hex def _decrypt(self, data: bytearray) -> bytes: if len(data) == 0: @@ -1157,9 +1213,40 @@ def _decrypt(self, data: bytearray) -> bytes: ) return b"" raise RuntimeError("Cannot decrypt: IV is None") + if self._encryption_mode == AESMode.GCM: + # Firmware only returns a 2-byte partial tag which can't be used for + # verification. Use a dummy 16-byte tag and skip finalize() since + # authentication is handled by the firmware. + decryptor = Cipher( + algorithms.AES128(self._encryption_key), + modes.GCM(self._iv, b"\x00" * 16), + ).decryptor() + return decryptor.update(data) decryptor = self._get_cipher().decryptor() return decryptor.update(data) + decryptor.finalize() + def _increment_gcm_iv(self) -> None: + """Increment GCM IV by 1 (big-endian). Called after each encrypted command.""" + if self._iv is None: + raise RuntimeError("Cannot increment GCM IV: IV is None") + if len(self._iv) != 12: + raise RuntimeError("Cannot increment GCM IV: IV length is not 12 bytes") + iv_int = int.from_bytes(self._iv, "big") + 1 + self._iv = iv_int.to_bytes(12, "big") + self._cipher = None + + def _resolve_encryption_mode(self, mode_byte: int | None) -> None: + """Resolve encryption mode from device response when available.""" + if mode_byte is None: + raise ValueError("Encryption mode byte is missing") + detected_mode = _normalize_encryption_mode(mode_byte) + if self._encryption_mode is not None and self._encryption_mode != detected_mode: + raise ValueError( + f"Conflicting encryption modes detected: {self._encryption_mode.name} vs {detected_mode.name}" + ) + self._encryption_mode = detected_mode + _LOGGER.debug("%s: Detected encryption mode: %s", self.name, detected_mode.name) + class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice): """ diff --git a/tests/test_encrypted_device.py b/tests/test_encrypted_device.py index 1ed071a7..941cf076 100644 --- a/tests/test_encrypted_device.py +++ b/tests/test_encrypted_device.py @@ -10,9 +10,7 @@ from bleak.exc import BleakDBusError from switchbot import SwitchbotModel -from switchbot.devices.device import ( - SwitchbotEncryptedDevice, -) +from switchbot.devices.device import AESMode, SwitchbotEncryptedDevice from .test_adv_parser import generate_ble_device @@ -133,7 +131,8 @@ async def test_send_command_iv_already_initialized() -> None: patch.object(device, "_decrypt") as mock_decrypt, ): mock_encrypt.return_value = ( - "656e637279707465645f64617461" # "encrypted_data" in hex + "656e637279707465645f64617461", # "encrypted_data" in hex + "abcd", ) mock_decrypt.return_value = b"decrypted_response" mock_send.return_value = b"\x01\x00\x00\x00encrypted_response" @@ -171,7 +170,7 @@ async def simulate_disconnect() -> None: patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): - mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex + mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00response" @@ -210,6 +209,166 @@ async def test_ensure_encryption_initialized_with_lock_held() -> None: assert device._cipher is None # Should be reset when IV changes +@pytest.mark.asyncio +async def test_ensure_encryption_initialized_sets_gcm_mode() -> None: + """Test that GCM mode is detected from device response.""" + device = create_encrypted_device() + + gcm_iv = b"\x01" * 12 + response = b"\x01\x00\x01\x00" + gcm_iv + b"\x00\x00\x00\x00" + + async with device._operation_lock: + with patch.object(device, "_send_command_locked_with_retry") as mock_send: + mock_send.return_value = response + + result = await device._ensure_encryption_initialized() + + assert result is True + assert device._encryption_mode == AESMode.GCM + assert device._iv == gcm_iv + + +@pytest.mark.asyncio +async def test_ensure_encryption_initialized_invalid_iv_length_gcm() -> None: + """Test that invalid IV length for GCM mode returns False.""" + device = create_encrypted_device() + + # GCM expects 12 bytes IV, but response has wrong length (only 8 bytes after trimming) + response = b"\x01\x00\x01\x00" + b"\x01" * 8 + b"\x00\x00\x00\x00" + + async with device._operation_lock: + with patch.object(device, "_send_command_locked_with_retry") as mock_send: + mock_send.return_value = response + + result = await device._ensure_encryption_initialized() + + assert result is False + assert device._iv is None + + +@pytest.mark.asyncio +async def test_ensure_encryption_initialized_invalid_iv_length_ctr() -> None: + """Test that invalid IV length for CTR mode returns False.""" + device = create_encrypted_device() + + # CTR expects 16 bytes IV, but response has only 8 bytes + response = b"\x01\x00\x00\x00" + b"\x01" * 8 + + async with device._operation_lock: + with patch.object(device, "_send_command_locked_with_retry") as mock_send: + mock_send.return_value = response + + result = await device._ensure_encryption_initialized() + + assert result is False + assert device._iv is None + + +@pytest.mark.asyncio +async def test_device_with_gcm_mode() -> None: + """Test that device initializes correctly in GCM mode and increments GCM IV.""" + device = create_encrypted_device() + device._encryption_mode = AESMode.GCM + device._iv = b"\x01" * 12 + + with ( + patch.object(device, "_ensure_encryption_initialized") as mock_ensure, + patch.object(device, "_send_command_locked_with_retry") as mock_send, + patch.object(device, "_decrypt") as mock_decrypt, + patch.object(device, "_encrypt") as mock_encrypt, + patch.object(device, "_increment_gcm_iv") as mock_inc_iv, + ): + mock_ensure.return_value = True + mock_encrypt.return_value = ("10203040", "abcd") + mock_send.return_value = b"\x01\x00\x00\x00\x10\x20\x30\x40" + mock_decrypt.return_value = b"\x10\x20\x30\x40" + + await device._send_command("570200") + + mock_inc_iv.assert_called_once() + + +@pytest.mark.asyncio +async def test_resolve_encryption_mode_invalid() -> None: + """Test that invalid mode byte raises error.""" + device = create_encrypted_device() + + with pytest.raises(ValueError, match="Unsupported encryption mode"): + device._resolve_encryption_mode(2) + + +@pytest.mark.asyncio +async def test_resolve_encryption_mode_missing() -> None: + """Test that missing mode byte raises error.""" + device = create_encrypted_device() + + with pytest.raises(ValueError, match="Encryption mode byte is missing"): + device._resolve_encryption_mode(None) + + +@pytest.mark.asyncio +async def test_resolve_encryption_mode_conflict() -> None: + """Test that conflicting encryption modes raise error.""" + device = create_encrypted_device() + device._encryption_mode = AESMode.CTR + + with pytest.raises( + ValueError, + match="Conflicting encryption modes detected: CTR vs GCM", + ): + device._resolve_encryption_mode(1) + + +@pytest.mark.asyncio +async def test_increment_gcm_iv() -> None: + """Test GCM IV increment logic.""" + device = create_encrypted_device() + device._encryption_mode = AESMode.GCM + device._iv = b"\x00" * 11 + b"\x01" + + device._increment_gcm_iv() + + assert device._iv == b"\x00" * 11 + b"\x02" + assert device._cipher is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("initial_iv", "expected_exception", "expected_message"), + [ + (None, RuntimeError, "Cannot increment GCM IV: IV is None"), + ( + b"\x00" * 10, + RuntimeError, + "Cannot increment GCM IV: IV length is not 12 bytes", + ), + ], +) +async def test_increment_gcm_iv_invalid( + initial_iv, expected_exception, expected_message +) -> None: + """Test GCM IV increment with invalid IV states.""" + device = create_encrypted_device() + device._encryption_mode = AESMode.GCM + device._iv = initial_iv + + with pytest.raises(expected_exception, match=expected_message): + device._increment_gcm_iv() + + +@pytest.mark.asyncio +async def test_gcm_encrypt_decrypt_without_finalize() -> None: + """Test GCM encrypt/decrypt works without finalize in decrypt.""" + device = create_encrypted_device() + device._encryption_mode = AESMode.GCM + device._iv = b"\x10" * 12 + + ciphertext_hex, _ = device._encrypt("48656c6c6f") + decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex)) + + assert decrypted.hex() == "48656c6c6f" + + @pytest.mark.asyncio async def test_ensure_encryption_initialized_failure() -> None: """Test _ensure_encryption_initialized when IV initialization fails.""" @@ -233,12 +392,13 @@ async def test_encrypt_decrypt_with_valid_iv() -> None: device._iv = b"\x00" * 16 # Use zeros for predictable test # Test encryption - encrypted = device._encrypt("48656c6c6f") # "Hello" in hex - assert isinstance(encrypted, str) - assert len(encrypted) > 0 + ciphertext_hex, header_hex = device._encrypt("48656c6c6f") # "Hello" in hex + assert isinstance(ciphertext_hex, str) + assert isinstance(header_hex, str) + assert len(ciphertext_hex) > 0 # Test decryption - decrypted = device._decrypt(bytearray.fromhex(encrypted)) + decrypted = device._decrypt(bytearray.fromhex(ciphertext_hex)) assert decrypted.hex() == "48656c6c6f" @@ -278,6 +438,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None: device = create_encrypted_device() device._iv = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x12\x34\x56\x78\x9a\xbc\xde\xf0" device._cipher = None # type: ignore[assignment] + device._encryption_mode = AESMode.CTR # Mock client mock_client = AsyncMock() @@ -288,6 +449,7 @@ async def test_execute_disconnect_clears_encryption_state() -> None: assert device._iv is None assert device._cipher is None + assert device._encryption_mode is None mock_disconnect.assert_called_once() @@ -304,7 +466,7 @@ async def test_concurrent_commands_with_same_device() -> None: patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): - mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex + mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" mock_send.return_value = b"\x01\x00\x00\x00data" @@ -337,7 +499,7 @@ async def test_command_retry_with_encryption() -> None: patch.object(device, "_encrypt") as mock_encrypt, patch.object(device, "_decrypt") as mock_decrypt, ): - mock_encrypt.return_value = "656e63727970746564" # "encrypted" in hex + mock_encrypt.return_value = ("656e63727970746564", "abcd") mock_decrypt.return_value = b"response" # First attempt fails, second succeeds @@ -360,7 +522,7 @@ async def test_empty_data_encryption_decryption() -> None: # Test empty encryption encrypted = device._encrypt("") - assert encrypted == "" + assert encrypted == ("", "") # Test empty decryption decrypted = device._decrypt(bytearray())