1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 HasReasoningEffort bool `json:"has_reasoning_effort"`
69 SupportsImages bool `json:"supports_attachments"`
70}
71
72type VertexAIOptions struct {
73 APIKey string `json:"api_key,omitempty"`
74 Project string `json:"project,omitempty"`
75 Location string `json:"location,omitempty"`
76}
77
78type ProviderConfig struct {
79 ID provider.InferenceProvider `json:"id"`
80 BaseURL string `json:"base_url,omitempty"`
81 ProviderType provider.Type `json:"provider_type"`
82 APIKey string `json:"api_key,omitempty"`
83 Disabled bool `json:"disabled"`
84 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
85 // used for e.x for vertex to set the project
86 ExtraParams map[string]string `json:"extra_params,omitempty"`
87
88 DefaultLargeModel string `json:"default_large_model,omitempty"`
89 DefaultSmallModel string `json:"default_small_model,omitempty"`
90
91 Models []Model `json:"models,omitempty"`
92}
93
94type Agent struct {
95 ID AgentID `json:"id"`
96 Name string `json:"name"`
97 Description string `json:"description,omitempty"`
98 // This is the id of the system prompt used by the agent
99 Disabled bool `json:"disabled"`
100
101 Model ModelType `json:"model"`
102
103 // The available tools for the agent
104 // if this is nil, all tools are available
105 AllowedTools []string `json:"allowed_tools"`
106
107 // this tells us which MCPs are available for this agent
108 // if this is empty all mcps are available
109 // the string array is the list of tools from the AllowedMCP the agent has available
110 // if the string array is nil, all tools from the AllowedMCP are available
111 AllowedMCP map[string][]string `json:"allowed_mcp"`
112
113 // The list of LSPs that this agent can use
114 // if this is nil, all LSPs are available
115 AllowedLSP []string `json:"allowed_lsp"`
116
117 // Overrides the context paths for this agent
118 ContextPaths []string `json:"context_paths"`
119}
120
121type MCPType string
122
123const (
124 MCPStdio MCPType = "stdio"
125 MCPSse MCPType = "sse"
126)
127
128type MCP struct {
129 Command string `json:"command"`
130 Env []string `json:"env"`
131 Args []string `json:"args"`
132 Type MCPType `json:"type"`
133 URL string `json:"url"`
134 Headers map[string]string `json:"headers"`
135}
136
137type LSPConfig struct {
138 Disabled bool `json:"enabled"`
139 Command string `json:"command"`
140 Args []string `json:"args"`
141 Options any `json:"options"`
142}
143
144type TUIOptions struct {
145 CompactMode bool `json:"compact_mode"`
146 // Here we can add themes later or any TUI related options
147}
148
149type Options struct {
150 ContextPaths []string `json:"context_paths"`
151 TUI TUIOptions `json:"tui"`
152 Debug bool `json:"debug"`
153 DebugLSP bool `json:"debug_lsp"`
154 DisableAutoSummarize bool `json:"disable_auto_summarize"`
155 // Relative to the cwd
156 DataDirectory string `json:"data_directory"`
157}
158
159type PreferredModel struct {
160 ModelID string `json:"model_id"`
161 Provider provider.InferenceProvider `json:"provider"`
162 ReasoningEffort string `json:"reasoning_effort,omitempty"`
163}
164
165type PreferredModels struct {
166 Large PreferredModel `json:"large"`
167 Small PreferredModel `json:"small"`
168}
169
170type Config struct {
171 Models PreferredModels `json:"models"`
172 // List of configured providers
173 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
174
175 // List of configured agents
176 Agents map[AgentID]Agent `json:"agents,omitempty"`
177
178 // List of configured MCPs
179 MCP map[string]MCP `json:"mcp,omitempty"`
180
181 // List of configured LSPs
182 LSP map[string]LSPConfig `json:"lsp,omitempty"`
183
184 // Miscellaneous options
185 Options Options `json:"options"`
186}
187
188var (
189 instance *Config // The single instance of the Singleton
190 cwd string
191 once sync.Once // Ensures the initialization happens only once
192
193)
194
195func loadConfig(cwd string, debug bool) (*Config, error) {
196 // First read the global config file
197 cfgPath := ConfigPath()
198
199 cfg := defaultConfigBasedOnEnv()
200 cfg.Options.Debug = debug
201 defaultLevel := slog.LevelInfo
202 if cfg.Options.Debug {
203 defaultLevel = slog.LevelDebug
204 }
205 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
206 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
207
208 // if file does not exist create it
209 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
210 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
211 return cfg, fmt.Errorf("failed to create directory: %w", err)
212 }
213 if _, err := os.Create(loggingFile); err != nil {
214 return cfg, fmt.Errorf("failed to create log file: %w", err)
215 }
216 }
217
218 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
219 if err != nil {
220 return cfg, fmt.Errorf("failed to open log file: %w", err)
221 }
222 // Configure logger
223 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
224 Level: defaultLevel,
225 }))
226 slog.SetDefault(logger)
227 } else {
228 // Configure logger
229 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
230 Level: defaultLevel,
231 }))
232 slog.SetDefault(logger)
233 }
234 var globalCfg *Config
235 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
236 // some other error occurred while checking the file
237 return nil, err
238 } else if err == nil {
239 // config file exists, read it
240 file, err := os.ReadFile(cfgPath)
241 if err != nil {
242 return nil, err
243 }
244 globalCfg = &Config{}
245 if err := json.Unmarshal(file, globalCfg); err != nil {
246 return nil, err
247 }
248 } else {
249 // config file does not exist, create a new one
250 globalCfg = &Config{}
251 }
252
253 var localConfig *Config
254 // Global config loaded, now read the local config file
255 localConfigPath := filepath.Join(cwd, "crush.json")
256 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
257 // some other error occurred while checking the file
258 return nil, err
259 } else if err == nil {
260 // local config file exists, read it
261 file, err := os.ReadFile(localConfigPath)
262 if err != nil {
263 return nil, err
264 }
265 localConfig = &Config{}
266 if err := json.Unmarshal(file, localConfig); err != nil {
267 return nil, err
268 }
269 }
270
271 // merge options
272 mergeOptions(cfg, globalCfg, localConfig)
273
274 mergeProviderConfigs(cfg, globalCfg, localConfig)
275 // no providers found the app is not initialized yet
276 if len(cfg.Providers) == 0 {
277 return cfg, nil
278 }
279 preferredProvider := getPreferredProvider(cfg.Providers)
280 if preferredProvider != nil {
281 cfg.Models = PreferredModels{
282 Large: PreferredModel{
283 ModelID: preferredProvider.DefaultLargeModel,
284 Provider: preferredProvider.ID,
285 },
286 Small: PreferredModel{
287 ModelID: preferredProvider.DefaultSmallModel,
288 Provider: preferredProvider.ID,
289 },
290 }
291 } else {
292 // No valid providers found, set empty models
293 cfg.Models = PreferredModels{}
294 }
295
296 mergeModels(cfg, globalCfg, localConfig)
297
298 agents := map[AgentID]Agent{
299 AgentCoder: {
300 ID: AgentCoder,
301 Name: "Coder",
302 Description: "An agent that helps with executing coding tasks.",
303 Model: LargeModel,
304 ContextPaths: cfg.Options.ContextPaths,
305 // All tools allowed
306 },
307 AgentTask: {
308 ID: AgentTask,
309 Name: "Task",
310 Description: "An agent that helps with searching for context and finding implementation details.",
311 Model: LargeModel,
312 ContextPaths: cfg.Options.ContextPaths,
313 AllowedTools: []string{
314 "glob",
315 "grep",
316 "ls",
317 "sourcegraph",
318 "view",
319 },
320 // NO MCPs or LSPs by default
321 AllowedMCP: map[string][]string{},
322 AllowedLSP: []string{},
323 },
324 }
325 cfg.Agents = agents
326 mergeAgents(cfg, globalCfg, localConfig)
327 mergeMCPs(cfg, globalCfg, localConfig)
328 mergeLSPs(cfg, globalCfg, localConfig)
329
330 // Validate the final configuration
331 if err := cfg.Validate(); err != nil {
332 return cfg, fmt.Errorf("configuration validation failed: %w", err)
333 }
334
335 return cfg, nil
336}
337
338func Init(workingDir string, debug bool) (*Config, error) {
339 var err error
340 once.Do(func() {
341 cwd = workingDir
342 instance, err = loadConfig(cwd, debug)
343 if err != nil {
344 logging.Error("Failed to load config", "error", err)
345 }
346 })
347
348 return instance, err
349}
350
351func Get() *Config {
352 if instance == nil {
353 // TODO: Handle this better
354 panic("Config not initialized. Call InitConfig first.")
355 }
356 return instance
357}
358
359func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
360 providers := Providers()
361 for _, p := range providers {
362 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
363 return &providerConfig
364 }
365 }
366 // if none found return the first configured provider
367 for _, providerConfig := range configuredProviders {
368 if !providerConfig.Disabled {
369 return &providerConfig
370 }
371 }
372 return nil
373}
374
375func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
376 if other.APIKey != "" {
377 base.APIKey = other.APIKey
378 }
379 // Only change these options if the provider is not a known provider
380 if !slices.Contains(provider.KnownProviders(), p) {
381 if other.BaseURL != "" {
382 base.BaseURL = other.BaseURL
383 }
384 if other.ProviderType != "" {
385 base.ProviderType = other.ProviderType
386 }
387 if len(other.ExtraHeaders) > 0 {
388 if base.ExtraHeaders == nil {
389 base.ExtraHeaders = make(map[string]string)
390 }
391 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
392 }
393 if len(other.ExtraParams) > 0 {
394 if base.ExtraParams == nil {
395 base.ExtraParams = make(map[string]string)
396 }
397 maps.Copy(base.ExtraParams, other.ExtraParams)
398 }
399 }
400
401 if other.Disabled {
402 base.Disabled = other.Disabled
403 }
404
405 if other.DefaultLargeModel != "" {
406 base.DefaultLargeModel = other.DefaultLargeModel
407 }
408 // Add new models if they don't exist
409 if other.Models != nil {
410 for _, model := range other.Models {
411 // check if the model already exists
412 exists := false
413 for _, existingModel := range base.Models {
414 if existingModel.ID == model.ID {
415 exists = true
416 break
417 }
418 }
419 if !exists {
420 base.Models = append(base.Models, model)
421 }
422 }
423 }
424
425 return base
426}
427
428func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
429 if !slices.Contains(provider.KnownProviders(), p) {
430 if providerConfig.ProviderType != provider.TypeOpenAI {
431 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
432 }
433 if providerConfig.BaseURL == "" {
434 return errors.New("base URL must be set for custom providers")
435 }
436 if providerConfig.APIKey == "" {
437 return errors.New("API key must be set for custom providers")
438 }
439 }
440 return nil
441}
442
443func mergeModels(base, global, local *Config) {
444 for _, cfg := range []*Config{global, local} {
445 if cfg == nil {
446 continue
447 }
448 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
449 base.Models.Large = cfg.Models.Large
450 }
451
452 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
453 base.Models.Small = cfg.Models.Small
454 }
455 }
456}
457
458func mergeOptions(base, global, local *Config) {
459 for _, cfg := range []*Config{global, local} {
460 if cfg == nil {
461 continue
462 }
463 baseOptions := base.Options
464 other := cfg.Options
465 if len(other.ContextPaths) > 0 {
466 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
467 }
468
469 if other.TUI.CompactMode {
470 baseOptions.TUI.CompactMode = other.TUI.CompactMode
471 }
472
473 if other.Debug {
474 baseOptions.Debug = other.Debug
475 }
476
477 if other.DebugLSP {
478 baseOptions.DebugLSP = other.DebugLSP
479 }
480
481 if other.DisableAutoSummarize {
482 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
483 }
484
485 if other.DataDirectory != "" {
486 baseOptions.DataDirectory = other.DataDirectory
487 }
488 base.Options = baseOptions
489 }
490}
491
492func mergeAgents(base, global, local *Config) {
493 for _, cfg := range []*Config{global, local} {
494 if cfg == nil {
495 continue
496 }
497 for agentID, newAgent := range cfg.Agents {
498 if _, ok := base.Agents[agentID]; !ok {
499 // New agent - apply defaults
500 newAgent.ID = agentID // Ensure the ID is set correctly
501 if newAgent.Model == "" {
502 newAgent.Model = LargeModel // Default model type
503 }
504 // Context paths are always additive - start with global, then add custom
505 if len(newAgent.ContextPaths) > 0 {
506 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
507 } else {
508 newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
509 }
510 base.Agents[agentID] = newAgent
511 } else {
512 baseAgent := base.Agents[agentID]
513
514 // Special handling for known agents - only allow model changes
515 if agentID == AgentCoder || agentID == AgentTask {
516 if newAgent.Model != "" {
517 baseAgent.Model = newAgent.Model
518 }
519 // For known agents, only allow MCP and LSP configuration
520 if newAgent.AllowedMCP != nil {
521 baseAgent.AllowedMCP = newAgent.AllowedMCP
522 }
523 if newAgent.AllowedLSP != nil {
524 baseAgent.AllowedLSP = newAgent.AllowedLSP
525 }
526 // Context paths are additive for known agents too
527 if len(newAgent.ContextPaths) > 0 {
528 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
529 }
530 } else {
531 // Custom agents - allow full merging
532 if newAgent.Name != "" {
533 baseAgent.Name = newAgent.Name
534 }
535 if newAgent.Description != "" {
536 baseAgent.Description = newAgent.Description
537 }
538 if newAgent.Model != "" {
539 baseAgent.Model = newAgent.Model
540 } else if baseAgent.Model == "" {
541 baseAgent.Model = LargeModel // Default fallback
542 }
543
544 // Boolean fields - always update (including false values)
545 baseAgent.Disabled = newAgent.Disabled
546
547 // Slice/Map fields - update if provided (including empty slices/maps)
548 if newAgent.AllowedTools != nil {
549 baseAgent.AllowedTools = newAgent.AllowedTools
550 }
551 if newAgent.AllowedMCP != nil {
552 baseAgent.AllowedMCP = newAgent.AllowedMCP
553 }
554 if newAgent.AllowedLSP != nil {
555 baseAgent.AllowedLSP = newAgent.AllowedLSP
556 }
557 // Context paths are additive for custom agents too
558 if len(newAgent.ContextPaths) > 0 {
559 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
560 }
561 }
562
563 base.Agents[agentID] = baseAgent
564 }
565 }
566 }
567}
568
569func mergeMCPs(base, global, local *Config) {
570 for _, cfg := range []*Config{global, local} {
571 if cfg == nil {
572 continue
573 }
574 maps.Copy(base.MCP, cfg.MCP)
575 }
576}
577
578func mergeLSPs(base, global, local *Config) {
579 for _, cfg := range []*Config{global, local} {
580 if cfg == nil {
581 continue
582 }
583 maps.Copy(base.LSP, cfg.LSP)
584 }
585}
586
587func mergeProviderConfigs(base, global, local *Config) {
588 for _, cfg := range []*Config{global, local} {
589 if cfg == nil {
590 continue
591 }
592 for providerName, globalProvider := range cfg.Providers {
593 if _, ok := base.Providers[providerName]; !ok {
594 base.Providers[providerName] = globalProvider
595 } else {
596 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
597 }
598 }
599 }
600
601 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
602 for providerName, providerConfig := range base.Providers {
603 err := validateProvider(providerName, providerConfig)
604 if err != nil {
605 logging.Warn("Skipping provider", "name", providerName, "error", err)
606 continue // Skip invalid providers
607 }
608 finalProviders[providerName] = providerConfig
609 }
610 base.Providers = finalProviders
611}
612
613func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
614 switch providerId {
615 case provider.InferenceProviderAnthropic:
616 return ProviderConfig{
617 ID: providerId,
618 ProviderType: provider.TypeAnthropic,
619 }
620 case provider.InferenceProviderOpenAI:
621 return ProviderConfig{
622 ID: providerId,
623 ProviderType: provider.TypeOpenAI,
624 }
625 case provider.InferenceProviderGemini:
626 return ProviderConfig{
627 ID: providerId,
628 ProviderType: provider.TypeGemini,
629 }
630 case provider.InferenceProviderBedrock:
631 return ProviderConfig{
632 ID: providerId,
633 ProviderType: provider.TypeBedrock,
634 }
635 case provider.InferenceProviderAzure:
636 return ProviderConfig{
637 ID: providerId,
638 ProviderType: provider.TypeAzure,
639 }
640 case provider.InferenceProviderOpenRouter:
641 return ProviderConfig{
642 ID: providerId,
643 ProviderType: provider.TypeOpenAI,
644 BaseURL: "https://openrouter.ai/api/v1",
645 ExtraHeaders: map[string]string{
646 "HTTP-Referer": "crush.charm.land",
647 "X-Title": "Crush",
648 },
649 }
650 case provider.InferenceProviderXAI:
651 return ProviderConfig{
652 ID: providerId,
653 ProviderType: provider.TypeXAI,
654 BaseURL: "https://api.x.ai/v1",
655 }
656 case provider.InferenceProviderVertexAI:
657 return ProviderConfig{
658 ID: providerId,
659 ProviderType: provider.TypeVertexAI,
660 }
661 default:
662 return ProviderConfig{
663 ID: providerId,
664 ProviderType: provider.TypeOpenAI,
665 }
666 }
667}
668
669func defaultConfigBasedOnEnv() *Config {
670 cfg := &Config{
671 Options: Options{
672 DataDirectory: defaultDataDirectory,
673 ContextPaths: defaultContextPaths,
674 },
675 Providers: make(map[provider.InferenceProvider]ProviderConfig),
676 Agents: make(map[AgentID]Agent),
677 LSP: make(map[string]LSPConfig),
678 MCP: make(map[string]MCP),
679 }
680
681 providers := Providers()
682
683 for _, p := range providers {
684 if strings.HasPrefix(p.APIKey, "$") {
685 envVar := strings.TrimPrefix(p.APIKey, "$")
686 if apiKey := os.Getenv(envVar); apiKey != "" {
687 providerConfig := providerDefaultConfig(p.ID)
688 providerConfig.APIKey = apiKey
689 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
690 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
691 baseURL := p.APIEndpoint
692 if strings.HasPrefix(baseURL, "$") {
693 envVar := strings.TrimPrefix(baseURL, "$")
694 baseURL = os.Getenv(envVar)
695 }
696 providerConfig.BaseURL = baseURL
697 for _, model := range p.Models {
698 configModel := Model{
699 ID: model.ID,
700 Name: model.Name,
701 CostPer1MIn: model.CostPer1MIn,
702 CostPer1MOut: model.CostPer1MOut,
703 CostPer1MInCached: model.CostPer1MInCached,
704 CostPer1MOutCached: model.CostPer1MOutCached,
705 ContextWindow: model.ContextWindow,
706 DefaultMaxTokens: model.DefaultMaxTokens,
707 CanReason: model.CanReason,
708 SupportsImages: model.SupportsImages,
709 }
710 // Set reasoning effort for reasoning models
711 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
712 configModel.HasReasoningEffort = model.HasReasoningEffort
713 configModel.ReasoningEffort = model.DefaultReasoningEffort
714 }
715 providerConfig.Models = append(providerConfig.Models, configModel)
716 }
717 cfg.Providers[p.ID] = providerConfig
718 }
719 }
720 }
721 // TODO: support local models
722
723 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
724 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
725 providerConfig.ExtraParams = map[string]string{
726 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
727 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
728 }
729 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
730 }
731
732 if hasAWSCredentials() {
733 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
734 providerConfig.ExtraParams = map[string]string{
735 "region": os.Getenv("AWS_DEFAULT_REGION"),
736 }
737 if providerConfig.ExtraParams["region"] == "" {
738 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
739 }
740 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
741 }
742 return cfg
743}
744
745func hasAWSCredentials() bool {
746 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
747 return true
748 }
749
750 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
751 return true
752 }
753
754 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
755 return true
756 }
757
758 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
759 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
760 return true
761 }
762
763 return false
764}
765
766func WorkingDirectory() string {
767 return cwd
768}
769
770// TODO: Handle error state
771
772func GetAgentModel(agentID AgentID) Model {
773 cfg := Get()
774 agent, ok := cfg.Agents[agentID]
775 if !ok {
776 logging.Error("Agent not found", "agent_id", agentID)
777 return Model{}
778 }
779
780 var model PreferredModel
781 switch agent.Model {
782 case LargeModel:
783 model = cfg.Models.Large
784 case SmallModel:
785 model = cfg.Models.Small
786 default:
787 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
788 model = cfg.Models.Large // Fallback to large model
789 }
790 providerConfig, ok := cfg.Providers[model.Provider]
791 if !ok {
792 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
793 return Model{}
794 }
795
796 for _, m := range providerConfig.Models {
797 if m.ID == model.ModelID {
798 return m
799 }
800 }
801
802 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
803 return Model{}
804}
805
806func GetAgentProvider(agentID AgentID) ProviderConfig {
807 cfg := Get()
808 agent, ok := cfg.Agents[agentID]
809 if !ok {
810 logging.Error("Agent not found", "agent_id", agentID)
811 return ProviderConfig{}
812 }
813
814 var model PreferredModel
815 switch agent.Model {
816 case LargeModel:
817 model = cfg.Models.Large
818 case SmallModel:
819 model = cfg.Models.Small
820 default:
821 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
822 model = cfg.Models.Large // Fallback to large model
823 }
824
825 providerConfig, ok := cfg.Providers[model.Provider]
826 if !ok {
827 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
828 return ProviderConfig{}
829 }
830
831 return providerConfig
832}
833
834func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
835 cfg := Get()
836 providerConfig, ok := cfg.Providers[provider]
837 if !ok {
838 logging.Error("Provider not found", "provider", provider)
839 return Model{}
840 }
841
842 for _, model := range providerConfig.Models {
843 if model.ID == modelID {
844 return model
845 }
846 }
847
848 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
849 return Model{}
850}
851
852func GetModel(modelType ModelType) Model {
853 cfg := Get()
854 var model PreferredModel
855 switch modelType {
856 case LargeModel:
857 model = cfg.Models.Large
858 case SmallModel:
859 model = cfg.Models.Small
860 default:
861 model = cfg.Models.Large // Fallback to large model
862 }
863 providerConfig, ok := cfg.Providers[model.Provider]
864 if !ok {
865 return Model{}
866 }
867
868 for _, m := range providerConfig.Models {
869 if m.ID == model.ModelID {
870 return m
871 }
872 }
873 return Model{}
874}
875
876func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
877 cfg := Get()
878 switch modelType {
879 case LargeModel:
880 cfg.Models.Large = model
881 case SmallModel:
882 cfg.Models.Small = model
883 default:
884 return fmt.Errorf("unknown model type: %s", modelType)
885 }
886 return nil
887}
888
889// ValidationError represents a configuration validation error
890type ValidationError struct {
891 Field string
892 Message string
893}
894
895func (e ValidationError) Error() string {
896 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
897}
898
899// ValidationErrors represents multiple validation errors
900type ValidationErrors []ValidationError
901
902func (e ValidationErrors) Error() string {
903 if len(e) == 0 {
904 return "no validation errors"
905 }
906 if len(e) == 1 {
907 return e[0].Error()
908 }
909
910 var messages []string
911 for _, err := range e {
912 messages = append(messages, err.Error())
913 }
914 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
915}
916
917// HasErrors returns true if there are any validation errors
918func (e ValidationErrors) HasErrors() bool {
919 return len(e) > 0
920}
921
922// Add appends a new validation error
923func (e *ValidationErrors) Add(field, message string) {
924 *e = append(*e, ValidationError{Field: field, Message: message})
925}
926
927// Validate performs comprehensive validation of the configuration
928func (c *Config) Validate() error {
929 var errors ValidationErrors
930
931 // Validate providers
932 c.validateProviders(&errors)
933
934 // Validate models
935 c.validateModels(&errors)
936
937 // Validate agents
938 c.validateAgents(&errors)
939
940 // Validate options
941 c.validateOptions(&errors)
942
943 // Validate MCP configurations
944 c.validateMCPs(&errors)
945
946 // Validate LSP configurations
947 c.validateLSPs(&errors)
948
949 // Validate cross-references
950 c.validateCrossReferences(&errors)
951
952 // Validate completeness
953 c.validateCompleteness(&errors)
954
955 if errors.HasErrors() {
956 return errors
957 }
958
959 return nil
960}
961
962// validateProviders validates all provider configurations
963func (c *Config) validateProviders(errors *ValidationErrors) {
964 if c.Providers == nil {
965 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
966 }
967
968 knownProviders := provider.KnownProviders()
969 validTypes := []provider.Type{
970 provider.TypeOpenAI,
971 provider.TypeAnthropic,
972 provider.TypeGemini,
973 provider.TypeAzure,
974 provider.TypeBedrock,
975 provider.TypeVertexAI,
976 provider.TypeXAI,
977 }
978
979 for providerID, providerConfig := range c.Providers {
980 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
981
982 // Validate API key for non-disabled providers
983 if !providerConfig.Disabled && providerConfig.APIKey == "" {
984 // Special case for AWS Bedrock and VertexAI which may use other auth methods
985 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
986 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
987 }
988 }
989
990 // Validate provider type
991 validType := slices.Contains(validTypes, providerConfig.ProviderType)
992 if !validType {
993 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
994 }
995
996 // Validate custom providers
997 isKnownProvider := slices.Contains(knownProviders, providerID)
998
999 if !isKnownProvider {
1000 // Custom provider validation
1001 if providerConfig.BaseURL == "" {
1002 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1003 }
1004 if providerConfig.ProviderType != provider.TypeOpenAI {
1005 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1006 }
1007 }
1008
1009 // Validate models
1010 modelIDs := make(map[string]bool)
1011 for i, model := range providerConfig.Models {
1012 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1013
1014 // Check for duplicate model IDs
1015 if modelIDs[model.ID] {
1016 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1017 }
1018 modelIDs[model.ID] = true
1019
1020 // Validate required model fields
1021 if model.ID == "" {
1022 errors.Add(modelFieldPrefix+".id", "model ID is required")
1023 }
1024 if model.Name == "" {
1025 errors.Add(modelFieldPrefix+".name", "model name is required")
1026 }
1027 if model.ContextWindow <= 0 {
1028 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1029 }
1030 if model.DefaultMaxTokens <= 0 {
1031 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1032 }
1033 if model.DefaultMaxTokens > model.ContextWindow {
1034 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1035 }
1036
1037 // Validate cost fields
1038 if model.CostPer1MIn < 0 {
1039 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1040 }
1041 if model.CostPer1MOut < 0 {
1042 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1043 }
1044 if model.CostPer1MInCached < 0 {
1045 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1046 }
1047 if model.CostPer1MOutCached < 0 {
1048 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1049 }
1050 }
1051
1052 // Validate default model references
1053 if providerConfig.DefaultLargeModel != "" {
1054 if !modelIDs[providerConfig.DefaultLargeModel] {
1055 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1056 }
1057 }
1058 if providerConfig.DefaultSmallModel != "" {
1059 if !modelIDs[providerConfig.DefaultSmallModel] {
1060 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1061 }
1062 }
1063
1064 // Validate provider-specific requirements
1065 c.validateProviderSpecific(providerID, providerConfig, errors)
1066 }
1067}
1068
1069// validateProviderSpecific validates provider-specific requirements
1070func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1071 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1072
1073 switch providerID {
1074 case provider.InferenceProviderVertexAI:
1075 if !providerConfig.Disabled {
1076 if providerConfig.ExtraParams == nil {
1077 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1078 } else {
1079 if providerConfig.ExtraParams["project"] == "" {
1080 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1081 }
1082 if providerConfig.ExtraParams["location"] == "" {
1083 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1084 }
1085 }
1086 }
1087 case provider.InferenceProviderBedrock:
1088 if !providerConfig.Disabled {
1089 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1090 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1091 }
1092 // Check for AWS credentials in environment
1093 if !hasAWSCredentials() {
1094 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1095 }
1096 }
1097 }
1098}
1099
1100// validateModels validates preferred model configurations
1101func (c *Config) validateModels(errors *ValidationErrors) {
1102 // Validate large model
1103 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1104 if c.Models.Large.ModelID == "" {
1105 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1106 }
1107 if c.Models.Large.Provider == "" {
1108 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1109 }
1110
1111 // Check if provider exists and is not disabled
1112 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1113 if providerConfig.Disabled {
1114 errors.Add("models.large.provider", "large model provider is disabled")
1115 }
1116
1117 // Check if model exists in provider
1118 modelExists := false
1119 for _, model := range providerConfig.Models {
1120 if model.ID == c.Models.Large.ModelID {
1121 modelExists = true
1122 break
1123 }
1124 }
1125 if !modelExists {
1126 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1127 }
1128 } else {
1129 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1130 }
1131 }
1132
1133 // Validate small model
1134 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1135 if c.Models.Small.ModelID == "" {
1136 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1137 }
1138 if c.Models.Small.Provider == "" {
1139 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1140 }
1141
1142 // Check if provider exists and is not disabled
1143 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1144 if providerConfig.Disabled {
1145 errors.Add("models.small.provider", "small model provider is disabled")
1146 }
1147
1148 // Check if model exists in provider
1149 modelExists := false
1150 for _, model := range providerConfig.Models {
1151 if model.ID == c.Models.Small.ModelID {
1152 modelExists = true
1153 break
1154 }
1155 }
1156 if !modelExists {
1157 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1158 }
1159 } else {
1160 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1161 }
1162 }
1163}
1164
1165// validateAgents validates agent configurations
1166func (c *Config) validateAgents(errors *ValidationErrors) {
1167 if c.Agents == nil {
1168 c.Agents = make(map[AgentID]Agent)
1169 }
1170
1171 validTools := []string{
1172 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1173 }
1174
1175 for agentID, agent := range c.Agents {
1176 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1177
1178 // Validate agent ID consistency
1179 if agent.ID != agentID {
1180 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1181 }
1182
1183 // Validate required fields
1184 if agent.ID == "" {
1185 errors.Add(fieldPrefix+".id", "agent ID is required")
1186 }
1187 if agent.Name == "" {
1188 errors.Add(fieldPrefix+".name", "agent name is required")
1189 }
1190
1191 // Validate model type
1192 if agent.Model != LargeModel && agent.Model != SmallModel {
1193 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1194 }
1195
1196 // Validate allowed tools
1197 if agent.AllowedTools != nil {
1198 for i, tool := range agent.AllowedTools {
1199 validTool := slices.Contains(validTools, tool)
1200 if !validTool {
1201 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1202 }
1203 }
1204 }
1205
1206 // Validate MCP references
1207 if agent.AllowedMCP != nil {
1208 for mcpName := range agent.AllowedMCP {
1209 if _, exists := c.MCP[mcpName]; !exists {
1210 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1211 }
1212 }
1213 }
1214
1215 // Validate LSP references
1216 if agent.AllowedLSP != nil {
1217 for _, lspName := range agent.AllowedLSP {
1218 if _, exists := c.LSP[lspName]; !exists {
1219 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1220 }
1221 }
1222 }
1223
1224 // Validate context paths (basic path validation)
1225 for i, contextPath := range agent.ContextPaths {
1226 if contextPath == "" {
1227 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1228 }
1229 // Check for invalid characters in path
1230 if strings.Contains(contextPath, "\x00") {
1231 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1232 }
1233 }
1234
1235 // Validate known agents maintain their core properties
1236 if agentID == AgentCoder {
1237 if agent.Name != "Coder" {
1238 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1239 }
1240 if agent.Description != "An agent that helps with executing coding tasks." {
1241 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1242 }
1243 } else if agentID == AgentTask {
1244 if agent.Name != "Task" {
1245 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1246 }
1247 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1248 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1249 }
1250 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1251 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1252 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1253 }
1254 }
1255 }
1256}
1257
1258// validateOptions validates configuration options
1259func (c *Config) validateOptions(errors *ValidationErrors) {
1260 // Validate data directory
1261 if c.Options.DataDirectory == "" {
1262 errors.Add("options.data_directory", "data directory is required")
1263 }
1264
1265 // Validate context paths
1266 for i, contextPath := range c.Options.ContextPaths {
1267 if contextPath == "" {
1268 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1269 }
1270 if strings.Contains(contextPath, "\x00") {
1271 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1272 }
1273 }
1274}
1275
1276// validateMCPs validates MCP configurations
1277func (c *Config) validateMCPs(errors *ValidationErrors) {
1278 if c.MCP == nil {
1279 c.MCP = make(map[string]MCP)
1280 }
1281
1282 for mcpName, mcpConfig := range c.MCP {
1283 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1284
1285 // Validate MCP type
1286 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1287 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1288 }
1289
1290 // Validate based on type
1291 if mcpConfig.Type == MCPStdio {
1292 if mcpConfig.Command == "" {
1293 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1294 }
1295 } else if mcpConfig.Type == MCPSse {
1296 if mcpConfig.URL == "" {
1297 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1298 }
1299 }
1300 }
1301}
1302
1303// validateLSPs validates LSP configurations
1304func (c *Config) validateLSPs(errors *ValidationErrors) {
1305 if c.LSP == nil {
1306 c.LSP = make(map[string]LSPConfig)
1307 }
1308
1309 for lspName, lspConfig := range c.LSP {
1310 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1311
1312 if lspConfig.Command == "" {
1313 errors.Add(fieldPrefix+".command", "command is required for LSP")
1314 }
1315 }
1316}
1317
1318// validateCrossReferences validates cross-references between different config sections
1319func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1320 // Validate that agents can use their assigned model types
1321 for agentID, agent := range c.Agents {
1322 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1323
1324 var preferredModel PreferredModel
1325 switch agent.Model {
1326 case LargeModel:
1327 preferredModel = c.Models.Large
1328 case SmallModel:
1329 preferredModel = c.Models.Small
1330 }
1331
1332 if preferredModel.Provider != "" {
1333 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1334 if providerConfig.Disabled {
1335 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1336 }
1337 }
1338 }
1339 }
1340}
1341
1342// validateCompleteness validates that the configuration is complete and usable
1343func (c *Config) validateCompleteness(errors *ValidationErrors) {
1344 // Check for at least one valid, non-disabled provider
1345 hasValidProvider := false
1346 for _, providerConfig := range c.Providers {
1347 if !providerConfig.Disabled {
1348 hasValidProvider = true
1349 break
1350 }
1351 }
1352 if !hasValidProvider {
1353 errors.Add("providers", "at least one non-disabled provider is required")
1354 }
1355
1356 // Check that default agents exist
1357 if _, exists := c.Agents[AgentCoder]; !exists {
1358 errors.Add("agents", "coder agent is required")
1359 }
1360 if _, exists := c.Agents[AgentTask]; !exists {
1361 errors.Add("agents", "task agent is required")
1362 }
1363
1364 // Check that preferred models are set if providers exist
1365 if hasValidProvider {
1366 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1367 errors.Add("models.large", "large preferred model must be configured when providers are available")
1368 }
1369 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1370 errors.Add("models.small", "small preferred model must be configured when providers are available")
1371 }
1372 }
1373}