1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 SupportsImages bool `json:"supports_attachments"`
69}
70
71type VertexAIOptions struct {
72 APIKey string `json:"api_key,omitempty"`
73 Project string `json:"project,omitempty"`
74 Location string `json:"location,omitempty"`
75}
76
77type ProviderConfig struct {
78 ID provider.InferenceProvider `json:"id"`
79 BaseURL string `json:"base_url,omitempty"`
80 ProviderType provider.Type `json:"provider_type"`
81 APIKey string `json:"api_key,omitempty"`
82 Disabled bool `json:"disabled"`
83 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
84 // used for e.x for vertex to set the project
85 ExtraParams map[string]string `json:"extra_params,omitempty"`
86
87 DefaultLargeModel string `json:"default_large_model,omitempty"`
88 DefaultSmallModel string `json:"default_small_model,omitempty"`
89
90 Models []Model `json:"models,omitempty"`
91}
92
93type Agent struct {
94 ID AgentID `json:"id"`
95 Name string `json:"name"`
96 Description string `json:"description,omitempty"`
97 // This is the id of the system prompt used by the agent
98 Disabled bool `json:"disabled"`
99
100 Model ModelType `json:"model"`
101
102 // The available tools for the agent
103 // if this is nil, all tools are available
104 AllowedTools []string `json:"allowed_tools"`
105
106 // this tells us which MCPs are available for this agent
107 // if this is empty all mcps are available
108 // the string array is the list of tools from the AllowedMCP the agent has available
109 // if the string array is nil, all tools from the AllowedMCP are available
110 AllowedMCP map[string][]string `json:"allowed_mcp"`
111
112 // The list of LSPs that this agent can use
113 // if this is nil, all LSPs are available
114 AllowedLSP []string `json:"allowed_lsp"`
115
116 // Overrides the context paths for this agent
117 ContextPaths []string `json:"context_paths"`
118}
119
120type MCPType string
121
122const (
123 MCPStdio MCPType = "stdio"
124 MCPSse MCPType = "sse"
125)
126
127type MCP struct {
128 Command string `json:"command"`
129 Env []string `json:"env"`
130 Args []string `json:"args"`
131 Type MCPType `json:"type"`
132 URL string `json:"url"`
133 Headers map[string]string `json:"headers"`
134}
135
136type LSPConfig struct {
137 Disabled bool `json:"enabled"`
138 Command string `json:"command"`
139 Args []string `json:"args"`
140 Options any `json:"options"`
141}
142
143type TUIOptions struct {
144 CompactMode bool `json:"compact_mode"`
145 // Here we can add themes later or any TUI related options
146}
147
148type Options struct {
149 ContextPaths []string `json:"context_paths"`
150 TUI TUIOptions `json:"tui"`
151 Debug bool `json:"debug"`
152 DebugLSP bool `json:"debug_lsp"`
153 DisableAutoSummarize bool `json:"disable_auto_summarize"`
154 // Relative to the cwd
155 DataDirectory string `json:"data_directory"`
156}
157
158type PreferredModel struct {
159 ModelID string `json:"model_id"`
160 Provider provider.InferenceProvider `json:"provider"`
161}
162
163type PreferredModels struct {
164 Large PreferredModel `json:"large"`
165 Small PreferredModel `json:"small"`
166}
167
168type Config struct {
169 Models PreferredModels `json:"models"`
170 // List of configured providers
171 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
172
173 // List of configured agents
174 Agents map[AgentID]Agent `json:"agents,omitempty"`
175
176 // List of configured MCPs
177 MCP map[string]MCP `json:"mcp,omitempty"`
178
179 // List of configured LSPs
180 LSP map[string]LSPConfig `json:"lsp,omitempty"`
181
182 // Miscellaneous options
183 Options Options `json:"options"`
184}
185
186var (
187 instance *Config // The single instance of the Singleton
188 cwd string
189 once sync.Once // Ensures the initialization happens only once
190
191)
192
193func loadConfig(cwd string, debug bool) (*Config, error) {
194 // First read the global config file
195 cfgPath := ConfigPath()
196
197 cfg := defaultConfigBasedOnEnv()
198 cfg.Options.Debug = debug
199 defaultLevel := slog.LevelInfo
200 if cfg.Options.Debug {
201 defaultLevel = slog.LevelDebug
202 }
203 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
204 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
205
206 // if file does not exist create it
207 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
208 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
209 return cfg, fmt.Errorf("failed to create directory: %w", err)
210 }
211 if _, err := os.Create(loggingFile); err != nil {
212 return cfg, fmt.Errorf("failed to create log file: %w", err)
213 }
214 }
215
216 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
217 if err != nil {
218 return cfg, fmt.Errorf("failed to open log file: %w", err)
219 }
220 // Configure logger
221 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
222 Level: defaultLevel,
223 }))
224 slog.SetDefault(logger)
225 } else {
226 // Configure logger
227 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
228 Level: defaultLevel,
229 }))
230 slog.SetDefault(logger)
231 }
232 var globalCfg *Config
233 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
234 // some other error occurred while checking the file
235 return nil, err
236 } else if err == nil {
237 // config file exists, read it
238 file, err := os.ReadFile(cfgPath)
239 if err != nil {
240 return nil, err
241 }
242 globalCfg = &Config{}
243 if err := json.Unmarshal(file, globalCfg); err != nil {
244 return nil, err
245 }
246 } else {
247 // config file does not exist, create a new one
248 globalCfg = &Config{}
249 }
250
251 var localConfig *Config
252 // Global config loaded, now read the local config file
253 localConfigPath := filepath.Join(cwd, "crush.json")
254 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
255 // some other error occurred while checking the file
256 return nil, err
257 } else if err == nil {
258 // local config file exists, read it
259 file, err := os.ReadFile(localConfigPath)
260 if err != nil {
261 return nil, err
262 }
263 localConfig = &Config{}
264 if err := json.Unmarshal(file, localConfig); err != nil {
265 return nil, err
266 }
267 }
268
269 // merge options
270 mergeOptions(cfg, globalCfg, localConfig)
271
272 mergeProviderConfigs(cfg, globalCfg, localConfig)
273 // no providers found the app is not initialized yet
274 if len(cfg.Providers) == 0 {
275 return cfg, nil
276 }
277 preferredProvider := getPreferredProvider(cfg.Providers)
278 if preferredProvider != nil {
279 cfg.Models = PreferredModels{
280 Large: PreferredModel{
281 ModelID: preferredProvider.DefaultLargeModel,
282 Provider: preferredProvider.ID,
283 },
284 Small: PreferredModel{
285 ModelID: preferredProvider.DefaultSmallModel,
286 Provider: preferredProvider.ID,
287 },
288 }
289 } else {
290 // No valid providers found, set empty models
291 cfg.Models = PreferredModels{}
292 }
293
294 mergeModels(cfg, globalCfg, localConfig)
295
296 agents := map[AgentID]Agent{
297 AgentCoder: {
298 ID: AgentCoder,
299 Name: "Coder",
300 Description: "An agent that helps with executing coding tasks.",
301 Model: LargeModel,
302 ContextPaths: cfg.Options.ContextPaths,
303 // All tools allowed
304 },
305 AgentTask: {
306 ID: AgentTask,
307 Name: "Task",
308 Description: "An agent that helps with searching for context and finding implementation details.",
309 Model: LargeModel,
310 ContextPaths: cfg.Options.ContextPaths,
311 AllowedTools: []string{
312 "glob",
313 "grep",
314 "ls",
315 "sourcegraph",
316 "view",
317 },
318 // NO MCPs or LSPs by default
319 AllowedMCP: map[string][]string{},
320 AllowedLSP: []string{},
321 },
322 }
323 cfg.Agents = agents
324 mergeAgents(cfg, globalCfg, localConfig)
325 mergeMCPs(cfg, globalCfg, localConfig)
326 mergeLSPs(cfg, globalCfg, localConfig)
327
328 // Validate the final configuration
329 if err := cfg.Validate(); err != nil {
330 return cfg, fmt.Errorf("configuration validation failed: %w", err)
331 }
332
333 return cfg, nil
334}
335
336func Init(workingDir string, debug bool) (*Config, error) {
337 var err error
338 once.Do(func() {
339 cwd = workingDir
340 instance, err = loadConfig(cwd, debug)
341 if err != nil {
342 logging.Error("Failed to load config", "error", err)
343 }
344 })
345
346 return instance, err
347}
348
349func Get() *Config {
350 if instance == nil {
351 // TODO: Handle this better
352 panic("Config not initialized. Call InitConfig first.")
353 }
354 return instance
355}
356
357func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
358 providers := Providers()
359 for _, p := range providers {
360 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
361 return &providerConfig
362 }
363 }
364 // if none found return the first configured provider
365 for _, providerConfig := range configuredProviders {
366 if !providerConfig.Disabled {
367 return &providerConfig
368 }
369 }
370 return nil
371}
372
373func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
374 if other.APIKey != "" {
375 base.APIKey = other.APIKey
376 }
377 // Only change these options if the provider is not a known provider
378 if !slices.Contains(provider.KnownProviders(), p) {
379 if other.BaseURL != "" {
380 base.BaseURL = other.BaseURL
381 }
382 if other.ProviderType != "" {
383 base.ProviderType = other.ProviderType
384 }
385 if len(other.ExtraHeaders) > 0 {
386 if base.ExtraHeaders == nil {
387 base.ExtraHeaders = make(map[string]string)
388 }
389 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
390 }
391 if len(other.ExtraParams) > 0 {
392 if base.ExtraParams == nil {
393 base.ExtraParams = make(map[string]string)
394 }
395 maps.Copy(base.ExtraParams, other.ExtraParams)
396 }
397 }
398
399 if other.Disabled {
400 base.Disabled = other.Disabled
401 }
402
403 if other.DefaultLargeModel != "" {
404 base.DefaultLargeModel = other.DefaultLargeModel
405 }
406 // Add new models if they don't exist
407 if other.Models != nil {
408 for _, model := range other.Models {
409 // check if the model already exists
410 exists := false
411 for _, existingModel := range base.Models {
412 if existingModel.ID == model.ID {
413 exists = true
414 break
415 }
416 }
417 if !exists {
418 base.Models = append(base.Models, model)
419 }
420 }
421 }
422
423 return base
424}
425
426func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
427 if !slices.Contains(provider.KnownProviders(), p) {
428 if providerConfig.ProviderType != provider.TypeOpenAI {
429 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
430 }
431 if providerConfig.BaseURL == "" {
432 return errors.New("base URL must be set for custom providers")
433 }
434 if providerConfig.APIKey == "" {
435 return errors.New("API key must be set for custom providers")
436 }
437 }
438 return nil
439}
440
441func mergeModels(base, global, local *Config) {
442 for _, cfg := range []*Config{global, local} {
443 if cfg == nil {
444 continue
445 }
446 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
447 base.Models.Large = cfg.Models.Large
448 }
449
450 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
451 base.Models.Small = cfg.Models.Small
452 }
453 }
454}
455
456func mergeOptions(base, global, local *Config) {
457 for _, cfg := range []*Config{global, local} {
458 if cfg == nil {
459 continue
460 }
461 baseOptions := base.Options
462 other := cfg.Options
463 if len(other.ContextPaths) > 0 {
464 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
465 }
466
467 if other.TUI.CompactMode {
468 baseOptions.TUI.CompactMode = other.TUI.CompactMode
469 }
470
471 if other.Debug {
472 baseOptions.Debug = other.Debug
473 }
474
475 if other.DebugLSP {
476 baseOptions.DebugLSP = other.DebugLSP
477 }
478
479 if other.DisableAutoSummarize {
480 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
481 }
482
483 if other.DataDirectory != "" {
484 baseOptions.DataDirectory = other.DataDirectory
485 }
486 base.Options = baseOptions
487 }
488}
489
490func mergeAgents(base, global, local *Config) {
491 for _, cfg := range []*Config{global, local} {
492 if cfg == nil {
493 continue
494 }
495 for agentID, newAgent := range cfg.Agents {
496 if _, ok := base.Agents[agentID]; !ok {
497 // New agent - apply defaults
498 newAgent.ID = agentID // Ensure the ID is set correctly
499 if newAgent.Model == "" {
500 newAgent.Model = LargeModel // Default model type
501 }
502 // Context paths are always additive - start with global, then add custom
503 if len(newAgent.ContextPaths) > 0 {
504 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
505 } else {
506 newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
507 }
508 base.Agents[agentID] = newAgent
509 } else {
510 baseAgent := base.Agents[agentID]
511
512 // Special handling for known agents - only allow model changes
513 if agentID == AgentCoder || agentID == AgentTask {
514 if newAgent.Model != "" {
515 baseAgent.Model = newAgent.Model
516 }
517 // For known agents, only allow MCP and LSP configuration
518 if newAgent.AllowedMCP != nil {
519 baseAgent.AllowedMCP = newAgent.AllowedMCP
520 }
521 if newAgent.AllowedLSP != nil {
522 baseAgent.AllowedLSP = newAgent.AllowedLSP
523 }
524 // Context paths are additive for known agents too
525 if len(newAgent.ContextPaths) > 0 {
526 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
527 }
528 } else {
529 // Custom agents - allow full merging
530 if newAgent.Name != "" {
531 baseAgent.Name = newAgent.Name
532 }
533 if newAgent.Description != "" {
534 baseAgent.Description = newAgent.Description
535 }
536 if newAgent.Model != "" {
537 baseAgent.Model = newAgent.Model
538 } else if baseAgent.Model == "" {
539 baseAgent.Model = LargeModel // Default fallback
540 }
541
542 // Boolean fields - always update (including false values)
543 baseAgent.Disabled = newAgent.Disabled
544
545 // Slice/Map fields - update if provided (including empty slices/maps)
546 if newAgent.AllowedTools != nil {
547 baseAgent.AllowedTools = newAgent.AllowedTools
548 }
549 if newAgent.AllowedMCP != nil {
550 baseAgent.AllowedMCP = newAgent.AllowedMCP
551 }
552 if newAgent.AllowedLSP != nil {
553 baseAgent.AllowedLSP = newAgent.AllowedLSP
554 }
555 // Context paths are additive for custom agents too
556 if len(newAgent.ContextPaths) > 0 {
557 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
558 }
559 }
560
561 base.Agents[agentID] = baseAgent
562 }
563 }
564 }
565}
566
567func mergeMCPs(base, global, local *Config) {
568 for _, cfg := range []*Config{global, local} {
569 if cfg == nil {
570 continue
571 }
572 maps.Copy(base.MCP, cfg.MCP)
573 }
574}
575
576func mergeLSPs(base, global, local *Config) {
577 for _, cfg := range []*Config{global, local} {
578 if cfg == nil {
579 continue
580 }
581 maps.Copy(base.LSP, cfg.LSP)
582 }
583}
584
585func mergeProviderConfigs(base, global, local *Config) {
586 for _, cfg := range []*Config{global, local} {
587 if cfg == nil {
588 continue
589 }
590 for providerName, globalProvider := range cfg.Providers {
591 if _, ok := base.Providers[providerName]; !ok {
592 base.Providers[providerName] = globalProvider
593 } else {
594 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
595 }
596 }
597 }
598
599 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
600 for providerName, providerConfig := range base.Providers {
601 err := validateProvider(providerName, providerConfig)
602 if err != nil {
603 logging.Warn("Skipping provider", "name", providerName, "error", err)
604 continue // Skip invalid providers
605 }
606 finalProviders[providerName] = providerConfig
607 }
608 base.Providers = finalProviders
609}
610
611func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
612 switch providerId {
613 case provider.InferenceProviderAnthropic:
614 return ProviderConfig{
615 ID: providerId,
616 ProviderType: provider.TypeAnthropic,
617 }
618 case provider.InferenceProviderOpenAI:
619 return ProviderConfig{
620 ID: providerId,
621 ProviderType: provider.TypeOpenAI,
622 }
623 case provider.InferenceProviderGemini:
624 return ProviderConfig{
625 ID: providerId,
626 ProviderType: provider.TypeGemini,
627 }
628 case provider.InferenceProviderBedrock:
629 return ProviderConfig{
630 ID: providerId,
631 ProviderType: provider.TypeBedrock,
632 }
633 case provider.InferenceProviderAzure:
634 return ProviderConfig{
635 ID: providerId,
636 ProviderType: provider.TypeAzure,
637 }
638 case provider.InferenceProviderOpenRouter:
639 return ProviderConfig{
640 ID: providerId,
641 ProviderType: provider.TypeOpenAI,
642 BaseURL: "https://openrouter.ai/api/v1",
643 ExtraHeaders: map[string]string{
644 "HTTP-Referer": "crush.charm.land",
645 "X-Title": "Crush",
646 },
647 }
648 case provider.InferenceProviderXAI:
649 return ProviderConfig{
650 ID: providerId,
651 ProviderType: provider.TypeXAI,
652 BaseURL: "https://api.x.ai/v1",
653 }
654 case provider.InferenceProviderVertexAI:
655 return ProviderConfig{
656 ID: providerId,
657 ProviderType: provider.TypeVertexAI,
658 }
659 default:
660 return ProviderConfig{
661 ID: providerId,
662 ProviderType: provider.TypeOpenAI,
663 }
664 }
665}
666
667func defaultConfigBasedOnEnv() *Config {
668 cfg := &Config{
669 Options: Options{
670 DataDirectory: defaultDataDirectory,
671 ContextPaths: defaultContextPaths,
672 },
673 Providers: make(map[provider.InferenceProvider]ProviderConfig),
674 Agents: make(map[AgentID]Agent),
675 LSP: make(map[string]LSPConfig),
676 MCP: make(map[string]MCP),
677 }
678
679 providers := Providers()
680
681 for _, p := range providers {
682 if strings.HasPrefix(p.APIKey, "$") {
683 envVar := strings.TrimPrefix(p.APIKey, "$")
684 if apiKey := os.Getenv(envVar); apiKey != "" {
685 providerConfig := providerDefaultConfig(p.ID)
686 providerConfig.APIKey = apiKey
687 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
688 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
689 baseURL := p.APIEndpoint
690 if strings.HasPrefix(baseURL, "$") {
691 envVar := strings.TrimPrefix(baseURL, "$")
692 baseURL = os.Getenv(envVar)
693 }
694 providerConfig.BaseURL = baseURL
695 for _, model := range p.Models {
696 providerConfig.Models = append(providerConfig.Models, Model{
697 ID: model.ID,
698 Name: model.Name,
699 CostPer1MIn: model.CostPer1MIn,
700 CostPer1MOut: model.CostPer1MOut,
701 CostPer1MInCached: model.CostPer1MInCached,
702 CostPer1MOutCached: model.CostPer1MOutCached,
703 ContextWindow: model.ContextWindow,
704 DefaultMaxTokens: model.DefaultMaxTokens,
705 CanReason: model.CanReason,
706 SupportsImages: model.SupportsImages,
707 })
708 }
709 cfg.Providers[p.ID] = providerConfig
710 }
711 }
712 }
713 // TODO: support local models
714
715 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
716 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
717 providerConfig.ExtraParams = map[string]string{
718 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
719 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
720 }
721 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
722 }
723
724 if hasAWSCredentials() {
725 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
726 providerConfig.ExtraParams = map[string]string{
727 "region": os.Getenv("AWS_DEFAULT_REGION"),
728 }
729 if providerConfig.ExtraParams["region"] == "" {
730 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
731 }
732 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
733 }
734 return cfg
735}
736
737func hasAWSCredentials() bool {
738 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
739 return true
740 }
741
742 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
743 return true
744 }
745
746 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
747 return true
748 }
749
750 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
751 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
752 return true
753 }
754
755 return false
756}
757
758func WorkingDirectory() string {
759 return cwd
760}
761
762// TODO: Handle error state
763
764func GetAgentModel(agentID AgentID) Model {
765 cfg := Get()
766 agent, ok := cfg.Agents[agentID]
767 if !ok {
768 logging.Error("Agent not found", "agent_id", agentID)
769 return Model{}
770 }
771
772 var model PreferredModel
773 switch agent.Model {
774 case LargeModel:
775 model = cfg.Models.Large
776 case SmallModel:
777 model = cfg.Models.Small
778 default:
779 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
780 model = cfg.Models.Large // Fallback to large model
781 }
782 providerConfig, ok := cfg.Providers[model.Provider]
783 if !ok {
784 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
785 return Model{}
786 }
787
788 for _, m := range providerConfig.Models {
789 if m.ID == model.ModelID {
790 return m
791 }
792 }
793
794 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
795 return Model{}
796}
797
798func GetAgentProvider(agentID AgentID) ProviderConfig {
799 cfg := Get()
800 agent, ok := cfg.Agents[agentID]
801 if !ok {
802 logging.Error("Agent not found", "agent_id", agentID)
803 return ProviderConfig{}
804 }
805
806 var model PreferredModel
807 switch agent.Model {
808 case LargeModel:
809 model = cfg.Models.Large
810 case SmallModel:
811 model = cfg.Models.Small
812 default:
813 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
814 model = cfg.Models.Large // Fallback to large model
815 }
816
817 providerConfig, ok := cfg.Providers[model.Provider]
818 if !ok {
819 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
820 return ProviderConfig{}
821 }
822
823 return providerConfig
824}
825
826func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
827 cfg := Get()
828 providerConfig, ok := cfg.Providers[provider]
829 if !ok {
830 logging.Error("Provider not found", "provider", provider)
831 return Model{}
832 }
833
834 for _, model := range providerConfig.Models {
835 if model.ID == modelID {
836 return model
837 }
838 }
839
840 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
841 return Model{}
842}
843
844func GetModel(modelType ModelType) Model {
845 cfg := Get()
846 var model PreferredModel
847 switch modelType {
848 case LargeModel:
849 model = cfg.Models.Large
850 case SmallModel:
851 model = cfg.Models.Small
852 default:
853 model = cfg.Models.Large // Fallback to large model
854 }
855 providerConfig, ok := cfg.Providers[model.Provider]
856 if !ok {
857 return Model{}
858 }
859
860 for _, m := range providerConfig.Models {
861 if m.ID == model.ModelID {
862 return m
863 }
864 }
865 return Model{}
866}
867
868func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
869 cfg := Get()
870 switch modelType {
871 case LargeModel:
872 cfg.Models.Large = model
873 case SmallModel:
874 cfg.Models.Small = model
875 default:
876 return fmt.Errorf("unknown model type: %s", modelType)
877 }
878 return nil
879}
880
881// ValidationError represents a configuration validation error
882type ValidationError struct {
883 Field string
884 Message string
885}
886
887func (e ValidationError) Error() string {
888 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
889}
890
891// ValidationErrors represents multiple validation errors
892type ValidationErrors []ValidationError
893
894func (e ValidationErrors) Error() string {
895 if len(e) == 0 {
896 return "no validation errors"
897 }
898 if len(e) == 1 {
899 return e[0].Error()
900 }
901
902 var messages []string
903 for _, err := range e {
904 messages = append(messages, err.Error())
905 }
906 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
907}
908
909// HasErrors returns true if there are any validation errors
910func (e ValidationErrors) HasErrors() bool {
911 return len(e) > 0
912}
913
914// Add appends a new validation error
915func (e *ValidationErrors) Add(field, message string) {
916 *e = append(*e, ValidationError{Field: field, Message: message})
917}
918
919// Validate performs comprehensive validation of the configuration
920func (c *Config) Validate() error {
921 var errors ValidationErrors
922
923 // Validate providers
924 c.validateProviders(&errors)
925
926 // Validate models
927 c.validateModels(&errors)
928
929 // Validate agents
930 c.validateAgents(&errors)
931
932 // Validate options
933 c.validateOptions(&errors)
934
935 // Validate MCP configurations
936 c.validateMCPs(&errors)
937
938 // Validate LSP configurations
939 c.validateLSPs(&errors)
940
941 // Validate cross-references
942 c.validateCrossReferences(&errors)
943
944 // Validate completeness
945 c.validateCompleteness(&errors)
946
947 if errors.HasErrors() {
948 return errors
949 }
950
951 return nil
952}
953
954// validateProviders validates all provider configurations
955func (c *Config) validateProviders(errors *ValidationErrors) {
956 if c.Providers == nil {
957 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
958 }
959
960 knownProviders := provider.KnownProviders()
961 validTypes := []provider.Type{
962 provider.TypeOpenAI,
963 provider.TypeAnthropic,
964 provider.TypeGemini,
965 provider.TypeAzure,
966 provider.TypeBedrock,
967 provider.TypeVertexAI,
968 provider.TypeXAI,
969 }
970
971 for providerID, providerConfig := range c.Providers {
972 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
973
974 // Validate API key for non-disabled providers
975 if !providerConfig.Disabled && providerConfig.APIKey == "" {
976 // Special case for AWS Bedrock and VertexAI which may use other auth methods
977 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
978 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
979 }
980 }
981
982 // Validate provider type
983 validType := false
984 for _, vt := range validTypes {
985 if providerConfig.ProviderType == vt {
986 validType = true
987 break
988 }
989 }
990 if !validType {
991 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
992 }
993
994 // Validate custom providers
995 isKnownProvider := false
996 for _, kp := range knownProviders {
997 if providerID == kp {
998 isKnownProvider = true
999 break
1000 }
1001 }
1002
1003 if !isKnownProvider {
1004 // Custom provider validation
1005 if providerConfig.BaseURL == "" {
1006 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1007 }
1008 if providerConfig.ProviderType != provider.TypeOpenAI {
1009 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1010 }
1011 }
1012
1013 // Validate models
1014 modelIDs := make(map[string]bool)
1015 for i, model := range providerConfig.Models {
1016 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1017
1018 // Check for duplicate model IDs
1019 if modelIDs[model.ID] {
1020 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1021 }
1022 modelIDs[model.ID] = true
1023
1024 // Validate required model fields
1025 if model.ID == "" {
1026 errors.Add(modelFieldPrefix+".id", "model ID is required")
1027 }
1028 if model.Name == "" {
1029 errors.Add(modelFieldPrefix+".name", "model name is required")
1030 }
1031 if model.ContextWindow <= 0 {
1032 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1033 }
1034 if model.DefaultMaxTokens <= 0 {
1035 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1036 }
1037 if model.DefaultMaxTokens > model.ContextWindow {
1038 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1039 }
1040
1041 // Validate cost fields
1042 if model.CostPer1MIn < 0 {
1043 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1044 }
1045 if model.CostPer1MOut < 0 {
1046 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1047 }
1048 if model.CostPer1MInCached < 0 {
1049 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1050 }
1051 if model.CostPer1MOutCached < 0 {
1052 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1053 }
1054 }
1055
1056 // Validate default model references
1057 if providerConfig.DefaultLargeModel != "" {
1058 if !modelIDs[providerConfig.DefaultLargeModel] {
1059 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1060 }
1061 }
1062 if providerConfig.DefaultSmallModel != "" {
1063 if !modelIDs[providerConfig.DefaultSmallModel] {
1064 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1065 }
1066 }
1067
1068 // Validate provider-specific requirements
1069 c.validateProviderSpecific(providerID, providerConfig, errors)
1070 }
1071}
1072
1073// validateProviderSpecific validates provider-specific requirements
1074func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1075 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1076
1077 switch providerID {
1078 case provider.InferenceProviderVertexAI:
1079 if !providerConfig.Disabled {
1080 if providerConfig.ExtraParams == nil {
1081 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1082 } else {
1083 if providerConfig.ExtraParams["project"] == "" {
1084 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1085 }
1086 if providerConfig.ExtraParams["location"] == "" {
1087 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1088 }
1089 }
1090 }
1091 case provider.InferenceProviderBedrock:
1092 if !providerConfig.Disabled {
1093 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1094 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1095 }
1096 // Check for AWS credentials in environment
1097 if !hasAWSCredentials() {
1098 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1099 }
1100 }
1101 }
1102}
1103
1104// validateModels validates preferred model configurations
1105func (c *Config) validateModels(errors *ValidationErrors) {
1106 // Validate large model
1107 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1108 if c.Models.Large.ModelID == "" {
1109 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1110 }
1111 if c.Models.Large.Provider == "" {
1112 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1113 }
1114
1115 // Check if provider exists and is not disabled
1116 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1117 if providerConfig.Disabled {
1118 errors.Add("models.large.provider", "large model provider is disabled")
1119 }
1120
1121 // Check if model exists in provider
1122 modelExists := false
1123 for _, model := range providerConfig.Models {
1124 if model.ID == c.Models.Large.ModelID {
1125 modelExists = true
1126 break
1127 }
1128 }
1129 if !modelExists {
1130 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1131 }
1132 } else {
1133 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1134 }
1135 }
1136
1137 // Validate small model
1138 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1139 if c.Models.Small.ModelID == "" {
1140 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1141 }
1142 if c.Models.Small.Provider == "" {
1143 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1144 }
1145
1146 // Check if provider exists and is not disabled
1147 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1148 if providerConfig.Disabled {
1149 errors.Add("models.small.provider", "small model provider is disabled")
1150 }
1151
1152 // Check if model exists in provider
1153 modelExists := false
1154 for _, model := range providerConfig.Models {
1155 if model.ID == c.Models.Small.ModelID {
1156 modelExists = true
1157 break
1158 }
1159 }
1160 if !modelExists {
1161 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1162 }
1163 } else {
1164 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1165 }
1166 }
1167}
1168
1169// validateAgents validates agent configurations
1170func (c *Config) validateAgents(errors *ValidationErrors) {
1171 if c.Agents == nil {
1172 c.Agents = make(map[AgentID]Agent)
1173 }
1174
1175 validTools := []string{
1176 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1177 }
1178
1179 for agentID, agent := range c.Agents {
1180 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1181
1182 // Validate agent ID consistency
1183 if agent.ID != agentID {
1184 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1185 }
1186
1187 // Validate required fields
1188 if agent.ID == "" {
1189 errors.Add(fieldPrefix+".id", "agent ID is required")
1190 }
1191 if agent.Name == "" {
1192 errors.Add(fieldPrefix+".name", "agent name is required")
1193 }
1194
1195 // Validate model type
1196 if agent.Model != LargeModel && agent.Model != SmallModel {
1197 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1198 }
1199
1200 // Validate allowed tools
1201 if agent.AllowedTools != nil {
1202 for i, tool := range agent.AllowedTools {
1203 validTool := false
1204 for _, vt := range validTools {
1205 if tool == vt {
1206 validTool = true
1207 break
1208 }
1209 }
1210 if !validTool {
1211 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1212 }
1213 }
1214 }
1215
1216 // Validate MCP references
1217 if agent.AllowedMCP != nil {
1218 for mcpName := range agent.AllowedMCP {
1219 if _, exists := c.MCP[mcpName]; !exists {
1220 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1221 }
1222 }
1223 }
1224
1225 // Validate LSP references
1226 if agent.AllowedLSP != nil {
1227 for _, lspName := range agent.AllowedLSP {
1228 if _, exists := c.LSP[lspName]; !exists {
1229 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1230 }
1231 }
1232 }
1233
1234 // Validate context paths (basic path validation)
1235 for i, contextPath := range agent.ContextPaths {
1236 if contextPath == "" {
1237 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1238 }
1239 // Check for invalid characters in path
1240 if strings.Contains(contextPath, "\x00") {
1241 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1242 }
1243 }
1244
1245 // Validate known agents maintain their core properties
1246 if agentID == AgentCoder {
1247 if agent.Name != "Coder" {
1248 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1249 }
1250 if agent.Description != "An agent that helps with executing coding tasks." {
1251 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1252 }
1253 } else if agentID == AgentTask {
1254 if agent.Name != "Task" {
1255 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1256 }
1257 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1258 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1259 }
1260 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1261 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1262 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1263 }
1264 }
1265 }
1266}
1267
1268// validateOptions validates configuration options
1269func (c *Config) validateOptions(errors *ValidationErrors) {
1270 // Validate data directory
1271 if c.Options.DataDirectory == "" {
1272 errors.Add("options.data_directory", "data directory is required")
1273 }
1274
1275 // Validate context paths
1276 for i, contextPath := range c.Options.ContextPaths {
1277 if contextPath == "" {
1278 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1279 }
1280 if strings.Contains(contextPath, "\x00") {
1281 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1282 }
1283 }
1284}
1285
1286// validateMCPs validates MCP configurations
1287func (c *Config) validateMCPs(errors *ValidationErrors) {
1288 if c.MCP == nil {
1289 c.MCP = make(map[string]MCP)
1290 }
1291
1292 for mcpName, mcpConfig := range c.MCP {
1293 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1294
1295 // Validate MCP type
1296 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1297 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1298 }
1299
1300 // Validate based on type
1301 if mcpConfig.Type == MCPStdio {
1302 if mcpConfig.Command == "" {
1303 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1304 }
1305 } else if mcpConfig.Type == MCPSse {
1306 if mcpConfig.URL == "" {
1307 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1308 }
1309 }
1310 }
1311}
1312
1313// validateLSPs validates LSP configurations
1314func (c *Config) validateLSPs(errors *ValidationErrors) {
1315 if c.LSP == nil {
1316 c.LSP = make(map[string]LSPConfig)
1317 }
1318
1319 for lspName, lspConfig := range c.LSP {
1320 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1321
1322 if lspConfig.Command == "" {
1323 errors.Add(fieldPrefix+".command", "command is required for LSP")
1324 }
1325 }
1326}
1327
1328// validateCrossReferences validates cross-references between different config sections
1329func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1330 // Validate that agents can use their assigned model types
1331 for agentID, agent := range c.Agents {
1332 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1333
1334 var preferredModel PreferredModel
1335 switch agent.Model {
1336 case LargeModel:
1337 preferredModel = c.Models.Large
1338 case SmallModel:
1339 preferredModel = c.Models.Small
1340 }
1341
1342 if preferredModel.Provider != "" {
1343 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1344 if providerConfig.Disabled {
1345 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1346 }
1347 }
1348 }
1349 }
1350}
1351
1352// validateCompleteness validates that the configuration is complete and usable
1353func (c *Config) validateCompleteness(errors *ValidationErrors) {
1354 // Check for at least one valid, non-disabled provider
1355 hasValidProvider := false
1356 for _, providerConfig := range c.Providers {
1357 if !providerConfig.Disabled {
1358 hasValidProvider = true
1359 break
1360 }
1361 }
1362 if !hasValidProvider {
1363 errors.Add("providers", "at least one non-disabled provider is required")
1364 }
1365
1366 // Check that default agents exist
1367 if _, exists := c.Agents[AgentCoder]; !exists {
1368 errors.Add("agents", "coder agent is required")
1369 }
1370 if _, exists := c.Agents[AgentTask]; !exists {
1371 errors.Add("agents", "task agent is required")
1372 }
1373
1374 // Check that preferred models are set if providers exist
1375 if hasValidProvider {
1376 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1377 errors.Add("models.large", "large preferred model must be configured when providers are available")
1378 }
1379 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1380 errors.Add("models.small", "small preferred model must be configured when providers are available")
1381 }
1382 }
1383}