|
14 | 14 | from pathlib import Path |
15 | 15 |
|
16 | 16 |
|
17 | | - |
18 | 17 | import httpx |
19 | 18 | from fastapi import FastAPI, HTTPException, Request |
20 | 19 | from fastapi.responses import Response, StreamingResponse, JSONResponse |
@@ -191,6 +190,20 @@ async def _startup() -> None: |
191 | 190 | logger.addHandler(stream_handler) |
192 | 191 |
|
193 | 192 |
|
| 193 | + |
| 194 | +class Message(BaseModel): |
| 195 | + role: str |
| 196 | + content: str |
| 197 | + |
| 198 | + |
| 199 | +class ChatCompletionRequest(BaseModel): |
| 200 | + model: str |
| 201 | + messages: List[Message] |
| 202 | + max_tokens: Optional[int] = None |
| 203 | + temperature: Optional[float] = None |
| 204 | + stream: Optional[bool] = False |
| 205 | + |
| 206 | + |
194 | 207 | class AgentRegistration(BaseModel): |
195 | 208 | name: str |
196 | 209 | endpoint: str |
@@ -251,13 +264,9 @@ def make_cache_key(payload: ChatCompletionRequest) -> str: |
251 | 264 | serialized = json.dumps(payload.dict(), sort_keys=True) |
252 | 265 | digest = hashlib.sha256(serialized.encode()).hexdigest() |
253 | 266 |
|
254 | | - return digest |
255 | | - |
256 | 267 |
|
257 | 268 | return f"chat:{digest}" |
258 | 269 |
|
259 | | - |
260 | | - |
261 | 270 | async def forward_to_local_agent(payload: ChatCompletionRequest) -> dict: |
262 | 271 | async with httpx.AsyncClient(base_url=LOCAL_AGENT_URL) as client: |
263 | 272 | resp = await client.post("/infer", json=payload.dict()) |
@@ -364,7 +373,6 @@ async def forward_to_venice(payload: ChatCompletionRequest): |
364 | 373 |
|
365 | 374 | return await venice.forward(payload, VENICE_BASE_URL, EXTERNAL_VENICE_KEY) |
366 | 375 |
|
367 | | - |
368 | 376 | @app.post("/register") |
369 | 377 | async def register_agent(payload: AgentRegistration) -> dict: |
370 | 378 | """Register a local agent and update the model registry.""" |
@@ -454,3 +462,92 @@ async def metrics() -> Response: |
454 | 462 | """Expose Prometheus metrics.""" |
455 | 463 |
|
456 | 464 | return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) |
| 465 | + |
| 466 | + |
| 467 | + backend = select_backend(payload) |
| 468 | + |
| 469 | + if backend == "local": |
| 470 | + return await forward_to_local_agent(payload) |
| 471 | + |
| 472 | + if backend == "openai": |
| 473 | + return await forward_to_openai(payload) |
| 474 | + |
| 475 | + cache_key = make_cache_key(payload) |
| 476 | + if not payload.stream: |
| 477 | + cached = await redis_client.get(cache_key) |
| 478 | + if cached: |
| 479 | + return json.loads(cached) |
| 480 | + |
| 481 | + entry = MODEL_REGISTRY.get(payload.model) |
| 482 | + |
| 483 | + if entry is not None: |
| 484 | + if entry.type == "local": |
| 485 | + data = await forward_to_local_agent(payload) |
| 486 | + if not payload.stream: |
| 487 | + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) |
| 488 | + return data |
| 489 | + if entry.type == "openai": |
| 490 | + |
| 491 | + return await forward_to_openai(payload) |
| 492 | + |
| 493 | + if entry.type == "llm-d": |
| 494 | + return await forward_to_llmd(payload) |
| 495 | + |
| 496 | + if entry.type == "anthropic": |
| 497 | + return await anthropic.forward( |
| 498 | + payload, ANTHROPIC_BASE_URL, EXTERNAL_ANTHROPIC_KEY |
| 499 | + ) |
| 500 | + if entry.type == "google": |
| 501 | + return await google.forward(payload, GOOGLE_BASE_URL, EXTERNAL_GOOGLE_KEY) |
| 502 | + if entry.type == "openrouter": |
| 503 | + return await openrouter.forward( |
| 504 | + payload, OPENROUTER_BASE_URL, EXTERNAL_OPENROUTER_KEY |
| 505 | + ) |
| 506 | + if entry.type == "grok": |
| 507 | + return await grok.forward(payload, GROK_BASE_URL, EXTERNAL_GROK_KEY) |
| 508 | + if entry.type == "venice": |
| 509 | + return await venice.forward(payload, VENICE_BASE_URL, EXTERNAL_VENICE_KEY) |
| 510 | + |
| 511 | + data = await forward_to_openai(payload) |
| 512 | + if not payload.stream: |
| 513 | + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) |
| 514 | + return data |
| 515 | + |
| 516 | + if payload.model.startswith("local"): |
| 517 | + data = await forward_to_local_agent(payload) |
| 518 | + if not payload.stream: |
| 519 | + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) |
| 520 | + return data |
| 521 | + |
| 522 | + if payload.model.startswith("gpt-"): |
| 523 | + data = await forward_to_openai(payload) |
| 524 | + if not payload.stream: |
| 525 | + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(data)) |
| 526 | + return data |
| 527 | + |
| 528 | + if payload.model.startswith("llmd-"): |
| 529 | + return await forward_to_llmd(payload) |
| 530 | + |
| 531 | + dummy_text = "Hello world" |
| 532 | + response = { |
| 533 | + "id": f"cmpl-{uuid.uuid4().hex}", |
| 534 | + "object": "chat.completion", |
| 535 | + "created": int(time.time()), |
| 536 | + "model": payload.model, |
| 537 | + "choices": [ |
| 538 | + { |
| 539 | + "index": 0, |
| 540 | + "message": {"role": "assistant", "content": dummy_text}, |
| 541 | + "finish_reason": "stop", |
| 542 | + } |
| 543 | + ], |
| 544 | + "usage": { |
| 545 | + "prompt_tokens": 0, |
| 546 | + "completion_tokens": 0, |
| 547 | + "total_tokens": 0, |
| 548 | + }, |
| 549 | + } |
| 550 | + if not payload.stream: |
| 551 | + await redis_client.setex(cache_key, CACHE_TTL, json.dumps(response)) |
| 552 | + return response |
| 553 | + |
0 commit comments