1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/invopop/jsonschema"
18)
19
20const (
21 defaultDataDirectory = ".crush"
22 defaultLogLevel = "info"
23 appName = "crush"
24
25 MaxTokensFallbackDefault = 4096
26)
27
28var defaultContextPaths = []string{
29 ".github/copilot-instructions.md",
30 ".cursorrules",
31 ".cursor/rules/",
32 "CLAUDE.md",
33 "CLAUDE.local.md",
34 "GEMINI.md",
35 "gemini.md",
36 "crush.md",
37 "crush.local.md",
38 "Crush.md",
39 "Crush.local.md",
40 "CRUSH.md",
41 "CRUSH.local.md",
42}
43
44type AgentID string
45
46const (
47 AgentCoder AgentID = "coder"
48 AgentTask AgentID = "task"
49)
50
51type ModelType string
52
53const (
54 LargeModel ModelType = "large"
55 SmallModel ModelType = "small"
56)
57
58type Model struct {
59 ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
60 Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
61 CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
62 CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
63 CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
64 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
65 ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
66 DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
67 CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
68 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
69 HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
70 SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
71}
72
73type VertexAIOptions struct {
74 APIKey string `json:"api_key,omitempty"`
75 Project string `json:"project,omitempty"`
76 Location string `json:"location,omitempty"`
77}
78
79type ProviderConfig struct {
80 ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
81 BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
82 ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
83 APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
84 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
85 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
86 // used for e.x for vertex to set the project
87 ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
88
89 DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
90 DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
91
92 Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
93}
94
95type Agent struct {
96 ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
97 Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
98 Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
99 // This is the id of the system prompt used by the agent
100 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
101
102 Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
103
104 // The available tools for the agent
105 // if this is nil, all tools are available
106 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
107
108 // this tells us which MCPs are available for this agent
109 // if this is empty all mcps are available
110 // the string array is the list of tools from the AllowedMCP the agent has available
111 // if the string array is nil, all tools from the AllowedMCP are available
112 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
113
114 // The list of LSPs that this agent can use
115 // if this is nil, all LSPs are available
116 AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
117
118 // Overrides the context paths for this agent
119 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
120}
121
122type MCPType string
123
124const (
125 MCPStdio MCPType = "stdio"
126 MCPSse MCPType = "sse"
127)
128
129type MCP struct {
130 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
131 Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
132 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
133 Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"`
134 URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
135 // TODO: maybe make it possible to get the value from the env
136 Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
137}
138
139type LSPConfig struct {
140 Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
141 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
142 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
143 Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
144}
145
146type TUIOptions struct {
147 CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
148 // Here we can add themes later or any TUI related options
149}
150
151type Options struct {
152 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
153 TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
154 Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
155 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
156 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
157 AutoOpenVSCodeDiff bool `json:"auto_open_vscode_diff,omitempty" jsonschema:"title=Auto Open VS Code Diff,description=Automatically open VS Code diff view when showing file changes,default=false"`
158 // Relative to the cwd
159 DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
160}
161
162type PreferredModel struct {
163 ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
164 Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
165 // ReasoningEffort overrides the default reasoning effort for this model
166 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
167 // MaxTokens overrides the default max tokens for this model
168 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
169
170 // Think indicates if the model should think, only applicable for anthropic reasoning models
171 Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
172}
173
174type PreferredModels struct {
175 Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
176 Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
177}
178
179type Config struct {
180 Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
181 // List of configured providers
182 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
183
184 // List of configured agents
185 Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
186
187 // List of configured MCPs
188 MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
189
190 // List of configured LSPs
191 LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
192
193 // Miscellaneous options
194 Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
195}
196
197var (
198 instance *Config // The single instance of the Singleton
199 cwd string
200 once sync.Once // Ensures the initialization happens only once
201
202)
203
204func readConfigFile(path string) (*Config, error) {
205 var cfg *Config
206 if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) {
207 // some other error occurred while checking the file
208 return nil, err
209 } else if err == nil {
210 // config file exists, read it
211 file, err := os.ReadFile(path)
212 if err != nil {
213 return nil, err
214 }
215 cfg = &Config{}
216 if err := json.Unmarshal(file, cfg); err != nil {
217 return nil, err
218 }
219 } else {
220 // config file does not exist, create a new one
221 cfg = &Config{}
222 }
223 return cfg, nil
224}
225
226func loadConfig(cwd string, debug bool) (*Config, error) {
227 // First read the global config file
228 cfgPath := ConfigPath()
229
230 cfg := defaultConfigBasedOnEnv()
231 cfg.Options.Debug = debug
232 defaultLevel := slog.LevelInfo
233 if cfg.Options.Debug {
234 defaultLevel = slog.LevelDebug
235 }
236 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
237 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
238
239 // if file does not exist create it
240 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
241 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
242 return cfg, fmt.Errorf("failed to create directory: %w", err)
243 }
244 if _, err := os.Create(loggingFile); err != nil {
245 return cfg, fmt.Errorf("failed to create log file: %w", err)
246 }
247 }
248
249 messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages")
250
251 if _, err := os.Stat(messagesPath); os.IsNotExist(err) {
252 if err := os.MkdirAll(messagesPath, 0o756); err != nil {
253 return cfg, fmt.Errorf("failed to create directory: %w", err)
254 }
255 }
256 logging.MessageDir = messagesPath
257
258 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
259 if err != nil {
260 return cfg, fmt.Errorf("failed to open log file: %w", err)
261 }
262 // Configure logger
263 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
264 Level: defaultLevel,
265 }))
266 slog.SetDefault(logger)
267 } else {
268 // Configure logger
269 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
270 Level: defaultLevel,
271 }))
272 slog.SetDefault(logger)
273 }
274
275 priorityOrderedConfigFiles := []string{
276 cfgPath, // Global config file
277 filepath.Join(cwd, "crush.json"), // Local config file
278 filepath.Join(cwd, ".crush.json"), // Local config file
279 }
280
281 configs := make([]*Config, 0)
282 for _, path := range priorityOrderedConfigFiles {
283 localConfig, err := readConfigFile(path)
284 if err != nil {
285 return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
286 }
287 if localConfig != nil {
288 // If the config file was read successfully, add it to the list
289 configs = append(configs, localConfig)
290 }
291 }
292
293 // merge options
294 mergeOptions(cfg, configs...)
295
296 mergeProviderConfigs(cfg, configs...)
297 // no providers found the app is not initialized yet
298 if len(cfg.Providers) == 0 {
299 return cfg, nil
300 }
301 preferredProvider := getPreferredProvider(cfg.Providers)
302 if preferredProvider != nil {
303 cfg.Models = PreferredModels{
304 Large: PreferredModel{
305 ModelID: preferredProvider.DefaultLargeModel,
306 Provider: preferredProvider.ID,
307 },
308 Small: PreferredModel{
309 ModelID: preferredProvider.DefaultSmallModel,
310 Provider: preferredProvider.ID,
311 },
312 }
313 } else {
314 // No valid providers found, set empty models
315 cfg.Models = PreferredModels{}
316 }
317
318 mergeModels(cfg, configs...)
319
320 agents := map[AgentID]Agent{
321 AgentCoder: {
322 ID: AgentCoder,
323 Name: "Coder",
324 Description: "An agent that helps with executing coding tasks.",
325 Model: LargeModel,
326 ContextPaths: cfg.Options.ContextPaths,
327 // All tools allowed
328 },
329 AgentTask: {
330 ID: AgentTask,
331 Name: "Task",
332 Description: "An agent that helps with searching for context and finding implementation details.",
333 Model: LargeModel,
334 ContextPaths: cfg.Options.ContextPaths,
335 AllowedTools: []string{
336 "glob",
337 "grep",
338 "ls",
339 "sourcegraph",
340 "view",
341 },
342 // NO MCPs or LSPs by default
343 AllowedMCP: map[string][]string{},
344 AllowedLSP: []string{},
345 },
346 }
347 cfg.Agents = agents
348 mergeAgents(cfg, configs...)
349 mergeMCPs(cfg, configs...)
350 mergeLSPs(cfg, configs...)
351
352 // Force enable VS Code diff when running inside VS Code
353 if os.Getenv("VSCODE_INJECTION") == "1" {
354 cfg.Options.AutoOpenVSCodeDiff = true
355 }
356
357 // Validate the final configuration
358 if err := cfg.Validate(); err != nil {
359 return cfg, fmt.Errorf("configuration validation failed: %w", err)
360 }
361
362 return cfg, nil
363}
364
365func Init(workingDir string, debug bool) (*Config, error) {
366 var err error
367 once.Do(func() {
368 cwd = workingDir
369 instance, err = loadConfig(cwd, debug)
370 if err != nil {
371 logging.Error("Failed to load config", "error", err)
372 }
373 })
374
375 return instance, err
376}
377
378func Get() *Config {
379 if instance == nil {
380 // TODO: Handle this better
381 panic("Config not initialized. Call InitConfig first.")
382 }
383 return instance
384}
385
386func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
387 providers := Providers()
388 for _, p := range providers {
389 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
390 return &providerConfig
391 }
392 }
393 // if none found return the first configured provider
394 for _, providerConfig := range configuredProviders {
395 if !providerConfig.Disabled {
396 return &providerConfig
397 }
398 }
399 return nil
400}
401
402func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
403 if other.APIKey != "" {
404 base.APIKey = other.APIKey
405 }
406 // Only change these options if the provider is not a known provider
407 if !slices.Contains(provider.KnownProviders(), p) {
408 if other.BaseURL != "" {
409 base.BaseURL = other.BaseURL
410 }
411 if other.ProviderType != "" {
412 base.ProviderType = other.ProviderType
413 }
414 if len(other.ExtraHeaders) > 0 {
415 if base.ExtraHeaders == nil {
416 base.ExtraHeaders = make(map[string]string)
417 }
418 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
419 }
420 if len(other.ExtraParams) > 0 {
421 if base.ExtraParams == nil {
422 base.ExtraParams = make(map[string]string)
423 }
424 maps.Copy(base.ExtraParams, other.ExtraParams)
425 }
426 }
427
428 if other.Disabled {
429 base.Disabled = other.Disabled
430 }
431
432 if other.DefaultLargeModel != "" {
433 base.DefaultLargeModel = other.DefaultLargeModel
434 }
435 // Add new models if they don't exist
436 if other.Models != nil {
437 for _, model := range other.Models {
438 // check if the model already exists
439 exists := false
440 for _, existingModel := range base.Models {
441 if existingModel.ID == model.ID {
442 exists = true
443 break
444 }
445 }
446 if !exists {
447 base.Models = append(base.Models, model)
448 }
449 }
450 }
451
452 return base
453}
454
455func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
456 if !slices.Contains(provider.KnownProviders(), p) {
457 if providerConfig.ProviderType != provider.TypeOpenAI {
458 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
459 }
460 if providerConfig.BaseURL == "" {
461 return errors.New("base URL must be set for custom providers")
462 }
463 if providerConfig.APIKey == "" {
464 return errors.New("API key must be set for custom providers")
465 }
466 }
467 return nil
468}
469
470func mergeModels(base *Config, others ...*Config) {
471 for _, cfg := range others {
472 if cfg == nil {
473 continue
474 }
475 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
476 base.Models.Large = cfg.Models.Large
477 }
478
479 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
480 base.Models.Small = cfg.Models.Small
481 }
482 }
483}
484
485func mergeOptions(base *Config, others ...*Config) {
486 for _, cfg := range others {
487 if cfg == nil {
488 continue
489 }
490 baseOptions := base.Options
491 other := cfg.Options
492 if len(other.ContextPaths) > 0 {
493 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
494 }
495
496 if other.TUI.CompactMode {
497 baseOptions.TUI.CompactMode = other.TUI.CompactMode
498 }
499
500 if other.Debug {
501 baseOptions.Debug = other.Debug
502 }
503
504 if other.DebugLSP {
505 baseOptions.DebugLSP = other.DebugLSP
506 }
507
508 if other.DisableAutoSummarize {
509 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
510 }
511
512 if other.AutoOpenVSCodeDiff {
513 baseOptions.AutoOpenVSCodeDiff = other.AutoOpenVSCodeDiff
514 }
515
516 if other.DataDirectory != "" {
517 baseOptions.DataDirectory = other.DataDirectory
518 }
519 base.Options = baseOptions
520 }
521}
522
523func mergeAgents(base *Config, others ...*Config) {
524 for _, cfg := range others {
525 if cfg == nil {
526 continue
527 }
528 for agentID, newAgent := range cfg.Agents {
529 if _, ok := base.Agents[agentID]; !ok {
530 newAgent.ID = agentID
531 if newAgent.Model == "" {
532 newAgent.Model = LargeModel
533 }
534 if len(newAgent.ContextPaths) > 0 {
535 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
536 } else {
537 newAgent.ContextPaths = base.Options.ContextPaths
538 }
539 base.Agents[agentID] = newAgent
540 } else {
541 baseAgent := base.Agents[agentID]
542
543 if agentID == AgentCoder || agentID == AgentTask {
544 if newAgent.Model != "" {
545 baseAgent.Model = newAgent.Model
546 }
547 if newAgent.AllowedMCP != nil {
548 baseAgent.AllowedMCP = newAgent.AllowedMCP
549 }
550 if newAgent.AllowedLSP != nil {
551 baseAgent.AllowedLSP = newAgent.AllowedLSP
552 }
553 // Context paths are additive for known agents too
554 if len(newAgent.ContextPaths) > 0 {
555 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
556 }
557 } else {
558 if newAgent.Name != "" {
559 baseAgent.Name = newAgent.Name
560 }
561 if newAgent.Description != "" {
562 baseAgent.Description = newAgent.Description
563 }
564 if newAgent.Model != "" {
565 baseAgent.Model = newAgent.Model
566 } else if baseAgent.Model == "" {
567 baseAgent.Model = LargeModel
568 }
569
570 baseAgent.Disabled = newAgent.Disabled
571
572 if newAgent.AllowedTools != nil {
573 baseAgent.AllowedTools = newAgent.AllowedTools
574 }
575 if newAgent.AllowedMCP != nil {
576 baseAgent.AllowedMCP = newAgent.AllowedMCP
577 }
578 if newAgent.AllowedLSP != nil {
579 baseAgent.AllowedLSP = newAgent.AllowedLSP
580 }
581 if len(newAgent.ContextPaths) > 0 {
582 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
583 }
584 }
585
586 base.Agents[agentID] = baseAgent
587 }
588 }
589 }
590}
591
592func mergeMCPs(base *Config, others ...*Config) {
593 for _, cfg := range others {
594 if cfg == nil {
595 continue
596 }
597 maps.Copy(base.MCP, cfg.MCP)
598 }
599}
600
601func mergeLSPs(base *Config, others ...*Config) {
602 for _, cfg := range others {
603 if cfg == nil {
604 continue
605 }
606 maps.Copy(base.LSP, cfg.LSP)
607 }
608}
609
610func mergeProviderConfigs(base *Config, others ...*Config) {
611 for _, cfg := range others {
612 if cfg == nil {
613 continue
614 }
615 for providerName, p := range cfg.Providers {
616 p.ID = providerName
617 if _, ok := base.Providers[providerName]; !ok {
618 if slices.Contains(provider.KnownProviders(), providerName) {
619 providers := Providers()
620 for _, providerDef := range providers {
621 if providerDef.ID == providerName {
622 logging.Info("Using default provider config for", "provider", providerName)
623 baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey)
624 base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p)
625 break
626 }
627 }
628 } else {
629 base.Providers[providerName] = p
630 }
631 } else {
632 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
633 }
634 }
635 }
636
637 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
638 for providerName, providerConfig := range base.Providers {
639 err := validateProvider(providerName, providerConfig)
640 if err != nil {
641 logging.Warn("Skipping provider", "name", providerName, "error", err)
642 continue // Skip invalid providers
643 }
644 finalProviders[providerName] = providerConfig
645 }
646 base.Providers = finalProviders
647}
648
649func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
650 switch providerID {
651 case provider.InferenceProviderAnthropic:
652 return ProviderConfig{
653 ID: providerID,
654 ProviderType: provider.TypeAnthropic,
655 }
656 case provider.InferenceProviderOpenAI:
657 return ProviderConfig{
658 ID: providerID,
659 ProviderType: provider.TypeOpenAI,
660 }
661 case provider.InferenceProviderGemini:
662 return ProviderConfig{
663 ID: providerID,
664 ProviderType: provider.TypeGemini,
665 }
666 case provider.InferenceProviderBedrock:
667 return ProviderConfig{
668 ID: providerID,
669 ProviderType: provider.TypeBedrock,
670 }
671 case provider.InferenceProviderAzure:
672 return ProviderConfig{
673 ID: providerID,
674 ProviderType: provider.TypeAzure,
675 }
676 case provider.InferenceProviderOpenRouter:
677 return ProviderConfig{
678 ID: providerID,
679 ProviderType: provider.TypeOpenAI,
680 BaseURL: "https://openrouter.ai/api/v1",
681 ExtraHeaders: map[string]string{
682 "HTTP-Referer": "crush.charm.land",
683 "X-Title": "Crush",
684 },
685 }
686 case provider.InferenceProviderXAI:
687 return ProviderConfig{
688 ID: providerID,
689 ProviderType: provider.TypeXAI,
690 BaseURL: "https://api.x.ai/v1",
691 }
692 case provider.InferenceProviderVertexAI:
693 return ProviderConfig{
694 ID: providerID,
695 ProviderType: provider.TypeVertexAI,
696 }
697 default:
698 return ProviderConfig{
699 ID: providerID,
700 ProviderType: provider.TypeOpenAI,
701 }
702 }
703}
704
705func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig {
706 providerConfig := providerDefaultConfig(p.ID)
707 providerConfig.APIKey = apiKey
708 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
709 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
710 baseURL := p.APIEndpoint
711 if strings.HasPrefix(baseURL, "$") {
712 envVar := strings.TrimPrefix(baseURL, "$")
713 baseURL = os.Getenv(envVar)
714 }
715 providerConfig.BaseURL = baseURL
716 for _, model := range p.Models {
717 configModel := Model{
718 ID: model.ID,
719 Name: model.Name,
720 CostPer1MIn: model.CostPer1MIn,
721 CostPer1MOut: model.CostPer1MOut,
722 CostPer1MInCached: model.CostPer1MInCached,
723 CostPer1MOutCached: model.CostPer1MOutCached,
724 ContextWindow: model.ContextWindow,
725 DefaultMaxTokens: model.DefaultMaxTokens,
726 CanReason: model.CanReason,
727 SupportsImages: model.SupportsImages,
728 }
729 // Set reasoning effort for reasoning models
730 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
731 configModel.HasReasoningEffort = model.HasReasoningEffort
732 configModel.ReasoningEffort = model.DefaultReasoningEffort
733 }
734 providerConfig.Models = append(providerConfig.Models, configModel)
735 }
736 return providerConfig
737}
738
739func defaultConfigBasedOnEnv() *Config {
740 // VS Code diff is disabled by default
741 autoOpenVSCodeDiff := false
742
743 cfg := &Config{
744 Options: Options{
745 DataDirectory: defaultDataDirectory,
746 ContextPaths: defaultContextPaths,
747 AutoOpenVSCodeDiff: autoOpenVSCodeDiff,
748 },
749 Providers: make(map[provider.InferenceProvider]ProviderConfig),
750 Agents: make(map[AgentID]Agent),
751 LSP: make(map[string]LSPConfig),
752 MCP: make(map[string]MCP),
753 }
754
755 providers := Providers()
756
757 for _, p := range providers {
758 if strings.HasPrefix(p.APIKey, "$") {
759 envVar := strings.TrimPrefix(p.APIKey, "$")
760 if apiKey := os.Getenv(envVar); apiKey != "" {
761 cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey)
762 }
763 }
764 }
765 // TODO: support local models
766
767 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
768 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
769 providerConfig.ExtraParams = map[string]string{
770 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
771 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
772 }
773 // Find the VertexAI provider definition to get default models
774 for _, p := range providers {
775 if p.ID == provider.InferenceProviderVertexAI {
776 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
777 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
778 for _, model := range p.Models {
779 configModel := Model{
780 ID: model.ID,
781 Name: model.Name,
782 CostPer1MIn: model.CostPer1MIn,
783 CostPer1MOut: model.CostPer1MOut,
784 CostPer1MInCached: model.CostPer1MInCached,
785 CostPer1MOutCached: model.CostPer1MOutCached,
786 ContextWindow: model.ContextWindow,
787 DefaultMaxTokens: model.DefaultMaxTokens,
788 CanReason: model.CanReason,
789 SupportsImages: model.SupportsImages,
790 }
791 // Set reasoning effort for reasoning models
792 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
793 configModel.HasReasoningEffort = model.HasReasoningEffort
794 configModel.ReasoningEffort = model.DefaultReasoningEffort
795 }
796 providerConfig.Models = append(providerConfig.Models, configModel)
797 }
798 break
799 }
800 }
801 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
802 }
803
804 if hasAWSCredentials() {
805 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
806 providerConfig.ExtraParams = map[string]string{
807 "region": os.Getenv("AWS_DEFAULT_REGION"),
808 }
809 if providerConfig.ExtraParams["region"] == "" {
810 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
811 }
812 // Find the Bedrock provider definition to get default models
813 for _, p := range providers {
814 if p.ID == provider.InferenceProviderBedrock {
815 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
816 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
817 for _, model := range p.Models {
818 configModel := Model{
819 ID: model.ID,
820 Name: model.Name,
821 CostPer1MIn: model.CostPer1MIn,
822 CostPer1MOut: model.CostPer1MOut,
823 CostPer1MInCached: model.CostPer1MInCached,
824 CostPer1MOutCached: model.CostPer1MOutCached,
825 ContextWindow: model.ContextWindow,
826 DefaultMaxTokens: model.DefaultMaxTokens,
827 CanReason: model.CanReason,
828 SupportsImages: model.SupportsImages,
829 }
830 // Set reasoning effort for reasoning models
831 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
832 configModel.HasReasoningEffort = model.HasReasoningEffort
833 configModel.ReasoningEffort = model.DefaultReasoningEffort
834 }
835 providerConfig.Models = append(providerConfig.Models, configModel)
836 }
837 break
838 }
839 }
840 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
841 }
842 return cfg
843}
844
845func hasAWSCredentials() bool {
846 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
847 return true
848 }
849
850 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
851 return true
852 }
853
854 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
855 return true
856 }
857
858 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
859 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
860 return true
861 }
862
863 return false
864}
865
866func WorkingDirectory() string {
867 return cwd
868}
869
870// TODO: Handle error state
871
872func GetAgentModel(agentID AgentID) Model {
873 cfg := Get()
874 agent, ok := cfg.Agents[agentID]
875 if !ok {
876 logging.Error("Agent not found", "agent_id", agentID)
877 return Model{}
878 }
879
880 var model PreferredModel
881 switch agent.Model {
882 case LargeModel:
883 model = cfg.Models.Large
884 case SmallModel:
885 model = cfg.Models.Small
886 default:
887 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
888 model = cfg.Models.Large // Fallback to large model
889 }
890 providerConfig, ok := cfg.Providers[model.Provider]
891 if !ok {
892 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
893 return Model{}
894 }
895
896 for _, m := range providerConfig.Models {
897 if m.ID == model.ModelID {
898 return m
899 }
900 }
901
902 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
903 return Model{}
904}
905
906// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
907// considering any overrides from the preferred model configuration
908func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
909 cfg := Get()
910 agent, ok := cfg.Agents[agentID]
911 if !ok {
912 logging.Error("Agent not found", "agent_id", agentID)
913 return 0
914 }
915
916 var preferredModel PreferredModel
917 switch agent.Model {
918 case LargeModel:
919 preferredModel = cfg.Models.Large
920 case SmallModel:
921 preferredModel = cfg.Models.Small
922 default:
923 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
924 preferredModel = cfg.Models.Large // Fallback to large model
925 }
926
927 // Get the base model configuration
928 baseModel := GetAgentModel(agentID)
929 if baseModel.ID == "" {
930 return 0
931 }
932
933 // Start with the default max tokens from the base model
934 maxTokens := baseModel.DefaultMaxTokens
935
936 // Override with preferred model max tokens if set
937 if preferredModel.MaxTokens > 0 {
938 maxTokens = preferredModel.MaxTokens
939 }
940
941 return maxTokens
942}
943
944func GetAgentProvider(agentID AgentID) ProviderConfig {
945 cfg := Get()
946 agent, ok := cfg.Agents[agentID]
947 if !ok {
948 logging.Error("Agent not found", "agent_id", agentID)
949 return ProviderConfig{}
950 }
951
952 var model PreferredModel
953 switch agent.Model {
954 case LargeModel:
955 model = cfg.Models.Large
956 case SmallModel:
957 model = cfg.Models.Small
958 default:
959 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
960 model = cfg.Models.Large // Fallback to large model
961 }
962
963 providerConfig, ok := cfg.Providers[model.Provider]
964 if !ok {
965 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
966 return ProviderConfig{}
967 }
968
969 return providerConfig
970}
971
972func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
973 cfg := Get()
974 providerConfig, ok := cfg.Providers[provider]
975 if !ok {
976 logging.Error("Provider not found", "provider", provider)
977 return Model{}
978 }
979
980 for _, model := range providerConfig.Models {
981 if model.ID == modelID {
982 return model
983 }
984 }
985
986 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
987 return Model{}
988}
989
990func GetModel(modelType ModelType) Model {
991 cfg := Get()
992 var model PreferredModel
993 switch modelType {
994 case LargeModel:
995 model = cfg.Models.Large
996 case SmallModel:
997 model = cfg.Models.Small
998 default:
999 model = cfg.Models.Large // Fallback to large model
1000 }
1001 providerConfig, ok := cfg.Providers[model.Provider]
1002 if !ok {
1003 return Model{}
1004 }
1005
1006 for _, m := range providerConfig.Models {
1007 if m.ID == model.ModelID {
1008 return m
1009 }
1010 }
1011 return Model{}
1012}
1013
1014func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
1015 cfg := Get()
1016 switch modelType {
1017 case LargeModel:
1018 cfg.Models.Large = model
1019 case SmallModel:
1020 cfg.Models.Small = model
1021 default:
1022 return fmt.Errorf("unknown model type: %s", modelType)
1023 }
1024 return nil
1025}
1026
1027// ValidationError represents a configuration validation error
1028type ValidationError struct {
1029 Field string
1030 Message string
1031}
1032
1033func (e ValidationError) Error() string {
1034 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
1035}
1036
1037// ValidationErrors represents multiple validation errors
1038type ValidationErrors []ValidationError
1039
1040func (e ValidationErrors) Error() string {
1041 if len(e) == 0 {
1042 return "no validation errors"
1043 }
1044 if len(e) == 1 {
1045 return e[0].Error()
1046 }
1047
1048 var messages []string
1049 for _, err := range e {
1050 messages = append(messages, err.Error())
1051 }
1052 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
1053}
1054
1055// HasErrors returns true if there are any validation errors
1056func (e ValidationErrors) HasErrors() bool {
1057 return len(e) > 0
1058}
1059
1060// Add appends a new validation error
1061func (e *ValidationErrors) Add(field, message string) {
1062 *e = append(*e, ValidationError{Field: field, Message: message})
1063}
1064
1065// Validate performs comprehensive validation of the configuration
1066func (c *Config) Validate() error {
1067 var errors ValidationErrors
1068
1069 // Validate providers
1070 c.validateProviders(&errors)
1071
1072 // Validate models
1073 c.validateModels(&errors)
1074
1075 // Validate agents
1076 c.validateAgents(&errors)
1077
1078 // Validate options
1079 c.validateOptions(&errors)
1080
1081 // Validate MCP configurations
1082 c.validateMCPs(&errors)
1083
1084 // Validate LSP configurations
1085 c.validateLSPs(&errors)
1086
1087 // Validate cross-references
1088 c.validateCrossReferences(&errors)
1089
1090 // Validate completeness
1091 c.validateCompleteness(&errors)
1092
1093 if errors.HasErrors() {
1094 return errors
1095 }
1096
1097 return nil
1098}
1099
1100// validateProviders validates all provider configurations
1101func (c *Config) validateProviders(errors *ValidationErrors) {
1102 if c.Providers == nil {
1103 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
1104 }
1105
1106 knownProviders := provider.KnownProviders()
1107 validTypes := []provider.Type{
1108 provider.TypeOpenAI,
1109 provider.TypeAnthropic,
1110 provider.TypeGemini,
1111 provider.TypeAzure,
1112 provider.TypeBedrock,
1113 provider.TypeVertexAI,
1114 provider.TypeXAI,
1115 }
1116
1117 for providerID, providerConfig := range c.Providers {
1118 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1119
1120 // Validate API key for non-disabled providers
1121 if !providerConfig.Disabled && providerConfig.APIKey == "" {
1122 // Special case for AWS Bedrock and VertexAI which may use other auth methods
1123 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
1124 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
1125 }
1126 }
1127
1128 // Validate provider type
1129 validType := slices.Contains(validTypes, providerConfig.ProviderType)
1130 if !validType {
1131 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
1132 }
1133
1134 // Validate custom providers
1135 isKnownProvider := slices.Contains(knownProviders, providerID)
1136
1137 if !isKnownProvider {
1138 // Custom provider validation
1139 if providerConfig.BaseURL == "" {
1140 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1141 }
1142 if providerConfig.ProviderType != provider.TypeOpenAI {
1143 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1144 }
1145 }
1146
1147 // Validate models
1148 modelIDs := make(map[string]bool)
1149 for i, model := range providerConfig.Models {
1150 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1151
1152 // Check for duplicate model IDs
1153 if modelIDs[model.ID] {
1154 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1155 }
1156 modelIDs[model.ID] = true
1157
1158 // Validate required model fields
1159 if model.ID == "" {
1160 errors.Add(modelFieldPrefix+".id", "model ID is required")
1161 }
1162 if model.Name == "" {
1163 errors.Add(modelFieldPrefix+".name", "model name is required")
1164 }
1165 if model.ContextWindow <= 0 {
1166 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1167 }
1168 if model.DefaultMaxTokens <= 0 {
1169 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1170 }
1171 if model.DefaultMaxTokens > model.ContextWindow {
1172 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1173 }
1174
1175 // Validate cost fields
1176 if model.CostPer1MIn < 0 {
1177 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1178 }
1179 if model.CostPer1MOut < 0 {
1180 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1181 }
1182 if model.CostPer1MInCached < 0 {
1183 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1184 }
1185 if model.CostPer1MOutCached < 0 {
1186 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1187 }
1188 }
1189
1190 // Validate default model references
1191 if providerConfig.DefaultLargeModel != "" {
1192 if !modelIDs[providerConfig.DefaultLargeModel] {
1193 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1194 }
1195 }
1196 if providerConfig.DefaultSmallModel != "" {
1197 if !modelIDs[providerConfig.DefaultSmallModel] {
1198 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1199 }
1200 }
1201
1202 // Validate provider-specific requirements
1203 c.validateProviderSpecific(providerID, providerConfig, errors)
1204 }
1205}
1206
1207// validateProviderSpecific validates provider-specific requirements
1208func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1209 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1210
1211 switch providerID {
1212 case provider.InferenceProviderVertexAI:
1213 if !providerConfig.Disabled {
1214 if providerConfig.ExtraParams == nil {
1215 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1216 } else {
1217 if providerConfig.ExtraParams["project"] == "" {
1218 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1219 }
1220 if providerConfig.ExtraParams["location"] == "" {
1221 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1222 }
1223 }
1224 }
1225 case provider.InferenceProviderBedrock:
1226 if !providerConfig.Disabled {
1227 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1228 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1229 }
1230 // Check for AWS credentials in environment
1231 if !hasAWSCredentials() {
1232 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1233 }
1234 }
1235 }
1236}
1237
1238// validateModels validates preferred model configurations
1239func (c *Config) validateModels(errors *ValidationErrors) {
1240 // Validate large model
1241 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1242 if c.Models.Large.ModelID == "" {
1243 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1244 }
1245 if c.Models.Large.Provider == "" {
1246 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1247 }
1248
1249 // Check if provider exists and is not disabled
1250 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1251 if providerConfig.Disabled {
1252 errors.Add("models.large.provider", "large model provider is disabled")
1253 }
1254
1255 // Check if model exists in provider
1256 modelExists := false
1257 for _, model := range providerConfig.Models {
1258 if model.ID == c.Models.Large.ModelID {
1259 modelExists = true
1260 break
1261 }
1262 }
1263 if !modelExists {
1264 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1265 }
1266 } else {
1267 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1268 }
1269 }
1270
1271 // Validate small model
1272 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1273 if c.Models.Small.ModelID == "" {
1274 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1275 }
1276 if c.Models.Small.Provider == "" {
1277 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1278 }
1279
1280 // Check if provider exists and is not disabled
1281 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1282 if providerConfig.Disabled {
1283 errors.Add("models.small.provider", "small model provider is disabled")
1284 }
1285
1286 // Check if model exists in provider
1287 modelExists := false
1288 for _, model := range providerConfig.Models {
1289 if model.ID == c.Models.Small.ModelID {
1290 modelExists = true
1291 break
1292 }
1293 }
1294 if !modelExists {
1295 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1296 }
1297 } else {
1298 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1299 }
1300 }
1301}
1302
1303// validateAgents validates agent configurations
1304func (c *Config) validateAgents(errors *ValidationErrors) {
1305 if c.Agents == nil {
1306 c.Agents = make(map[AgentID]Agent)
1307 }
1308
1309 validTools := []string{
1310 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "vscode_diff", "write", "agent",
1311 }
1312
1313 for agentID, agent := range c.Agents {
1314 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1315
1316 // Validate agent ID consistency
1317 if agent.ID != agentID {
1318 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1319 }
1320
1321 // Validate required fields
1322 if agent.ID == "" {
1323 errors.Add(fieldPrefix+".id", "agent ID is required")
1324 }
1325 if agent.Name == "" {
1326 errors.Add(fieldPrefix+".name", "agent name is required")
1327 }
1328
1329 // Validate model type
1330 if agent.Model != LargeModel && agent.Model != SmallModel {
1331 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1332 }
1333
1334 // Validate allowed tools
1335 if agent.AllowedTools != nil {
1336 for i, tool := range agent.AllowedTools {
1337 validTool := slices.Contains(validTools, tool)
1338 if !validTool {
1339 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1340 }
1341 }
1342 }
1343
1344 // Validate MCP references
1345 if agent.AllowedMCP != nil {
1346 for mcpName := range agent.AllowedMCP {
1347 if _, exists := c.MCP[mcpName]; !exists {
1348 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1349 }
1350 }
1351 }
1352
1353 // Validate LSP references
1354 if agent.AllowedLSP != nil {
1355 for _, lspName := range agent.AllowedLSP {
1356 if _, exists := c.LSP[lspName]; !exists {
1357 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1358 }
1359 }
1360 }
1361
1362 // Validate context paths (basic path validation)
1363 for i, contextPath := range agent.ContextPaths {
1364 if contextPath == "" {
1365 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1366 }
1367 // Check for invalid characters in path
1368 if strings.Contains(contextPath, "\x00") {
1369 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1370 }
1371 }
1372
1373 // Validate known agents maintain their core properties
1374 if agentID == AgentCoder {
1375 if agent.Name != "Coder" {
1376 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1377 }
1378 if agent.Description != "An agent that helps with executing coding tasks." {
1379 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1380 }
1381 } else if agentID == AgentTask {
1382 if agent.Name != "Task" {
1383 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1384 }
1385 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1386 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1387 }
1388 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1389 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1390 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1391 }
1392 }
1393 }
1394}
1395
1396// validateOptions validates configuration options
1397func (c *Config) validateOptions(errors *ValidationErrors) {
1398 // Validate data directory
1399 if c.Options.DataDirectory == "" {
1400 errors.Add("options.data_directory", "data directory is required")
1401 }
1402
1403 // Validate context paths
1404 for i, contextPath := range c.Options.ContextPaths {
1405 if contextPath == "" {
1406 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1407 }
1408 if strings.Contains(contextPath, "\x00") {
1409 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1410 }
1411 }
1412}
1413
1414// validateMCPs validates MCP configurations
1415func (c *Config) validateMCPs(errors *ValidationErrors) {
1416 if c.MCP == nil {
1417 c.MCP = make(map[string]MCP)
1418 }
1419
1420 for mcpName, mcpConfig := range c.MCP {
1421 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1422
1423 // Validate MCP type
1424 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1425 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1426 }
1427
1428 // Validate based on type
1429 if mcpConfig.Type == MCPStdio {
1430 if mcpConfig.Command == "" {
1431 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1432 }
1433 } else if mcpConfig.Type == MCPSse {
1434 if mcpConfig.URL == "" {
1435 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1436 }
1437 }
1438 }
1439}
1440
1441// validateLSPs validates LSP configurations
1442func (c *Config) validateLSPs(errors *ValidationErrors) {
1443 if c.LSP == nil {
1444 c.LSP = make(map[string]LSPConfig)
1445 }
1446
1447 for lspName, lspConfig := range c.LSP {
1448 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1449
1450 if lspConfig.Command == "" {
1451 errors.Add(fieldPrefix+".command", "command is required for LSP")
1452 }
1453 }
1454}
1455
1456// validateCrossReferences validates cross-references between different config sections
1457func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1458 // Validate that agents can use their assigned model types
1459 for agentID, agent := range c.Agents {
1460 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1461
1462 var preferredModel PreferredModel
1463 switch agent.Model {
1464 case LargeModel:
1465 preferredModel = c.Models.Large
1466 case SmallModel:
1467 preferredModel = c.Models.Small
1468 }
1469
1470 if preferredModel.Provider != "" {
1471 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1472 if providerConfig.Disabled {
1473 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1474 }
1475 }
1476 }
1477 }
1478}
1479
1480// validateCompleteness validates that the configuration is complete and usable
1481func (c *Config) validateCompleteness(errors *ValidationErrors) {
1482 // Check for at least one valid, non-disabled provider
1483 hasValidProvider := false
1484 for _, providerConfig := range c.Providers {
1485 if !providerConfig.Disabled {
1486 hasValidProvider = true
1487 break
1488 }
1489 }
1490 if !hasValidProvider {
1491 errors.Add("providers", "at least one non-disabled provider is required")
1492 }
1493
1494 // Check that default agents exist
1495 if _, exists := c.Agents[AgentCoder]; !exists {
1496 errors.Add("agents", "coder agent is required")
1497 }
1498 if _, exists := c.Agents[AgentTask]; !exists {
1499 errors.Add("agents", "task agent is required")
1500 }
1501
1502 // Check that preferred models are set if providers exist
1503 if hasValidProvider {
1504 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1505 errors.Add("models.large", "large preferred model must be configured when providers are available")
1506 }
1507 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1508 errors.Add("models.small", "small preferred model must be configured when providers are available")
1509 }
1510 }
1511}
1512
1513// JSONSchemaExtend adds custom schema properties for AgentID
1514func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
1515 schema.Enum = []any{
1516 string(AgentCoder),
1517 string(AgentTask),
1518 }
1519}
1520
1521// JSONSchemaExtend adds custom schema properties for ModelType
1522func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
1523 schema.Enum = []any{
1524 string(LargeModel),
1525 string(SmallModel),
1526 }
1527}
1528
1529// JSONSchemaExtend adds custom schema properties for MCPType
1530func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
1531 schema.Enum = []any{
1532 string(MCPStdio),
1533 string(MCPSse),
1534 }
1535}