@@ -200,6 +200,28 @@ var (
)
+func readConfigFile(path string) (*Config, error) {
+ var cfg *Config
+ if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) {
+ // some other error occurred while checking the file
+ return nil, err
+ } else if err == nil {
+ // config file exists, read it
+ file, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ cfg = &Config{}
+ if err := json.Unmarshal(file, cfg); err != nil {
+ return nil, err
+ }
+ } else {
+ // config file does not exist, create a new one
+ cfg = &Config{}
+ }
+ return cfg, nil
+}
+
func loadConfig(cwd string, debug bool) (*Config, error) {
// First read the global config file
cfgPath := ConfigPath()
@@ -248,47 +270,29 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
}))
slog.SetDefault(logger)
}
- var globalCfg *Config
- if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
- // some other error occurred while checking the file
- return nil, err
- } else if err == nil {
- // config file exists, read it
- file, err := os.ReadFile(cfgPath)
- if err != nil {
- return nil, err
- }
- globalCfg = &Config{}
- if err := json.Unmarshal(file, globalCfg); err != nil {
- return nil, err
- }
- } else {
- // config file does not exist, create a new one
- globalCfg = &Config{}
+
+ priorityOrderedConfigFiles := []string{
+ cfgPath, // Global config file
+ filepath.Join(cwd, "crush.json"), // Local config file
+ filepath.Join(cwd, ".crush.json"), // Local config file
}
- var localConfig *Config
- // Global config loaded, now read the local config file
- localConfigPath := filepath.Join(cwd, "crush.json")
- if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
- // some other error occurred while checking the file
- return nil, err
- } else if err == nil {
- // local config file exists, read it
- file, err := os.ReadFile(localConfigPath)
+ configs := make([]*Config, 0)
+ for _, path := range priorityOrderedConfigFiles {
+ localConfig, err := readConfigFile(path)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
}
- localConfig = &Config{}
- if err := json.Unmarshal(file, localConfig); err != nil {
- return nil, err
+ if localConfig != nil {
+ // If the config file was read successfully, add it to the list
+ configs = append(configs, localConfig)
}
}
// merge options
- mergeOptions(cfg, globalCfg, localConfig)
+ mergeOptions(cfg, configs...)
- mergeProviderConfigs(cfg, globalCfg, localConfig)
+ mergeProviderConfigs(cfg, configs...)
// no providers found the app is not initialized yet
if len(cfg.Providers) == 0 {
return cfg, nil
@@ -310,7 +314,7 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
cfg.Models = PreferredModels{}
}
- mergeModels(cfg, globalCfg, localConfig)
+ mergeModels(cfg, configs...)
agents := map[AgentID]Agent{
AgentCoder: {
@@ -340,9 +344,9 @@ func loadConfig(cwd string, debug bool) (*Config, error) {
},
}
cfg.Agents = agents
- mergeAgents(cfg, globalCfg, localConfig)
- mergeMCPs(cfg, globalCfg, localConfig)
- mergeLSPs(cfg, globalCfg, localConfig)
+ mergeAgents(cfg, configs...)
+ mergeMCPs(cfg, configs...)
+ mergeLSPs(cfg, configs...)
// Validate the final configuration
if err := cfg.Validate(); err != nil {
@@ -457,8 +461,8 @@ func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfi
return nil
}
-func mergeModels(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeModels(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
@@ -472,8 +476,8 @@ func mergeModels(base, global, local *Config) {
}
}
-func mergeOptions(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeOptions(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
@@ -506,8 +510,8 @@ func mergeOptions(base, global, local *Config) {
}
}
-func mergeAgents(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeAgents(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
@@ -575,8 +579,8 @@ func mergeAgents(base, global, local *Config) {
}
}
-func mergeMCPs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeMCPs(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
@@ -584,8 +588,8 @@ func mergeMCPs(base, global, local *Config) {
}
}
-func mergeLSPs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeLSPs(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
@@ -593,15 +597,27 @@ func mergeLSPs(base, global, local *Config) {
}
}
-func mergeProviderConfigs(base, global, local *Config) {
- for _, cfg := range []*Config{global, local} {
+func mergeProviderConfigs(base *Config, others ...*Config) {
+ for _, cfg := range others {
if cfg == nil {
continue
}
for providerName, p := range cfg.Providers {
p.ID = providerName
if _, ok := base.Providers[providerName]; !ok {
- base.Providers[providerName] = p
+ if slices.Contains(provider.KnownProviders(), providerName) {
+ providers := Providers()
+ for _, providerDef := range providers {
+ if providerDef.ID == providerName {
+ logging.Info("Using default provider config for", "provider", providerName)
+ baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey)
+ base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p)
+ break
+ }
+ }
+ } else {
+ base.Providers[providerName] = p
+ }
} else {
base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
}
@@ -676,6 +692,40 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig
}
}
+func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig {
+ providerConfig := providerDefaultConfig(p.ID)
+ providerConfig.APIKey = apiKey
+ providerConfig.DefaultLargeModel = p.DefaultLargeModelID
+ providerConfig.DefaultSmallModel = p.DefaultSmallModelID
+ baseURL := p.APIEndpoint
+ if strings.HasPrefix(baseURL, "$") {
+ envVar := strings.TrimPrefix(baseURL, "$")
+ baseURL = os.Getenv(envVar)
+ }
+ providerConfig.BaseURL = baseURL
+ for _, model := range p.Models {
+ configModel := Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ SupportsImages: model.SupportsImages,
+ }
+ // Set reasoning effort for reasoning models
+ if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
+ configModel.HasReasoningEffort = model.HasReasoningEffort
+ configModel.ReasoningEffort = model.DefaultReasoningEffort
+ }
+ providerConfig.Models = append(providerConfig.Models, configModel)
+ }
+ return providerConfig
+}
+
func defaultConfigBasedOnEnv() *Config {
cfg := &Config{
Options: Options{
@@ -694,37 +744,7 @@ func defaultConfigBasedOnEnv() *Config {
if strings.HasPrefix(p.APIKey, "$") {
envVar := strings.TrimPrefix(p.APIKey, "$")
if apiKey := os.Getenv(envVar); apiKey != "" {
- providerConfig := providerDefaultConfig(p.ID)
- providerConfig.APIKey = apiKey
- providerConfig.DefaultLargeModel = p.DefaultLargeModelID
- providerConfig.DefaultSmallModel = p.DefaultSmallModelID
- baseURL := p.APIEndpoint
- if strings.HasPrefix(baseURL, "$") {
- envVar := strings.TrimPrefix(baseURL, "$")
- baseURL = os.Getenv(envVar)
- }
- providerConfig.BaseURL = baseURL
- for _, model := range p.Models {
- configModel := Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- SupportsImages: model.SupportsImages,
- }
- // Set reasoning effort for reasoning models
- if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
- configModel.HasReasoningEffort = model.HasReasoningEffort
- configModel.ReasoningEffort = model.DefaultReasoningEffort
- }
- providerConfig.Models = append(providerConfig.Models, configModel)
- }
- cfg.Providers[p.ID] = providerConfig
+ cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey)
}
}
}