Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/java/com/UoB/AILearningTool/DatabaseController.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.springframework.stereotype.Service;


// Communication with SQL database.
@Service
public class DatabaseController {
private Map<String, User> users = new HashMap<>();
private Map<String, Chat> chats = new HashMap<>();
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/UoB/AILearningTool/OpenAIAPIController.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import org.springframework.stereotype.Service;

@Service
public class OpenAIAPIController {
private final Logger log = LoggerFactory.getLogger(OpenAIAPIController.class);
private final OpenAIAuthenticator authenticator;
Expand Down
112 changes: 75 additions & 37 deletions src/main/java/com/UoB/AILearningTool/SpringController.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
import org.springframework.web.bind.annotation.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.io.IOException;

@RestController
@Service
public class SpringController {
private final Logger log = LoggerFactory.getLogger(SpringController.class);
private final DatabaseController DBC = new DatabaseController();
private final DatabaseController DBC;
// TODO: Replace with OpenAIAPIController with WatsonxAPIController when API quota issue will be resolved.
// private final WatsonxAPIController WXC = new WatsonxAPIController();
private final OpenAIAPIController WXC = new OpenAIAPIController();
private final OpenAIAPIController WXC;

@Autowired
public SpringController(DatabaseController DBC, OpenAIAPIController WXC) {
this.DBC = DBC;
this.WXC = WXC;
}
// Assign a unique user ID for the user.
@GetMapping("/signup")
public void signup(@CookieValue(value = "optionalConsent", defaultValue = "false") boolean optionalConsent,
Expand Down Expand Up @@ -56,17 +63,27 @@ public void createChat(@CookieValue(value = "userID", defaultValue = "") String
// Create a chat
User user = DBC.getUser(userID);
if (user != null) {
WatsonxResponse wresponse;
String chatID = DBC.createChat(user, initialMessage);

// Send the message history (system prompt and initial message) to Watsonx API,
// add the AI response to the message history of the chat.
try {
wresponse = WXC.sendUserMessage(DBC.getChat(DBC.getUser(userID), chatID).getMessageHistory(user));
response.getWriter().write(chatID);
response.setStatus(wresponse.statusCode);
if (wresponse.statusCode == 200) {
DBC.getChat(DBC.getUser(userID), chatID).addAIMessage(userID, wresponse.responseText);
// Grab the chat one time
Chat chat = DBC.getChat(user, chatID);
if (chat != null) {
String messageHistory = chat.getMessageHistory(user);
WatsonxResponse wresponse = WXC.sendUserMessage(messageHistory);

response.getWriter().write(chatID);
response.setStatus(wresponse.statusCode);

if (wresponse.statusCode == 200) {
// Reuse the same chat and user objects, instead of creating new ones
chat.addAIMessage(userID, wresponse.responseText);
}
} else {
response.getWriter().write("null");
response.setStatus(400);
}
} catch (IOException e) {
log.warn(String.valueOf(e));
Expand All @@ -87,25 +104,34 @@ public void sendMessage(@CookieValue(value = "userID", defaultValue = "") String
@RequestParam(name = "newMessage") String newMessage,
@RequestParam(name = "chatID") String chatID,
HttpServletResponse response) {
Chat chat = DBC.getChat(DBC.getUser(userID), chatID);
String inputString = chat.getMessageHistory(DBC.getUser(userID));

// Default response - Unauthorised
WatsonxResponse wresponse = new WatsonxResponse(401, "");
response.setContentType("text/plain");
response.setStatus(401);

if (chat != null && inputString != null) {
// If a message can be added to the message history of a chat, then send the message history
// to Watsonx API.
boolean success = chat.addUserMessage(userID, newMessage);
if (success) {
inputString = chat.getMessageHistory(DBC.getUser(userID));
// TODO: Revert wresponse when issue with Watsonx API quota will be resolved.
// wresponse = WXC.sendUserMessage(StringTools.messageHistoryPrepare(inputString));
wresponse = WXC.sendUserMessage(inputString);
response.setStatus(wresponse.statusCode);
}
response.setStatus(401); // Default

// 1) Get the user
User user = DBC.getUser(userID);
// 2) get the chat
Chat chat = DBC.getChat(user, chatID);
if (chat == null) {
// null chat
return;
}

// 3) Get existing history
String inputString = chat.getMessageHistory(user);
if (inputString == null) {
return;
}

// 4) Try to add the new user message
boolean success = chat.addUserMessage(userID, newMessage);
if (success) {
// 5) Re-fetch history & call AI
inputString = chat.getMessageHistory(user);
WatsonxResponse wresponse = WXC.sendUserMessage(inputString);

// Update status from AI
response.setStatus(wresponse.statusCode);

try {
if (wresponse.statusCode == 200) {
chat.addAIMessage(userID, wresponse.responseText);
Expand Down Expand Up @@ -152,24 +178,36 @@ public void sendIncognitoMessage(@CookieValue(value = "userID") String userID,
public void getChatHistory(@CookieValue(value = "userID", defaultValue = "") String userID,
@RequestParam(name = "chatID") String chatID,
HttpServletResponse response) {
User user = DBC.getUser(userID);
Chat chat = DBC.getChat(DBC.getUser(userID), chatID);
String messageHistory = chat.getMessageHistory(DBC.getUser(userID));
if (chat != null && messageHistory != null) {
response.setContentType("text/plain");
response.setStatus(200);
try {
response.getWriter().write(messageHistory);
} catch (IOException e) {
log.warn(String.valueOf(e));
if (chat == null) {
response.setStatus(401);
try {
response.getWriter().write("");
} catch (IOException e) {
log.warn(String.valueOf(e));
}
} else {
response.setContentType("text/plain");
return;
}

String messageHistory = chat.getMessageHistory(user);
if (messageHistory == null) {
response.setStatus(401);
try {
response.getWriter().write("");
} catch (IOException e) {
log.warn(String.valueOf(e));
}
return;

}
// If we reach here, chat != null and messageHistory != null
response.setStatus(200);
try {
response.getWriter().write(messageHistory);
} catch (IOException e) {
log.warn(String.valueOf(e));
response.setStatus(500);
}
}
}

This file was deleted.

140 changes: 140 additions & 0 deletions src/test/java/com/UoB/AILearningTool/SpringControllerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package com.UoB.AILearningTool;

import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.mockito.*;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.Cookie;
import org.mockito.junit.jupiter.MockitoExtension;

import java.io.IOException;
import java.io.PrintWriter;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

// integrates Mockito with JUnit
@ExtendWith(MockitoExtension.class)
class SpringControllerTest {
@Mock
private DatabaseController mockDBController;

@Mock
private OpenAIAPIController mockWXC;

@Mock
private HttpServletResponse mockResponse;

@Mock
private PrintWriter mockWriter;

// This Tells Mockito to inject the @Mock fields into SpringController's constructor
@InjectMocks
private SpringController springController;

@Captor
private ArgumentCaptor<Cookie> cookieCaptor;

@BeforeEach
public void setUp() throws IOException {
Mockito.lenient().when(mockResponse.getWriter()).thenReturn(mockWriter);
}

@Test
public void testNewUserCreated() {

// mock addUser to return a specific ID "user123" when called wit optional consent = true
boolean optionalConsent = true;
String generatedUserID = "user123";
when(mockDBController.addUser(optionalConsent)).thenReturn(generatedUserID);

// call the signup method with the arranged parameters
springController.signup(optionalConsent, mockResponse);

// verify that addUser was called with optionalConsent = true
verify(mockDBController, times(1)).addUser(optionalConsent);

// capture the cookie added to the response
verify(mockResponse, times(1)).addCookie(cookieCaptor.capture());
Cookie capturedCookie = cookieCaptor.getValue();


// assertions on the captured cookie
assertEquals("userID", capturedCookie.getName(), "Cookie name should be userID");
assertEquals(generatedUserID, capturedCookie.getValue(), "Cookie value should match generated userID");
assertEquals(30 * 24 * 60 * 60, capturedCookie.getMaxAge(), "Cookie max age should be 30 days");
}

@Test
public void testRevokeConsent() {

// define a userID "user123" to revoke consent for
String userID = "user123";

// mock removeUser to return true
when(mockDBController.removeUser(userID)).thenReturn(true);

springController.revokeConsent(userID, mockResponse);

// verify that removeUser was called with the correct userID
verify(mockDBController, times(1)).removeUser(userID);

// verify that the status was set to 200 when user was removed
verify(mockResponse, times(1)).setStatus(HttpServletResponse.SC_OK);

// capture the cookie added to the response
verify(mockResponse, times(1)).addCookie(cookieCaptor.capture());
Cookie capturedCookie = cookieCaptor.getValue();

// assertions on captured cookie
assertEquals("userID", capturedCookie.getName(), "Cookie name should be userID");
assertEquals("", capturedCookie.getValue(), "Cookie value should be empty");
assertEquals(0, capturedCookie.getMaxAge(), "Cookie max age should be 0");
}

@Test
void testCreateChat() {
// Given
String userID = "user123";
String initialMessage = "Hello chatbot!";
String generatedChatID = "abc123";
User mockUser = new User(true);
Chat mockChat = mock(Chat.class);

// Mock out DB calls
when(mockDBController.getUser(userID)).thenReturn(mockUser);
when(mockDBController.createChat(mockUser, initialMessage)).thenReturn(generatedChatID);
when(mockDBController.getChat(mockUser, generatedChatID)).thenReturn(mockChat);

// Mock chat & AI calls
String messageHistory = "System prompt\nHello chatbot!";
when(mockChat.getMessageHistory(mockUser)).thenReturn(messageHistory);

WatsonxResponse aiResponse = new WatsonxResponse(200, "AI says hi");
when(mockWXC.sendUserMessage(messageHistory)).thenReturn(aiResponse);

// When
springController.createChat(userID, initialMessage, mockResponse);

// Then
verify(mockDBController).getUser(userID);
verify(mockDBController).createChat(mockUser, initialMessage);
verify(mockDBController).getChat(mockUser, generatedChatID);

verify(mockChat).getMessageHistory(mockUser);
verify(mockWXC).sendUserMessage(messageHistory);

// AI responded with status 200, so we expect addAIMessage to be called:
verify(mockChat).addAIMessage(userID, "AI says hi");

// Finally, verify response
verify(mockResponse).setContentType("text/plain");
verify(mockResponse).setStatus(200);
verify(mockWriter).write(generatedChatID);
}

}



Loading