diff --git a/internal/config/config.go b/internal/config/config.go index 4a8c47dd8686ca562b60c97efa2c15c31daf88ad..589cd5c0ca30811d2fa47ae527e2880d82ccedcd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -200,6 +200,28 @@ var ( ) +func readConfigFile(path string) (*Config, error) { + var cfg *Config + if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) { + // some other error occurred while checking the file + return nil, err + } else if err == nil { + // config file exists, read it + file, err := os.ReadFile(path) + if err != nil { + return nil, err + } + cfg = &Config{} + if err := json.Unmarshal(file, cfg); err != nil { + return nil, err + } + } else { + // config file does not exist, create a new one + cfg = &Config{} + } + return cfg, nil +} + func loadConfig(cwd string, debug bool) (*Config, error) { // First read the global config file cfgPath := ConfigPath() @@ -248,47 +270,29 @@ func loadConfig(cwd string, debug bool) (*Config, error) { })) slog.SetDefault(logger) } - var globalCfg *Config - if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // config file exists, read it - file, err := os.ReadFile(cfgPath) - if err != nil { - return nil, err - } - globalCfg = &Config{} - if err := json.Unmarshal(file, globalCfg); err != nil { - return nil, err - } - } else { - // config file does not exist, create a new one - globalCfg = &Config{} + + priorityOrderedConfigFiles := []string{ + cfgPath, // Global config file + filepath.Join(cwd, "crush.json"), // Local config file + filepath.Join(cwd, ".crush.json"), // Local config file } - var localConfig *Config - // Global config loaded, now read the local config file - localConfigPath := filepath.Join(cwd, "crush.json") - if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) { - // some other error occurred while checking the file - return nil, err - } else if err == nil { - // local config file exists, read it - file, err := os.ReadFile(localConfigPath) + configs := make([]*Config, 0) + for _, path := range priorityOrderedConfigFiles { + localConfig, err := readConfigFile(path) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read config file %s: %w", path, err) } - localConfig = &Config{} - if err := json.Unmarshal(file, localConfig); err != nil { - return nil, err + if localConfig != nil { + // If the config file was read successfully, add it to the list + configs = append(configs, localConfig) } } // merge options - mergeOptions(cfg, globalCfg, localConfig) + mergeOptions(cfg, configs...) - mergeProviderConfigs(cfg, globalCfg, localConfig) + mergeProviderConfigs(cfg, configs...) // no providers found the app is not initialized yet if len(cfg.Providers) == 0 { return cfg, nil @@ -310,7 +314,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) { cfg.Models = PreferredModels{} } - mergeModels(cfg, globalCfg, localConfig) + mergeModels(cfg, configs...) agents := map[AgentID]Agent{ AgentCoder: { @@ -340,9 +344,9 @@ func loadConfig(cwd string, debug bool) (*Config, error) { }, } cfg.Agents = agents - mergeAgents(cfg, globalCfg, localConfig) - mergeMCPs(cfg, globalCfg, localConfig) - mergeLSPs(cfg, globalCfg, localConfig) + mergeAgents(cfg, configs...) + mergeMCPs(cfg, configs...) + mergeLSPs(cfg, configs...) // Validate the final configuration if err := cfg.Validate(); err != nil { @@ -457,8 +461,8 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi return nil } -func mergeModels(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeModels(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } @@ -472,8 +476,8 @@ func mergeModels(base, global, local *Config) { } } -func mergeOptions(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeOptions(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } @@ -506,8 +510,8 @@ func mergeOptions(base, global, local *Config) { } } -func mergeAgents(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeAgents(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } @@ -575,8 +579,8 @@ func mergeAgents(base, global, local *Config) { } } -func mergeMCPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeMCPs(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } @@ -584,8 +588,8 @@ func mergeMCPs(base, global, local *Config) { } } -func mergeLSPs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeLSPs(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } @@ -593,15 +597,27 @@ func mergeLSPs(base, global, local *Config) { } } -func mergeProviderConfigs(base, global, local *Config) { - for _, cfg := range []*Config{global, local} { +func mergeProviderConfigs(base *Config, others ...*Config) { + for _, cfg := range others { if cfg == nil { continue } for providerName, p := range cfg.Providers { p.ID = providerName if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = p + if slices.Contains(provider.KnownProviders(), providerName) { + providers := Providers() + for _, providerDef := range providers { + if providerDef.ID == providerName { + logging.Info("Using default provider config for", "provider", providerName) + baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey) + base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p) + break + } + } + } else { + base.Providers[providerName] = p + } } else { base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p) } @@ -676,6 +692,40 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig } } +func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig { + providerConfig := providerDefaultConfig(p.ID) + providerConfig.APIKey = apiKey + providerConfig.DefaultLargeModel = p.DefaultLargeModelID + providerConfig.DefaultSmallModel = p.DefaultSmallModelID + baseURL := p.APIEndpoint + if strings.HasPrefix(baseURL, "$") { + envVar := strings.TrimPrefix(baseURL, "$") + baseURL = os.Getenv(envVar) + } + providerConfig.BaseURL = baseURL + for _, model := range p.Models { + configModel := Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + SupportsImages: model.SupportsImages, + } + // Set reasoning effort for reasoning models + if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { + configModel.HasReasoningEffort = model.HasReasoningEffort + configModel.ReasoningEffort = model.DefaultReasoningEffort + } + providerConfig.Models = append(providerConfig.Models, configModel) + } + return providerConfig +} + func defaultConfigBasedOnEnv() *Config { cfg := &Config{ Options: Options{ @@ -694,37 +744,7 @@ func defaultConfigBasedOnEnv() *Config { if strings.HasPrefix(p.APIKey, "$") { envVar := strings.TrimPrefix(p.APIKey, "$") if apiKey := os.Getenv(envVar); apiKey != "" { - providerConfig := providerDefaultConfig(p.ID) - providerConfig.APIKey = apiKey - providerConfig.DefaultLargeModel = p.DefaultLargeModelID - providerConfig.DefaultSmallModel = p.DefaultSmallModelID - baseURL := p.APIEndpoint - if strings.HasPrefix(baseURL, "$") { - envVar := strings.TrimPrefix(baseURL, "$") - baseURL = os.Getenv(envVar) - } - providerConfig.BaseURL = baseURL - for _, model := range p.Models { - configModel := Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - SupportsImages: model.SupportsImages, - } - // Set reasoning effort for reasoning models - if model.HasReasoningEffort && model.DefaultReasoningEffort != "" { - configModel.HasReasoningEffort = model.HasReasoningEffort - configModel.ReasoningEffort = model.DefaultReasoningEffort - } - providerConfig.Models = append(providerConfig.Models, configModel) - } - cfg.Providers[p.ID] = providerConfig + cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey) } } } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 05f09ad77e224a59bcb825e85f353e317c7c4a83..a1c1414f159a0d6282c2dbfb678726602edf1d1f 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -382,6 +382,10 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err return false, 0, err } + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + if apiErr.StatusCode == 401 { a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey) if err != nil { @@ -395,10 +399,6 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err return false, 0, err } - if attempts > maxRetries { - return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) - } - retryMs := 0 retryAfterValues := apiErr.Response.Header.Values("Retry-After") diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index bddf820c2d3ccf9bba1a683ed4fe469d05fa31bf..46c0b210f6caa3adc4de131e251dc4d865fc5f80 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -347,6 +347,10 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) return false, 0, err } + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + // Check for token expiration (401 Unauthorized) if apiErr.StatusCode == 401 { o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey) @@ -361,10 +365,6 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) return false, 0, err } - if attempts > maxRetries { - return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) - } - retryMs := 0 retryAfterValues := apiErr.Response.Header.Values("Retry-After")