Add /api/get-message-answer API

This commit is contained in:
Yang Luo
2023-05-01 23:15:51 +08:00
parent eea2e1d271
commit 2cd6f9df8e
7 changed files with 159 additions and 2 deletions

View File

@ -16,8 +16,11 @@ package controllers
import (
"encoding/json"
"fmt"
"strings"
"github.com/beego/beego/utils/pagination"
"github.com/casdoor/casdoor/ai"
"github.com/casdoor/casdoor/object"
"github.com/casdoor/casdoor/util"
)
@ -71,6 +74,65 @@ func (c *ApiController) GetMessage() {
c.ServeJSON()
}
// GetMessageAnswer
// @Title GetMessageAnswer
// @Tag Message API
// @Description get message answer
// @Param id query string true "The id ( owner/name ) of the message"
// @Success 200 {object} object.Message The Response object
// @router /get-message-answer [get]
func (c *ApiController) GetMessageAnswer() {
id := c.Input().Get("id")
message := object.GetMessage(id)
if message == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return
}
chatId := util.GetId(message.Owner, message.Chat)
chat := object.GetChat(chatId)
if chat == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return
}
if chat.Type != "AI" {
c.ResponseError(c.T("chat:The chat type must be \"AI\""))
return
}
providerId := util.GetId(chat.Owner, chat.User2)
provider := object.GetProvider(providerId)
if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId))
return
}
if provider.Category != "AI" || provider.ClientSecret == "" {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId))
return
}
c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
authToken := provider.ClientSecret
question := message.Text
var stringBuilder strings.Builder
err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil {
panic(err)
}
answer := stringBuilder.String()
fmt.Printf("Question: [%s]\n", message.Text)
fmt.Printf("Answer: [%s]\n", answer)
}
// UpdateMessage
// @Title UpdateMessage
// @Tag Message API
@ -108,7 +170,25 @@ func (c *ApiController) AddMessage() {
return
}
c.Data["json"] = wrapActionResponse(object.AddMessage(&message))
affected := object.AddMessage(&message)
if affected {
chatId := util.GetId(message.Owner, message.Chat)
chat := object.GetChat(chatId)
if chat != nil && chat.Type == "AI" {
answerMessage := &object.Message{
Owner: message.Owner,
Name: fmt.Sprintf("message_%s", util.GetRandomName()),
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
Organization: message.Organization,
Chat: message.Chat,
Author: "AI",
Text: "",
}
object.AddMessage(answerMessage)
}
}
c.Data["json"] = wrapActionResponse(affected)
c.ServeJSON()
}