1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 HasReasoningEffort bool `json:"has_reasoning_effort"`
69 SupportsImages bool `json:"supports_attachments"`
70}
71
72type VertexAIOptions struct {
73 APIKey string `json:"api_key,omitempty"`
74 Project string `json:"project,omitempty"`
75 Location string `json:"location,omitempty"`
76}
77
78type ProviderConfig struct {
79 ID provider.InferenceProvider `json:"id"`
80 BaseURL string `json:"base_url,omitempty"`
81 ProviderType provider.Type `json:"provider_type"`
82 APIKey string `json:"api_key,omitempty"`
83 Disabled bool `json:"disabled"`
84 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
85 // used for e.x for vertex to set the project
86 ExtraParams map[string]string `json:"extra_params,omitempty"`
87
88 DefaultLargeModel string `json:"default_large_model,omitempty"`
89 DefaultSmallModel string `json:"default_small_model,omitempty"`
90
91 Models []Model `json:"models,omitempty"`
92}
93
94type Agent struct {
95 ID AgentID `json:"id"`
96 Name string `json:"name"`
97 Description string `json:"description,omitempty"`
98 // This is the id of the system prompt used by the agent
99 Disabled bool `json:"disabled"`
100
101 Model ModelType `json:"model"`
102
103 // The available tools for the agent
104 // if this is nil, all tools are available
105 AllowedTools []string `json:"allowed_tools"`
106
107 // this tells us which MCPs are available for this agent
108 // if this is empty all mcps are available
109 // the string array is the list of tools from the AllowedMCP the agent has available
110 // if the string array is nil, all tools from the AllowedMCP are available
111 AllowedMCP map[string][]string `json:"allowed_mcp"`
112
113 // The list of LSPs that this agent can use
114 // if this is nil, all LSPs are available
115 AllowedLSP []string `json:"allowed_lsp"`
116
117 // Overrides the context paths for this agent
118 ContextPaths []string `json:"context_paths"`
119}
120
121type MCPType string
122
123const (
124 MCPStdio MCPType = "stdio"
125 MCPSse MCPType = "sse"
126)
127
128type MCP struct {
129 Command string `json:"command"`
130 Env []string `json:"env"`
131 Args []string `json:"args"`
132 Type MCPType `json:"type"`
133 URL string `json:"url"`
134 Headers map[string]string `json:"headers"`
135}
136
137type LSPConfig struct {
138 Disabled bool `json:"enabled"`
139 Command string `json:"command"`
140 Args []string `json:"args"`
141 Options any `json:"options"`
142}
143
144type TUIOptions struct {
145 CompactMode bool `json:"compact_mode"`
146 // Here we can add themes later or any TUI related options
147}
148
149type Options struct {
150 ContextPaths []string `json:"context_paths"`
151 TUI TUIOptions `json:"tui"`
152 Debug bool `json:"debug"`
153 DebugLSP bool `json:"debug_lsp"`
154 DisableAutoSummarize bool `json:"disable_auto_summarize"`
155 // Relative to the cwd
156 DataDirectory string `json:"data_directory"`
157}
158
159type PreferredModel struct {
160 ModelID string `json:"model_id"`
161 Provider provider.InferenceProvider `json:"provider"`
162 // Overrides the default reasoning effort for this model
163 ReasoningEffort string `json:"reasoning_effort,omitempty"`
164 // Overrides the default max tokens for this model
165 MaxTokens int64 `json:"max_tokens,omitempty"`
166}
167
168type PreferredModels struct {
169 Large PreferredModel `json:"large"`
170 Small PreferredModel `json:"small"`
171}
172
173type Config struct {
174 Models PreferredModels `json:"models"`
175 // List of configured providers
176 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
177
178 // List of configured agents
179 Agents map[AgentID]Agent `json:"agents,omitempty"`
180
181 // List of configured MCPs
182 MCP map[string]MCP `json:"mcp,omitempty"`
183
184 // List of configured LSPs
185 LSP map[string]LSPConfig `json:"lsp,omitempty"`
186
187 // Miscellaneous options
188 Options Options `json:"options"`
189}
190
191var (
192 instance *Config // The single instance of the Singleton
193 cwd string
194 once sync.Once // Ensures the initialization happens only once
195
196)
197
198func loadConfig(cwd string, debug bool) (*Config, error) {
199 // First read the global config file
200 cfgPath := ConfigPath()
201
202 cfg := defaultConfigBasedOnEnv()
203 cfg.Options.Debug = debug
204 defaultLevel := slog.LevelInfo
205 if cfg.Options.Debug {
206 defaultLevel = slog.LevelDebug
207 }
208 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
209 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
210
211 // if file does not exist create it
212 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
213 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
214 return cfg, fmt.Errorf("failed to create directory: %w", err)
215 }
216 if _, err := os.Create(loggingFile); err != nil {
217 return cfg, fmt.Errorf("failed to create log file: %w", err)
218 }
219 }
220
221 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
222 if err != nil {
223 return cfg, fmt.Errorf("failed to open log file: %w", err)
224 }
225 // Configure logger
226 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
227 Level: defaultLevel,
228 }))
229 slog.SetDefault(logger)
230 } else {
231 // Configure logger
232 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
233 Level: defaultLevel,
234 }))
235 slog.SetDefault(logger)
236 }
237 var globalCfg *Config
238 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
239 // some other error occurred while checking the file
240 return nil, err
241 } else if err == nil {
242 // config file exists, read it
243 file, err := os.ReadFile(cfgPath)
244 if err != nil {
245 return nil, err
246 }
247 globalCfg = &Config{}
248 if err := json.Unmarshal(file, globalCfg); err != nil {
249 return nil, err
250 }
251 } else {
252 // config file does not exist, create a new one
253 globalCfg = &Config{}
254 }
255
256 var localConfig *Config
257 // Global config loaded, now read the local config file
258 localConfigPath := filepath.Join(cwd, "crush.json")
259 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
260 // some other error occurred while checking the file
261 return nil, err
262 } else if err == nil {
263 // local config file exists, read it
264 file, err := os.ReadFile(localConfigPath)
265 if err != nil {
266 return nil, err
267 }
268 localConfig = &Config{}
269 if err := json.Unmarshal(file, localConfig); err != nil {
270 return nil, err
271 }
272 }
273
274 // merge options
275 mergeOptions(cfg, globalCfg, localConfig)
276
277 mergeProviderConfigs(cfg, globalCfg, localConfig)
278 // no providers found the app is not initialized yet
279 if len(cfg.Providers) == 0 {
280 return cfg, nil
281 }
282 preferredProvider := getPreferredProvider(cfg.Providers)
283 if preferredProvider != nil {
284 cfg.Models = PreferredModels{
285 Large: PreferredModel{
286 ModelID: preferredProvider.DefaultLargeModel,
287 Provider: preferredProvider.ID,
288 },
289 Small: PreferredModel{
290 ModelID: preferredProvider.DefaultSmallModel,
291 Provider: preferredProvider.ID,
292 },
293 }
294 } else {
295 // No valid providers found, set empty models
296 cfg.Models = PreferredModels{}
297 }
298
299 mergeModels(cfg, globalCfg, localConfig)
300
301 agents := map[AgentID]Agent{
302 AgentCoder: {
303 ID: AgentCoder,
304 Name: "Coder",
305 Description: "An agent that helps with executing coding tasks.",
306 Model: LargeModel,
307 ContextPaths: cfg.Options.ContextPaths,
308 // All tools allowed
309 },
310 AgentTask: {
311 ID: AgentTask,
312 Name: "Task",
313 Description: "An agent that helps with searching for context and finding implementation details.",
314 Model: LargeModel,
315 ContextPaths: cfg.Options.ContextPaths,
316 AllowedTools: []string{
317 "glob",
318 "grep",
319 "ls",
320 "sourcegraph",
321 "view",
322 },
323 // NO MCPs or LSPs by default
324 AllowedMCP: map[string][]string{},
325 AllowedLSP: []string{},
326 },
327 }
328 cfg.Agents = agents
329 mergeAgents(cfg, globalCfg, localConfig)
330 mergeMCPs(cfg, globalCfg, localConfig)
331 mergeLSPs(cfg, globalCfg, localConfig)
332
333 // Validate the final configuration
334 if err := cfg.Validate(); err != nil {
335 return cfg, fmt.Errorf("configuration validation failed: %w", err)
336 }
337
338 return cfg, nil
339}
340
341func Init(workingDir string, debug bool) (*Config, error) {
342 var err error
343 once.Do(func() {
344 cwd = workingDir
345 instance, err = loadConfig(cwd, debug)
346 if err != nil {
347 logging.Error("Failed to load config", "error", err)
348 }
349 })
350
351 return instance, err
352}
353
354func Get() *Config {
355 if instance == nil {
356 // TODO: Handle this better
357 panic("Config not initialized. Call InitConfig first.")
358 }
359 return instance
360}
361
362func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
363 providers := Providers()
364 for _, p := range providers {
365 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
366 return &providerConfig
367 }
368 }
369 // if none found return the first configured provider
370 for _, providerConfig := range configuredProviders {
371 if !providerConfig.Disabled {
372 return &providerConfig
373 }
374 }
375 return nil
376}
377
378func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
379 if other.APIKey != "" {
380 base.APIKey = other.APIKey
381 }
382 // Only change these options if the provider is not a known provider
383 if !slices.Contains(provider.KnownProviders(), p) {
384 if other.BaseURL != "" {
385 base.BaseURL = other.BaseURL
386 }
387 if other.ProviderType != "" {
388 base.ProviderType = other.ProviderType
389 }
390 if len(other.ExtraHeaders) > 0 {
391 if base.ExtraHeaders == nil {
392 base.ExtraHeaders = make(map[string]string)
393 }
394 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
395 }
396 if len(other.ExtraParams) > 0 {
397 if base.ExtraParams == nil {
398 base.ExtraParams = make(map[string]string)
399 }
400 maps.Copy(base.ExtraParams, other.ExtraParams)
401 }
402 }
403
404 if other.Disabled {
405 base.Disabled = other.Disabled
406 }
407
408 if other.DefaultLargeModel != "" {
409 base.DefaultLargeModel = other.DefaultLargeModel
410 }
411 // Add new models if they don't exist
412 if other.Models != nil {
413 for _, model := range other.Models {
414 // check if the model already exists
415 exists := false
416 for _, existingModel := range base.Models {
417 if existingModel.ID == model.ID {
418 exists = true
419 break
420 }
421 }
422 if !exists {
423 base.Models = append(base.Models, model)
424 }
425 }
426 }
427
428 return base
429}
430
431func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
432 if !slices.Contains(provider.KnownProviders(), p) {
433 if providerConfig.ProviderType != provider.TypeOpenAI {
434 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
435 }
436 if providerConfig.BaseURL == "" {
437 return errors.New("base URL must be set for custom providers")
438 }
439 if providerConfig.APIKey == "" {
440 return errors.New("API key must be set for custom providers")
441 }
442 }
443 return nil
444}
445
446func mergeModels(base, global, local *Config) {
447 for _, cfg := range []*Config{global, local} {
448 if cfg == nil {
449 continue
450 }
451 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
452 base.Models.Large = cfg.Models.Large
453 }
454
455 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
456 base.Models.Small = cfg.Models.Small
457 }
458 }
459}
460
461func mergeOptions(base, global, local *Config) {
462 for _, cfg := range []*Config{global, local} {
463 if cfg == nil {
464 continue
465 }
466 baseOptions := base.Options
467 other := cfg.Options
468 if len(other.ContextPaths) > 0 {
469 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
470 }
471
472 if other.TUI.CompactMode {
473 baseOptions.TUI.CompactMode = other.TUI.CompactMode
474 }
475
476 if other.Debug {
477 baseOptions.Debug = other.Debug
478 }
479
480 if other.DebugLSP {
481 baseOptions.DebugLSP = other.DebugLSP
482 }
483
484 if other.DisableAutoSummarize {
485 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
486 }
487
488 if other.DataDirectory != "" {
489 baseOptions.DataDirectory = other.DataDirectory
490 }
491 base.Options = baseOptions
492 }
493}
494
495func mergeAgents(base, global, local *Config) {
496 for _, cfg := range []*Config{global, local} {
497 if cfg == nil {
498 continue
499 }
500 for agentID, newAgent := range cfg.Agents {
501 if _, ok := base.Agents[agentID]; !ok {
502 // New agent - apply defaults
503 newAgent.ID = agentID // Ensure the ID is set correctly
504 if newAgent.Model == "" {
505 newAgent.Model = LargeModel // Default model type
506 }
507 // Context paths are always additive - start with global, then add custom
508 if len(newAgent.ContextPaths) > 0 {
509 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
510 } else {
511 newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
512 }
513 base.Agents[agentID] = newAgent
514 } else {
515 baseAgent := base.Agents[agentID]
516
517 // Special handling for known agents - only allow model changes
518 if agentID == AgentCoder || agentID == AgentTask {
519 if newAgent.Model != "" {
520 baseAgent.Model = newAgent.Model
521 }
522 // For known agents, only allow MCP and LSP configuration
523 if newAgent.AllowedMCP != nil {
524 baseAgent.AllowedMCP = newAgent.AllowedMCP
525 }
526 if newAgent.AllowedLSP != nil {
527 baseAgent.AllowedLSP = newAgent.AllowedLSP
528 }
529 // Context paths are additive for known agents too
530 if len(newAgent.ContextPaths) > 0 {
531 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
532 }
533 } else {
534 // Custom agents - allow full merging
535 if newAgent.Name != "" {
536 baseAgent.Name = newAgent.Name
537 }
538 if newAgent.Description != "" {
539 baseAgent.Description = newAgent.Description
540 }
541 if newAgent.Model != "" {
542 baseAgent.Model = newAgent.Model
543 } else if baseAgent.Model == "" {
544 baseAgent.Model = LargeModel // Default fallback
545 }
546
547 // Boolean fields - always update (including false values)
548 baseAgent.Disabled = newAgent.Disabled
549
550 // Slice/Map fields - update if provided (including empty slices/maps)
551 if newAgent.AllowedTools != nil {
552 baseAgent.AllowedTools = newAgent.AllowedTools
553 }
554 if newAgent.AllowedMCP != nil {
555 baseAgent.AllowedMCP = newAgent.AllowedMCP
556 }
557 if newAgent.AllowedLSP != nil {
558 baseAgent.AllowedLSP = newAgent.AllowedLSP
559 }
560 // Context paths are additive for custom agents too
561 if len(newAgent.ContextPaths) > 0 {
562 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
563 }
564 }
565
566 base.Agents[agentID] = baseAgent
567 }
568 }
569 }
570}
571
572func mergeMCPs(base, global, local *Config) {
573 for _, cfg := range []*Config{global, local} {
574 if cfg == nil {
575 continue
576 }
577 maps.Copy(base.MCP, cfg.MCP)
578 }
579}
580
581func mergeLSPs(base, global, local *Config) {
582 for _, cfg := range []*Config{global, local} {
583 if cfg == nil {
584 continue
585 }
586 maps.Copy(base.LSP, cfg.LSP)
587 }
588}
589
590func mergeProviderConfigs(base, global, local *Config) {
591 for _, cfg := range []*Config{global, local} {
592 if cfg == nil {
593 continue
594 }
595 for providerName, p := range cfg.Providers {
596 if _, ok := base.Providers[providerName]; !ok {
597 base.Providers[providerName] = p
598 } else {
599 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
600 }
601 }
602 }
603
604 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
605 for providerName, providerConfig := range base.Providers {
606 err := validateProvider(providerName, providerConfig)
607 if err != nil {
608 logging.Warn("Skipping provider", "name", providerName, "error", err)
609 continue // Skip invalid providers
610 }
611 finalProviders[providerName] = providerConfig
612 }
613 base.Providers = finalProviders
614}
615
616func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
617 switch providerId {
618 case provider.InferenceProviderAnthropic:
619 return ProviderConfig{
620 ID: providerId,
621 ProviderType: provider.TypeAnthropic,
622 }
623 case provider.InferenceProviderOpenAI:
624 return ProviderConfig{
625 ID: providerId,
626 ProviderType: provider.TypeOpenAI,
627 }
628 case provider.InferenceProviderGemini:
629 return ProviderConfig{
630 ID: providerId,
631 ProviderType: provider.TypeGemini,
632 }
633 case provider.InferenceProviderBedrock:
634 return ProviderConfig{
635 ID: providerId,
636 ProviderType: provider.TypeBedrock,
637 }
638 case provider.InferenceProviderAzure:
639 return ProviderConfig{
640 ID: providerId,
641 ProviderType: provider.TypeAzure,
642 }
643 case provider.InferenceProviderOpenRouter:
644 return ProviderConfig{
645 ID: providerId,
646 ProviderType: provider.TypeOpenAI,
647 BaseURL: "https://openrouter.ai/api/v1",
648 ExtraHeaders: map[string]string{
649 "HTTP-Referer": "crush.charm.land",
650 "X-Title": "Crush",
651 },
652 }
653 case provider.InferenceProviderXAI:
654 return ProviderConfig{
655 ID: providerId,
656 ProviderType: provider.TypeXAI,
657 BaseURL: "https://api.x.ai/v1",
658 }
659 case provider.InferenceProviderVertexAI:
660 return ProviderConfig{
661 ID: providerId,
662 ProviderType: provider.TypeVertexAI,
663 }
664 default:
665 return ProviderConfig{
666 ID: providerId,
667 ProviderType: provider.TypeOpenAI,
668 }
669 }
670}
671
672func defaultConfigBasedOnEnv() *Config {
673 cfg := &Config{
674 Options: Options{
675 DataDirectory: defaultDataDirectory,
676 ContextPaths: defaultContextPaths,
677 },
678 Providers: make(map[provider.InferenceProvider]ProviderConfig),
679 Agents: make(map[AgentID]Agent),
680 LSP: make(map[string]LSPConfig),
681 MCP: make(map[string]MCP),
682 }
683
684 providers := Providers()
685
686 for _, p := range providers {
687 if strings.HasPrefix(p.APIKey, "$") {
688 envVar := strings.TrimPrefix(p.APIKey, "$")
689 if apiKey := os.Getenv(envVar); apiKey != "" {
690 providerConfig := providerDefaultConfig(p.ID)
691 providerConfig.APIKey = apiKey
692 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
693 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
694 baseURL := p.APIEndpoint
695 if strings.HasPrefix(baseURL, "$") {
696 envVar := strings.TrimPrefix(baseURL, "$")
697 baseURL = os.Getenv(envVar)
698 }
699 providerConfig.BaseURL = baseURL
700 for _, model := range p.Models {
701 configModel := Model{
702 ID: model.ID,
703 Name: model.Name,
704 CostPer1MIn: model.CostPer1MIn,
705 CostPer1MOut: model.CostPer1MOut,
706 CostPer1MInCached: model.CostPer1MInCached,
707 CostPer1MOutCached: model.CostPer1MOutCached,
708 ContextWindow: model.ContextWindow,
709 DefaultMaxTokens: model.DefaultMaxTokens,
710 CanReason: model.CanReason,
711 SupportsImages: model.SupportsImages,
712 }
713 // Set reasoning effort for reasoning models
714 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
715 configModel.HasReasoningEffort = model.HasReasoningEffort
716 configModel.ReasoningEffort = model.DefaultReasoningEffort
717 }
718 providerConfig.Models = append(providerConfig.Models, configModel)
719 }
720 cfg.Providers[p.ID] = providerConfig
721 }
722 }
723 }
724 // TODO: support local models
725
726 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
727 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
728 providerConfig.ExtraParams = map[string]string{
729 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
730 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
731 }
732 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
733 }
734
735 if hasAWSCredentials() {
736 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
737 providerConfig.ExtraParams = map[string]string{
738 "region": os.Getenv("AWS_DEFAULT_REGION"),
739 }
740 if providerConfig.ExtraParams["region"] == "" {
741 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
742 }
743 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
744 }
745 return cfg
746}
747
748func hasAWSCredentials() bool {
749 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
750 return true
751 }
752
753 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
754 return true
755 }
756
757 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
758 return true
759 }
760
761 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
762 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
763 return true
764 }
765
766 return false
767}
768
769func WorkingDirectory() string {
770 return cwd
771}
772
773// TODO: Handle error state
774
775func GetAgentModel(agentID AgentID) Model {
776 cfg := Get()
777 agent, ok := cfg.Agents[agentID]
778 if !ok {
779 logging.Error("Agent not found", "agent_id", agentID)
780 return Model{}
781 }
782
783 var model PreferredModel
784 switch agent.Model {
785 case LargeModel:
786 model = cfg.Models.Large
787 case SmallModel:
788 model = cfg.Models.Small
789 default:
790 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
791 model = cfg.Models.Large // Fallback to large model
792 }
793 providerConfig, ok := cfg.Providers[model.Provider]
794 if !ok {
795 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
796 return Model{}
797 }
798
799 for _, m := range providerConfig.Models {
800 if m.ID == model.ModelID {
801 return m
802 }
803 }
804
805 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
806 return Model{}
807}
808
809func GetAgentProvider(agentID AgentID) ProviderConfig {
810 cfg := Get()
811 agent, ok := cfg.Agents[agentID]
812 if !ok {
813 logging.Error("Agent not found", "agent_id", agentID)
814 return ProviderConfig{}
815 }
816
817 var model PreferredModel
818 switch agent.Model {
819 case LargeModel:
820 model = cfg.Models.Large
821 case SmallModel:
822 model = cfg.Models.Small
823 default:
824 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
825 model = cfg.Models.Large // Fallback to large model
826 }
827
828 providerConfig, ok := cfg.Providers[model.Provider]
829 if !ok {
830 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
831 return ProviderConfig{}
832 }
833
834 return providerConfig
835}
836
837func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
838 cfg := Get()
839 providerConfig, ok := cfg.Providers[provider]
840 if !ok {
841 logging.Error("Provider not found", "provider", provider)
842 return Model{}
843 }
844
845 for _, model := range providerConfig.Models {
846 if model.ID == modelID {
847 return model
848 }
849 }
850
851 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
852 return Model{}
853}
854
855func GetModel(modelType ModelType) Model {
856 cfg := Get()
857 var model PreferredModel
858 switch modelType {
859 case LargeModel:
860 model = cfg.Models.Large
861 case SmallModel:
862 model = cfg.Models.Small
863 default:
864 model = cfg.Models.Large // Fallback to large model
865 }
866 providerConfig, ok := cfg.Providers[model.Provider]
867 if !ok {
868 return Model{}
869 }
870
871 for _, m := range providerConfig.Models {
872 if m.ID == model.ModelID {
873 return m
874 }
875 }
876 return Model{}
877}
878
879func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
880 cfg := Get()
881 switch modelType {
882 case LargeModel:
883 cfg.Models.Large = model
884 case SmallModel:
885 cfg.Models.Small = model
886 default:
887 return fmt.Errorf("unknown model type: %s", modelType)
888 }
889 return nil
890}
891
892// ValidationError represents a configuration validation error
893type ValidationError struct {
894 Field string
895 Message string
896}
897
898func (e ValidationError) Error() string {
899 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
900}
901
902// ValidationErrors represents multiple validation errors
903type ValidationErrors []ValidationError
904
905func (e ValidationErrors) Error() string {
906 if len(e) == 0 {
907 return "no validation errors"
908 }
909 if len(e) == 1 {
910 return e[0].Error()
911 }
912
913 var messages []string
914 for _, err := range e {
915 messages = append(messages, err.Error())
916 }
917 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
918}
919
920// HasErrors returns true if there are any validation errors
921func (e ValidationErrors) HasErrors() bool {
922 return len(e) > 0
923}
924
925// Add appends a new validation error
926func (e *ValidationErrors) Add(field, message string) {
927 *e = append(*e, ValidationError{Field: field, Message: message})
928}
929
930// Validate performs comprehensive validation of the configuration
931func (c *Config) Validate() error {
932 var errors ValidationErrors
933
934 // Validate providers
935 c.validateProviders(&errors)
936
937 // Validate models
938 c.validateModels(&errors)
939
940 // Validate agents
941 c.validateAgents(&errors)
942
943 // Validate options
944 c.validateOptions(&errors)
945
946 // Validate MCP configurations
947 c.validateMCPs(&errors)
948
949 // Validate LSP configurations
950 c.validateLSPs(&errors)
951
952 // Validate cross-references
953 c.validateCrossReferences(&errors)
954
955 // Validate completeness
956 c.validateCompleteness(&errors)
957
958 if errors.HasErrors() {
959 return errors
960 }
961
962 return nil
963}
964
965// validateProviders validates all provider configurations
966func (c *Config) validateProviders(errors *ValidationErrors) {
967 if c.Providers == nil {
968 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
969 }
970
971 knownProviders := provider.KnownProviders()
972 validTypes := []provider.Type{
973 provider.TypeOpenAI,
974 provider.TypeAnthropic,
975 provider.TypeGemini,
976 provider.TypeAzure,
977 provider.TypeBedrock,
978 provider.TypeVertexAI,
979 provider.TypeXAI,
980 }
981
982 for providerID, providerConfig := range c.Providers {
983 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
984
985 // Validate API key for non-disabled providers
986 if !providerConfig.Disabled && providerConfig.APIKey == "" {
987 // Special case for AWS Bedrock and VertexAI which may use other auth methods
988 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
989 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
990 }
991 }
992
993 // Validate provider type
994 validType := slices.Contains(validTypes, providerConfig.ProviderType)
995 if !validType {
996 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
997 }
998
999 // Validate custom providers
1000 isKnownProvider := slices.Contains(knownProviders, providerID)
1001
1002 if !isKnownProvider {
1003 // Custom provider validation
1004 if providerConfig.BaseURL == "" {
1005 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1006 }
1007 if providerConfig.ProviderType != provider.TypeOpenAI {
1008 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1009 }
1010 }
1011
1012 // Validate models
1013 modelIDs := make(map[string]bool)
1014 for i, model := range providerConfig.Models {
1015 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1016
1017 // Check for duplicate model IDs
1018 if modelIDs[model.ID] {
1019 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1020 }
1021 modelIDs[model.ID] = true
1022
1023 // Validate required model fields
1024 if model.ID == "" {
1025 errors.Add(modelFieldPrefix+".id", "model ID is required")
1026 }
1027 if model.Name == "" {
1028 errors.Add(modelFieldPrefix+".name", "model name is required")
1029 }
1030 if model.ContextWindow <= 0 {
1031 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1032 }
1033 if model.DefaultMaxTokens <= 0 {
1034 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1035 }
1036 if model.DefaultMaxTokens > model.ContextWindow {
1037 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1038 }
1039
1040 // Validate cost fields
1041 if model.CostPer1MIn < 0 {
1042 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1043 }
1044 if model.CostPer1MOut < 0 {
1045 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1046 }
1047 if model.CostPer1MInCached < 0 {
1048 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1049 }
1050 if model.CostPer1MOutCached < 0 {
1051 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1052 }
1053 }
1054
1055 // Validate default model references
1056 if providerConfig.DefaultLargeModel != "" {
1057 if !modelIDs[providerConfig.DefaultLargeModel] {
1058 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1059 }
1060 }
1061 if providerConfig.DefaultSmallModel != "" {
1062 if !modelIDs[providerConfig.DefaultSmallModel] {
1063 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1064 }
1065 }
1066
1067 // Validate provider-specific requirements
1068 c.validateProviderSpecific(providerID, providerConfig, errors)
1069 }
1070}
1071
1072// validateProviderSpecific validates provider-specific requirements
1073func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1074 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1075
1076 switch providerID {
1077 case provider.InferenceProviderVertexAI:
1078 if !providerConfig.Disabled {
1079 if providerConfig.ExtraParams == nil {
1080 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1081 } else {
1082 if providerConfig.ExtraParams["project"] == "" {
1083 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1084 }
1085 if providerConfig.ExtraParams["location"] == "" {
1086 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1087 }
1088 }
1089 }
1090 case provider.InferenceProviderBedrock:
1091 if !providerConfig.Disabled {
1092 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1093 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1094 }
1095 // Check for AWS credentials in environment
1096 if !hasAWSCredentials() {
1097 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1098 }
1099 }
1100 }
1101}
1102
1103// validateModels validates preferred model configurations
1104func (c *Config) validateModels(errors *ValidationErrors) {
1105 // Validate large model
1106 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1107 if c.Models.Large.ModelID == "" {
1108 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1109 }
1110 if c.Models.Large.Provider == "" {
1111 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1112 }
1113
1114 // Check if provider exists and is not disabled
1115 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1116 if providerConfig.Disabled {
1117 errors.Add("models.large.provider", "large model provider is disabled")
1118 }
1119
1120 // Check if model exists in provider
1121 modelExists := false
1122 for _, model := range providerConfig.Models {
1123 if model.ID == c.Models.Large.ModelID {
1124 modelExists = true
1125 break
1126 }
1127 }
1128 if !modelExists {
1129 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1130 }
1131 } else {
1132 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1133 }
1134 }
1135
1136 // Validate small model
1137 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1138 if c.Models.Small.ModelID == "" {
1139 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1140 }
1141 if c.Models.Small.Provider == "" {
1142 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1143 }
1144
1145 // Check if provider exists and is not disabled
1146 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1147 if providerConfig.Disabled {
1148 errors.Add("models.small.provider", "small model provider is disabled")
1149 }
1150
1151 // Check if model exists in provider
1152 modelExists := false
1153 for _, model := range providerConfig.Models {
1154 if model.ID == c.Models.Small.ModelID {
1155 modelExists = true
1156 break
1157 }
1158 }
1159 if !modelExists {
1160 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1161 }
1162 } else {
1163 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1164 }
1165 }
1166}
1167
1168// validateAgents validates agent configurations
1169func (c *Config) validateAgents(errors *ValidationErrors) {
1170 if c.Agents == nil {
1171 c.Agents = make(map[AgentID]Agent)
1172 }
1173
1174 validTools := []string{
1175 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1176 }
1177
1178 for agentID, agent := range c.Agents {
1179 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1180
1181 // Validate agent ID consistency
1182 if agent.ID != agentID {
1183 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1184 }
1185
1186 // Validate required fields
1187 if agent.ID == "" {
1188 errors.Add(fieldPrefix+".id", "agent ID is required")
1189 }
1190 if agent.Name == "" {
1191 errors.Add(fieldPrefix+".name", "agent name is required")
1192 }
1193
1194 // Validate model type
1195 if agent.Model != LargeModel && agent.Model != SmallModel {
1196 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1197 }
1198
1199 // Validate allowed tools
1200 if agent.AllowedTools != nil {
1201 for i, tool := range agent.AllowedTools {
1202 validTool := slices.Contains(validTools, tool)
1203 if !validTool {
1204 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1205 }
1206 }
1207 }
1208
1209 // Validate MCP references
1210 if agent.AllowedMCP != nil {
1211 for mcpName := range agent.AllowedMCP {
1212 if _, exists := c.MCP[mcpName]; !exists {
1213 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1214 }
1215 }
1216 }
1217
1218 // Validate LSP references
1219 if agent.AllowedLSP != nil {
1220 for _, lspName := range agent.AllowedLSP {
1221 if _, exists := c.LSP[lspName]; !exists {
1222 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1223 }
1224 }
1225 }
1226
1227 // Validate context paths (basic path validation)
1228 for i, contextPath := range agent.ContextPaths {
1229 if contextPath == "" {
1230 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1231 }
1232 // Check for invalid characters in path
1233 if strings.Contains(contextPath, "\x00") {
1234 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1235 }
1236 }
1237
1238 // Validate known agents maintain their core properties
1239 if agentID == AgentCoder {
1240 if agent.Name != "Coder" {
1241 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1242 }
1243 if agent.Description != "An agent that helps with executing coding tasks." {
1244 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1245 }
1246 } else if agentID == AgentTask {
1247 if agent.Name != "Task" {
1248 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1249 }
1250 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1251 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1252 }
1253 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1254 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1255 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1256 }
1257 }
1258 }
1259}
1260
1261// validateOptions validates configuration options
1262func (c *Config) validateOptions(errors *ValidationErrors) {
1263 // Validate data directory
1264 if c.Options.DataDirectory == "" {
1265 errors.Add("options.data_directory", "data directory is required")
1266 }
1267
1268 // Validate context paths
1269 for i, contextPath := range c.Options.ContextPaths {
1270 if contextPath == "" {
1271 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1272 }
1273 if strings.Contains(contextPath, "\x00") {
1274 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1275 }
1276 }
1277}
1278
1279// validateMCPs validates MCP configurations
1280func (c *Config) validateMCPs(errors *ValidationErrors) {
1281 if c.MCP == nil {
1282 c.MCP = make(map[string]MCP)
1283 }
1284
1285 for mcpName, mcpConfig := range c.MCP {
1286 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1287
1288 // Validate MCP type
1289 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1290 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1291 }
1292
1293 // Validate based on type
1294 if mcpConfig.Type == MCPStdio {
1295 if mcpConfig.Command == "" {
1296 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1297 }
1298 } else if mcpConfig.Type == MCPSse {
1299 if mcpConfig.URL == "" {
1300 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1301 }
1302 }
1303 }
1304}
1305
1306// validateLSPs validates LSP configurations
1307func (c *Config) validateLSPs(errors *ValidationErrors) {
1308 if c.LSP == nil {
1309 c.LSP = make(map[string]LSPConfig)
1310 }
1311
1312 for lspName, lspConfig := range c.LSP {
1313 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1314
1315 if lspConfig.Command == "" {
1316 errors.Add(fieldPrefix+".command", "command is required for LSP")
1317 }
1318 }
1319}
1320
1321// validateCrossReferences validates cross-references between different config sections
1322func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1323 // Validate that agents can use their assigned model types
1324 for agentID, agent := range c.Agents {
1325 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1326
1327 var preferredModel PreferredModel
1328 switch agent.Model {
1329 case LargeModel:
1330 preferredModel = c.Models.Large
1331 case SmallModel:
1332 preferredModel = c.Models.Small
1333 }
1334
1335 if preferredModel.Provider != "" {
1336 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1337 if providerConfig.Disabled {
1338 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1339 }
1340 }
1341 }
1342 }
1343}
1344
1345// validateCompleteness validates that the configuration is complete and usable
1346func (c *Config) validateCompleteness(errors *ValidationErrors) {
1347 // Check for at least one valid, non-disabled provider
1348 hasValidProvider := false
1349 for _, providerConfig := range c.Providers {
1350 if !providerConfig.Disabled {
1351 hasValidProvider = true
1352 break
1353 }
1354 }
1355 if !hasValidProvider {
1356 errors.Add("providers", "at least one non-disabled provider is required")
1357 }
1358
1359 // Check that default agents exist
1360 if _, exists := c.Agents[AgentCoder]; !exists {
1361 errors.Add("agents", "coder agent is required")
1362 }
1363 if _, exists := c.Agents[AgentTask]; !exists {
1364 errors.Add("agents", "task agent is required")
1365 }
1366
1367 // Check that preferred models are set if providers exist
1368 if hasValidProvider {
1369 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1370 errors.Add("models.large", "large preferred model must be configured when providers are available")
1371 }
1372 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1373 errors.Add("models.small", "small preferred model must be configured when providers are available")
1374 }
1375 }
1376}