Skip to content

Commit 0b05cc9

Browse files
committed
修改openai接口,让openai可以使用function的方式获取表结构
1 parent e707d64 commit 0b05cc9

File tree

8 files changed

+255
-47
lines changed

8 files changed

+255
-47
lines changed

chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => {
4242
};
4343

4444
const renderSelectTable = () => {
45-
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props;
45+
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props;
4646
const options = (tables || []).map((t) => ({ value: t, label: t }));
4747
return (
4848
<div className={styles.aiSelectedTable}>
4949
<Radio.Group
5050
onChange={(v) => onSelectTableSyncModel(v.target.value)}
51-
// value={syncTableModel}
52-
value={SyncModelType.MANUAL}
51+
value={syncTableModel}
5352
style={{ marginBottom: '8px' }}
5453
>
5554
<Space direction="horizontal">
56-
{/* <Radio value={SyncModelType.AUTO}>自动</Radio> */}
55+
<Radio value={SyncModelType.AUTO}>自动</Radio>
5756
<Radio value={SyncModelType.MANUAL}>手动</Radio>
5857
</Space>
5958
</Radio.Group>

chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ const SelectBoundInfo = memo((props: IProps) => {
186186
boundInfo.databaseName,
187187
boundInfo.schemaName,
188188
);
189-
setSelectedTables(tableNameListTemp.slice(0, 1));
189+
//setSelectedTables(tableNameListTemp.slice(0, 1));
190190
}
191191
}, [allTableList, isActive]);
192192

chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,20 @@
5454
import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse;
5555
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
5656
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
57+
import ai.chat2db.spi.MetaData;
58+
import ai.chat2db.spi.model.Table;
59+
import ai.chat2db.spi.sql.Chat2DBContext;
60+
import ai.chat2db.spi.sql.ConnectInfo;
5761
import cn.hutool.core.util.StrUtil;
5862
import cn.hutool.json.JSONUtil;
5963
import com.alibaba.fastjson2.JSON;
64+
import com.google.common.collect.ImmutableMap;
6065
import com.google.common.collect.Lists;
66+
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
6167
import com.unfbx.chatgpt.entity.chat.Message;
68+
import com.unfbx.chatgpt.entity.chat.Parameters;
69+
import com.unfbx.chatgpt.entity.chat.tool.Tools;
70+
import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
6271
import jakarta.annotation.Resource;
6372
import lombok.extern.slf4j.Slf4j;
6473
import org.apache.commons.collections4.CollectionUtils;
@@ -171,7 +180,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc
171180
/**
172181
* 自定义模型非流式输出接口DEMO
173182
* <p>
174-
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
183+
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
175184
* </p>
176185
*
177186
* @param queryRequest
@@ -276,11 +285,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter
276285
* @throws IOException
277286
*/
278287
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
279-
throws IOException {
280-
String prompt = buildPrompt(queryRequest);
288+
throws IOException {
289+
String prompt = buildPrompt2(queryRequest);
281290
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
282291
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
283-
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
292+
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
284293
throw new ParamBusinessException();
285294
}
286295

@@ -290,9 +299,28 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
290299
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
291300
messages.add(currentMessage);
292301
buildSseEmitter(sseEmitter, uid);
293-
294-
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
295-
OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
302+
ConnectInfo connectInfo = Chat2DBContext.getConnectInfo();
303+
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo, queryRequest);
304+
ToolsFunction function = ToolsFunction.builder()
305+
.name("get_table_columns")
306+
.description("获取指定表的字段名,类型")
307+
.parameters(Parameters.builder()
308+
.type("object")
309+
.properties(ImmutableMap.builder()
310+
.put("table_name", ImmutableMap.builder()
311+
.put("type", "string")
312+
.put("description", "表名,例如```User```")
313+
.build())
314+
.build())
315+
.required(List.of("table_name"))
316+
.build())
317+
.build();
318+
ChatCompletion chatCompletion = ChatCompletion.builder()
319+
.model("gpt-3.5-turbo-1106")
320+
.tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)))
321+
.toolChoice("auto")
322+
.messages(messages).stream(true).build();
323+
OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener);
296324
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
297325
return sseEmitter;
298326
}
@@ -630,6 +658,47 @@ private String buildPrompt(ChatQueryRequest queryRequest) {
630658
return cleanedInput;
631659
}
632660

661+
/**
662+
* 构建prompt
663+
*
664+
* @param queryRequest
665+
* @return
666+
*/
667+
private String buildPrompt2(ChatQueryRequest queryRequest) {
668+
if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) {
669+
return queryRequest.getMessage();
670+
}
671+
672+
// 查询schema信息
673+
String dataSourceType = queryDatabaseType(queryRequest);
674+
String properties = "";
675+
if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
676+
properties = queryRequest.getTableNames().stream().collect(Collectors.joining(","));
677+
} else {
678+
properties = queryDatabaseSchema2(queryRequest);
679+
}
680+
String prompt = queryRequest.getMessage();
681+
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
682+
: queryRequest.getPromptType();
683+
PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
684+
String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
685+
String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
686+
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# "
687+
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
688+
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
689+
pType.getDescription(), ext, prompt);
690+
switch (pType) {
691+
case SQL_2_SQL:
692+
schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
693+
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
694+
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
695+
default:
696+
break;
697+
}
698+
String cleanedInput = schemaProperty.replaceAll("[\r\t]", "");
699+
return cleanedInput;
700+
}
701+
633702
/**
634703
* query chat2db apikey
635704
*
@@ -727,6 +796,28 @@ public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
727796
}
728797
}
729798

799+
800+
/**
801+
* query database schema
802+
*
803+
* @param queryRequest
804+
* @return
805+
* @throws IOException
806+
*/
807+
public String queryDatabaseSchema2(ChatQueryRequest queryRequest) {
808+
MetaData metaSchema = Chat2DBContext.getMetaData();
809+
try {
810+
List<Table> tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null);
811+
return tables.stream()
812+
.map(table -> StringUtils.isBlank(table.getComment()) ? table.getName()
813+
: table.getName() + "(" + table.getComment() + ")")
814+
.collect(Collectors.joining(","));
815+
} catch (Exception e) {
816+
log.error("query table error:{}, do nothing", e.getMessage());
817+
return "";
818+
}
819+
}
820+
730821
/**
731822
* query database schema
732823
*

chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
22

3-
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
4-
import ai.chat2db.server.domain.api.model.Config;
5-
import ai.chat2db.server.domain.api.service.ConfigService;
6-
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
73
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
84
import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor;
95
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi;
106
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
117
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
12-
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
138
import cn.hutool.http.ContentType;
149
import com.fasterxml.jackson.databind.DeserializationFeature;
1510
import com.fasterxml.jackson.databind.ObjectMapper;
1611
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
1712
import com.unfbx.chatgpt.entity.chat.Message;
18-
import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor;
1913
import lombok.Getter;
2014
import lombok.extern.slf4j.Slf4j;
2115
import okhttp3.*;

chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import java.net.InetSocketAddress;
55
import java.net.Proxy;
66
import java.util.Objects;
7+
import java.util.concurrent.TimeUnit;
78

89
import ai.chat2db.server.domain.api.model.Config;
910
import ai.chat2db.server.domain.api.service.ConfigService;
@@ -93,7 +94,17 @@ public static void refresh() {
9394
log.info("refresh openai apikey:{}", maskApiKey(apikey));
9495
if (Objects.nonNull(host) && Objects.nonNull(port)) {
9596
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port));
96-
OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build();
97+
OkHttpClient okHttpClient = new OkHttpClient.Builder()
98+
// 设置连接超时为10秒
99+
.connectTimeout(10, TimeUnit.SECONDS)
100+
// 设置读取超时为30秒
101+
.readTimeout(30, TimeUnit.SECONDS)
102+
// 设置写入超时为15秒
103+
.writeTimeout(15, TimeUnit.SECONDS)
104+
// 设置整个调用的超时为1分钟
105+
.callTimeout(1, TimeUnit.MINUTES)
106+
.proxy(proxy)
107+
.build();
97108
OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
98109
Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build();
99110
} else {

0 commit comments

Comments
 (0)