@@ -77,15 +77,10 @@ func Load(workingDir string, debug bool) (*Config, error) {
return cfg, nil
}
- // largeModel, ok := cfg.Models[SelectedModelTypeLarge]
- // if !ok {
- // // set default
- // }
- // smallModel, ok := cfg.Models[SelectedModelTypeSmall]
- // if !ok {
- // // set default
- // }
- //
+ if err := cfg.configureSelectedModels(providers); err != nil {
+ return nil, fmt.Errorf("failed to configure selected models: %w", err)
+ }
+
return cfg, nil
}
@@ -236,34 +231,6 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn
return nil
}
-func hasVertexCredentials(env env.Env) bool {
- useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
- hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
- hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
- return useVertex && hasProject && hasLocation
-}
-
-func hasAWSCredentials(env env.Env) bool {
- if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
- return true
- }
-
- if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
- return true
- }
-
- if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
- return true
- }
-
- if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
- env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
- return true
- }
-
- return false
-}
-
func (cfg *Config) setDefaults(workingDir string) {
cfg.workingDir = workingDir
if cfg.Options == nil {
@@ -297,6 +264,135 @@ func (cfg *Config) setDefaults(workingDir string) {
cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
}
+func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
+ if len(knownProviders) == 0 && len(cfg.Providers) == 0 {
+ err = fmt.Errorf("no providers configured, please configure at least one provider")
+ return
+ }
+
+ // Use the first provider enabled based on the known providers order
+ // if no provider found that is known use the first provider configured
+ for _, p := range knownProviders {
+ providerConfig, ok := cfg.Providers[string(p.ID)]
+ if !ok || providerConfig.Disable {
+ continue
+ }
+ defaultLargeModel := cfg.GetModel(string(p.ID), p.DefaultLargeModelID)
+ if defaultLargeModel == nil {
+ err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
+ return
+ }
+ largeModel = SelectedModel{
+ Provider: string(p.ID),
+ Model: defaultLargeModel.ID,
+ MaxTokens: defaultLargeModel.DefaultMaxTokens,
+ }
+
+ defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID)
+ if defaultSmallModel == nil {
+ err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
+ }
+ smallModel = SelectedModel{
+ Provider: string(p.ID),
+ Model: defaultSmallModel.ID,
+ MaxTokens: defaultSmallModel.DefaultMaxTokens,
+ }
+ return
+ }
+
+ enabledProviders := cfg.EnabledProviders()
+ slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
+ return strings.Compare(a.ID, b.ID)
+ })
+
+ if len(enabledProviders) == 0 {
+ err = fmt.Errorf("no providers configured, please configure at least one provider")
+ return
+ }
+
+ providerConfig := enabledProviders[0]
+ if len(providerConfig.Models) == 0 {
+ err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
+ return
+ }
+ defaultLargeModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
+ largeModel = SelectedModel{
+ Provider: providerConfig.ID,
+ Model: defaultLargeModel.ID,
+ MaxTokens: defaultLargeModel.DefaultMaxTokens,
+ }
+ defaultSmallModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
+ smallModel = SelectedModel{
+ Provider: providerConfig.ID,
+ Model: defaultSmallModel.ID,
+ MaxTokens: defaultSmallModel.DefaultMaxTokens,
+ }
+ return
+}
+
+func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) error {
+ large, small, err := cfg.defaultModelSelection(knownProviders)
+ if err != nil {
+ return fmt.Errorf("failed to select default models: %w", err)
+ }
+
+ largeModelSelected, largeModelConfigured := cfg.Models[SelectedModelTypeLarge]
+ if largeModelConfigured {
+ if largeModelSelected.Model != "" {
+ large.Model = largeModelSelected.Model
+ }
+ if largeModelSelected.Provider != "" {
+ large.Provider = largeModelSelected.Provider
+ }
+ model := cfg.GetModel(large.Provider, large.Model)
+ if model == nil {
+ return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+ }
+ if largeModelSelected.MaxTokens > 0 {
+ large.MaxTokens = largeModelSelected.MaxTokens
+ } else {
+ large.MaxTokens = model.DefaultMaxTokens
+ }
+ large.ReasoningEffort = largeModelSelected.ReasoningEffort
+ large.Think = largeModelSelected.Think
+
+ }
+ smallModelSelected, smallModelConfigured := cfg.Models[SelectedModelTypeSmall]
+ if smallModelConfigured {
+ if smallModelSelected.Model != "" {
+ small.Model = smallModelSelected.Model
+ }
+ if smallModelSelected.Provider != "" {
+ small.Provider = smallModelSelected.Provider
+ }
+
+ model := cfg.GetModel(small.Provider, small.Model)
+ if model == nil {
+ return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+ }
+ if smallModelSelected.MaxTokens > 0 {
+ small.MaxTokens = smallModelSelected.MaxTokens
+ } else {
+ small.MaxTokens = model.DefaultMaxTokens
+ }
+ small.ReasoningEffort = smallModelSelected.ReasoningEffort
+ small.Think = smallModelSelected.Think
+ }
+
+ // validate the selected models
+ largeModel := cfg.GetModel(large.Provider, large.Model)
+ if largeModel == nil {
+ return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+ }
+ smallModel := cfg.GetModel(small.Provider, small.Model)
+ if smallModel == nil {
+ return fmt.Errorf("small model %s not found for provider %s", small.Model, small.Provider)
+ }
+ cfg.Models[SelectedModelTypeLarge] = large
+ cfg.Models[SelectedModelTypeSmall] = small
+ return nil
+}
+
func loadFromConfigPaths(configPaths []string) (*Config, error) {
var configs []io.Reader
@@ -329,6 +425,34 @@ func loadFromReaders(readers []io.Reader) (*Config, error) {
return LoadReader(merged)
}
+func hasVertexCredentials(env env.Env) bool {
+ useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
+ hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
+ hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
+ return useVertex && hasProject && hasLocation
+}
+
+func hasAWSCredentials(env env.Env) bool {
+ if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
+ return true
+ }
+
+ if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
+ return true
+ }
+
+ if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
+ return true
+ }
+
+ if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
+ env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
+ return true
+ }
+
+ return false
+}
+
func globalConfig() string {
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
@@ -760,3 +760,390 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
assert.True(t, exists)
})
}
+
+func TestConfig_defaultModelSelection(t *testing.T) {
+ t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "abc",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{}
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ large, small, err := cfg.defaultModelSelection(knownProviders)
+ assert.NoError(t, err)
+ assert.Equal(t, "large-model", large.Model)
+ assert.Equal(t, "openai", large.Provider)
+ assert.Equal(t, int64(1000), large.MaxTokens)
+ assert.Equal(t, "small-model", small.Model)
+ assert.Equal(t, "openai", small.Provider)
+ assert.Equal(t, int64(500), small.MaxTokens)
+ })
+ t.Run("should error if no providers configured", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$MISSING_KEY",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{}
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ _, _, err = cfg.defaultModelSelection(knownProviders)
+ assert.Error(t, err)
+ })
+ t.Run("should error if model is missing", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "abc",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "not-large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{}
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+ _, _, err = cfg.defaultModelSelection(knownProviders)
+ assert.Error(t, err)
+ })
+
+ t.Run("should configure the default models with a custom provider", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$MISSING", // will not be included in the config
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "not-large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Models: []provider.Model{
+ {
+ ID: "model",
+ DefaultMaxTokens: 600,
+ },
+ },
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+ large, small, err := cfg.defaultModelSelection(knownProviders)
+ assert.NoError(t, err)
+ assert.Equal(t, "model", large.Model)
+ assert.Equal(t, "custom", large.Provider)
+ assert.Equal(t, int64(600), large.MaxTokens)
+ assert.Equal(t, "model", small.Model)
+ assert.Equal(t, "custom", small.Provider)
+ assert.Equal(t, int64(600), small.MaxTokens)
+ })
+
+ t.Run("should fail if no model configured", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "$MISSING", // will not be included in the config
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "not-large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Models: []provider.Model{},
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+ _, _, err = cfg.defaultModelSelection(knownProviders)
+ assert.Error(t, err)
+ })
+ t.Run("should use the default provider first", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "set",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Providers: map[string]ProviderConfig{
+ "custom": {
+ APIKey: "test-key",
+ BaseURL: "https://api.custom.com/v1",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ },
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+ large, small, err := cfg.defaultModelSelection(knownProviders)
+ assert.NoError(t, err)
+ assert.Equal(t, "large-model", large.Model)
+ assert.Equal(t, "openai", large.Provider)
+ assert.Equal(t, int64(1000), large.MaxTokens)
+ assert.Equal(t, "small-model", small.Model)
+ assert.Equal(t, "openai", small.Provider)
+ assert.Equal(t, int64(500), small.MaxTokens)
+ })
+}
+
+func TestConfig_configureSelectedModels(t *testing.T) {
+ t.Run("should override defaults", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "abc",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "larger-model",
+ DefaultMaxTokens: 2000,
+ },
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Models: map[SelectedModelType]SelectedModel{
+ "large": {
+ Model: "larger-model",
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ err = cfg.configureSelectedModels(knownProviders)
+ assert.NoError(t, err)
+ large := cfg.Models[SelectedModelTypeLarge]
+ small := cfg.Models[SelectedModelTypeSmall]
+ assert.Equal(t, "larger-model", large.Model)
+ assert.Equal(t, "openai", large.Provider)
+ assert.Equal(t, int64(2000), large.MaxTokens)
+ assert.Equal(t, "small-model", small.Model)
+ assert.Equal(t, "openai", small.Provider)
+ assert.Equal(t, int64(500), small.MaxTokens)
+ })
+ t.Run("should be possible to use multiple providers", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "abc",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ {
+ ID: "anthropic",
+ APIKey: "abc",
+ DefaultLargeModelID: "a-large-model",
+ DefaultSmallModelID: "a-small-model",
+ Models: []provider.Model{
+ {
+ ID: "a-large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "a-small-model",
+ DefaultMaxTokens: 200,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Models: map[SelectedModelType]SelectedModel{
+ "small": {
+ Model: "a-small-model",
+ Provider: "anthropic",
+ MaxTokens: 300,
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ err = cfg.configureSelectedModels(knownProviders)
+ assert.NoError(t, err)
+ large := cfg.Models[SelectedModelTypeLarge]
+ small := cfg.Models[SelectedModelTypeSmall]
+ assert.Equal(t, "large-model", large.Model)
+ assert.Equal(t, "openai", large.Provider)
+ assert.Equal(t, int64(1000), large.MaxTokens)
+ assert.Equal(t, "a-small-model", small.Model)
+ assert.Equal(t, "anthropic", small.Provider)
+ assert.Equal(t, int64(300), small.MaxTokens)
+ })
+
+ t.Run("should override the max tokens only", func(t *testing.T) {
+ knownProviders := []provider.Provider{
+ {
+ ID: "openai",
+ APIKey: "abc",
+ DefaultLargeModelID: "large-model",
+ DefaultSmallModelID: "small-model",
+ Models: []provider.Model{
+ {
+ ID: "large-model",
+ DefaultMaxTokens: 1000,
+ },
+ {
+ ID: "small-model",
+ DefaultMaxTokens: 500,
+ },
+ },
+ },
+ }
+
+ cfg := &Config{
+ Models: map[SelectedModelType]SelectedModel{
+ "large": {
+ MaxTokens: 100,
+ },
+ },
+ }
+ cfg.setDefaults("/tmp")
+ env := env.NewFromMap(map[string]string{})
+ resolver := NewEnvironmentVariableResolver(env)
+ err := cfg.configureProviders(env, resolver, knownProviders)
+ assert.NoError(t, err)
+
+ err = cfg.configureSelectedModels(knownProviders)
+ assert.NoError(t, err)
+ large := cfg.Models[SelectedModelTypeLarge]
+ assert.Equal(t, "large-model", large.Model)
+ assert.Equal(t, "openai", large.Provider)
+ assert.Equal(t, int64(100), large.MaxTokens)
+ })
+}