diff --git a/ai/ai.go b/ai/ai.go index 2b1a046f..fa97d4da 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -80,12 +80,22 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil ctx := context.Background() + // https://platform.openai.com/tokenizer + // https://github.com/pkoukk/tiktoken-go#available-encodings + promptTokens, err := getTokenSize(openai.GPT3TextDavinci003, question) + if err != nil { + return err + } + + // https://platform.openai.com/docs/models/gpt-3-5 + maxTokens := 4097 - promptTokens + respStream, err := client.CreateCompletionStream( ctx, openai.CompletionRequest{ Model: openai.GPT3TextDavinci003, Prompt: question, - MaxTokens: 3000, + MaxTokens: maxTokens, Stream: true, }, ) diff --git a/ai/ai_test.go b/ai/ai_test.go index 063bfd62..b4a094cb 100644 --- a/ai/ai_test.go +++ b/ai/ai_test.go @@ -19,6 +19,7 @@ import ( "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/proxy" + "github.com/sashabaranov/go-openai" ) func TestRun(t *testing.T) { @@ -32,3 +33,7 @@ func TestRun(t *testing.T) { println(text) } + +func TestToken(t *testing.T) { + println(getTokenSize(openai.GPT3TextDavinci003, "")) +} diff --git a/ai/util.go b/ai/util.go new file mode 100644 index 00000000..e0f1085f --- /dev/null +++ b/ai/util.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Casdoor Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import "github.com/pkoukk/tiktoken-go" + +func getTokenSize(model string, prompt string) (int, error) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + return 0, err + } + + token := tkm.Encode(prompt, nil, nil) + res := len(token) + return res, nil +} diff --git a/go.mod b/go.mod index d2a13d44..9046d405 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/casdoor/xorm-adapter/v3 v3.0.4 github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f github.com/denisenkom/go-mssqldb v0.9.0 + github.com/dlclark/regexp2 v1.9.0 // indirect github.com/fogleman/gg v1.3.0 github.com/forestmgy/ldapserver v1.1.0 github.com/go-git/go-git/v5 v5.6.0 @@ -36,6 +37,7 @@ require ( github.com/markbates/goth v1.75.2 github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect github.com/nyaruka/phonenumbers v1.1.5 + github.com/pkoukk/tiktoken-go v0.1.1 github.com/prometheus/client_golang v1.7.0 github.com/prometheus/client_model v0.2.0 github.com/qiangmzsx/string-adapter/v2 v2.1.0 @@ -47,7 +49,7 @@ require ( github.com/shirou/gopsutil v3.21.11+incompatible github.com/siddontang/go-log v0.0.0-20190221022429-1e957dd83bed github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.2 github.com/tealeg/xlsx v1.0.5 github.com/thanhpk/randstr v1.0.4 github.com/tklauser/go-sysconf v0.3.10 // indirect diff --git a/go.sum b/go.sum index dd28f058..3b492866 100644 --- a/go.sum +++ b/go.sum @@ -161,6 +161,9 @@ github.com/denisenkom/go-mssqldb v0.9.0 h1:RSohk2RsiZqLZ0zCjtfn3S4Gp4exhpBWHyQ7D github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI= +github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= @@ -476,6 +479,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo= +github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -573,8 +578,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=