Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions router/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,12 @@ def get_weight_provider(name: str) -> WeightProvider:

provider = WEIGHT_PROVIDERS.get(name)
if provider is None:
try:
module = getattr(providers, name)
except AttributeError as exc: # coverage: ignore -- defensive
module = providers.PROVIDER_REGISTRY.get(name)
if module is None:
raise HTTPException(
status_code=500,
detail=f"Unsupported provider '{name}'",
) from exc
)
class_name = "".join(part.capitalize() for part in name.split("_")) + "Provider"
provider_cls = getattr(module, class_name, None)
if provider_cls is None:
Expand Down
28 changes: 27 additions & 1 deletion router/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
"""Collection of provider implementations for external APIs."""

from . import anthropic, google, openrouter, grok, venice, openai, huggingface
# ruff: noqa: E402

from __future__ import annotations

from types import ModuleType

from .base import ApiProvider, WeightProvider

PROVIDER_REGISTRY: dict[str, ModuleType] = {}


def register_provider(name: str, module: ModuleType) -> None:
"""Register ``module`` under ``name``."""

PROVIDER_REGISTRY[name] = module


from . import (
anthropic,
google,
openrouter,
grok,
venice,
openai,
huggingface,
) # noqa: E402

__all__ = [
"ApiProvider",
"WeightProvider",
"PROVIDER_REGISTRY",
"register_provider",
"anthropic",
"google",
"openrouter",
Expand Down
5 changes: 5 additions & 0 deletions router/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import AsyncIterator

from .base import ApiProvider
from . import register_provider
import sys


class AnthropicProvider(ApiProvider):
Expand Down Expand Up @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``AnthropicProvider``."""
provider = AnthropicProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("anthropic", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import AsyncIterator

from .base import ApiProvider
from . import register_provider
import sys


class GoogleProvider(ApiProvider):
Expand Down Expand Up @@ -61,3 +63,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``GoogleProvider``."""
provider = GoogleProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("google", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import AsyncIterator

from .base import ApiProvider
from . import register_provider
import sys


class GrokProvider(ApiProvider):
Expand Down Expand Up @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``GrokProvider``."""
provider = GrokProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("grok", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from ..schemas import ChatCompletionRequest
from .base import WeightProvider
from . import register_provider
import sys


class HuggingFaceProvider(WeightProvider):
Expand Down Expand Up @@ -82,3 +84,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str) -> dict:

provider = HuggingFaceProvider()
return await provider.forward(payload, base_url)


register_provider("huggingface", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from ..schemas import ChatCompletionRequest
from ..utils import stream_resp
from .base import ApiProvider
from . import register_provider
import sys


class OpenAIProvider(ApiProvider):
Expand Down Expand Up @@ -52,3 +54,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``OpenAIProvider``."""
provider = OpenAIProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("openai", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import AsyncIterator

from .base import ApiProvider
from . import register_provider
import sys


class OpenRouterProvider(ApiProvider):
Expand Down Expand Up @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``OpenRouterProvider``."""
provider = OpenRouterProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("openrouter", sys.modules[__name__])
5 changes: 5 additions & 0 deletions router/providers/venice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import AsyncIterator

from .base import ApiProvider
from . import register_provider
import sys


class VeniceProvider(ApiProvider):
Expand Down Expand Up @@ -57,3 +59,6 @@ async def forward(payload: ChatCompletionRequest, base_url: str, api_key: str |
"""Backward compatible wrapper for ``VeniceProvider``."""
provider = VeniceProvider()
return await provider.forward(payload, base_url, api_key)


register_provider("venice", sys.modules[__name__])
9 changes: 8 additions & 1 deletion tests/router/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@ async def forward(self, payload, base_url):
return {}

dummy_module = types.SimpleNamespace(DummyProvider=Dummy)
monkeypatch.setattr(router_main.providers, "dummy", dummy_module, raising=False)
monkeypatch.setitem(router_main.providers.PROVIDER_REGISTRY, "dummy", dummy_module)

router_main.WEIGHT_PROVIDERS.clear()
provider1 = router_main.get_weight_provider("dummy")
provider2 = router_main.get_weight_provider("dummy")
assert isinstance(provider1, Dummy)
assert provider1 is provider2
assert router_main.providers.PROVIDER_REGISTRY["dummy"] is dummy_module

with pytest.raises(router_main.HTTPException):
router_main.get_weight_provider("missing")


def test_provider_registry_contains_defaults():
reg = router_main.providers.PROVIDER_REGISTRY
assert "openai" in reg
assert reg["openai"] is router_main.providers.openai
Loading