Skip to content

Commit

Permalink
换用廉价模型 (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
flycash committed Sep 22, 2024
1 parent b629ce3 commit f5225a1
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 56 deletions.
5 changes: 2 additions & 3 deletions internal/ai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ func InitHandlerFacade(common []handler.Builder,

func InitZhipu() *zhipu.Handler {
type Config struct {
APIKey string `yaml:"apikey"`
Price float64 `yaml:"price"`
APIKey string `yaml:"apikey"`
}
var cfg Config
err := econf.UnmarshalKey("zhipu", &cfg)
if err != nil {
panic(err)
}
h, err := zhipu.NewHandler(cfg.APIKey, cfg.Price)
h, err := zhipu.NewHandler(cfg.APIKey)
if err != nil {
panic(err)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/ai/internal/domain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ type LLMResponse struct {
}

type BizConfig struct {
// 使用的模型
Model string
// 多少分钱/1000 token
Price int64

Temperature float64
TopP float64

// 系统 Prompt
SystemPrompt string
// 允许的最长输入
// 这里我们不用计算 token,只需要简单约束一下字符串长度就可以
MaxInput int
Expand Down
23 changes: 16 additions & 7 deletions internal/ai/internal/integration/llm_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (s *LLMServiceSuite) SetupSuite() {
err = s.db.Create(&dao.BizConfig{
Biz: domain.BizQuestionExamine,
MaxInput: 100,
PromptTemplate: "这是问题 %s,这是用户输入 %s",
PromptTemplate: "这是问题 %s,这是问题内容 %s,这是用户输入 %s",
KnowledgeId: knowledgeId,
Ctime: now,
Utime: now,
Expand All @@ -60,7 +60,7 @@ func (s *LLMServiceSuite) SetupSuite() {
err = s.db.Create(&dao.BizConfig{
Biz: domain.BizCaseExamine,
MaxInput: 100,
PromptTemplate: "这是案例 %s,这是用户输入 %s",
PromptTemplate: "这是案例 %s,这是案例内容 %s,这是用户输入 %s",
KnowledgeId: knowledgeId,
Ctime: now,
Utime: now,
Expand Down Expand Up @@ -97,6 +97,7 @@ func (s *LLMServiceSuite) TestService() {
Tid: "11",
Input: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Expand Down Expand Up @@ -142,11 +143,12 @@ func (s *LLMServiceSuite) TestService() {
Valid: true,
Val: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Status: 1,
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
Answer: sqlx.NewNullString("aians"),
}, logModel)
// 校验credit写入的内容是否正确
Expand All @@ -171,6 +173,7 @@ func (s *LLMServiceSuite) TestService() {
Tid: "13",
Input: []string{
"案例1",
"案例1内容",
"用户输入1",
},
},
Expand Down Expand Up @@ -217,11 +220,12 @@ func (s *LLMServiceSuite) TestService() {
Valid: true,
Val: []string{
"案例1",
"案例1内容",
"用户输入1",
},
},
Status: 1,
PromptTemplate: sqlx.NewNullString("这是案例 %s,这是用户输入 %s"),
PromptTemplate: sqlx.NewNullString("这是案例 %s,这是案例内容 %s,这是用户输入 %s"),
Answer: sqlx.NewNullString("aians"),
}, logModel)
// 校验credit写入的内容是否正确
Expand Down Expand Up @@ -274,11 +278,12 @@ func (s *LLMServiceSuite) TestService() {
Valid: true,
Val: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Status: domain.RecordStatusFailed.ToUint8(),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
}, logModel)
},
assertFunc: assert.Error,
Expand All @@ -291,6 +296,7 @@ func (s *LLMServiceSuite) TestService() {
Tid: "11",
Input: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Expand Down Expand Up @@ -323,11 +329,12 @@ func (s *LLMServiceSuite) TestService() {
Valid: true,
Val: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Status: domain.CreditStatusFailed.ToUint8(),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
Answer: sqlx.NewNullString("aians"),
}, logModel)
// 校验credit写入的内容是否正确
Expand All @@ -353,6 +360,7 @@ func (s *LLMServiceSuite) TestService() {
Tid: "11",
Input: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Expand Down Expand Up @@ -412,11 +420,12 @@ func (s *LLMServiceSuite) TestService() {
Valid: true,
Val: []string{
"问题1",
"问题1内容",
"用户输入1",
},
},
Status: domain.RecordStatusFailed.ToUint8(),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是用户输入 %s"),
PromptTemplate: sqlx.NewNullString("这是问题 %s,这是问题内容 %s,这是用户输入 %s"),
Answer: sqlx.NewNullString("aians"),
}, logModel)
// 校验credit写入的内容是否正确
Expand Down
5 changes: 5 additions & 0 deletions internal/ai/internal/repository/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func (repo *CachedConfigRepository) GetConfig(ctx context.Context, biz string) (
return domain.BizConfig{}, err
}
return domain.BizConfig{
Model: res.Model,
Price: res.Price,
Temperature: res.Temperature,
TopP: res.TopP,
SystemPrompt: res.SystemPrompt,
MaxInput: res.MaxInput,
PromptTemplate: res.PromptTemplate,
KnowledgeId: res.KnowledgeId,
Expand Down
12 changes: 9 additions & 3 deletions internal/ai/internal/repository/dao/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ func (dao *GORMConfigDAO) GetConfig(ctx context.Context, biz string) (BizConfig,
}

type BizConfig struct {
Id int64 `gorm:"primaryKey;autoIncrement;comment:AI biz 配置表ID"`
Biz string `gorm:"type:varchar(256);uniqueIndex;not null;comment:业务类型名"`
MaxInput int `gorm:"comment:最大输入长度"`
Id int64 `gorm:"primaryKey;autoIncrement;comment:AI biz 配置表ID"`
Biz string `gorm:"type:varchar(256);uniqueIndex;not null;comment:业务类型名"`
MaxInput int `gorm:"comment:最大输入长度"`
Model string `gorm:"type:varchar(256)"`
Price int64
Temperature float64
TopP float64
// 系统 prompt
SystemPrompt string
PromptTemplate string
KnowledgeId string `gorm:"type:varchar(256);not null;comment:使用的知识库 ID"`
// 其它字段按需添加
Expand Down
5 changes: 3 additions & 2 deletions internal/ai/internal/service/llm/handler/biz/case_examine.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ func NewCaseExamineBizHandlerBuilder() *CaseExamineBizHandlerBuilder {
func (h *CaseExamineBizHandlerBuilder) Next(next handler.Handler) handler.Handler {
return handler.HandleFunc(func(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
title := req.Input[0]
userInput := req.Input[1]
refCase := req.Input[1]
userInput := req.Input[2]
userInputLen := utf8.RuneCount([]byte(userInput))

if userInputLen > req.Config.MaxInput {
return domain.LLMResponse{}, fmt.Errorf("输入太长,最常不超过 %d,现有长度 %d", req.Config.MaxInput, userInputLen)
}
// 把 input 和 prompt 结合起来
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, userInput)
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, refCase, userInput)
req.Prompt = prompt
return next.Handle(ctx, req)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ func NewQuestionExamineBizHandlerBuilder() *QuestionExamineBizHandlerBuilder {
func (h *QuestionExamineBizHandlerBuilder) Next(next handler.Handler) handler.Handler {
return handler.HandleFunc(func(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
title := req.Input[0]
userInput := req.Input[1]
answer := req.Input[1]
userInput := req.Input[2]
userInputLen := utf8.RuneCount([]byte(userInput))

if userInputLen > req.Config.MaxInput {
return domain.LLMResponse{}, fmt.Errorf("输入太长,最常不超过 %d,现有长度 %d", req.Config.MaxInput, userInputLen)
}
// 把 input 和 prompt 结合起来
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, userInput)
prompt := fmt.Sprintf(req.Config.PromptTemplate, title, answer, userInput)
req.Prompt = prompt
return next.Handle(ctx, req)
})
Expand Down
50 changes: 34 additions & 16 deletions internal/ai/internal/service/llm/handler/platform/zhipu/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,16 @@ import (
// Handler 如果后续有不同的实现,就提供不同的实现
type Handler struct {
client *zhipu.Client
svc *zhipu.ChatCompletionService
// 价格和 model 进行绑定的
price float64
}

func NewHandler(apikey string,
price float64) (*Handler, error) {
func NewHandler(apikey string) (*Handler, error) {
client, err := zhipu.NewClient(zhipu.WithAPIKey(apikey))
if err != nil {
return nil, err
}
const model = "glm-4-0520"
svc := client.ChatCompletion(model)
return &Handler{
client: client,
// 后续可以做成可配置的
svc: svc,
price: price,
}, err
}

Expand All @@ -38,19 +30,15 @@ func (h *Handler) Name() string {

func (h *Handler) Handle(ctx context.Context, req domain.LLMRequest) (domain.LLMResponse, error) {
// 这边它不会调用 next,因为它是最终的出口
completion, err := h.svc.AddTool(zhipu.ChatCompletionToolRetrieval{
KnowledgeID: req.Config.KnowledgeId,
}).AddMessage(zhipu.ChatCompletionMessage{
Role: "user",
Content: req.Prompt,
}).Do(ctx)
chatReq := h.buildReq(req)
completion, err := chatReq.Do(ctx)
if err != nil {
return domain.LLMResponse{}, err
}
tokens := completion.Usage.TotalTokens
// 现在的报价都是 N/1k token
// 而后向上取整
amt := math.Ceil(float64(tokens) * h.price / 1000)
amt := math.Ceil(float64(tokens*req.Config.Price) / float64(1000))
// 金额只有具体的模型才知道怎么算
resp := domain.LLMResponse{
Tokens: tokens,
Expand All @@ -62,3 +50,33 @@ func (h *Handler) Handle(ctx context.Context, req domain.LLMRequest) (domain.LLM
}
return resp, nil
}

func (h *Handler) buildReq(req domain.LLMRequest) *zhipu.ChatCompletionService {
svc := h.client.ChatCompletion(req.Config.Model)
chatReq := svc.AddMessage(zhipu.ChatCompletionMessage{
Role: zhipu.RoleUser,
Content: req.Prompt,
})

if req.Config.Temperature > 0 {
chatReq = chatReq.SetTemperature(req.Config.Temperature)
}

if req.Config.TopP > 0 {
chatReq = chatReq.SetTopP(req.Config.TopP)
}

if req.Config.SystemPrompt != "" {
chatReq = chatReq.AddMessage(zhipu.ChatCompletionMessage{
Role: zhipu.RoleSystem,
Content: req.Config.SystemPrompt,
})
}

if req.Config.KnowledgeId != "" {
chatReq = chatReq.AddTool(zhipu.ChatCompletionToolRetrieval{
KnowledgeID: req.Config.KnowledgeId,
})
}
return chatReq
}
2 changes: 1 addition & 1 deletion internal/cases/internal/service/examine.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (svc *LLMExamineService) Examine(ctx context.Context,
Uid: uid,
Tid: tid,
Biz: biz,
Input: []string{ca.Title, input},
Input: []string{ca.Title, ca.Content, input},
}
aiResp, err := svc.aiSvc.Invoke(ctx, aiReq)
if err != nil {
Expand Down
19 changes: 18 additions & 1 deletion internal/question/internal/domain/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package domain

import "time"
import (
"strings"
"time"
)

// Question 和 QuestionSet 是一个多对多的关系
type Question struct {
Expand Down Expand Up @@ -47,6 +50,20 @@ type Answer struct {
Utime time.Time
}

func (a Answer) String() string {
var sb strings.Builder
sb.WriteString("15K: ")
sb.WriteString(a.Basic.Content)
sb.WriteString("\n")
sb.WriteString("25K: ")
sb.WriteString(a.Intermediate.Content)
sb.WriteString("\n")
sb.WriteString("35K: ")
sb.WriteString(a.Advanced.Content)
sb.WriteString("\n")
return sb.String()
}

type AnswerElement struct {
Id int64
Content string
Expand Down
10 changes: 5 additions & 5 deletions internal/question/internal/integration/examine_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *ExamineHandlerTest) SetupSuite() {
return ai.LLMResponse{
Tokens: req.Uid,
Amount: req.Uid,
Answer: "评分:15K",
Answer: "最终评分 \n 1分",
}, nil
}).AnyTimes()
module, err := startup.InitModule(nil, &interactive.Module{}, &permission.Module{}, &ai.Module{Svc: aiSvc})
Expand Down Expand Up @@ -138,7 +138,7 @@ func (s *ExamineHandlerTest) TestExamine() {
Uid: uid,
Qid: 1,
Result: domain.ResultBasic.ToUint8(),
RawResult: "评分:15K",
RawResult: "最终评分 \n 1分",
Tokens: uid,
Amount: uid,
}, record)
Expand Down Expand Up @@ -168,7 +168,7 @@ func (s *ExamineHandlerTest) TestExamine() {
wantResp: test.Result[web.ExamineResult]{
Data: web.ExamineResult{
Result: domain.ResultBasic.ToUint8(),
RawResult: "评分:15K",
RawResult: "最终评分 \n 1分",
Amount: uid,
},
},
Expand Down Expand Up @@ -206,7 +206,7 @@ func (s *ExamineHandlerTest) TestExamine() {
Uid: uid,
Qid: 2,
Result: domain.ResultBasic.ToUint8(),
RawResult: "评分:15K",
RawResult: "最终评分 \n 1分",
Tokens: uid,
Amount: uid,
}, record)
Expand Down Expand Up @@ -236,7 +236,7 @@ func (s *ExamineHandlerTest) TestExamine() {
wantResp: test.Result[web.ExamineResult]{
Data: web.ExamineResult{
Result: domain.ResultBasic.ToUint8(),
RawResult: "评分:15K",
RawResult: "最终评分 \n 1分",
Amount: uid,
},
},
Expand Down
Loading

0 comments on commit f5225a1

Please sign in to comment.