1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/invopop/jsonschema"
18)
19
20const (
21 defaultDataDirectory = ".crush"
22 defaultLogLevel = "info"
23 appName = "crush"
24
25 MaxTokensFallbackDefault = 4096
26)
27
28var defaultContextPaths = []string{
29 ".github/copilot-instructions.md",
30 ".cursorrules",
31 ".cursor/rules/",
32 "CLAUDE.md",
33 "CLAUDE.local.md",
34 "GEMINI.md",
35 "gemini.md",
36 "crush.md",
37 "crush.local.md",
38 "Crush.md",
39 "Crush.local.md",
40 "CRUSH.md",
41 "CRUSH.local.md",
42}
43
44type AgentID string
45
46const (
47 AgentCoder AgentID = "coder"
48 AgentTask AgentID = "task"
49)
50
51type ModelType string
52
53const (
54 LargeModel ModelType = "large"
55 SmallModel ModelType = "small"
56)
57
58type Model struct {
59 ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
60 Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
61 CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
62 CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
63 CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
64 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
65 ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
66 DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
67 CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
68 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
69 HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
70 SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
71}
72
73type VertexAIOptions struct {
74 APIKey string `json:"api_key,omitempty"`
75 Project string `json:"project,omitempty"`
76 Location string `json:"location,omitempty"`
77}
78
79type ProviderConfig struct {
80 ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
81 BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
82 ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
83 APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
84 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
85 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
86 // used for e.x for vertex to set the project
87 ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
88
89 DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
90 DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
91
92 Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
93}
94
95type Agent struct {
96 ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
97 Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
98 Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
99 // This is the id of the system prompt used by the agent
100 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
101
102 Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
103
104 // The available tools for the agent
105 // if this is nil, all tools are available
106 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
107
108 // this tells us which MCPs are available for this agent
109 // if this is empty all mcps are available
110 // the string array is the list of tools from the AllowedMCP the agent has available
111 // if the string array is nil, all tools from the AllowedMCP are available
112 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
113
114 // The list of LSPs that this agent can use
115 // if this is nil, all LSPs are available
116 AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
117
118 // Overrides the context paths for this agent
119 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
120}
121
122type MCPType string
123
124const (
125 MCPStdio MCPType = "stdio"
126 MCPSse MCPType = "sse"
127)
128
129type MCP struct {
130 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
131 Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
132 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
133 Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"`
134 URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
135 // TODO: maybe make it possible to get the value from the env
136 Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
137}
138
139type LSPConfig struct {
140 Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
141 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
142 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
143 Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
144}
145
146type TUIOptions struct {
147 CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
148 // Here we can add themes later or any TUI related options
149}
150
151type Options struct {
152 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
153 TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
154 Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
155 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
156 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
157 // Relative to the cwd
158 DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
159}
160
161type PreferredModel struct {
162 ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
163 Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
164 // ReasoningEffort overrides the default reasoning effort for this model
165 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
166 // MaxTokens overrides the default max tokens for this model
167 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
168
169 // Think indicates if the model should think, only applicable for anthropic reasoning models
170 Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
171}
172
173type PreferredModels struct {
174 Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
175 Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
176}
177
178type Config struct {
179 Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
180 // List of configured providers
181 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
182
183 // List of configured agents
184 Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
185
186 // List of configured MCPs
187 MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
188
189 // List of configured LSPs
190 LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
191
192 // Miscellaneous options
193 Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
194}
195
196var (
197 instance *Config // The single instance of the Singleton
198 cwd string
199 once sync.Once // Ensures the initialization happens only once
200
201)
202
203func loadConfig(cwd string, debug bool) (*Config, error) {
204 // First read the global config file
205 cfgPath := ConfigPath()
206
207 cfg := defaultConfigBasedOnEnv()
208 cfg.Options.Debug = debug
209 defaultLevel := slog.LevelInfo
210 if cfg.Options.Debug {
211 defaultLevel = slog.LevelDebug
212 }
213 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
214 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
215
216 // if file does not exist create it
217 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
218 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
219 return cfg, fmt.Errorf("failed to create directory: %w", err)
220 }
221 if _, err := os.Create(loggingFile); err != nil {
222 return cfg, fmt.Errorf("failed to create log file: %w", err)
223 }
224 }
225
226 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
227 if err != nil {
228 return cfg, fmt.Errorf("failed to open log file: %w", err)
229 }
230 // Configure logger
231 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
232 Level: defaultLevel,
233 }))
234 slog.SetDefault(logger)
235 } else {
236 // Configure logger
237 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
238 Level: defaultLevel,
239 }))
240 slog.SetDefault(logger)
241 }
242 var globalCfg *Config
243 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
244 // some other error occurred while checking the file
245 return nil, err
246 } else if err == nil {
247 // config file exists, read it
248 file, err := os.ReadFile(cfgPath)
249 if err != nil {
250 return nil, err
251 }
252 globalCfg = &Config{}
253 if err := json.Unmarshal(file, globalCfg); err != nil {
254 return nil, err
255 }
256 } else {
257 // config file does not exist, create a new one
258 globalCfg = &Config{}
259 }
260
261 var localConfig *Config
262 // Global config loaded, now read the local config file
263 localConfigPath := filepath.Join(cwd, "crush.json")
264 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
265 // some other error occurred while checking the file
266 return nil, err
267 } else if err == nil {
268 // local config file exists, read it
269 file, err := os.ReadFile(localConfigPath)
270 if err != nil {
271 return nil, err
272 }
273 localConfig = &Config{}
274 if err := json.Unmarshal(file, localConfig); err != nil {
275 return nil, err
276 }
277 }
278
279 // merge options
280 mergeOptions(cfg, globalCfg, localConfig)
281
282 mergeProviderConfigs(cfg, globalCfg, localConfig)
283 // no providers found the app is not initialized yet
284 if len(cfg.Providers) == 0 {
285 return cfg, nil
286 }
287 preferredProvider := getPreferredProvider(cfg.Providers)
288 if preferredProvider != nil {
289 cfg.Models = PreferredModels{
290 Large: PreferredModel{
291 ModelID: preferredProvider.DefaultLargeModel,
292 Provider: preferredProvider.ID,
293 },
294 Small: PreferredModel{
295 ModelID: preferredProvider.DefaultSmallModel,
296 Provider: preferredProvider.ID,
297 },
298 }
299 } else {
300 // No valid providers found, set empty models
301 cfg.Models = PreferredModels{}
302 }
303
304 mergeModels(cfg, globalCfg, localConfig)
305
306 agents := map[AgentID]Agent{
307 AgentCoder: {
308 ID: AgentCoder,
309 Name: "Coder",
310 Description: "An agent that helps with executing coding tasks.",
311 Model: LargeModel,
312 ContextPaths: cfg.Options.ContextPaths,
313 // All tools allowed
314 },
315 AgentTask: {
316 ID: AgentTask,
317 Name: "Task",
318 Description: "An agent that helps with searching for context and finding implementation details.",
319 Model: LargeModel,
320 ContextPaths: cfg.Options.ContextPaths,
321 AllowedTools: []string{
322 "glob",
323 "grep",
324 "ls",
325 "sourcegraph",
326 "view",
327 "definitions",
328 },
329 // NO MCPs or LSPs by default
330 AllowedMCP: map[string][]string{},
331 AllowedLSP: []string{},
332 },
333 }
334 cfg.Agents = agents
335 mergeAgents(cfg, globalCfg, localConfig)
336 mergeMCPs(cfg, globalCfg, localConfig)
337 mergeLSPs(cfg, globalCfg, localConfig)
338
339 // Validate the final configuration
340 if err := cfg.Validate(); err != nil {
341 return cfg, fmt.Errorf("configuration validation failed: %w", err)
342 }
343
344 return cfg, nil
345}
346
347func Init(workingDir string, debug bool) (*Config, error) {
348 var err error
349 once.Do(func() {
350 cwd = workingDir
351 instance, err = loadConfig(cwd, debug)
352 if err != nil {
353 logging.Error("Failed to load config", "error", err)
354 }
355 })
356
357 return instance, err
358}
359
360func Get() *Config {
361 if instance == nil {
362 // TODO: Handle this better
363 panic("Config not initialized. Call InitConfig first.")
364 }
365 return instance
366}
367
368func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
369 providers := Providers()
370 for _, p := range providers {
371 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
372 return &providerConfig
373 }
374 }
375 // if none found return the first configured provider
376 for _, providerConfig := range configuredProviders {
377 if !providerConfig.Disabled {
378 return &providerConfig
379 }
380 }
381 return nil
382}
383
384func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
385 if other.APIKey != "" {
386 base.APIKey = other.APIKey
387 }
388 // Only change these options if the provider is not a known provider
389 if !slices.Contains(provider.KnownProviders(), p) {
390 if other.BaseURL != "" {
391 base.BaseURL = other.BaseURL
392 }
393 if other.ProviderType != "" {
394 base.ProviderType = other.ProviderType
395 }
396 if len(other.ExtraHeaders) > 0 {
397 if base.ExtraHeaders == nil {
398 base.ExtraHeaders = make(map[string]string)
399 }
400 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
401 }
402 if len(other.ExtraParams) > 0 {
403 if base.ExtraParams == nil {
404 base.ExtraParams = make(map[string]string)
405 }
406 maps.Copy(base.ExtraParams, other.ExtraParams)
407 }
408 }
409
410 if other.Disabled {
411 base.Disabled = other.Disabled
412 }
413
414 if other.DefaultLargeModel != "" {
415 base.DefaultLargeModel = other.DefaultLargeModel
416 }
417 // Add new models if they don't exist
418 if other.Models != nil {
419 for _, model := range other.Models {
420 // check if the model already exists
421 exists := false
422 for _, existingModel := range base.Models {
423 if existingModel.ID == model.ID {
424 exists = true
425 break
426 }
427 }
428 if !exists {
429 base.Models = append(base.Models, model)
430 }
431 }
432 }
433
434 return base
435}
436
437func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
438 if !slices.Contains(provider.KnownProviders(), p) {
439 if providerConfig.ProviderType != provider.TypeOpenAI {
440 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
441 }
442 if providerConfig.BaseURL == "" {
443 return errors.New("base URL must be set for custom providers")
444 }
445 if providerConfig.APIKey == "" {
446 return errors.New("API key must be set for custom providers")
447 }
448 }
449 return nil
450}
451
452func mergeModels(base, global, local *Config) {
453 for _, cfg := range []*Config{global, local} {
454 if cfg == nil {
455 continue
456 }
457 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
458 base.Models.Large = cfg.Models.Large
459 }
460
461 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
462 base.Models.Small = cfg.Models.Small
463 }
464 }
465}
466
467func mergeOptions(base, global, local *Config) {
468 for _, cfg := range []*Config{global, local} {
469 if cfg == nil {
470 continue
471 }
472 baseOptions := base.Options
473 other := cfg.Options
474 if len(other.ContextPaths) > 0 {
475 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
476 }
477
478 if other.TUI.CompactMode {
479 baseOptions.TUI.CompactMode = other.TUI.CompactMode
480 }
481
482 if other.Debug {
483 baseOptions.Debug = other.Debug
484 }
485
486 if other.DebugLSP {
487 baseOptions.DebugLSP = other.DebugLSP
488 }
489
490 if other.DisableAutoSummarize {
491 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
492 }
493
494 if other.DataDirectory != "" {
495 baseOptions.DataDirectory = other.DataDirectory
496 }
497 base.Options = baseOptions
498 }
499}
500
501func mergeAgents(base, global, local *Config) {
502 for _, cfg := range []*Config{global, local} {
503 if cfg == nil {
504 continue
505 }
506 for agentID, newAgent := range cfg.Agents {
507 if _, ok := base.Agents[agentID]; !ok {
508 newAgent.ID = agentID
509 if newAgent.Model == "" {
510 newAgent.Model = LargeModel
511 }
512 if len(newAgent.ContextPaths) > 0 {
513 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
514 } else {
515 newAgent.ContextPaths = base.Options.ContextPaths
516 }
517 base.Agents[agentID] = newAgent
518 } else {
519 baseAgent := base.Agents[agentID]
520
521 if agentID == AgentCoder || agentID == AgentTask {
522 if newAgent.Model != "" {
523 baseAgent.Model = newAgent.Model
524 }
525 if newAgent.AllowedMCP != nil {
526 baseAgent.AllowedMCP = newAgent.AllowedMCP
527 }
528 if newAgent.AllowedLSP != nil {
529 baseAgent.AllowedLSP = newAgent.AllowedLSP
530 }
531 // Context paths are additive for known agents too
532 if len(newAgent.ContextPaths) > 0 {
533 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
534 }
535 } else {
536 if newAgent.Name != "" {
537 baseAgent.Name = newAgent.Name
538 }
539 if newAgent.Description != "" {
540 baseAgent.Description = newAgent.Description
541 }
542 if newAgent.Model != "" {
543 baseAgent.Model = newAgent.Model
544 } else if baseAgent.Model == "" {
545 baseAgent.Model = LargeModel
546 }
547
548 baseAgent.Disabled = newAgent.Disabled
549
550 if newAgent.AllowedTools != nil {
551 baseAgent.AllowedTools = newAgent.AllowedTools
552 }
553 if newAgent.AllowedMCP != nil {
554 baseAgent.AllowedMCP = newAgent.AllowedMCP
555 }
556 if newAgent.AllowedLSP != nil {
557 baseAgent.AllowedLSP = newAgent.AllowedLSP
558 }
559 if len(newAgent.ContextPaths) > 0 {
560 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
561 }
562 }
563
564 base.Agents[agentID] = baseAgent
565 }
566 }
567 }
568}
569
570func mergeMCPs(base, global, local *Config) {
571 for _, cfg := range []*Config{global, local} {
572 if cfg == nil {
573 continue
574 }
575 maps.Copy(base.MCP, cfg.MCP)
576 }
577}
578
579func mergeLSPs(base, global, local *Config) {
580 for _, cfg := range []*Config{global, local} {
581 if cfg == nil {
582 continue
583 }
584 maps.Copy(base.LSP, cfg.LSP)
585 }
586}
587
588func mergeProviderConfigs(base, global, local *Config) {
589 for _, cfg := range []*Config{global, local} {
590 if cfg == nil {
591 continue
592 }
593 for providerName, p := range cfg.Providers {
594 p.ID = providerName
595 if _, ok := base.Providers[providerName]; !ok {
596 base.Providers[providerName] = p
597 } else {
598 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
599 }
600 }
601 }
602
603 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
604 for providerName, providerConfig := range base.Providers {
605 err := validateProvider(providerName, providerConfig)
606 if err != nil {
607 logging.Warn("Skipping provider", "name", providerName, "error", err)
608 continue // Skip invalid providers
609 }
610 finalProviders[providerName] = providerConfig
611 }
612 base.Providers = finalProviders
613}
614
615func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
616 switch providerID {
617 case provider.InferenceProviderAnthropic:
618 return ProviderConfig{
619 ID: providerID,
620 ProviderType: provider.TypeAnthropic,
621 }
622 case provider.InferenceProviderOpenAI:
623 return ProviderConfig{
624 ID: providerID,
625 ProviderType: provider.TypeOpenAI,
626 }
627 case provider.InferenceProviderGemini:
628 return ProviderConfig{
629 ID: providerID,
630 ProviderType: provider.TypeGemini,
631 }
632 case provider.InferenceProviderBedrock:
633 return ProviderConfig{
634 ID: providerID,
635 ProviderType: provider.TypeBedrock,
636 }
637 case provider.InferenceProviderAzure:
638 return ProviderConfig{
639 ID: providerID,
640 ProviderType: provider.TypeAzure,
641 }
642 case provider.InferenceProviderOpenRouter:
643 return ProviderConfig{
644 ID: providerID,
645 ProviderType: provider.TypeOpenAI,
646 BaseURL: "https://openrouter.ai/api/v1",
647 ExtraHeaders: map[string]string{
648 "HTTP-Referer": "crush.charm.land",
649 "X-Title": "Crush",
650 },
651 }
652 case provider.InferenceProviderXAI:
653 return ProviderConfig{
654 ID: providerID,
655 ProviderType: provider.TypeXAI,
656 BaseURL: "https://api.x.ai/v1",
657 }
658 case provider.InferenceProviderVertexAI:
659 return ProviderConfig{
660 ID: providerID,
661 ProviderType: provider.TypeVertexAI,
662 }
663 default:
664 return ProviderConfig{
665 ID: providerID,
666 ProviderType: provider.TypeOpenAI,
667 }
668 }
669}
670
671func defaultConfigBasedOnEnv() *Config {
672 cfg := &Config{
673 Options: Options{
674 DataDirectory: defaultDataDirectory,
675 ContextPaths: defaultContextPaths,
676 },
677 Providers: make(map[provider.InferenceProvider]ProviderConfig),
678 Agents: make(map[AgentID]Agent),
679 LSP: make(map[string]LSPConfig),
680 MCP: make(map[string]MCP),
681 }
682
683 providers := Providers()
684
685 for _, p := range providers {
686 if strings.HasPrefix(p.APIKey, "$") {
687 envVar := strings.TrimPrefix(p.APIKey, "$")
688 if apiKey := os.Getenv(envVar); apiKey != "" {
689 providerConfig := providerDefaultConfig(p.ID)
690 providerConfig.APIKey = apiKey
691 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
692 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
693 baseURL := p.APIEndpoint
694 if strings.HasPrefix(baseURL, "$") {
695 envVar := strings.TrimPrefix(baseURL, "$")
696 baseURL = os.Getenv(envVar)
697 }
698 providerConfig.BaseURL = baseURL
699 for _, model := range p.Models {
700 configModel := Model{
701 ID: model.ID,
702 Name: model.Name,
703 CostPer1MIn: model.CostPer1MIn,
704 CostPer1MOut: model.CostPer1MOut,
705 CostPer1MInCached: model.CostPer1MInCached,
706 CostPer1MOutCached: model.CostPer1MOutCached,
707 ContextWindow: model.ContextWindow,
708 DefaultMaxTokens: model.DefaultMaxTokens,
709 CanReason: model.CanReason,
710 SupportsImages: model.SupportsImages,
711 }
712 // Set reasoning effort for reasoning models
713 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
714 configModel.HasReasoningEffort = model.HasReasoningEffort
715 configModel.ReasoningEffort = model.DefaultReasoningEffort
716 }
717 providerConfig.Models = append(providerConfig.Models, configModel)
718 }
719 cfg.Providers[p.ID] = providerConfig
720 }
721 }
722 }
723 // TODO: support local models
724
725 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
726 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
727 providerConfig.ExtraParams = map[string]string{
728 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
729 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
730 }
731 // Find the VertexAI provider definition to get default models
732 for _, p := range providers {
733 if p.ID == provider.InferenceProviderVertexAI {
734 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
735 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
736 for _, model := range p.Models {
737 configModel := Model{
738 ID: model.ID,
739 Name: model.Name,
740 CostPer1MIn: model.CostPer1MIn,
741 CostPer1MOut: model.CostPer1MOut,
742 CostPer1MInCached: model.CostPer1MInCached,
743 CostPer1MOutCached: model.CostPer1MOutCached,
744 ContextWindow: model.ContextWindow,
745 DefaultMaxTokens: model.DefaultMaxTokens,
746 CanReason: model.CanReason,
747 SupportsImages: model.SupportsImages,
748 }
749 // Set reasoning effort for reasoning models
750 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
751 configModel.HasReasoningEffort = model.HasReasoningEffort
752 configModel.ReasoningEffort = model.DefaultReasoningEffort
753 }
754 providerConfig.Models = append(providerConfig.Models, configModel)
755 }
756 break
757 }
758 }
759 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
760 }
761
762 if hasAWSCredentials() {
763 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
764 providerConfig.ExtraParams = map[string]string{
765 "region": os.Getenv("AWS_DEFAULT_REGION"),
766 }
767 if providerConfig.ExtraParams["region"] == "" {
768 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
769 }
770 // Find the Bedrock provider definition to get default models
771 for _, p := range providers {
772 if p.ID == provider.InferenceProviderBedrock {
773 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
774 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
775 for _, model := range p.Models {
776 configModel := Model{
777 ID: model.ID,
778 Name: model.Name,
779 CostPer1MIn: model.CostPer1MIn,
780 CostPer1MOut: model.CostPer1MOut,
781 CostPer1MInCached: model.CostPer1MInCached,
782 CostPer1MOutCached: model.CostPer1MOutCached,
783 ContextWindow: model.ContextWindow,
784 DefaultMaxTokens: model.DefaultMaxTokens,
785 CanReason: model.CanReason,
786 SupportsImages: model.SupportsImages,
787 }
788 // Set reasoning effort for reasoning models
789 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
790 configModel.HasReasoningEffort = model.HasReasoningEffort
791 configModel.ReasoningEffort = model.DefaultReasoningEffort
792 }
793 providerConfig.Models = append(providerConfig.Models, configModel)
794 }
795 break
796 }
797 }
798 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
799 }
800 return cfg
801}
802
803func hasAWSCredentials() bool {
804 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
805 return true
806 }
807
808 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
809 return true
810 }
811
812 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
813 return true
814 }
815
816 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
817 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
818 return true
819 }
820
821 return false
822}
823
824func WorkingDirectory() string {
825 return cwd
826}
827
828// TODO: Handle error state
829
830func GetAgentModel(agentID AgentID) Model {
831 cfg := Get()
832 agent, ok := cfg.Agents[agentID]
833 if !ok {
834 logging.Error("Agent not found", "agent_id", agentID)
835 return Model{}
836 }
837
838 var model PreferredModel
839 switch agent.Model {
840 case LargeModel:
841 model = cfg.Models.Large
842 case SmallModel:
843 model = cfg.Models.Small
844 default:
845 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
846 model = cfg.Models.Large // Fallback to large model
847 }
848 providerConfig, ok := cfg.Providers[model.Provider]
849 if !ok {
850 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
851 return Model{}
852 }
853
854 for _, m := range providerConfig.Models {
855 if m.ID == model.ModelID {
856 return m
857 }
858 }
859
860 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
861 return Model{}
862}
863
864// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
865// considering any overrides from the preferred model configuration
866func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
867 cfg := Get()
868 agent, ok := cfg.Agents[agentID]
869 if !ok {
870 logging.Error("Agent not found", "agent_id", agentID)
871 return 0
872 }
873
874 var preferredModel PreferredModel
875 switch agent.Model {
876 case LargeModel:
877 preferredModel = cfg.Models.Large
878 case SmallModel:
879 preferredModel = cfg.Models.Small
880 default:
881 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
882 preferredModel = cfg.Models.Large // Fallback to large model
883 }
884
885 // Get the base model configuration
886 baseModel := GetAgentModel(agentID)
887 if baseModel.ID == "" {
888 return 0
889 }
890
891 // Start with the default max tokens from the base model
892 maxTokens := baseModel.DefaultMaxTokens
893
894 // Override with preferred model max tokens if set
895 if preferredModel.MaxTokens > 0 {
896 maxTokens = preferredModel.MaxTokens
897 }
898
899 return maxTokens
900}
901
902func GetAgentProvider(agentID AgentID) ProviderConfig {
903 cfg := Get()
904 agent, ok := cfg.Agents[agentID]
905 if !ok {
906 logging.Error("Agent not found", "agent_id", agentID)
907 return ProviderConfig{}
908 }
909
910 var model PreferredModel
911 switch agent.Model {
912 case LargeModel:
913 model = cfg.Models.Large
914 case SmallModel:
915 model = cfg.Models.Small
916 default:
917 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
918 model = cfg.Models.Large // Fallback to large model
919 }
920
921 providerConfig, ok := cfg.Providers[model.Provider]
922 if !ok {
923 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
924 return ProviderConfig{}
925 }
926
927 return providerConfig
928}
929
930func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
931 cfg := Get()
932 providerConfig, ok := cfg.Providers[provider]
933 if !ok {
934 logging.Error("Provider not found", "provider", provider)
935 return Model{}
936 }
937
938 for _, model := range providerConfig.Models {
939 if model.ID == modelID {
940 return model
941 }
942 }
943
944 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
945 return Model{}
946}
947
948func GetModel(modelType ModelType) Model {
949 cfg := Get()
950 var model PreferredModel
951 switch modelType {
952 case LargeModel:
953 model = cfg.Models.Large
954 case SmallModel:
955 model = cfg.Models.Small
956 default:
957 model = cfg.Models.Large // Fallback to large model
958 }
959 providerConfig, ok := cfg.Providers[model.Provider]
960 if !ok {
961 return Model{}
962 }
963
964 for _, m := range providerConfig.Models {
965 if m.ID == model.ModelID {
966 return m
967 }
968 }
969 return Model{}
970}
971
972func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
973 cfg := Get()
974 switch modelType {
975 case LargeModel:
976 cfg.Models.Large = model
977 case SmallModel:
978 cfg.Models.Small = model
979 default:
980 return fmt.Errorf("unknown model type: %s", modelType)
981 }
982 return nil
983}
984
985// ValidationError represents a configuration validation error
986type ValidationError struct {
987 Field string
988 Message string
989}
990
991func (e ValidationError) Error() string {
992 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
993}
994
995// ValidationErrors represents multiple validation errors
996type ValidationErrors []ValidationError
997
998func (e ValidationErrors) Error() string {
999 if len(e) == 0 {
1000 return "no validation errors"
1001 }
1002 if len(e) == 1 {
1003 return e[0].Error()
1004 }
1005
1006 var messages []string
1007 for _, err := range e {
1008 messages = append(messages, err.Error())
1009 }
1010 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
1011}
1012
1013// HasErrors returns true if there are any validation errors
1014func (e ValidationErrors) HasErrors() bool {
1015 return len(e) > 0
1016}
1017
1018// Add appends a new validation error
1019func (e *ValidationErrors) Add(field, message string) {
1020 *e = append(*e, ValidationError{Field: field, Message: message})
1021}
1022
1023// Validate performs comprehensive validation of the configuration
1024func (c *Config) Validate() error {
1025 var errors ValidationErrors
1026
1027 // Validate providers
1028 c.validateProviders(&errors)
1029
1030 // Validate models
1031 c.validateModels(&errors)
1032
1033 // Validate agents
1034 c.validateAgents(&errors)
1035
1036 // Validate options
1037 c.validateOptions(&errors)
1038
1039 // Validate MCP configurations
1040 c.validateMCPs(&errors)
1041
1042 // Validate LSP configurations
1043 c.validateLSPs(&errors)
1044
1045 // Validate cross-references
1046 c.validateCrossReferences(&errors)
1047
1048 // Validate completeness
1049 c.validateCompleteness(&errors)
1050
1051 if errors.HasErrors() {
1052 return errors
1053 }
1054
1055 return nil
1056}
1057
1058// validateProviders validates all provider configurations
1059func (c *Config) validateProviders(errors *ValidationErrors) {
1060 if c.Providers == nil {
1061 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
1062 }
1063
1064 knownProviders := provider.KnownProviders()
1065 validTypes := []provider.Type{
1066 provider.TypeOpenAI,
1067 provider.TypeAnthropic,
1068 provider.TypeGemini,
1069 provider.TypeAzure,
1070 provider.TypeBedrock,
1071 provider.TypeVertexAI,
1072 provider.TypeXAI,
1073 }
1074
1075 for providerID, providerConfig := range c.Providers {
1076 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1077
1078 // Validate API key for non-disabled providers
1079 if !providerConfig.Disabled && providerConfig.APIKey == "" {
1080 // Special case for AWS Bedrock and VertexAI which may use other auth methods
1081 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
1082 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
1083 }
1084 }
1085
1086 // Validate provider type
1087 validType := slices.Contains(validTypes, providerConfig.ProviderType)
1088 if !validType {
1089 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
1090 }
1091
1092 // Validate custom providers
1093 isKnownProvider := slices.Contains(knownProviders, providerID)
1094
1095 if !isKnownProvider {
1096 // Custom provider validation
1097 if providerConfig.BaseURL == "" {
1098 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1099 }
1100 if providerConfig.ProviderType != provider.TypeOpenAI {
1101 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1102 }
1103 }
1104
1105 // Validate models
1106 modelIDs := make(map[string]bool)
1107 for i, model := range providerConfig.Models {
1108 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1109
1110 // Check for duplicate model IDs
1111 if modelIDs[model.ID] {
1112 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1113 }
1114 modelIDs[model.ID] = true
1115
1116 // Validate required model fields
1117 if model.ID == "" {
1118 errors.Add(modelFieldPrefix+".id", "model ID is required")
1119 }
1120 if model.Name == "" {
1121 errors.Add(modelFieldPrefix+".name", "model name is required")
1122 }
1123 if model.ContextWindow <= 0 {
1124 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1125 }
1126 if model.DefaultMaxTokens <= 0 {
1127 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1128 }
1129 if model.DefaultMaxTokens > model.ContextWindow {
1130 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1131 }
1132
1133 // Validate cost fields
1134 if model.CostPer1MIn < 0 {
1135 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1136 }
1137 if model.CostPer1MOut < 0 {
1138 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1139 }
1140 if model.CostPer1MInCached < 0 {
1141 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1142 }
1143 if model.CostPer1MOutCached < 0 {
1144 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1145 }
1146 }
1147
1148 // Validate default model references
1149 if providerConfig.DefaultLargeModel != "" {
1150 if !modelIDs[providerConfig.DefaultLargeModel] {
1151 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1152 }
1153 }
1154 if providerConfig.DefaultSmallModel != "" {
1155 if !modelIDs[providerConfig.DefaultSmallModel] {
1156 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1157 }
1158 }
1159
1160 // Validate provider-specific requirements
1161 c.validateProviderSpecific(providerID, providerConfig, errors)
1162 }
1163}
1164
1165// validateProviderSpecific validates provider-specific requirements
1166func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1167 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1168
1169 switch providerID {
1170 case provider.InferenceProviderVertexAI:
1171 if !providerConfig.Disabled {
1172 if providerConfig.ExtraParams == nil {
1173 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1174 } else {
1175 if providerConfig.ExtraParams["project"] == "" {
1176 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1177 }
1178 if providerConfig.ExtraParams["location"] == "" {
1179 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1180 }
1181 }
1182 }
1183 case provider.InferenceProviderBedrock:
1184 if !providerConfig.Disabled {
1185 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1186 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1187 }
1188 // Check for AWS credentials in environment
1189 if !hasAWSCredentials() {
1190 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1191 }
1192 }
1193 }
1194}
1195
1196// validateModels validates preferred model configurations
1197func (c *Config) validateModels(errors *ValidationErrors) {
1198 // Validate large model
1199 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1200 if c.Models.Large.ModelID == "" {
1201 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1202 }
1203 if c.Models.Large.Provider == "" {
1204 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1205 }
1206
1207 // Check if provider exists and is not disabled
1208 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1209 if providerConfig.Disabled {
1210 errors.Add("models.large.provider", "large model provider is disabled")
1211 }
1212
1213 // Check if model exists in provider
1214 modelExists := false
1215 for _, model := range providerConfig.Models {
1216 if model.ID == c.Models.Large.ModelID {
1217 modelExists = true
1218 break
1219 }
1220 }
1221 if !modelExists {
1222 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1223 }
1224 } else {
1225 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1226 }
1227 }
1228
1229 // Validate small model
1230 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1231 if c.Models.Small.ModelID == "" {
1232 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1233 }
1234 if c.Models.Small.Provider == "" {
1235 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1236 }
1237
1238 // Check if provider exists and is not disabled
1239 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1240 if providerConfig.Disabled {
1241 errors.Add("models.small.provider", "small model provider is disabled")
1242 }
1243
1244 // Check if model exists in provider
1245 modelExists := false
1246 for _, model := range providerConfig.Models {
1247 if model.ID == c.Models.Small.ModelID {
1248 modelExists = true
1249 break
1250 }
1251 }
1252 if !modelExists {
1253 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1254 }
1255 } else {
1256 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1257 }
1258 }
1259}
1260
1261// validateAgents validates agent configurations
1262func (c *Config) validateAgents(errors *ValidationErrors) {
1263 if c.Agents == nil {
1264 c.Agents = make(map[AgentID]Agent)
1265 }
1266
1267 validTools := []string{
1268 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent", "definitions", "diagnostics",
1269 }
1270
1271 for agentID, agent := range c.Agents {
1272 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1273
1274 // Validate agent ID consistency
1275 if agent.ID != agentID {
1276 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1277 }
1278
1279 // Validate required fields
1280 if agent.ID == "" {
1281 errors.Add(fieldPrefix+".id", "agent ID is required")
1282 }
1283 if agent.Name == "" {
1284 errors.Add(fieldPrefix+".name", "agent name is required")
1285 }
1286
1287 // Validate model type
1288 if agent.Model != LargeModel && agent.Model != SmallModel {
1289 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1290 }
1291
1292 // Validate allowed tools
1293 if agent.AllowedTools != nil {
1294 for i, tool := range agent.AllowedTools {
1295 validTool := slices.Contains(validTools, tool)
1296 if !validTool {
1297 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1298 }
1299 }
1300 }
1301
1302 // Validate MCP references
1303 if agent.AllowedMCP != nil {
1304 for mcpName := range agent.AllowedMCP {
1305 if _, exists := c.MCP[mcpName]; !exists {
1306 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1307 }
1308 }
1309 }
1310
1311 // Validate LSP references
1312 if agent.AllowedLSP != nil {
1313 for _, lspName := range agent.AllowedLSP {
1314 if _, exists := c.LSP[lspName]; !exists {
1315 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1316 }
1317 }
1318 }
1319
1320 // Validate context paths (basic path validation)
1321 for i, contextPath := range agent.ContextPaths {
1322 if contextPath == "" {
1323 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1324 }
1325 // Check for invalid characters in path
1326 if strings.Contains(contextPath, "\x00") {
1327 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1328 }
1329 }
1330
1331 // Validate known agents maintain their core properties
1332 if agentID == AgentCoder {
1333 if agent.Name != "Coder" {
1334 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1335 }
1336 if agent.Description != "An agent that helps with executing coding tasks." {
1337 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1338 }
1339 } else if agentID == AgentTask {
1340 if agent.Name != "Task" {
1341 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1342 }
1343 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1344 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1345 }
1346 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view", "definitions"}
1347 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1348 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1349 }
1350 }
1351 }
1352}
1353
1354// validateOptions validates configuration options
1355func (c *Config) validateOptions(errors *ValidationErrors) {
1356 // Validate data directory
1357 if c.Options.DataDirectory == "" {
1358 errors.Add("options.data_directory", "data directory is required")
1359 }
1360
1361 // Validate context paths
1362 for i, contextPath := range c.Options.ContextPaths {
1363 if contextPath == "" {
1364 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1365 }
1366 if strings.Contains(contextPath, "\x00") {
1367 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1368 }
1369 }
1370}
1371
1372// validateMCPs validates MCP configurations
1373func (c *Config) validateMCPs(errors *ValidationErrors) {
1374 if c.MCP == nil {
1375 c.MCP = make(map[string]MCP)
1376 }
1377
1378 for mcpName, mcpConfig := range c.MCP {
1379 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1380
1381 // Validate MCP type
1382 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1383 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1384 }
1385
1386 // Validate based on type
1387 if mcpConfig.Type == MCPStdio {
1388 if mcpConfig.Command == "" {
1389 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1390 }
1391 } else if mcpConfig.Type == MCPSse {
1392 if mcpConfig.URL == "" {
1393 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1394 }
1395 }
1396 }
1397}
1398
1399// validateLSPs validates LSP configurations
1400func (c *Config) validateLSPs(errors *ValidationErrors) {
1401 if c.LSP == nil {
1402 c.LSP = make(map[string]LSPConfig)
1403 }
1404
1405 for lspName, lspConfig := range c.LSP {
1406 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1407
1408 if lspConfig.Command == "" {
1409 errors.Add(fieldPrefix+".command", "command is required for LSP")
1410 }
1411 }
1412}
1413
1414// validateCrossReferences validates cross-references between different config sections
1415func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1416 // Validate that agents can use their assigned model types
1417 for agentID, agent := range c.Agents {
1418 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1419
1420 var preferredModel PreferredModel
1421 switch agent.Model {
1422 case LargeModel:
1423 preferredModel = c.Models.Large
1424 case SmallModel:
1425 preferredModel = c.Models.Small
1426 }
1427
1428 if preferredModel.Provider != "" {
1429 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1430 if providerConfig.Disabled {
1431 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1432 }
1433 }
1434 }
1435 }
1436}
1437
1438// validateCompleteness validates that the configuration is complete and usable
1439func (c *Config) validateCompleteness(errors *ValidationErrors) {
1440 // Check for at least one valid, non-disabled provider
1441 hasValidProvider := false
1442 for _, providerConfig := range c.Providers {
1443 if !providerConfig.Disabled {
1444 hasValidProvider = true
1445 break
1446 }
1447 }
1448 if !hasValidProvider {
1449 errors.Add("providers", "at least one non-disabled provider is required")
1450 }
1451
1452 // Check that default agents exist
1453 if _, exists := c.Agents[AgentCoder]; !exists {
1454 errors.Add("agents", "coder agent is required")
1455 }
1456 if _, exists := c.Agents[AgentTask]; !exists {
1457 errors.Add("agents", "task agent is required")
1458 }
1459
1460 // Check that preferred models are set if providers exist
1461 if hasValidProvider {
1462 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1463 errors.Add("models.large", "large preferred model must be configured when providers are available")
1464 }
1465 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1466 errors.Add("models.small", "small preferred model must be configured when providers are available")
1467 }
1468 }
1469}
1470
1471// JSONSchemaExtend adds custom schema properties for AgentID
1472func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
1473 schema.Enum = []any{
1474 string(AgentCoder),
1475 string(AgentTask),
1476 }
1477}
1478
1479// JSONSchemaExtend adds custom schema properties for ModelType
1480func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
1481 schema.Enum = []any{
1482 string(LargeModel),
1483 string(SmallModel),
1484 }
1485}
1486
1487// JSONSchemaExtend adds custom schema properties for MCPType
1488func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
1489 schema.Enum = []any{
1490 string(MCPStdio),
1491 string(MCPSse),
1492 }
1493}