1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/invopop/jsonschema"
18)
19
20const (
21 defaultDataDirectory = ".crush"
22 defaultLogLevel = "info"
23 appName = "crush"
24
25 MaxTokensFallbackDefault = 4096
26)
27
28var defaultContextPaths = []string{
29 ".github/copilot-instructions.md",
30 ".cursorrules",
31 ".cursor/rules/",
32 "CLAUDE.md",
33 "CLAUDE.local.md",
34 "GEMINI.md",
35 "gemini.md",
36 "crush.md",
37 "crush.local.md",
38 "Crush.md",
39 "Crush.local.md",
40 "CRUSH.md",
41 "CRUSH.local.md",
42}
43
44type AgentID string
45
46const (
47 AgentCoder AgentID = "coder"
48 AgentTask AgentID = "task"
49)
50
51type ModelType string
52
53const (
54 LargeModel ModelType = "large"
55 SmallModel ModelType = "small"
56)
57
58type Model struct {
59 ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
60 Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
61 CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
62 CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
63 CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
64 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
65 ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
66 DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
67 CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
68 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
69 HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
70 SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
71}
72
73type VertexAIOptions struct {
74 APIKey string `json:"api_key,omitempty"`
75 Project string `json:"project,omitempty"`
76 Location string `json:"location,omitempty"`
77}
78
79type ProviderConfig struct {
80 ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
81 BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
82 ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
83 APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
84 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
85 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
86 // used for e.x for vertex to set the project
87 ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
88
89 DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
90 DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
91
92 Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
93}
94
95type Agent struct {
96 ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
97 Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
98 Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
99 // This is the id of the system prompt used by the agent
100 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
101
102 Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
103
104 // The available tools for the agent
105 // if this is nil, all tools are available
106 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
107
108 // this tells us which MCPs are available for this agent
109 // if this is empty all mcps are available
110 // the string array is the list of tools from the AllowedMCP the agent has available
111 // if the string array is nil, all tools from the AllowedMCP are available
112 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
113
114 // The list of LSPs that this agent can use
115 // if this is nil, all LSPs are available
116 AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
117
118 // Overrides the context paths for this agent
119 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
120}
121
122type MCPType string
123
124const (
125 MCPStdio MCPType = "stdio"
126 MCPSse MCPType = "sse"
127)
128
129type MCP struct {
130 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
131 Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
132 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
133 Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"`
134 URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
135 // TODO: maybe make it possible to get the value from the env
136 Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
137}
138
139type LSPConfig struct {
140 Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
141 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
142 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
143 Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
144}
145
146type TUIOptions struct {
147 CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
148 // Here we can add themes later or any TUI related options
149}
150
151type Options struct {
152 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
153 TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
154 Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
155 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
156 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
157 // Relative to the cwd
158 DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
159}
160
161type PreferredModel struct {
162 ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
163 Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
164 // ReasoningEffort overrides the default reasoning effort for this model
165 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
166 // MaxTokens overrides the default max tokens for this model
167 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
168
169 // Think indicates if the model should think, only applicable for anthropic reasoning models
170 Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
171}
172
173type PreferredModels struct {
174 Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
175 Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
176}
177
178type Config struct {
179 Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
180 // List of configured providers
181 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
182
183 // List of configured agents
184 Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
185
186 // List of configured MCPs
187 MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
188
189 // List of configured LSPs
190 LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
191
192 // Miscellaneous options
193 Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
194}
195
196var (
197 instance *Config // The single instance of the Singleton
198 cwd string
199 once sync.Once // Ensures the initialization happens only once
200
201)
202
203func readConfigFile(path string) (*Config, error) {
204 var cfg *Config
205 if _, err := os.Stat(path); err != nil && !os.IsNotExist(err) {
206 // some other error occurred while checking the file
207 return nil, err
208 } else if err == nil {
209 // config file exists, read it
210 file, err := os.ReadFile(path)
211 if err != nil {
212 return nil, err
213 }
214 cfg = &Config{}
215 if err := json.Unmarshal(file, cfg); err != nil {
216 return nil, err
217 }
218 } else {
219 // config file does not exist, create a new one
220 cfg = &Config{}
221 }
222 return cfg, nil
223}
224
225func loadConfig(cwd string, debug bool) (*Config, error) {
226 // First read the global config file
227 cfgPath := ConfigPath()
228
229 cfg := defaultConfigBasedOnEnv()
230 cfg.Options.Debug = debug
231 defaultLevel := slog.LevelInfo
232 if cfg.Options.Debug {
233 defaultLevel = slog.LevelDebug
234 }
235 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
236 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
237
238 // if file does not exist create it
239 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
240 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
241 return cfg, fmt.Errorf("failed to create directory: %w", err)
242 }
243 if _, err := os.Create(loggingFile); err != nil {
244 return cfg, fmt.Errorf("failed to create log file: %w", err)
245 }
246 }
247
248 messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages")
249
250 if _, err := os.Stat(messagesPath); os.IsNotExist(err) {
251 if err := os.MkdirAll(messagesPath, 0o756); err != nil {
252 return cfg, fmt.Errorf("failed to create directory: %w", err)
253 }
254 }
255 logging.MessageDir = messagesPath
256
257 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
258 if err != nil {
259 return cfg, fmt.Errorf("failed to open log file: %w", err)
260 }
261 // Configure logger
262 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
263 Level: defaultLevel,
264 }))
265 slog.SetDefault(logger)
266 } else {
267 // Configure logger
268 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
269 Level: defaultLevel,
270 }))
271 slog.SetDefault(logger)
272 }
273
274 priorityOrderedConfigFiles := []string{
275 cfgPath, // Global config file
276 filepath.Join(cwd, "crush.json"), // Local config file
277 filepath.Join(cwd, ".crush.json"), // Local config file
278 }
279
280 configs := make([]*Config, 0)
281 for _, path := range priorityOrderedConfigFiles {
282 localConfig, err := readConfigFile(path)
283 if err != nil {
284 return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
285 }
286 if localConfig != nil {
287 // If the config file was read successfully, add it to the list
288 configs = append(configs, localConfig)
289 }
290 }
291
292 // merge options
293 mergeOptions(cfg, configs...)
294
295 mergeProviderConfigs(cfg, configs...)
296 // no providers found the app is not initialized yet
297 if len(cfg.Providers) == 0 {
298 return cfg, nil
299 }
300 preferredProvider := getPreferredProvider(cfg.Providers)
301 if preferredProvider != nil {
302 cfg.Models = PreferredModels{
303 Large: PreferredModel{
304 ModelID: preferredProvider.DefaultLargeModel,
305 Provider: preferredProvider.ID,
306 },
307 Small: PreferredModel{
308 ModelID: preferredProvider.DefaultSmallModel,
309 Provider: preferredProvider.ID,
310 },
311 }
312 } else {
313 // No valid providers found, set empty models
314 cfg.Models = PreferredModels{}
315 }
316
317 mergeModels(cfg, configs...)
318
319 agents := map[AgentID]Agent{
320 AgentCoder: {
321 ID: AgentCoder,
322 Name: "Coder",
323 Description: "An agent that helps with executing coding tasks.",
324 Model: LargeModel,
325 ContextPaths: cfg.Options.ContextPaths,
326 // All tools allowed
327 },
328 AgentTask: {
329 ID: AgentTask,
330 Name: "Task",
331 Description: "An agent that helps with searching for context and finding implementation details.",
332 Model: LargeModel,
333 ContextPaths: cfg.Options.ContextPaths,
334 AllowedTools: []string{
335 "glob",
336 "grep",
337 "ls",
338 "sourcegraph",
339 "view",
340 },
341 // NO MCPs or LSPs by default
342 AllowedMCP: map[string][]string{},
343 AllowedLSP: []string{},
344 },
345 }
346 cfg.Agents = agents
347 mergeAgents(cfg, configs...)
348 mergeMCPs(cfg, configs...)
349 mergeLSPs(cfg, configs...)
350
351 // Validate the final configuration
352 if err := cfg.Validate(); err != nil {
353 return cfg, fmt.Errorf("configuration validation failed: %w", err)
354 }
355
356 return cfg, nil
357}
358
359func Init(workingDir string, debug bool) (*Config, error) {
360 var err error
361 once.Do(func() {
362 cwd = workingDir
363 instance, err = loadConfig(cwd, debug)
364 if err != nil {
365 logging.Error("Failed to load config", "error", err)
366 }
367 })
368
369 return instance, err
370}
371
372func Get() *Config {
373 if instance == nil {
374 // TODO: Handle this better
375 panic("Config not initialized. Call InitConfig first.")
376 }
377 return instance
378}
379
380func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
381 providers := Providers()
382 for _, p := range providers {
383 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
384 return &providerConfig
385 }
386 }
387 // if none found return the first configured provider
388 for _, providerConfig := range configuredProviders {
389 if !providerConfig.Disabled {
390 return &providerConfig
391 }
392 }
393 return nil
394}
395
396func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
397 if other.APIKey != "" {
398 base.APIKey = other.APIKey
399 }
400 // Only change these options if the provider is not a known provider
401 if !slices.Contains(provider.KnownProviders(), p) {
402 if other.BaseURL != "" {
403 base.BaseURL = other.BaseURL
404 }
405 if other.ProviderType != "" {
406 base.ProviderType = other.ProviderType
407 }
408 if len(other.ExtraHeaders) > 0 {
409 if base.ExtraHeaders == nil {
410 base.ExtraHeaders = make(map[string]string)
411 }
412 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
413 }
414 if len(other.ExtraParams) > 0 {
415 if base.ExtraParams == nil {
416 base.ExtraParams = make(map[string]string)
417 }
418 maps.Copy(base.ExtraParams, other.ExtraParams)
419 }
420 }
421
422 if other.Disabled {
423 base.Disabled = other.Disabled
424 }
425
426 if other.DefaultLargeModel != "" {
427 base.DefaultLargeModel = other.DefaultLargeModel
428 }
429 // Add new models if they don't exist
430 if other.Models != nil {
431 for _, model := range other.Models {
432 // check if the model already exists
433 exists := false
434 for _, existingModel := range base.Models {
435 if existingModel.ID == model.ID {
436 exists = true
437 break
438 }
439 }
440 if !exists {
441 base.Models = append(base.Models, model)
442 }
443 }
444 }
445
446 return base
447}
448
449func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
450 if !slices.Contains(provider.KnownProviders(), p) {
451 if providerConfig.ProviderType != provider.TypeOpenAI {
452 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
453 }
454 if providerConfig.BaseURL == "" {
455 return errors.New("base URL must be set for custom providers")
456 }
457 if providerConfig.APIKey == "" {
458 return errors.New("API key must be set for custom providers")
459 }
460 }
461 return nil
462}
463
464func mergeModels(base *Config, others ...*Config) {
465 for _, cfg := range others {
466 if cfg == nil {
467 continue
468 }
469 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
470 base.Models.Large = cfg.Models.Large
471 }
472
473 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
474 base.Models.Small = cfg.Models.Small
475 }
476 }
477}
478
479func mergeOptions(base *Config, others ...*Config) {
480 for _, cfg := range others {
481 if cfg == nil {
482 continue
483 }
484 baseOptions := base.Options
485 other := cfg.Options
486 if len(other.ContextPaths) > 0 {
487 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
488 }
489
490 if other.TUI.CompactMode {
491 baseOptions.TUI.CompactMode = other.TUI.CompactMode
492 }
493
494 if other.Debug {
495 baseOptions.Debug = other.Debug
496 }
497
498 if other.DebugLSP {
499 baseOptions.DebugLSP = other.DebugLSP
500 }
501
502 if other.DisableAutoSummarize {
503 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
504 }
505
506 if other.DataDirectory != "" {
507 baseOptions.DataDirectory = other.DataDirectory
508 }
509 base.Options = baseOptions
510 }
511}
512
513func mergeAgents(base *Config, others ...*Config) {
514 for _, cfg := range others {
515 if cfg == nil {
516 continue
517 }
518 for agentID, newAgent := range cfg.Agents {
519 if _, ok := base.Agents[agentID]; !ok {
520 newAgent.ID = agentID
521 if newAgent.Model == "" {
522 newAgent.Model = LargeModel
523 }
524 if len(newAgent.ContextPaths) > 0 {
525 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
526 } else {
527 newAgent.ContextPaths = base.Options.ContextPaths
528 }
529 base.Agents[agentID] = newAgent
530 } else {
531 baseAgent := base.Agents[agentID]
532
533 if agentID == AgentCoder || agentID == AgentTask {
534 if newAgent.Model != "" {
535 baseAgent.Model = newAgent.Model
536 }
537 if newAgent.AllowedMCP != nil {
538 baseAgent.AllowedMCP = newAgent.AllowedMCP
539 }
540 if newAgent.AllowedLSP != nil {
541 baseAgent.AllowedLSP = newAgent.AllowedLSP
542 }
543 // Context paths are additive for known agents too
544 if len(newAgent.ContextPaths) > 0 {
545 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
546 }
547 } else {
548 if newAgent.Name != "" {
549 baseAgent.Name = newAgent.Name
550 }
551 if newAgent.Description != "" {
552 baseAgent.Description = newAgent.Description
553 }
554 if newAgent.Model != "" {
555 baseAgent.Model = newAgent.Model
556 } else if baseAgent.Model == "" {
557 baseAgent.Model = LargeModel
558 }
559
560 baseAgent.Disabled = newAgent.Disabled
561
562 if newAgent.AllowedTools != nil {
563 baseAgent.AllowedTools = newAgent.AllowedTools
564 }
565 if newAgent.AllowedMCP != nil {
566 baseAgent.AllowedMCP = newAgent.AllowedMCP
567 }
568 if newAgent.AllowedLSP != nil {
569 baseAgent.AllowedLSP = newAgent.AllowedLSP
570 }
571 if len(newAgent.ContextPaths) > 0 {
572 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
573 }
574 }
575
576 base.Agents[agentID] = baseAgent
577 }
578 }
579 }
580}
581
582func mergeMCPs(base *Config, others ...*Config) {
583 for _, cfg := range others {
584 if cfg == nil {
585 continue
586 }
587 maps.Copy(base.MCP, cfg.MCP)
588 }
589}
590
591func mergeLSPs(base *Config, others ...*Config) {
592 for _, cfg := range others {
593 if cfg == nil {
594 continue
595 }
596 maps.Copy(base.LSP, cfg.LSP)
597 }
598}
599
600func mergeProviderConfigs(base *Config, others ...*Config) {
601 for _, cfg := range others {
602 if cfg == nil {
603 continue
604 }
605 for providerName, p := range cfg.Providers {
606 p.ID = providerName
607 if _, ok := base.Providers[providerName]; !ok {
608 if slices.Contains(provider.KnownProviders(), providerName) {
609 providers := Providers()
610 for _, providerDef := range providers {
611 if providerDef.ID == providerName {
612 logging.Info("Using default provider config for", "provider", providerName)
613 baseProvider := getDefaultProviderConfig(providerDef, providerDef.APIKey)
614 base.Providers[providerName] = mergeProviderConfig(providerName, baseProvider, p)
615 break
616 }
617 }
618 } else {
619 base.Providers[providerName] = p
620 }
621 } else {
622 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
623 }
624 }
625 }
626
627 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
628 for providerName, providerConfig := range base.Providers {
629 err := validateProvider(providerName, providerConfig)
630 if err != nil {
631 logging.Warn("Skipping provider", "name", providerName, "error", err)
632 continue // Skip invalid providers
633 }
634 finalProviders[providerName] = providerConfig
635 }
636 base.Providers = finalProviders
637}
638
639func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
640 switch providerID {
641 case provider.InferenceProviderAnthropic:
642 return ProviderConfig{
643 ID: providerID,
644 ProviderType: provider.TypeAnthropic,
645 }
646 case provider.InferenceProviderOpenAI:
647 return ProviderConfig{
648 ID: providerID,
649 ProviderType: provider.TypeOpenAI,
650 }
651 case provider.InferenceProviderGemini:
652 return ProviderConfig{
653 ID: providerID,
654 ProviderType: provider.TypeGemini,
655 }
656 case provider.InferenceProviderBedrock:
657 return ProviderConfig{
658 ID: providerID,
659 ProviderType: provider.TypeBedrock,
660 }
661 case provider.InferenceProviderAzure:
662 return ProviderConfig{
663 ID: providerID,
664 ProviderType: provider.TypeAzure,
665 }
666 case provider.InferenceProviderOpenRouter:
667 return ProviderConfig{
668 ID: providerID,
669 ProviderType: provider.TypeOpenAI,
670 BaseURL: "https://openrouter.ai/api/v1",
671 ExtraHeaders: map[string]string{
672 "HTTP-Referer": "crush.charm.land",
673 "X-Title": "Crush",
674 },
675 }
676 case provider.InferenceProviderXAI:
677 return ProviderConfig{
678 ID: providerID,
679 ProviderType: provider.TypeXAI,
680 BaseURL: "https://api.x.ai/v1",
681 }
682 case provider.InferenceProviderVertexAI:
683 return ProviderConfig{
684 ID: providerID,
685 ProviderType: provider.TypeVertexAI,
686 }
687 default:
688 return ProviderConfig{
689 ID: providerID,
690 ProviderType: provider.TypeOpenAI,
691 }
692 }
693}
694
695func getDefaultProviderConfig(p provider.Provider, apiKey string) ProviderConfig {
696 providerConfig := providerDefaultConfig(p.ID)
697 providerConfig.APIKey = apiKey
698 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
699 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
700 baseURL := p.APIEndpoint
701 if strings.HasPrefix(baseURL, "$") {
702 envVar := strings.TrimPrefix(baseURL, "$")
703 baseURL = os.Getenv(envVar)
704 }
705 providerConfig.BaseURL = baseURL
706 for _, model := range p.Models {
707 configModel := Model{
708 ID: model.ID,
709 Name: model.Name,
710 CostPer1MIn: model.CostPer1MIn,
711 CostPer1MOut: model.CostPer1MOut,
712 CostPer1MInCached: model.CostPer1MInCached,
713 CostPer1MOutCached: model.CostPer1MOutCached,
714 ContextWindow: model.ContextWindow,
715 DefaultMaxTokens: model.DefaultMaxTokens,
716 CanReason: model.CanReason,
717 SupportsImages: model.SupportsImages,
718 }
719 // Set reasoning effort for reasoning models
720 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
721 configModel.HasReasoningEffort = model.HasReasoningEffort
722 configModel.ReasoningEffort = model.DefaultReasoningEffort
723 }
724 providerConfig.Models = append(providerConfig.Models, configModel)
725 }
726 return providerConfig
727}
728
729func defaultConfigBasedOnEnv() *Config {
730 cfg := &Config{
731 Options: Options{
732 DataDirectory: defaultDataDirectory,
733 ContextPaths: defaultContextPaths,
734 },
735 Providers: make(map[provider.InferenceProvider]ProviderConfig),
736 Agents: make(map[AgentID]Agent),
737 LSP: make(map[string]LSPConfig),
738 MCP: make(map[string]MCP),
739 }
740
741 providers := Providers()
742
743 for _, p := range providers {
744 if strings.HasPrefix(p.APIKey, "$") {
745 envVar := strings.TrimPrefix(p.APIKey, "$")
746 if apiKey := os.Getenv(envVar); apiKey != "" {
747 cfg.Providers[p.ID] = getDefaultProviderConfig(p, apiKey)
748 }
749 }
750 }
751 // TODO: support local models
752
753 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
754 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
755 providerConfig.ExtraParams = map[string]string{
756 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
757 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
758 }
759 // Find the VertexAI provider definition to get default models
760 for _, p := range providers {
761 if p.ID == provider.InferenceProviderVertexAI {
762 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
763 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
764 for _, model := range p.Models {
765 configModel := Model{
766 ID: model.ID,
767 Name: model.Name,
768 CostPer1MIn: model.CostPer1MIn,
769 CostPer1MOut: model.CostPer1MOut,
770 CostPer1MInCached: model.CostPer1MInCached,
771 CostPer1MOutCached: model.CostPer1MOutCached,
772 ContextWindow: model.ContextWindow,
773 DefaultMaxTokens: model.DefaultMaxTokens,
774 CanReason: model.CanReason,
775 SupportsImages: model.SupportsImages,
776 }
777 // Set reasoning effort for reasoning models
778 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
779 configModel.HasReasoningEffort = model.HasReasoningEffort
780 configModel.ReasoningEffort = model.DefaultReasoningEffort
781 }
782 providerConfig.Models = append(providerConfig.Models, configModel)
783 }
784 break
785 }
786 }
787 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
788 }
789
790 if hasAWSCredentials() {
791 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
792 providerConfig.ExtraParams = map[string]string{
793 "region": os.Getenv("AWS_DEFAULT_REGION"),
794 }
795 if providerConfig.ExtraParams["region"] == "" {
796 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
797 }
798 // Find the Bedrock provider definition to get default models
799 for _, p := range providers {
800 if p.ID == provider.InferenceProviderBedrock {
801 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
802 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
803 for _, model := range p.Models {
804 configModel := Model{
805 ID: model.ID,
806 Name: model.Name,
807 CostPer1MIn: model.CostPer1MIn,
808 CostPer1MOut: model.CostPer1MOut,
809 CostPer1MInCached: model.CostPer1MInCached,
810 CostPer1MOutCached: model.CostPer1MOutCached,
811 ContextWindow: model.ContextWindow,
812 DefaultMaxTokens: model.DefaultMaxTokens,
813 CanReason: model.CanReason,
814 SupportsImages: model.SupportsImages,
815 }
816 // Set reasoning effort for reasoning models
817 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
818 configModel.HasReasoningEffort = model.HasReasoningEffort
819 configModel.ReasoningEffort = model.DefaultReasoningEffort
820 }
821 providerConfig.Models = append(providerConfig.Models, configModel)
822 }
823 break
824 }
825 }
826 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
827 }
828 return cfg
829}
830
831func hasAWSCredentials() bool {
832 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
833 return true
834 }
835
836 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
837 return true
838 }
839
840 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
841 return true
842 }
843
844 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
845 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
846 return true
847 }
848
849 return false
850}
851
852func WorkingDirectory() string {
853 return cwd
854}
855
856// TODO: Handle error state
857
858func GetAgentModel(agentID AgentID) Model {
859 cfg := Get()
860 agent, ok := cfg.Agents[agentID]
861 if !ok {
862 logging.Error("Agent not found", "agent_id", agentID)
863 return Model{}
864 }
865
866 var model PreferredModel
867 switch agent.Model {
868 case LargeModel:
869 model = cfg.Models.Large
870 case SmallModel:
871 model = cfg.Models.Small
872 default:
873 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
874 model = cfg.Models.Large // Fallback to large model
875 }
876 providerConfig, ok := cfg.Providers[model.Provider]
877 if !ok {
878 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
879 return Model{}
880 }
881
882 for _, m := range providerConfig.Models {
883 if m.ID == model.ModelID {
884 return m
885 }
886 }
887
888 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
889 return Model{}
890}
891
892// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
893// considering any overrides from the preferred model configuration
894func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
895 cfg := Get()
896 agent, ok := cfg.Agents[agentID]
897 if !ok {
898 logging.Error("Agent not found", "agent_id", agentID)
899 return 0
900 }
901
902 var preferredModel PreferredModel
903 switch agent.Model {
904 case LargeModel:
905 preferredModel = cfg.Models.Large
906 case SmallModel:
907 preferredModel = cfg.Models.Small
908 default:
909 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
910 preferredModel = cfg.Models.Large // Fallback to large model
911 }
912
913 // Get the base model configuration
914 baseModel := GetAgentModel(agentID)
915 if baseModel.ID == "" {
916 return 0
917 }
918
919 // Start with the default max tokens from the base model
920 maxTokens := baseModel.DefaultMaxTokens
921
922 // Override with preferred model max tokens if set
923 if preferredModel.MaxTokens > 0 {
924 maxTokens = preferredModel.MaxTokens
925 }
926
927 return maxTokens
928}
929
930func GetAgentProvider(agentID AgentID) ProviderConfig {
931 cfg := Get()
932 agent, ok := cfg.Agents[agentID]
933 if !ok {
934 logging.Error("Agent not found", "agent_id", agentID)
935 return ProviderConfig{}
936 }
937
938 var model PreferredModel
939 switch agent.Model {
940 case LargeModel:
941 model = cfg.Models.Large
942 case SmallModel:
943 model = cfg.Models.Small
944 default:
945 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
946 model = cfg.Models.Large // Fallback to large model
947 }
948
949 providerConfig, ok := cfg.Providers[model.Provider]
950 if !ok {
951 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
952 return ProviderConfig{}
953 }
954
955 return providerConfig
956}
957
958func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
959 cfg := Get()
960 providerConfig, ok := cfg.Providers[provider]
961 if !ok {
962 logging.Error("Provider not found", "provider", provider)
963 return Model{}
964 }
965
966 for _, model := range providerConfig.Models {
967 if model.ID == modelID {
968 return model
969 }
970 }
971
972 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
973 return Model{}
974}
975
976func GetModel(modelType ModelType) Model {
977 cfg := Get()
978 var model PreferredModel
979 switch modelType {
980 case LargeModel:
981 model = cfg.Models.Large
982 case SmallModel:
983 model = cfg.Models.Small
984 default:
985 model = cfg.Models.Large // Fallback to large model
986 }
987 providerConfig, ok := cfg.Providers[model.Provider]
988 if !ok {
989 return Model{}
990 }
991
992 for _, m := range providerConfig.Models {
993 if m.ID == model.ModelID {
994 return m
995 }
996 }
997 return Model{}
998}
999
1000func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
1001 cfg := Get()
1002 switch modelType {
1003 case LargeModel:
1004 cfg.Models.Large = model
1005 case SmallModel:
1006 cfg.Models.Small = model
1007 default:
1008 return fmt.Errorf("unknown model type: %s", modelType)
1009 }
1010 return nil
1011}
1012
1013// ValidationError represents a configuration validation error
1014type ValidationError struct {
1015 Field string
1016 Message string
1017}
1018
1019func (e ValidationError) Error() string {
1020 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
1021}
1022
1023// ValidationErrors represents multiple validation errors
1024type ValidationErrors []ValidationError
1025
1026func (e ValidationErrors) Error() string {
1027 if len(e) == 0 {
1028 return "no validation errors"
1029 }
1030 if len(e) == 1 {
1031 return e[0].Error()
1032 }
1033
1034 var messages []string
1035 for _, err := range e {
1036 messages = append(messages, err.Error())
1037 }
1038 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
1039}
1040
1041// HasErrors returns true if there are any validation errors
1042func (e ValidationErrors) HasErrors() bool {
1043 return len(e) > 0
1044}
1045
1046// Add appends a new validation error
1047func (e *ValidationErrors) Add(field, message string) {
1048 *e = append(*e, ValidationError{Field: field, Message: message})
1049}
1050
1051// Validate performs comprehensive validation of the configuration
1052func (c *Config) Validate() error {
1053 var errors ValidationErrors
1054
1055 // Validate providers
1056 c.validateProviders(&errors)
1057
1058 // Validate models
1059 c.validateModels(&errors)
1060
1061 // Validate agents
1062 c.validateAgents(&errors)
1063
1064 // Validate options
1065 c.validateOptions(&errors)
1066
1067 // Validate MCP configurations
1068 c.validateMCPs(&errors)
1069
1070 // Validate LSP configurations
1071 c.validateLSPs(&errors)
1072
1073 // Validate cross-references
1074 c.validateCrossReferences(&errors)
1075
1076 // Validate completeness
1077 c.validateCompleteness(&errors)
1078
1079 if errors.HasErrors() {
1080 return errors
1081 }
1082
1083 return nil
1084}
1085
1086// validateProviders validates all provider configurations
1087func (c *Config) validateProviders(errors *ValidationErrors) {
1088 if c.Providers == nil {
1089 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
1090 }
1091
1092 knownProviders := provider.KnownProviders()
1093 validTypes := []provider.Type{
1094 provider.TypeOpenAI,
1095 provider.TypeAnthropic,
1096 provider.TypeGemini,
1097 provider.TypeAzure,
1098 provider.TypeBedrock,
1099 provider.TypeVertexAI,
1100 provider.TypeXAI,
1101 }
1102
1103 for providerID, providerConfig := range c.Providers {
1104 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1105
1106 // Validate API key for non-disabled providers
1107 if !providerConfig.Disabled && providerConfig.APIKey == "" {
1108 // Special case for AWS Bedrock and VertexAI which may use other auth methods
1109 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
1110 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
1111 }
1112 }
1113
1114 // Validate provider type
1115 validType := slices.Contains(validTypes, providerConfig.ProviderType)
1116 if !validType {
1117 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
1118 }
1119
1120 // Validate custom providers
1121 isKnownProvider := slices.Contains(knownProviders, providerID)
1122
1123 if !isKnownProvider {
1124 // Custom provider validation
1125 if providerConfig.BaseURL == "" {
1126 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1127 }
1128 if providerConfig.ProviderType != provider.TypeOpenAI {
1129 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1130 }
1131 }
1132
1133 // Validate models
1134 modelIDs := make(map[string]bool)
1135 for i, model := range providerConfig.Models {
1136 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1137
1138 // Check for duplicate model IDs
1139 if modelIDs[model.ID] {
1140 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1141 }
1142 modelIDs[model.ID] = true
1143
1144 // Validate required model fields
1145 if model.ID == "" {
1146 errors.Add(modelFieldPrefix+".id", "model ID is required")
1147 }
1148 if model.Name == "" {
1149 errors.Add(modelFieldPrefix+".name", "model name is required")
1150 }
1151 if model.ContextWindow <= 0 {
1152 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1153 }
1154 if model.DefaultMaxTokens <= 0 {
1155 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1156 }
1157 if model.DefaultMaxTokens > model.ContextWindow {
1158 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1159 }
1160
1161 // Validate cost fields
1162 if model.CostPer1MIn < 0 {
1163 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1164 }
1165 if model.CostPer1MOut < 0 {
1166 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1167 }
1168 if model.CostPer1MInCached < 0 {
1169 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1170 }
1171 if model.CostPer1MOutCached < 0 {
1172 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1173 }
1174 }
1175
1176 // Validate default model references
1177 if providerConfig.DefaultLargeModel != "" {
1178 if !modelIDs[providerConfig.DefaultLargeModel] {
1179 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1180 }
1181 }
1182 if providerConfig.DefaultSmallModel != "" {
1183 if !modelIDs[providerConfig.DefaultSmallModel] {
1184 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1185 }
1186 }
1187
1188 // Validate provider-specific requirements
1189 c.validateProviderSpecific(providerID, providerConfig, errors)
1190 }
1191}
1192
1193// validateProviderSpecific validates provider-specific requirements
1194func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1195 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1196
1197 switch providerID {
1198 case provider.InferenceProviderVertexAI:
1199 if !providerConfig.Disabled {
1200 if providerConfig.ExtraParams == nil {
1201 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1202 } else {
1203 if providerConfig.ExtraParams["project"] == "" {
1204 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1205 }
1206 if providerConfig.ExtraParams["location"] == "" {
1207 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1208 }
1209 }
1210 }
1211 case provider.InferenceProviderBedrock:
1212 if !providerConfig.Disabled {
1213 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1214 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1215 }
1216 // Check for AWS credentials in environment
1217 if !hasAWSCredentials() {
1218 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1219 }
1220 }
1221 }
1222}
1223
1224// validateModels validates preferred model configurations
1225func (c *Config) validateModels(errors *ValidationErrors) {
1226 // Validate large model
1227 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1228 if c.Models.Large.ModelID == "" {
1229 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1230 }
1231 if c.Models.Large.Provider == "" {
1232 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1233 }
1234
1235 // Check if provider exists and is not disabled
1236 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1237 if providerConfig.Disabled {
1238 errors.Add("models.large.provider", "large model provider is disabled")
1239 }
1240
1241 // Check if model exists in provider
1242 modelExists := false
1243 for _, model := range providerConfig.Models {
1244 if model.ID == c.Models.Large.ModelID {
1245 modelExists = true
1246 break
1247 }
1248 }
1249 if !modelExists {
1250 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1251 }
1252 } else {
1253 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1254 }
1255 }
1256
1257 // Validate small model
1258 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1259 if c.Models.Small.ModelID == "" {
1260 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1261 }
1262 if c.Models.Small.Provider == "" {
1263 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1264 }
1265
1266 // Check if provider exists and is not disabled
1267 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1268 if providerConfig.Disabled {
1269 errors.Add("models.small.provider", "small model provider is disabled")
1270 }
1271
1272 // Check if model exists in provider
1273 modelExists := false
1274 for _, model := range providerConfig.Models {
1275 if model.ID == c.Models.Small.ModelID {
1276 modelExists = true
1277 break
1278 }
1279 }
1280 if !modelExists {
1281 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1282 }
1283 } else {
1284 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1285 }
1286 }
1287}
1288
1289// validateAgents validates agent configurations
1290func (c *Config) validateAgents(errors *ValidationErrors) {
1291 if c.Agents == nil {
1292 c.Agents = make(map[AgentID]Agent)
1293 }
1294
1295 validTools := []string{
1296 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1297 }
1298
1299 for agentID, agent := range c.Agents {
1300 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1301
1302 // Validate agent ID consistency
1303 if agent.ID != agentID {
1304 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1305 }
1306
1307 // Validate required fields
1308 if agent.ID == "" {
1309 errors.Add(fieldPrefix+".id", "agent ID is required")
1310 }
1311 if agent.Name == "" {
1312 errors.Add(fieldPrefix+".name", "agent name is required")
1313 }
1314
1315 // Validate model type
1316 if agent.Model != LargeModel && agent.Model != SmallModel {
1317 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1318 }
1319
1320 // Validate allowed tools
1321 if agent.AllowedTools != nil {
1322 for i, tool := range agent.AllowedTools {
1323 validTool := slices.Contains(validTools, tool)
1324 if !validTool {
1325 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1326 }
1327 }
1328 }
1329
1330 // Validate MCP references
1331 if agent.AllowedMCP != nil {
1332 for mcpName := range agent.AllowedMCP {
1333 if _, exists := c.MCP[mcpName]; !exists {
1334 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1335 }
1336 }
1337 }
1338
1339 // Validate LSP references
1340 if agent.AllowedLSP != nil {
1341 for _, lspName := range agent.AllowedLSP {
1342 if _, exists := c.LSP[lspName]; !exists {
1343 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1344 }
1345 }
1346 }
1347
1348 // Validate context paths (basic path validation)
1349 for i, contextPath := range agent.ContextPaths {
1350 if contextPath == "" {
1351 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1352 }
1353 // Check for invalid characters in path
1354 if strings.Contains(contextPath, "\x00") {
1355 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1356 }
1357 }
1358
1359 // Validate known agents maintain their core properties
1360 if agentID == AgentCoder {
1361 if agent.Name != "Coder" {
1362 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1363 }
1364 if agent.Description != "An agent that helps with executing coding tasks." {
1365 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1366 }
1367 } else if agentID == AgentTask {
1368 if agent.Name != "Task" {
1369 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1370 }
1371 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1372 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1373 }
1374 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1375 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1376 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1377 }
1378 }
1379 }
1380}
1381
1382// validateOptions validates configuration options
1383func (c *Config) validateOptions(errors *ValidationErrors) {
1384 // Validate data directory
1385 if c.Options.DataDirectory == "" {
1386 errors.Add("options.data_directory", "data directory is required")
1387 }
1388
1389 // Validate context paths
1390 for i, contextPath := range c.Options.ContextPaths {
1391 if contextPath == "" {
1392 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1393 }
1394 if strings.Contains(contextPath, "\x00") {
1395 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1396 }
1397 }
1398}
1399
1400// validateMCPs validates MCP configurations
1401func (c *Config) validateMCPs(errors *ValidationErrors) {
1402 if c.MCP == nil {
1403 c.MCP = make(map[string]MCP)
1404 }
1405
1406 for mcpName, mcpConfig := range c.MCP {
1407 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1408
1409 // Validate MCP type
1410 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1411 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1412 }
1413
1414 // Validate based on type
1415 if mcpConfig.Type == MCPStdio {
1416 if mcpConfig.Command == "" {
1417 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1418 }
1419 } else if mcpConfig.Type == MCPSse {
1420 if mcpConfig.URL == "" {
1421 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1422 }
1423 }
1424 }
1425}
1426
1427// validateLSPs validates LSP configurations
1428func (c *Config) validateLSPs(errors *ValidationErrors) {
1429 if c.LSP == nil {
1430 c.LSP = make(map[string]LSPConfig)
1431 }
1432
1433 for lspName, lspConfig := range c.LSP {
1434 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1435
1436 if lspConfig.Command == "" {
1437 errors.Add(fieldPrefix+".command", "command is required for LSP")
1438 }
1439 }
1440}
1441
1442// validateCrossReferences validates cross-references between different config sections
1443func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1444 // Validate that agents can use their assigned model types
1445 for agentID, agent := range c.Agents {
1446 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1447
1448 var preferredModel PreferredModel
1449 switch agent.Model {
1450 case LargeModel:
1451 preferredModel = c.Models.Large
1452 case SmallModel:
1453 preferredModel = c.Models.Small
1454 }
1455
1456 if preferredModel.Provider != "" {
1457 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1458 if providerConfig.Disabled {
1459 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1460 }
1461 }
1462 }
1463 }
1464}
1465
1466// validateCompleteness validates that the configuration is complete and usable
1467func (c *Config) validateCompleteness(errors *ValidationErrors) {
1468 // Check for at least one valid, non-disabled provider
1469 hasValidProvider := false
1470 for _, providerConfig := range c.Providers {
1471 if !providerConfig.Disabled {
1472 hasValidProvider = true
1473 break
1474 }
1475 }
1476 if !hasValidProvider {
1477 errors.Add("providers", "at least one non-disabled provider is required")
1478 }
1479
1480 // Check that default agents exist
1481 if _, exists := c.Agents[AgentCoder]; !exists {
1482 errors.Add("agents", "coder agent is required")
1483 }
1484 if _, exists := c.Agents[AgentTask]; !exists {
1485 errors.Add("agents", "task agent is required")
1486 }
1487
1488 // Check that preferred models are set if providers exist
1489 if hasValidProvider {
1490 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1491 errors.Add("models.large", "large preferred model must be configured when providers are available")
1492 }
1493 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1494 errors.Add("models.small", "small preferred model must be configured when providers are available")
1495 }
1496 }
1497}
1498
1499// JSONSchemaExtend adds custom schema properties for AgentID
1500func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
1501 schema.Enum = []any{
1502 string(AgentCoder),
1503 string(AgentTask),
1504 }
1505}
1506
1507// JSONSchemaExtend adds custom schema properties for ModelType
1508func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
1509 schema.Enum = []any{
1510 string(LargeModel),
1511 string(SmallModel),
1512 }
1513}
1514
1515// JSONSchemaExtend adds custom schema properties for MCPType
1516func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
1517 schema.Enum = []any{
1518 string(MCPStdio),
1519 string(MCPSse),
1520 }
1521}