From 5044207b44b10025ff5799b1b13cbd1520fbc8d9 Mon Sep 17 00:00:00 2001 From: Eddie Dunn <45917906+eddiedunn@users.noreply.github.com> Date: Tue, 10 Jun 2025 10:51:15 -0400 Subject: [PATCH] feat: add provider registry --- router/main.py | 7 +++---- router/providers/__init__.py | 28 +++++++++++++++++++++++++++- router/providers/anthropic.py | 5 +++++ router/providers/google.py | 5 +++++ router/providers/grok.py | 5 +++++ router/providers/huggingface.py | 5 +++++ router/providers/openai.py | 5 +++++ router/providers/openrouter.py | 5 +++++ router/providers/venice.py | 5 +++++ tests/router/test_utils.py | 9 ++++++++- 10 files changed, 73 insertions(+), 6 deletions(-) diff --git a/router/main.py b/router/main.py index a8f59dc..e99e1a5 100644 --- a/router/main.py +++ b/router/main.py @@ -160,13 +160,12 @@ def get_weight_provider(name: str) -> WeightProvider: provider = WEIGHT_PROVIDERS.get(name) if provider is None: - try: - module = getattr(providers, name) - except AttributeError as exc: # coverage: ignore -- defensive + module = providers.PROVIDER_REGISTRY.get(name) + if module is None: raise HTTPException( status_code=500, detail=f"Unsupported provider '{name}'", - ) from exc + ) class_name = "".join(part.capitalize() for part in name.split("_")) + "Provider" provider_cls = getattr(module, class_name, None) if provider_cls is None: diff --git a/router/providers/__init__.py b/router/providers/__init__.py index 5aec1b1..47cf63e 100644 --- a/router/providers/__init__.py +++ b/router/providers/__init__.py @@ -1,11 +1,37 @@ """Collection of provider implementations for external APIs.""" -from . import anthropic, google, openrouter, grok, venice, openai, huggingface +# ruff: noqa: E402 + +from __future__ import annotations + +from types import ModuleType + from .base import ApiProvider, WeightProvider +PROVIDER_REGISTRY: dict[str, ModuleType] = {} + + +def register_provider(name: str, module: ModuleType) -> None: + """Register ``module`` under ``name``.""" + + PROVIDER_REGISTRY[name] = module + + +from . import ( + anthropic, + google, + openrouter, + grok, + venice, + openai, + huggingface, +) # noqa: E402 + __all__ = [ "ApiProvider", "WeightProvider", + "PROVIDER_REGISTRY", + "register_provider", "anthropic", "google", "openrouter", diff --git a/router/providers/anthropic.py b/router/providers/anthropic.py index 374d812..aca068f 100644 --- a/router/providers/anthropic.py +++ b/router/providers/anthropic.py @@ -10,6 +10,8 @@ from typing import AsyncIterator from .base import ApiProvider +from . import register_provider +import sys class AnthropicProvider(ApiProvider): @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``AnthropicProvider``.""" provider = AnthropicProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("anthropic", sys.modules[__name__]) diff --git a/router/providers/google.py b/router/providers/google.py index 6a5ab21..bfb3d48 100644 --- a/router/providers/google.py +++ b/router/providers/google.py @@ -10,6 +10,8 @@ from typing import AsyncIterator from .base import ApiProvider +from . import register_provider +import sys class GoogleProvider(ApiProvider): @@ -61,3 +63,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``GoogleProvider``.""" provider = GoogleProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("google", sys.modules[__name__]) diff --git a/router/providers/grok.py b/router/providers/grok.py index ee8a485..fb0cb91 100644 --- a/router/providers/grok.py +++ b/router/providers/grok.py @@ -10,6 +10,8 @@ from typing import AsyncIterator from .base import ApiProvider +from . import register_provider +import sys class GrokProvider(ApiProvider): @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``GrokProvider``.""" provider = GrokProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("grok", sys.modules[__name__]) diff --git a/router/providers/huggingface.py b/router/providers/huggingface.py index 540e488..a7ce0d5 100644 --- a/router/providers/huggingface.py +++ b/router/providers/huggingface.py @@ -13,6 +13,8 @@ from ..schemas import ChatCompletionRequest from .base import WeightProvider +from . import register_provider +import sys class HuggingFaceProvider(WeightProvider): @@ -82,3 +84,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str) -> dict: provider = HuggingFaceProvider() return await provider.forward(payload, base_url) + + +register_provider("huggingface", sys.modules[__name__]) diff --git a/router/providers/openai.py b/router/providers/openai.py index ddb721d..3dd03ef 100644 --- a/router/providers/openai.py +++ b/router/providers/openai.py @@ -9,6 +9,8 @@ from ..schemas import ChatCompletionRequest from ..utils import stream_resp from .base import ApiProvider +from . import register_provider +import sys class OpenAIProvider(ApiProvider): @@ -52,3 +54,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``OpenAIProvider``.""" provider = OpenAIProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("openai", sys.modules[__name__]) diff --git a/router/providers/openrouter.py b/router/providers/openrouter.py index efa3a8f..e53df2d 100644 --- a/router/providers/openrouter.py +++ b/router/providers/openrouter.py @@ -10,6 +10,8 @@ from typing import AsyncIterator from .base import ApiProvider +from . import register_provider +import sys class OpenRouterProvider(ApiProvider): @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``OpenRouterProvider``.""" provider = OpenRouterProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("openrouter", sys.modules[__name__]) diff --git a/router/providers/venice.py b/router/providers/venice.py index 88bc7df..798b3ac 100644 --- a/router/providers/venice.py +++ b/router/providers/venice.py @@ -10,6 +10,8 @@ from typing import AsyncIterator from .base import ApiProvider +from . import register_provider +import sys class VeniceProvider(ApiProvider): @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str | """Backward compatible wrapper for ``VeniceProvider``.""" provider = VeniceProvider() return await provider.forward(payload, base_url, api_key) + + +register_provider("venice", sys.modules[__name__]) diff --git a/tests/router/test_utils.py b/tests/router/test_utils.py index 1a42b25..ae55ae7 100644 --- a/tests/router/test_utils.py +++ b/tests/router/test_utils.py @@ -52,13 +52,20 @@ async def forward(self, payload, base_url): return {} dummy_module = types.SimpleNamespace(DummyProvider=Dummy) - monkeypatch.setattr(router_main.providers, "dummy", dummy_module, raising=False) + monkeypatch.setitem(router_main.providers.PROVIDER_REGISTRY, "dummy", dummy_module) router_main.WEIGHT_PROVIDERS.clear() provider1 = router_main.get_weight_provider("dummy") provider2 = router_main.get_weight_provider("dummy") assert isinstance(provider1, Dummy) assert provider1 is provider2 + assert router_main.providers.PROVIDER_REGISTRY["dummy"] is dummy_module with pytest.raises(router_main.HTTPException): router_main.get_weight_provider("missing") + + +def test_provider_registry_contains_defaults(): + reg = router_main.providers.PROVIDER_REGISTRY + assert "openai" in reg + assert reg["openai"] is router_main.providers.openai