Skip to content

Commit

Permalink
Merge pull request #13 from ekkinox/feat/max-tokens
Browse files Browse the repository at this point in the history
Added config for max tokens
  • Loading branch information
ekkinox authored Apr 17, 2023
2 parents cc1bbc3 + 96a4db4 commit 939f092
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 9 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

All notable changes to this project will be documented in this file.

## 0.3.0

### Added

- Configuration for OpenAI API max-tokens (default 1000)
- Better feedback for install script

### Updated

- sashabaranov/go-openai to v1.8.0

## 0.2.0

### Added
Expand Down
4 changes: 2 additions & 2 deletions ai/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (e *Engine) ExecCompletion(input string) (*EngineExecOutput, error) {
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
MaxTokens: 1000,
MaxTokens: e.config.GetAiConfig().GetMaxTokens(),
Messages: e.prepareCompletionMessages(),
},
)
Expand Down Expand Up @@ -170,7 +170,7 @@ func (e *Engine) ChatStreamCompletion(input string) error {

req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
MaxTokens: 1000,
MaxTokens: e.config.GetAiConfig().GetMaxTokens(),
Messages: e.prepareCompletionMessages(),
Stream: true,
}
Expand Down
6 changes: 6 additions & 0 deletions config/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ const (
openai_key = "OPENAI_KEY"
openai_proxy = "OPENAI_PROXY"
openai_temperature = "OPENAI_TEMPERATURE"
openai_max_tokens = "OPENAI_MAX_TOKENS"
)

type AiConfig struct {
key string
proxy string
temperature float64
maxTokens int
}

func (c AiConfig) GetKey() string {
Expand All @@ -23,3 +25,7 @@ func (c AiConfig) GetProxy() string {
func (c AiConfig) GetTemperature() float64 {
return c.temperature
}

func (c AiConfig) GetMaxTokens() int {
return c.maxTokens
}
10 changes: 10 additions & 0 deletions config/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ func TestAiConfig(t *testing.T) {
t.Run("GetKey", testGetKey)
t.Run("GetProxy", testGetProxy)
t.Run("GetTemperature", testGetTemperature)
t.Run("GetMaxTokens", testGetMaxTokens)
}

func testGetKey(t *testing.T) {
Expand Down Expand Up @@ -38,3 +39,12 @@ func testGetTemperature(t *testing.T) {

assert.Equal(t, expectedTemperature, actualTemperature, "The two temperatures should be the same.")
}

func testGetMaxTokens(t *testing.T) {
expectedMaxTokens := 2000
aiConfig := AiConfig{maxTokens: expectedMaxTokens}

actualMaxTokens := aiConfig.GetMaxTokens()

assert.Equal(t, expectedMaxTokens, actualMaxTokens, "The two maxTokens should be the same.")
}
9 changes: 6 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ func NewConfig() (*Config, error) {
return &Config{
ai: AiConfig{
key: viper.GetString(openai_key),
temperature: viper.GetFloat64(openai_temperature),
proxy: viper.GetString(openai_proxy),
temperature: viper.GetFloat64(openai_temperature),
maxTokens: viper.GetInt(openai_max_tokens),
},
user: UserConfig{
defaultPromptMode: viper.GetString(user_default_prompt_mode),
Expand All @@ -52,10 +53,12 @@ func NewConfig() (*Config, error) {

func WriteConfig(key string, write bool) (*Config, error) {
system := system.Analyse()

// ai defaults
viper.Set(openai_key, key)
viper.SetDefault(openai_temperature, 0.2)
viper.SetDefault(openai_proxy, "")
viper.SetDefault(openai_temperature, 0.2)
viper.SetDefault(openai_max_tokens, 1000)
// user defaults
viper.SetDefault(user_default_prompt_mode, "exec")
viper.SetDefault(user_preferences, "")

Expand Down
10 changes: 7 additions & 3 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ func setupViper(t *testing.T) {
viper.SetConfigName(strings.ToLower(system.GetApplicationName()))
viper.AddConfigPath("/tmp/")
viper.Set(openai_key, "test_key")
viper.Set(openai_temperature, 0.2)
viper.Set(openai_proxy, "test_proxy")
viper.Set(openai_temperature, 0.2)
viper.Set(openai_max_tokens, 2000)
viper.Set(user_default_prompt_mode, "exec")
viper.Set(user_preferences, "test_preferences")

Expand All @@ -47,6 +48,7 @@ func testNewConfig(t *testing.T) {
assert.Equal(t, "test_key", cfg.GetAiConfig().GetKey())
assert.Equal(t, "test_proxy", cfg.GetAiConfig().GetProxy())
assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature())
assert.Equal(t, 2000, cfg.GetAiConfig().GetMaxTokens())
assert.Equal(t, "exec", cfg.GetUserConfig().GetDefaultPromptMode())
assert.Equal(t, "test_preferences", cfg.GetUserConfig().GetPreferences())

Expand All @@ -61,16 +63,18 @@ func testWriteConfig(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, "new_test_key", cfg.GetAiConfig().GetKey())
assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature())
assert.Equal(t, "test_proxy", cfg.GetAiConfig().GetProxy())
assert.Equal(t, 0.2, cfg.GetAiConfig().GetTemperature())
assert.Equal(t, 2000, cfg.GetAiConfig().GetMaxTokens())
assert.Equal(t, "exec", cfg.GetUserConfig().GetDefaultPromptMode())
assert.Equal(t, "test_preferences", cfg.GetUserConfig().GetPreferences())

assert.NotNil(t, cfg.GetSystemConfig())

assert.Equal(t, "new_test_key", viper.GetString(openai_key))
assert.Equal(t, 0.2, viper.GetFloat64(openai_temperature))
assert.Equal(t, "test_proxy", viper.GetString(openai_proxy))
assert.Equal(t, 0.2, viper.GetFloat64(openai_temperature))
assert.Equal(t, 2000, viper.GetInt(openai_max_tokens))
assert.Equal(t, "exec", viper.GetString(user_default_prompt_mode))
assert.Equal(t, "test_preferences", viper.GetString(user_preferences))
}
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ BINNAME="${BINNAME:-yo}"
BINDIR="${BINDIR:-/usr/local/bin}"
URL="https://github.com/$REPOOWNER/$REPONAME/releases/download/${RELEASETAG}/yo_${RELEASETAG}_${KERNEL}_${MACHINE}.tar.gz"

echo "Downloading from $URL"
echo "Downloading version $RELEASETAG from $URL"
echo

curl -q --fail --location --progress-bar --output "yo_${RELEASETAG}_${KERNEL}_${MACHINE}.tar.gz" "$URL"
Expand Down

0 comments on commit 939f092

Please sign in to comment.