Handle message answer

This commit is contained in:
Yang Luo
2023-05-02 01:30:06 +08:00
parent 2cd6f9df8e
commit 84a7fdcd07
7 changed files with 93 additions and 17 deletions

View File

@ -85,7 +85,7 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil
openai.CompletionRequest{ openai.CompletionRequest{
Model: openai.GPT3TextDavinci003, Model: openai.GPT3TextDavinci003,
Prompt: question, Prompt: question,
MaxTokens: 50, MaxTokens: 3000,
Stream: true, Stream: true,
}, },
) )

View File

@ -74,6 +74,14 @@ func (c *ApiController) GetMessage() {
c.ServeJSON() c.ServeJSON()
} }
func (c *ApiController) ResponseErrorStream(errorText string) {
event := fmt.Sprintf("event: myerror\ndata: %s\n\n", errorText)
_, err := c.Ctx.ResponseWriter.Write([]byte(event))
if err != nil {
panic(err)
}
}
// GetMessageAnswer // GetMessageAnswer
// @Title GetMessageAnswer // @Title GetMessageAnswer
// @Tag Message API // @Tag Message API
@ -84,33 +92,48 @@ func (c *ApiController) GetMessage() {
func (c *ApiController) GetMessageAnswer() { func (c *ApiController) GetMessageAnswer() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
message := object.GetMessage(id) message := object.GetMessage(id)
if message == nil { if message == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return
}
if message.Author != "AI" || message.ReplyTo == "" || message.Text != "" {
c.ResponseErrorStream(c.T("chat:The message is invalid"))
return return
} }
chatId := util.GetId(message.Owner, message.Chat) chatId := util.GetId(message.Owner, message.Chat)
chat := object.GetChat(chatId) chat := object.GetChat(chatId)
if chat == nil { if chat == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return return
} }
if chat.Type != "AI" { if chat.Type != "AI" {
c.ResponseError(c.T("chat:The chat type must be \"AI\"")) c.ResponseErrorStream(c.T("chat:The chat type must be \"AI\""))
return
}
questionMessage := object.GetMessage(message.ReplyTo)
if questionMessage == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return return
} }
providerId := util.GetId(chat.Owner, chat.User2) providerId := util.GetId(chat.Owner, chat.User2)
provider := object.GetProvider(providerId) provider := object.GetProvider(providerId)
if provider == nil { if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId))
return return
} }
if provider.Category != "AI" || provider.ClientSecret == "" { if provider.Category != "AI" || provider.ClientSecret == "" {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId))
return return
} }
@ -119,18 +142,27 @@ func (c *ApiController) GetMessageAnswer() {
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
authToken := provider.ClientSecret authToken := provider.ClientSecret
question := message.Text question := questionMessage.Text
var stringBuilder strings.Builder var stringBuilder strings.Builder
err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil {
c.ResponseErrorStream(err.Error())
return
}
event := fmt.Sprintf("event: end\ndata: %s\n\n", "end")
_, err = c.Ctx.ResponseWriter.Write([]byte(event))
if err != nil { if err != nil {
panic(err) panic(err)
} }
answer := stringBuilder.String() answer := stringBuilder.String()
fmt.Printf("Question: [%s]\n", message.Text) fmt.Printf("Question: [%s]\n", questionMessage.Text)
fmt.Printf("Answer: [%s]\n", answer) fmt.Printf("Answer: [%s]\n", answer)
message.Text = answer
object.UpdateMessage(message.GetId(), message)
} }
// UpdateMessage // UpdateMessage
@ -181,6 +213,7 @@ func (c *ApiController) AddMessage() {
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
Organization: message.Organization, Organization: message.Organization,
Chat: message.Chat, Chat: message.Chat,
ReplyTo: message.GetId(),
Author: "AI", Author: "AI",
Text: "", Text: "",
} }

View File

@ -28,6 +28,7 @@ type Message struct {
Organization string `xorm:"varchar(100)" json:"organization"` Organization string `xorm:"varchar(100)" json:"organization"`
Chat string `xorm:"varchar(100) index" json:"chat"` Chat string `xorm:"varchar(100) index" json:"chat"`
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
Author string `xorm:"varchar(100)" json:"author"` Author string `xorm:"varchar(100)" json:"author"`
Text string `xorm:"mediumtext" json:"text"` Text string `xorm:"mediumtext" json:"text"`
} }

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
import React from "react"; import React from "react";
import {Avatar, Input, List, Spin} from "antd"; import {Alert, Avatar, Input, List, Spin} from "antd";
import {CopyOutlined, DislikeOutlined, LikeOutlined, SendOutlined} from "@ant-design/icons"; import {CopyOutlined, DislikeOutlined, LikeOutlined, SendOutlined} from "@ant-design/icons";
import i18next from "i18next"; import i18next from "i18next";
@ -107,7 +107,15 @@ class ChatBox extends React.Component {
<div style={{width: "800px", margin: "0 auto", position: "relative"}}> <div style={{width: "800px", margin: "0 auto", position: "relative"}}>
<List.Item.Meta <List.Item.Meta
avatar={<Avatar style={{width: "30px", height: "30px", borderRadius: "3px"}} src={item.author === `${this.props.account.owner}/${this.props.account.name}` ? this.props.account.avatar : "https://cdn.casbin.com/casdoor/resource/built-in/admin/gpt.png"} />} avatar={<Avatar style={{width: "30px", height: "30px", borderRadius: "3px"}} src={item.author === `${this.props.account.owner}/${this.props.account.name}` ? this.props.account.avatar : "https://cdn.casbin.com/casdoor/resource/built-in/admin/gpt.png"} />}
title={<div style={{fontSize: "16px", fontWeight: "normal", lineHeight: "24px", marginTop: "-15px", marginLeft: "5px", marginRight: "80px"}}>{item.text}</div>} title={
<div style={{fontSize: "16px", fontWeight: "normal", lineHeight: "24px", marginTop: "-15px", marginLeft: "5px", marginRight: "80px"}}>
{
!item.text.includes("#ERROR#") ? item.text : (
<Alert message={item.text.slice("#ERROR#: ".length)} type="error" showIcon />
)
}
</div>
}
/> />
<div style={{position: "absolute", top: "0px", right: "0px"}} <div style={{position: "absolute", top: "0px", right: "0px"}}
> >

View File

@ -56,6 +56,7 @@ class ChatPage extends BaseListPage {
createdTime: moment().format(), createdTime: moment().format(),
organization: this.props.account.owner, organization: this.props.account.owner,
chat: this.state.chatName, chat: this.state.chatName,
replyTo: "",
author: `${this.props.account.owner}/${this.props.account.name}`, author: `${this.props.account.owner}/${this.props.account.name}`,
text: text, text: text,
}; };
@ -83,6 +84,31 @@ class ChatPage extends BaseListPage {
messages: messages, messages: messages,
}); });
if (messages.length > 0) {
const lastMessage = messages[messages.length - 1];
if (lastMessage.author === "AI" && lastMessage.replyTo !== "" && lastMessage.text === "") {
let text = "";
MessageBackend.getMessageAnswer(lastMessage.owner, lastMessage.name, (data) => {
const lastMessage2 = Setting.deepCopy(lastMessage);
text += data;
lastMessage2.text = text;
messages[messages.length - 1] = lastMessage2;
this.setState({
messages: messages,
});
}, (error) => {
Setting.showMessage("error", `${i18next.t("general:Failed to get answer")}: ${error}`);
const lastMessage2 = Setting.deepCopy(lastMessage);
lastMessage2.text = `#ERROR#: ${error}`;
messages[messages.length - 1] = lastMessage2;
this.setState({
messages: messages,
});
});
}
}
Setting.scrollToDiv(`chatbox-list-item-${messages.length}`); Setting.scrollToDiv(`chatbox-list-item-${messages.length}`);
}); });
} }

View File

@ -31,6 +31,7 @@ class MessageListPage extends BaseListPage {
createdTime: moment().format(), createdTime: moment().format(),
organization: this.props.account.owner, organization: this.props.account.owner,
chat: "", chat: "",
replyTo: "",
author: `${this.props.account.owner}/${this.props.account.name}`, author: `${this.props.account.owner}/${this.props.account.name}`,
text: "", text: "",
}; };

View File

@ -45,13 +45,20 @@ export function getMessage(owner, name) {
} }
export function getMessageAnswer(owner, name, onMessage, onError) { export function getMessageAnswer(owner, name, onMessage, onError) {
const source = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`); const eventSource = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`);
source.onmessage = function(event) {
onMessage(event.data); eventSource.addEventListener("message", (e) => {
}; onMessage(e.data);
source.onerror = function(error) { });
onError(error);
}; eventSource.addEventListener("myerror", (e) => {
onError(e.data);
eventSource.close();
});
eventSource.addEventListener("end", (e) => {
eventSource.close();
});
} }
export function updateMessage(owner, name, message) { export function updateMessage(owner, name, message) {