From 0b05cc958729bc76f4513b97a4bf8616b38dc270 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Sun, 4 Feb 2024 23:13:13 +0800 Subject: [PATCH 01/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9openai=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=EF=BC=8C=E8=AE=A9openai=E5=8F=AF=E4=BB=A5=E4=BD=BF?= =?UTF-8?q?=E7=94=A8function=E7=9A=84=E6=96=B9=E5=BC=8F=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E8=A1=A8=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/ChatInput/index.tsx | 7 +- .../components/SelectBoundInfo/index.tsx | 2 +- .../web/api/controller/ai/ChatController.java | 105 ++++++++++- .../chat2db/client/Chat2DBAIStreamClient.java | 6 - .../ai/openai/client/OpenAIClient.java | 13 +- .../listener/OpenAIEventSourceListener.java | 166 +++++++++++++++--- .../ai/chat2db/spi/sql/Chat2DBContext.java | 1 + chat2db-server/pom.xml | 2 +- 8 files changed, 255 insertions(+), 47 deletions(-) diff --git a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx index 592755979..997b8bd3d 100644 --- a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx +++ b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx @@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => { }; const renderSelectTable = () => { - const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props; + const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props; const options = (tables || []).map((t) => ({ value: t, label: t })); return (
onSelectTableSyncModel(v.target.value)} - // value={syncTableModel} - value={SyncModelType.MANUAL} + value={syncTableModel} style={{ marginBottom: '8px' }} > - {/* 自动 */} + 自动 手动 diff --git a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx index e4318a2a9..4be3b4b20 100644 --- a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx +++ b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx @@ -186,7 +186,7 @@ const SelectBoundInfo = memo((props: IProps) => { boundInfo.databaseName, boundInfo.schemaName, ); - setSelectedTables(tableNameListTemp.slice(0, 1)); + //setSelectedTables(tableNameListTemp.slice(0, 1)); } }, [allTableList, isActive]); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index c9e77806f..aab53bdaf 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -54,11 +54,20 @@ import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse; import ai.chat2db.server.web.api.http.response.TableSchemaResponse; import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import ai.chat2db.spi.MetaData; +import ai.chat2db.spi.model.Table; +import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.sql.ConnectInfo; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; import com.alibaba.fastjson2.JSON; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.Parameters; +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; @@ -171,7 +180,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc /** * 自定义模型非流式输出接口DEMO *

- * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致 + * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致 *

* * @param queryRequest @@ -276,11 +285,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter * @throws IOException */ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) - throws IOException { - String prompt = buildPrompt(queryRequest); + throws IOException { + String prompt = buildPrompt2(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, - prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); + prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); throw new ParamBusinessException(); } @@ -290,9 +299,28 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); messages.add(currentMessage); buildSseEmitter(sseEmitter, uid); - - OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter); - OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener); + ConnectInfo connectInfo = Chat2DBContext.getConnectInfo(); + OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo, queryRequest); + ToolsFunction function = ToolsFunction.builder() + .name("get_table_columns") + .description("获取指定表的字段名,类型") + .parameters(Parameters.builder() + .type("object") + .properties(ImmutableMap.builder() + .put("table_name", ImmutableMap.builder() + .put("type", "string") + .put("description", "表名,例如```User```") + .build()) + .build()) + .required(List.of("table_name")) + .build()) + .build(); + ChatCompletion chatCompletion = ChatCompletion.builder() + .model("gpt-3.5-turbo-1106") + .tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))) + .toolChoice("auto") + .messages(messages).stream(true).build(); + OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener); LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); return sseEmitter; } @@ -630,6 +658,47 @@ private String buildPrompt(ChatQueryRequest queryRequest) { return cleanedInput; } + /** + * 构建prompt + * + * @param queryRequest + * @return + */ + private String buildPrompt2(ChatQueryRequest queryRequest) { + if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { + return queryRequest.getMessage(); + } + + // 查询schema信息 + String dataSourceType = queryDatabaseType(queryRequest); + String properties = ""; + if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { + properties = queryRequest.getTableNames().stream().collect(Collectors.joining(",")); + } else { + properties = queryDatabaseSchema2(queryRequest); + } + String prompt = queryRequest.getMessage(); + String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() + : queryRequest.getPromptType(); + PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; + String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( + "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " + + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, + properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", + pType.getDescription(), ext, prompt); + switch (pType) { + case SQL_2_SQL: + schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); + default: + break; + } + String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); + return cleanedInput; + } + /** * query chat2db apikey * @@ -727,6 +796,28 @@ public String queryDatabaseSchema(ChatQueryRequest queryRequest) { } } + + /** + * query database schema + * + * @param queryRequest + * @return + * @throws IOException + */ + public String queryDatabaseSchema2(ChatQueryRequest queryRequest) { + MetaData metaSchema = Chat2DBContext.getMetaData(); + try { + List tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null); + return tables.stream() + .map(table -> StringUtils.isBlank(table.getComment()) ? table.getName() + : table.getName() + "(" + table.getComment() + ")") + .collect(Collectors.joining(",")); + } catch (Exception e) { + log.error("query table error:{}, do nothing", e.getMessage()); + return ""; + } + } + /** * query database schema * diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java index 0f0b6d84f..295d39cff 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java @@ -1,21 +1,15 @@ package ai.chat2db.server.web.api.controller.ai.chat2db.client; -import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; -import ai.chat2db.server.domain.api.model.Config; -import ai.chat2db.server.domain.api.service.ConfigService; -import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; -import ai.chat2db.server.web.api.util.ApplicationContextUtil; import cn.hutool.http.ContentType; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; -import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import okhttp3.*; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java index 9ebf711c2..1d3de3bc7 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java @@ -4,6 +4,7 @@ import java.net.InetSocketAddress; import java.net.Proxy; import java.util.Objects; +import java.util.concurrent.TimeUnit; import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.service.ConfigService; @@ -93,7 +94,17 @@ public static void refresh() { log.info("refresh openai apikey:{}", maskApiKey(apikey)); if (Objects.nonNull(host) && Objects.nonNull(port)) { Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port)); - OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build(); + OkHttpClient okHttpClient = new OkHttpClient.Builder() + // 设置连接超时为10秒 + .connectTimeout(10, TimeUnit.SECONDS) + // 设置读取超时为30秒 + .readTimeout(30, TimeUnit.SECONDS) + // 设置写入超时为15秒 + .writeTimeout(15, TimeUnit.SECONDS) + // 设置整个调用的超时为1分钟 + .callTimeout(1, TimeUnit.MINUTES) + .proxy(proxy) + .build(); OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey( Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build(); } else { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index ccadf6d68..099fd76b5 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -1,12 +1,18 @@ package ai.chat2db.server.web.api.controller.ai.openai.listener; -import java.util.Objects; - +import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse; - +import ai.chat2db.spi.MetaData; +import ai.chat2db.spi.sql.Chat2DBContext; +import ai.chat2db.spi.sql.ConnectInfo; +import com.alibaba.fastjson2.JSONObject; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; +import com.unfbx.chatgpt.entity.chat.BaseMessage; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.tool.ToolCallFunction; +import com.unfbx.chatgpt.entity.chat.tool.ToolCalls; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import okhttp3.Response; @@ -15,6 +21,10 @@ import okhttp3.sse.EventSourceListener; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + /** * 描述:OpenAIEventSourceListener * @@ -24,10 +34,78 @@ @Slf4j public class OpenAIEventSourceListener extends EventSourceListener { - private SseEmitter sseEmitter; + private final SseEmitter sseEmitter; + + private final List messages; - public OpenAIEventSourceListener(SseEmitter sseEmitter) { + private final ConnectInfo connectInfo; + + private final ChatQueryRequest queryRequest; + + private List toolCalls = new ArrayList<>(); + + + public OpenAIEventSourceListener(SseEmitter sseEmitter, List messages, ConnectInfo connectInfo, ChatQueryRequest queryRequest) { this.sseEmitter = sseEmitter; + this.messages = messages; + this.connectInfo = connectInfo; + this.queryRequest = queryRequest; + } + + public static List mergeToolCallsLists(List list1, List list2) { + List mergedList = new ArrayList<>(list1); + if (list2.isEmpty()) { + return mergedList; + } + ToolCalls item2 = list2.get(0); + boolean isMerged = false; + // 反向遍历 + for (int i = list1.size() - 1; i >= 0; i--) { + ToolCalls item1 = list1.get(i); + if (item2.getId() == null || Objects.equals(item1.getId(), item2.getId())) { + mergedList.set(i, mergeToolCalls(item1, item2)); + isMerged = true; + break; + } + } + if (!isMerged) { + // 如果 list2 中的对象与 list1 中的任何对象都不匹配,则作为新对象添加 + mergedList.add(item2); + } + return mergedList; + } + + private static ToolCalls mergeToolCalls(ToolCalls tc1, ToolCalls tc2) { + if (tc1 == null) return tc2; + if (tc2 == null) return tc1; + + // 相同的逻辑,只是当 id 为 null 时进行合并 + String id = tc1.getId() != null ? tc1.getId() : tc2.getId(); + String type = mergeStrings(tc1.getType(), tc2.getType()); + ToolCallFunction function = mergeToolCallFunctions(tc1.getFunction(), tc2.getFunction()); + + return new ToolCalls(id, type, function); + } + + private static ToolCallFunction mergeToolCallFunctions(ToolCallFunction f1, ToolCallFunction f2) { + if (f1 == null) return f2; + if (f2 == null) return f1; + + String name = mergeStrings(f1.getName(), f2.getName()); + String arguments = mergeStrings(f1.getArguments(), f2.getArguments()); + + return new ToolCallFunction(name, arguments); + } + + private static String mergeStrings(String str1, String str2) { + if (str1 != null && str2 != null) { + // Concatenate both strings + return str1 + str2; + } else if (str1 != null) { + return str1; + } else { + return str2; + } } /** @@ -46,35 +124,69 @@ public void onOpen(EventSource eventSource, Response response) { public void onEvent(EventSource eventSource, String id, String type, String data) { log.info("OpenAI返回数据:{}", data); if (data.equals("[DONE]")) { - log.info("OpenAI返回数据结束了"); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); + if (toolCalls.isEmpty()) { + log.info("OpenAI返回数据结束了"); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]") + .reconnectTime(3000)); + sseEmitter.complete(); + return; + } + messages.add(Message.builder() + .toolCalls(toolCalls) + .role(BaseMessage.Role.ASSISTANT).build()); + Chat2DBContext.putContext(connectInfo); + try { + for (ToolCalls toolCall : toolCalls) { + String callId = toolCall.getId(); + ToolCallFunction function = toolCall.getFunction(); + if (function != null && Objects.nonNull(function.getArguments())) { + String functionName = function.getName(); + JSONObject arguments = JSONObject.parse(function.getArguments()); + if ("get_table_columns".equals(functionName)) { + MetaData metaSchema = Chat2DBContext.getMetaData(); + String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), arguments.getString("table_name")); + messages.add(Message.builder().role(BaseMessage.Role.TOOL) + .toolCallId(callId) + .name(functionName) + .content(ddl) + .build()); + } + } + } + } finally { + Chat2DBContext.removeContext(); + } + OpenAIClient.getInstance().streamChatCompletion(messages, this); + toolCalls.clear(); return; } ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); // 读取Json ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); - String text = completionResponse.getChoices().get(0).getDelta() == null - ? completionResponse.getChoices().get(0).getText() - : completionResponse.getChoices().get(0).getDelta().getContent(); + Message delta = completionResponse.getChoices().get(0).getDelta(); + if (delta != null && delta.getToolCalls() != null) { + this.toolCalls = mergeToolCallsLists(this.toolCalls, delta.getToolCalls()); + } + String text = delta == null + ? completionResponse.getChoices().get(0).getText() + : delta.getContent(); Message message = new Message(); if (text != null) { message.setContent(text); sseEmitter.send(SseEmitter.event() - .id(completionResponse.getId()) - .data(message) - .reconnectTime(3000)); + .id(completionResponse.getId()) + .data(message) + .reconnectTime(3000)); } } @Override public void onClosed(EventSource eventSource) { - sseEmitter.complete(); - log.info("OpenAI关闭sse连接..."); +// sseEmitter.complete(); +// log.info("OpenAI关闭sse连接..."); } @Override @@ -88,11 +200,11 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { Message sseMessage = new Message(); sseMessage.setContent(message); sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); + .id("[ERROR]") + .data(sseMessage)); sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); + .id("[DONE]") + .data("[DONE]")); sseEmitter.complete(); return; } @@ -108,11 +220,11 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { Message message = new Message(); message.setContent("出现异常,请在帮助中查看详细日志:" + bodyString); sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); + .id("[ERROR]") + .data(message)); sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); + .id("[DONE]") + .data("[DONE]")); sseEmitter.complete(); } catch (Exception exception) { log.error("发送数据异常:", exception); diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java index 9e6fce81a..88183d9ad 100644 --- a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java @@ -142,6 +142,7 @@ public static void removeContext() { try { if (connection != null && !connection.isClosed()) { connection.close(); + connectInfo.setConnection(null); } } catch (SQLException e) { log.error("close connection error", e); diff --git a/chat2db-server/pom.xml b/chat2db-server/pom.xml index 16c693477..b5b0cf43a 100644 --- a/chat2db-server/pom.xml +++ b/chat2db-server/pom.xml @@ -222,7 +222,7 @@ com.unfbx chatgpt-java - 1.0.8 + 1.1.5 org.slf4j From 1b84afbd4814e0088b9bcce4e89fbc3e4df0091f Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Mon, 5 Feb 2024 09:05:49 +0800 Subject: [PATCH 02/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AE=B9=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/ai/chat2db/plugin/mysql/MysqlMetaData.java | 9 +++++++-- .../ai/openai/listener/OpenAIEventSourceListener.java | 11 +++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java index 40a291955..d08cc4a6a 100644 --- a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java +++ b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java @@ -33,8 +33,13 @@ public List databases(Connection connection) { @Override public String tableDDL(Connection connection, @NotEmpty String databaseName, String schemaName, @NotEmpty String tableName) { - String sql = "SHOW CREATE TABLE " + format(databaseName) + "." - + format(tableName); + String sql; + if(StringUtils.isEmpty(databaseName)) { + sql = "SHOW CREATE TABLE " + format(tableName); + }else{ + sql = "SHOW CREATE TABLE " + format(databaseName) + "." + + format(tableName); + } return SQLExecutor.getInstance().execute(connection, sql, resultSet -> { if (resultSet.next()) { return resultSet.getString("Create Table"); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 099fd76b5..e30ff1c21 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -19,6 +19,7 @@ import okhttp3.ResponseBody; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import org.apache.commons.lang3.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.ArrayList; @@ -146,11 +147,17 @@ public void onEvent(EventSource eventSource, String id, String type, String data JSONObject arguments = JSONObject.parse(function.getArguments()); if ("get_table_columns".equals(functionName)) { MetaData metaSchema = Chat2DBContext.getMetaData(); - String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), arguments.getString("table_name")); + String content; + try { + content = metaSchema.tableDDL(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), arguments.getString("table_name")); + }catch (Exception e){ + log.error("OpenAI查询表结构失败",e); + content = StringUtils.defaultString(e.getMessage(), "OpenAI查询表结构失败"); + } messages.add(Message.builder().role(BaseMessage.Role.TOOL) .toolCallId(callId) .name(functionName) - .content(ddl) + .content(content) .build()); } } From 16f3dc24e6572ae8d5dfa98e2a80b93d970d3a0d Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Wed, 7 Feb 2024 15:20:37 +0800 Subject: [PATCH 03/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=91=98=E8=A6=81?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E6=96=B9=E5=BC=8F=EF=BC=8C=E4=B9=8B=E5=89=8D?= =?UTF-8?q?=E7=94=A8=E5=AE=98=E7=BD=91=E7=9A=84=E5=A4=AA=E6=B5=AA=E8=B4=B9?= =?UTF-8?q?token=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web/api/controller/ai/ChatController.java | 384 ++---------------- .../controller/ai/EmbeddingController.java | 6 +- .../controller/ai/KnowledgeController.java | 4 +- .../ai/TextGenerationController.java | 4 +- .../listener/OpenAIEventSourceListener.java | 69 ++-- .../controller/ai/utils/PromptService.java | 364 +++++++++++++++++ 6 files changed, 432 insertions(+), 399 deletions(-) create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index aab53bdaf..ff93fcbd4 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -1,5 +1,7 @@ package ai.chat2db.server.web.api.controller.ai; + + import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.model.DataSource; @@ -11,6 +13,8 @@ import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.tools.common.util.ContextUtils; import ai.chat2db.server.tools.common.util.EasyEnumUtils; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; @@ -41,6 +45,7 @@ import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.tongyi.client.TongyiChatAIClient; import ai.chat2db.server.web.api.controller.ai.tongyi.listener.TongyiChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient; import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; @@ -68,6 +73,7 @@ import com.unfbx.chatgpt.entity.chat.Parameters; import com.unfbx.chatgpt.entity.chat.tool.Tools; import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; +import com.unfbx.chatgpt.entity.chat.BaseChatCompletion.Model; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; @@ -99,14 +105,6 @@ @Slf4j public class ChatController { - @Autowired - private TableService tableService; - - @Autowired - private ChatConverter chatConverter; - - @Autowired - private DataSourceService dataSourceService; @Value("${chatgpt.context.length}") private Integer contextLength; @@ -117,6 +115,10 @@ public class ChatController { @Resource private GatewayClientService gatewayClientService; + + @Resource + protected PromptService promptService; + /** * chat的超时时间 */ @@ -271,7 +273,7 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE */ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) { RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter); - RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener); + RestAIClient.getInstance().restCompletions(promptService.buildPrompt(prompt), eventSourceListener); return sseEmitter; } @@ -286,7 +288,7 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter */ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt2(queryRequest); + String prompt = promptService.buildAutoPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); @@ -299,9 +301,12 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); messages.add(currentMessage); buildSseEmitter(sseEmitter, uid); - ConnectInfo connectInfo = Chat2DBContext.getConnectInfo(); - OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo, queryRequest); - ToolsFunction function = ToolsFunction.builder() + LoginUser loginUser = ContextUtils.getLoginUser(); + OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, promptService, queryRequest,loginUser); + ChatCompletion chatCompletion = ChatCompletion.builder() + .messages(messages).stream(true).build(); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = ToolsFunction.builder() .name("get_table_columns") .description("获取指定表的字段名,类型") .parameters(Parameters.builder() @@ -315,11 +320,10 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE .required(List.of("table_name")) .build()) .build(); - ChatCompletion chatCompletion = ChatCompletion.builder() - .model("gpt-3.5-turbo-1106") - .tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))) - .toolChoice("auto") - .messages(messages).stream(true).build(); + chatCompletion.setModel("gpt-3.5-turbo-0125"); + chatCompletion.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + chatCompletion.setToolChoice("auto"); + } OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener); LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); return sseEmitter; @@ -336,7 +340,7 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE */ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("exceed max token length:{},input length:{}", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); @@ -366,7 +370,7 @@ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter s * @throws IOException */ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); @@ -401,7 +405,7 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse * @throws IOException */ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -422,7 +426,7 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter * @throws IOException */ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -443,7 +447,7 @@ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter * @throws IOException */ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -464,7 +468,7 @@ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitte * @throws IOException */ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -506,7 +510,7 @@ private List getFastChatMessage(String uid, String prompt) { * @throws IOException */ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); if (messages.size() >= 2 && messages.size() % 2 == 0) { messages.remove(messages.size() - 1); @@ -531,7 +535,7 @@ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter ss * @throws IOException */ private SseEmitter chatWithClaudeAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); ClaudeChatMessage claudeChatMessage = new ClaudeChatMessage(); claudeChatMessage.setText(prompt); ClaudeChatCompletionsOptions chatCompletionsOptions = new ClaudeChatCompletionsOptions(); @@ -574,333 +578,5 @@ private SseEmitter buildSseEmitter(SseEmitter sseEmitter, String uid) throws IOE return sseEmitter; } - /** - * 构建schema参数 - * - * @param tableQueryParam - * @param tableNames - * @return - */ - private String buildTableColumn(TableQueryParam tableQueryParam, - List tableNames) { - if (CollectionUtils.isEmpty(tableNames)) { - return ""; - } - List schemaContent = Lists.newArrayList(); - try { - schemaContent = tableNames.stream().map(tableName -> { - tableQueryParam.setTableName(tableName); - return queryTableDdl(tableName, tableQueryParam); - }).collect(Collectors.toList()); - } catch (Exception exception) { - log.error("query table error, do nothing"); - } - - return JSON.toJSONString(schemaContent); - } - - /** - * query table schema - * - * @param tableName - * @param request - * @return - */ - private String queryTableDdl(String tableName, TableQueryParam request) { - ShowCreateTableParam param = new ShowCreateTableParam(); - param.setTableName(tableName); - param.setDataSourceId(request.getDataSourceId()); - param.setDatabaseName(request.getDatabaseName()); - param.setSchemaName(request.getSchemaName()); - DataResult tableSchema = tableService.showCreateTable(param); - return tableSchema.getData(); - } - - /** - * 构建prompt - * - * @param queryRequest - * @return - */ - private String buildPrompt(ChatQueryRequest queryRequest) { - if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { - return queryRequest.getMessage(); - } - - // 查询schema信息 - String dataSourceType = queryDatabaseType(queryRequest); - String properties = ""; - if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { - TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); - properties = buildTableColumn(queryParam, queryRequest.getTableNames()); - } else { - properties = mappingDatabaseSchema(queryRequest); - } - String prompt = queryRequest.getMessage(); - String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() - : queryRequest.getPromptType(); - PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); - String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; - String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( - "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# " - + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, - properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", - pType.getDescription(), ext, prompt); - switch (pType) { - case SQL_2_SQL: - schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); - default: - break; - } - String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); - return cleanedInput; - } - - /** - * 构建prompt - * - * @param queryRequest - * @return - */ - private String buildPrompt2(ChatQueryRequest queryRequest) { - if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { - return queryRequest.getMessage(); - } - - // 查询schema信息 - String dataSourceType = queryDatabaseType(queryRequest); - String properties = ""; - if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { - properties = queryRequest.getTableNames().stream().collect(Collectors.joining(",")); - } else { - properties = queryDatabaseSchema2(queryRequest); - } - String prompt = queryRequest.getMessage(); - String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() - : queryRequest.getPromptType(); - PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); - String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; - String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( - "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " - + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, - properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", - pType.getDescription(), ext, prompt); - switch (pType) { - case SQL_2_SQL: - schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); - default: - break; - } - String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); - return cleanedInput; - } - - /** - * query chat2db apikey - * - * @return - */ - public String getApiKey() { - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); - String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); - // only sync for chat2db ai - if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) { - return null; - } - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return null; - } - return keyConfig.getContent(); - } - - /** - * query database type - * - * @param queryRequest - * @return - */ - public String queryDatabaseType(ChatQueryRequest queryRequest) { - // 查询schema信息 - DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId()); - String dataSourceType = dataResult.getData().getType(); - if (StringUtils.isBlank(dataSourceType)) { - dataSourceType = "MYSQL"; - } - return dataSourceType; - } - - public String mappingDatabaseSchema(ChatQueryRequest queryRequest) { - String properties = ""; - String apiKey = getApiKey(); - if (StringUtils.isNotBlank(apiKey)) { - boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData(); - if (res) { -// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest); - properties = queryDatabaseSchema(queryRequest); - } - } - return properties; - } - - /** - * query database schema - * - * @param queryRequest - * @return - * @throws IOException - */ - public String queryDatabaseSchema(ChatQueryRequest queryRequest) { - // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); - List> contentVector = new ArrayList<>(); - if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) { - return ""; - } - contentVector.add(response.getData().get(0).getEmbedding()); - - // search embedding - TableSchemaRequest tableSchemaRequest = new TableSchemaRequest(); - tableSchemaRequest.setSchemaVector(contentVector); - tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); - tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); - tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName()); - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return ""; - } - tableSchemaRequest.setApiKey(keyConfig.getContent()); - try { - DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest); - List schemas = Lists.newArrayList(); - if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { - for(TableSchema data: result.getData().getTableSchemas()){ - schemas.add(data.getTableSchema()); - } - } - if (CollectionUtils.isEmpty(schemas)) { - return ""; - } - String res = JSON.toJSONString(schemas); - log.info("search vector result:{}", res); - return res; - } catch (Exception exception) { - log.error("query table error, do nothing"); - return ""; - } - } - - - /** - * query database schema - * - * @param queryRequest - * @return - * @throws IOException - */ - public String queryDatabaseSchema2(ChatQueryRequest queryRequest) { - MetaData metaSchema = Chat2DBContext.getMetaData(); - try { - List
tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null); - return tables.stream() - .map(table -> StringUtils.isBlank(table.getComment()) ? table.getName() - : table.getName() + "(" + table.getComment() + ")") - .collect(Collectors.joining(",")); - } catch (Exception e) { - log.error("query table error:{}, do nothing", e.getMessage()); - return ""; - } - } - - /** - * query database schema - * - * @param queryRequest - * @return - * @throws IOException - */ - public String querySchemaByEs(ChatQueryRequest queryRequest) { - // search embedding - EsTableSchemaRequest tableSchemaRequest = new EsTableSchemaRequest(); - tableSchemaRequest.setSearchKey(queryRequest.getMessage()); - tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); - tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); - tableSchemaRequest.setSchemaName(queryRequest.getSchemaName()); - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return ""; - } - tableSchemaRequest.setApiKey(keyConfig.getContent()); - try { - DataResult result = gatewayClientService.schemaEsSearch(tableSchemaRequest); - List schemas = Lists.newArrayList(); - if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { - for(EsTableSchema data: result.getData().getTableSchemas()){ - schemas.add(data.getTableSchemaContent()); - } - } - if (CollectionUtils.isEmpty(schemas)) { - return ""; - } - String res = JSON.toJSONString(schemas); - log.info("search es result:{}", res); - return res; - } catch (Exception exception) { - log.error("query es table error, do nothing"); - return ""; - } - } - - /** - * distribute embedding with different AI - * - * @return - */ - public FastChatEmbeddingResponse distributeAIEmbedding(String input) { - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); - String aiSqlSource = config.getContent(); - if (Objects.isNull(aiSqlSource)) { - return null; - } - AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource); - switch (Objects.requireNonNull(aiSqlSourceEnum)) { - case CHAT2DBAI: - return embeddingWithChat2dbAi(input); - case FASTCHATAI: - return embeddingWithFastChatAi(input); - } - return null; - } - - /** - * embedding with fast chat openai - * - * @param input - * @return - * @throws IOException - */ - private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) { - FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input); - return response; - } - - /** - * embedding with open ai - * - * @param input - * @return - */ - private FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) { - FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input); - return embeddings; - } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java index c8c694309..70df31c6f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java @@ -242,7 +242,7 @@ public void syncTableVector(TableBriefQueryRequest param) throws Exception { return; } - String apiKey = getApiKey(); + String apiKey = promptService.getApiKey(); if (StringUtils.isBlank(apiKey)) { return; } @@ -281,7 +281,7 @@ private void saveTableEmbedding(String tableSchema, TableSchemaRequest tableSche List> contentVector = new ArrayList<>(); for(String str : schemaList){ // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(str); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str); if(response == null){ throw new ParamBusinessException(); } @@ -310,7 +310,7 @@ public void syncTableEs(TableBriefQueryRequest param) throws Exception { return; } - String apiKey = getApiKey(); + String apiKey = promptService.getApiKey(); if (StringUtils.isBlank(apiKey)) { return; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java index 6ff16ee09..6ef0731ac 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java @@ -70,7 +70,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request) contentWordCount.add(str.length()); // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(str); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str); if(response == null){ continue; } @@ -97,7 +97,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request) public SseEmitter search(ChatQueryRequest queryRequest, @RequestHeader Map headers) throws Exception { // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(queryRequest.getMessage()); List> contentVector = new ArrayList<>(); contentVector.add(response.getData().get(0).getEmbedding()); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java index 0c6180667..94caf7d4f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java @@ -63,8 +63,8 @@ public SseEmitter prompt(ChatQueryRequest queryRequest, @RequestHeader Map messages; - - private final ConnectInfo connectInfo; + private final PromptService promptService;; private final ChatQueryRequest queryRequest; + private final LoginUser loginUser; + private List toolCalls = new ArrayList<>(); - public OpenAIEventSourceListener(SseEmitter sseEmitter, List messages, ConnectInfo connectInfo, ChatQueryRequest queryRequest) { + public OpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, ChatQueryRequest queryRequest, LoginUser loginUser) { this.sseEmitter = sseEmitter; - this.messages = messages; - this.connectInfo = connectInfo; + this.promptService = promptService; this.queryRequest = queryRequest; + this.loginUser = loginUser; } public static List mergeToolCallsLists(List list1, List list2) { @@ -134,37 +134,30 @@ public void onEvent(EventSource eventSource, String id, String type, String data sseEmitter.complete(); return; } - messages.add(Message.builder() - .toolCalls(toolCalls) - .role(BaseMessage.Role.ASSISTANT).build()); - Chat2DBContext.putContext(connectInfo); - try { - for (ToolCalls toolCall : toolCalls) { - String callId = toolCall.getId(); - ToolCallFunction function = toolCall.getFunction(); - if (function != null && Objects.nonNull(function.getArguments())) { - String functionName = function.getName(); + List tableNames = new ArrayList<>(); + for (ToolCalls toolCall : toolCalls) { + String callId = toolCall.getId(); + ToolCallFunction function = toolCall.getFunction(); + if (function != null && Objects.nonNull(function.getArguments())) { + String functionName = function.getName(); + if ("get_table_columns".equals(functionName)) { JSONObject arguments = JSONObject.parse(function.getArguments()); - if ("get_table_columns".equals(functionName)) { - MetaData metaSchema = Chat2DBContext.getMetaData(); - String content; - try { - content = metaSchema.tableDDL(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), arguments.getString("table_name")); - }catch (Exception e){ - log.error("OpenAI查询表结构失败",e); - content = StringUtils.defaultString(e.getMessage(), "OpenAI查询表结构失败"); - } - messages.add(Message.builder().role(BaseMessage.Role.TOOL) - .toolCallId(callId) - .name(functionName) - .content(content) - .build()); - } + tableNames.add(arguments.getString("table_name")); } } - } finally { - Chat2DBContext.removeContext(); } + List messages = new ArrayList<>(); + queryRequest.setTableNames(tableNames); + ContextUtils.setContext(Context.builder() + .loginUser(loginUser) + .build()); + Dbutils.setSession(); + String prompt = promptService.buildPrompt(queryRequest); + Dbutils.removeSession(); + prompt = prompt.replaceAll("#", ""); + log.info(prompt); + Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); + messages.add(currentMessage); OpenAIClient.getInstance().streamChatCompletion(messages, this); toolCalls.clear(); return; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java new file mode 100644 index 000000000..3e9ba940e --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -0,0 +1,364 @@ +package ai.chat2db.server.web.api.controller.ai.utils; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import com.alibaba.fastjson2.JSON; +import com.google.common.collect.Lists; + +import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.model.DataSource; +import ai.chat2db.server.domain.api.param.ShowCreateTableParam; +import ai.chat2db.server.domain.api.param.TableQueryParam; +import ai.chat2db.server.domain.api.service.ConfigService; +import ai.chat2db.server.domain.api.service.DataSourceService; +import ai.chat2db.server.domain.api.service.TableService; +import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; +import ai.chat2db.server.tools.base.wrapper.result.DataResult; +import ai.chat2db.server.tools.common.util.EasyEnumUtils; +import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; +import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; +import ai.chat2db.server.web.api.controller.ai.enums.PromptType; +import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; +import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; +import ai.chat2db.server.web.api.http.GatewayClientService; +import ai.chat2db.server.web.api.http.model.EsTableSchema; +import ai.chat2db.server.web.api.http.model.TableSchema; +import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest; +import ai.chat2db.server.web.api.http.request.TableSchemaRequest; +import ai.chat2db.server.web.api.http.request.WhiteListRequest; +import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse; +import ai.chat2db.server.web.api.http.response.TableSchemaResponse; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import ai.chat2db.spi.MetaData; +import ai.chat2db.spi.model.Table; +import ai.chat2db.spi.sql.Chat2DBContext; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; + + +@Slf4j +@ConnectionInfoAspect +@Service +public class PromptService { + + + @Autowired + private TableService tableService; + + @Autowired + private DataSourceService dataSourceService; + + + @Autowired + private ChatConverter chatConverter; + + + @Resource + private GatewayClientService gatewayClientService; + + + /** + * 构建prompt + * + * @param queryRequest + * @return + */ + public String buildPrompt(ChatQueryRequest queryRequest) { + if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { + return queryRequest.getMessage(); + } + + // 查询schema信息 + String dataSourceType = queryDatabaseType(queryRequest); + String properties = ""; + if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { + TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); + properties = buildTableColumn(queryParam, queryRequest.getTableNames()); + } else { + properties = mappingDatabaseSchema(queryRequest); + } + String prompt = queryRequest.getMessage(); + String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() + : queryRequest.getPromptType(); + PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; + String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( + "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# " + + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, + properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", + pType.getDescription(), ext, prompt); + switch (pType) { + case SQL_2_SQL: + schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); + default: + break; + } + String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); + return cleanedInput; + } + + public String mappingDatabaseSchema(ChatQueryRequest queryRequest) { + String properties = ""; + String apiKey = getApiKey(); + if (StringUtils.isNotBlank(apiKey)) { + boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData(); + if (res) { +// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest); + properties = queryDatabaseSchema(queryRequest); + } + } + return properties; + } + + + /** + * query chat2db apikey + * + * @return + */ + public String getApiKey() { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); + // only sync for chat2db ai + if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) { + return null; + } + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return null; + } + return keyConfig.getContent(); + } + + /** + * 构建schema参数 + * + * @param tableQueryParam + * @param tableNames + * @return + */ + public String buildTableColumn(TableQueryParam tableQueryParam, + List tableNames) { + if (CollectionUtils.isEmpty(tableNames)) { + return ""; + } + List schemaContent = Lists.newArrayList(); + try { + schemaContent = tableNames.stream().map(tableName -> { + tableQueryParam.setTableName(tableName); + return queryTableDdl(tableName, tableQueryParam); + }).collect(Collectors.toList()); + } catch (Exception exception) { + log.error("query table error, do nothing"); + } + + return JSON.toJSONString(schemaContent); + } + + /** + * query table schema + * + * @param tableName + * @param request + * @return + */ + public String queryTableDdl(String tableName, TableQueryParam request) { + ShowCreateTableParam param = new ShowCreateTableParam(); + param.setTableName(tableName); + param.setDataSourceId(request.getDataSourceId()); + param.setDatabaseName(request.getDatabaseName()); + param.setSchemaName(request.getSchemaName()); + DataResult tableSchema = tableService.showCreateTable(param); + return tableSchema.getData(); + } + + /** + * query database schema + * + * @param queryRequest + * @return + * @throws IOException + */ + public String queryDatabaseSchema(ChatQueryRequest queryRequest) { + // request embedding + FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); + List> contentVector = new ArrayList<>(); + if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) { + return ""; + } + contentVector.add(response.getData().get(0).getEmbedding()); + + // search embedding + TableSchemaRequest tableSchemaRequest = new TableSchemaRequest(); + tableSchemaRequest.setSchemaVector(contentVector); + tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); + tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); + tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName()); + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return ""; + } + tableSchemaRequest.setApiKey(keyConfig.getContent()); + try { + DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest); + List schemas = Lists.newArrayList(); + if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { + for(TableSchema data: result.getData().getTableSchemas()){ + schemas.add(data.getTableSchema()); + } + } + if (CollectionUtils.isEmpty(schemas)) { + return ""; + } + String res = JSON.toJSONString(schemas); + log.info("search vector result:{}", res); + return res; + } catch (Exception exception) { + log.error("query table error, do nothing"); + return ""; + } + } + + /** + * distribute embedding with different AI + * + * @return + */ + public FastChatEmbeddingResponse distributeAIEmbedding(String input) { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + String aiSqlSource = config.getContent(); + if (Objects.isNull(aiSqlSource)) { + return null; + } + AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource); + switch (Objects.requireNonNull(aiSqlSourceEnum)) { + case CHAT2DBAI: + return embeddingWithChat2dbAi(input); + case FASTCHATAI: + return embeddingWithFastChatAi(input); + } + return null; + } + + /** + * embedding with fast chat openai + * + * @param input + * @return + * @throws IOException + */ + public FastChatEmbeddingResponse embeddingWithFastChatAi(String input) { + FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input); + return response; + } + + /** + * embedding with open ai + * + * @param input + * @return + */ + public FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) { + FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input); + return embeddings; + } + + /** + * 构建prompt + * + * @param queryRequest + * @return + */ + public String buildAutoPrompt(ChatQueryRequest queryRequest) { + if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { + return queryRequest.getMessage(); + } + + // 查询schema信息 + String dataSourceType = queryDatabaseType(queryRequest); + String properties = ""; + if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { + properties = queryRequest.getTableNames().stream().collect(Collectors.joining(",")); + } else { + properties = queryDatabaseTables(queryRequest); + } + String prompt = queryRequest.getMessage(); + String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() + : queryRequest.getPromptType(); + PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; + String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( + "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " + + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, + properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", + pType.getDescription(), ext, prompt); + switch (pType) { + case SQL_2_SQL: + schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); + default: + break; + } + String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); + return cleanedInput; + } + + + /** + * query database type + * + * @param queryRequest + * @return + */ + public String queryDatabaseType(ChatQueryRequest queryRequest) { + // 查询schema信息 + DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId()); + String dataSourceType = dataResult.getData().getType(); + if (StringUtils.isBlank(dataSourceType)) { + dataSourceType = "MYSQL"; + } + return dataSourceType; + } + + /** + * query database schema + * + * @param queryRequest + * @return + * @throws IOException + */ + public String queryDatabaseTables(ChatQueryRequest queryRequest) { + MetaData metaSchema = Chat2DBContext.getMetaData(); + try { + List
tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null); + return tables.stream() + .map(table -> StringUtils.isBlank(table.getComment()) ? table.getName() + : table.getName() + "(" + table.getComment() + ")") + .collect(Collectors.joining(",")); + } catch (Exception e) { + log.error("query table error:{}, do nothing", e.getMessage()); + return ""; + } + } + +} From d5a8216a67454dd537c2f5a54564ac0c36ec8ddc Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Wed, 14 Feb 2024 23:35:30 +0800 Subject: [PATCH 04/16] =?UTF-8?q?=E6=99=BA=E8=B0=B1ai=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E5=9B=9E=E8=B0=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web/api/controller/ai/ChatController.java | 2 +- .../zhipu/client/ZhipuChatAIStreamClient.java | 61 +++++++++++++------ .../model/ZhipuChatCompletionsOptions.java | 60 ++++++++++++++++++ 3 files changed, 105 insertions(+), 18 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index ff93fcbd4..5150b3db9 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -426,7 +426,7 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter * @throws IOException */ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = promptService.buildPrompt(queryRequest); + String prompt = promptService.buildAutoPrompt(queryRequest); List messages = getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java index 550c929eb..896e180b3 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java @@ -4,9 +4,15 @@ import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.zhipu.interceptor.ZhipuChatHeaderAuthorizationInterceptor; import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function.Parameters; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function.Property; import cn.hutool.http.ContentType; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; + import lombok.Getter; import lombok.extern.slf4j.Slf4j; import okhttp3.MediaType; @@ -19,6 +25,7 @@ import org.apache.commons.collections4.CollectionUtils; import org.jetbrains.annotations.NotNull; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -69,7 +76,6 @@ public class ZhipuChatAIStreamClient { @Getter private OkHttpClient okHttpClient; - /** * @param builder */ @@ -90,13 +96,12 @@ private ZhipuChatAIStreamClient(Builder builder) { * okhttpclient */ private OkHttpClient okHttpClient() { - OkHttpClient okHttpClient = new OkHttpClient - .Builder() - .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret)) - .connectTimeout(10, TimeUnit.SECONDS) - .writeTimeout(50, TimeUnit.SECONDS) - .readTimeout(50, TimeUnit.SECONDS) - .build(); + OkHttpClient okHttpClient = new OkHttpClient.Builder() + .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret)) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); return okHttpClient; } @@ -195,12 +200,34 @@ public void streamCompletions(List chatMessages, EventSourceLis } log.info("Zhipu Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); try { - // 建议直接查看demo包代码,这里更新可能不及时 - ZhipuChatCompletionsOptions completionsOptions = new ZhipuChatCompletionsOptions(); - completionsOptions.setPrompt(chatMessages); - completionsOptions.setModel(this.model); String requestId = String.valueOf(System.currentTimeMillis()); - completionsOptions.setRequestId(requestId); + // 建议直接查看demo包代码,这里更新可能不及时 + ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() + .requestId(requestId) + .stream(true) + .sseFormat("data") + .model(this.model) + .toolChoice("auto") + .prompt(chatMessages) + .tools(Arrays.asList( + Tool.builder() + .type("function") + .function(Function.builder() + .name("get_table_columns") + .description("获取指定表的字段名,类型") + .parameters(Parameters.builder() + .type("object") + .properties(ImmutableMap.builder() + .put("table_name", Property.builder() + .type("string") + .description("表名,例如```User```") + .build()) + .build()) + .required(Arrays.asList("table_name")) + .build()) + .build()) + .build())) + .build(); ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); String requestBody = mapper.writeValueAsString(completionsOptions); @@ -208,10 +235,10 @@ public void streamCompletions(List chatMessages, EventSourceLis String url = this.apiHost + "/" + this.model + "/" + "sse-invoke"; EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); Request request = new Request.Builder() - .url(url) - .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) - .build(); - //创建事件 + .url(url) + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + // 创建事件 EventSource eventSource = factory.newEventSource(request, eventSourceListener); log.info("finish invoking zhipu chat ai"); } catch (Exception e) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java index 4b6359cc2..4bcc82d4b 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java @@ -5,15 +5,19 @@ import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import com.fasterxml.jackson.annotation.JsonProperty; + +import lombok.Builder; import lombok.Data; import java.util.List; +import java.util.Map; /** * The configuration information for a chat completions request. Completions support a wide variety of tasks and * generate text that continues from or "completes" provided prompt data. */ @Data +@Builder public final class ZhipuChatCompletionsOptions { @JsonProperty(value = "request_id") @@ -45,4 +49,60 @@ public final class ZhipuChatCompletionsOptions { */ @JsonProperty(value = "model") private String model; + + + + // 新添加的参数 + @JsonProperty(value = "tool_choice") + private String toolChoice; // 工具选择策略 + + @JsonProperty(value = "tools") + private List tools; // 工具列表 + + // 工具类 + @Data + @Builder + public static class Tool { + @JsonProperty(value = "type") + private String type; + + @JsonProperty(value = "function") + private Function function; + + @Data + @Builder + public static class Function { + @JsonProperty(value = "name") + private String name; + + @JsonProperty(value = "description") + private String description; + + @JsonProperty(value = "parameters") + private Parameters parameters; + + @Data + @Builder + public static class Parameters { + @JsonProperty(value = "type") + private String type; + + @JsonProperty(value = "properties") + private Map properties; + + @JsonProperty(value = "required") + private List required; + } + + @Data + @Builder + public static class Property { + @JsonProperty(value = "type") + private String type; + + @JsonProperty(value = "description") + private String description; + } + } + } } From d8ac27db770ee6ce537048bb8eb16f85d572e28b Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Tue, 20 Feb 2024 11:40:43 +0800 Subject: [PATCH 05/16] =?UTF-8?q?=E6=99=BA=E8=B0=B1=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=8D=87=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web/api/controller/ai/ChatController.java | 95 +++--------- .../api/controller/ai/enums/PromptType.java | 6 + .../listener/OpenAIEventSourceListener.java | 17 ++- .../controller/ai/utils/PromptService.java | 57 ++++++- .../ai/zhipu/client/ZhipuChatAIClient.java | 4 +- .../zhipu/client/ZhipuChatAIStreamClient.java | 50 +------ .../ZhipuChatAIEventSourceListener.java | 139 ++++-------------- .../model/ZhipuChatCompletionsOptions.java | 59 +------- 8 files changed, 131 insertions(+), 296 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 5150b3db9..0222b3b05 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -4,18 +4,10 @@ import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; import ai.chat2db.server.domain.api.model.Config; -import ai.chat2db.server.domain.api.model.DataSource; -import ai.chat2db.server.domain.api.param.ShowCreateTableParam; -import ai.chat2db.server.domain.api.param.TableQueryParam; import ai.chat2db.server.domain.api.service.ConfigService; -import ai.chat2db.server.domain.api.service.DataSourceService; -import ai.chat2db.server.domain.api.service.TableService; -import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; -import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; import ai.chat2db.server.tools.common.model.LoginUser; import ai.chat2db.server.tools.common.util.ContextUtils; -import ai.chat2db.server.tools.common.util.EasyEnumUtils; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener; @@ -30,10 +22,7 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage; import ai.chat2db.server.web.api.controller.ai.config.LocalCache; -import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; -import ai.chat2db.server.web.api.controller.ai.enums.PromptType; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; -import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; import ai.chat2db.server.web.api.controller.ai.fastchat.listener.FastChatAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; @@ -50,48 +39,31 @@ import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; import ai.chat2db.server.web.api.controller.ai.zhipu.listener.ZhipuChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import ai.chat2db.server.web.api.http.GatewayClientService; -import ai.chat2db.server.web.api.http.model.EsTableSchema; -import ai.chat2db.server.web.api.http.model.TableSchema; -import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest; -import ai.chat2db.server.web.api.http.request.TableSchemaRequest; -import ai.chat2db.server.web.api.http.request.WhiteListRequest; -import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse; -import ai.chat2db.server.web.api.http.response.TableSchemaResponse; import ai.chat2db.server.web.api.util.ApplicationContextUtil; -import ai.chat2db.spi.MetaData; -import ai.chat2db.spi.model.Table; -import ai.chat2db.spi.sql.Chat2DBContext; -import ai.chat2db.spi.sql.ConnectInfo; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; -import com.alibaba.fastjson2.JSON; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; -import com.unfbx.chatgpt.entity.chat.Parameters; import com.unfbx.chatgpt.entity.chat.tool.Tools; import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; -import com.unfbx.chatgpt.entity.chat.BaseChatCompletion.Model; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.IOException; -import java.math.BigDecimal; import java.time.Duration; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; /** * 描述: @@ -306,20 +278,7 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE ChatCompletion chatCompletion = ChatCompletion.builder() .messages(messages).stream(true).build(); if(queryRequest.getDatabaseName()!=null){ - ToolsFunction function = ToolsFunction.builder() - .name("get_table_columns") - .description("获取指定表的字段名,类型") - .parameters(Parameters.builder() - .type("object") - .properties(ImmutableMap.builder() - .put("table_name", ImmutableMap.builder() - .put("type", "string") - .put("description", "表名,例如```User```") - .build()) - .build()) - .required(List.of("table_name")) - .build()) - .build(); + ToolsFunction function = PromptService.getToolsFunction(); chatCompletion.setModel("gpt-3.5-turbo-0125"); chatCompletion.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); chatCompletion.setToolChoice("auto"); @@ -406,7 +365,7 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse */ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -427,12 +386,25 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter */ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildAutoPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); - - ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter); - ZhipuChatAIClient.getInstance().streamCompletions(messages, sourceListener); + LoginUser loginUser = ContextUtils.getLoginUser(); + ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser); + String requestId = String.valueOf(System.currentTimeMillis()); + // 建议直接查看demo包代码,这里更新可能不及时 + ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() + .requestId(requestId) + .stream(true) + .toolChoice("auto") + .messages(messages) + .build(); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = PromptService.getToolsFunction(); + completionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + completionsOptions.setToolChoice("auto"); + } + ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, sourceListener); LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); return sseEmitter; } @@ -448,7 +420,7 @@ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter */ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -469,7 +441,7 @@ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitte */ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -479,26 +451,7 @@ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter return sseEmitter; } - /** - * get fast chat message - * - * @param uid - * @param prompt - * @return - */ - private List getFastChatMessage(String uid, String prompt) { - List messages = (List)LocalCache.CACHE.get(uid); - if (CollectionUtils.isNotEmpty(messages)) { - if (messages.size() >= contextLength) { - messages = messages.subList(1, contextLength); - } - } else { - messages = Lists.newArrayList(); - } - FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); - messages.add(currentMessage); - return messages; - } + /** * chat with wenxin chat openai @@ -511,7 +464,7 @@ private List getFastChatMessage(String uid, String prompt) { */ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + List messages = promptService.getFastChatMessage(uid, prompt); if (messages.size() >= 2 && messages.size() % 2 == 0) { messages.remove(messages.size() - 1); } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java index 9e9745c75..0135e5ea6 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java @@ -38,6 +38,12 @@ public enum PromptType implements BaseEnum { * text generation */ TEXT_GENERATION("文本生成"), + + + /** + * function call + */ + FUNCTION_CALL("获取指定表的字段名,类型"), ; final String description; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 7f3e6b4f5..6bd49a387 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -37,11 +37,11 @@ public class OpenAIEventSourceListener extends EventSourceListener { private final SseEmitter sseEmitter; - private final PromptService promptService;; + protected final PromptService promptService;; private final ChatQueryRequest queryRequest; - private final LoginUser loginUser; + public final LoginUser loginUser; private List toolCalls = new ArrayList<>(); @@ -117,6 +117,13 @@ public void onOpen(EventSource eventSource, Response response) { log.info("OpenAI建立sse连接..."); } + + public void functionCall(String prompt){ + List messages = new ArrayList<>(); + Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); + messages.add(currentMessage); + OpenAIClient.getInstance().streamChatCompletion(messages, this); + } /** * {@inheritDoc} */ @@ -146,7 +153,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data } } } - List messages = new ArrayList<>(); + queryRequest.setTableNames(tableNames); ContextUtils.setContext(Context.builder() .loginUser(loginUser) @@ -156,9 +163,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data Dbutils.removeSession(); prompt = prompt.replaceAll("#", ""); log.info(prompt); - Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); - messages.add(currentMessage); - OpenAIClient.getInstance().streamChatCompletion(messages, this); + functionCall(prompt); toolCalls.clear(); return; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index 3e9ba940e..5a1ae63d2 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -10,10 +10,14 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import com.alibaba.fastjson2.JSON; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import com.unfbx.chatgpt.entity.chat.Parameters; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; import ai.chat2db.server.domain.api.model.Config; @@ -28,19 +32,19 @@ import ai.chat2db.server.tools.common.util.EasyEnumUtils; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.ai.config.LocalCache; import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; import ai.chat2db.server.web.api.controller.ai.enums.PromptType; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; import ai.chat2db.server.web.api.http.GatewayClientService; -import ai.chat2db.server.web.api.http.model.EsTableSchema; import ai.chat2db.server.web.api.http.model.TableSchema; -import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest; import ai.chat2db.server.web.api.http.request.TableSchemaRequest; import ai.chat2db.server.web.api.http.request.WhiteListRequest; -import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse; import ai.chat2db.server.web.api.http.response.TableSchemaResponse; import ai.chat2db.server.web.api.util.ApplicationContextUtil; import ai.chat2db.spi.MetaData; @@ -56,6 +60,10 @@ public class PromptService { + @Value("${chatgpt.context.length}") + private Integer contextLength; + + @Autowired private TableService tableService; @@ -292,7 +300,6 @@ public String buildAutoPrompt(ChatQueryRequest queryRequest) { if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { return queryRequest.getMessage(); } - // 查询schema信息 String dataSourceType = queryDatabaseType(queryRequest); String properties = ""; @@ -305,6 +312,10 @@ public String buildAutoPrompt(ChatQueryRequest queryRequest) { String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() : queryRequest.getPromptType(); PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + if (pType.equals(PromptType.NL_2_SQL)) { + pType = PromptType.FUNCTION_CALL; + } + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " @@ -361,4 +372,42 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { } } + public static ToolsFunction getToolsFunction(){ + return ToolsFunction.builder() + .name("get_table_columns") + .description("获取指定表的字段名,类型") + .parameters(Parameters.builder() + .type("object") + .properties(ImmutableMap.builder() + .put("table_name", ImmutableMap.builder() + .put("type", "string") + .put("description", "表名,例如```User```") + .build()) + .build()) + .required(List.of("table_name")) + .build()) + .build(); + } + + + /** + * get fast chat message + * + * @param uid + * @param prompt + * @return + */ + public List getFastChatMessage(String uid, String prompt) { + List messages = (List)LocalCache.CACHE.get(uid); + if (CollectionUtils.isNotEmpty(messages)) { + if (messages.size() >= contextLength) { + messages = messages.subList(1, contextLength); + } + } else { + messages = Lists.newArrayList(); + } + FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); + messages.add(currentMessage); + return messages; + } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java index f205f17f5..db0d35fa6 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java @@ -58,8 +58,8 @@ private static ZhipuChatAIStreamClient singleton() { public static void refresh() { String apiKey = ""; - String apiHost = "https://open.bigmodel.cn/api/paas/v3/model-api/"; - String model = "chatglm_turbo"; + String apiHost = "https://open.bigmodel.cn/api/paas/v4/chat/completions"; + String model = "glm-4"; ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); Config apiHostConfig = configService.find(ZHIPU_HOST).getData(); if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java index 896e180b3..ef0ec8071 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java @@ -1,18 +1,11 @@ package ai.chat2db.server.web.api.controller.ai.zhipu.client; import ai.chat2db.server.tools.common.exception.ParamBusinessException; -import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.zhipu.interceptor.ZhipuChatHeaderAuthorizationInterceptor; import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function.Parameters; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions.Tool.Function.Property; import cn.hutool.http.ContentType; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableMap; - import lombok.Getter; import lombok.extern.slf4j.Slf4j; import okhttp3.MediaType; @@ -22,11 +15,8 @@ import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSources; -import org.apache.commons.collections4.CollectionUtils; import org.jetbrains.annotations.NotNull; -import java.util.Arrays; -import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -189,50 +179,20 @@ public ZhipuChatAIStreamClient build() { * @param chatMessages * @param eventSourceListener */ - public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { - if (CollectionUtils.isEmpty(chatMessages)) { - log.error("param error:Zhipu Chat Prompt cannot be empty"); - throw new ParamBusinessException("prompt"); - } + public void streamCompletions(ZhipuChatCompletionsOptions completionsOptions, EventSourceListener eventSourceListener) { + if (Objects.isNull(eventSourceListener)) { log.error("param error:Zhipu ChatEventSourceListener cannot be empty"); throw new ParamBusinessException(); } - log.info("Zhipu Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); + completionsOptions.setModel(this.model); try { - String requestId = String.valueOf(System.currentTimeMillis()); - // 建议直接查看demo包代码,这里更新可能不及时 - ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() - .requestId(requestId) - .stream(true) - .sseFormat("data") - .model(this.model) - .toolChoice("auto") - .prompt(chatMessages) - .tools(Arrays.asList( - Tool.builder() - .type("function") - .function(Function.builder() - .name("get_table_columns") - .description("获取指定表的字段名,类型") - .parameters(Parameters.builder() - .type("object") - .properties(ImmutableMap.builder() - .put("table_name", Property.builder() - .type("string") - .description("表名,例如```User```") - .build()) - .build()) - .required(Arrays.asList("table_name")) - .build()) - .build()) - .build())) - .build(); + ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); String requestBody = mapper.writeValueAsString(completionsOptions); - String url = this.apiHost + "/" + this.model + "/" + "sse-invoke"; + String url = this.apiHost; EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); Request request = new Request.Builder() .url(url) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java index a8b1ae016..5fd65b128 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java @@ -1,22 +1,19 @@ package ai.chat2db.server.web.api.controller.ai.zhipu.listener; +import ai.chat2db.server.tools.common.model.LoginUser; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.unfbx.chatgpt.entity.chat.Message; -import lombok.SneakyThrows; +import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; +import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import lombok.extern.slf4j.Slf4j; -import okhttp3.Response; -import okhttp3.ResponseBody; -import okhttp3.sse.EventSource; -import okhttp3.sse.EventSourceListener; -import org.apache.commons.lang3.StringUtils; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import java.io.IOException; +import java.util.List; import java.util.Objects; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + /** * 描述:OpenAIEventSourceListener * @@ -24,111 +21,25 @@ * @date 2023-02-22 */ @Slf4j -public class ZhipuChatAIEventSourceListener extends EventSourceListener { - - private SseEmitter sseEmitter; - - private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); - - public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter) { - this.sseEmitter = sseEmitter; - } - - /** - * {@inheritDoc} - */ - @Override - public void onOpen(EventSource eventSource, Response response) { - log.info("Zhipu Chat Sse connecting..."); - } - - /** - * {@inheritDoc} - */ - @SneakyThrows - @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("Zhipu Chat AI response data:{}", data); - if (data.equals("[DONE]")) { - log.info("Zhipu Chat AI closed"); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); - return; - } - - ZhipuChatCompletions chatCompletions = mapper.readValue(data, ZhipuChatCompletions.class); - String text = chatCompletions.getData(); - if (Objects.isNull(text)) { - for (FastChatMessage message : chatCompletions.getBody().getChoices()) { - if (message != null && message.getContent() != null) { - text = message.getContent(); - } - } - } - - Message message = new Message(); - message.setContent(text); - sseEmitter.send(SseEmitter.event() - .id(null) - .data(message) - .reconnectTime(3000)); +public class ZhipuChatAIEventSourceListener extends OpenAIEventSourceListener { + + public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, + ChatQueryRequest queryRequest, LoginUser loginUser) { + super(sseEmitter, promptService, queryRequest, loginUser); } - @Override - public void onClosed(EventSource eventSource) { - try { - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - } catch (IOException e) { - throw new RuntimeException(e); - } - sseEmitter.complete(); - log.info("ZhipuChatAI close sse connection..."); - } @Override - public void onFailure(EventSource eventSource, Throwable t, Response response) { - try { - if (Objects.isNull(response)) { - String message = t.getMessage(); - Message sseMessage = new Message(); - sseMessage.setContent(message); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - return; - } - ResponseBody body = response.body(); - String bodyString = Objects.nonNull(t) ? t.getMessage() : ""; - if (Objects.nonNull(body)) { - bodyString = body.string(); - if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) { - bodyString = t.getMessage(); - } - log.error("Zhipu Chat AI sse response:{}", bodyString); - } else { - log.error("Zhipu Chat AI sse response:{},error:{}", response, t); - } - eventSource.cancel(); - Message message = new Message(); - message.setContent("Zhipu Chat AI error:" + bodyString); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - } catch (Exception exception) { - log.error("Zhipu Chat AI send data error:", exception); - } + public void functionCall(String prompt){ + Long uid = loginUser.getId(); + List messages = promptService.getFastChatMessage(Objects.toString(uid), prompt); + String requestId = String.valueOf(System.currentTimeMillis()); + ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() + .requestId(requestId) + .stream(true) + .toolChoice("auto") + .messages(messages) + .build(); + ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, this); } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java index 4bcc82d4b..06c16bd07 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java @@ -5,12 +5,12 @@ import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.tool.Tools; import lombok.Builder; import lombok.Data; import java.util.List; -import java.util.Map; /** * The configuration information for a chat completions request. Completions support a wide variety of tasks and @@ -24,12 +24,9 @@ public final class ZhipuChatCompletionsOptions { private String requestId; // sse-params - @JsonProperty(value = "incremental") + @JsonProperty(value = "stream") private Boolean stream = true; - @JsonProperty(value = "sseFormat") - private String sseFormat = "data"; - /* * The collection of context messages associated with this chat completions request. @@ -37,8 +34,8 @@ public final class ZhipuChatCompletionsOptions { * the behavior of the assistant, followed by alternating messages between the User and * Assistant roles. */ - @JsonProperty(value = "prompt") - private List prompt; + @JsonProperty(value = "messages") + private List messages; // @@ -57,52 +54,6 @@ public final class ZhipuChatCompletionsOptions { private String toolChoice; // 工具选择策略 @JsonProperty(value = "tools") - private List tools; // 工具列表 - - // 工具类 - @Data - @Builder - public static class Tool { - @JsonProperty(value = "type") - private String type; - - @JsonProperty(value = "function") - private Function function; - - @Data - @Builder - public static class Function { - @JsonProperty(value = "name") - private String name; - - @JsonProperty(value = "description") - private String description; - - @JsonProperty(value = "parameters") - private Parameters parameters; - - @Data - @Builder - public static class Parameters { - @JsonProperty(value = "type") - private String type; - - @JsonProperty(value = "properties") - private Map properties; - - @JsonProperty(value = "required") - private List required; - } - - @Data - @Builder - public static class Property { - @JsonProperty(value = "type") - private String type; + private List tools; // 工具列表 - @JsonProperty(value = "description") - private String description; - } - } - } } From c330315881d9870a02519bd1047a7135fe39391c Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Wed, 21 Feb 2024 14:44:01 +0800 Subject: [PATCH 06/16] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E6=89=93=E5=8D=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web/api/controller/ai/ChatController.java | 1 + .../listener/OpenAIEventSourceListener.java | 39 +++++++++++++++---- .../controller/ai/utils/PromptService.java | 11 ++++-- .../ZhipuChatAIEventSourceListener.java | 11 +++++- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 0222b3b05..2a106fc10 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -386,6 +386,7 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter */ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { String prompt = promptService.buildAutoPrompt(queryRequest); + log.info("原始提示词{}",prompt); List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 6bd49a387..2d63e9f4c 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -8,6 +8,8 @@ import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse; import ai.chat2db.server.web.api.controller.ai.utils.PromptService; + +import com.alibaba.fastjson2.JSONArray; import com.alibaba.fastjson2.JSONObject; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -109,12 +111,16 @@ private static String mergeStrings(String str1, String str2) { } } + + public String getName() { + return "OpenAI"; + } /** * {@inheritDoc} */ @Override public void onOpen(EventSource eventSource, Response response) { - log.info("OpenAI建立sse连接..."); + log.info("{}建立sse连接...",getName()); } @@ -124,16 +130,32 @@ public void functionCall(String prompt){ messages.add(currentMessage); OpenAIClient.getInstance().streamChatCompletion(messages, this); } + + + public void handleTableNames(List tableNames,Object instance){ + if(instance instanceof JSONArray){ + ((JSONArray)instance).forEach(tableName->{ + handleTableNames(tableNames,tableName); + }); + }else if (instance instanceof JSONObject) { + ((JSONObject)instance).entrySet().forEach(entrySet->{ + handleTableNames(tableNames,entrySet.getValue()); + }); + }else if (instance instanceof String) { + tableNames.add((String)instance); + } + } /** * {@inheritDoc} */ @SneakyThrows @Override public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("OpenAI返回数据:{}", data); + String scheme = getName(); + log.info("{}返回数据:{}",scheme,data); if (data.equals("[DONE]")) { if (toolCalls.isEmpty()) { - log.info("OpenAI返回数据结束了"); + log.info("{}返回数据结束了",scheme); sseEmitter.send(SseEmitter.event() .id("[DONE]") .data("[DONE]") @@ -149,7 +171,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data String functionName = function.getName(); if ("get_table_columns".equals(functionName)) { JSONObject arguments = JSONObject.parse(function.getArguments()); - tableNames.add(arguments.getString("table_name")); + handleTableNames(tableNames,arguments.get("table_names")); } } } @@ -162,7 +184,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data String prompt = promptService.buildPrompt(queryRequest); Dbutils.removeSession(); prompt = prompt.replaceAll("#", ""); - log.info(prompt); + log.info("{} 新提示词 :{}",scheme,prompt); functionCall(prompt); toolCalls.clear(); return; @@ -196,6 +218,7 @@ public void onClosed(EventSource eventSource) { @Override public void onFailure(EventSource eventSource, Throwable t, Response response) { + String scheme = getName(); try { if (Objects.isNull(response)) { String message = t.getMessage(); @@ -217,9 +240,9 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { String bodyString = null; if (Objects.nonNull(body)) { bodyString = body.string(); - log.error("OpenAI sse连接异常data:{}", bodyString, t); + log.error("{} sse连接异常data:{}",scheme, bodyString, t); } else { - log.error("OpenAI sse连接异常data:{}", response, t); + log.error("{} sse连接异常data:{}",scheme, response, t); } eventSource.cancel(); Message message = new Message(); @@ -232,7 +255,7 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { .data("[DONE]")); sseEmitter.complete(); } catch (Exception exception) { - log.error("发送数据异常:", exception); + log.error("{}发送数据异常:", scheme,exception); } } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index 5a1ae63d2..d3ac972a9 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -304,7 +304,8 @@ public String buildAutoPrompt(ChatQueryRequest queryRequest) { String dataSourceType = queryDatabaseType(queryRequest); String properties = ""; if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { - properties = queryRequest.getTableNames().stream().collect(Collectors.joining(",")); + TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); + properties = buildTableColumn(queryParam, queryRequest.getTableNames()); } else { properties = queryDatabaseTables(queryRequest); } @@ -375,13 +376,15 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { public static ToolsFunction getToolsFunction(){ return ToolsFunction.builder() .name("get_table_columns") - .description("获取指定表的字段名,类型") + .description("获取指定表的属性") .parameters(Parameters.builder() .type("object") .properties(ImmutableMap.builder() - .put("table_name", ImmutableMap.builder() - .put("type", "string") + .put("table_names", ImmutableMap.builder() .put("description", "表名,例如```User```") + .put("type", "array") + .put("items", ImmutableMap.of("type", "string")) + .put("uniqueItems", true) .build()) .build()) .required(List.of("table_name")) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java index 5fd65b128..abc07a8e1 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java @@ -14,6 +14,9 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; + /** * 描述:OpenAIEventSourceListener * @@ -28,18 +31,24 @@ public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter, PromptService promp super(sseEmitter, promptService, queryRequest, loginUser); } + @Override + public String getName(){ + return "Zhipu"; + } @Override public void functionCall(String prompt){ Long uid = loginUser.getId(); List messages = promptService.getFastChatMessage(Objects.toString(uid), prompt); String requestId = String.valueOf(System.currentTimeMillis()); + ToolsFunction function = PromptService.getToolsFunction(); ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() .requestId(requestId) .stream(true) .toolChoice("auto") + .tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))) .messages(messages) - .build(); + .build(); ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, this); } } From e5bbcc2115c30f0d2d549405d568c626953b70cb Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Thu, 22 Feb 2024 09:59:02 +0800 Subject: [PATCH 07/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=89=80=E6=9C=89=E8=A1=A8=E4=B8=8D=E6=98=BE=E7=A4=BA=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E5=90=8D=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/core/impl/TableServiceImpl.java | 62 +++++++++---------- .../api/controller/ai/enums/PromptType.java | 4 +- .../controller/ai/utils/PromptService.java | 4 +- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java index 1454772d8..a5aaea130 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java @@ -419,44 +419,40 @@ public ListResult queryTables(TablePageQueryParam param) { private long addDBCache(Long dataSourceId, String databaseName, String schemaName, long version) { String key = getTableKey(dataSourceId, databaseName, schemaName); - Connection connection = Chat2DBContext.getConnection(); long n = 0; - try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, null, - new String[]{"TABLE", "SYSTEM TABLE"})) { - List cacheDOS = new ArrayList<>(); - while (resultSet.next()) { - TableCacheDO tableCacheDO = new TableCacheDO(); - tableCacheDO.setDatabaseName(databaseName); - tableCacheDO.setSchemaName(schemaName); - tableCacheDO.setTableName(resultSet.getString("TABLE_NAME")); - tableCacheDO.setExtendInfo(resultSet.getString("REMARKS")); - tableCacheDO.setDataSourceId(dataSourceId); - tableCacheDO.setVersion(version); - tableCacheDO.setKey(key); - cacheDOS.add(tableCacheDO); - if (cacheDOS.size() >= 500) { - getTableCacheMapper().batchInsert(cacheDOS); - cacheDOS = new ArrayList<>(); - } - n++; - } - if (!CollectionUtils.isEmpty(cacheDOS)) { + MetaData metaSchema = Chat2DBContext.getMetaData(); + List
tables = metaSchema.tables(connection, databaseName, schemaName, null); + List cacheDOS = new ArrayList<>(); + for(Table table : tables){ + TableCacheDO tableCacheDO = new TableCacheDO(); + tableCacheDO.setDatabaseName(databaseName); + tableCacheDO.setSchemaName(schemaName); + tableCacheDO.setTableName(table.getName()); + tableCacheDO.setExtendInfo(table.getComment()); + tableCacheDO.setDataSourceId(dataSourceId); + tableCacheDO.setVersion(version); + tableCacheDO.setKey(key); + cacheDOS.add(tableCacheDO); + if (cacheDOS.size() >= 500) { getTableCacheMapper().batchInsert(cacheDOS); + cacheDOS = new ArrayList<>(); } - LambdaQueryWrapper q = new LambdaQueryWrapper(); - q.eq(TableCacheDO::getDataSourceId, dataSourceId); - q.lt(TableCacheDO::getVersion, version); - if (StringUtils.isNotBlank(databaseName)) { - q.eq(TableCacheDO::getDatabaseName, databaseName); - } - if (StringUtils.isNotBlank(schemaName)) { - q.eq(TableCacheDO::getSchemaName, schemaName); - } - getTableCacheMapper().delete(q); - } catch (SQLException e) { - throw new RuntimeException(e); + n++; + } + if (!CollectionUtils.isEmpty(cacheDOS)) { + getTableCacheMapper().batchInsert(cacheDOS); + } + LambdaQueryWrapper q = new LambdaQueryWrapper(); + q.eq(TableCacheDO::getDataSourceId, dataSourceId); + q.lt(TableCacheDO::getVersion, version); + if (StringUtils.isNotBlank(databaseName)) { + q.eq(TableCacheDO::getDatabaseName, databaseName); + } + if (StringUtils.isNotBlank(schemaName)) { + q.eq(TableCacheDO::getSchemaName, schemaName); } + getTableCacheMapper().delete(q); return n; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java index 0135e5ea6..f6e833a01 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java @@ -41,9 +41,9 @@ public enum PromptType implements BaseEnum { /** - * function call + * GET_TABLE_COLUMNS */ - FUNCTION_CALL("获取指定表的字段名,类型"), + GET_TABLE_COLUMNS("获取指定表的属性"), ; final String description; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index d3ac972a9..3cb89d1f0 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -314,7 +314,7 @@ public String buildAutoPrompt(ChatQueryRequest queryRequest) { : queryRequest.getPromptType(); PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); if (pType.equals(PromptType.NL_2_SQL)) { - pType = PromptType.FUNCTION_CALL; + pType = PromptType.GET_TABLE_COLUMNS; } String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; @@ -376,7 +376,7 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { public static ToolsFunction getToolsFunction(){ return ToolsFunction.builder() .name("get_table_columns") - .description("获取指定表的属性") + .description(PromptType.GET_TABLE_COLUMNS.getDescription()) .parameters(Parameters.builder() .type("object") .properties(ImmutableMap.builder() From 1a721f085e8810cd6640fd4095ceb4338f8e40c5 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Thu, 22 Feb 2024 15:19:06 +0800 Subject: [PATCH 08/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=A4=96=E9=94=AE?= =?UTF-8?q?=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controller/ai/utils/PromptService.java | 69 ++++++++++++++++--- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index 3cb89d1f0..b0ba984b1 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -3,10 +3,14 @@ import java.io.IOException; import java.math.BigDecimal; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; +import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -49,6 +53,7 @@ import ai.chat2db.server.web.api.util.ApplicationContextUtil; import ai.chat2db.spi.MetaData; import ai.chat2db.spi.model.Table; +import ai.chat2db.spi.model.TableColumn; import ai.chat2db.spi.sql.Chat2DBContext; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -167,17 +172,16 @@ public String buildTableColumn(TableQueryParam tableQueryParam, if (CollectionUtils.isEmpty(tableNames)) { return ""; } - List schemaContent = Lists.newArrayList(); try { - schemaContent = tableNames.stream().map(tableName -> { + return tableNames.stream().map(tableName -> { tableQueryParam.setTableName(tableName); return queryTableDdl(tableName, tableQueryParam); - }).collect(Collectors.toList()); + }).collect(Collectors.joining(";\n")); } catch (Exception exception) { log.error("query table error, do nothing"); } - return JSON.toJSONString(schemaContent); + return ""; } /** @@ -313,10 +317,9 @@ public String buildAutoPrompt(ChatQueryRequest queryRequest) { String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() : queryRequest.getPromptType(); PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); - if (pType.equals(PromptType.NL_2_SQL)) { + if (StringUtils.isNotEmpty(properties)) { pType = PromptType.GET_TABLE_COLUMNS; } - String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " @@ -352,6 +355,32 @@ public String queryDatabaseType(ChatQueryRequest queryRequest) { return dataSourceType; } + + /** + * 根据给定的表对象找出所有可能的外键列 + * @return 外键列名列表 + */ + public static List findPossibleForeignKeys(List columns) { + List foreignKeys = new ArrayList<>(); + for (TableColumn column : columns) { + String columnName = column.getName(); + // 假设TableColumn类有一个getTableName方法可以获取列所属的表名 + String tableName = column.getTableName(); + Boolean primaryKey = column.getPrimaryKey(); + + // 检查列名是否符合`关联表_id`的格式,并且列名前半部分不等于表名 + if (columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey)) { + // 从列名中移除"_id"以获取可能的关联表名 + String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3); + + if (!potentialForeignKeyTable.equals(tableName)) { + foreignKeys.add(columnName); + } + } + } + return foreignKeys; + } + /** * query database schema * @@ -363,10 +392,30 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { MetaData metaSchema = Chat2DBContext.getMetaData(); try { List
tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null); - return tables.stream() - .map(table -> StringUtils.isBlank(table.getComment()) ? table.getName() - : table.getName() + "(" + table.getComment() + ")") - .collect(Collectors.joining(",")); + + return tables.stream().map(table -> { + StringBuilder sb = new StringBuilder(table.getName()); // 直接在初始化时加入表名 + String comment = table.getComment(); + List columns = metaSchema.columns(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), table.getName()); + List foreignKeys = findPossibleForeignKeys(columns); // 假设这个方法已经被定义 + + // 只有当有注释或外键时才添加额外信息 + if(StringUtils.isNotEmpty(comment) || !foreignKeys.isEmpty()){ + sb.append("(").append(comment); + + // 如果存在外键,添加外键信息 + if(!foreignKeys.isEmpty()){ + // 如果注释和外键都存在,先添加一个分隔符 + if(StringUtils.isNotEmpty(comment)) { + sb.append("; "); + } + sb.append("外键:").append(String.join(", ", foreignKeys)); // 优化外键的展示 + } + sb.append(")"); + } + return sb.toString(); // 在映射阶段直接转换为字符串 + }) + .collect(Collectors.joining(",")); } catch (Exception e) { log.error("query table error:{}, do nothing", e.getMessage()); return ""; From f6b4c36430244ebb61fb729f2c011f410122f662 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Sun, 25 Feb 2024 11:05:26 +0800 Subject: [PATCH 09/16] =?UTF-8?q?=E5=8A=A0=E5=88=97=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/core/impl/TableServiceImpl.java | 13 +++++++++++- .../controller/ai/utils/PromptService.java | 21 ++++++++++++++----- .../rdb/converter/RdbWebConverter.java | 9 ++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java index a5aaea130..46cf1e033 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java @@ -339,6 +339,16 @@ public PageResult
pageQuery(TablePageQueryParam param, TableSelector sele t.setComment(tableCacheDO.getExtendInfo()); t.setSchemaName(tableCacheDO.getSchemaName()); t.setDatabaseName(tableCacheDO.getDatabaseName()); + if(Boolean.TRUE.equals(selector.getColumnList())){ + TableQueryParam tableQueryParam = new TableQueryParam(); + tableQueryParam.setDataSourceId(param.getDataSourceId()); + tableQueryParam.setDatabaseName(param.getDatabaseName()); + tableQueryParam.setSchemaName(param.getSchemaName()); + tableQueryParam.setTableName(tableCacheDO.getTableName()); + tableQueryParam.setRefresh(false); + List columns = queryColumns(tableQueryParam); + t.setColumnList(columns); + } tables.add(t); } } @@ -433,6 +443,7 @@ private long addDBCache(Long dataSourceId, String databaseName, String schemaNam tableCacheDO.setDataSourceId(dataSourceId); tableCacheDO.setVersion(version); tableCacheDO.setKey(key); + metaSchema.columns(connection, databaseName, schemaName, table.getName()); cacheDOS.add(tableCacheDO); if (cacheDOS.size() >= 500) { getTableCacheMapper().batchInsert(cacheDOS); @@ -476,7 +487,7 @@ private Long getLock(Long dataSourceId, String databaseName, String schemaName, } } else { long version = versionDO.getVersion() + 1; - LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper(); + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); queryWrapper.eq(TableCacheVersionDO::getId, versionDO.getId()); queryWrapper.eq(TableCacheVersionDO::getVersion, versionDO.getVersion()); versionDO.setVersion(version); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index b0ba984b1..47c1e1e0b 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -27,12 +27,15 @@ import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.model.DataSource; import ai.chat2db.server.domain.api.param.ShowCreateTableParam; +import ai.chat2db.server.domain.api.param.TablePageQueryParam; import ai.chat2db.server.domain.api.param.TableQueryParam; +import ai.chat2db.server.domain.api.param.TableSelector; import ai.chat2db.server.domain.api.service.ConfigService; import ai.chat2db.server.domain.api.service.DataSourceService; import ai.chat2db.server.domain.api.service.TableService; import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; import ai.chat2db.server.tools.base.wrapper.result.DataResult; +import ai.chat2db.server.tools.base.wrapper.result.PageResult; import ai.chat2db.server.tools.common.util.EasyEnumUtils; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; @@ -45,6 +48,7 @@ import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; +import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter; import ai.chat2db.server.web.api.http.GatewayClientService; import ai.chat2db.server.web.api.http.model.TableSchema; import ai.chat2db.server.web.api.http.request.TableSchemaRequest; @@ -84,6 +88,10 @@ public class PromptService { private GatewayClientService gatewayClientService; + @Autowired + private RdbWebConverter rdbWebConverter; + + /** * 构建prompt * @@ -391,13 +399,16 @@ public static List findPossibleForeignKeys(List columns) { public String queryDatabaseTables(ChatQueryRequest queryRequest) { MetaData metaSchema = Chat2DBContext.getMetaData(); try { - List
tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null); - - return tables.stream().map(table -> { + TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(queryRequest); + TableSelector tableSelector = new TableSelector(); + tableSelector.setColumnList(true); + tableSelector.setIndexList(false); + PageResult
tables = tableService.pageQuery(queryParam,tableSelector); + return tables.getData().stream().map(table -> { StringBuilder sb = new StringBuilder(table.getName()); // 直接在初始化时加入表名 String comment = table.getComment(); - List columns = metaSchema.columns(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), table.getName()); - List foreignKeys = findPossibleForeignKeys(columns); // 假设这个方法已经被定义 + List columns = table.getColumnList(); + List foreignKeys = findPossibleForeignKeys(columns); // 只有当有注释或外键时才添加额外信息 if(StringUtils.isNotEmpty(comment) || !foreignKeys.isEmpty()){ diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java index b04663dc9..5a37c352a 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java @@ -3,6 +3,7 @@ import java.util.List; import ai.chat2db.server.domain.api.param.*; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.data.source.vo.DatabaseVO; import ai.chat2db.server.web.api.controller.rdb.request.*; import ai.chat2db.server.web.api.controller.rdb.vo.ColumnVO; @@ -99,6 +100,14 @@ public abstract class RdbWebConverter { * @return */ public abstract SqlVO dto2vo(Sql dto); + + /** + * 参数转换 + * + * @param request + * @return + */ + public abstract TablePageQueryParam tablePageRequest2param(ChatQueryRequest request); /** * 参数转换 * From a00f94ffeca5380db660fc893afe662837dbf71d Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Sun, 25 Feb 2024 16:49:09 +0800 Subject: [PATCH 10/16] =?UTF-8?q?er=E5=9B=BE=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controller/ai/utils/PromptService.java | 1 - .../api/controller/rdb/TableController.java | 37 ++++++++++++++++ .../java/ai/chat2db/spi/model/ErDiagram.java | 42 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index 47c1e1e0b..b70f8b058 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -397,7 +397,6 @@ public static List findPossibleForeignKeys(List columns) { * @throws IOException */ public String queryDatabaseTables(ChatQueryRequest queryRequest) { - MetaData metaSchema = Chat2DBContext.getMetaData(); try { TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(queryRequest); TableSelector tableSelector = new TableSelector(); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java index 131a6bf6c..016f3bdad 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java @@ -250,4 +250,41 @@ public ActionResult delete(@Valid @RequestBody TableDeleteRequest request) { DropParam dropParam = rdbWebConverter.tableDelete2dropParam(request); return tableService.drop(dropParam); } + + + /** + * 查询ER图 + * + * @param request + * @return + */ + @GetMapping("/er-diagram") + public DataResult erDiagram(@Valid TableBriefQueryRequest request) { + TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(request); + TableSelector tableSelector = new TableSelector(); + tableSelector.setColumnList(true); + tableSelector.setIndexList(false); + PageResult
tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector); + new ArrayList<>(); + List entityList = tableDTOPageResult.getData().stream().map(table -> { + ErDiagram.Node entity = new ErDiagram.Node(table.getName(), + StringUtils.defaultIfBlank(table.getComment(), table.getName())); + return entity; + }).collect(Collectors.toList()); + List relationList = tableDTOPageResult.getData().stream().flatMap(table -> { + return table.getColumnList().stream().filter(column -> { + String columnName = column.getName(); + Boolean primaryKey = column.getPrimaryKey(); + return columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey); + }).map(column -> { + String columnName = column.getName(); + String tableName = column.getTableName(); + // 从列名中移除"_id"以获取可能的关联表名 + String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3); + ErDiagram.Edge relation = new ErDiagram.Edge(columnName,tableName, potentialForeignKeyTable,column.getComment()); + return relation; + }); + }).collect(Collectors.toList()); + return DataResult.of(new ErDiagram(entityList, relationList)); + } } diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java new file mode 100644 index 000000000..67a2f9920 --- /dev/null +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java @@ -0,0 +1,42 @@ +package ai.chat2db.spi.model; + +import java.util.List; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +/** + * er图 + */ +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class ErDiagram { + + private List nodes; + private List edges; + + @Data + @SuperBuilder + @NoArgsConstructor + @AllArgsConstructor + public static class Node { + private String id; + private String label; + } + + @Data + @SuperBuilder + @NoArgsConstructor + @AllArgsConstructor + public static class Edge { + private String id; + private String source; + private String target; + private String label; + } + +} From 63cbd75f034b2e275c3102529ff4b526949d6258 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Wed, 28 Feb 2024 10:40:56 +0800 Subject: [PATCH 11/16] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/web/api/controller/rdb/TableController.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java index 016f3bdad..f2ce64dfd 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java @@ -18,17 +18,18 @@ import ai.chat2db.server.web.api.controller.rdb.vo.SqlVO; import ai.chat2db.server.web.api.controller.rdb.vo.TableVO; import ai.chat2db.spi.model.*; -import ai.chat2db.spi.sql.Chat2DBContext; -import ai.chat2db.spi.sql.ConnectInfo; import com.google.common.collect.Lists; import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; + +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.stream.Collectors; @Slf4j @ConnectionInfoAspect @@ -265,7 +266,6 @@ public DataResult erDiagram(@Valid TableBriefQueryRequest request) { tableSelector.setColumnList(true); tableSelector.setIndexList(false); PageResult
tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector); - new ArrayList<>(); List entityList = tableDTOPageResult.getData().stream().map(table -> { ErDiagram.Node entity = new ErDiagram.Node(table.getName(), StringUtils.defaultIfBlank(table.getComment(), table.getName())); From 79d6ef3771375c4205582ac30dcf67d1454990e0 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Wed, 6 Mar 2024 17:34:12 +0800 Subject: [PATCH 12/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=BD=91=E5=85=B3?= =?UTF-8?q?=E5=92=8Cazure=E6=8E=A5=E5=8F=A3=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- chat2db-gateway/pom.xml | 80 ++++++++++ .../main/java/com/hejianjun/Application.java | 18 +++ .../hejianjun/ElasticsearchClientConfig.java | 41 +++++ .../java/com/hejianjun/SchemaDocument.java | 14 ++ .../com/hejianjun/TableSchemaController.java | 56 +++++++ .../com/hejianjun/TableSchemaRequest.java | 34 +++++ .../com/hejianjun/TableSchemaService.java | 107 +++++++++++++ .../aspect/GatewayClientServiceAspect.java | 33 ++++ .../web/api/controller/ai/ChatController.java | 19 ++- .../azure/client/AzureOpenAiStreamClient.java | 12 +- .../AzureOpenAIEventSourceListener.java | 142 ++++-------------- .../model/AzureChatCompletionsOptions.java | 9 ++ .../listener/OpenAIEventSourceListener.java | 5 + .../controller/ai/utils/PromptService.java | 1 + .../ZhipuChatAIEventSourceListener.java | 7 +- 16 files changed, 452 insertions(+), 129 deletions(-) create mode 100644 chat2db-gateway/pom.xml create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/Application.java create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java create mode 100644 chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java create mode 100644 chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java diff --git a/.gitignore b/.gitignore index f8a263aaa..793134d36 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ package-lock.json /chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/src/main/resources/lib/* /chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/lib/* /lib -/out/* \ No newline at end of file +/out/* +/chat2db-gateway/target diff --git a/chat2db-gateway/pom.xml b/chat2db-gateway/pom.xml new file mode 100644 index 000000000..1e3e4b89a --- /dev/null +++ b/chat2db-gateway/pom.xml @@ -0,0 +1,80 @@ + + + 4.0.0 + + com.hejianjun + chat2db-gateway + 0.0.1-SNAPSHOT + jar + + chat2db-gateway + Project for chat2db-gateway + + + org.springframework.boot + spring-boot-starter-parent + 2.6.7 + + + + + 11 + 8.12.2 + 2.0.1 + + + + + + org.springframework.boot + spring-boot-starter-web + + + + + co.elastic.clients + elasticsearch-java + 8.12.2 + + + jakarta.json + jakarta.json-api + ${jakarta-json.version} + + + com.fasterxml.jackson.core + jackson-databind + 2.12.3 + + + + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + org.projectlombok + lombok + true + + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + diff --git a/chat2db-gateway/src/main/java/com/hejianjun/Application.java b/chat2db-gateway/src/main/java/com/hejianjun/Application.java new file mode 100644 index 000000000..23f0f58f3 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/Application.java @@ -0,0 +1,18 @@ +package com.hejianjun; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@Slf4j +@SpringBootApplication +public class Application { + /** + * 主程序入口 + * @param args 命令行参数 + */ + public static void main(String[] args) { + SpringApplication.run(Application.class, args); + } + +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java b/chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java new file mode 100644 index 000000000..e3d997bc1 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java @@ -0,0 +1,41 @@ +package com.hejianjun; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.ElasticsearchTransport; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import org.apache.http.Header; +import org.apache.http.HttpHost; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.client.RestClient; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class ElasticsearchClientConfig { + + String apiKey = "DVaOd3B6Rl*9sWUeTIHO"; + + /** + * 创建ElasticsearchClient实例 + * + * @return ElasticsearchClient实例 + */ + @Bean + public ElasticsearchClient elasticsearchClient() { + // 初始化低级客户端 + RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200)) + .setDefaultHeaders(new Header[]{ + new BasicHeader("Authorization", "ApiKey " + apiKey) + }) + .build(); + + // 使用低级客户端创建传输层 + ElasticsearchTransport transport = new RestClientTransport( + restClient, new JacksonJsonpMapper()); + + // 创建ElasticsearchClient实例 + return new ElasticsearchClient(transport); + } + +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java b/chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java new file mode 100644 index 000000000..b077e6508 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java @@ -0,0 +1,14 @@ +package com.hejianjun; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.math.BigDecimal; +import java.util.List; + +@Data +@AllArgsConstructor +public class SchemaDocument { + private String schema; + private List vector; +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java new file mode 100644 index 000000000..93883b9d6 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java @@ -0,0 +1,56 @@ +package com.hejianjun; + +import co.elastic.clients.json.JsonData; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +@Slf4j +@RestController +@AllArgsConstructor +@RequestMapping("/api/client/milvus") +public class TableSchemaController { + + private final TableSchemaService service; + + + /** + * 保存表结构 + * @param request 表结构请求对象 + * @return 保存成功的文档ID + */ + @PostMapping("/schema/save") + public ResponseEntity> saveSchema(@RequestBody TableSchemaRequest request) { + try { + List documentId = service.saveSchemaBatch(request); + return ResponseEntity.ok(documentId); + } catch (IOException e) { + log.error("保存表结构时发生错误", e); + return ResponseEntity.internalServerError().build(); + } + } + + /** + * 通过向量搜索表结构 + * @param request 表结构搜索请求 + * @return 搜索结果列表 + */ + @PostMapping("/schema/search") + public ResponseEntity searchByVector(@RequestBody TableSchemaRequest request) { + try { + TableSchemaRequest tableSchemaRequest = service.searchByVector(request); + return ResponseEntity.ok(tableSchemaRequest); + } catch (IOException e) { + log.error("Error searching schema", e); + return ResponseEntity.internalServerError().build(); + } + } +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java new file mode 100644 index 000000000..a3c72acf8 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java @@ -0,0 +1,34 @@ +package com.hejianjun; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +import java.math.BigDecimal; +import java.util.List; + +/** + * 表结构请求 + */ +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class TableSchemaRequest { + + // 数据源ID + private Long dataSourceId; + // 数据库名称 + private String databaseName; + // API密钥 + private String apiKey; + // 数据源模式 + private String dataSourceSchema; + // 模式向量 + private List> schemaVector; + // 模式列表 + private List schemaList; + // 插入前删除 + private Boolean deleteBeforeInsert = false; +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java new file mode 100644 index 000000000..0bf7c31bc --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java @@ -0,0 +1,107 @@ +package com.hejianjun; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.core.BulkRequest; +import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.IndexResponse; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; +import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.json.JsonData; +import lombok.AllArgsConstructor; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * TableSchemaService类用于处理表结构相关的操作。 + */ +@Service +@AllArgsConstructor +public class TableSchemaService { + + private final ElasticsearchClient client; + + /** + * 批量保存表结构。 + * + * @param request 表结构请求对象 + * @return 保存成功后的每个文档的ID列表 + * @throws IOException IO异常 + */ + public List saveSchemaBatch(TableSchemaRequest request) throws IOException { + List documentIds = new ArrayList<>(); + + // 构建批量请求 + BulkRequest.Builder bulkBuilder = new BulkRequest.Builder(); + + String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema(); + + for (int i = 0; i < request.getSchemaVector().size(); i++) { + // 假设schemaVector和schemaList的长度相同,并且一一对应 + List vector = request.getSchemaVector().get(i); + String schema = request.getSchemaList().get(i); + + // 创建文档内容,这里简化为Map,具体结构根据需求定义 + SchemaDocument document = new SchemaDocument(schema,vector); + + // 添加到批量请求 + bulkBuilder.operations(op -> op + .index(idx -> idx + .index(indexName) + .document(document) + ) + ); + } + + // 执行批量请求 + BulkResponse bulkResponse = client.bulk(bulkBuilder.build()); + + // 收集文档ID + for (BulkResponseItem item : bulkResponse.items()) { + if (item.error()!=null) { + throw new IOException("Error indexing document: " + item.error().reason()); + } + documentIds.add(item.id()); + } + + return documentIds; + } + + /** + * 根据向量搜索表结构。 + * + * @param request 表结构请求对象 + * @return 搜索结果列表 + * @throws IOException IO异常 + */ + public TableSchemaRequest searchByVector(TableSchemaRequest request) throws IOException { + String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema(); + List vector = request.getSchemaVector().get(0); + // 假设schemaVector已转换为适合Elasticsearch的格式 + // 执行k-NN搜索 + SearchResponse response = client.search(s -> s + .index(indexName) + // 这里添加k-NN查询逻辑,具体实现根据实际需求 + , SchemaDocument.class + ); + List> schemaVector = new ArrayList<>(); + List schemaList = new ArrayList<>(); + List> hits = response.hits().hits(); + for (Hit hit: hits) { + SchemaDocument document = hit.source(); + if(document!=null) { + schemaVector.add(document.getVector()); + schemaList.add(document.getSchema()); + } + } + request.setSchemaVector(schemaVector); + request.setSchemaList(schemaList); + return request; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java new file mode 100644 index 000000000..01e4c0eef --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java @@ -0,0 +1,33 @@ +package ai.chat2db.server.web.api.aspect; + + +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.springframework.stereotype.Component; + +@Aspect +@Component +public class GatewayClientServiceAspect { + /** + * 定义切点,匹配 GatewayClientService 类中的所有方法 + */ + @Pointcut("execution(* ai.chat2db.server.web.api.http.GatewayClientService.*(..)) && !execution(* ai.chat2db.server.web.api.http.GatewayClientService.checkInWhite(..))") + public void gatewayClientServiceMethods() {} + + + + /** + * 环绕通知:在切点方法执行时触发 + * @param joinPoint + * @return + * @throws Throwable + */ + @Around("gatewayClientServiceMethods()") + public Object aroundGatewayClientServiceMethods(ProceedingJoinPoint joinPoint) throws Throwable { + // 这里你可以执行一些自定义的逻辑,如果需要的话 + // 然后返回 null 或其他默认值 + return null; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 2a106fc10..973b890e7 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -11,6 +11,7 @@ import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient; @@ -24,6 +25,7 @@ import ai.chat2db.server.web.api.controller.ai.config.LocalCache; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; import ai.chat2db.server.web.api.controller.ai.fastchat.listener.FastChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient; @@ -329,7 +331,7 @@ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter s * @throws IOException */ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = promptService.buildPrompt(queryRequest); + String prompt = promptService.buildAutoPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); @@ -347,9 +349,16 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse messages.add(currentMessage); buildSseEmitter(sseEmitter, uid); - - AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter); - AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener); + LoginUser loginUser = ContextUtils.getLoginUser(); + AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser); + AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages); + chatCompletionsOptions.setStream(true); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = PromptService.getToolsFunction(); + chatCompletionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + chatCompletionsOptions.setToolChoice("auto"); + } + AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, sourceListener); LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); return sseEmitter; } @@ -397,7 +406,7 @@ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() .requestId(requestId) .stream(true) - .toolChoice("auto") + .messages(messages) .build(); if(queryRequest.getDatabaseName()!=null){ diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java index 338f5b1c1..6ae245590 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java @@ -149,22 +149,14 @@ public AzureOpenAiStreamClient build() { * @param chatMessages * @param eventSourceListener */ - public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { - if (CollectionUtils.isEmpty(chatMessages)) { - log.error("param error:Azure Prompt cannot be empty"); - throw new ParamBusinessException("prompt"); - } + public void streamCompletions(AzureChatCompletionsOptions chatCompletionsOptions, EventSourceListener eventSourceListener) { if (Objects.isNull(eventSourceListener)) { log.error("param error:AzureEventSourceListener cannot be empty"); throw new ParamBusinessException(); } - log.info("Azure Open AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); try { - - AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(chatMessages); chatCompletionsOptions.setStream(true); chatCompletionsOptions.setModel(this.deployId); - EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @@ -172,7 +164,7 @@ public void streamCompletions(List chatMessages, EventSourceLi if (!endpoint.endsWith("/")) { endpoint = endpoint + "/"; } - String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2023-05-15"; + String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2024-02-15-preview"; Request request = new Request.Builder() .url(url) .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java index 4488bd6b8..2b9ab4a99 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java @@ -1,15 +1,32 @@ package ai.chat2db.server.web.api.controller.ai.azure.listener; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatChoice; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletions; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureCompletionsUsage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; +import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; +import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; + import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; + import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import okhttp3.Response; @@ -26,123 +43,26 @@ * @date 2023-02-22 */ @Slf4j -public class AzureOpenAIEventSourceListener extends EventSourceListener { +public class AzureOpenAIEventSourceListener extends OpenAIEventSourceListener { - private SseEmitter sseEmitter; - private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); - - public AzureOpenAIEventSourceListener(SseEmitter sseEmitter) { - this.sseEmitter = sseEmitter; + public AzureOpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, + ChatQueryRequest queryRequest, LoginUser loginUser) { + super(sseEmitter, promptService, queryRequest, loginUser); } - /** - * {@inheritDoc} - */ @Override - public void onOpen(EventSource eventSource, Response response) { - log.info("AzureOpenAI建立sse连接..."); + public String getName(){ + return "AzureOpenAI"; } - /** - * {@inheritDoc} - */ - @SneakyThrows - @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("AzureOpenAI返回数据:{}", data); - if (data.equals("[DONE]")) { - log.info("AzureOpenAI返回数据结束了"); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); - return; - } - - AzureChatCompletions chatCompletions = mapper.readValue(data, AzureChatCompletions.class); - String text = ""; - log.info("Model ID={} is created at {}.", chatCompletions.getId(), - chatCompletions.getCreated()); - for (AzureChatChoice choice : chatCompletions.getChoices()) { - AzureChatMessage message = choice.getDelta(); - if (message != null) { - log.info("Index: {}, Chat Role: {}", choice.getIndex(), message.getRole()); - if (message.getContent() != null) { - text = message.getContent(); - } - } - } - - AzureCompletionsUsage usage = chatCompletions.getUsage(); - if (usage != null) { - log.info( - "Usage: number of prompt token is {}, number of completion token is {}, and number of total " - + "tokens in request and response is {}.%n", usage.getPromptTokens(), - usage.getCompletionTokens(), usage.getTotalTokens()); - } - - Message message = new Message(); - message.setContent(text); - sseEmitter.send(SseEmitter.event() - .id(null) - .data(message) - .reconnectTime(3000)); - } - - @Override - public void onClosed(EventSource eventSource) { - try { - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - } catch (IOException e) { - throw new RuntimeException(e); - } - sseEmitter.complete(); - log.info("AzureOpenAI close sse connection..."); - } - - @Override - public void onFailure(EventSource eventSource, Throwable t, Response response) { - try { - if (Objects.isNull(response)) { - String message = t.getMessage(); - Message sseMessage = new Message(); - sseMessage.setContent(message); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - return; - } - ResponseBody body = response.body(); - String bodyString = Objects.nonNull(t) ? t.getMessage() : ""; - if (Objects.nonNull(body)) { - bodyString = body.string(); - if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) { - bodyString = t.getMessage(); - } - log.error("Azure OpenAI sse response:{}", bodyString); - } else { - log.error("Azure OpenAI sse response:{},error:{}", response, t); - } - eventSource.cancel(); - Message message = new Message(); - message.setContent("Azure OpenAI error:" + bodyString); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - } catch (Exception exception) { - log.error("Azure OpenAI发送数据异常:", exception); - } + @Override + public void functionCall(String prompt){ + AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt); + List messages = new ArrayList<>(); + messages.add(currentMessage); + AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages); + chatCompletionsOptions.setStream(true); + AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, this); } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java index 1d6198e57..8d33166b3 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java @@ -7,6 +7,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.tool.Tools; + import lombok.Data; /** @@ -391,4 +393,11 @@ public AzureChatCompletionsOptions setModel(String model) { this.model = model; return this; } + + // 新添加的参数 + @JsonProperty(value = "tool_choice") + private String toolChoice; // 工具选择策略 + + @JsonProperty(value = "tools") + private List tools; // 工具列表 } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 2d63e9f4c..36afe4e02 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -22,6 +22,8 @@ import okhttp3.ResponseBody; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; + +import org.apache.commons.collections4.CollectionUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.ArrayList; @@ -193,6 +195,9 @@ public void onEvent(EventSource eventSource, String id, String type, String data mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); // 读取Json ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); + if(CollectionUtils.isEmpty(completionResponse.getChoices())){ + return; + } Message delta = completionResponse.getChoices().get(0).getDelta(); if (delta != null && delta.getToolCalls() != null) { this.toolCalls = mergeToolCallsLists(this.toolCalls, delta.getToolCalls()); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index b70f8b058..c84831224 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -13,6 +13,7 @@ import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.poi.ss.formula.functions.T; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java index abc07a8e1..a02668775 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java @@ -2,6 +2,7 @@ import ai.chat2db.server.tools.common.model.LoginUser; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.utils.PromptService; @@ -9,6 +10,7 @@ import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import lombok.extern.slf4j.Slf4j; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -38,8 +40,9 @@ public String getName(){ @Override public void functionCall(String prompt){ - Long uid = loginUser.getId(); - List messages = promptService.getFastChatMessage(Objects.toString(uid), prompt); + FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); + List messages = new ArrayList<>(); + messages.add(currentMessage); String requestId = String.valueOf(System.currentTimeMillis()); ToolsFunction function = PromptService.getToolsFunction(); ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() From b2b41480600599eaca402762ab5af91d0a7d469e Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Sat, 30 Mar 2024 21:48:33 +0800 Subject: [PATCH 13/16] =?UTF-8?q?=E6=A8=A1=E7=B3=8A=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E8=A1=A8=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/java/com/hejianjun/TableSchemaRequest.java | 4 ++++ .../src/main/resources/mapper/TableCacheMapper.xml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java index a3c72acf8..6e28ca0ea 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java @@ -18,16 +18,20 @@ public class TableSchemaRequest { // 数据源ID + @NotNull private Long dataSourceId; // 数据库名称 + @NotNull private String databaseName; // API密钥 private String apiKey; // 数据源模式 private String dataSourceSchema; // 模式向量 + @NotNull private List> schemaVector; // 模式列表 + @NotNull private List schemaList; // 插入前删除 private Boolean deleteBeforeInsert = false; diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml index c367e2605..37efbd21c 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml @@ -25,7 +25,7 @@ and tc.schema_name = #{schemaName} - and LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%')) + and (LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%')) or tc.extend_info like concat('%',#{searchKey},'%')) From 6bcdea72d507f80c48c8edc3a3821cac70d7e92b Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Mon, 1 Apr 2024 16:11:03 +0800 Subject: [PATCH 14/16] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=A1=A8=E9=80=89?= =?UTF-8?q?=E6=8B=A9=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web/api/controller/ai/ChatController.java | 2 + .../listener/OpenAIEventSourceListener.java | 43 +++++++++++++------ .../controller/ai/utils/PromptService.java | 6 ++- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index 973b890e7..8a4ca74eb 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -336,6 +336,8 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); throw new ParamBusinessException(); + }else{ + log.info("提示词 :{}",prompt); } List messages = (List)LocalCache.CACHE.get(uid); if (CollectionUtils.isNotEmpty(messages)) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index 36afe4e02..de39b8bcb 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -24,9 +24,13 @@ import okhttp3.sse.EventSourceListener; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Set; import java.util.List; import java.util.Objects; @@ -134,17 +138,24 @@ public void functionCall(String prompt){ } - public void handleTableNames(List tableNames,Object instance){ - if(instance instanceof JSONArray){ - ((JSONArray)instance).forEach(tableName->{ - handleTableNames(tableNames,tableName); - }); - }else if (instance instanceof JSONObject) { - ((JSONObject)instance).entrySet().forEach(entrySet->{ - handleTableNames(tableNames,entrySet.getValue()); - }); - }else if (instance instanceof String) { - tableNames.add((String)instance); + public void handleTableNames(Set tableNames, Object instance) { + if (instance instanceof JSONArray) { + ((JSONArray) instance).forEach(item -> handleTableNames(tableNames, item)); + } else if (instance instanceof JSONObject) { + ((JSONObject) instance).forEach((key, value) -> handleTableNames(tableNames, value)); + } else if (instance instanceof String) { + String tableName = (String) instance; + List queryTableNames = queryRequest.getTableNames(); + if (queryTableNames != null) { + String mostSimilarTableName = queryTableNames.stream() + // 根据相似度排序 + .min(Comparator.comparingInt(existingTableName -> StringUtils.getLevenshteinDistance(existingTableName, tableName))) + .orElse(tableName); + tableNames.add(mostSimilarTableName); + }else{ + tableNames.add(tableName); + } + } } /** @@ -165,7 +176,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data sseEmitter.complete(); return; } - List tableNames = new ArrayList<>(); + Set tableNames = new HashSet<>(); for (ToolCalls toolCall : toolCalls) { String callId = toolCall.getId(); ToolCallFunction function = toolCall.getFunction(); @@ -177,8 +188,12 @@ public void onEvent(EventSource eventSource, String id, String type, String data } } } - - queryRequest.setTableNames(tableNames); + Message message = new Message(); + message.setContent("选择表" + tableNames); + sseEmitter.send(SseEmitter.event() + .data(message) + .reconnectTime(3000)); + queryRequest.setTableNames(new ArrayList<>(tableNames)); ContextUtils.setContext(Context.builder() .loginUser(loginUser) .build()); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index c84831224..9b52411bf 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -404,7 +404,9 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { tableSelector.setColumnList(true); tableSelector.setIndexList(false); PageResult
tables = tableService.pageQuery(queryParam,tableSelector); - return tables.getData().stream().map(table -> { + List tableNames = new ArrayList<>(); + String properties = tables.getData().stream().map(table -> { + tableNames.add(table.getName()); StringBuilder sb = new StringBuilder(table.getName()); // 直接在初始化时加入表名 String comment = table.getComment(); List columns = table.getColumnList(); @@ -427,6 +429,8 @@ public String queryDatabaseTables(ChatQueryRequest queryRequest) { return sb.toString(); // 在映射阶段直接转换为字符串 }) .collect(Collectors.joining(",")); + queryRequest.setTableNames(tableNames); + return properties; } catch (Exception e) { log.error("query table error:{}, do nothing", e.getMessage()); return ""; From 3e52d92f6f363e59787eb9dc268b9cfed775ffa7 Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Mon, 8 Apr 2024 11:16:40 +0800 Subject: [PATCH 15/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E8=AF=AD=E5=8F=A5=E5=92=8C=E4=BF=AE=E6=94=B9=E9=80=89=E6=8B=A9?= =?UTF-8?q?=E8=A1=A8bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/openai/listener/OpenAIEventSourceListener.java | 3 ++- .../server/web/api/controller/ai/utils/PromptService.java | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index de39b8bcb..54609e72f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -185,11 +185,12 @@ public void onEvent(EventSource eventSource, String id, String type, String data if ("get_table_columns".equals(functionName)) { JSONObject arguments = JSONObject.parse(function.getArguments()); handleTableNames(tableNames,arguments.get("table_names")); + log.info("原始参数:{},处理后:{}",arguments,tableNames); } } } Message message = new Message(); - message.setContent("选择表" + tableNames); + message.setContent("选择表" + tableNames+"\n"); sseEmitter.send(SseEmitter.event() .data(message) .reconnectTime(3000)); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java index 9b52411bf..503c5543a 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -400,6 +400,7 @@ public static List findPossibleForeignKeys(List columns) { public String queryDatabaseTables(ChatQueryRequest queryRequest) { try { TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(queryRequest); + queryParam.queryAll(); TableSelector tableSelector = new TableSelector(); tableSelector.setColumnList(true); tableSelector.setIndexList(false); From fc096ef35db85d1707cec279049069b5f9869a1a Mon Sep 17 00:00:00 2001 From: hejianjun <942156265@qq.com> Date: Tue, 9 Apr 2024 17:14:54 +0800 Subject: [PATCH 16/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat2db-gateway/pom.xml | 4 ++++ .../java/com/hejianjun/{ => bean}/SchemaDocument.java | 2 +- .../java/com/hejianjun/{ => bean}/TableSchemaRequest.java | 3 ++- .../hejianjun/{ => config}/ElasticsearchClientConfig.java | 4 ++-- .../hejianjun/{ => controller}/TableSchemaController.java | 6 +++--- .../com/hejianjun/{ => service}/TableSchemaService.java | 8 +++----- 6 files changed, 15 insertions(+), 12 deletions(-) rename chat2db-gateway/src/main/java/com/hejianjun/{ => bean}/SchemaDocument.java (89%) rename chat2db-gateway/src/main/java/com/hejianjun/{ => bean}/TableSchemaRequest.java (91%) rename chat2db-gateway/src/main/java/com/hejianjun/{ => config}/ElasticsearchClientConfig.java (94%) rename chat2db-gateway/src/main/java/com/hejianjun/{ => controller}/TableSchemaController.java (93%) rename chat2db-gateway/src/main/java/com/hejianjun/{ => service}/TableSchemaService.java (95%) diff --git a/chat2db-gateway/pom.xml b/chat2db-gateway/pom.xml index 1e3e4b89a..68d8cbc0d 100644 --- a/chat2db-gateway/pom.xml +++ b/chat2db-gateway/pom.xml @@ -65,6 +65,10 @@ lombok true + + org.springframework.boot + spring-boot-starter-validation + diff --git a/chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java similarity index 89% rename from chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java rename to chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java index b077e6508..82432390c 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/SchemaDocument.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java @@ -1,4 +1,4 @@ -package com.hejianjun; +package com.hejianjun.bean; import lombok.AllArgsConstructor; import lombok.Data; diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java similarity index 91% rename from chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java rename to chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java index 6e28ca0ea..e3398b435 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaRequest.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java @@ -1,10 +1,11 @@ -package com.hejianjun; +package com.hejianjun.bean; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.experimental.SuperBuilder; +import javax.validation.constraints.NotNull; import java.math.BigDecimal; import java.util.List; diff --git a/chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java similarity index 94% rename from chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java rename to chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java index e3d997bc1..c95bc67f5 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/ElasticsearchClientConfig.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java @@ -1,4 +1,4 @@ -package com.hejianjun; +package com.hejianjun.config; import co.elastic.clients.elasticsearch.ElasticsearchClient; import co.elastic.clients.json.jackson.JacksonJsonpMapper; @@ -14,7 +14,7 @@ @Configuration public class ElasticsearchClientConfig { - String apiKey = "DVaOd3B6Rl*9sWUeTIHO"; + String apiKey = "0E9NGIy7gb8a3TDVM8dC"; /** * 创建ElasticsearchClient实例 diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java similarity index 93% rename from chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java rename to chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java index 93883b9d6..5abb24133 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaController.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java @@ -1,6 +1,7 @@ -package com.hejianjun; +package com.hejianjun.controller; -import co.elastic.clients.json.JsonData; +import com.hejianjun.bean.TableSchemaRequest; +import com.hejianjun.service.TableSchemaService; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.ResponseEntity; @@ -11,7 +12,6 @@ import java.io.IOException; import java.util.List; -import java.util.Map; @Slf4j @RestController diff --git a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java similarity index 95% rename from chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java rename to chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java index 0bf7c31bc..3668f3358 100644 --- a/chat2db-gateway/src/main/java/com/hejianjun/TableSchemaService.java +++ b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java @@ -1,13 +1,13 @@ -package com.hejianjun; +package com.hejianjun.service; import co.elastic.clients.elasticsearch.ElasticsearchClient; import co.elastic.clients.elasticsearch.core.BulkRequest; import co.elastic.clients.elasticsearch.core.BulkResponse; -import co.elastic.clients.elasticsearch.core.IndexResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; import co.elastic.clients.elasticsearch.core.search.Hit; -import co.elastic.clients.json.JsonData; +import com.hejianjun.bean.SchemaDocument; +import com.hejianjun.bean.TableSchemaRequest; import lombok.AllArgsConstructor; import org.springframework.stereotype.Service; @@ -15,8 +15,6 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; /** * TableSchemaService类用于处理表结构相关的操作。