diff --git a/ai/ai.go b/ai/ai.go index 59828223..3db48c40 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -17,6 +17,7 @@ package ai import ( "context" "fmt" + "io" "strings" "time" @@ -73,3 +74,43 @@ func QueryAnswerSafe(authToken string, question string) string { return res } + +func QueryAnswerStream(authToken string, question string, writer io.Writer, builder *strings.Builder) error { + client := getProxyClientFromToken(authToken) + + ctx := context.Background() + + respStream, err := client.CreateCompletionStream( + ctx, + openai.CompletionRequest{ + Model: openai.GPT3TextDavinci003, + Prompt: question, + MaxTokens: 50, + Stream: true, + }, + ) + if err != nil { + return err + } + defer respStream.Close() + + for { + completion, streamErr := respStream.Recv() + if streamErr != nil { + if streamErr == io.EOF { + break + } + return streamErr + } + + // Write the streamed data as Server-Sent Events + if _, err := fmt.Fprintf(writer, "data: %s\n\n", completion.Choices[0].Text); err != nil { + return err + } + + // Append the response to the strings.Builder + builder.WriteString(completion.Choices[0].Text) + } + + return nil +} diff --git a/controllers/message.go b/controllers/message.go index 7adb7314..902c1715 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -16,8 +16,11 @@ package controllers import ( "encoding/json" + "fmt" + "strings" "github.com/beego/beego/utils/pagination" + "github.com/casdoor/casdoor/ai" "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/util" ) @@ -71,6 +74,65 @@ func (c *ApiController) GetMessage() { c.ServeJSON() } +// GetMessageAnswer +// @Title GetMessageAnswer +// @Tag Message API +// @Description get message answer +// @Param id query string true "The id ( owner/name ) of the message" +// @Success 200 {object} object.Message The Response object +// @router /get-message-answer [get] +func (c *ApiController) GetMessageAnswer() { + id := c.Input().Get("id") + + message := object.GetMessage(id) + if message == nil { + c.ResponseError(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) + return + } + + chatId := util.GetId(message.Owner, message.Chat) + chat := object.GetChat(chatId) + if chat == nil { + c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) + return + } + + if chat.Type != "AI" { + c.ResponseError(c.T("chat:The chat type must be \"AI\"")) + return + } + + providerId := util.GetId(chat.Owner, chat.User2) + provider := object.GetProvider(providerId) + if provider == nil { + c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) + return + } + + if provider.Category != "AI" || provider.ClientSecret == "" { + c.ResponseError(fmt.Sprintf(c.T("chat:The provider: %s is invalid"), providerId)) + return + } + + c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") + c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") + + authToken := provider.ClientSecret + question := message.Text + var stringBuilder strings.Builder + err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) + if err != nil { + panic(err) + } + + answer := stringBuilder.String() + + fmt.Printf("Question: [%s]\n", message.Text) + fmt.Printf("Answer: [%s]\n", answer) + +} + // UpdateMessage // @Title UpdateMessage // @Tag Message API @@ -108,7 +170,25 @@ func (c *ApiController) AddMessage() { return } - c.Data["json"] = wrapActionResponse(object.AddMessage(&message)) + affected := object.AddMessage(&message) + if affected { + chatId := util.GetId(message.Owner, message.Chat) + chat := object.GetChat(chatId) + if chat != nil && chat.Type == "AI" { + answerMessage := &object.Message{ + Owner: message.Owner, + Name: fmt.Sprintf("message_%s", util.GetRandomName()), + CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), + Organization: message.Organization, + Chat: message.Chat, + Author: "AI", + Text: "", + } + object.AddMessage(answerMessage) + } + } + + c.Data["json"] = wrapActionResponse(affected) c.ServeJSON() } diff --git a/controllers/util.go b/controllers/util.go index f3cc8428..498b6b92 100644 --- a/controllers/util.go +++ b/controllers/util.go @@ -128,7 +128,7 @@ func (c *ApiController) GetProviderFromContext(category string) (*object.Provide if providerName != "" { provider := object.GetProvider(util.GetId("admin", providerName)) if provider == nil { - c.ResponseError(c.T("util:The provider: %s is not found"), providerName) + c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName)) return nil, nil, false } return provider, nil, true diff --git a/routers/router.go b/routers/router.go index 4e0baf28..cc6c9e70 100644 --- a/routers/router.go +++ b/routers/router.go @@ -197,6 +197,7 @@ func initAPI() { beego.Router("/api/get-messages", &controllers.ApiController{}, "GET:GetMessages") beego.Router("/api/get-message", &controllers.ApiController{}, "GET:GetMessage") + beego.Router("/api/get-message-answer", &controllers.ApiController{}, "GET:GetMessageAnswer") beego.Router("/api/update-message", &controllers.ApiController{}, "POST:UpdateMessage") beego.Router("/api/add-message", &controllers.ApiController{}, "POST:AddMessage") beego.Router("/api/delete-message", &controllers.ApiController{}, "POST:DeleteMessage") diff --git a/util/string.go b/util/string.go index d5409f47..51e0487d 100644 --- a/util/string.go +++ b/util/string.go @@ -20,6 +20,7 @@ import ( "encoding/hex" "errors" "fmt" + "math/rand" "os" "strconv" "strings" @@ -141,6 +142,16 @@ func GenerateSimpleTimeId() string { return t } +func GetRandomName() string { + rand.Seed(time.Now().UnixNano()) + const charset = "0123456789abcdefghijklmnopqrstuvwxyz" + result := make([]byte, 6) + for i := range result { + result[i] = charset[rand.Intn(len(charset))] + } + return string(result) +} + func GetId(owner, name string) string { return fmt.Sprintf("%s/%s", owner, name) } diff --git a/util/time.go b/util/time.go index 7edcaa4c..e72f1762 100644 --- a/util/time.go +++ b/util/time.go @@ -25,6 +25,20 @@ func GetCurrentTime() string { return tm.Format(time.RFC3339) } +func GetCurrentTimeEx(timestamp string) string { + tm := time.Now() + inputTime, err := time.Parse(time.RFC3339, timestamp) + if err != nil { + panic(err) + } + + if !tm.After(inputTime) { + tm = inputTime.Add(1 * time.Millisecond) + } + + return tm.Format("2006-01-02T15:04:05.999Z07:00") +} + func GetCurrentUnixTime() string { return strconv.FormatInt(time.Now().UnixNano(), 10) } diff --git a/web/src/backend/MessageBackend.js b/web/src/backend/MessageBackend.js index e46e5c75..2a0c8439 100644 --- a/web/src/backend/MessageBackend.js +++ b/web/src/backend/MessageBackend.js @@ -44,6 +44,16 @@ export function getMessage(owner, name) { }).then(res => res.json()); } +export function getMessageAnswer(owner, name, onMessage, onError) { + const source = new EventSource(`${Setting.ServerUrl}/api/get-message-answer?id=${owner}/${encodeURIComponent(name)}`); + source.onmessage = function(event) { + onMessage(event.data); + }; + source.onerror = function(error) { + onError(error); + }; +} + export function updateMessage(owner, name, message) { const newMessage = Setting.deepCopy(message); return fetch(`${Setting.ServerUrl}/api/update-message?id=${owner}/${encodeURIComponent(name)}`, {