Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/trace_server/test_custom_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,87 @@ def test_custom_provider_ollama_model(client):
_secret_fetcher_context.reset(token)


def test_custom_provider_trailing_slash_normalization(client):
"""Test that trailing slashes in base_url are stripped to prevent redirect issues.

When a base_url has a trailing slash, HTTP servers often redirect to the
canonical URL without the slash using 301/302. These redirects cause most
HTTP clients to change POST to GET, breaking the API call.

This test verifies that trailing slashes are stripped before making the request.
"""
is_sqlite = client_is_sqlite(client)
if is_sqlite:
# no need to test in sqlite
return

# Create provider ID and model ID for testing
provider_id = f"test-trailing-slash-{uuid.uuid4()}"
model_id = "test-model"
model_name = f"custom::{provider_id}::{model_id}"

# Create a Provider object with a trailing slash in base_url
provider_obj = create_provider_obj(
project_id=client._project_id(),
provider_id=provider_id,
base_url="http://localhost:11434/", # Note the trailing slash
api_key_name="TEST_API_KEY",
extra_headers={},
)

# Create a ProviderModel object
provider_model_obj = create_provider_model_obj(
project_id=client._project_id(),
provider_id=provider_id,
model_id=model_id,
model_name=model_id,
)

# Mock responses for obj_read calls
mock_obj_read = create_mock_obj_read(provider_obj, provider_model_obj)

# Create test input
inputs = {
"model": model_name,
"messages": [{"role": "user", "content": "Hello, world!"}],
}

# Mock response from LiteLLM
mock_response = create_mock_completion_response(
model_name=model_id,
content="Hello!",
)

with with_tracing_disabled():
mock_secret_fetcher, token = setup_test_environment()
try:
with patch(
"weave.trace_server.clickhouse_trace_server_batched.ClickHouseTraceServer.obj_read"
) as mock_read:
mock_read.side_effect = mock_obj_read
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = ModelResponse.model_validate(
mock_response
)
client.server.completions_create(
tsi.CompletionsCreateReq.model_validate(
{
"project_id": client._project_id(),
"inputs": inputs,
}
)
)

# Verify the trailing slash was stripped from api_base
mock_completion.assert_called_once()
call_args = mock_completion.call_args[1]
assert call_args["api_base"] == "http://localhost:11434", (
f"Expected trailing slash to be stripped. Got '{call_args['api_base']}'"
)
finally:
_secret_fetcher_context.reset(token)


def test_get_custom_provider_info():
"""Test the get_custom_provider_info function directly."""
# Set up test data
Expand Down
10 changes: 10 additions & 0 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def lite_llm_completion(
extra_headers: dict[str, str] | None = None,
return_type: str | None = None,
) -> tsi.CompletionsCreateRes:
# Normalize base_url to prevent issues with trailing slashes causing redirects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that this fixes the issue makes me wonder if we have brittle url-creation code elsewhere. Is that other code something that we can change? If so, a more durable and holistic fix would be to make the other code more resilient to trailing slashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mscavezze-cw Yes unfortunately I believe this issue is on the litellm side of the url construction, we just send them the base url inputted by the user here.

# that change POST to GET (HTTP 301/302 redirect behavior)
if base_url:
base_url = base_url.rstrip("/")

# Setup provider-specific credentials and model modifications
(
aws_access_key_id,
Expand Down Expand Up @@ -534,6 +539,11 @@ def lite_llm_completion_stream(
follows the non-streaming version: any exception is surfaced to the caller
as a single error chunk and the iterator terminates.
"""
# Normalize base_url to prevent issues with trailing slashes causing redirects
# that change POST to GET (HTTP 301/302 redirect behavior)
if base_url:
base_url = base_url.rstrip("/")

# Setup provider-specific credentials and model modifications
(
aws_access_key_id,
Expand Down