diff --git a/.env.example b/.env.example index 769f774..d0f2364 100644 --- a/.env.example +++ b/.env.example @@ -20,7 +20,9 @@ OPENROUTER_BASE_URL=https://openrouter.ai EXTERNAL_OPENROUTER_KEY= GROK_BASE_URL=https://api.groq.com EXTERNAL_GROK_KEY= +# Base URL for Venice provider VENICE_BASE_URL=https://api.venice.ai +# API key for Venice EXTERNAL_VENICE_KEY= # Rate limiting settings diff --git a/CHANGELOG.md b/CHANGELOG.md index 36786d5..1d67b5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - llm-d cluster support (`make k3s-up` and router forwarding) - Redis caching layer with TTL (`REDIS_URL`, `CACHE_TTL`) +- Venice provider forwarding support ## [MVP Release] diff --git a/IMPLEMENTATION_STATUS.md b/IMPLEMENTATION_STATUS.md index 03448b9..dbabf7d 100644 --- a/IMPLEMENTATION_STATUS.md +++ b/IMPLEMENTATION_STATUS.md @@ -63,7 +63,7 @@ This section tracks features, integrations, and improvements to be implemented a - [x] OpenRouter - Grok -- Venice +- [x] Venice --- diff --git a/router/main.py b/router/main.py index 63f3f95..8662c07 100644 --- a/router/main.py +++ b/router/main.py @@ -14,7 +14,6 @@ from pathlib import Path - import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse, JSONResponse @@ -191,6 +190,20 @@ async def _startup() -> None: logger.addHandler(stream_handler) + +class Message(BaseModel): + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Message] + max_tokens: Optional[int] = None + temperature: Optional[float] = None + stream: Optional[bool] = False + + class AgentRegistration(BaseModel): name: str endpoint: str @@ -251,13 +264,9 @@ def make_cache_key(payload: ChatCompletionRequest) -> str: serialized = json.dumps(payload.dict(), sort_keys=True) digest = hashlib.sha256(serialized.encode()).hexdigest() - return digest - return f"chat:{digest}" - - async def forward_to_local_agent(payload: ChatCompletionRequest) -> dict: async with httpx.AsyncClient(base_url=LOCAL_AGENT_URL) as client: resp = await client.post("/infer", json=payload.dict()) @@ -364,7 +373,6 @@ async def forward_to_venice(payload: ChatCompletionRequest): return await venice.forward(payload, VENICE_BASE_URL, EXTERNAL_VENICE_KEY) - @app.post("/register") async def register_agent(payload: AgentRegistration) -> dict: """Register a local agent and update the model registry.""" @@ -454,3 +462,92 @@ async def metrics() -> Response: """Expose Prometheus metrics.""" return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) + + + backend = select_backend(payload) + + if backend == "local": + return await forward_to_local_agent(payload) + + if backend == "openai": + return await forward_to_openai(payload) + + cache_key = make_cache_key(payload) + if not payload.stream: + cached = await redis_client.get(cache_key) + if cached: + return json.loads(cached) + + entry = MODEL_REGISTRY.get(payload.model) + + if entry is not None: + if entry.type == "local": + data = await forward_to_local_agent(payload) + if not payload.stream: + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) + return data + if entry.type == "openai": + + return await forward_to_openai(payload) + + if entry.type == "llm-d": + return await forward_to_llmd(payload) + + if entry.type == "anthropic": + return await anthropic.forward( + payload, ANTHROPIC_BASE_URL, EXTERNAL_ANTHROPIC_KEY + ) + if entry.type == "google": + return await google.forward(payload, GOOGLE_BASE_URL, EXTERNAL_GOOGLE_KEY) + if entry.type == "openrouter": + return await openrouter.forward( + payload, OPENROUTER_BASE_URL, EXTERNAL_OPENROUTER_KEY + ) + if entry.type == "grok": + return await grok.forward(payload, GROK_BASE_URL, EXTERNAL_GROK_KEY) + if entry.type == "venice": + return await venice.forward(payload, VENICE_BASE_URL, EXTERNAL_VENICE_KEY) + + data = await forward_to_openai(payload) + if not payload.stream: + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) + return data + + if payload.model.startswith("local"): + data = await forward_to_local_agent(payload) + if not payload.stream: + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) + return data + + if payload.model.startswith("gpt-"): + data = await forward_to_openai(payload) + if not payload.stream: + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) + return data + + if payload.model.startswith("llmd-"): + return await forward_to_llmd(payload) + + dummy_text = "Hello world" + response = { + "id": f"cmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": payload.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": dummy_text}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + if not payload.stream: + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(response)) + return response + diff --git a/tests/router/test_provider_venice.py b/tests/router/test_provider_venice.py new file mode 100644 index 0000000..3863a4a --- /dev/null +++ b/tests/router/test_provider_venice.py @@ -0,0 +1,60 @@ +import httpx +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import router.main as router_main +import router.registry as registry +from sqlalchemy import create_engine + +venice_app = FastAPI() + + +@venice_app.post("/v1/chat/completions") +async def _completions(payload: router_main.ChatCompletionRequest): + user_msg = payload.messages[-1].content if payload.messages else "" + content = f"Venice: {user_msg}" + return { + "id": "ven-1", + "object": "chat.completion", + "created": 0, + "model": payload.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def test_forward_to_venice(monkeypatch, tmp_path) -> None: + monkeypatch.setattr(router_main, "VENICE_BASE_URL", "http://testserver") + monkeypatch.setattr(router_main, "EXTERNAL_VENICE_KEY", "dummy") + + db_path = tmp_path / "models.db" + monkeypatch.setattr(router_main, "SQLITE_DB_PATH", str(db_path)) + registry.SQLITE_DB_PATH = str(db_path) + registry.engine = create_engine(f"sqlite:///{db_path}") + registry.SessionLocal = registry.sessionmaker(bind=registry.engine) + registry.create_tables() + with registry.get_session() as session: + registry.upsert_model(session, "venus-1", "venice", "unused") + + real_async_client = httpx.AsyncClient + transport = httpx.ASGITransport(app=venice_app) + + def client_factory(*args, **kwargs): + return real_async_client(transport=transport, base_url="http://testserver") + + monkeypatch.setattr(router_main.httpx, "AsyncClient", client_factory) + + client = TestClient(router_main.app) + payload = { + "model": "venus-1", + "messages": [{"role": "user", "content": "hi"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "Venice: hi"