|
| 1 | +""" |
| 2 | +Unit tests for daily summary race condition fix (#4594). |
| 3 | +
|
| 4 | +Verifies that: |
| 5 | +1. try_acquire_daily_summary_lock uses atomic SETNX |
| 6 | +2. Only the first caller acquires the lock; concurrent callers are rejected |
| 7 | +3. _send_summary_notification skips work when lock is already held |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import sys |
| 12 | +import types |
| 13 | +import threading |
| 14 | +from unittest.mock import MagicMock, patch |
| 15 | + |
| 16 | +os.environ.setdefault( |
| 17 | + "ENCRYPTION_SECRET", |
| 18 | + "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv", |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +def _stub_module(name: str) -> types.ModuleType: |
| 23 | + mod = types.ModuleType(name) |
| 24 | + sys.modules[name] = mod |
| 25 | + return mod |
| 26 | + |
| 27 | + |
| 28 | +# Stub database package and submodules to avoid Firestore init. |
| 29 | +if "database" not in sys.modules: |
| 30 | + database_mod = _stub_module("database") |
| 31 | + database_mod.__path__ = [] |
| 32 | +else: |
| 33 | + database_mod = sys.modules["database"] |
| 34 | + |
| 35 | +for submodule in [ |
| 36 | + "redis_db", |
| 37 | + "chat", |
| 38 | + "conversations", |
| 39 | + "notifications", |
| 40 | + "users", |
| 41 | + "daily_summaries", |
| 42 | + "_client", |
| 43 | + "auth", |
| 44 | +]: |
| 45 | + full_name = f"database.{submodule}" |
| 46 | + if full_name not in sys.modules: |
| 47 | + mod = _stub_module(full_name) |
| 48 | + setattr(database_mod, submodule, mod) |
| 49 | + |
| 50 | +# Set up mock redis and real lock function |
| 51 | +redis_db_mod = sys.modules["database.redis_db"] |
| 52 | +mock_r = MagicMock() |
| 53 | +redis_db_mod.r = mock_r |
| 54 | + |
| 55 | + |
| 56 | +def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) -> bool: |
| 57 | + result = mock_r.set(f'users:{uid}:daily_summary_lock:{date}', '1', ex=ttl, nx=True) |
| 58 | + return result is not None |
| 59 | + |
| 60 | + |
| 61 | +redis_db_mod.try_acquire_daily_summary_lock = try_acquire_daily_summary_lock |
| 62 | + |
| 63 | +# Set up mock auth |
| 64 | +auth_mod = sys.modules["database.auth"] |
| 65 | +auth_mod.get_user_name = MagicMock(return_value="Test User") |
| 66 | + |
| 67 | +# Set up mock client |
| 68 | +client_mod = sys.modules["database._client"] |
| 69 | +client_mod.db = MagicMock() |
| 70 | +client_mod.document_id_from_seed = MagicMock(return_value="doc-id") |
| 71 | + |
| 72 | +# Stub utils modules that pull in heavy dependencies. |
| 73 | +for name in [ |
| 74 | + "utils.llm.external_integrations", |
| 75 | + "utils.notifications", |
| 76 | + "utils.webhooks", |
| 77 | +]: |
| 78 | + if name not in sys.modules: |
| 79 | + _stub_module(name) |
| 80 | + |
| 81 | +# Add needed attrs to stubs |
| 82 | +utils_llm_ext = sys.modules["utils.llm.external_integrations"] |
| 83 | +utils_llm_ext.get_conversation_summary = MagicMock() |
| 84 | +utils_llm_ext.generate_comprehensive_daily_summary = MagicMock() |
| 85 | + |
| 86 | +utils_notifications = sys.modules["utils.notifications"] |
| 87 | +utils_notifications.send_bulk_notification = MagicMock() |
| 88 | +utils_notifications.send_notification = MagicMock() |
| 89 | + |
| 90 | +utils_webhooks = sys.modules["utils.webhooks"] |
| 91 | +utils_webhooks.day_summary_webhook = MagicMock() |
| 92 | + |
| 93 | +# Stub models |
| 94 | +for name in ["models.notification_message", "models.conversation"]: |
| 95 | + if name not in sys.modules: |
| 96 | + _stub_module(name) |
| 97 | + |
| 98 | +models_notif = sys.modules["models.notification_message"] |
| 99 | +mock_notification_message = MagicMock() |
| 100 | +mock_notification_message.get_message_as_dict = MagicMock(return_value={}) |
| 101 | +models_notif.NotificationMessage = mock_notification_message |
| 102 | + |
| 103 | +models_convo = sys.modules["models.conversation"] |
| 104 | +models_convo.Conversation = MagicMock() |
| 105 | + |
| 106 | +# Now we can safely import |
| 107 | +from utils.other.notifications import _send_summary_notification |
| 108 | + |
| 109 | + |
| 110 | +class TestTryAcquireDailySummaryLock: |
| 111 | + """Tests for the atomic SETNX lock function.""" |
| 112 | + |
| 113 | + def test_lock_acquired_returns_true(self): |
| 114 | + mock_r.set.return_value = True |
| 115 | + assert try_acquire_daily_summary_lock('uid1', '2026-02-07') is True |
| 116 | + mock_r.set.assert_called_with('users:uid1:daily_summary_lock:2026-02-07', '1', ex=7200, nx=True) |
| 117 | + |
| 118 | + def test_lock_already_held_returns_false(self): |
| 119 | + mock_r.set.return_value = None # SETNX returns None when key exists |
| 120 | + assert try_acquire_daily_summary_lock('uid1', '2026-02-07') is False |
| 121 | + |
| 122 | + def test_custom_ttl(self): |
| 123 | + mock_r.set.return_value = True |
| 124 | + try_acquire_daily_summary_lock('uid1', '2026-02-07', ttl=3600) |
| 125 | + mock_r.set.assert_called_with('users:uid1:daily_summary_lock:2026-02-07', '1', ex=3600, nx=True) |
| 126 | + |
| 127 | + def test_different_users_get_separate_locks(self): |
| 128 | + mock_r.set.return_value = True |
| 129 | + try_acquire_daily_summary_lock('uid1', '2026-02-07') |
| 130 | + try_acquire_daily_summary_lock('uid2', '2026-02-07') |
| 131 | + calls = mock_r.set.call_args_list[-2:] |
| 132 | + assert calls[0][0][0] == 'users:uid1:daily_summary_lock:2026-02-07' |
| 133 | + assert calls[1][0][0] == 'users:uid2:daily_summary_lock:2026-02-07' |
| 134 | + |
| 135 | + def test_different_dates_get_separate_locks(self): |
| 136 | + mock_r.set.return_value = True |
| 137 | + try_acquire_daily_summary_lock('uid1', '2026-02-06') |
| 138 | + try_acquire_daily_summary_lock('uid1', '2026-02-07') |
| 139 | + calls = mock_r.set.call_args_list[-2:] |
| 140 | + assert calls[0][0][0] == 'users:uid1:daily_summary_lock:2026-02-06' |
| 141 | + assert calls[1][0][0] == 'users:uid1:daily_summary_lock:2026-02-07' |
| 142 | + |
| 143 | + |
| 144 | +class TestRaceConditionPrevention: |
| 145 | + """Simulate concurrent calls to verify only one wins the lock.""" |
| 146 | + |
| 147 | + def test_concurrent_lock_attempts_only_one_wins(self): |
| 148 | + call_count = 0 |
| 149 | + |
| 150 | + def setnx_side_effect(*args, **kwargs): |
| 151 | + nonlocal call_count |
| 152 | + call_count += 1 |
| 153 | + # First caller wins, rest get None |
| 154 | + return True if call_count == 1 else None |
| 155 | + |
| 156 | + mock_r.set.side_effect = setnx_side_effect |
| 157 | + |
| 158 | + results = [] |
| 159 | + barrier = threading.Barrier(5) |
| 160 | + |
| 161 | + def attempt_lock(): |
| 162 | + barrier.wait() |
| 163 | + result = try_acquire_daily_summary_lock('uid1', '2026-02-07') |
| 164 | + results.append(result) |
| 165 | + |
| 166 | + threads = [threading.Thread(target=attempt_lock) for _ in range(5)] |
| 167 | + for t in threads: |
| 168 | + t.start() |
| 169 | + for t in threads: |
| 170 | + t.join() |
| 171 | + |
| 172 | + assert results.count(True) == 1 |
| 173 | + assert results.count(False) == 4 |
| 174 | + |
| 175 | + # Reset side_effect |
| 176 | + mock_r.set.side_effect = None |
| 177 | + |
| 178 | + def test_redis_error_propagates_no_silent_swallow(self): |
| 179 | + """Transient Redis failure must propagate — no state mutation should happen.""" |
| 180 | + mock_r.set.side_effect = ConnectionError("Redis unavailable") |
| 181 | + |
| 182 | + try: |
| 183 | + try_acquire_daily_summary_lock('uid1', '2026-02-07') |
| 184 | + assert False, "Expected ConnectionError to propagate" |
| 185 | + except ConnectionError: |
| 186 | + pass # Expected: error propagates, no silent swallow |
| 187 | + |
| 188 | + mock_r.set.side_effect = None |
| 189 | + |
| 190 | + |
| 191 | +class TestSendSummaryNotificationLockIntegration: |
| 192 | + """Verify _send_summary_notification respects the lock.""" |
| 193 | + |
| 194 | + @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False) |
| 195 | + def test_skips_when_lock_not_acquired(self, mock_lock): |
| 196 | + convos_db = sys.modules["database.conversations"] |
| 197 | + convos_db.get_conversations = MagicMock() |
| 198 | + gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary |
| 199 | + send_mock = sys.modules["utils.notifications"].send_notification |
| 200 | + |
| 201 | + convos_db.get_conversations.reset_mock() |
| 202 | + gen_mock.reset_mock() |
| 203 | + send_mock.reset_mock() |
| 204 | + |
| 205 | + user_data = ('uid1', ['token1'], 'America/New_York') |
| 206 | + _send_summary_notification(user_data) |
| 207 | + |
| 208 | + mock_lock.assert_called_once() |
| 209 | + convos_db.get_conversations.assert_not_called() |
| 210 | + gen_mock.assert_not_called() |
| 211 | + send_mock.assert_not_called() |
| 212 | + |
| 213 | + @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) |
| 214 | + def test_proceeds_when_lock_acquired(self, mock_lock): |
| 215 | + convos_db = sys.modules["database.conversations"] |
| 216 | + convos_db.get_conversations = MagicMock(return_value=[{'id': 'c1'}]) |
| 217 | + |
| 218 | + gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary |
| 219 | + gen_mock.return_value = {'day_emoji': '!', 'headline': 'Test', 'overview': 'Summary'} |
| 220 | + |
| 221 | + daily_db = sys.modules["database.daily_summaries"] |
| 222 | + daily_db.create_daily_summary = MagicMock(return_value='summary-123') |
| 223 | + |
| 224 | + send_mock = sys.modules["utils.notifications"].send_notification |
| 225 | + send_mock.reset_mock() |
| 226 | + |
| 227 | + user_data = ('uid1', ['token1'], 'America/New_York') |
| 228 | + _send_summary_notification(user_data) |
| 229 | + |
| 230 | + mock_lock.assert_called_once() |
| 231 | + convos_db.get_conversations.assert_called_once() |
| 232 | + gen_mock.assert_called_once() |
| 233 | + send_mock.assert_called_once() |
| 234 | + |
| 235 | + @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True) |
| 236 | + def test_no_conversations_skips_llm(self, mock_lock): |
| 237 | + convos_db = sys.modules["database.conversations"] |
| 238 | + convos_db.get_conversations = MagicMock(return_value=[]) |
| 239 | + |
| 240 | + gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary |
| 241 | + gen_mock.reset_mock() |
| 242 | + |
| 243 | + send_mock = sys.modules["utils.notifications"].send_notification |
| 244 | + send_mock.reset_mock() |
| 245 | + |
| 246 | + user_data = ('uid1', ['token1'], 'America/New_York') |
| 247 | + _send_summary_notification(user_data) |
| 248 | + |
| 249 | + mock_lock.assert_called_once() |
| 250 | + convos_db.get_conversations.assert_called_once() |
| 251 | + gen_mock.assert_not_called() |
| 252 | + send_mock.assert_not_called() |
| 253 | + |
| 254 | + @patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False) |
| 255 | + def test_utc_fallback_still_acquires_lock(self, mock_lock): |
| 256 | + """User data without timezone falls back to UTC; lock must still be called.""" |
| 257 | + convos_db = sys.modules["database.conversations"] |
| 258 | + convos_db.get_conversations = MagicMock() |
| 259 | + convos_db.get_conversations.reset_mock() |
| 260 | + |
| 261 | + gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary |
| 262 | + gen_mock.reset_mock() |
| 263 | + |
| 264 | + # No timezone element in tuple — triggers UTC fallback |
| 265 | + user_data = ('uid1', ['token1']) |
| 266 | + _send_summary_notification(user_data) |
| 267 | + |
| 268 | + mock_lock.assert_called_once() |
| 269 | + # Lock denied, so no downstream work |
| 270 | + convos_db.get_conversations.assert_not_called() |
| 271 | + gen_mock.assert_not_called() |
0 commit comments