fix: fix config

Kujtim Hoxha created

Change summary

internal/config/config.go          | 180 +++++++++++++++++--------------
internal/llm/provider/anthropic.go |   8 
internal/llm/provider/openai.go    |   8 
3 files changed, 108 insertions(+), 88 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -200,6 +200,28 @@ var (
 
 )
 
+func readConfigFile(path string) (*Config, error) {
+	var cfg *Config
+	if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) {
+		// some other error occurred while checking the file
+		return nil, err
+	} else if err == nil {
+		// config file exists, read it
+		file, err := os.ReadFile(path)
+		if err != nil {
+			return nil, err
+		}
+		cfg = &Config{}
+		if err := json.Unmarshal(file, cfg); err != nil {
+			return nil, err
+		}
+	} else {
+		// config file does not exist, create a new one
+		cfg = &Config{}
+	}
+	return cfg, nil
+}
+
 func loadConfig(cwd string, debug bool) (*Config, error) {
 	// First read the global config file
 	cfgPath := ConfigPath()
@@ -248,47 +270,29 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
 		}))
 		slog.SetDefault(logger)
 	}
-	var globalCfg *Config
-	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
-		// some other error occurred while checking the file
-		return nil, err
-	} else if err == nil {
-		// config file exists, read it
-		file, err := os.ReadFile(cfgPath)
-		if err != nil {
-			return nil, err
-		}
-		globalCfg = &Config{}
-		if err := json.Unmarshal(file, globalCfg); err != nil {
-			return nil, err
-		}
-	} else {
-		// config file does not exist, create a new one
-		globalCfg = &Config{}
+
+	priorityOrderedConfigFiles := []string{
+		cfgPath,                           // Global config file
+		filepath.Join(cwd, "crush.json"),  // Local config file
+		filepath.Join(cwd, ".crush.json"), // Local config file
 	}
 
-	var localConfig *Config
-	// Global config loaded, now read the local config file
-	localConfigPath := filepath.Join(cwd, "crush.json")
-	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
-		// some other error occurred while checking the file
-		return nil, err
-	} else if err == nil {
-		// local config file exists, read it
-		file, err := os.ReadFile(localConfigPath)
+	configs := make([]*Config, 0)
+	for _, path := range priorityOrderedConfigFiles {
+		localConfig, err := readConfigFile(path)
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
 		}
-		localConfig = &Config{}
-		if err := json.Unmarshal(file, localConfig); err != nil {
-			return nil, err
+		if localConfig != nil {
+			// If the config file was read successfully, add it to the list
+			configs = append(configs, localConfig)
 		}
 	}
 
 	// merge options
-	mergeOptions(cfg, globalCfg, localConfig)
+	mergeOptions(cfg, configs...)
 
-	mergeProviderConfigs(cfg, globalCfg, localConfig)
+	mergeProviderConfigs(cfg, configs...)
 	// no providers found the app is not initialized yet
 	if len(cfg.Providers) == 0 {
 		return cfg, nil
@@ -310,7 +314,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
 		cfg.Models = PreferredModels{}
 	}
 
-	mergeModels(cfg, globalCfg, localConfig)
+	mergeModels(cfg, configs...)
 
 	agents := map[AgentID]Agent{
 		AgentCoder: {
@@ -340,9 +344,9 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
 		},
 	}
 	cfg.Agents = agents
-	mergeAgents(cfg, globalCfg, localConfig)
-	mergeMCPs(cfg, globalCfg, localConfig)
-	mergeLSPs(cfg, globalCfg, localConfig)
+	mergeAgents(cfg, configs...)
+	mergeMCPs(cfg, configs...)
+	mergeLSPs(cfg, configs...)
 
 	// Validate the final configuration
 	if err := cfg.Validate(); err != nil {
@@ -457,8 +461,8 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi
 	return nil
 }
 
-func mergeModels(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeModels(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
@@ -472,8 +476,8 @@ func mergeModels(base, global, local *Config) {
 	}
 }
 
-func mergeOptions(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeOptions(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
@@ -506,8 +510,8 @@ func mergeOptions(base, global, local *Config) {
 	}
 }
 
-func mergeAgents(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeAgents(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
@@ -575,8 +579,8 @@ func mergeAgents(base, global, local *Config) {
 	}
 }
 
-func mergeMCPs(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeMCPs(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
@@ -584,8 +588,8 @@ func mergeMCPs(base, global, local *Config) {
 	}
 }
 
-func mergeLSPs(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeLSPs(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
@@ -593,15 +597,27 @@ func mergeLSPs(base, global, local *Config) {
 	}
 }
 
-func mergeProviderConfigs(base, global, local *Config) {
-	for _, cfg := range []*Config{global, local} {
+func mergeProviderConfigs(base *Config, others ...*Config) {
+	for _, cfg := range others {
 		if cfg == nil {
 			continue
 		}
 		for providerName, p := range cfg.Providers {
 			p.ID = providerName
 			if _, ok := base.Providers[providerName]; !ok {
-				base.Providers[providerName] = p
+				if slices.Contains(provider.KnownProviders(), providerName) {
+					providers := Providers()
+					for _, providerDef := range providers {
+						if providerDef.ID == providerName {
+							logging.Info("Using default provider config for", "provider", providerName)
+							baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey)
+							base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p)
+							break
+						}
+					}
+				} else {
+					base.Providers[providerName] = p
+				}
 			} else {
 				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
 			}
@@ -676,6 +692,40 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig
 	}
 }
 
+func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig {
+	providerConfig := providerDefaultConfig(p.ID)
+	providerConfig.APIKey = apiKey
+	providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+	providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+	baseURL := p.APIEndpoint
+	if strings.HasPrefix(baseURL, "$") {
+		envVar := strings.TrimPrefix(baseURL, "$")
+		baseURL = os.Getenv(envVar)
+	}
+	providerConfig.BaseURL = baseURL
+	for _, model := range p.Models {
+		configModel := Model{
+			ID:                 model.ID,
+			Name:               model.Name,
+			CostPer1MIn:        model.CostPer1MIn,
+			CostPer1MOut:       model.CostPer1MOut,
+			CostPer1MInCached:  model.CostPer1MInCached,
+			CostPer1MOutCached: model.CostPer1MOutCached,
+			ContextWindow:      model.ContextWindow,
+			DefaultMaxTokens:   model.DefaultMaxTokens,
+			CanReason:          model.CanReason,
+			SupportsImages:     model.SupportsImages,
+		}
+		// Set reasoning effort for reasoning models
+		if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+			configModel.HasReasoningEffort = model.HasReasoningEffort
+			configModel.ReasoningEffort = model.DefaultReasoningEffort
+		}
+		providerConfig.Models = append(providerConfig.Models, configModel)
+	}
+	return providerConfig
+}
+
 func defaultConfigBasedOnEnv() *Config {
 	cfg := &Config{
 		Options: Options{
@@ -694,37 +744,7 @@ func defaultConfigBasedOnEnv() *Config {
 		if strings.HasPrefix(p.APIKey, "$") {
 			envVar := strings.TrimPrefix(p.APIKey, "$")
 			if apiKey := os.Getenv(envVar); apiKey != "" {
-				providerConfig := providerDefaultConfig(p.ID)
-				providerConfig.APIKey = apiKey
-				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
-				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
-				baseURL := p.APIEndpoint
-				if strings.HasPrefix(baseURL, "$") {
-					envVar := strings.TrimPrefix(baseURL, "$")
-					baseURL = os.Getenv(envVar)
-				}
-				providerConfig.BaseURL = baseURL
-				for _, model := range p.Models {
-					configModel := Model{
-						ID:                 model.ID,
-						Name:               model.Name,
-						CostPer1MIn:        model.CostPer1MIn,
-						CostPer1MOut:       model.CostPer1MOut,
-						CostPer1MInCached:  model.CostPer1MInCached,
-						CostPer1MOutCached: model.CostPer1MOutCached,
-						ContextWindow:      model.ContextWindow,
-						DefaultMaxTokens:   model.DefaultMaxTokens,
-						CanReason:          model.CanReason,
-						SupportsImages:     model.SupportsImages,
-					}
-					// Set reasoning effort for reasoning models
-					if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
-						configModel.HasReasoningEffort = model.HasReasoningEffort
-						configModel.ReasoningEffort = model.DefaultReasoningEffort
-					}
-					providerConfig.Models = append(providerConfig.Models, configModel)
-				}
-				cfg.Providers[p.ID] = providerConfig
+				cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey)
 			}
 		}
 	}

internal/llm/provider/anthropic.go 🔗

@@ -382,6 +382,10 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
 		return false, 0, err
 	}
 
+	if attempts > maxRetries {
+		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+	}
+
 	if apiErr.StatusCode == 401 {
 		a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
 		if err != nil {
@@ -395,10 +399,6 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
 		return false, 0, err
 	}
 
-	if attempts > maxRetries {
-		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
-	}
-
 	retryMs := 0
 	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
 

internal/llm/provider/openai.go 🔗

@@ -347,6 +347,10 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 		return false, 0, err
 	}
 
+	if attempts > maxRetries {
+		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+	}
+
 	// Check for token expiration (401 Unauthorized)
 	if apiErr.StatusCode == 401 {
 		o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
@@ -361,10 +365,6 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 		return false, 0, err
 	}
 
-	if attempts > maxRetries {
-		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
-	}
-
 	retryMs := 0
 	retryAfterValues := apiErr.Response.Header.Values("Retry-After")