5454import ai .chat2db .server .web .api .http .response .EsTableSchemaResponse ;
5555import ai .chat2db .server .web .api .http .response .TableSchemaResponse ;
5656import ai .chat2db .server .web .api .util .ApplicationContextUtil ;
57+ import ai .chat2db .spi .MetaData ;
58+ import ai .chat2db .spi .model .Table ;
59+ import ai .chat2db .spi .sql .Chat2DBContext ;
60+ import ai .chat2db .spi .sql .ConnectInfo ;
5761import cn .hutool .core .util .StrUtil ;
5862import cn .hutool .json .JSONUtil ;
5963import com .alibaba .fastjson2 .JSON ;
64+ import com .google .common .collect .ImmutableMap ;
6065import com .google .common .collect .Lists ;
66+ import com .unfbx .chatgpt .entity .chat .ChatCompletion ;
6167import com .unfbx .chatgpt .entity .chat .Message ;
68+ import com .unfbx .chatgpt .entity .chat .Parameters ;
69+ import com .unfbx .chatgpt .entity .chat .tool .Tools ;
70+ import com .unfbx .chatgpt .entity .chat .tool .ToolsFunction ;
6271import jakarta .annotation .Resource ;
6372import lombok .extern .slf4j .Slf4j ;
6473import org .apache .commons .collections4 .CollectionUtils ;
@@ -171,7 +180,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc
171180 /**
172181 * 自定义模型非流式输出接口DEMO
173182 * <p>
174- * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
183+ * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
175184 * </p>
176185 *
177186 * @param queryRequest
@@ -276,11 +285,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter
276285 * @throws IOException
277286 */
278287 private SseEmitter chatWithOpenAi (ChatQueryRequest queryRequest , SseEmitter sseEmitter , String uid )
279- throws IOException {
280- String prompt = buildPrompt (queryRequest );
288+ throws IOException {
289+ String prompt = buildPrompt2 (queryRequest );
281290 if (prompt .length () / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH ) {
282291 log .error ("提示语超出最大长度:{},输入长度:{}, 请重新输入" , MAX_PROMPT_LENGTH ,
283- prompt .length () / TOKEN_CONVERT_CHAR_LENGTH );
292+ prompt .length () / TOKEN_CONVERT_CHAR_LENGTH );
284293 throw new ParamBusinessException ();
285294 }
286295
@@ -290,9 +299,28 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
290299 Message currentMessage = Message .builder ().content (prompt ).role (Message .Role .USER ).build ();
291300 messages .add (currentMessage );
292301 buildSseEmitter (sseEmitter , uid );
293-
294- OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener (sseEmitter );
295- OpenAIClient .getInstance ().streamChatCompletion (messages , openAIEventSourceListener );
302+ ConnectInfo connectInfo = Chat2DBContext .getConnectInfo ();
303+ OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener (sseEmitter , messages , connectInfo , queryRequest );
304+ ToolsFunction function = ToolsFunction .builder ()
305+ .name ("get_table_columns" )
306+ .description ("获取指定表的字段名,类型" )
307+ .parameters (Parameters .builder ()
308+ .type ("object" )
309+ .properties (ImmutableMap .builder ()
310+ .put ("table_name" , ImmutableMap .builder ()
311+ .put ("type" , "string" )
312+ .put ("description" , "表名,例如```User```" )
313+ .build ())
314+ .build ())
315+ .required (List .of ("table_name" ))
316+ .build ())
317+ .build ();
318+ ChatCompletion chatCompletion = ChatCompletion .builder ()
319+ .model ("gpt-3.5-turbo-1106" )
320+ .tools (List .of (new Tools (Tools .Type .FUNCTION .getName (), function )))
321+ .toolChoice ("auto" )
322+ .messages (messages ).stream (true ).build ();
323+ OpenAIClient .getInstance ().streamChatCompletion (chatCompletion , openAIEventSourceListener );
296324 LocalCache .CACHE .put (uid , JSONUtil .toJsonStr (messages ), LocalCache .TIMEOUT );
297325 return sseEmitter ;
298326 }
@@ -630,6 +658,47 @@ private String buildPrompt(ChatQueryRequest queryRequest) {
630658 return cleanedInput ;
631659 }
632660
661+ /**
662+ * 构建prompt
663+ *
664+ * @param queryRequest
665+ * @return
666+ */
667+ private String buildPrompt2 (ChatQueryRequest queryRequest ) {
668+ if (PromptType .TEXT_GENERATION .getCode ().equals (queryRequest .getPromptType ())) {
669+ return queryRequest .getMessage ();
670+ }
671+
672+ // 查询schema信息
673+ String dataSourceType = queryDatabaseType (queryRequest );
674+ String properties = "" ;
675+ if (CollectionUtils .isNotEmpty (queryRequest .getTableNames ())) {
676+ properties = queryRequest .getTableNames ().stream ().collect (Collectors .joining ("," ));
677+ } else {
678+ properties = queryDatabaseSchema2 (queryRequest );
679+ }
680+ String prompt = queryRequest .getMessage ();
681+ String promptType = StringUtils .isBlank (queryRequest .getPromptType ()) ? PromptType .NL_2_SQL .getCode ()
682+ : queryRequest .getPromptType ();
683+ PromptType pType = EasyEnumUtils .getEnum (PromptType .class , promptType );
684+ String ext = StringUtils .isNotBlank (queryRequest .getExt ()) ? queryRequest .getExt () : "" ;
685+ String schemaProperty = StringUtils .isNotEmpty (properties ) ? String .format (
686+ "### 请根据以下table properties和SQL input%s. %s\n #\n ### %s SQL tables:\n #\n # "
687+ + "%s\n #\n #\n ### SQL input: %s" , pType .getDescription (), ext , dataSourceType ,
688+ properties , prompt ) : String .format ("### 请根据以下SQL input%s. %s\n #\n ### SQL input: %s" ,
689+ pType .getDescription (), ext , prompt );
690+ switch (pType ) {
691+ case SQL_2_SQL :
692+ schemaProperty = StringUtils .isNotBlank (queryRequest .getDestSqlType ()) ? String .format (
693+ "%s\n #\n ### 目标SQL类型: %s" , schemaProperty , queryRequest .getDestSqlType ()) : String .format (
694+ "%s\n #\n ### 目标SQL类型: %s" , schemaProperty , dataSourceType );
695+ default :
696+ break ;
697+ }
698+ String cleanedInput = schemaProperty .replaceAll ("[\r \t ]" , "" );
699+ return cleanedInput ;
700+ }
701+
633702 /**
634703 * query chat2db apikey
635704 *
@@ -727,6 +796,28 @@ public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
727796 }
728797 }
729798
799+
800+ /**
801+ * query database schema
802+ *
803+ * @param queryRequest
804+ * @return
805+ * @throws IOException
806+ */
807+ public String queryDatabaseSchema2 (ChatQueryRequest queryRequest ) {
808+ MetaData metaSchema = Chat2DBContext .getMetaData ();
809+ try {
810+ List <Table > tables = metaSchema .tables (Chat2DBContext .getConnection (), queryRequest .getDatabaseName (), queryRequest .getSchemaName (), null );
811+ return tables .stream ()
812+ .map (table -> StringUtils .isBlank (table .getComment ()) ? table .getName ()
813+ : table .getName () + "(" + table .getComment () + ")" )
814+ .collect (Collectors .joining ("," ));
815+ } catch (Exception e ) {
816+ log .error ("query table error:{}, do nothing" , e .getMessage ());
817+ return "" ;
818+ }
819+ }
820+
730821 /**
731822 * query database schema
732823 *
0 commit comments