Skip to content

Commit a15eda1

Browse files
authored
Add usage quota to AI creation of flashcards and collections (#82)
1 parent 1cc38ba commit a15eda1

File tree

20 files changed

+419
-52
lines changed

20 files changed

+419
-52
lines changed

backend/.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ POSTGRES_USER=postgres
1919
POSTGRES_PASSWORD=changethis
2020

2121
# AI
22+
AI_MAX_USAGE_QUOTA=30
23+
AI_QUOTA_TIME_RANGE_DAYS=30 # time in days
2224
AI_MODEL="dummy_model"
2325
AI_API_KEY="dummy_api_key"
2426

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Add AI Usage Quota tables
2+
3+
Revision ID: d1ea38d75310
4+
Revises: cb16ae472c1e
5+
Create Date: 2025-05-04 09:59:20.325131
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = 'd1ea38d75310'
15+
down_revision = 'cb16ae472c1e'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
op.create_table(
22+
'aiusagequota',
23+
sa.Column('id', sa.UUID(), primary_key=True, nullable=False),
24+
sa.Column('user_id', sa.UUID(), sa.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False),
25+
sa.Column('usage_count', sa.Integer, default=0, nullable=False),
26+
sa.Column('last_reset_time', sa.DateTime(timezone=True), nullable=False),
27+
)
28+
29+
30+
def downgrade():
31+
op.drop_table('aiusagequota')

backend/src/core/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
6464

6565
AI_API_KEY: str | None = None
6666
AI_MODEL: str | None = None
67+
AI_MAX_USAGE_QUOTA: int = 30
68+
AI_QUOTA_TIME_RANGE_DAYS: int = 1
6769

6870
COLLECTION_GENERATION_PROMPT: str | None = None
6971
CARD_GENERATION_PROMPT: str | None = None
@@ -92,6 +94,13 @@ def _enforce_non_default_secrets(self) -> Self:
9294
"FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
9395
)
9496

97+
if self.AI_MAX_USAGE_QUOTA is None or self.AI_MAX_USAGE_QUOTA <= 0:
98+
raise ValueError("AI_MAX_USAGE_QUOTA must be set to a positive integer.")
99+
if self.AI_QUOTA_TIME_RANGE_DAYS is None or self.AI_QUOTA_TIME_RANGE_DAYS <= 0:
100+
raise ValueError(
101+
"AI_QUOTA_TIME_RANGE_DAYS must be set to a positive integer."
102+
)
103+
95104
return self
96105

97106

backend/src/flashcards/api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from src.ai_models.gemini import GeminiProviderDep
88
from src.ai_models.gemini.exceptions import AIGenerationError
99
from src.auth.services import CurrentUser, SessionDep
10+
from src.users.services import check_and_increment_ai_usage_quota
1011

1112
from . import services
1213
from .exceptions import EmptyCollectionError
@@ -52,6 +53,12 @@ async def create_collection(
5253

5354
if collection_in.prompt:
5455
try:
56+
if not await asyncio.to_thread(
57+
lambda: check_and_increment_ai_usage_quota(session, current_user)
58+
):
59+
raise HTTPException(
60+
status_code=429, detail="Quota for AI usage is reached."
61+
)
5562
flashcard_collection = await services.generate_ai_collection(
5663
provider, collection_in.prompt
5764
)
@@ -145,6 +152,12 @@ async def create_card(
145152
if not access_checked:
146153
raise HTTPException(status_code=404, detail="Collection not found")
147154
if card_in.prompt:
155+
if not await asyncio.to_thread(
156+
lambda: check_and_increment_ai_usage_quota(session, current_user)
157+
):
158+
raise HTTPException(
159+
status_code=429, detail="Quota for AI usage is reached."
160+
)
148161
card_base = await services.generate_ai_flashcard(card_in.prompt, provider)
149162
card_in.front = card_base.front
150163
card_in.back = card_base.back

backend/src/users/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from src.auth.services import CurrentUser, SessionDep
66
from src.core.config import settings
7-
from src.users.schemas import UserCreate, UserPublic, UserRegister
7+
from src.users.schemas import AIUsageQuota, UserCreate, UserPublic, UserRegister
88

99
from . import services
1010

@@ -38,3 +38,8 @@ def register_user(session: SessionDep, user_in: UserRegister) -> Any:
3838
user_create = UserCreate.model_validate(user_in)
3939
user = services.create_user(session=session, user_create=user_create)
4040
return user
41+
42+
43+
@router.get("/users/me/ai-usage-quota", response_model=AIUsageQuota)
44+
def get_my_ai_usage_quota(current_user: CurrentUser):
45+
return services.get_ai_usage_quota_for_user(current_user)

backend/src/users/models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import uuid
2+
from datetime import datetime, timezone
23
from typing import TYPE_CHECKING
34

4-
from sqlmodel import Field, Relationship
5+
from sqlmodel import Field, Relationship, SQLModel
56

67
from src.users.schemas import UserBase
78

@@ -22,3 +23,19 @@ class User(UserBase, table=True):
2223
cascade_delete=True,
2324
sa_relationship_kwargs={"lazy": "selectin"},
2425
)
26+
ai_usage_quota: "AIUsageQuota" = Relationship(
27+
back_populates="user",
28+
sa_relationship_kwargs={"uselist": False, "lazy": "selectin"},
29+
)
30+
31+
32+
class AIUsageQuota(SQLModel, table=True):
33+
id: uuid.UUID | None = Field(default_factory=uuid.uuid4, primary_key=True)
34+
user_id: uuid.UUID = Field(
35+
foreign_key="user.id", index=True, unique=True, ondelete="CASCADE"
36+
)
37+
usage_count: int = Field(default=0)
38+
last_reset_time: datetime = Field(
39+
default_factory=lambda: datetime.now(timezone.utc)
40+
)
41+
user: "User" = Relationship(back_populates="ai_usage_quota")

backend/src/users/schemas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from datetime import datetime
23

34
from pydantic import EmailStr
45
from sqlmodel import Field, SQLModel
@@ -26,3 +27,9 @@ class UserRegister(SQLModel):
2627

2728
class UserPublic(UserBase):
2829
id: uuid.UUID
30+
31+
32+
class AIUsageQuota(SQLModel):
33+
usage_count: int
34+
max_usage_allowed: int
35+
reset_date: datetime

backend/src/users/services.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import uuid
2+
from datetime import datetime, timedelta, timezone
23
from typing import Any
34

4-
from sqlmodel import Session, select
5+
from sqlalchemy.exc import IntegrityError
6+
from sqlmodel import Session, select, update
57

68
from src.auth.services import get_password_hash
9+
from src.core.config import settings
10+
from src.users.models import AIUsageQuota as AIUsageQuotaModel
711
from src.users.models import User
8-
from src.users.schemas import UserCreate, UserUpdate
12+
from src.users.schemas import AIUsageQuota, UserCreate, UserUpdate
913

1014

1115
def create_user(*, session: Session, user_create: UserCreate) -> User:
@@ -42,3 +46,67 @@ def get_user_by_email(*, session: Session, email: str) -> User | None:
4246
statement = select(User).where(User.email == email)
4347
session_user = session.exec(statement).first()
4448
return session_user
49+
50+
51+
def get_ai_usage_quota_for_user(user: User) -> AIUsageQuota:
52+
quota = user.ai_usage_quota
53+
if not quota:
54+
return AIUsageQuota(
55+
usage_count=0,
56+
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
57+
reset_date=(
58+
datetime.now(timezone.utc)
59+
+ timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
60+
),
61+
)
62+
return AIUsageQuota(
63+
usage_count=quota.usage_count,
64+
max_usage_allowed=settings.AI_MAX_USAGE_QUOTA,
65+
reset_date=(
66+
quota.last_reset_time + timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
67+
),
68+
)
69+
70+
71+
def check_and_increment_ai_usage_quota(session: Session, user: User) -> bool:
72+
now = datetime.now(timezone.utc)
73+
reset_threshold = now - timedelta(days=settings.AI_QUOTA_TIME_RANGE_DAYS)
74+
75+
if not user.ai_usage_quota:
76+
try:
77+
quota = AIUsageQuotaModel(
78+
user_id=user.id, usage_count=1, last_reset_time=now
79+
)
80+
session.add(quota)
81+
session.commit()
82+
return True
83+
except IntegrityError:
84+
session.rollback()
85+
86+
session.refresh(user)
87+
88+
result_reset = session.exec(
89+
update(AIUsageQuotaModel)
90+
.where(
91+
(AIUsageQuotaModel.user_id == user.id)
92+
& (AIUsageQuotaModel.last_reset_time <= reset_threshold)
93+
)
94+
.values(usage_count=1, last_reset_time=now)
95+
)
96+
97+
if result_reset.rowcount > 0:
98+
session.commit()
99+
return True
100+
101+
result_increment = session.exec(
102+
update(AIUsageQuotaModel)
103+
.where(
104+
(AIUsageQuotaModel.user_id == user.id)
105+
& (AIUsageQuotaModel.last_reset_time > reset_threshold)
106+
& (AIUsageQuotaModel.usage_count < settings.AI_MAX_USAGE_QUOTA)
107+
)
108+
.values(usage_count=AIUsageQuotaModel.usage_count + 1)
109+
)
110+
111+
session.commit()
112+
return result_increment.rowcount > 0

backend/tests/flashcards/card/test_api.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -394,18 +394,22 @@ def test_create_card_with_prompt_ai(
394394
with patch(
395395
"src.flashcards.services.generate_ai_flashcard", new_callable=AsyncMock
396396
) as mock_ai:
397-
mock_ai.return_value = type("Card", (), ai_card)()
398-
card_data = {"prompt": prompt, "front": "", "back": ""}
399-
rsp = client.post(
400-
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
401-
json=card_data,
402-
headers=normal_user_token_headers,
403-
)
404-
assert rsp.status_code == 200
405-
content = rsp.json()
406-
assert content["front"] == ai_card["front"]
407-
assert content["back"] == ai_card["back"]
408-
mock_ai.assert_called_once_with(prompt, ANY)
397+
with patch(
398+
"src.flashcards.api.check_and_increment_ai_usage_quota"
399+
) as mock_quota_check:
400+
mock_quota_check.return_value = True
401+
mock_ai.return_value = type("Card", (), ai_card)()
402+
card_data = {"prompt": prompt, "front": "", "back": ""}
403+
rsp = client.post(
404+
f"{settings.API_V1_STR}/collections/{collection_id}/cards/",
405+
json=card_data,
406+
headers=normal_user_token_headers,
407+
)
408+
assert rsp.status_code == 200
409+
content = rsp.json()
410+
assert content["front"] == ai_card["front"]
411+
assert content["back"] == ai_card["back"]
412+
mock_ai.assert_called_once_with(prompt, ANY)
409413

410414

411415
def test_create_card_with_prompt_too_long(

backend/tests/flashcards/collection/test_api.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,25 +83,29 @@ def test_create_collection_with_prompt(
8383
with patch(
8484
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
8585
) as mock_ai_generate:
86-
mock_ai_generate.return_value = mock_collection
87-
88-
rsp = client.post(
89-
f"{settings.API_V1_STR}/collections/",
90-
json=collection_data.model_dump(),
91-
headers=normal_user_token_headers,
92-
)
93-
94-
assert rsp.status_code == 200
95-
content = rsp.json()
96-
assert content["name"] == collection_data.name
97-
assert "id" in content
98-
assert isinstance(content["id"], str)
99-
assert len(content["cards"]) == len(mock_collection.cards)
100-
for i, card in enumerate(mock_collection.cards):
101-
assert content["cards"][i]["front"] == card.front
102-
assert content["cards"][i]["back"] == card.back
103-
104-
mock_ai_generate.assert_called_once()
86+
with patch(
87+
"src.flashcards.api.check_and_increment_ai_usage_quota"
88+
) as mock_quota_check:
89+
mock_ai_generate.return_value = mock_collection
90+
mock_quota_check.return_value = True
91+
92+
rsp = client.post(
93+
f"{settings.API_V1_STR}/collections/",
94+
json=collection_data.model_dump(),
95+
headers=normal_user_token_headers,
96+
)
97+
98+
assert rsp.status_code == 200
99+
content = rsp.json()
100+
assert content["name"] == collection_data.name
101+
assert "id" in content
102+
assert isinstance(content["id"], str)
103+
assert len(content["cards"]) == len(mock_collection.cards)
104+
for i, card in enumerate(mock_collection.cards):
105+
assert content["cards"][i]["front"] == card.front
106+
assert content["cards"][i]["back"] == card.back
107+
108+
mock_ai_generate.assert_called_once()
105109

106110

107111
def test_create_collection_with_ai_generation_error(
@@ -114,19 +118,23 @@ def test_create_collection_with_ai_generation_error(
114118
with patch(
115119
"src.flashcards.services.generate_ai_collection", new_callable=AsyncMock
116120
) as mock_ai_generate:
117-
err_msg = "AI service is unavailable"
118-
mock_ai_generate.side_effect = AIGenerationError(err_msg)
119-
120-
rsp = client.post(
121-
f"{settings.API_V1_STR}/collections/",
122-
json=collection_data.model_dump(),
123-
headers=normal_user_token_headers,
124-
)
125-
126-
assert rsp.status_code == 500
127-
content = rsp.json()
128-
assert "detail" in content
129-
assert err_msg in content["detail"]
121+
with patch(
122+
"src.flashcards.api.check_and_increment_ai_usage_quota"
123+
) as mock_quota_check:
124+
err_msg = "AI service is unavailable"
125+
mock_ai_generate.side_effect = AIGenerationError(err_msg)
126+
mock_quota_check.return_value = True
127+
128+
rsp = client.post(
129+
f"{settings.API_V1_STR}/collections/",
130+
json=collection_data.model_dump(),
131+
headers=normal_user_token_headers,
132+
)
133+
134+
assert rsp.status_code == 500
135+
content = rsp.json()
136+
assert "detail" in content
137+
assert err_msg in content["detail"]
130138

131139

132140
def test_read_collection(

0 commit comments

Comments
 (0)