diff --git a/tests/trace_server/test_custom_provider.py b/tests/trace_server/test_custom_provider.py index 68c57804c207..d1379697bbb8 100644 --- a/tests/trace_server/test_custom_provider.py +++ b/tests/trace_server/test_custom_provider.py @@ -477,6 +477,87 @@ def test_custom_provider_ollama_model(client): _secret_fetcher_context.reset(token) +def test_custom_provider_trailing_slash_normalization(client): + """Test that trailing slashes in base_url are stripped to prevent redirect issues. + + When a base_url has a trailing slash, HTTP servers often redirect to the + canonical URL without the slash using 301/302. These redirects cause most + HTTP clients to change POST to GET, breaking the API call. + + This test verifies that trailing slashes are stripped before making the request. + """ + is_sqlite = client_is_sqlite(client) + if is_sqlite: + # no need to test in sqlite + return + + # Create provider ID and model ID for testing + provider_id = f"test-trailing-slash-{uuid.uuid4()}" + model_id = "test-model" + model_name = f"custom::{provider_id}::{model_id}" + + # Create a Provider object with a trailing slash in base_url + provider_obj = create_provider_obj( + project_id=client._project_id(), + provider_id=provider_id, + base_url="http://localhost:11434/", # Note the trailing slash + api_key_name="TEST_API_KEY", + extra_headers={}, + ) + + # Create a ProviderModel object + provider_model_obj = create_provider_model_obj( + project_id=client._project_id(), + provider_id=provider_id, + model_id=model_id, + model_name=model_id, + ) + + # Mock responses for obj_read calls + mock_obj_read = create_mock_obj_read(provider_obj, provider_model_obj) + + # Create test input + inputs = { + "model": model_name, + "messages": [{"role": "user", "content": "Hello, world!"}], + } + + # Mock response from LiteLLM + mock_response = create_mock_completion_response( + model_name=model_id, + content="Hello!", + ) + + with with_tracing_disabled(): + mock_secret_fetcher, token = setup_test_environment() + try: + with patch( + "weave.trace_server.clickhouse_trace_server_batched.ClickHouseTraceServer.obj_read" + ) as mock_read: + mock_read.side_effect = mock_obj_read + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse.model_validate( + mock_response + ) + client.server.completions_create( + tsi.CompletionsCreateReq.model_validate( + { + "project_id": client._project_id(), + "inputs": inputs, + } + ) + ) + + # Verify the trailing slash was stripped from api_base + mock_completion.assert_called_once() + call_args = mock_completion.call_args[1] + assert call_args["api_base"] == "http://localhost:11434", ( + f"Expected trailing slash to be stripped. Got '{call_args['api_base']}'" + ) + finally: + _secret_fetcher_context.reset(token) + + def test_get_custom_provider_info(): """Test the get_custom_provider_info function directly.""" # Set up test data diff --git a/weave/trace_server/llm_completion.py b/weave/trace_server/llm_completion.py index 90929373b2ed..d923b72dbbd8 100644 --- a/weave/trace_server/llm_completion.py +++ b/weave/trace_server/llm_completion.py @@ -231,6 +231,11 @@ def lite_llm_completion( extra_headers: dict[str, str] | None = None, return_type: str | None = None, ) -> tsi.CompletionsCreateRes: + # Normalize base_url to prevent issues with trailing slashes causing redirects + # that change POST to GET (HTTP 301/302 redirect behavior) + if base_url: + base_url = base_url.rstrip("/") + # Setup provider-specific credentials and model modifications ( aws_access_key_id, @@ -534,6 +539,11 @@ def lite_llm_completion_stream( follows the non-streaming version: any exception is surfaced to the caller as a single error chunk and the iterator terminates. """ + # Normalize base_url to prevent issues with trailing slashes causing redirects + # that change POST to GET (HTTP 301/302 redirect behavior) + if base_url: + base_url = base_url.rstrip("/") + # Setup provider-specific credentials and model modifications ( aws_access_key_id,