From 6a7d5c50e23faf7f27672262d08ccf515a4da00e Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 5 Jul 2025 11:11:53 +0200 Subject: [PATCH] chore: add default models handling and tests --- pkg/config/config.go | 11 ++ pkg/config/load.go | 198 ++++++++++++++++---- pkg/config/load_test.go | 387 ++++++++++++++++++++++++++++++++++++++++ pkg/env/env.go | 3 + 4 files changed, 562 insertions(+), 37 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index de1d2abe08c91e23e5ac6c10c8b45e0a72efa2ba..e1554e23d3755637178a033c25aca2fee75525b5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -195,3 +195,14 @@ func (c *Config) EnabledProviders() []ProviderConfig { func (c *Config) IsConfigured() bool { return len(c.EnabledProviders()) > 0 } + +func (c *Config) GetModel(provider, model string) *provider.Model { + if providerConfig, ok := c.Providers[provider]; ok { + for _, m := range providerConfig.Models { + if m.ID == model { + return &m + } + } + } + return nil +} diff --git a/pkg/config/load.go b/pkg/config/load.go index 8294edef917b8e9ab9f190676181eaa90f62b8b9..ba563f4bdb87f5e0d1847c6e0152cef310a7ac08 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -77,15 +77,10 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } - // largeModel, ok := cfg.Models[SelectedModelTypeLarge] - // if !ok { - // // set default - // } - // smallModel, ok := cfg.Models[SelectedModelTypeSmall] - // if !ok { - // // set default - // } - // + if err := cfg.configureSelectedModels(providers); err != nil { + return nil, fmt.Errorf("failed to configure selected models: %w", err) + } + return cfg, nil } @@ -236,34 +231,6 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn return nil } -func hasVertexCredentials(env env.Env) bool { - useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true" - hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != "" - hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != "" - return useVertex && hasProject && hasLocation -} - -func hasAWSCredentials(env env.Env) bool { - if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" { - return true - } - - if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" { - return true - } - - if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" { - return true - } - - if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || - env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { - return true - } - - return false -} - func (cfg *Config) setDefaults(workingDir string) { cfg.workingDir = workingDir if cfg.Options == nil { @@ -297,6 +264,135 @@ func (cfg *Config) setDefaults(workingDir string) { cfg.Options.ContextPaths = slices.Compact(cfg.Options.ContextPaths) } +func (cfg *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { + if len(knownProviders) == 0 && len(cfg.Providers) == 0 { + err = fmt.Errorf("no providers configured, please configure at least one provider") + return + } + + // Use the first provider enabled based on the known providers order + // if no provider found that is known use the first provider configured + for _, p := range knownProviders { + providerConfig, ok := cfg.Providers[string(p.ID)] + if !ok || providerConfig.Disable { + continue + } + defaultLargeModel := cfg.GetModel(string(p.ID), p.DefaultLargeModelID) + if defaultLargeModel == nil { + err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID) + return + } + largeModel = SelectedModel{ + Provider: string(p.ID), + Model: defaultLargeModel.ID, + MaxTokens: defaultLargeModel.DefaultMaxTokens, + } + + defaultSmallModel := cfg.GetModel(string(p.ID), p.DefaultSmallModelID) + if defaultSmallModel == nil { + err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID) + } + smallModel = SelectedModel{ + Provider: string(p.ID), + Model: defaultSmallModel.ID, + MaxTokens: defaultSmallModel.DefaultMaxTokens, + } + return + } + + enabledProviders := cfg.EnabledProviders() + slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int { + return strings.Compare(a.ID, b.ID) + }) + + if len(enabledProviders) == 0 { + err = fmt.Errorf("no providers configured, please configure at least one provider") + return + } + + providerConfig := enabledProviders[0] + if len(providerConfig.Models) == 0 { + err = fmt.Errorf("provider %s has no models configured", providerConfig.ID) + return + } + defaultLargeModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID) + largeModel = SelectedModel{ + Provider: providerConfig.ID, + Model: defaultLargeModel.ID, + MaxTokens: defaultLargeModel.DefaultMaxTokens, + } + defaultSmallModel := cfg.GetModel(providerConfig.ID, providerConfig.Models[0].ID) + smallModel = SelectedModel{ + Provider: providerConfig.ID, + Model: defaultSmallModel.ID, + MaxTokens: defaultSmallModel.DefaultMaxTokens, + } + return +} + +func (cfg *Config) configureSelectedModels(knownProviders []provider.Provider) error { + large, small, err := cfg.defaultModelSelection(knownProviders) + if err != nil { + return fmt.Errorf("failed to select default models: %w", err) + } + + largeModelSelected, largeModelConfigured := cfg.Models[SelectedModelTypeLarge] + if largeModelConfigured { + if largeModelSelected.Model != "" { + large.Model = largeModelSelected.Model + } + if largeModelSelected.Provider != "" { + large.Provider = largeModelSelected.Provider + } + model := cfg.GetModel(large.Provider, large.Model) + if model == nil { + return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider) + } + if largeModelSelected.MaxTokens > 0 { + large.MaxTokens = largeModelSelected.MaxTokens + } else { + large.MaxTokens = model.DefaultMaxTokens + } + large.ReasoningEffort = largeModelSelected.ReasoningEffort + large.Think = largeModelSelected.Think + + } + smallModelSelected, smallModelConfigured := cfg.Models[SelectedModelTypeSmall] + if smallModelConfigured { + if smallModelSelected.Model != "" { + small.Model = smallModelSelected.Model + } + if smallModelSelected.Provider != "" { + small.Provider = smallModelSelected.Provider + } + + model := cfg.GetModel(small.Provider, small.Model) + if model == nil { + return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider) + } + if smallModelSelected.MaxTokens > 0 { + small.MaxTokens = smallModelSelected.MaxTokens + } else { + small.MaxTokens = model.DefaultMaxTokens + } + small.ReasoningEffort = smallModelSelected.ReasoningEffort + small.Think = smallModelSelected.Think + } + + // validate the selected models + largeModel := cfg.GetModel(large.Provider, large.Model) + if largeModel == nil { + return fmt.Errorf("large model %s not found for provider %s", large.Model, large.Provider) + } + smallModel := cfg.GetModel(small.Provider, small.Model) + if smallModel == nil { + return fmt.Errorf("small model %s not found for provider %s", small.Model, small.Provider) + } + cfg.Models[SelectedModelTypeLarge] = large + cfg.Models[SelectedModelTypeSmall] = small + return nil +} + func loadFromConfigPaths(configPaths []string) (*Config, error) { var configs []io.Reader @@ -329,6 +425,34 @@ func loadFromReaders(readers []io.Reader) (*Config, error) { return LoadReader(merged) } +func hasVertexCredentials(env env.Env) bool { + useVertex := env.Get("GOOGLE_GENAI_USE_VERTEXAI") == "true" + hasProject := env.Get("GOOGLE_CLOUD_PROJECT") != "" + hasLocation := env.Get("GOOGLE_CLOUD_LOCATION") != "" + return useVertex && hasProject && hasLocation +} + +func hasAWSCredentials(env env.Env) bool { + if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" { + return true + } + + if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" { + return true + } + + if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" { + return true + } + + if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } + + return false +} + func globalConfig() string { xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") if xdgConfigHome != "" { diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index 911b203de887b255fdc80c25d4433fd5de52a419..01cf088c5b639683c30dc7d61505cde6b28ff593 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -760,3 +760,390 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { assert.True(t, exists) }) } + +func TestConfig_defaultModelSelection(t *testing.T) { + t.Run("default behavior uses the default models for given provider", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "abc", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{} + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + large, small, err := cfg.defaultModelSelection(knownProviders) + assert.NoError(t, err) + assert.Equal(t, "large-model", large.Model) + assert.Equal(t, "openai", large.Provider) + assert.Equal(t, int64(1000), large.MaxTokens) + assert.Equal(t, "small-model", small.Model) + assert.Equal(t, "openai", small.Provider) + assert.Equal(t, int64(500), small.MaxTokens) + }) + t.Run("should error if no providers configured", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$MISSING_KEY", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{} + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + _, _, err = cfg.defaultModelSelection(knownProviders) + assert.Error(t, err) + }) + t.Run("should error if model is missing", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "abc", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "not-large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{} + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + _, _, err = cfg.defaultModelSelection(knownProviders) + assert.Error(t, err) + }) + + t.Run("should configure the default models with a custom provider", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$MISSING", // will not be included in the config + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "not-large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Models: []provider.Model{ + { + ID: "model", + DefaultMaxTokens: 600, + }, + }, + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + large, small, err := cfg.defaultModelSelection(knownProviders) + assert.NoError(t, err) + assert.Equal(t, "model", large.Model) + assert.Equal(t, "custom", large.Provider) + assert.Equal(t, int64(600), large.MaxTokens) + assert.Equal(t, "model", small.Model) + assert.Equal(t, "custom", small.Provider) + assert.Equal(t, int64(600), small.MaxTokens) + }) + + t.Run("should fail if no model configured", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "$MISSING", // will not be included in the config + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "not-large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Models: []provider.Model{}, + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + _, _, err = cfg.defaultModelSelection(knownProviders) + assert.Error(t, err) + }) + t.Run("should use the default provider first", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "set", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom": { + APIKey: "test-key", + BaseURL: "https://api.custom.com/v1", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + }, + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + large, small, err := cfg.defaultModelSelection(knownProviders) + assert.NoError(t, err) + assert.Equal(t, "large-model", large.Model) + assert.Equal(t, "openai", large.Provider) + assert.Equal(t, int64(1000), large.MaxTokens) + assert.Equal(t, "small-model", small.Model) + assert.Equal(t, "openai", small.Provider) + assert.Equal(t, int64(500), small.MaxTokens) + }) +} + +func TestConfig_configureSelectedModels(t *testing.T) { + t.Run("should override defaults", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "abc", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "larger-model", + DefaultMaxTokens: 2000, + }, + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{ + Models: map[SelectedModelType]SelectedModel{ + "large": { + Model: "larger-model", + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + err = cfg.configureSelectedModels(knownProviders) + assert.NoError(t, err) + large := cfg.Models[SelectedModelTypeLarge] + small := cfg.Models[SelectedModelTypeSmall] + assert.Equal(t, "larger-model", large.Model) + assert.Equal(t, "openai", large.Provider) + assert.Equal(t, int64(2000), large.MaxTokens) + assert.Equal(t, "small-model", small.Model) + assert.Equal(t, "openai", small.Provider) + assert.Equal(t, int64(500), small.MaxTokens) + }) + t.Run("should be possible to use multiple providers", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "abc", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + { + ID: "anthropic", + APIKey: "abc", + DefaultLargeModelID: "a-large-model", + DefaultSmallModelID: "a-small-model", + Models: []provider.Model{ + { + ID: "a-large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "a-small-model", + DefaultMaxTokens: 200, + }, + }, + }, + } + + cfg := &Config{ + Models: map[SelectedModelType]SelectedModel{ + "small": { + Model: "a-small-model", + Provider: "anthropic", + MaxTokens: 300, + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + err = cfg.configureSelectedModels(knownProviders) + assert.NoError(t, err) + large := cfg.Models[SelectedModelTypeLarge] + small := cfg.Models[SelectedModelTypeSmall] + assert.Equal(t, "large-model", large.Model) + assert.Equal(t, "openai", large.Provider) + assert.Equal(t, int64(1000), large.MaxTokens) + assert.Equal(t, "a-small-model", small.Model) + assert.Equal(t, "anthropic", small.Provider) + assert.Equal(t, int64(300), small.MaxTokens) + }) + + t.Run("should override the max tokens only", func(t *testing.T) { + knownProviders := []provider.Provider{ + { + ID: "openai", + APIKey: "abc", + DefaultLargeModelID: "large-model", + DefaultSmallModelID: "small-model", + Models: []provider.Model{ + { + ID: "large-model", + DefaultMaxTokens: 1000, + }, + { + ID: "small-model", + DefaultMaxTokens: 500, + }, + }, + }, + } + + cfg := &Config{ + Models: map[SelectedModelType]SelectedModel{ + "large": { + MaxTokens: 100, + }, + }, + } + cfg.setDefaults("/tmp") + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, knownProviders) + assert.NoError(t, err) + + err = cfg.configureSelectedModels(knownProviders) + assert.NoError(t, err) + large := cfg.Models[SelectedModelTypeLarge] + assert.Equal(t, "large-model", large.Model) + assert.Equal(t, "openai", large.Provider) + assert.Equal(t, int64(100), large.MaxTokens) + }) +} diff --git a/pkg/env/env.go b/pkg/env/env.go index f223bea50e465c28d924072b35dd042be12b0054..24d44d10fca5a374732283d0aca4ddc8166b879b 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -51,5 +51,8 @@ func (m *mapEnv) Env() []string { } func NewFromMap(m map[string]string) Env { + if m == nil { + m = make(map[string]string) + } return &mapEnv{m: m} }