From 96a4db49aaf5d7e5da6e0b3d779d599f93c9585c Mon Sep 17 00:00:00 2001 From: Jonathan Vuillemin Date: Mon, 17 Apr 2023 14:02:31 +0200 Subject: [PATCH] feat: added config for max tokens --- CHANGELOG.md | 11 +++++++++++ ai/engine.go | 4 ++-- config/ai.go | 6 ++++++ config/ai_test.go | 10 ++++++++++ config/config.go | 9 ++++++--- config/config_test.go | 10 +++++++--- install.sh | 2 +- 7 files changed, 43 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dff3e5b..57b6278 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ All notable changes to this project will be documented in this file. +## 0.3.0 + +### Added + +- Configuration for OpenAI API max-tokens (default 1000) +- Better feedback for install script + +### Updated + +- sashabaranov/go-openai to v1.8.0 + ## 0.2.0 ### Added diff --git a/ai/engine.go b/ai/engine.go index d36342e..f7a86d4 100644 --- a/ai/engine.go +++ b/ai/engine.go @@ -128,7 +128,7 @@ func (e *Engine) ExecCompletion(input string) (*EngineExecOutput, error) { ctx, openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo, - MaxTokens: 1000, + MaxTokens: e.config.GetAiConfig().GetMaxTokens(), Messages: e.prepareCompletionMessages(), }, ) @@ -170,7 +170,7 @@ func (e *Engine) ChatStreamCompletion(input string) error { req := openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo, - MaxTokens: 1000, + MaxTokens: e.config.GetAiConfig().GetMaxTokens(), Messages: e.prepareCompletionMessages(), Stream: true, } diff --git a/config/ai.go b/config/ai.go index 4d9bd2c..354d060 100644 --- a/config/ai.go +++ b/config/ai.go @@ -4,12 +4,14 @@ const ( openai_key = "OPENAI_KEY" openai_proxy = "OPENAI_PROXY" openai_temperature = "OPENAI_TEMPERATURE" + openai_max_tokens = "OPENAI_MAX_TOKENS" ) type AiConfig struct { key string proxy string temperature float64 + maxTokens int } func (c AiConfig) GetKey() string { @@ -23,3 +25,7 @@ func (c AiConfig) GetProxy() string { func (c AiConfig) GetTemperature() float64 { return c.temperature } + +func (c AiConfig) GetMaxTokens() int { + return c.maxTokens +} diff --git a/config/ai_test.go b/config/ai_test.go index 4371bd6..9fc7780 100644 --- a/config/ai_test.go +++ b/config/ai_test.go @@ -10,6 +10,7 @@ func TestAiConfig(t *testing.T) { t.Run("GetKey", testGetKey) t.Run("GetProxy", testGetProxy) t.Run("GetTemperature", testGetTemperature) + t.Run("GetMaxTokens", testGetMaxTokens) } func testGetKey(t *testing.T) { @@ -38,3 +39,12 @@ func testGetTemperature(t *testing.T) { assert.Equal(t, expectedTemperature, actualTemperature, "The two temperatures should be the same.") } + +func testGetMaxTokens(t *testing.T) { + expectedMaxTokens := 2000 + aiConfig := AiConfig{maxTokens: expectedMaxTokens} + + actualMaxTokens := aiConfig.GetMaxTokens() + + assert.Equal(t, expectedMaxTokens, actualMaxTokens, "The two maxTokens should be the same.") +} diff --git a/config/config.go b/config/config.go index 8af1339..0e40144 100644 --- a/config/config.go +++ b/config/config.go @@ -39,8 +39,9 @@ func NewConfig() (*Config, error) { return &Config{ ai: AiConfig{ key: viper.GetString(openai_key), - temperature: viper.GetFloat64(openai_temperature), proxy: viper.GetString(openai_proxy), + temperature: viper.GetFloat64(openai_temperature), + maxTokens: viper.GetInt(openai_max_tokens), }, user: UserConfig{ defaultPromptMode: viper.GetString(user_default_prompt_mode), @@ -52,10 +53,12 @@ func NewConfig() (*Config, error) { func WriteConfig(key string, write bool) (*Config, error) { system := system.Analyse() - + // ai defaults viper.Set(openai_key, key) - viper.SetDefault(openai_temperature, 0.2) viper.SetDefault(openai_proxy, "") + viper.SetDefault(openai_temperature, 0.2) + viper.SetDefault(openai_max_tokens, 1000) + // user defaults viper.SetDefault(user_default_prompt_mode, "exec") viper.SetDefault(user_preferences, "") diff --git a/config/config_test.go b/config/config_test.go index a38cc53..d2136a5 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -24,8 +24,9 @@ func setupViper(t *testing.T) { viper.SetConfigName(strings.ToLower(system.GetApplicationName())) viper.AddConfigPath("/tmp/") viper.Set(openai_key, "test_key") - viper.Set(openai_temperature, 0.2) viper.Set(openai_proxy, "test_proxy") + viper.Set(openai_temperature, 0.2) + viper.Set(openai_max_tokens, 2000) viper.Set(user_default_prompt_mode, "exec") viper.Set(user_preferences, "test_preferences") @@ -47,6 +48,7 @@ func testNewConfig(t *testing.T) { assert.Equal(t, "test_key", cfg.GetAiConfig().GetKey()) assert.Equal(t, "test_proxy", cfg.GetAiConfig().GetProxy()) assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature()) + assert.Equal(t, 2000, cfg.GetAiConfig().GetMaxTokens()) assert.Equal(t, "exec", cfg.GetUserConfig().GetDefaultPromptMode()) assert.Equal(t, "test_preferences", cfg.GetUserConfig().GetPreferences()) @@ -61,16 +63,18 @@ func testWriteConfig(t *testing.T) { require.NoError(t, err) assert.Equal(t, "new_test_key", cfg.GetAiConfig().GetKey()) - assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature()) assert.Equal(t, "test_proxy", cfg.GetAiConfig().GetProxy()) + assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature()) + assert.Equal(t, 2000, cfg.GetAiConfig().GetMaxTokens()) assert.Equal(t, "exec", cfg.GetUserConfig().GetDefaultPromptMode()) assert.Equal(t, "test_preferences", cfg.GetUserConfig().GetPreferences()) assert.NotNil(t, cfg.GetSystemConfig()) assert.Equal(t, "new_test_key", viper.GetString(openai_key)) - assert.Equal(t, 0.2, viper.GetFloat64(openai_temperature)) assert.Equal(t, "test_proxy", viper.GetString(openai_proxy)) + assert.Equal(t, 0.2, viper.GetFloat64(openai_temperature)) + assert.Equal(t, 2000, viper.GetInt(openai_max_tokens)) assert.Equal(t, "exec", viper.GetString(user_default_prompt_mode)) assert.Equal(t, "test_preferences", viper.GetString(user_preferences)) } diff --git a/install.sh b/install.sh index 90c363b..d50d53e 100644 --- a/install.sh +++ b/install.sh @@ -46,7 +46,7 @@ BINNAME="${BINNAME:-yo}" BINDIR="${BINDIR:-/usr/local/bin}" URL="https://github.com/$REPOOWNER/$REPONAME/releases/download/${RELEASETAG}/yo_${RELEASETAG}_${KERNEL}_${MACHINE}.tar.gz" -echo "Downloading from $URL" +echo "Downloading version $RELEASETAG from $URL" echo curl -q --fail --location --progress-bar --output "yo_${RELEASETAG}_${KERNEL}_${MACHINE}.tar.gz" "$URL"