Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions apisix/plugins/ai-drivers/openai-base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,26 @@ local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT


function _M.new(opt)
return setmetatable(opt, mt)
local function merge_request_query_params(ctx, query_params)
if ctx.var.is_args == "?" and ctx.var.args and #ctx.var.args > 0 then
local req_args_tab = core.string.decode_args(ctx.var.args)
if type(req_args_tab) == "table" then
core.table.merge(query_params, req_args_tab)
end
end
end


function _M.new(opts)
local self = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? It breaks the original transparent pass-through mechanism—now only explicitly listed fields are passed through.

host = opts.host,
port = opts.port,
path = opts.path,
remove_model = opts.options and opts.options.remove_model,
request_filter = opts.request_filter,
response_filter = opts.response_filter,
}
return setmetatable(self, mt)
end


Expand Down Expand Up @@ -295,7 +313,30 @@ function _M.request(self, ctx, conf, request_table, extra_opts)
end
end

local path = (parsed_url and parsed_url.path or self.path)
local path_mode = extra_opts.path_mode or "fixed"
local endpoint_path = parsed_url and parsed_url.path
local req_path = ctx.var.uri
local path

if path_mode == "preserve" then
path = req_path
merge_request_query_params(ctx, query_params)
elseif path_mode == "append" then
local prefix = endpoint_path or ""
if prefix == "" or prefix == "/" then
path = req_path
else
path = prefix .. req_path
path = path:gsub("//+", "/")
end
merge_request_query_params(ctx, query_params)
else
if endpoint_path and endpoint_path ~= "" then
path = endpoint_path
else
path = self.path
end
end

local headers = auth.header or {}
headers["Content-Type"] = "application/json"
Expand Down
1 change: 1 addition & 0 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function _M.before_proxy(conf, ctx, on_error)
local extra_opts = {
name = ai_instance.name,
endpoint = core.table.try_read_attr(ai_instance, "override", "endpoint"),
path_mode = core.table.try_read_attr(ai_instance, "override", "path_mode"),
model_options = ai_instance.options,
conf = ai_instance.provider_conf or {},
auth = ai_instance.auth,
Expand Down
12 changes: 12 additions & 0 deletions apisix/plugins/ai-proxy/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ local model_options_schema = {
additionalProperties = true,
}

local path_mode_schema = {
type = "string",
enum = {"fixed", "preserve", "append"},
default = "fixed",
description = "How to determine the upstream request path: " ..
"fixed (default) uses endpoint path or driver default, " ..
"preserve uses the original request URI path, " ..
"append appends the original request URI path to the endpoint path",
}

local provider_vertex_ai_schema = {
type = "object",
properties = {
Expand Down Expand Up @@ -122,6 +132,7 @@ local ai_instance_schema = {
type = "string",
description = "To be specified to override the endpoint of the AI Instance",
},
path_mode = path_mode_schema,
},
},
checks = {
Expand Down Expand Up @@ -198,6 +209,7 @@ _M.ai_proxy_schema = {
type = "string",
description = "To be specified to override the endpoint of the AI Instance",
},
path_mode = path_mode_schema,
},
},
},
Expand Down
155 changes: 155 additions & 0 deletions t/plugin/ai-proxy-multi.openai-compatible.t
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,31 @@ _EOC_
ngx.say("path override works")
}
}

location /proxy/v1/chat/completions {
content_by_lua_block {
local json = require("cjson.safe")
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)

local header_auth = ngx.req.get_headers()["authorization"]
if header_auth ~= "Bearer token" then
ngx.status = 401
ngx.say("Unauthorized")
return
end

if not body.messages or #body.messages < 1 then
ngx.status = 400
ngx.say([[{ "error": "bad request"}]])
return
end

ngx.status = 200
ngx.say([[{"path_mode": "append works", "path": "/proxy/v1/chat/completions"}]])
}
}
}
_EOC_

Expand Down Expand Up @@ -296,3 +321,133 @@ passed
}
--- response_body_like eval
qr/6data: \[DONE\]\n\n/



=== TEST 5: path_mode=preserve - route uri preserved as upstream path
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/v1/chat/completions",
"plugins": {
"ai-proxy-multi": {
"instances": [
{
"name": "self-hosted",
"provider": "openai-compatible",
"weight": 1,
"auth": {
"header": {
"Authorization": "Bearer token"
}
},
"options": {
"model": "custom"
},
"override": {
"endpoint": "http://localhost:6724",
"path_mode": "preserve"
}
}
],
"ssl_verify": false
}
}
}]]
)

if code >= 300 then
ngx.status = code
ngx.say(body)
return
end

local code, body, actual_body = t("/v1/chat/completions",
ngx.HTTP_POST,
[[{
"messages": [
{ "role": "system", "content": "You are a mathematician" },
{ "role": "user", "content": "What is 1+1?" }
]
}]],
nil,
{
["Content-Type"] = "application/json",
}
)

ngx.status = code
ngx.say(actual_body)
}
}
--- response_body eval
qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/



=== TEST 6: path_mode=append - endpoint path + request uri
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/v1/chat/completions",
"plugins": {
"ai-proxy-multi": {
"instances": [
{
"name": "self-hosted",
"provider": "openai-compatible",
"weight": 1,
"auth": {
"header": {
"Authorization": "Bearer token"
}
},
"options": {
"model": "custom"
},
"override": {
"endpoint": "http://localhost:6724/proxy",
"path_mode": "append"
}
}
],
"ssl_verify": false
}
}
}]]
)

if code >= 300 then
ngx.status = code
ngx.say(body)
return
end

local code, body, actual_body = t("/v1/chat/completions",
ngx.HTTP_POST,
[[{
"messages": [
{ "role": "system", "content": "You are a mathematician" },
{ "role": "user", "content": "What is 1+1?" }
]
}]],
nil,
{
["Content-Type"] = "application/json",
}
)

ngx.status = code
ngx.say(actual_body)
}
}
--- response_body
{"path_mode": "append works", "path": "/proxy/v1/chat/completions"}
Loading
Loading