diff --git a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread2.java b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread.java similarity index 81% rename from electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread2.java rename to electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread.java index 5a8bbce..bbc2e2b 100644 --- a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread2.java +++ b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread.java @@ -10,12 +10,13 @@ import java.util.concurrent.Callable; @AllArgsConstructor @NoArgsConstructor -public class ChatTaskThread2 implements Callable> { +public class ChatTaskThread implements Callable> { + private ChatService chatService; private QueryDTO queryDTO; @Override public Flux call() throws Exception { - return chatService.chatStreamStr(queryDTO.getMsg()); + return chatService.chatStreamStr(queryDTO); } } diff --git a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread1.java b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread1.java deleted file mode 100644 index 0642af1..0000000 --- a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread1.java +++ /dev/null @@ -1,21 +0,0 @@ -package com.electromagnetic.industry.software.manage.ai; - -import com.electromagnetic.industry.software.manage.pojo.req.QueryDTO; -import com.electromagnetic.industry.software.manage.service.serviceimpl.ChatService; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; - -import java.util.concurrent.Callable; - -@NoArgsConstructor -@AllArgsConstructor -public class ChatTaskThread1 implements Callable { - - private ChatService chatService; - private QueryDTO queryDTO; - - @Override - public String call() throws Exception { - return chatService.chat(queryDTO); - } -} diff --git a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread3.java b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread3.java deleted file mode 100644 index a5d26fb..0000000 --- a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/ai/ChatTaskThread3.java +++ /dev/null @@ -1,23 +0,0 @@ -package com.electromagnetic.industry.software.manage.ai; - -import com.electromagnetic.industry.software.manage.pojo.req.QueryDTO; -import com.electromagnetic.industry.software.manage.service.serviceimpl.ChatService; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.springframework.ai.chat.model.ChatResponse; -import reactor.core.publisher.Flux; - -import java.util.concurrent.Callable; - -@AllArgsConstructor -@NoArgsConstructor -public class ChatTaskThread3 implements Callable> { - - private ChatService chatService; - private QueryDTO queryDTO; - - @Override - public Flux call() throws Exception { - return chatService.chatStreamResponse(queryDTO.getMsg()); - } -} diff --git a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/controller/AiController.java b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/controller/AiController.java index 07ae5d7..5a53cd0 100644 --- a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/controller/AiController.java +++ b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/controller/AiController.java @@ -2,16 +2,12 @@ package com.electromagnetic.industry.software.manage.controller; import cn.hutool.core.util.StrUtil; import com.electromagnetic.industry.software.common.resp.ElectromagneticResult; -import com.electromagnetic.industry.software.common.util.ElectromagneticResultUtil; -import com.electromagnetic.industry.software.manage.ai.ChatTaskThread1; -import com.electromagnetic.industry.software.manage.ai.ChatTaskThread2; -import com.electromagnetic.industry.software.manage.ai.ChatTaskThread3; +import com.electromagnetic.industry.software.manage.ai.ChatTaskThread; import com.electromagnetic.industry.software.manage.ai.ThreadUtil; import com.electromagnetic.industry.software.manage.pojo.req.QueryDTO; import com.electromagnetic.industry.software.manage.service.serviceimpl.ChatService; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; @@ -33,34 +29,34 @@ public class AiController { return chatService.addFromUpload(file); } - @PostMapping("/chat") - public ElectromagneticResult chat(@RequestBody QueryDTO queryDTO) throws Exception { - log.info("question is --->" + queryDTO.getMsg()); - ChatTaskThread1 chatTaskThread = new ChatTaskThread1(chatService, queryDTO); - Future future = ThreadUtil.getThreadPool().submit(chatTaskThread); - String res = future.get(); - log.info("answer is --->" + res); - return ElectromagneticResultUtil.success(res); - } +// @PostMapping("/chat") +// public ElectromagneticResult chat(@RequestBody QueryDTO queryDTO) throws Exception { +// log.info("question is --->" + queryDTO.getMsg()); +// ChatTaskThread chatTaskThread = new ChatTaskThread<>(chatService, queryDTO); +// Future future = ThreadUtil.getThreadPool().submit(chatTaskThread); +// String res = future.get(); +// log.info("answer is --->" + res); +// return ElectromagneticResultUtil.success(res); +// } @PostMapping(path = "/chatStreamStr", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux chatStreamStr(@RequestBody QueryDTO queryDTO) throws ExecutionException, InterruptedException { if (StrUtil.isEmpty(queryDTO.getMsg())) { return Flux.empty(); } - ChatTaskThread2 chatTaskThread = new ChatTaskThread2(chatService, queryDTO); + ChatTaskThread chatTaskThread = new ChatTaskThread(chatService, queryDTO); Future> future = ThreadUtil.getThreadPool().submit(chatTaskThread); return future.get(); } - @PostMapping(path = "/chatStreamResp", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux chatStreamResp(@RequestBody QueryDTO queryDTO) throws ExecutionException, InterruptedException { - if (StrUtil.isEmpty(queryDTO.getMsg())) { - return Flux.empty(); - } - ChatTaskThread3 chatTaskThread = new ChatTaskThread3(chatService, queryDTO); - Future> future = ThreadUtil.getThreadPool().submit(chatTaskThread); - return future.get(); - } +// @PostMapping(path = "/chatStreamResp", produces = MediaType.TEXT_EVENT_STREAM_VALUE) +// public Flux chatStreamResp(@RequestBody QueryDTO queryDTO) throws ExecutionException, InterruptedException { +// if (StrUtil.isEmpty(queryDTO.getMsg())) { +// return Flux.empty(); +// } +// ChatTaskThread> chatTaskThread = new ChatTaskThread<>(chatService, queryDTO); +// Future> future = ThreadUtil.getThreadPool().submit(chatTaskThread); +// return future.get(); +// } } diff --git a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/service/serviceimpl/ChatService.java b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/service/serviceimpl/ChatService.java index ec09401..16a1a80 100644 --- a/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/service/serviceimpl/ChatService.java +++ b/electrmangnetic/src/main/java/com/electromagnetic/industry/software/manage/service/serviceimpl/ChatService.java @@ -17,12 +17,8 @@ import com.electromagnetic.industry.software.manage.pojo.req.QueryDTO; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.vectorstore.VectorStore; @@ -107,25 +103,25 @@ public class ChatService { return ElectromagneticResultUtil.success(fileMd5); } - public String chat(QueryDTO queryDTO) { +// public String chat(String msg) { +// +// log.info("Start call model to answer"); +// +// return ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt() +// .user(msg) +// .advisors(advisorSpec -> advisorSpec +//// .param(CHAT_MEMORY_CONVERSATION_ID_KEY, queryDTO.getUserId()) +// .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)) +// .call() +// .content(); +// } +// +// public Flux chatStreamResponse(String msg) { +// ChatClient.StreamResponseSpec stream = ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt(new Prompt(new UserMessage(msg))).stream(); +// return stream.chatResponse(); +// } - log.info("Start call model to answer"); - - return ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt() - .user(queryDTO.getMsg()) - .advisors(advisorSpec -> advisorSpec -// .param(CHAT_MEMORY_CONVERSATION_ID_KEY, queryDTO.getUserId()) - .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)) - .call() - .content(); - } - - public Flux chatStreamStr(String msg) { - return ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt(msg).stream().content(); - } - - public Flux chatStreamResponse(String msg) { - ChatClient.StreamResponseSpec stream = ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt(new Prompt(new UserMessage(msg))).stream(); - return stream.chatResponse(); + public Flux chatStreamStr(QueryDTO queryDTO) { + return ChatClient.builder(model).defaultAdvisors(messageChatMemoryAdvisor, questionAnswerAdvisor).build().prompt(queryDTO.getMsg()).stream().content(); } }