Skip to content

Commit

Permalink
Replace with GetModelByName, output friendly name
Browse files Browse the repository at this point in the history
  • Loading branch information
cheshire137 committed Oct 8, 2024
1 parent f9b7694 commit c873994
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,11 @@ func NewRunCommand() *cobra.Command {
modelName = args[0]
}

modelName, err = util.ValidateModelName(modelName, models)
model, err := util.GetModelByName(modelName, models)
if err != nil {
return err
}
modelName = model.Name

initialPrompt := ""
singleShot := false
Expand Down
4 changes: 2 additions & 2 deletions cmd/view/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func NewViewCommand() *cobra.Command {
modelName = args[0]
}

modelName, err = util.ValidateModelName(modelName, models)
model, err := util.GetModelByName(modelName, models)
if err != nil {
return err
}

io.WriteString(out, "You selected: "+modelName+"\n")
io.WriteString(out, "You selected: "+model.FriendlyName+"\n")
return nil
},
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/util/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"github.com/github/gh-models/internal/azure_models"
)

// ValidateModelName checks whether the given model name is a valid one, based on the provided list of models.
func ValidateModelName(modelName string, models []*azure_models.ModelSummary) (string, error) {
// GetModelByName returns the model with the specified name, or an error if no such model exists within the given list.
func GetModelByName(modelName string, models []*azure_models.ModelSummary) (*azure_models.ModelSummary, error) {
for _, model := range models {
if strings.EqualFold(model.FriendlyName, modelName) || strings.EqualFold(model.Name, modelName) {
return model.Name, nil
return model, nil
}
}
return "", fmt.Errorf("the specified model name is not supported: %s", modelName)
return nil, fmt.Errorf("the specified model name is not supported: %s", modelName)
}

0 comments on commit c873994

Please sign in to comment.