diff --git a/api/Dockerfile b/api/Dockerfile index 21dd924..d0e7e7a 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -12,7 +12,7 @@ RUN mkdir -p $FOLDER # Install packages COPY ./requirements.txt $FOLDER/requirements.txt -RUN pip install -r $FOLDER/requirements.txt +RUN pip install --no-cache-dir -r $FOLDER/requirements.txt # Copy the project files into the container COPY ./src $FOLDER/src diff --git a/api/requirements.txt b/api/requirements.txt index a390d8d..9673330 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -7,4 +7,6 @@ retry==0.9.2 tiktoken==0.4.0 python-dotenv==1.0.0 websockets===11.0.3 -gunicorn===20.1.0 \ No newline at end of file +gunicorn===20.1.0 +transformers +torch==2.3.0 diff --git a/api/src/embedding/openai.py b/api/src/embedding/openai.py index 7663332..cd54807 100644 --- a/api/src/embedding/openai.py +++ b/api/src/embedding/openai.py @@ -1,19 +1,38 @@ -import openai -from embedding.base_embedding import BaseEmbedding +# import openai +# from embedding.base_embedding import BaseEmbedding -class OpenAIEmbedding(BaseEmbedding): - """Wrapper around OpenAI embedding models.""" +# class OpenAIEmbedding(BaseEmbedding): +# """Wrapper around OpenAI embedding models.""" + +# def __init__( +# self, openai_api_key: str, model_name: str = "text-embedding-ada-002" +# ) -> None: +# openai.api_key = openai_api_key +# self.model = model_name + +# def generate( +# self, +# input: str, +# ) -> str: +# embedding = openai.Embedding.create(input=input, model=self.model) +# return embedding["data"][0]["embedding"] + +from sentence_transformers import SentenceTransformer +from base_embedding import BaseEmbedding + + +class LlamaEmbedding(BaseEmbedding): + """Wrapper around HuggingFace embedding models.""" def __init__( - self, openai_api_key: str, model_name: str = "text-embedding-ada-002" + self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" ) -> None: - openai.api_key = openai_api_key - self.model = model_name + self.model = SentenceTransformer(model_name) def generate( self, input: str, - ) -> str: - embedding = openai.Embedding.create(input=input, model=self.model) - return embedding["data"][0]["embedding"] + ) -> list: + embedding = self.model.encode(input) + return embedding diff --git a/api/src/llm/openai.py b/api/src/llm/openai.py index b09e311..68edbac 100644 --- a/api/src/llm/openai.py +++ b/api/src/llm/openai.py @@ -1,26 +1,109 @@ +# from typing import ( +# Callable, +# List, +# ) + +# import openai +# import tiktoken +from llm.basellm import BaseLLM +# from retry import retry + + +# class OpenAIChat(BaseLLM): +# """Wrapper around OpenAI Chat large language models.""" + +# def __init__( +# self, +# openai_api_key: str, +# model_name: str = "gpt-3.5-turbo", +# max_tokens: int = 1000, +# temperature: float = 0.0, +# ) -> None: +# openai.api_key = openai_api_key +# self.model = model_name +# self.max_tokens = max_tokens +# self.temperature = temperature + +# @retry(tries=3, delay=1) +# def generate( +# self, +# messages: List[str], +# ) -> str: +# try: +# completions = openai.ChatCompletion.create( +# model=self.model, +# temperature=self.temperature, +# max_tokens=self.max_tokens, +# messages=messages, +# ) +# return completions.choices[0].message.content +# # catch context length / do not retry +# except openai.error.InvalidRequestError as e: +# return str(f"Error: {e}") +# # catch authorization errors / do not retry +# except openai.error.AuthenticationError as e: +# return "Error: The provided OpenAI API key is invalid" +# except Exception as e: +# print(f"Retrying LLM call {e}") +# raise Exception() + +# async def generateStreaming( +# self, +# messages: List[str], +# onTokenCallback=Callable[[str], None], +# ) -> str: +# result = [] +# completions = openai.ChatCompletion.create( +# model=self.model, +# temperature=self.temperature, +# max_tokens=self.max_tokens, +# messages=messages, +# stream=True, +# ) +# result = [] +# for message in completions: +# # Process the streamed messages or perform any other desired action +# delta = message["choices"][0]["delta"] +# if "content" in delta: +# result.append(delta["content"]) +# await onTokenCallback(message) +# return result + +# def num_tokens_from_string(self, string: str) -> int: +# encoding = tiktoken.encoding_for_model(self.model) +# num_tokens = len(encoding.encode(string)) +# return num_tokens + +# def max_allowed_token_length(self) -> int: +# # TODO: list all models and their max tokens from api +# return 2049 from typing import ( Callable, List, ) -import openai -import tiktoken -from llm.basellm import BaseLLM +# from transformers import LlamaForCausalLM, LlamaTokenizer +import torch +# from basellm import BaseLLM from retry import retry +# Load model directly +from transformers import AutoTokenizer, AutoModelForCausalLM + +# tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-32K-Instruct-GPTQ", trust_remote_code=True) +# model = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-32K-Instruct-GPTQ", trust_remote_code=True) -class OpenAIChat(BaseLLM): - """Wrapper around OpenAI Chat large language models.""" +class Llama2Chat(BaseLLM): + """Wrapper around HuggingFace Llama2 large language models.""" def __init__( self, - openai_api_key: str, - model_name: str = "gpt-3.5-turbo", - max_tokens: int = 1000, + model_name: str = "TheBloke/Llama-2-7B-32K-Instruct-GPTQ", + max_tokens: int = 2056, temperature: float = 0.0, ) -> None: - openai.api_key = openai_api_key - self.model = model_name + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) self.max_tokens = max_tokens self.temperature = temperature @@ -30,50 +113,39 @@ def generate( messages: List[str], ) -> str: try: - completions = openai.ChatCompletion.create( - model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - messages=messages, - ) - return completions.choices[0].message.content - # catch context length / do not retry - except openai.error.InvalidRequestError as e: - return str(f"Error: {e}") - # catch authorization errors / do not retry - except openai.error.AuthenticationError as e: - return "Error: The provided OpenAI API key is invalid" + # Concatenate the messages into a single string + input_text = " ".join(messages) + inputs = self.tokenizer(input_text, return_tensors="pt", max_length=self.max_tokens, truncation=True) + outputs = self.model.generate(**inputs, max_length=self.max_tokens, temperature=self.temperature) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: print(f"Retrying LLM call {e}") - raise Exception() + raise Exception(f"Error: {e}") async def generateStreaming( self, messages: List[str], onTokenCallback=Callable[[str], None], ) -> str: - result = [] - completions = openai.ChatCompletion.create( - model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - messages=messages, - stream=True, - ) - result = [] - for message in completions: - # Process the streamed messages or perform any other desired action - delta = message["choices"][0]["delta"] - if "content" in delta: - result.append(delta["content"]) - await onTokenCallback(message) - return result + try: + input_text = " ".join(messages) + inputs = self.tokenizer(input_text, return_tensors="pt", max_length=self.max_tokens, truncation=True) + outputs = self.model.generate(**inputs, max_length=self.max_tokens, temperature=self.temperature) + + result = [] + for token_id in outputs[0]: + token = self.tokenizer.decode(token_id, skip_special_tokens=True) + result.append(token) + await onTokenCallback(token) + return result + except Exception as e: + print(f"Error during streaming generation: {e}") + raise Exception(f"Error: {e}") def num_tokens_from_string(self, string: str) -> int: - encoding = tiktoken.encoding_for_model(self.model) - num_tokens = len(encoding.encode(string)) - return num_tokens + inputs = self.tokenizer(string, return_tensors="pt") + return inputs.input_ids.shape[1] def max_allowed_token_length(self) -> int: - # TODO: list all models and their max tokens from api - return 2049 + return self.tokenizer.model_max_length + diff --git a/api/src/main.py b/api/src/main.py index cbb35b2..c9ec7d6 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -17,7 +17,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fewshot_examples import get_fewshot_examples -from llm.openai import OpenAIChat +from llm.openai import Llama2Chat +# from llm.openai import OpenAIChat + from pydantic import BaseModel @@ -79,11 +81,16 @@ async def questionProposalsForCurrentDb(payload: questionProposalPayload): questionProposalGenerator = QuestionProposalGenerator( database=neo4j_connection, - llm=OpenAIChat( - openai_api_key=api_key, - model_name="gpt-3.5-turbo-0613", + # llm=OpenAIChat( + # openai_api_key=api_key, + # model_name="gpt-3.5-turbo-0613", + # max_tokens=512, + # temperature=0.8, + llm=Llama2Chat( + # openai_api_key=api_key, + model_name="TheBloke/Llama-2-7B-32K-Instruct-GPTQ", max_tokens=512, - temperature=0.8, + temperature=0, ), ) @@ -128,17 +135,33 @@ async def onToken(token): ) api_key = openai_api_key if openai_api_key else data.get("api_key") - default_llm = OpenAIChat( - openai_api_key=api_key, - model_name=data.get("model_name", "gpt-3.5-turbo-0613"), + # default_llm = OpenAIChat( + # openai_api_key=api_key, + # model_name=data.get("model_name", "gpt-3.5-turbo-0613"), + # ) + + default_llm = Llama2Chat( + # openai_api_key=api_key, + model_name=data.get("model_name", "TheBloke/Llama-2-7B-32K-Instruct-GPTQ"), + max_tokens=512, + temperature=0, ) + + # summarize_results = SummarizeCypherResult( + # llm=OpenAIChat( + # openai_api_key=api_key, + # model_name="gpt-3.5-turbo-0613", + # max_tokens=128, + # ) + # ) summarize_results = SummarizeCypherResult( - llm=OpenAIChat( - openai_api_key=api_key, - model_name="gpt-3.5-turbo-0613", - max_tokens=128, - ) + llm=Llama2Chat( + # openai_api_key=api_key, + model_name=data.get("model_name", "TheBloke/Llama-2-7B-32K-Instruct-GPTQ"), + max_tokens=128, + temperature=0, ) + ) text2cypher = Text2Cypher( database=neo4j_connection, @@ -205,9 +228,16 @@ async def root(payload: ImportPayload): try: result = "" - llm = OpenAIChat( - openai_api_key=api_key, model_name="gpt-3.5-turbo-16k", max_tokens=4000 - ) + # llm = OpenAIChat( + # openai_api_key=api_key, model_name="gpt-3.5-turbo-16k", max_tokens=4000 + # ) + llm=Llama2Chat( + # openai_api_key=api_key, + model_name="TheBloke/Llama-2-7B-32K-Instruct-GPTQ", + max_tokens=512, + temperature=0, + ) + if not payload.neo4j_schema: extractor = DataExtractor(llm=llm) @@ -246,11 +276,17 @@ async def companyInformation(payload: companyReportPayload): ) api_key = openai_api_key if openai_api_key else payload.api_key - llm = OpenAIChat( - openai_api_key=api_key, - model_name="gpt-3.5-turbo-16k-0613", - max_tokens=512, - ) + # llm = OpenAIChat( + # openai_api_key=api_key, + # model_name="gpt-3.5-turbo-16k-0613", + # max_tokens=512, + # ) + llm=Llama2Chat( + model_name="TheBloke/Llama-2-7B-32K-Instruct-GPTQ", + max_tokens=512, + temperature=0, + ) + print("Running company report for " + payload.company) company_report = CompanyReport(neo4j_connection, payload.company, llm) result = company_report.run() diff --git a/docker-compose.yml b/docker-compose.yml index bf3039b..7bab8c7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: "3.7" +version: "3.8" services: backend: build: @@ -20,3 +20,8 @@ services: container_name: ui ports: - 4173:4173 + volumes: + - .:/app + - /app/node_modules + environment: + - NODE_ENV=development diff --git a/ui/Dockerfile b/ui/Dockerfile index e41e721..17c3a4b 100644 --- a/ui/Dockerfile +++ b/ui/Dockerfile @@ -10,10 +10,29 @@ COPY . /ui/. WORKDIR $FOLDER +# Copy CA certificate +# COPY your-ca-certificate.crt /usr/local/share/ca-certificates/your-ca-certificate.crt + +# # Update CA certificates +# RUN update-ca-certificates + +# # Configure npm to use the custom CA +# RUN npm config set cafile /usr/local/share/ca-certificates/your-ca-certificate.crt + +RUN rm -rf node_modules +RUN rm -rf dist + +# Disable strict SSL +RUN npm config set strict-ssl false +RUN npm install vite +RUN npm install --save react react-dom +RUN npm install --save-dev @types/react @types/react-dom typescript +# RUN npm install --save-dev @types/vite +RUN npm install react-use-websocket RUN npm install +# RUN npm run build EXPOSE 4173 -RUN npm run build -CMD ["npm", "run", "preview"] +CMD ["npm", "run", "start"] diff --git a/ui/import.meta.env b/ui/import.meta.env new file mode 100644 index 0000000..9c252e4 --- /dev/null +++ b/ui/import.meta.env @@ -0,0 +1,8 @@ +const apiUrl = import.meta.env.VITE_API_URL; +const API_KEY_ENDPOINT = import.meta.env.VITE_HAS_API_KEY_ENDPOINT; +const KG_CHAT_SAMPLE_QUESTIONS_ENDPOINT = import.meta.env.VITE_KG_CHAT_SAMPLE_QUESTIONS_ENDPOINT; + +console.log(apiUrl); +console.log(API_KEY_ENDPOINT); +console.log(KG_CHAT_SAMPLE_QUESTIONS_ENDPOINT); + diff --git a/ui/package.json b/ui/package.json index 1a74502..dd86c62 100644 --- a/ui/package.json +++ b/ui/package.json @@ -58,6 +58,6 @@ "vite": "^4.3.8" }, "engines": { - "node": "^18.0.0" + "node": ">=18.0.0" } } diff --git a/ui/src/chat-with-kg/App.tsx b/ui/src/chat-with-kg/App.tsx index 12a39c4..561beef 100644 --- a/ui/src/chat-with-kg/App.tsx +++ b/ui/src/chat-with-kg/App.tsx @@ -1,14 +1,11 @@ -import { useCallback, useEffect, useState, ChangeEvent } from "react"; +import React, { useCallback, useEffect, useState, ChangeEvent } from "react"; +import { useWebSocket, ReadyState } from "react-use-websocket"; + import ChatContainer from "./ChatContainer"; import type { ChatMessageObject } from "./ChatMessage"; import ChatInput from "./ChatInput"; -import useWebSocket, { ReadyState } from "react-use-websocket"; import KeyModal from "../components/keymodal"; -import type { - ConversationState, - WebSocketRequest, - WebSocketResponse, -} from "./types/websocketTypes"; +import type { ConversationState, WebSocketRequest, WebSocketResponse } from "../../types/websocketTypes"; const SEND_REQUESTS = true; @@ -49,7 +46,7 @@ function loadKeyFromStorage() { return localStorage.getItem("api_key"); } -const QUESTION_PREFIX_REGEXP = /^[0-9]{1,2}[\w]*[\.\)\-]*[\w]*/; +const QUESTION_PREFIX_REGEXP = /^[0-9]{1,2}[\\w]*[\\.\\)\\-]*[\\w]*/; function stripQuestionPrefix(question: string): string { if (question.match(QUESTION_PREFIX_REGEXP)) { @@ -59,255 +56,57 @@ function stripQuestionPrefix(question: string): string { } function App() { - const [serverAvailable, setServerAvailable] = useState(true); - const [needsApiKeyLoading, setNeedsApiKeyLoading] = useState(true); - const [needsApiKey, setNeedsApiKey] = useState(true); - const [chatMessages, setChatMessages] = useState(chatMessageObjects); - const [conversationState, setConversationState] = - useState("ready"); - const { sendJsonMessage, lastMessage, readyState } = useWebSocket(URI, { + const [messages, setMessages] = useState(chatMessageObjects); + const [apiKey, setApiKey] = useState(loadKeyFromStorage()); + const [isKeyModalOpen, setIsKeyModalOpen] = useState(!apiKey); + + const { sendMessage, lastMessage, readyState } = useWebSocket(URI, { + queryParams: { apiKey: apiKey ?? "" }, + onOpen: () => console.log("Connected to WebSocket"), + onClose: () => console.log("Disconnected from WebSocket"), shouldReconnect: () => true, - reconnectInterval: 5000, }); - const [errorMessage, setErrorMessage] = useState(null); - const [modalIsOpen, setModalIsOpen] = useState(false); - const [apiKey, setApiKey] = useState(loadKeyFromStorage() || ""); - const [sampleQuestions, setSampleQuestions] = useState([]); - const [text2cypherModel, setText2cypherModel] = useState("gpt-3.5-turbo-0613"); - - const showContent = serverAvailable && !needsApiKeyLoading; - - function loadSampleQuestions() { - const body = { - api_key: apiKey, - }; - const options = { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(body), - }; - fetch(QUESTIONS_URI, options).then( - (response) => { - response.json().then( - (result) => { - if (result.output && result.output.length > 0) { - setSampleQuestions(result.output.map(stripQuestionPrefix)); - } else { - setSampleQuestions([]); - } - }, - (error) => { - setSampleQuestions([]); - } - ); - }, - (error) => { - setSampleQuestions([]); - } - ); - } useEffect(() => { - fetch(HAS_API_KEY_URI).then( - (response) => { - response.json().then( - (result) => { - // const needsKey = result.output; - const needsKey = !result.output; - setNeedsApiKey(needsKey); - setNeedsApiKeyLoading(false); - if (needsKey) { - const api_key = loadKeyFromStorage(); - if (api_key) { - setApiKey(api_key); - loadSampleQuestions(); - } else { - setModalIsOpen(true); - } - } else { - loadSampleQuestions(); - } - }, - (error) => { - setNeedsApiKeyLoading(false); - setServerAvailable(false); - } - ); - }, - (error) => { - setNeedsApiKeyLoading(false); - setServerAvailable(false); - } - ); - }, []); - - useEffect(() => { - if (!lastMessage || !serverAvailable) { - return; - } - - const websocketResponse = JSON.parse(lastMessage.data) as WebSocketResponse; - - if (websocketResponse.type === "debug") { - console.log(websocketResponse.detail); - } else if (websocketResponse.type === "error") { - setConversationState("error"); - setErrorMessage(websocketResponse.detail); - console.error(websocketResponse.detail); - } else if (websocketResponse.type === "start") { - setConversationState("streaming"); - - setChatMessages((chatMessages) => [ - ...chatMessages, - { - id: chatMessages.length, - type: "text", - sender: "bot", - message: "", - complete: false, - }, + if (lastMessage !== null) { + const response: WebSocketResponse = JSON.parse(lastMessage.data); + setMessages((prevMessages) => [ + ...prevMessages, + { id: prevMessages.length, type: "text", sender: "bot", message: response.data, complete: true }, ]); - } else if (websocketResponse.type === "stream") { - setChatMessages((chatMessages) => { - const lastChatMessage = chatMessages[chatMessages.length - 1]; - const rest = chatMessages.slice(0, -1); - - return [ - ...rest, - { - ...lastChatMessage, - message: lastChatMessage.message + websocketResponse.output, - }, - ]; - }); - } else if (websocketResponse.type === "end") { - setChatMessages((chatMessages) => { - const lastChatMessage = chatMessages[chatMessages.length - 1]; - const rest = chatMessages.slice(0, -1); - return [ - ...rest, - { - ...lastChatMessage, - complete: true, - cypher: websocketResponse.generated_cypher, - }, - ]; - }); - setConversationState("ready"); } }, [lastMessage]); - useEffect(() => { - if (conversationState === "error") { - const timeout = setTimeout(() => { - setConversationState("ready"); - }, 1000); - return () => clearTimeout(timeout); - } - }, [conversationState]); - - const sendQuestion = (question: string) => { - const webSocketRequest: WebSocketRequest = { - type: "question", - question: question, - }; - if (serverAvailable && !needsApiKeyLoading && needsApiKey && apiKey) { - webSocketRequest.api_key = apiKey; - } - webSocketRequest.model_name = text2cypherModel; - sendJsonMessage(webSocketRequest); + const handleSendMessage = (message: string) => { + const newMessage: ChatMessageObject = { id: messages.length, type: "input", sender: "self", message, complete: true }; + setMessages((prevMessages) => [...prevMessages, newMessage]); + const request: WebSocketRequest = { message }; + sendMessage(JSON.stringify(request)); }; - const onChatInput = (message: string) => { - if (conversationState === "ready") { - setChatMessages((chatMessages) => - chatMessages.concat([ - { - id: chatMessages.length, - type: "input", - sender: "self", - message: message, - complete: true, - }, - ]) - ); - if (SEND_REQUESTS) { - setConversationState("waiting"); - sendQuestion(message); - } - setErrorMessage(null); - } + const handleApiKeyChange = (key: string) => { + setApiKey(key); + localStorage.setItem("api_key", key); }; - const openModal = () => { - setModalIsOpen(true); + const handleCloseModal = () => { + setIsKeyModalOpen(false); }; - const onCloseModal = () => { - setModalIsOpen(false); - if (apiKey && sampleQuestions.length === 0) { - loadSampleQuestions(); - } - }; - - const onApiKeyChange = (newApiKey: string) => { - setApiKey(newApiKey); - localStorage.setItem("api_key", newApiKey); - }; - - const handleModelChange = (e: ChangeEvent) => { - setText2cypherModel(e.target.value) - } - return ( -
- {needsApiKey && ( -
- -
- )} -
- -
-
- {!serverAvailable && ( -
Server is unavailable, please reload the page to try again.
- )} - {serverAvailable && needsApiKeyLoading &&
Initializing...
} - - {showContent && readyState === ReadyState.OPEN && ( - <> - - - {errorMessage} - - )}{" "} - {showContent && readyState === ReadyState.CONNECTING && ( -
Connecting...
- )} - {showContent && readyState === ReadyState.CLOSED && ( -
-
Could not connect to server, reconnecting...
-
- )} -
+
+
+

Chat with Knowledge Graph

+
+ + +
); } diff --git a/ui/src/chat-with-kg/main.tsx b/ui/src/chat-with-kg/main.tsx index 3fb9bd4..a427d7d 100644 --- a/ui/src/chat-with-kg/main.tsx +++ b/ui/src/chat-with-kg/main.tsx @@ -1,8 +1,13 @@ import React from "react"; import { createRoot } from "react-dom/client"; -import App from "./App.js"; +import App from "./App"; import Modal from "react-modal"; +import ReactDOM from 'react-dom'; + +// import * as React from 'react'; +// import * as ReactDOM from 'react-dom'; +ReactDOM.render(, document.getElementById('root')); import "@neo4j-ndl/base/lib/neo4j-ds-styles.css"; import "./index.css"; diff --git a/ui/src/chat-with-kg/types/websocketTypes.ts b/ui/src/chat-with-kg/types/websocketTypes.ts deleted file mode 100644 index f88cfa6..0000000 --- a/ui/src/chat-with-kg/types/websocketTypes.ts +++ /dev/null @@ -1,28 +0,0 @@ -export type WebSocketRequest = { - type: "question"; - question: string; - api_key?: string; - model_name?: string; -}; - -export type WebSocketResponse = - | { type: "start" } - | { - type: "stream"; - output: string; - } - | { - type: "end"; - output: string; - generated_cypher: string | null; - } - | { - type: "error"; - detail: string; - } - | { - type: "debug"; - detail: string; - }; - -export type ConversationState = "waiting" | "streaming" | "ready" | "error"; diff --git a/ui/src/react-use-websocket.d.ts b/ui/src/react-use-websocket.d.ts new file mode 100644 index 0000000..63d9a98 --- /dev/null +++ b/ui/src/react-use-websocket.d.ts @@ -0,0 +1,38 @@ +// react-use-websocket.d.ts +declare module 'react-use-websocket' { + import { ComponentType } from 'react'; + + export type ReadyState = number; + export type SendMessage = (message: string) => void; + export type Options = { + retryOnError?: boolean; + reconnectAttempts?: number; + reconnectInterval?: number; + share?: boolean; + onOpen?: () => void; + onClose?: () => void; + onMessage?: (message: WebSocketEventMap['message']) => void; + onError?: (error: WebSocketEventMap['error']) => void; + filter?: () => boolean; + }; + + export function useWebSocket( + url: string, + options?: Options + ): { + sendMessage: SendMessage; + sendJsonMessage: (message: any) => void; + lastMessage: WebSocketEventMap['message'] | null; + readyState: ReadyState; + }; + + export const ReadyState: { + CONNECTING: number; + OPEN: number; + CLOSING: number; + CLOSED: number; + }; + + const WebSocketComponent: ComponentType<{ url: string; options?: Options }>; + export default WebSocketComponent; +} diff --git a/ui/src/vite-env.d.ts b/ui/src/vite-env.d.ts new file mode 100644 index 0000000..55659a4 --- /dev/null +++ b/ui/src/vite-env.d.ts @@ -0,0 +1,26 @@ +// src/vite-env.d.ts + +interface ImportMetaEnv { + readonly VITE_API_URL: string; + readonly VITE_ANOTHER_ENV_VAR: string; + // Add other environment variables here as needed + } + + interface ImportMeta { + readonly env: ImportMetaEnv; + } + + +/// + + +interface ImportMetaEnv { + readonly VITE_KG_CHAT_BACKEND_ENDPOINT: string; + readonly VITE_HAS_API_KEY_ENDPOINT: string; + readonly VITE_KG_CHAT_SAMPLE_QUESTIONS_ENDPOINT: string; + // Add other environment variables here... +} + +interface ImportMeta { + readonly env: ImportMetaEnv; +} \ No newline at end of file diff --git a/ui/tsconfig.json b/ui/tsconfig.json index 1cb4f27..e24b39d 100644 --- a/ui/tsconfig.json +++ b/ui/tsconfig.json @@ -7,15 +7,23 @@ "skipLibCheck": true, "esModuleInterop": false, "allowSyntheticDefaultImports": true, - "strict": true, - "forceConsistentCasingInFileNames": true, + "noFallthroughCasesInSwitch": true, + "strict": true, // + "forceConsistentCasingInFileNames": true, // "module": "ESNext", "moduleResolution": "Node", "resolveJsonModule": true, "isolatedModules": true, "noEmit": true, "types": ["vite/client"], - "jsx": "react-jsx" - }, - "include": ["src"], - } \ No newline at end of file + "jsx": "react-jsx", + "baseUrl": "./src" + }, + "include": [ + "src", + // "src/vite-env.d.ts", // Adjust the path according to your setup + "types", + // "src/react-use-websocket.d.ts" // Adjust the path as needed + ] +} + diff --git a/ui/types/vite-env.d.ts b/ui/types/vite-env.d.ts new file mode 100644 index 0000000..b03d52e --- /dev/null +++ b/ui/types/vite-env.d.ts @@ -0,0 +1,14 @@ +// vite-env.d.ts + +interface ImportMetaEnv { + readonly VITE_API_URL: string; + readonly VITE_ANOTHER_ENV_VAR: string; + // Add other environment variables here as needed + } + + interface ImportMeta { + readonly env: ImportMetaEnv; + } + + +/// \ No newline at end of file diff --git a/ui/types/websocketTypes.ts b/ui/types/websocketTypes.ts new file mode 100644 index 0000000..8e1d672 --- /dev/null +++ b/ui/types/websocketTypes.ts @@ -0,0 +1,27 @@ +// types/websocketTypes.ts + +export type ConversationState = "ready" | "waiting" | "streaming" | "error"; + +export interface WebSocketRequest { + type: string; + message: string; + question: string; + api_key?: string; + model_name?: string; +} + +export interface WebSocketResponse { + type: string; + detail: string; + output?: string; + generated_cypher?: string; +} + +export interface ChatMessageObject { + id: number; + type: "input" | "text"; + sender: "self" | "bot"; + message: string; + complete: boolean; +} + diff --git a/ui/vite.config.js b/ui/vite.config.js index a28ef37..70d4c52 100644 --- a/ui/vite.config.js +++ b/ui/vite.config.js @@ -1,6 +1,6 @@ -import { resolve } from 'path' -import { defineConfig } from 'vite' -import react from '@vitejs/plugin-react' +import { resolve } from 'path'; +import { defineConfig } from 'vite'; +import react from '@vitejs/plugin-react'; // https://vitejs.dev/config/ export default defineConfig({