Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions backend/database/redis_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,28 +791,7 @@ def delete_speech_profile_duration(uid: str):
# ******************************************************


def set_daily_summary_sent(uid: str, date: str, ttl: int = 60 * 60 * 2):
"""
Mark that a daily summary was sent to user for a specific date.
Default TTL is 2 hours to prevent duplicate sends within the same hour window.
Args:
uid: User ID
date: Date string in YYYY-MM-DD format
ttl: Time to live in seconds (default: 2 hours)
"""
r.set(f'users:{uid}:daily_summary_sent:{date}', '1', ex=ttl)


def has_daily_summary_been_sent(uid: str, date: str) -> bool:
"""
Check if daily summary was already sent to user for a specific date.
Args:
uid: User ID
date: Date string in YYYY-MM-DD format
Returns:
True if summary was already sent for this date, False otherwise
"""
return r.exists(f'users:{uid}:daily_summary_sent:{date}')
def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) -> bool:
"""Atomically acquire lock BEFORE expensive LLM work. Returns True if acquired, False if another job instance already holds it."""
result = r.set(f'users:{uid}:daily_summary_lock:{date}', '1', ex=ttl, nx=True)
return result is not None
1 change: 1 addition & 0 deletions backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ pytest tests/unit/test_process_conversation_usage_context.py -v
pytest tests/unit/test_llm_usage_db.py -v
pytest tests/unit/test_llm_usage_endpoints.py -v
pytest tests/unit/test_app_uid_keyerror.py -v
pytest tests/unit/test_daily_summary_race_condition.py -v
271 changes: 271 additions & 0 deletions backend/tests/unit/test_daily_summary_race_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""
Unit tests for daily summary race condition fix (#4594).

Verifies that:
1. try_acquire_daily_summary_lock uses atomic SETNX
2. Only the first caller acquires the lock; concurrent callers are rejected
3. _send_summary_notification skips work when lock is already held
"""

import os
import sys
import types
import threading
from unittest.mock import MagicMock, patch

os.environ.setdefault(
"ENCRYPTION_SECRET",
"omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv",
)


def _stub_module(name: str) -> types.ModuleType:
mod = types.ModuleType(name)
sys.modules[name] = mod
return mod


# Stub database package and submodules to avoid Firestore init.
if "database" not in sys.modules:
database_mod = _stub_module("database")
database_mod.__path__ = []
else:
database_mod = sys.modules["database"]

for submodule in [
"redis_db",
"chat",
"conversations",
"notifications",
"users",
"daily_summaries",
"_client",
"auth",
]:
full_name = f"database.{submodule}"
if full_name not in sys.modules:
mod = _stub_module(full_name)
setattr(database_mod, submodule, mod)

# Set up mock redis and real lock function
redis_db_mod = sys.modules["database.redis_db"]
mock_r = MagicMock()
redis_db_mod.r = mock_r


def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) -> bool:
result = mock_r.set(f'users:{uid}:daily_summary_lock:{date}', '1', ex=ttl, nx=True)
return result is not None


redis_db_mod.try_acquire_daily_summary_lock = try_acquire_daily_summary_lock
Comment on lines +56 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test is testing a local, copy-pasted version of try_acquire_daily_summary_lock, not the actual implementation from database/redis_db.py. This means that if the logic in the original file changes, this test might still pass, giving a false sense of security. The test should import and test the production code directly.

To fix this, you should import the actual function and mock its dependencies. Given the extensive module stubbing in this file, a robust way to achieve this is to patch the redis.Redis client at the top of the file before any application modules are imported. This will prevent the real Redis client from being instantiated.

Example of how you could refactor the test setup:

# At the top of your test file
from unittest.mock import MagicMock, patch

mock_r = MagicMock()
patch('redis.Redis', return_value=mock_r).start()

# Now you can safely import from your application modules
# You might need to adjust the existing module stubbing
from database.redis_db import try_acquire_daily_summary_lock

# ...

class TestTryAcquireDailySummaryLock:
    """Tests for the atomic SETNX lock function."""

    def test_lock_acquired_returns_true(self):
        mock_r.reset_mock()
        mock_r.set.return_value = True
        assert try_acquire_daily_summary_lock('uid1', '2026-02-07') is True
        mock_r.set.assert_called_with('users:uid1:daily_summary_lock:2026-02-07', '1', ex=7200, nx=True)

    # ... other tests for try_acquire_daily_summary_lock

This change is critical to ensure the tests are correctly validating the production code.


# Set up mock auth
auth_mod = sys.modules["database.auth"]
auth_mod.get_user_name = MagicMock(return_value="Test User")

# Set up mock client
client_mod = sys.modules["database._client"]
client_mod.db = MagicMock()
client_mod.document_id_from_seed = MagicMock(return_value="doc-id")

# Stub utils modules that pull in heavy dependencies.
for name in [
"utils.llm.external_integrations",
"utils.notifications",
"utils.webhooks",
]:
if name not in sys.modules:
_stub_module(name)

# Add needed attrs to stubs
utils_llm_ext = sys.modules["utils.llm.external_integrations"]
utils_llm_ext.get_conversation_summary = MagicMock()
utils_llm_ext.generate_comprehensive_daily_summary = MagicMock()

utils_notifications = sys.modules["utils.notifications"]
utils_notifications.send_bulk_notification = MagicMock()
utils_notifications.send_notification = MagicMock()

utils_webhooks = sys.modules["utils.webhooks"]
utils_webhooks.day_summary_webhook = MagicMock()

# Stub models
for name in ["models.notification_message", "models.conversation"]:
if name not in sys.modules:
_stub_module(name)

models_notif = sys.modules["models.notification_message"]
mock_notification_message = MagicMock()
mock_notification_message.get_message_as_dict = MagicMock(return_value={})
models_notif.NotificationMessage = mock_notification_message

models_convo = sys.modules["models.conversation"]
models_convo.Conversation = MagicMock()

# Now we can safely import
from utils.other.notifications import _send_summary_notification


class TestTryAcquireDailySummaryLock:
"""Tests for the atomic SETNX lock function."""

def test_lock_acquired_returns_true(self):
mock_r.set.return_value = True
assert try_acquire_daily_summary_lock('uid1', '2026-02-07') is True
mock_r.set.assert_called_with('users:uid1:daily_summary_lock:2026-02-07', '1', ex=7200, nx=True)

def test_lock_already_held_returns_false(self):
mock_r.set.return_value = None # SETNX returns None when key exists
assert try_acquire_daily_summary_lock('uid1', '2026-02-07') is False

def test_custom_ttl(self):
mock_r.set.return_value = True
try_acquire_daily_summary_lock('uid1', '2026-02-07', ttl=3600)
mock_r.set.assert_called_with('users:uid1:daily_summary_lock:2026-02-07', '1', ex=3600, nx=True)

def test_different_users_get_separate_locks(self):
mock_r.set.return_value = True
try_acquire_daily_summary_lock('uid1', '2026-02-07')
try_acquire_daily_summary_lock('uid2', '2026-02-07')
calls = mock_r.set.call_args_list[-2:]
assert calls[0][0][0] == 'users:uid1:daily_summary_lock:2026-02-07'
assert calls[1][0][0] == 'users:uid2:daily_summary_lock:2026-02-07'

def test_different_dates_get_separate_locks(self):
mock_r.set.return_value = True
try_acquire_daily_summary_lock('uid1', '2026-02-06')
try_acquire_daily_summary_lock('uid1', '2026-02-07')
calls = mock_r.set.call_args_list[-2:]
assert calls[0][0][0] == 'users:uid1:daily_summary_lock:2026-02-06'
assert calls[1][0][0] == 'users:uid1:daily_summary_lock:2026-02-07'


class TestRaceConditionPrevention:
"""Simulate concurrent calls to verify only one wins the lock."""

def test_concurrent_lock_attempts_only_one_wins(self):
call_count = 0

def setnx_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
# First caller wins, rest get None
return True if call_count == 1 else None

mock_r.set.side_effect = setnx_side_effect

results = []
barrier = threading.Barrier(5)

def attempt_lock():
barrier.wait()
result = try_acquire_daily_summary_lock('uid1', '2026-02-07')
results.append(result)

threads = [threading.Thread(target=attempt_lock) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

assert results.count(True) == 1
assert results.count(False) == 4

# Reset side_effect
mock_r.set.side_effect = None

def test_redis_error_propagates_no_silent_swallow(self):
"""Transient Redis failure must propagate — no state mutation should happen."""
mock_r.set.side_effect = ConnectionError("Redis unavailable")

try:
try_acquire_daily_summary_lock('uid1', '2026-02-07')
assert False, "Expected ConnectionError to propagate"
except ConnectionError:
pass # Expected: error propagates, no silent swallow

mock_r.set.side_effect = None


class TestSendSummaryNotificationLockIntegration:
"""Verify _send_summary_notification respects the lock."""

@patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False)
def test_skips_when_lock_not_acquired(self, mock_lock):
convos_db = sys.modules["database.conversations"]
convos_db.get_conversations = MagicMock()
gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary
send_mock = sys.modules["utils.notifications"].send_notification

convos_db.get_conversations.reset_mock()
gen_mock.reset_mock()
send_mock.reset_mock()

user_data = ('uid1', ['token1'], 'America/New_York')
_send_summary_notification(user_data)

mock_lock.assert_called_once()
convos_db.get_conversations.assert_not_called()
gen_mock.assert_not_called()
send_mock.assert_not_called()

@patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True)
def test_proceeds_when_lock_acquired(self, mock_lock):
convos_db = sys.modules["database.conversations"]
convos_db.get_conversations = MagicMock(return_value=[{'id': 'c1'}])

gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary
gen_mock.return_value = {'day_emoji': '!', 'headline': 'Test', 'overview': 'Summary'}

daily_db = sys.modules["database.daily_summaries"]
daily_db.create_daily_summary = MagicMock(return_value='summary-123')

send_mock = sys.modules["utils.notifications"].send_notification
send_mock.reset_mock()

user_data = ('uid1', ['token1'], 'America/New_York')
_send_summary_notification(user_data)

mock_lock.assert_called_once()
convos_db.get_conversations.assert_called_once()
gen_mock.assert_called_once()
send_mock.assert_called_once()

@patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True)
def test_no_conversations_skips_llm(self, mock_lock):
convos_db = sys.modules["database.conversations"]
convos_db.get_conversations = MagicMock(return_value=[])

gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary
gen_mock.reset_mock()

send_mock = sys.modules["utils.notifications"].send_notification
send_mock.reset_mock()

user_data = ('uid1', ['token1'], 'America/New_York')
_send_summary_notification(user_data)

mock_lock.assert_called_once()
convos_db.get_conversations.assert_called_once()
gen_mock.assert_not_called()
send_mock.assert_not_called()

@patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=False)
def test_utc_fallback_still_acquires_lock(self, mock_lock):
"""User data without timezone falls back to UTC; lock must still be called."""
convos_db = sys.modules["database.conversations"]
convos_db.get_conversations = MagicMock()
convos_db.get_conversations.reset_mock()

gen_mock = sys.modules["utils.llm.external_integrations"].generate_comprehensive_daily_summary
gen_mock.reset_mock()

# No timezone element in tuple — triggers UTC fallback
user_data = ('uid1', ['token1'])
_send_summary_notification(user_data)

mock_lock.assert_called_once()
# Lock denied, so no downstream work
convos_db.get_conversations.assert_not_called()
gen_mock.assert_not_called()
16 changes: 5 additions & 11 deletions backend/utils/other/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import database.chat as chat_db
import database.conversations as conversations_db
import database.notifications as notification_db
from database.redis_db import set_daily_summary_sent, has_daily_summary_been_sent
from database.redis_db import try_acquire_daily_summary_lock
from models.notification_message import NotificationMessage
from models.conversation import Conversation
from utils.llm.external_integrations import get_conversation_summary
from utils.llm.external_integrations import get_conversation_summary, generate_comprehensive_daily_summary
from utils.notifications import send_bulk_notification, send_notification
from utils.webhooks import day_summary_webhook
import database.daily_summaries as daily_summaries_db


def should_run_job():
Expand Down Expand Up @@ -115,8 +116,8 @@ def _send_summary_notification(user_data: tuple):
display_date = now_utc.date()
date_str = display_date.strftime('%Y-%m-%d')

# Check if summary already sent for this date
if has_daily_summary_been_sent(uid, date_str):
# Atomically acquire lock BEFORE expensive LLM work to prevent race condition
if not try_acquire_daily_summary_lock(uid, date_str):
return

conversations_data = conversations_db.get_conversations(uid, start_date=start_date_utc, end_date=end_date_utc)
Expand All @@ -125,10 +126,6 @@ def _send_summary_notification(user_data: tuple):

conversations = [Conversation(**convo_data) for convo_data in conversations_data]

# Generate comprehensive daily summary
from utils.llm.external_integrations import generate_comprehensive_daily_summary
import database.daily_summaries as daily_summaries_db

summary_data = generate_comprehensive_daily_summary(uid, conversations, date_str, start_date_utc, end_date_utc)

# Store in database
Expand Down Expand Up @@ -158,9 +155,6 @@ def _send_summary_notification(user_data: tuple):
uid, daily_summary_title, summary_body, NotificationMessage.get_message_as_dict(ai_message), tokens=tokens
)

# Mark that summary was sent for this date
set_daily_summary_sent(uid, date_str)


async def _send_bulk_summary_notification(users: list):
loop = asyncio.get_running_loop()
Expand Down