1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 HasReasoningEffort bool `json:"has_reasoning_effort"`
69 SupportsImages bool `json:"supports_attachments"`
70}
71
72type VertexAIOptions struct {
73 APIKey string `json:"api_key,omitempty"`
74 Project string `json:"project,omitempty"`
75 Location string `json:"location,omitempty"`
76}
77
78type ProviderConfig struct {
79 ID provider.InferenceProvider `json:"id"`
80 BaseURL string `json:"base_url,omitempty"`
81 ProviderType provider.Type `json:"provider_type"`
82 APIKey string `json:"api_key,omitempty"`
83 Disabled bool `json:"disabled"`
84 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
85 // used for e.x for vertex to set the project
86 ExtraParams map[string]string `json:"extra_params,omitempty"`
87
88 DefaultLargeModel string `json:"default_large_model,omitempty"`
89 DefaultSmallModel string `json:"default_small_model,omitempty"`
90
91 Models []Model `json:"models,omitempty"`
92}
93
94type Agent struct {
95 ID AgentID `json:"id"`
96 Name string `json:"name"`
97 Description string `json:"description,omitempty"`
98 // This is the id of the system prompt used by the agent
99 Disabled bool `json:"disabled"`
100
101 Model ModelType `json:"model"`
102
103 // The available tools for the agent
104 // if this is nil, all tools are available
105 AllowedTools []string `json:"allowed_tools"`
106
107 // this tells us which MCPs are available for this agent
108 // if this is empty all mcps are available
109 // the string array is the list of tools from the AllowedMCP the agent has available
110 // if the string array is nil, all tools from the AllowedMCP are available
111 AllowedMCP map[string][]string `json:"allowed_mcp"`
112
113 // The list of LSPs that this agent can use
114 // if this is nil, all LSPs are available
115 AllowedLSP []string `json:"allowed_lsp"`
116
117 // Overrides the context paths for this agent
118 ContextPaths []string `json:"context_paths"`
119}
120
121type MCPType string
122
123const (
124 MCPStdio MCPType = "stdio"
125 MCPSse MCPType = "sse"
126)
127
128type MCP struct {
129 Command string `json:"command"`
130 Env []string `json:"env"`
131 Args []string `json:"args"`
132 Type MCPType `json:"type"`
133 URL string `json:"url"`
134 Headers map[string]string `json:"headers"`
135}
136
137type LSPConfig struct {
138 Disabled bool `json:"enabled"`
139 Command string `json:"command"`
140 Args []string `json:"args"`
141 Options any `json:"options"`
142}
143
144type TUIOptions struct {
145 CompactMode bool `json:"compact_mode"`
146 // Here we can add themes later or any TUI related options
147}
148
149type Options struct {
150 ContextPaths []string `json:"context_paths"`
151 TUI TUIOptions `json:"tui"`
152 Debug bool `json:"debug"`
153 DebugLSP bool `json:"debug_lsp"`
154 DisableAutoSummarize bool `json:"disable_auto_summarize"`
155 // Relative to the cwd
156 DataDirectory string `json:"data_directory"`
157}
158
159type PreferredModel struct {
160 ModelID string `json:"model_id"`
161 Provider provider.InferenceProvider `json:"provider"`
162 // ReasoningEffort overrides the default reasoning effort for this model
163 ReasoningEffort string `json:"reasoning_effort,omitempty"`
164 // MaxTokens overrides the default max tokens for this model
165 MaxTokens int64 `json:"max_tokens,omitempty"`
166
167 // Think indicates if the model should think, only applicable for anthropic reasoning models
168 Think bool `json:"think,omitempty"`
169}
170
171type PreferredModels struct {
172 Large PreferredModel `json:"large"`
173 Small PreferredModel `json:"small"`
174}
175
176type Config struct {
177 Models PreferredModels `json:"models"`
178 // List of configured providers
179 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
180
181 // List of configured agents
182 Agents map[AgentID]Agent `json:"agents,omitempty"`
183
184 // List of configured MCPs
185 MCP map[string]MCP `json:"mcp,omitempty"`
186
187 // List of configured LSPs
188 LSP map[string]LSPConfig `json:"lsp,omitempty"`
189
190 // Miscellaneous options
191 Options Options `json:"options"`
192}
193
194var (
195 instance *Config // The single instance of the Singleton
196 cwd string
197 once sync.Once // Ensures the initialization happens only once
198
199)
200
201func loadConfig(cwd string, debug bool) (*Config, error) {
202 // First read the global config file
203 cfgPath := ConfigPath()
204
205 cfg := defaultConfigBasedOnEnv()
206 cfg.Options.Debug = debug
207 defaultLevel := slog.LevelInfo
208 if cfg.Options.Debug {
209 defaultLevel = slog.LevelDebug
210 }
211 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
212 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
213
214 // if file does not exist create it
215 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
216 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
217 return cfg, fmt.Errorf("failed to create directory: %w", err)
218 }
219 if _, err := os.Create(loggingFile); err != nil {
220 return cfg, fmt.Errorf("failed to create log file: %w", err)
221 }
222 }
223
224 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
225 if err != nil {
226 return cfg, fmt.Errorf("failed to open log file: %w", err)
227 }
228 // Configure logger
229 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
230 Level: defaultLevel,
231 }))
232 slog.SetDefault(logger)
233 } else {
234 // Configure logger
235 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
236 Level: defaultLevel,
237 }))
238 slog.SetDefault(logger)
239 }
240 var globalCfg *Config
241 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
242 // some other error occurred while checking the file
243 return nil, err
244 } else if err == nil {
245 // config file exists, read it
246 file, err := os.ReadFile(cfgPath)
247 if err != nil {
248 return nil, err
249 }
250 globalCfg = &Config{}
251 if err := json.Unmarshal(file, globalCfg); err != nil {
252 return nil, err
253 }
254 } else {
255 // config file does not exist, create a new one
256 globalCfg = &Config{}
257 }
258
259 var localConfig *Config
260 // Global config loaded, now read the local config file
261 localConfigPath := filepath.Join(cwd, "crush.json")
262 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
263 // some other error occurred while checking the file
264 return nil, err
265 } else if err == nil {
266 // local config file exists, read it
267 file, err := os.ReadFile(localConfigPath)
268 if err != nil {
269 return nil, err
270 }
271 localConfig = &Config{}
272 if err := json.Unmarshal(file, localConfig); err != nil {
273 return nil, err
274 }
275 }
276
277 // merge options
278 mergeOptions(cfg, globalCfg, localConfig)
279
280 mergeProviderConfigs(cfg, globalCfg, localConfig)
281 // no providers found the app is not initialized yet
282 if len(cfg.Providers) == 0 {
283 return cfg, nil
284 }
285 preferredProvider := getPreferredProvider(cfg.Providers)
286 if preferredProvider != nil {
287 cfg.Models = PreferredModels{
288 Large: PreferredModel{
289 ModelID: preferredProvider.DefaultLargeModel,
290 Provider: preferredProvider.ID,
291 },
292 Small: PreferredModel{
293 ModelID: preferredProvider.DefaultSmallModel,
294 Provider: preferredProvider.ID,
295 },
296 }
297 } else {
298 // No valid providers found, set empty models
299 cfg.Models = PreferredModels{}
300 }
301
302 mergeModels(cfg, globalCfg, localConfig)
303
304 agents := map[AgentID]Agent{
305 AgentCoder: {
306 ID: AgentCoder,
307 Name: "Coder",
308 Description: "An agent that helps with executing coding tasks.",
309 Model: LargeModel,
310 ContextPaths: cfg.Options.ContextPaths,
311 // All tools allowed
312 },
313 AgentTask: {
314 ID: AgentTask,
315 Name: "Task",
316 Description: "An agent that helps with searching for context and finding implementation details.",
317 Model: LargeModel,
318 ContextPaths: cfg.Options.ContextPaths,
319 AllowedTools: []string{
320 "glob",
321 "grep",
322 "ls",
323 "sourcegraph",
324 "view",
325 },
326 // NO MCPs or LSPs by default
327 AllowedMCP: map[string][]string{},
328 AllowedLSP: []string{},
329 },
330 }
331 cfg.Agents = agents
332 mergeAgents(cfg, globalCfg, localConfig)
333 mergeMCPs(cfg, globalCfg, localConfig)
334 mergeLSPs(cfg, globalCfg, localConfig)
335
336 // Validate the final configuration
337 if err := cfg.Validate(); err != nil {
338 return cfg, fmt.Errorf("configuration validation failed: %w", err)
339 }
340
341 return cfg, nil
342}
343
344func Init(workingDir string, debug bool) (*Config, error) {
345 var err error
346 once.Do(func() {
347 cwd = workingDir
348 instance, err = loadConfig(cwd, debug)
349 if err != nil {
350 logging.Error("Failed to load config", "error", err)
351 }
352 })
353
354 return instance, err
355}
356
357func Get() *Config {
358 if instance == nil {
359 // TODO: Handle this better
360 panic("Config not initialized. Call InitConfig first.")
361 }
362 return instance
363}
364
365func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
366 providers := Providers()
367 for _, p := range providers {
368 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
369 return &providerConfig
370 }
371 }
372 // if none found return the first configured provider
373 for _, providerConfig := range configuredProviders {
374 if !providerConfig.Disabled {
375 return &providerConfig
376 }
377 }
378 return nil
379}
380
381func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
382 if other.APIKey != "" {
383 base.APIKey = other.APIKey
384 }
385 // Only change these options if the provider is not a known provider
386 if !slices.Contains(provider.KnownProviders(), p) {
387 if other.BaseURL != "" {
388 base.BaseURL = other.BaseURL
389 }
390 if other.ProviderType != "" {
391 base.ProviderType = other.ProviderType
392 }
393 if len(other.ExtraHeaders) > 0 {
394 if base.ExtraHeaders == nil {
395 base.ExtraHeaders = make(map[string]string)
396 }
397 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
398 }
399 if len(other.ExtraParams) > 0 {
400 if base.ExtraParams == nil {
401 base.ExtraParams = make(map[string]string)
402 }
403 maps.Copy(base.ExtraParams, other.ExtraParams)
404 }
405 }
406
407 if other.Disabled {
408 base.Disabled = other.Disabled
409 }
410
411 if other.DefaultLargeModel != "" {
412 base.DefaultLargeModel = other.DefaultLargeModel
413 }
414 // Add new models if they don't exist
415 if other.Models != nil {
416 for _, model := range other.Models {
417 // check if the model already exists
418 exists := false
419 for _, existingModel := range base.Models {
420 if existingModel.ID == model.ID {
421 exists = true
422 break
423 }
424 }
425 if !exists {
426 base.Models = append(base.Models, model)
427 }
428 }
429 }
430
431 return base
432}
433
434func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
435 if !slices.Contains(provider.KnownProviders(), p) {
436 if providerConfig.ProviderType != provider.TypeOpenAI {
437 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
438 }
439 if providerConfig.BaseURL == "" {
440 return errors.New("base URL must be set for custom providers")
441 }
442 if providerConfig.APIKey == "" {
443 return errors.New("API key must be set for custom providers")
444 }
445 }
446 return nil
447}
448
449func mergeModels(base, global, local *Config) {
450 for _, cfg := range []*Config{global, local} {
451 if cfg == nil {
452 continue
453 }
454 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
455 base.Models.Large = cfg.Models.Large
456 }
457
458 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
459 base.Models.Small = cfg.Models.Small
460 }
461 }
462}
463
464func mergeOptions(base, global, local *Config) {
465 for _, cfg := range []*Config{global, local} {
466 if cfg == nil {
467 continue
468 }
469 baseOptions := base.Options
470 other := cfg.Options
471 if len(other.ContextPaths) > 0 {
472 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
473 }
474
475 if other.TUI.CompactMode {
476 baseOptions.TUI.CompactMode = other.TUI.CompactMode
477 }
478
479 if other.Debug {
480 baseOptions.Debug = other.Debug
481 }
482
483 if other.DebugLSP {
484 baseOptions.DebugLSP = other.DebugLSP
485 }
486
487 if other.DisableAutoSummarize {
488 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
489 }
490
491 if other.DataDirectory != "" {
492 baseOptions.DataDirectory = other.DataDirectory
493 }
494 base.Options = baseOptions
495 }
496}
497
498func mergeAgents(base, global, local *Config) {
499 for _, cfg := range []*Config{global, local} {
500 if cfg == nil {
501 continue
502 }
503 for agentID, newAgent := range cfg.Agents {
504 if _, ok := base.Agents[agentID]; !ok {
505 // New agent - apply defaults
506 newAgent.ID = agentID // Ensure the ID is set correctly
507 if newAgent.Model == "" {
508 newAgent.Model = LargeModel // Default model type
509 }
510 // Context paths are always additive - start with global, then add custom
511 if len(newAgent.ContextPaths) > 0 {
512 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
513 } else {
514 newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
515 }
516 base.Agents[agentID] = newAgent
517 } else {
518 baseAgent := base.Agents[agentID]
519
520 // Special handling for known agents - only allow model changes
521 if agentID == AgentCoder || agentID == AgentTask {
522 if newAgent.Model != "" {
523 baseAgent.Model = newAgent.Model
524 }
525 // For known agents, only allow MCP and LSP configuration
526 if newAgent.AllowedMCP != nil {
527 baseAgent.AllowedMCP = newAgent.AllowedMCP
528 }
529 if newAgent.AllowedLSP != nil {
530 baseAgent.AllowedLSP = newAgent.AllowedLSP
531 }
532 // Context paths are additive for known agents too
533 if len(newAgent.ContextPaths) > 0 {
534 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
535 }
536 } else {
537 // Custom agents - allow full merging
538 if newAgent.Name != "" {
539 baseAgent.Name = newAgent.Name
540 }
541 if newAgent.Description != "" {
542 baseAgent.Description = newAgent.Description
543 }
544 if newAgent.Model != "" {
545 baseAgent.Model = newAgent.Model
546 } else if baseAgent.Model == "" {
547 baseAgent.Model = LargeModel // Default fallback
548 }
549
550 // Boolean fields - always update (including false values)
551 baseAgent.Disabled = newAgent.Disabled
552
553 // Slice/Map fields - update if provided (including empty slices/maps)
554 if newAgent.AllowedTools != nil {
555 baseAgent.AllowedTools = newAgent.AllowedTools
556 }
557 if newAgent.AllowedMCP != nil {
558 baseAgent.AllowedMCP = newAgent.AllowedMCP
559 }
560 if newAgent.AllowedLSP != nil {
561 baseAgent.AllowedLSP = newAgent.AllowedLSP
562 }
563 // Context paths are additive for custom agents too
564 if len(newAgent.ContextPaths) > 0 {
565 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
566 }
567 }
568
569 base.Agents[agentID] = baseAgent
570 }
571 }
572 }
573}
574
575func mergeMCPs(base, global, local *Config) {
576 for _, cfg := range []*Config{global, local} {
577 if cfg == nil {
578 continue
579 }
580 maps.Copy(base.MCP, cfg.MCP)
581 }
582}
583
584func mergeLSPs(base, global, local *Config) {
585 for _, cfg := range []*Config{global, local} {
586 if cfg == nil {
587 continue
588 }
589 maps.Copy(base.LSP, cfg.LSP)
590 }
591}
592
593func mergeProviderConfigs(base, global, local *Config) {
594 for _, cfg := range []*Config{global, local} {
595 if cfg == nil {
596 continue
597 }
598 for providerName, p := range cfg.Providers {
599 if _, ok := base.Providers[providerName]; !ok {
600 base.Providers[providerName] = p
601 } else {
602 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
603 }
604 }
605 }
606
607 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
608 for providerName, providerConfig := range base.Providers {
609 err := validateProvider(providerName, providerConfig)
610 if err != nil {
611 logging.Warn("Skipping provider", "name", providerName, "error", err)
612 continue // Skip invalid providers
613 }
614 finalProviders[providerName] = providerConfig
615 }
616 base.Providers = finalProviders
617}
618
619func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
620 switch providerId {
621 case provider.InferenceProviderAnthropic:
622 return ProviderConfig{
623 ID: providerId,
624 ProviderType: provider.TypeAnthropic,
625 }
626 case provider.InferenceProviderOpenAI:
627 return ProviderConfig{
628 ID: providerId,
629 ProviderType: provider.TypeOpenAI,
630 }
631 case provider.InferenceProviderGemini:
632 return ProviderConfig{
633 ID: providerId,
634 ProviderType: provider.TypeGemini,
635 }
636 case provider.InferenceProviderBedrock:
637 return ProviderConfig{
638 ID: providerId,
639 ProviderType: provider.TypeBedrock,
640 }
641 case provider.InferenceProviderAzure:
642 return ProviderConfig{
643 ID: providerId,
644 ProviderType: provider.TypeAzure,
645 }
646 case provider.InferenceProviderOpenRouter:
647 return ProviderConfig{
648 ID: providerId,
649 ProviderType: provider.TypeOpenAI,
650 BaseURL: "https://openrouter.ai/api/v1",
651 ExtraHeaders: map[string]string{
652 "HTTP-Referer": "crush.charm.land",
653 "X-Title": "Crush",
654 },
655 }
656 case provider.InferenceProviderXAI:
657 return ProviderConfig{
658 ID: providerId,
659 ProviderType: provider.TypeXAI,
660 BaseURL: "https://api.x.ai/v1",
661 }
662 case provider.InferenceProviderVertexAI:
663 return ProviderConfig{
664 ID: providerId,
665 ProviderType: provider.TypeVertexAI,
666 }
667 default:
668 return ProviderConfig{
669 ID: providerId,
670 ProviderType: provider.TypeOpenAI,
671 }
672 }
673}
674
675func defaultConfigBasedOnEnv() *Config {
676 cfg := &Config{
677 Options: Options{
678 DataDirectory: defaultDataDirectory,
679 ContextPaths: defaultContextPaths,
680 },
681 Providers: make(map[provider.InferenceProvider]ProviderConfig),
682 Agents: make(map[AgentID]Agent),
683 LSP: make(map[string]LSPConfig),
684 MCP: make(map[string]MCP),
685 }
686
687 providers := Providers()
688
689 for _, p := range providers {
690 if strings.HasPrefix(p.APIKey, "$") {
691 envVar := strings.TrimPrefix(p.APIKey, "$")
692 if apiKey := os.Getenv(envVar); apiKey != "" {
693 providerConfig := providerDefaultConfig(p.ID)
694 providerConfig.APIKey = apiKey
695 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
696 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
697 baseURL := p.APIEndpoint
698 if strings.HasPrefix(baseURL, "$") {
699 envVar := strings.TrimPrefix(baseURL, "$")
700 baseURL = os.Getenv(envVar)
701 }
702 providerConfig.BaseURL = baseURL
703 for _, model := range p.Models {
704 configModel := Model{
705 ID: model.ID,
706 Name: model.Name,
707 CostPer1MIn: model.CostPer1MIn,
708 CostPer1MOut: model.CostPer1MOut,
709 CostPer1MInCached: model.CostPer1MInCached,
710 CostPer1MOutCached: model.CostPer1MOutCached,
711 ContextWindow: model.ContextWindow,
712 DefaultMaxTokens: model.DefaultMaxTokens,
713 CanReason: model.CanReason,
714 SupportsImages: model.SupportsImages,
715 }
716 // Set reasoning effort for reasoning models
717 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
718 configModel.HasReasoningEffort = model.HasReasoningEffort
719 configModel.ReasoningEffort = model.DefaultReasoningEffort
720 }
721 providerConfig.Models = append(providerConfig.Models, configModel)
722 }
723 cfg.Providers[p.ID] = providerConfig
724 }
725 }
726 }
727 // TODO: support local models
728
729 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
730 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
731 providerConfig.ExtraParams = map[string]string{
732 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
733 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
734 }
735 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
736 }
737
738 if hasAWSCredentials() {
739 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
740 providerConfig.ExtraParams = map[string]string{
741 "region": os.Getenv("AWS_DEFAULT_REGION"),
742 }
743 if providerConfig.ExtraParams["region"] == "" {
744 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
745 }
746 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
747 }
748 return cfg
749}
750
751func hasAWSCredentials() bool {
752 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
753 return true
754 }
755
756 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
757 return true
758 }
759
760 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
761 return true
762 }
763
764 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
765 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
766 return true
767 }
768
769 return false
770}
771
772func WorkingDirectory() string {
773 return cwd
774}
775
776// TODO: Handle error state
777
778func GetAgentModel(agentID AgentID) Model {
779 cfg := Get()
780 agent, ok := cfg.Agents[agentID]
781 if !ok {
782 logging.Error("Agent not found", "agent_id", agentID)
783 return Model{}
784 }
785
786 var model PreferredModel
787 switch agent.Model {
788 case LargeModel:
789 model = cfg.Models.Large
790 case SmallModel:
791 model = cfg.Models.Small
792 default:
793 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
794 model = cfg.Models.Large // Fallback to large model
795 }
796 providerConfig, ok := cfg.Providers[model.Provider]
797 if !ok {
798 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
799 return Model{}
800 }
801
802 for _, m := range providerConfig.Models {
803 if m.ID == model.ModelID {
804 return m
805 }
806 }
807
808 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
809 return Model{}
810}
811
812func GetAgentProvider(agentID AgentID) ProviderConfig {
813 cfg := Get()
814 agent, ok := cfg.Agents[agentID]
815 if !ok {
816 logging.Error("Agent not found", "agent_id", agentID)
817 return ProviderConfig{}
818 }
819
820 var model PreferredModel
821 switch agent.Model {
822 case LargeModel:
823 model = cfg.Models.Large
824 case SmallModel:
825 model = cfg.Models.Small
826 default:
827 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
828 model = cfg.Models.Large // Fallback to large model
829 }
830
831 providerConfig, ok := cfg.Providers[model.Provider]
832 if !ok {
833 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
834 return ProviderConfig{}
835 }
836
837 return providerConfig
838}
839
840func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
841 cfg := Get()
842 providerConfig, ok := cfg.Providers[provider]
843 if !ok {
844 logging.Error("Provider not found", "provider", provider)
845 return Model{}
846 }
847
848 for _, model := range providerConfig.Models {
849 if model.ID == modelID {
850 return model
851 }
852 }
853
854 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
855 return Model{}
856}
857
858func GetModel(modelType ModelType) Model {
859 cfg := Get()
860 var model PreferredModel
861 switch modelType {
862 case LargeModel:
863 model = cfg.Models.Large
864 case SmallModel:
865 model = cfg.Models.Small
866 default:
867 model = cfg.Models.Large // Fallback to large model
868 }
869 providerConfig, ok := cfg.Providers[model.Provider]
870 if !ok {
871 return Model{}
872 }
873
874 for _, m := range providerConfig.Models {
875 if m.ID == model.ModelID {
876 return m
877 }
878 }
879 return Model{}
880}
881
882func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
883 cfg := Get()
884 switch modelType {
885 case LargeModel:
886 cfg.Models.Large = model
887 case SmallModel:
888 cfg.Models.Small = model
889 default:
890 return fmt.Errorf("unknown model type: %s", modelType)
891 }
892 return nil
893}
894
895// ValidationError represents a configuration validation error
896type ValidationError struct {
897 Field string
898 Message string
899}
900
901func (e ValidationError) Error() string {
902 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
903}
904
905// ValidationErrors represents multiple validation errors
906type ValidationErrors []ValidationError
907
908func (e ValidationErrors) Error() string {
909 if len(e) == 0 {
910 return "no validation errors"
911 }
912 if len(e) == 1 {
913 return e[0].Error()
914 }
915
916 var messages []string
917 for _, err := range e {
918 messages = append(messages, err.Error())
919 }
920 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
921}
922
923// HasErrors returns true if there are any validation errors
924func (e ValidationErrors) HasErrors() bool {
925 return len(e) > 0
926}
927
928// Add appends a new validation error
929func (e *ValidationErrors) Add(field, message string) {
930 *e = append(*e, ValidationError{Field: field, Message: message})
931}
932
933// Validate performs comprehensive validation of the configuration
934func (c *Config) Validate() error {
935 var errors ValidationErrors
936
937 // Validate providers
938 c.validateProviders(&errors)
939
940 // Validate models
941 c.validateModels(&errors)
942
943 // Validate agents
944 c.validateAgents(&errors)
945
946 // Validate options
947 c.validateOptions(&errors)
948
949 // Validate MCP configurations
950 c.validateMCPs(&errors)
951
952 // Validate LSP configurations
953 c.validateLSPs(&errors)
954
955 // Validate cross-references
956 c.validateCrossReferences(&errors)
957
958 // Validate completeness
959 c.validateCompleteness(&errors)
960
961 if errors.HasErrors() {
962 return errors
963 }
964
965 return nil
966}
967
968// validateProviders validates all provider configurations
969func (c *Config) validateProviders(errors *ValidationErrors) {
970 if c.Providers == nil {
971 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
972 }
973
974 knownProviders := provider.KnownProviders()
975 validTypes := []provider.Type{
976 provider.TypeOpenAI,
977 provider.TypeAnthropic,
978 provider.TypeGemini,
979 provider.TypeAzure,
980 provider.TypeBedrock,
981 provider.TypeVertexAI,
982 provider.TypeXAI,
983 }
984
985 for providerID, providerConfig := range c.Providers {
986 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
987
988 // Validate API key for non-disabled providers
989 if !providerConfig.Disabled && providerConfig.APIKey == "" {
990 // Special case for AWS Bedrock and VertexAI which may use other auth methods
991 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
992 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
993 }
994 }
995
996 // Validate provider type
997 validType := slices.Contains(validTypes, providerConfig.ProviderType)
998 if !validType {
999 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
1000 }
1001
1002 // Validate custom providers
1003 isKnownProvider := slices.Contains(knownProviders, providerID)
1004
1005 if !isKnownProvider {
1006 // Custom provider validation
1007 if providerConfig.BaseURL == "" {
1008 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1009 }
1010 if providerConfig.ProviderType != provider.TypeOpenAI {
1011 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1012 }
1013 }
1014
1015 // Validate models
1016 modelIDs := make(map[string]bool)
1017 for i, model := range providerConfig.Models {
1018 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1019
1020 // Check for duplicate model IDs
1021 if modelIDs[model.ID] {
1022 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1023 }
1024 modelIDs[model.ID] = true
1025
1026 // Validate required model fields
1027 if model.ID == "" {
1028 errors.Add(modelFieldPrefix+".id", "model ID is required")
1029 }
1030 if model.Name == "" {
1031 errors.Add(modelFieldPrefix+".name", "model name is required")
1032 }
1033 if model.ContextWindow <= 0 {
1034 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1035 }
1036 if model.DefaultMaxTokens <= 0 {
1037 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1038 }
1039 if model.DefaultMaxTokens > model.ContextWindow {
1040 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1041 }
1042
1043 // Validate cost fields
1044 if model.CostPer1MIn < 0 {
1045 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1046 }
1047 if model.CostPer1MOut < 0 {
1048 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1049 }
1050 if model.CostPer1MInCached < 0 {
1051 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1052 }
1053 if model.CostPer1MOutCached < 0 {
1054 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1055 }
1056 }
1057
1058 // Validate default model references
1059 if providerConfig.DefaultLargeModel != "" {
1060 if !modelIDs[providerConfig.DefaultLargeModel] {
1061 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1062 }
1063 }
1064 if providerConfig.DefaultSmallModel != "" {
1065 if !modelIDs[providerConfig.DefaultSmallModel] {
1066 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1067 }
1068 }
1069
1070 // Validate provider-specific requirements
1071 c.validateProviderSpecific(providerID, providerConfig, errors)
1072 }
1073}
1074
1075// validateProviderSpecific validates provider-specific requirements
1076func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1077 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1078
1079 switch providerID {
1080 case provider.InferenceProviderVertexAI:
1081 if !providerConfig.Disabled {
1082 if providerConfig.ExtraParams == nil {
1083 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1084 } else {
1085 if providerConfig.ExtraParams["project"] == "" {
1086 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1087 }
1088 if providerConfig.ExtraParams["location"] == "" {
1089 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1090 }
1091 }
1092 }
1093 case provider.InferenceProviderBedrock:
1094 if !providerConfig.Disabled {
1095 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1096 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1097 }
1098 // Check for AWS credentials in environment
1099 if !hasAWSCredentials() {
1100 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1101 }
1102 }
1103 }
1104}
1105
1106// validateModels validates preferred model configurations
1107func (c *Config) validateModels(errors *ValidationErrors) {
1108 // Validate large model
1109 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1110 if c.Models.Large.ModelID == "" {
1111 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1112 }
1113 if c.Models.Large.Provider == "" {
1114 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1115 }
1116
1117 // Check if provider exists and is not disabled
1118 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1119 if providerConfig.Disabled {
1120 errors.Add("models.large.provider", "large model provider is disabled")
1121 }
1122
1123 // Check if model exists in provider
1124 modelExists := false
1125 for _, model := range providerConfig.Models {
1126 if model.ID == c.Models.Large.ModelID {
1127 modelExists = true
1128 break
1129 }
1130 }
1131 if !modelExists {
1132 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1133 }
1134 } else {
1135 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1136 }
1137 }
1138
1139 // Validate small model
1140 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1141 if c.Models.Small.ModelID == "" {
1142 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1143 }
1144 if c.Models.Small.Provider == "" {
1145 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1146 }
1147
1148 // Check if provider exists and is not disabled
1149 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1150 if providerConfig.Disabled {
1151 errors.Add("models.small.provider", "small model provider is disabled")
1152 }
1153
1154 // Check if model exists in provider
1155 modelExists := false
1156 for _, model := range providerConfig.Models {
1157 if model.ID == c.Models.Small.ModelID {
1158 modelExists = true
1159 break
1160 }
1161 }
1162 if !modelExists {
1163 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1164 }
1165 } else {
1166 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1167 }
1168 }
1169}
1170
1171// validateAgents validates agent configurations
1172func (c *Config) validateAgents(errors *ValidationErrors) {
1173 if c.Agents == nil {
1174 c.Agents = make(map[AgentID]Agent)
1175 }
1176
1177 validTools := []string{
1178 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1179 }
1180
1181 for agentID, agent := range c.Agents {
1182 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1183
1184 // Validate agent ID consistency
1185 if agent.ID != agentID {
1186 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1187 }
1188
1189 // Validate required fields
1190 if agent.ID == "" {
1191 errors.Add(fieldPrefix+".id", "agent ID is required")
1192 }
1193 if agent.Name == "" {
1194 errors.Add(fieldPrefix+".name", "agent name is required")
1195 }
1196
1197 // Validate model type
1198 if agent.Model != LargeModel && agent.Model != SmallModel {
1199 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1200 }
1201
1202 // Validate allowed tools
1203 if agent.AllowedTools != nil {
1204 for i, tool := range agent.AllowedTools {
1205 validTool := slices.Contains(validTools, tool)
1206 if !validTool {
1207 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1208 }
1209 }
1210 }
1211
1212 // Validate MCP references
1213 if agent.AllowedMCP != nil {
1214 for mcpName := range agent.AllowedMCP {
1215 if _, exists := c.MCP[mcpName]; !exists {
1216 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1217 }
1218 }
1219 }
1220
1221 // Validate LSP references
1222 if agent.AllowedLSP != nil {
1223 for _, lspName := range agent.AllowedLSP {
1224 if _, exists := c.LSP[lspName]; !exists {
1225 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1226 }
1227 }
1228 }
1229
1230 // Validate context paths (basic path validation)
1231 for i, contextPath := range agent.ContextPaths {
1232 if contextPath == "" {
1233 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1234 }
1235 // Check for invalid characters in path
1236 if strings.Contains(contextPath, "\x00") {
1237 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1238 }
1239 }
1240
1241 // Validate known agents maintain their core properties
1242 if agentID == AgentCoder {
1243 if agent.Name != "Coder" {
1244 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1245 }
1246 if agent.Description != "An agent that helps with executing coding tasks." {
1247 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1248 }
1249 } else if agentID == AgentTask {
1250 if agent.Name != "Task" {
1251 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1252 }
1253 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1254 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1255 }
1256 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1257 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1258 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1259 }
1260 }
1261 }
1262}
1263
1264// validateOptions validates configuration options
1265func (c *Config) validateOptions(errors *ValidationErrors) {
1266 // Validate data directory
1267 if c.Options.DataDirectory == "" {
1268 errors.Add("options.data_directory", "data directory is required")
1269 }
1270
1271 // Validate context paths
1272 for i, contextPath := range c.Options.ContextPaths {
1273 if contextPath == "" {
1274 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1275 }
1276 if strings.Contains(contextPath, "\x00") {
1277 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1278 }
1279 }
1280}
1281
1282// validateMCPs validates MCP configurations
1283func (c *Config) validateMCPs(errors *ValidationErrors) {
1284 if c.MCP == nil {
1285 c.MCP = make(map[string]MCP)
1286 }
1287
1288 for mcpName, mcpConfig := range c.MCP {
1289 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1290
1291 // Validate MCP type
1292 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1293 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1294 }
1295
1296 // Validate based on type
1297 if mcpConfig.Type == MCPStdio {
1298 if mcpConfig.Command == "" {
1299 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1300 }
1301 } else if mcpConfig.Type == MCPSse {
1302 if mcpConfig.URL == "" {
1303 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1304 }
1305 }
1306 }
1307}
1308
1309// validateLSPs validates LSP configurations
1310func (c *Config) validateLSPs(errors *ValidationErrors) {
1311 if c.LSP == nil {
1312 c.LSP = make(map[string]LSPConfig)
1313 }
1314
1315 for lspName, lspConfig := range c.LSP {
1316 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1317
1318 if lspConfig.Command == "" {
1319 errors.Add(fieldPrefix+".command", "command is required for LSP")
1320 }
1321 }
1322}
1323
1324// validateCrossReferences validates cross-references between different config sections
1325func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1326 // Validate that agents can use their assigned model types
1327 for agentID, agent := range c.Agents {
1328 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1329
1330 var preferredModel PreferredModel
1331 switch agent.Model {
1332 case LargeModel:
1333 preferredModel = c.Models.Large
1334 case SmallModel:
1335 preferredModel = c.Models.Small
1336 }
1337
1338 if preferredModel.Provider != "" {
1339 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1340 if providerConfig.Disabled {
1341 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1342 }
1343 }
1344 }
1345 }
1346}
1347
1348// validateCompleteness validates that the configuration is complete and usable
1349func (c *Config) validateCompleteness(errors *ValidationErrors) {
1350 // Check for at least one valid, non-disabled provider
1351 hasValidProvider := false
1352 for _, providerConfig := range c.Providers {
1353 if !providerConfig.Disabled {
1354 hasValidProvider = true
1355 break
1356 }
1357 }
1358 if !hasValidProvider {
1359 errors.Add("providers", "at least one non-disabled provider is required")
1360 }
1361
1362 // Check that default agents exist
1363 if _, exists := c.Agents[AgentCoder]; !exists {
1364 errors.Add("agents", "coder agent is required")
1365 }
1366 if _, exists := c.Agents[AgentTask]; !exists {
1367 errors.Add("agents", "task agent is required")
1368 }
1369
1370 // Check that preferred models are set if providers exist
1371 if hasValidProvider {
1372 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1373 errors.Add("models.large", "large preferred model must be configured when providers are available")
1374 }
1375 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1376 errors.Add("models.small", "small preferred model must be configured when providers are available")
1377 }
1378 }
1379}