Add /api/get-message-answer API

This commit is contained in:
Yang Luo 2023-05-01 23:15:51 +08:00
parent eea2e1d271
commit 2cd6f9df8e
7 changed files with 159 additions and 2 deletions

View File

@ -17,6 +17,7 @@ package ai
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"strings" "strings"
"time" "time"
@ -73,3 +74,43 @@ func QueryAnswerSafe(authToken string, question string) string {
return res return res
} }
func QueryAnswerStream(authToken string, question string, writer io.Writer, builder *strings.Builder) error {
client := getProxyClientFromToken(authToken)
ctx := context.Background()
respStream, err := client.CreateCompletionStream(
ctx,
openai.CompletionRequest{
Model: openai.GPT3TextDavinci003,
Prompt: question,
MaxTokens: 50,
Stream: true,
},
)
if err != nil {
return err
}
defer respStream.Close()
for {
completion, streamErr := respStream.Recv()
if streamErr != nil {
if streamErr == io.EOF {
break
}
return streamErr
}
// Write the streamed data as Server-Sent Events
if _, err := fmt.Fprintf(writer, "data: %s\n\n", completion.Choices[0].Text); err != nil {
return err
}
// Append the response to the strings.Builder
builder.WriteString(completion.Choices[0].Text)
}
return nil
}

View File

@ -16,8 +16,11 @@ package controllers
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings"
"github.com/beego/beego/utils/pagination" "github.com/beego/beego/utils/pagination"
"github.com/casdoor/casdoor/ai"
"github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/object"
"github.com/casdoor/casdoor/util" "github.com/casdoor/casdoor/util"
) )
@ -71,6 +74,65 @@ func (c *ApiController) GetMessage() {
c.ServeJSON() c.ServeJSON()
} }
// GetMessageAnswer
// @Title GetMessageAnswer
// @Tag Message API
// @Description get message answer
// @Param id query string true "The id ( owner/name ) of the message"
// @Success 200 {object} object.Message The Response object
// @router /get-message-answer [get]
func (c *ApiController) GetMessageAnswer() {
id := c.Input().Get("id")
message := object.GetMessage(id)
if message == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return
}
chatId := util.GetId(message.Owner, message.Chat)
chat := object.GetChat(chatId)
if chat == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return
}
if chat.Type != "AI" {
c.ResponseError(c.T("chat:The chat type must be \"AI\""))
return
}
providerId := util.GetId(chat.Owner, chat.User2)
provider := object.GetProvider(providerId)
if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId))
return
}
if provider.Category != "AI" || provider.ClientSecret == "" {
c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId))
return
}
c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
authToken := provider.ClientSecret
question := message.Text
var stringBuilder strings.Builder
err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil {
panic(err)
}
answer := stringBuilder.String()
fmt.Printf("Question: [%s]\n", message.Text)
fmt.Printf("Answer: [%s]\n", answer)
}
// UpdateMessage // UpdateMessage
// @Title UpdateMessage // @Title UpdateMessage
// @Tag Message API // @Tag Message API
@ -108,7 +170,25 @@ func (c *ApiController) AddMessage() {
return return
} }
c.Data["json"] = wrapActionResponse(object.AddMessage(&message)) affected := object.AddMessage(&message)
if affected {
chatId := util.GetId(message.Owner, message.Chat)
chat := object.GetChat(chatId)
if chat != nil && chat.Type == "AI" {
answerMessage := &object.Message{
Owner: message.Owner,
Name: fmt.Sprintf("message_%s", util.GetRandomName()),
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
Organization: message.Organization,
Chat: message.Chat,
Author: "AI",
Text: "",
}
object.AddMessage(answerMessage)
}
}
c.Data["json"] = wrapActionResponse(affected)
c.ServeJSON() c.ServeJSON()
} }

View File

@ -128,7 +128,7 @@ func (c *ApiController) GetProviderFromContext(category string) (*object.Provide
if providerName != "" { if providerName != "" {
provider := object.GetProvider(util.GetId("admin", providerName)) provider := object.GetProvider(util.GetId("admin", providerName))
if provider == nil { if provider == nil {
c.ResponseError(c.T("util:The provider: %s is not found"), providerName) c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName))
return nil, nil, false return nil, nil, false
} }
return provider, nil, true return provider, nil, true

View File

@ -197,6 +197,7 @@ func initAPI() {
beego.Router("/api/get-messages", &controllers.ApiController{}, "GET:GetMessages") beego.Router("/api/get-messages", &controllers.ApiController{}, "GET:GetMessages")
beego.Router("/api/get-message", &controllers.ApiController{}, "GET:GetMessage") beego.Router("/api/get-message", &controllers.ApiController{}, "GET:GetMessage")
beego.Router("/api/get-message-answer", &controllers.ApiController{}, "GET:GetMessageAnswer")
beego.Router("/api/update-message", &controllers.ApiController{}, "POST:UpdateMessage") beego.Router("/api/update-message", &controllers.ApiController{}, "POST:UpdateMessage")
beego.Router("/api/add-message", &controllers.ApiController{}, "POST:AddMessage") beego.Router("/api/add-message", &controllers.ApiController{}, "POST:AddMessage")
beego.Router("/api/delete-message", &controllers.ApiController{}, "POST:DeleteMessage") beego.Router("/api/delete-message", &controllers.ApiController{}, "POST:DeleteMessage")

View File

@ -20,6 +20,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -141,6 +142,16 @@ func GenerateSimpleTimeId() string {
return t return t
} }
func GetRandomName() string {
rand.Seed(time.Now().UnixNano())
const charset = "0123456789abcdefghijklmnopqrstuvwxyz"
result := make([]byte, 6)
for i := range result {
result[i] = charset[rand.Intn(len(charset))]
}
return string(result)
}
func GetId(owner, name string) string { func GetId(owner, name string) string {
return fmt.Sprintf("%s/%s", owner, name) return fmt.Sprintf("%s/%s", owner, name)
} }

View File

@ -25,6 +25,20 @@ func GetCurrentTime() string {
return tm.Format(time.RFC3339) return tm.Format(time.RFC3339)
} }
func GetCurrentTimeEx(timestamp string) string {
tm := time.Now()
inputTime, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
panic(err)
}
if !tm.After(inputTime) {
tm = inputTime.Add(1 * time.Millisecond)
}
return tm.Format("2006-01-02T15:04:05.999Z07:00")
}
func GetCurrentUnixTime() string { func GetCurrentUnixTime() string {
return strconv.FormatInt(time.Now().UnixNano(), 10) return strconv.FormatInt(time.Now().UnixNano(), 10)
} }

View File

@ -44,6 +44,16 @@ export function getMessage(owner, name) {
}).then(res => res.json()); }).then(res => res.json());
} }
export function getMessageAnswer(owner, name, onMessage, onError) {
const source = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`);
source.onmessage = function(event) {
onMessage(event.data);
};
source.onerror = function(error) {
onError(error);
};
}
export function updateMessage(owner, name, message) { export function updateMessage(owner, name, message) {
const newMessage = Setting.deepCopy(message); const newMessage = Setting.deepCopy(message);
return fetch(`${Setting.ServerUrl}/api/update-message?id=${owner}/${encodeURIComponent(name)}`, { return fetch(`${Setting.ServerUrl}/api/update-message?id=${owner}/${encodeURIComponent(name)}`, {