Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xllm_service/common/call_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ class StreamCallData : public CallData {
brpc::Controller* controller,
bool stream,
::google::protobuf::Closure* done,
Request* request,
Response* response,
std::function<void(const std::string&)> trace_callback = nullptr)
: controller_(controller),
done_(done),
request_(request),
response_(response),
trace_callback_(std::move(trace_callback)) {
stream_ = stream;
Expand Down Expand Up @@ -200,6 +202,7 @@ class StreamCallData : public CallData {
return true;
}

Request& request() { return *request_; }
Response& response() { return *response_; }
::google::protobuf::Closure* done() { return done_; }
bool finished() { return finished_; }
Expand All @@ -208,7 +211,8 @@ class StreamCallData : public CallData {
brpc::Controller* controller_;
::google::protobuf::Closure* done_;

Response* response_;
Request* request_ = nullptr;
Response* response_ = nullptr;

bool stream_ = false;
butil::intrusive_ptr<brpc::ProgressiveAttachment> pa_;
Expand Down
3 changes: 3 additions & 0 deletions xllm_service/common/xllm/output.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ struct RequestOutput {

// whether the request is finished.
bool finished = false;

// whether the prefill stage is finished on prefill_instance.
bool finished_on_prefill_instance = false;
};

inline std::optional<std::string> to_string(FinishReason reason) {
Expand Down
210 changes: 101 additions & 109 deletions xllm_service/http_service/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <brpc/controller.h>
#include <brpc/progressive_reader.h>
#include <glog/logging.h>
#include <google/protobuf/util/json_util.h>
#include <json2pb/json_to_pb.h>
#include <json2pb/pb_to_json.h>

Expand All @@ -30,9 +31,11 @@ limitations under the License.
#include "common/call_data.h"
#include "common/closure_guard.h"
#include "common/utils.h"
#include "common/xllm/status.h"
#include "common/xllm/uuid.h"
#include "completion.pb.h"
#include "scheduler/scheduler.h"
#include "xllm_service.pb.h"

namespace xllm_service {

Expand Down Expand Up @@ -88,46 +91,20 @@ void handle_non_stream_response(brpc::Controller* cntl,
call_data->write_and_finish(cntl->response_attachment().to_string());
}

// fire and forget
template <typename T>
void handle_first_response(brpc::Controller* cntl,
std::shared_ptr<T> call_data,
Scheduler* scheduler,
std::string service_request_id,
bool stream) {
void handle_first_send_request(brpc::Controller* cntl,
std::shared_ptr<T> call_data,
Scheduler* scheduler,
std::string service_request_id,
bool stream) {
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
if (cntl->Failed()) {
LOG(ERROR) << "Fail to send stream generation, " << cntl->ErrorText();
call_data->finish_with_error(cntl->ErrorText());
scheduler->finish_request(service_request_id, /*error*/ true);
return;
}

if (stream) {
// write first token from prefill
std::string response = cntl->response_attachment().to_string();
// check response for stream request to handle error in prefill instance
// Currently, 1.response with "data:" prefix means no error and return the
// first token 2.empty response means the first token can not directly
// generate characters
if (!response.empty()) {
if (response.find("data:") != 0) {
LOG(ERROR) << "Fail in the prefill instance, " << response;
call_data->finish_with_error(response);
scheduler->finish_request(service_request_id, /*error*/ true);
return;
}
call_data->write(response);
}
}
// non-stream, all generated tokens will be sent from decode via rpc service.
// non-stream, all error in prefill instance will be handled through
// cntrl->setFailed()

// update token latency metrics
scheduler->update_token_latency_metrics_for_prefill(service_request_id);

// update request metrics for prefill finished request
scheduler->update_request_metrics_for_prefill(service_request_id);
}

template <typename T>
Expand Down Expand Up @@ -163,12 +140,34 @@ class CustomProgressiveReader : public brpc::ProgressiveReader {
};
} // namespace

namespace {

constexpr char kInferContentLength[] = "Infer-Content-Length";
constexpr char kContentLength[] = "Content-Length";

size_t GetJsonContentLength(const brpc::Controller* ctrl) {
const auto infer_content_len =
ctrl->http_request().GetHeader(kInferContentLength);
if (infer_content_len != nullptr) {
return std::stoul(*infer_content_len);
}

const auto content_len = ctrl->http_request().GetHeader(kContentLength);
if (content_len != nullptr) {
return std::stoul(*content_len);
}

LOG(FATAL) << "Content-Length header is missing.";
return (size_t)-1L;
}

} // namespace

template <typename T>
void XllmHttpServiceImpl::handle(std::shared_ptr<T> call_data,
const std::string& req_attachment,
std::shared_ptr<Request> request,
const std::string& method) {
std::shared_ptr<Request> request) {
// record request
auto& req_pb = call_data->request();
bool success = scheduler_->record_new_request(call_data, request);
if (!success) {
LOG(ERROR) << "rpc service add new request error: "
Expand All @@ -181,32 +180,29 @@ void XllmHttpServiceImpl::handle(std::shared_ptr<T> call_data,
// TODO: optimize the thread pool to async mode.
auto& target_uri = request->routing.prefill_name;
brpc::Channel* channel_ptr = scheduler_->get_channel(target_uri).get();
// use stub
xllm::proto::XllmAPIService_Stub stub(channel_ptr);
// xllm::proto::Status* resp_pb = new xllm::proto::Status();
brpc::Controller* redirect_cntl = new brpc::Controller();
google::protobuf::Closure* done =
brpc::NewCallback(&handle_first_send_request<T>,
redirect_cntl,
call_data,
scheduler_,
request->service_request_id,
request->stream);

if constexpr (std::is_same_v<T, CompletionCallData>) {
stub.Completions(redirect_cntl, &req_pb, nullptr, done);
} else if constexpr (std::is_same_v<T, ChatCallData>) {
stub.ChatCompletions(redirect_cntl, &req_pb, nullptr, done);
} else {
delete redirect_cntl;
delete done;
LOG(ERROR) << "Unknown call_data type";
}

// send request to prefill instance.
thread_pool_->schedule([this,
request,
req_attachment = std::move(req_attachment),
call_data,
channel_ptr,
target_uri = target_uri + method]() {
brpc::Controller* redirect_cntl = new brpc::Controller();
redirect_cntl->http_request().uri() = target_uri.c_str();
redirect_cntl->http_request().set_method(brpc::HTTP_METHOD_POST);

// redirect the input request content
redirect_cntl->request_attachment().append(req_attachment);

// tokens will be received via rpc channel.
google::protobuf::Closure* done =
brpc::NewCallback(&handle_first_response<T>,
redirect_cntl,
call_data,
scheduler_,
request->service_request_id,
request->stream);
channel_ptr->CallMethod(NULL, redirect_cntl, NULL, NULL, done);
return;
});
}

template <typename T>
Expand Down Expand Up @@ -240,22 +236,31 @@ std::shared_ptr<Request> XllmHttpServiceImpl::generate_request(
}

namespace {
void handle_get_response(brpc::Controller* cntl,
std::shared_ptr<CompletionCallData> call_data,
google::protobuf::Closure* done) {
void handle_get_model_response(brpc::Controller* cntl,
std::shared_ptr<CompletionCallData> call_data,
google::protobuf::Closure* done,
xllm::proto::ModelListResponse* resp_pb) {
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<google::protobuf::Closure> done_guard(done);
std::unique_ptr<xllm::proto::ModelListResponse> resp_pb_guard(resp_pb);

if (cntl->Failed()) {
LOG(ERROR) << "Fail to send stream generation, " << cntl->ErrorText();
call_data->finish_with_error(cntl->ErrorText());
return;
}
call_data->write_and_finish(cntl->response_attachment().to_string());
std::string err_msg;
std::string json_output;
if (!json2pb::ProtoMessageToJson(*resp_pb, &json_output, &err_msg)) {
call_data->finish_with_error(err_msg);
LOG(ERROR) << "ProtoMessageToJson failed: " << err_msg;
return;
}
LOG(INFO) << "ProtoMessageToJson: " << json_output;
call_data->write_and_finish(json_output);
}
} // namespace

void XllmHttpServiceImpl::get_serving(
const std::string& serving_method,
void XllmHttpServiceImpl::get_serving_models(
::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
Expand All @@ -269,11 +274,15 @@ void XllmHttpServiceImpl::get_serving(
cntl->SetFailed("brpc request | respose | controller is null");
return;
}
auto arena = response->GetArena();
auto req_pb =
google::protobuf::Arena::CreateMessage<::xllm::proto::ModelListRequest>(
arena);

// auto call_data = std::make_shared<StreamCallData>(cntl, false,
// done_guard.release());
auto call_data = std::make_shared<CompletionCallData>(
cntl, false, done_guard.release(), nullptr);
cntl, false, done_guard.release(), nullptr, nullptr);

auto service_request = std::make_shared<Request>();
if (!scheduler_->schedule(service_request)) {
Expand All @@ -284,22 +293,13 @@ void XllmHttpServiceImpl::get_serving(

brpc::Channel* channel_ptr =
scheduler_->get_channel(service_request->routing.prefill_name).get();
std::string target_uri =
service_request->routing.prefill_name + serving_method;

thread_pool_->schedule(
[/*req_attachment, */ call_data, cntl, channel_ptr, target_uri]() {
brpc::Controller* redirect_cntl = new brpc::Controller();
redirect_cntl->http_request().uri() = target_uri.c_str();
redirect_cntl->http_request().set_method(brpc::HTTP_METHOD_GET);

google::protobuf::Closure* done = brpc::NewCallback(
&handle_get_response, redirect_cntl, call_data, done);

// Because `done'(last parameter) is NULL, this function waits until
// the response comes back or error occurs(including timeout).
channel_ptr->CallMethod(NULL, redirect_cntl, NULL, NULL, done);
});

xllm::proto::XllmAPIService_Stub stub(channel_ptr);
brpc::Controller* redirect_cntl = new brpc::Controller();
auto* resp_pb = new xllm::proto::ModelListResponse();
google::protobuf::Closure* done_callback = brpc::NewCallback(
&handle_get_model_response, redirect_cntl, call_data, done, resp_pb);
stub.Models(redirect_cntl, req_pb, resp_pb, done_callback);
}

void XllmHttpServiceImpl::Completions(
Expand Down Expand Up @@ -359,16 +359,9 @@ void XllmHttpServiceImpl::Completions(
req_pb->mutable_routing()->set_decode_name(
service_request->routing.decode_name);

std::string req_attachment;
if (!json2pb::ProtoMessageToJson(*req_pb, &req_attachment)) {
cntl->SetFailed("proto to json failed");
LOG(ERROR) << "proto to json failed";
return;
}

auto call_data = std::make_shared<CompletionCallData>(
cntl, service_request->stream, done_guard.release(), resp_pb);
handle(call_data, req_attachment, service_request, "/v1/completions");
cntl, service_request->stream, done_guard.release(), req_pb, resp_pb);
handle(call_data, service_request);
}

void XllmHttpServiceImpl::ChatCompletions(
Expand All @@ -393,12 +386,17 @@ void XllmHttpServiceImpl::ChatCompletions(
google::protobuf::Arena::CreateMessage<::xllm::proto::ChatResponse>(
arena);

std::string attachment = std::move(cntl->request_attachment().to_string());
std::string error;
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
if (!st) {
cntl->SetFailed(error);
LOG(ERROR) << "parse json to proto failed: " << error;
auto content_len = GetJsonContentLength(cntl);
std::string attachment;
cntl->request_attachment().copy_to(&attachment, content_len, 0);

google::protobuf::util::JsonParseOptions options;
options.ignore_unknown_fields = true;
auto status =
google::protobuf::util::JsonStringToMessage(attachment, req_pb, options);
if (!status.ok()) {
cntl->SetFailed(status.ToString());
LOG(ERROR) << "parse json to proto failed: " << status.ToString();
return;
}

Expand Down Expand Up @@ -430,16 +428,9 @@ void XllmHttpServiceImpl::ChatCompletions(
req_pb->mutable_routing()->set_decode_name(
service_request->routing.decode_name);

std::string req_attachment;
if (!json2pb::ProtoMessageToJson(*req_pb, &req_attachment)) {
cntl->SetFailed("proto to json failed");
LOG(ERROR) << "proto to json failed";
return;
}

auto call_data = std::make_shared<ChatCallData>(
cntl, service_request->stream, done_guard.release(), resp_pb);
handle(call_data, req_attachment, service_request, "/v1/chat/completions");
cntl, service_request->stream, done_guard.release(), req_pb, resp_pb);
handle(call_data, service_request);
}

void XllmHttpServiceImpl::Embeddings(
Expand All @@ -465,14 +456,15 @@ void XllmHttpServiceImpl::Models(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) {
get_serving("/v1/models", controller, request, response, done);
get_serving_models(controller, request, response, done);
}

void XllmHttpServiceImpl::Metrics(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) {
get_serving("/metrics", controller, request, response, done);
ClosureGuard done_guard(done);
// TODO: implement metrics endpoint
}

} // namespace xllm_service
12 changes: 7 additions & 5 deletions xllm_service/http_service/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class XllmHttpServiceImpl : public proto::XllmHttpService {
std::shared_ptr<Request> request,
const std::string& method);

void get_serving(const std::string& serving_method,
::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done);
template <typename T>
void handle(std::shared_ptr<T> call_data, std::shared_ptr<Request> request);

void get_serving_models(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done);

private:
Options options_;
Expand Down
Loading