Skip to content

Commit 8b4966a

Browse files
authored
Merge pull request #28 from janumiko/janumiko/redis
Janumiko/redis
2 parents 91fb147 + 269df7c commit 8b4966a

File tree

9 files changed

+344
-21
lines changed

9 files changed

+344
-21
lines changed

.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
GEMINI_API_KEY=<TOKEN>
2+
REDIS_PASSWORD=<REDIS_PASSWORD>
File renamed without changes.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ For containerized deployment:
7373

7474
3. **Run the container**:
7575
```sh
76-
sudo docker run -p 7860:7860 --env-file .env reagentai
76+
sudo docker compose up
7777
```
7878

7979
4. **Access the application**:

docker-compose.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
services:
2+
app:
3+
container_name: reagent_app
4+
build:
5+
context: .
6+
dockerfile: Dockerfile.reagent
7+
ports:
8+
- "7860:7860"
9+
env_file:
10+
- .env
11+
environment:
12+
- REDIS_HOST=redis
13+
- REDIS_PORT=6379
14+
- REDIS_PASSWORD=${REDIS_PASSWORD}
15+
networks:
16+
- app-network
17+
restart: no
18+
depends_on:
19+
redis:
20+
condition: service_started
21+
22+
redis:
23+
container_name: reagent_redis
24+
image: redis:7-alpine
25+
volumes:
26+
- redis_data:/data
27+
networks:
28+
- app-network
29+
restart: no
30+
command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru --requirepass ${REDIS_PASSWORD}
31+
healthcheck:
32+
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
33+
interval: 60s
34+
timeout: 5s
35+
retries: 3
36+
37+
volumes:
38+
redis_data:
39+
driver: local
40+
41+
networks:
42+
app-network:
43+
driver: bridge

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"gradio>=5.29.1",
1212
"pydantic-ai-slim[duckduckgo]>=0.2.4",
1313
"pubchempy>=1.0.4",
14+
"redis[hiredis]>=6.2.0",
1415
]
1516

1617
[tool.black]

src/reagentai/agents/main/main_agent.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self,
4141
model_name: str,
4242
instructions: str,
43-
tools: list[Tool],
43+
tools: list[Tool[MainAgentDependencyTypes]],
4444
dependency_types: type[MainAgentDependencyTypes],
4545
dependencies: MainAgentDependencyTypes,
4646
output_type: type[str],
@@ -80,12 +80,13 @@ def _create_agent(self) -> Agent[MainAgentDependencyTypes, str]:
8080
Agent[MainAgentDependencyTypes, str]: An instance of the Agent configured with the main agent's model and instructions.
8181
"""
8282

83-
return Agent(
83+
return Agent[MainAgentDependencyTypes, str](
8484
self.model_name,
8585
tools=self.tools,
8686
instructions=self.instructions,
8787
deps_type=self.dependency_types,
8888
output_type=self.output_type,
89+
retries=3,
8990
)
9091

9192
def remove_last_messages(self, remove_user_prompt: bool = True):
@@ -122,7 +123,7 @@ def get_total_token_usage(self) -> int:
122123
Returns:
123124
int: The total number of tokens used by the agent.
124125
"""
125-
if self.usage:
126+
if self.usage and self.usage.total_tokens:
126127
return self.usage.total_tokens
127128
else:
128129
return 0
@@ -137,7 +138,9 @@ def clear_history(self):
137138
self.usage = None
138139

139140
@asynccontextmanager
140-
async def run_stream(self, user_query: str) -> AsyncIterator[result.StreamedRunResult]:
141+
async def run_stream(
142+
self, user_query: str
143+
) -> AsyncIterator[result.StreamedRunResult[MainAgentDependencyTypes, str]]:
141144
"""
142145
Streams the response from the agent asynchronously.
143146
@@ -199,14 +202,14 @@ def create_main_agent() -> MainAgent:
199202
instructions = instructions_file.read()
200203

201204
tools = [
202-
Tool(perform_retrosynthesis, takes_ctx=True),
203-
Tool(is_valid_smiles),
204-
Tool(smiles_to_image),
205-
Tool(route_to_image),
206-
Tool(find_similar_molecules),
207-
Tool(get_smiles_from_name),
208-
Tool(get_compound_info),
209-
Tool(get_name_from_smiles),
205+
Tool[MainAgentDependencyTypes](perform_retrosynthesis, takes_ctx=True),
206+
Tool[MainAgentDependencyTypes](is_valid_smiles),
207+
Tool[MainAgentDependencyTypes](smiles_to_image),
208+
Tool[MainAgentDependencyTypes](route_to_image),
209+
Tool[MainAgentDependencyTypes](find_similar_molecules),
210+
Tool[MainAgentDependencyTypes](get_smiles_from_name),
211+
Tool[MainAgentDependencyTypes](get_compound_info),
212+
Tool[MainAgentDependencyTypes](get_name_from_smiles),
210213
duckduckgo_search_tool(),
211214
]
212215

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,148 @@
1+
import logging
2+
import pickle
3+
14
from aizynthfinder.context.config import Configuration
25

6+
from src.reagentai.common.utils.redis import RedisManager
37
from src.reagentai.models.retrosynthesis import RouteCollection
48

9+
logger = logging.getLogger(__name__)
10+
511

612
class RetrosynthesisCache:
713
"""
814
A cache for storing retrosynthesis routes based on target SMILES strings.
9-
This class provides methods to add, retrieve, and clear cached routes.
10-
It also maintains a configuration for the AiZynthFinder instance used in retrosynthesis.
15+
Supports both in-memory and Redis backends with automatic fallback.
1116
"""
1217

13-
routes_cache: dict[str, RouteCollection] = {}
18+
# Class-level cache for fast access
19+
_memory_cache: dict[str, RouteCollection] = {}
1420
finder_config: Configuration | None = None
21+
_cache_prefix = "retrosynthesis"
22+
_default_ttl = 86400 # 24 hours
23+
24+
@classmethod
25+
def _serialize_data(cls, data: RouteCollection) -> bytes:
26+
"""Serialize RouteCollection for Redis storage."""
27+
try:
28+
return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
29+
except Exception as e:
30+
logger.error(f"Failed to serialize data: {e}")
31+
raise
1532

1633
@classmethod
17-
def add(cls, target_smile: str, data: RouteCollection):
18-
cls.routes_cache[target_smile] = data
34+
def _deserialize_data(cls, data: bytes) -> RouteCollection:
35+
"""Deserialize RouteCollection from Redis storage."""
36+
try:
37+
return pickle.loads(data)
38+
except Exception as e:
39+
logger.error(f"Failed to deserialize data: {e}")
40+
raise
41+
42+
@classmethod
43+
def _get_cache_key(cls, target_smile: str) -> str:
44+
"""Generate standardized cache key."""
45+
normalized_smile = target_smile.strip().lower()
46+
return f"{cls._cache_prefix}:{normalized_smile}"
47+
48+
@classmethod
49+
def add(cls, target_smile: str, data: RouteCollection, ttl: int | None = None) -> bool:
50+
"""Add route collection to cache."""
51+
if not target_smile or not data:
52+
logger.warning("Invalid input for cache add operation")
53+
return False
54+
55+
# Always store in memory cache
56+
cls._memory_cache[target_smile] = data
57+
58+
# Attempt Redis storage
59+
ttl = ttl or cls._default_ttl
60+
cache_key = cls._get_cache_key(target_smile)
61+
62+
with RedisManager.get_client() as redis_client:
63+
if redis_client:
64+
try:
65+
serialized_data = cls._serialize_data(data)
66+
result = redis_client.setex(cache_key, ttl, serialized_data)
67+
if result:
68+
logger.debug(f"Cached to Redis: {cache_key}")
69+
return True
70+
except Exception as e:
71+
logger.warning(f"Failed to cache to Redis: {e}")
72+
73+
logger.debug(f"Cached to memory only: {target_smile}")
74+
return True
1975

2076
@classmethod
2177
def get(cls, target_smile: str) -> RouteCollection | None:
22-
return cls.routes_cache.get(target_smile)
78+
"""Retrieve route collection from cache."""
79+
if not target_smile:
80+
return None
81+
82+
# Check memory cache first
83+
if target_smile in cls._memory_cache:
84+
logger.debug(f"Cache hit (memory): {target_smile}")
85+
return cls._memory_cache[target_smile]
86+
87+
# Check Redis cache
88+
cache_key = cls._get_cache_key(target_smile)
89+
90+
with RedisManager.get_client() as redis_client:
91+
if redis_client:
92+
try:
93+
cached_data = redis_client.get(cache_key)
94+
if cached_data and isinstance(cached_data, bytes):
95+
data = cls._deserialize_data(cached_data)
96+
cls._memory_cache[target_smile] = data
97+
logger.debug(f"Cache hit (Redis): {target_smile}")
98+
return data
99+
except Exception as e:
100+
logger.warning(f"Failed to retrieve from Redis: {e}")
101+
102+
logger.debug(f"Cache miss: {target_smile}")
103+
return None
104+
105+
@classmethod
106+
def delete(cls, target_smile: str) -> bool:
107+
"""Delete specific entry from cache."""
108+
if not target_smile:
109+
return False
110+
111+
cls._memory_cache.pop(target_smile, None)
112+
cache_key = cls._get_cache_key(target_smile)
113+
114+
with RedisManager.get_client() as redis_client:
115+
if redis_client:
116+
try:
117+
result = redis_client.delete(cache_key)
118+
logger.debug(f"Deleted from cache: {target_smile}")
119+
return bool(result)
120+
except Exception as e:
121+
logger.warning(f"Failed to delete from Redis: {e}")
122+
123+
return True
124+
125+
@classmethod
126+
def clear(cls) -> bool:
127+
"""Clear all cached routes."""
128+
cls._memory_cache.clear()
129+
130+
with RedisManager.get_client() as redis_client:
131+
if redis_client:
132+
try:
133+
pipeline = redis_client.pipeline()
134+
for key in redis_client.scan_iter(match=f"{cls._cache_prefix}:*", count=100):
135+
pipeline.delete(key)
136+
pipeline.execute()
137+
logger.info("Cleared Redis cache")
138+
return True
139+
except Exception as e:
140+
logger.warning(f"Failed to clear Redis cache: {e}")
141+
142+
logger.info("Cleared memory cache")
143+
return True
23144

24145
@classmethod
25-
def clear(cls):
26-
cls.routes_cache.clear()
146+
def close(cls):
147+
"""Close Redis connections and cleanup resources."""
148+
RedisManager.close()
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from collections.abc import Generator
2+
from contextlib import contextmanager
3+
import logging
4+
import os
5+
6+
import redis
7+
from redis.exceptions import ConnectionError, RedisError, TimeoutError
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class RedisManager:
13+
"""Centralized Redis connection management."""
14+
15+
_pool: redis.ConnectionPool | None = None
16+
17+
@classmethod
18+
def get_pool(cls) -> redis.ConnectionPool | None:
19+
"""Get or create Redis connection pool."""
20+
if cls._pool is None:
21+
try:
22+
cls._pool = redis.ConnectionPool(
23+
host=os.getenv("REDIS_HOST", "localhost"),
24+
port=int(os.getenv("REDIS_PORT", "6379")),
25+
password=os.getenv("REDIS_PASSWORD"),
26+
decode_responses=False,
27+
socket_timeout=5,
28+
socket_connect_timeout=5,
29+
retry_on_timeout=True,
30+
max_connections=20,
31+
health_check_interval=30,
32+
)
33+
# Test connection
34+
with redis.Redis(connection_pool=cls._pool) as client:
35+
client.ping()
36+
logger.info("Redis connection pool initialized successfully")
37+
except (ConnectionError, TimeoutError, RedisError) as e:
38+
logger.warning(f"Redis connection failed: {e}")
39+
cls._pool = None
40+
except Exception as e:
41+
logger.error(f"Unexpected error initializing Redis: {e}")
42+
cls._pool = None
43+
return cls._pool
44+
45+
@classmethod
46+
@contextmanager
47+
def get_client(cls) -> Generator[redis.Redis | None, None, None]:
48+
"""Context manager for Redis client."""
49+
pool = cls.get_pool()
50+
if pool is None:
51+
yield None
52+
return
53+
54+
client = None
55+
try:
56+
client = redis.Redis(connection_pool=pool)
57+
yield client
58+
except (ConnectionError, TimeoutError, RedisError) as e:
59+
logger.warning(f"Redis operation failed: {e}")
60+
yield None
61+
except Exception as e:
62+
logger.error(f"Unexpected Redis error: {e}")
63+
yield None
64+
finally:
65+
if client:
66+
try:
67+
client.close()
68+
except Exception:
69+
pass
70+
71+
@classmethod
72+
def close(cls):
73+
"""Close Redis connection pool."""
74+
if cls._pool:
75+
try:
76+
cls._pool.disconnect()
77+
cls._pool = None
78+
logger.info("Redis connection pool closed")
79+
except Exception as e:
80+
logger.warning(f"Error closing Redis pool: {e}")

0 commit comments

Comments
 (0)