From e53dc80b36bf817422e7d3c3d79f620df396004f Mon Sep 17 00:00:00 2001 From: Eddie Dunn <45917906+eddiedunn@users.noreply.github.com> Date: Wed, 28 May 2025 14:20:36 -0400 Subject: [PATCH] feat: add Venice provider support --- .env.example | 2 + CHANGELOG.md | 1 + IMPLEMENTATION_STATUS.md | 2 +- router/main.py | 18 ++------- tests/router/test_provider_venice.py | 60 ++++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 tests/router/test_provider_venice.py diff --git a/.env.example b/.env.example index 2f810e8..9c7ac84 100644 --- a/.env.example +++ b/.env.example @@ -19,7 +19,9 @@ OPENROUTER_BASE_URL=https://openrouter.ai EXTERNAL_OPENROUTER_KEY= GROK_BASE_URL=https://api.groq.com EXTERNAL_GROK_KEY= +# Base URL for Venice provider VENICE_BASE_URL=https://api.venice.ai +# API key for Venice EXTERNAL_VENICE_KEY= # Rate limiting settings diff --git a/CHANGELOG.md b/CHANGELOG.md index 36786d5..1d67b5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - llm-d cluster support (`make k3s-up` and router forwarding) - Redis caching layer with TTL (`REDIS_URL`, `CACHE_TTL`) +- Venice provider forwarding support ## [MVP Release] diff --git a/IMPLEMENTATION_STATUS.md b/IMPLEMENTATION_STATUS.md index 482d597..e0eded4 100644 --- a/IMPLEMENTATION_STATUS.md +++ b/IMPLEMENTATION_STATUS.md @@ -62,7 +62,7 @@ This section tracks features, integrations, and improvements to be implemented a - Google - OpenRouter - Grok -- Venice +- [x] Venice --- diff --git a/router/main.py b/router/main.py index c5d1f5e..d0bb2bf 100644 --- a/router/main.py +++ b/router/main.py @@ -5,7 +5,6 @@ import uuid - import httpx from fastapi import FastAPI, HTTPException from fastapi.responses import Response, StreamingResponse @@ -95,6 +94,7 @@ "router_request_latency_seconds", "Request latency in seconds", labelnames=["backend"], +) CACHE_TTL = int(os.getenv("CACHE_TTL", "300")) @@ -125,8 +125,6 @@ async def dispatch(self, request: Request, call_next): return response - - app = FastAPI(title="Intelligent Inference Router") app.add_middleware(RateLimitMiddleware) @@ -160,7 +158,6 @@ async def _startup() -> None: logger.addHandler(stream_handler) - class Message(BaseModel): role: str content: str @@ -174,8 +171,6 @@ class ChatCompletionRequest(BaseModel): stream: Optional[bool] = False - - class AgentRegistration(BaseModel): name: str endpoint: str @@ -185,6 +180,7 @@ class AgentRegistration(BaseModel): class AgentHeartbeat(BaseModel): name: str + def select_backend(payload: ChatCompletionRequest) -> str: """Return backend key for the given request.""" @@ -208,9 +204,6 @@ def make_cache_key(payload: ChatCompletionRequest) -> str: digest = hashlib.sha256(serialized.encode()).hexdigest() - - - async def forward_to_local_agent(payload: ChatCompletionRequest) -> dict: async with httpx.AsyncClient(base_url=LOCAL_AGENT_URL) as client: resp = await client.post("/infer", json=payload.dict()) @@ -258,7 +251,6 @@ async def forward_to_openai(payload: ChatCompletionRequest): return resp.json() - async def forward_to_llmd(payload: ChatCompletionRequest): """Forward request to the llm-d cluster.""" @@ -285,6 +277,7 @@ async def forward_to_llmd(payload: ChatCompletionRequest): raise HTTPException(status_code=502, detail="llm-d error") from exc return resp.json() + @app.post("/register") async def register_agent(payload: AgentRegistration) -> dict: """Register a local agent and update the model registry.""" @@ -306,7 +299,6 @@ async def heartbeat(payload: AgentHeartbeat) -> dict: return {"status": "ok"} - @app.post("/v1/chat/completions") async def chat_completions(payload: ChatCompletionRequest): @@ -373,7 +365,6 @@ async def metrics() -> Response: return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) - backend = select_backend(payload) if backend == "local": @@ -423,7 +414,6 @@ async def metrics() -> Response: await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) return data - if payload.model.startswith("local"): data = await forward_to_local_agent(payload) if not payload.stream: @@ -436,7 +426,6 @@ async def metrics() -> Response: await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) return data - if payload.model.startswith("llmd-"): return await forward_to_llmd(payload) @@ -462,4 +451,3 @@ async def metrics() -> Response: if not payload.stream: await redis_client.setex(cache_key, CACHE_TTL, json.dumps(response)) return response - diff --git a/tests/router/test_provider_venice.py b/tests/router/test_provider_venice.py new file mode 100644 index 0000000..3863a4a --- /dev/null +++ b/tests/router/test_provider_venice.py @@ -0,0 +1,60 @@ +import httpx +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import router.main as router_main +import router.registry as registry +from sqlalchemy import create_engine + +venice_app = FastAPI() + + +@venice_app.post("/v1/chat/completions") +async def _completions(payload: router_main.ChatCompletionRequest): + user_msg = payload.messages[-1].content if payload.messages else "" + content = f"Venice: {user_msg}" + return { + "id": "ven-1", + "object": "chat.completion", + "created": 0, + "model": payload.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +def test_forward_to_venice(monkeypatch, tmp_path) -> None: + monkeypatch.setattr(router_main, "VENICE_BASE_URL", "http://testserver") + monkeypatch.setattr(router_main, "EXTERNAL_VENICE_KEY", "dummy") + + db_path = tmp_path / "models.db" + monkeypatch.setattr(router_main, "SQLITE_DB_PATH", str(db_path)) + registry.SQLITE_DB_PATH = str(db_path) + registry.engine = create_engine(f"sqlite:///{db_path}") + registry.SessionLocal = registry.sessionmaker(bind=registry.engine) + registry.create_tables() + with registry.get_session() as session: + registry.upsert_model(session, "venus-1", "venice", "unused") + + real_async_client = httpx.AsyncClient + transport = httpx.ASGITransport(app=venice_app) + + def client_factory(*args, **kwargs): + return real_async_client(transport=transport, base_url="http://testserver") + + monkeypatch.setattr(router_main.httpx, "AsyncClient", client_factory) + + client = TestClient(router_main.app) + payload = { + "model": "venus-1", + "messages": [{"role": "user", "content": "hi"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "Venice: hi"