Add getTokenSize()

This commit is contained in:
Yang Luo
2023-05-02 10:04:11 +08:00
parent d699774179
commit cb13d693e6
5 changed files with 54 additions and 3 deletions

View File

@ -80,12 +80,22 @@ func QueryAnswerStream(authToken string, question string, writer io.Writer, buil
ctx := context.Background()
// https://platform.openai.com/tokenizer
// https://github.com/pkoukk/tiktoken-go#available-encodings
promptTokens, err := getTokenSize(openai.GPT3TextDavinci003, question)
if err != nil {
return err
}
// https://platform.openai.com/docs/models/gpt-3-5
maxTokens := 4097 - promptTokens
respStream, err := client.CreateCompletionStream(
ctx,
openai.CompletionRequest{
Model: openai.GPT3TextDavinci003,
Prompt: question,
MaxTokens: 3000,
MaxTokens: maxTokens,
Stream: true,
},
)

View File

@ -19,6 +19,7 @@ import (
"github.com/casdoor/casdoor/object"
"github.com/casdoor/casdoor/proxy"
"github.com/sashabaranov/go-openai"
)
func TestRun(t *testing.T) {
@ -32,3 +33,7 @@ func TestRun(t *testing.T) {
println(text)
}
func TestToken(t *testing.T) {
println(getTokenSize(openai.GPT3TextDavinci003, ""))
}

28
ai/util.go Normal file
View File

@ -0,0 +1,28 @@
// Copyright 2023 The Casdoor Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ai
import "github.com/pkoukk/tiktoken-go"
func getTokenSize(model string, prompt string) (int, error) {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
return 0, err
}
token := tkm.Encode(prompt, nil, nil)
res := len(token)
return res, nil
}