diff --git a/ai/ai.go b/ai/ai.go index 3db48c40..2b1a046f 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -85,7 +85,7 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil openai.CompletionRequest{ Model: openai.GPT3TextDavinci003, Prompt: question, - MaxTokens: 50, + MaxTokens: 3000, Stream: true, }, ) diff --git a/controllers/message.go b/controllers/message.go index 902c1715..6766df67 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -74,6 +74,14 @@ func (c *ApiController) GetMessage() { c.ServeJSON() } +func (c *ApiController) ResponseErrorStream(errorText string) { + event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText) + _, err := c.Ctx.ResponseWriter.Write([]byte(event)) + if err != nil { + panic(err) + } +} + // GetMessageAnswer // @Title GetMessageAnswer // @Tag Message API @@ -84,33 +92,48 @@ func (c *ApiController) GetMessage() { func (c *ApiController) GetMessageAnswer() { id := c.Input().Get("id") + c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") + c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") + message := object.GetMessage(id) if message == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) + return + } + + if message.Author != "AI" || message.ReplyTo == "" || message.Text != "" { + c.ResponseErrorStream(c.T("chat:The message is invalid")) return } chatId := util.GetId(message.Owner, message.Chat) chat := object.GetChat(chatId) if chat == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) return } if chat.Type != "AI" { - c.ResponseError(c.T("chat:The chat type must be \"AI\"")) + c.ResponseErrorStream(c.T("chat:The chat type must be \"AI\"")) + return + } + + questionMessage := object.GetMessage(message.ReplyTo) + if questionMessage == nil { + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) return } providerId := util.GetId(chat.Owner, chat.User2) provider := object.GetProvider(providerId) if provider == nil { - c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) return } if provider.Category != "AI" || provider.ClientSecret == "" { - c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) + c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) return } @@ -119,18 +142,27 @@ func (c *ApiController) GetMessageAnswer() { c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") authToken := provider.ClientSecret - question := message.Text + question := questionMessage.Text var stringBuilder strings.Builder err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) + if err != nil { + c.ResponseErrorStream(err.Error()) + return + } + + event := fmt.Sprintf("event: end\ndata: %s\n\n", "end") + _, err = c.Ctx.ResponseWriter.Write([]byte(event)) if err != nil { panic(err) } answer := stringBuilder.String() - fmt.Printf("Question: [%s]\n", message.Text) + fmt.Printf("Question: [%s]\n", questionMessage.Text) fmt.Printf("Answer: [%s]\n", answer) + message.Text = answer + object.UpdateMessage(message.GetId(), message) } // UpdateMessage @@ -181,6 +213,7 @@ func (c *ApiController) AddMessage() { CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), Organization: message.Organization, Chat: message.Chat, + ReplyTo: message.GetId(), Author: "AI", Text: "", } diff --git a/object/message.go b/object/message.go index 9ec3cc76..a5fb6b57 100644 --- a/object/message.go +++ b/object/message.go @@ -28,6 +28,7 @@ type Message struct { Organization string `xorm:"varchar(100)" json:"organization"` Chat string `xorm:"varchar(100) index" json:"chat"` + ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` Author string `xorm:"varchar(100)" json:"author"` Text string `xorm:"mediumtext" json:"text"` } diff --git a/web/src/ChatBox.js b/web/src/ChatBox.js index 54337c3d..edf6e7e2 100644 --- a/web/src/ChatBox.js +++ b/web/src/ChatBox.js @@ -13,7 +13,7 @@ // limitations under the License. import React from "react"; -import {Avatar, Input, List, Spin} from "antd"; +import {Alert, Avatar, Input, List, Spin} from "antd"; import {CopyOutlined, DislikeOutlined, LikeOutlined, SendOutlined} from "@ant-design/icons"; import i18next from "i18next"; @@ -107,7 +107,15 @@ class ChatBox extends React.Component {