From 84a7fdcd07d7f490a6796e32c67874ca8678e62a Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Tue, 2 May 2023 01:30:06 +0800 Subject: [PATCH] Handle message answer --- ai/ai.go | 2 +- controllers/message.go | 47 ++++++++++++++++++++++++++----- object/message.go | 1 + web/src/ChatBox.js | 12 ++++++-- web/src/ChatPage.js | 26 +++++++++++++++++ web/src/MessageListPage.js | 1 + web/src/backend/MessageBackend.js | 21 +++++++++----- 7 files changed, 93 insertions(+), 17 deletions(-) diff --git a/ai/ai.go b/ai/ai.go index 3db48c40..2b1a046f 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -85,7 +85,7 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil openai.CompletionRequest{ Model: openai.GPT3TextDavinci003, Prompt: question, - MaxTokens: 50, + MaxTokens: 3000, Stream: true, }, ) diff --git a/controllers/message.go b/controllers/message.go index 902c1715..6766df67 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -74,6 +74,14 @@ func (c *ApiController) GetMessage() { c.ServeJSON() } +func (c *ApiController) ResponseErrorStream(errorText string) { + event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) + _, err := c.Ctx.ResponseWriter.Write([]byte(event)) + if err != nil { + panic(err) + } +} + // GetMessageAnswer // @Title GetMessageAnswer // @Tag Message API @@ -84,33 +92,48 @@ func (c *ApiController) GetMessage() { func (c *ApiController) GetMessageAnswer() { id := c.Input().Get("id") + c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") + c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") + message := object.GetMessage(id) if message == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) + return + } + + if message.Author != "AI" || message.ReplyTo == "" || message.Text != "" { + c.ResponseErrorStream(c.T("chat:The message is invalid")) return } chatId := util.GetId(message.Owner, message.Chat) chat := object.GetChat(chatId) if chat == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) return } if chat.Type != "AI" { - c.ResponseError(c.T("chat:The chat type must be \"AI\"")) + c.ResponseErrorStream(c.T("chat:The chat type must be \"AI\"")) + return + } + + questionMessage := object.GetMessage(message.ReplyTo) + if questionMessage == nil { + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) return } providerId := util.GetId(chat.Owner, chat.User2) provider := object.GetProvider(providerId) if provider == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) return } if provider.Category != "AI" || provider.ClientSecret == "" { - c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) return } @@ -119,18 +142,27 @@ func (c *ApiController) GetMessageAnswer() { c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") authToken := provider.ClientSecret - question := message.Text + question := questionMessage.Text var stringBuilder strings.Builder err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) + if err != nil { + c.ResponseErrorStream(err.Error()) + return + } + + event := fmt.Sprintf("event: end\ndata: %s\n\n", "end") + _, err = c.Ctx.ResponseWriter.Write([]byte(event)) if err != nil { panic(err) } answer := stringBuilder.String() - fmt.Printf("Question: [%s]\n", message.Text) + fmt.Printf("Question: [%s]\n", questionMessage.Text) fmt.Printf("Answer: [%s]\n", answer) + message.Text = answer + object.UpdateMessage(message.GetId(), message) } // UpdateMessage @@ -181,6 +213,7 @@ func (c *ApiController) AddMessage() { CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), Organization: message.Organization, Chat: message.Chat, + ReplyTo: message.GetId(), Author: "AI", Text: "", } diff --git a/object/message.go b/object/message.go index 9ec3cc76..a5fb6b57 100644 --- a/object/message.go +++ b/object/message.go @@ -28,6 +28,7 @@ type Message struct { Organization string `xorm:"varchar(100)" json:"organization"` Chat string `xorm:"varchar(100) index" json:"chat"` + ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` Author string `xorm:"varchar(100)" json:"author"` Text string `xorm:"mediumtext" json:"text"` } diff --git a/web/src/ChatBox.js b/web/src/ChatBox.js index 54337c3d..edf6e7e2 100644 --- a/web/src/ChatBox.js +++ b/web/src/ChatBox.js @@ -13,7 +13,7 @@ // limitations under the License. import React from "react"; -import {Avatar, Input, List, Spin} from "antd"; +import {Alert, Avatar, Input, List, Spin} from "antd"; import {CopyOutlined, DislikeOutlined, LikeOutlined, SendOutlined} from "@ant-design/icons"; import i18next from "i18next"; @@ -107,7 +107,15 @@ class ChatBox extends React.Component {
} - title={
{item.text}
} + title={ +
+ { + !item.text.includes("#ERROR#") ? item.text : ( + + ) + } +
+ } />
diff --git a/web/src/ChatPage.js b/web/src/ChatPage.js index fad63b01..71c643fd 100644 --- a/web/src/ChatPage.js +++ b/web/src/ChatPage.js @@ -56,6 +56,7 @@ class ChatPage extends BaseListPage { createdTime: moment().format(), organization: this.props.account.owner, chat: this.state.chatName, + replyTo: "", author: `${this.props.account.owner}/${this.props.account.name}`, text: text, }; @@ -83,6 +84,31 @@ class ChatPage extends BaseListPage { messages: messages, }); + if (messages.length > 0) { + const lastMessage = messages[messages.length - 1]; + if (lastMessage.author === "AI" && lastMessage.replyTo !== "" && lastMessage.text === "") { + let text = ""; + MessageBackend.getMessageAnswer(lastMessage.owner, lastMessage.name, (data) => { + const lastMessage2 = Setting.deepCopy(lastMessage); + text += data; + lastMessage2.text = text; + messages[messages.length - 1] = lastMessage2; + this.setState({ + messages: messages, + }); + }, (error) => { + Setting.showMessage("error", `${i18next.t("general:Failed to get answer")}: ${error}`); + + const lastMessage2 = Setting.deepCopy(lastMessage); + lastMessage2.text = `#ERROR#: ${error}`; + messages[messages.length - 1] = lastMessage2; + this.setState({ + messages: messages, + }); + }); + } + } + Setting.scrollToDiv(`chatbox-list-item-${messages.length}`); }); } diff --git a/web/src/MessageListPage.js b/web/src/MessageListPage.js index 4a482fe4..a6f4fd52 100644 --- a/web/src/MessageListPage.js +++ b/web/src/MessageListPage.js @@ -31,6 +31,7 @@ class MessageListPage extends BaseListPage { createdTime: moment().format(), organization: this.props.account.owner, chat: "", + replyTo: "", author: `${this.props.account.owner}/${this.props.account.name}`, text: "", }; diff --git a/web/src/backend/MessageBackend.js b/web/src/backend/MessageBackend.js index 2a0c8439..2e2183f7 100644 --- a/web/src/backend/MessageBackend.js +++ b/web/src/backend/MessageBackend.js @@ -45,13 +45,20 @@ export function getMessage(owner, name) { } export function getMessageAnswer(owner, name, onMessage, onError) { - const source = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`); - source.onmessage = function(event) { - onMessage(event.data); - }; - source.onerror = function(error) { - onError(error); - }; + const eventSource = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`); + + eventSource.addEventListener("message", (e) => { + onMessage(e.data); + }); + + eventSource.addEventListener("myerror", (e) => { + onError(e.data); + eventSource.close(); + }); + + eventSource.addEventListener("end", (e) => { + eventSource.close(); + }); } export function updateMessage(owner, name, message) {