chore: add default models handling and tests

Kujtim Hoxha created

Change summary

pkg/config/config.go    |  11 +
pkg/config/load.go      | 198 +++++++++++++++++----
pkg/config/load_test.go | 387 +++++++++++++++++++++++++++++++++++++++++++
pkg/env/env.go          |   3 
4 files changed, 562 insertions(+), 37 deletions(-)

Detailed changes

pkg/config/config.go 🔗

@@ -195,3 +195,14 @@ func (c *Config) EnabledProviders() []ProviderConfig {
 func (c *Config) IsConfigured() bool {
 	return len(c.EnabledProviders()) > 0
 }
+
+func (c *Config) GetModel(provider, model string) *provider.Model {
+	if providerConfig, ok := c.Providers[provider]; ok {
+		for _, m := range providerConfig.Models {
+			if m.ID == model {
+				return &m
+			}
+		}
+	}
+	return nil
+}

pkg/config/load.go 🔗

@@ -77,15 +77,10 @@ func Load(workingDir string, debug bool) (*Config, error) {
 		return cfg, nil
 	}
 
-	// largeModel, ok := cfg.Models[SelectedModelTypeLarge]
-	// if !ok {
-	// 	// set default
-	// }
-	// smallModel, ok := cfg.Models[SelectedModelTypeSmall]
-	// if !ok {
-	// 	// set default
-	// }
-	//
+	if err := cfg.configureSelectedModels(providers); err != nil {
+		return nil, fmt.Errorf("failed to configure selected models: %w", err)
+	}
+
 	return cfg, nil
 }
 
@@ -236,34 +231,6 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn
 	return nil
 }
 
-func hasVertexCredentials(env env.Env) bool {
-	useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
-	hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
-	hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
-	return useVertex && hasProject && hasLocation
-}
-
-func hasAWSCredentials(env env.Env) bool {
-	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
-		return true
-	}
-
-	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
-		return true
-	}
-
-	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
-		return true
-	}
-
-	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
-		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
-		return true
-	}
-
-	return false
-}
-
 func (cfg *Config) setDefaults(workingDir string) {
 	cfg.workingDir = workingDir
 	if cfg.Options == nil {
@@ -297,6 +264,135 @@ func (cfg *Config) setDefaults(workingDir string) {
 	cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths)
 }
 
+func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
+	if len(knownProviders) == 0 && len(cfg.Providers) == 0 {
+		err = fmt.Errorf("no providers configured, please configure at least one provider")
+		return
+	}
+
+	// Use the first provider enabled based on the known providers order
+	// if no provider found that is known use the first provider configured
+	for _, p := range knownProviders {
+		providerConfig, ok := cfg.Providers[string(p.ID)]
+		if !ok || providerConfig.Disable {
+			continue
+		}
+		defaultLargeModel := cfg.GetModel(string(p.ID), p.DefaultLargeModelID)
+		if defaultLargeModel == nil {
+			err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
+			return
+		}
+		largeModel = SelectedModel{
+			Provider:  string(p.ID),
+			Model:     defaultLargeModel.ID,
+			MaxTokens: defaultLargeModel.DefaultMaxTokens,
+		}
+
+		defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID)
+		if defaultSmallModel == nil {
+			err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
+		}
+		smallModel = SelectedModel{
+			Provider:  string(p.ID),
+			Model:     defaultSmallModel.ID,
+			MaxTokens: defaultSmallModel.DefaultMaxTokens,
+		}
+		return
+	}
+
+	enabledProviders := cfg.EnabledProviders()
+	slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
+		return strings.Compare(a.ID, b.ID)
+	})
+
+	if len(enabledProviders) == 0 {
+		err = fmt.Errorf("no providers configured, please configure at least one provider")
+		return
+	}
+
+	providerConfig := enabledProviders[0]
+	if len(providerConfig.Models) == 0 {
+		err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
+		return
+	}
+	defaultLargeModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
+	largeModel = SelectedModel{
+		Provider:  providerConfig.ID,
+		Model:     defaultLargeModel.ID,
+		MaxTokens: defaultLargeModel.DefaultMaxTokens,
+	}
+	defaultSmallModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
+	smallModel = SelectedModel{
+		Provider:  providerConfig.ID,
+		Model:     defaultSmallModel.ID,
+		MaxTokens: defaultSmallModel.DefaultMaxTokens,
+	}
+	return
+}
+
+func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) error {
+	large, small, err := cfg.defaultModelSelection(knownProviders)
+	if err != nil {
+		return fmt.Errorf("failed to select default models: %w", err)
+	}
+
+	largeModelSelected, largeModelConfigured := cfg.Models[SelectedModelTypeLarge]
+	if largeModelConfigured {
+		if largeModelSelected.Model != "" {
+			large.Model = largeModelSelected.Model
+		}
+		if largeModelSelected.Provider != "" {
+			large.Provider = largeModelSelected.Provider
+		}
+		model := cfg.GetModel(large.Provider, large.Model)
+		if model == nil {
+			return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+		}
+		if largeModelSelected.MaxTokens > 0 {
+			large.MaxTokens = largeModelSelected.MaxTokens
+		} else {
+			large.MaxTokens = model.DefaultMaxTokens
+		}
+		large.ReasoningEffort = largeModelSelected.ReasoningEffort
+		large.Think = largeModelSelected.Think
+
+	}
+	smallModelSelected, smallModelConfigured := cfg.Models[SelectedModelTypeSmall]
+	if smallModelConfigured {
+		if smallModelSelected.Model != "" {
+			small.Model = smallModelSelected.Model
+		}
+		if smallModelSelected.Provider != "" {
+			small.Provider = smallModelSelected.Provider
+		}
+
+		model := cfg.GetModel(small.Provider, small.Model)
+		if model == nil {
+			return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+		}
+		if smallModelSelected.MaxTokens > 0 {
+			small.MaxTokens = smallModelSelected.MaxTokens
+		} else {
+			small.MaxTokens = model.DefaultMaxTokens
+		}
+		small.ReasoningEffort = smallModelSelected.ReasoningEffort
+		small.Think = smallModelSelected.Think
+	}
+
+	// validate the selected models
+	largeModel := cfg.GetModel(large.Provider, large.Model)
+	if largeModel == nil {
+		return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider)
+	}
+	smallModel := cfg.GetModel(small.Provider, small.Model)
+	if smallModel == nil {
+		return fmt.Errorf("small model %s not found for provider %s", small.Model, small.Provider)
+	}
+	cfg.Models[SelectedModelTypeLarge] = large
+	cfg.Models[SelectedModelTypeSmall] = small
+	return nil
+}
+
 func loadFromConfigPaths(configPaths []string) (*Config, error) {
 	var configs []io.Reader
 
@@ -329,6 +425,34 @@ func loadFromReaders(readers []io.Reader) (*Config, error) {
 	return LoadReader(merged)
 }
 
+func hasVertexCredentials(env env.Env) bool {
+	useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true"
+	hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != ""
+	hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != ""
+	return useVertex && hasProject && hasLocation
+}
+
+func hasAWSCredentials(env env.Env) bool {
+	if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
+		return true
+	}
+
+	if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
+		return true
+	}
+
+	if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
+		return true
+	}
+
+	if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
+		env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
+		return true
+	}
+
+	return false
+}
+
 func globalConfig() string {
 	xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
 	if xdgConfigHome != "" {

pkg/config/load_test.go 🔗

@@ -760,3 +760,390 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
 		assert.True(t, exists)
 	})
 }
+
+func TestConfig_defaultModelSelection(t *testing.T) {
+	t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "abc",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+
+		large, small, err := cfg.defaultModelSelection(knownProviders)
+		assert.NoError(t, err)
+		assert.Equal(t, "large-model", large.Model)
+		assert.Equal(t, "openai", large.Provider)
+		assert.Equal(t, int64(1000), large.MaxTokens)
+		assert.Equal(t, "small-model", small.Model)
+		assert.Equal(t, "openai", small.Provider)
+		assert.Equal(t, int64(500), small.MaxTokens)
+	})
+	t.Run("should error if no providers configured", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "$MISSING_KEY",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+
+		_, _, err = cfg.defaultModelSelection(knownProviders)
+		assert.Error(t, err)
+	})
+	t.Run("should error if model is missing", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "abc",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "not-large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+		_, _, err = cfg.defaultModelSelection(knownProviders)
+		assert.Error(t, err)
+	})
+
+	t.Run("should configure the default models with a custom provider", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "$MISSING", // will not be included in the config
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "not-large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Providers: map[string]ProviderConfig{
+				"custom": {
+					APIKey:  "test-key",
+					BaseURL: "https://api.custom.com/v1",
+					Models: []provider.Model{
+						{
+							ID:               "model",
+							DefaultMaxTokens: 600,
+						},
+					},
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+		large, small, err := cfg.defaultModelSelection(knownProviders)
+		assert.NoError(t, err)
+		assert.Equal(t, "model", large.Model)
+		assert.Equal(t, "custom", large.Provider)
+		assert.Equal(t, int64(600), large.MaxTokens)
+		assert.Equal(t, "model", small.Model)
+		assert.Equal(t, "custom", small.Provider)
+		assert.Equal(t, int64(600), small.MaxTokens)
+	})
+
+	t.Run("should fail if no model configured", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "$MISSING", // will not be included in the config
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "not-large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Providers: map[string]ProviderConfig{
+				"custom": {
+					APIKey:  "test-key",
+					BaseURL: "https://api.custom.com/v1",
+					Models:  []provider.Model{},
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+		_, _, err = cfg.defaultModelSelection(knownProviders)
+		assert.Error(t, err)
+	})
+	t.Run("should use the default provider first", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "set",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Providers: map[string]ProviderConfig{
+				"custom": {
+					APIKey:  "test-key",
+					BaseURL: "https://api.custom.com/v1",
+					Models: []provider.Model{
+						{
+							ID:               "large-model",
+							DefaultMaxTokens: 1000,
+						},
+					},
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+		large, small, err := cfg.defaultModelSelection(knownProviders)
+		assert.NoError(t, err)
+		assert.Equal(t, "large-model", large.Model)
+		assert.Equal(t, "openai", large.Provider)
+		assert.Equal(t, int64(1000), large.MaxTokens)
+		assert.Equal(t, "small-model", small.Model)
+		assert.Equal(t, "openai", small.Provider)
+		assert.Equal(t, int64(500), small.MaxTokens)
+	})
+}
+
+func TestConfig_configureSelectedModels(t *testing.T) {
+	t.Run("should override defaults", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "abc",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "larger-model",
+						DefaultMaxTokens: 2000,
+					},
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Models: map[SelectedModelType]SelectedModel{
+				"large": {
+					Model: "larger-model",
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+
+		err = cfg.configureSelectedModels(knownProviders)
+		assert.NoError(t, err)
+		large := cfg.Models[SelectedModelTypeLarge]
+		small := cfg.Models[SelectedModelTypeSmall]
+		assert.Equal(t, "larger-model", large.Model)
+		assert.Equal(t, "openai", large.Provider)
+		assert.Equal(t, int64(2000), large.MaxTokens)
+		assert.Equal(t, "small-model", small.Model)
+		assert.Equal(t, "openai", small.Provider)
+		assert.Equal(t, int64(500), small.MaxTokens)
+	})
+	t.Run("should be possible to use multiple providers", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "abc",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+			{
+				ID:                  "anthropic",
+				APIKey:              "abc",
+				DefaultLargeModelID: "a-large-model",
+				DefaultSmallModelID: "a-small-model",
+				Models: []provider.Model{
+					{
+						ID:               "a-large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "a-small-model",
+						DefaultMaxTokens: 200,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Models: map[SelectedModelType]SelectedModel{
+				"small": {
+					Model:     "a-small-model",
+					Provider:  "anthropic",
+					MaxTokens: 300,
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+
+		err = cfg.configureSelectedModels(knownProviders)
+		assert.NoError(t, err)
+		large := cfg.Models[SelectedModelTypeLarge]
+		small := cfg.Models[SelectedModelTypeSmall]
+		assert.Equal(t, "large-model", large.Model)
+		assert.Equal(t, "openai", large.Provider)
+		assert.Equal(t, int64(1000), large.MaxTokens)
+		assert.Equal(t, "a-small-model", small.Model)
+		assert.Equal(t, "anthropic", small.Provider)
+		assert.Equal(t, int64(300), small.MaxTokens)
+	})
+
+	t.Run("should override the max tokens only", func(t *testing.T) {
+		knownProviders := []provider.Provider{
+			{
+				ID:                  "openai",
+				APIKey:              "abc",
+				DefaultLargeModelID: "large-model",
+				DefaultSmallModelID: "small-model",
+				Models: []provider.Model{
+					{
+						ID:               "large-model",
+						DefaultMaxTokens: 1000,
+					},
+					{
+						ID:               "small-model",
+						DefaultMaxTokens: 500,
+					},
+				},
+			},
+		}
+
+		cfg := &Config{
+			Models: map[SelectedModelType]SelectedModel{
+				"large": {
+					MaxTokens: 100,
+				},
+			},
+		}
+		cfg.setDefaults("/tmp")
+		env := env.NewFromMap(map[string]string{})
+		resolver := NewEnvironmentVariableResolver(env)
+		err := cfg.configureProviders(env, resolver, knownProviders)
+		assert.NoError(t, err)
+
+		err = cfg.configureSelectedModels(knownProviders)
+		assert.NoError(t, err)
+		large := cfg.Models[SelectedModelTypeLarge]
+		assert.Equal(t, "large-model", large.Model)
+		assert.Equal(t, "openai", large.Provider)
+		assert.Equal(t, int64(100), large.MaxTokens)
+	})
+}

pkg/env/env.go 🔗

@@ -51,5 +51,8 @@ func (m *mapEnv) Env() []string {
 }
 
 func NewFromMap(m map[string]string) Env {
+	if m == nil {
+		m = make(map[string]string)
+	}
 	return &mapEnv{m: m}
 }