1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/invopop/jsonschema"
18)
19
20const (
21 defaultDataDirectory = ".crush"
22 defaultLogLevel = "info"
23 appName = "crush"
24
25 MaxTokensFallbackDefault = 4096
26)
27
28var defaultContextPaths = []string{
29 ".github/copilot-instructions.md",
30 ".cursorrules",
31 ".cursor/rules/",
32 "CLAUDE.md",
33 "CLAUDE.local.md",
34 "GEMINI.md",
35 "gemini.md",
36 "crush.md",
37 "crush.local.md",
38 "Crush.md",
39 "Crush.local.md",
40 "CRUSH.md",
41 "CRUSH.local.md",
42}
43
44type AgentID string
45
46const (
47 AgentCoder AgentID = "coder"
48 AgentTask AgentID = "task"
49)
50
51type ModelType string
52
53const (
54 LargeModel ModelType = "large"
55 SmallModel ModelType = "small"
56)
57
58type Model struct {
59 ID string `json:"id" jsonschema:"title=Model ID,description=Unique identifier for the model, the API model"`
60 Name string `json:"name" jsonschema:"title=Model Name,description=Display name of the model"`
61 CostPer1MIn float64 `json:"cost_per_1m_in,omitempty" jsonschema:"title=Input Cost,description=Cost per 1 million input tokens,minimum=0"`
62 CostPer1MOut float64 `json:"cost_per_1m_out,omitempty" jsonschema:"title=Output Cost,description=Cost per 1 million output tokens,minimum=0"`
63 CostPer1MInCached float64 `json:"cost_per_1m_in_cached,omitempty" jsonschema:"title=Cached Input Cost,description=Cost per 1 million cached input tokens,minimum=0"`
64 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached,omitempty" jsonschema:"title=Cached Output Cost,description=Cost per 1 million cached output tokens,minimum=0"`
65 ContextWindow int64 `json:"context_window" jsonschema:"title=Context Window,description=Maximum context window size in tokens,minimum=1"`
66 DefaultMaxTokens int64 `json:"default_max_tokens" jsonschema:"title=Default Max Tokens,description=Default maximum tokens for responses,minimum=1"`
67 CanReason bool `json:"can_reason,omitempty" jsonschema:"title=Can Reason,description=Whether the model supports reasoning capabilities"`
68 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Default reasoning effort level for reasoning models"`
69 HasReasoningEffort bool `json:"has_reasoning_effort,omitempty" jsonschema:"title=Has Reasoning Effort,description=Whether the model supports reasoning effort configuration"`
70 SupportsImages bool `json:"supports_attachments,omitempty" jsonschema:"title=Supports Images,description=Whether the model supports image attachments"`
71}
72
73type VertexAIOptions struct {
74 APIKey string `json:"api_key,omitempty"`
75 Project string `json:"project,omitempty"`
76 Location string `json:"location,omitempty"`
77}
78
79type ProviderConfig struct {
80 ID provider.InferenceProvider `json:"id,omitempty" jsonschema:"title=Provider ID,description=Unique identifier for the provider"`
81 BaseURL string `json:"base_url,omitempty" jsonschema:"title=Base URL,description=Base URL for the provider API (required for custom providers)"`
82 ProviderType provider.Type `json:"provider_type" jsonschema:"title=Provider Type,description=Type of the provider (openai, anthropic, etc.)"`
83 APIKey string `json:"api_key,omitempty" jsonschema:"title=API Key,description=API key for authenticating with the provider"`
84 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this provider is disabled,default=false"`
85 ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"title=Extra Headers,description=Additional HTTP headers to send with requests"`
86 // used for e.x for vertex to set the project
87 ExtraParams map[string]string `json:"extra_params,omitempty" jsonschema:"title=Extra Parameters,description=Additional provider-specific parameters"`
88
89 DefaultLargeModel string `json:"default_large_model,omitempty" jsonschema:"title=Default Large Model,description=Default model ID for large model type"`
90 DefaultSmallModel string `json:"default_small_model,omitempty" jsonschema:"title=Default Small Model,description=Default model ID for small model type"`
91
92 Models []Model `json:"models,omitempty" jsonschema:"title=Models,description=List of available models for this provider"`
93}
94
95type Agent struct {
96 ID AgentID `json:"id,omitempty" jsonschema:"title=Agent ID,description=Unique identifier for the agent,enum=coder,enum=task"`
97 Name string `json:"name,omitempty" jsonschema:"title=Name,description=Display name of the agent"`
98 Description string `json:"description,omitempty" jsonschema:"title=Description,description=Description of what the agent does"`
99 // This is the id of the system prompt used by the agent
100 Disabled bool `json:"disabled,omitempty" jsonschema:"title=Disabled,description=Whether this agent is disabled,default=false"`
101
102 Model ModelType `json:"model" jsonschema:"title=Model Type,description=Type of model to use (large or small),enum=large,enum=small"`
103
104 // The available tools for the agent
105 // if this is nil, all tools are available
106 AllowedTools []string `json:"allowed_tools,omitempty" jsonschema:"title=Allowed Tools,description=List of tools this agent is allowed to use (if nil all tools are allowed)"`
107
108 // this tells us which MCPs are available for this agent
109 // if this is empty all mcps are available
110 // the string array is the list of tools from the AllowedMCP the agent has available
111 // if the string array is nil, all tools from the AllowedMCP are available
112 AllowedMCP map[string][]string `json:"allowed_mcp,omitempty" jsonschema:"title=Allowed MCP,description=Map of MCP servers this agent can use and their allowed tools"`
113
114 // The list of LSPs that this agent can use
115 // if this is nil, all LSPs are available
116 AllowedLSP []string `json:"allowed_lsp,omitempty" jsonschema:"title=Allowed LSP,description=List of LSP servers this agent can use (if nil all LSPs are allowed)"`
117
118 // Overrides the context paths for this agent
119 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=Custom context paths for this agent (additive to global context paths)"`
120}
121
122type MCPType string
123
124const (
125 MCPStdio MCPType = "stdio"
126 MCPSse MCPType = "sse"
127)
128
129type MCP struct {
130 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for stdio MCP servers"`
131 Env []string `json:"env,omitempty" jsonschema:"title=Environment,description=Environment variables for the MCP server"`
132 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the MCP server"`
133 Type MCPType `json:"type" jsonschema:"title=Type,description=Type of MCP connection,enum=stdio,enum=sse,default=stdio"`
134 URL string `json:"url,omitempty" jsonschema:"title=URL,description=URL for SSE MCP servers"`
135 // TODO: maybe make it possible to get the value from the env
136 Headers map[string]string `json:"headers,omitempty" jsonschema:"title=Headers,description=HTTP headers for SSE MCP servers"`
137}
138
139type LSPConfig struct {
140 Disabled bool `json:"enabled,omitempty" jsonschema:"title=Enabled,description=Whether this LSP server is enabled,default=true"`
141 Command string `json:"command" jsonschema:"title=Command,description=Command to execute for the LSP server"`
142 Args []string `json:"args,omitempty" jsonschema:"title=Arguments,description=Command line arguments for the LSP server"`
143 Options any `json:"options,omitempty" jsonschema:"title=Options,description=LSP server specific options"`
144}
145
146type TUIOptions struct {
147 CompactMode bool `json:"compact_mode" jsonschema:"title=Compact Mode,description=Enable compact mode for the TUI,default=false"`
148 // Here we can add themes later or any TUI related options
149}
150
151type Options struct {
152 ContextPaths []string `json:"context_paths,omitempty" jsonschema:"title=Context Paths,description=List of paths to search for context files"`
153 TUI TUIOptions `json:"tui,omitempty" jsonschema:"title=TUI Options,description=Terminal UI configuration options"`
154 Debug bool `json:"debug,omitempty" jsonschema:"title=Debug,description=Enable debug logging,default=false"`
155 DebugLSP bool `json:"debug_lsp,omitempty" jsonschema:"title=Debug LSP,description=Enable LSP debug logging,default=false"`
156 DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty" jsonschema:"title=Disable Auto Summarize,description=Disable automatic conversation summarization,default=false"`
157 // Relative to the cwd
158 DataDirectory string `json:"data_directory,omitempty" jsonschema:"title=Data Directory,description=Directory for storing application data,default=.crush"`
159}
160
161type PreferredModel struct {
162 ModelID string `json:"model_id" jsonschema:"title=Model ID,description=ID of the preferred model"`
163 Provider provider.InferenceProvider `json:"provider" jsonschema:"title=Provider,description=Provider for the preferred model"`
164 // ReasoningEffort overrides the default reasoning effort for this model
165 ReasoningEffort string `json:"reasoning_effort,omitempty" jsonschema:"title=Reasoning Effort,description=Override reasoning effort for this model"`
166 // MaxTokens overrides the default max tokens for this model
167 MaxTokens int64 `json:"max_tokens,omitempty" jsonschema:"title=Max Tokens,description=Override max tokens for this model,minimum=1"`
168
169 // Think indicates if the model should think, only applicable for anthropic reasoning models
170 Think bool `json:"think,omitempty" jsonschema:"title=Think,description=Enable thinking for reasoning models,default=false"`
171}
172
173type PreferredModels struct {
174 Large PreferredModel `json:"large,omitempty" jsonschema:"title=Large Model,description=Preferred model configuration for large model type"`
175 Small PreferredModel `json:"small,omitempty" jsonschema:"title=Small Model,description=Preferred model configuration for small model type"`
176}
177
178type Config struct {
179 Models PreferredModels `json:"models,omitempty" jsonschema:"title=Models,description=Preferred model configurations for large and small model types"`
180 // List of configured providers
181 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty" jsonschema:"title=Providers,description=LLM provider configurations"`
182
183 // List of configured agents
184 Agents map[AgentID]Agent `json:"agents,omitempty" jsonschema:"title=Agents,description=Agent configurations for different tasks"`
185
186 // List of configured MCPs
187 MCP map[string]MCP `json:"mcp,omitempty" jsonschema:"title=MCP,description=Model Control Protocol server configurations"`
188
189 // List of configured LSPs
190 LSP map[string]LSPConfig `json:"lsp,omitempty" jsonschema:"title=LSP,description=Language Server Protocol configurations"`
191
192 // Miscellaneous options
193 Options Options `json:"options,omitempty" jsonschema:"title=Options,description=General application options and settings"`
194}
195
196var (
197 instance *Config // The single instance of the Singleton
198 cwd string
199 once sync.Once // Ensures the initialization happens only once
200
201)
202
203func loadConfig(cwd string, debug bool) (*Config, error) {
204 // First read the global config file
205 cfgPath := ConfigPath()
206
207 cfg := defaultConfigBasedOnEnv()
208 cfg.Options.Debug = debug
209 defaultLevel := slog.LevelInfo
210 if cfg.Options.Debug {
211 defaultLevel = slog.LevelDebug
212 }
213 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
214 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
215
216 // if file does not exist create it
217 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
218 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
219 return cfg, fmt.Errorf("failed to create directory: %w", err)
220 }
221 if _, err := os.Create(loggingFile); err != nil {
222 return cfg, fmt.Errorf("failed to create log file: %w", err)
223 }
224 }
225
226 messagesPath := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "messages")
227
228 if _, err := os.Stat(messagesPath); os.IsNotExist(err) {
229 if err := os.MkdirAll(messagesPath, 0o756); err != nil {
230 return cfg, fmt.Errorf("failed to create directory: %w", err)
231 }
232 }
233 logging.MessageDir = messagesPath
234
235 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
236 if err != nil {
237 return cfg, fmt.Errorf("failed to open log file: %w", err)
238 }
239 // Configure logger
240 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
241 Level: defaultLevel,
242 }))
243 slog.SetDefault(logger)
244 } else {
245 // Configure logger
246 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
247 Level: defaultLevel,
248 }))
249 slog.SetDefault(logger)
250 }
251 var globalCfg *Config
252 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
253 // some other error occurred while checking the file
254 return nil, err
255 } else if err == nil {
256 // config file exists, read it
257 file, err := os.ReadFile(cfgPath)
258 if err != nil {
259 return nil, err
260 }
261 globalCfg = &Config{}
262 if err := json.Unmarshal(file, globalCfg); err != nil {
263 return nil, err
264 }
265 } else {
266 // config file does not exist, create a new one
267 globalCfg = &Config{}
268 }
269
270 var localConfig *Config
271 // Global config loaded, now read the local config file
272 localConfigPath := filepath.Join(cwd, "crush.json")
273 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
274 // some other error occurred while checking the file
275 return nil, err
276 } else if err == nil {
277 // local config file exists, read it
278 file, err := os.ReadFile(localConfigPath)
279 if err != nil {
280 return nil, err
281 }
282 localConfig = &Config{}
283 if err := json.Unmarshal(file, localConfig); err != nil {
284 return nil, err
285 }
286 }
287
288 // merge options
289 mergeOptions(cfg, globalCfg, localConfig)
290
291 mergeProviderConfigs(cfg, globalCfg, localConfig)
292 // no providers found the app is not initialized yet
293 if len(cfg.Providers) == 0 {
294 return cfg, nil
295 }
296 preferredProvider := getPreferredProvider(cfg.Providers)
297 if preferredProvider != nil {
298 cfg.Models = PreferredModels{
299 Large: PreferredModel{
300 ModelID: preferredProvider.DefaultLargeModel,
301 Provider: preferredProvider.ID,
302 },
303 Small: PreferredModel{
304 ModelID: preferredProvider.DefaultSmallModel,
305 Provider: preferredProvider.ID,
306 },
307 }
308 } else {
309 // No valid providers found, set empty models
310 cfg.Models = PreferredModels{}
311 }
312
313 mergeModels(cfg, globalCfg, localConfig)
314
315 agents := map[AgentID]Agent{
316 AgentCoder: {
317 ID: AgentCoder,
318 Name: "Coder",
319 Description: "An agent that helps with executing coding tasks.",
320 Model: LargeModel,
321 ContextPaths: cfg.Options.ContextPaths,
322 // All tools allowed
323 },
324 AgentTask: {
325 ID: AgentTask,
326 Name: "Task",
327 Description: "An agent that helps with searching for context and finding implementation details.",
328 Model: LargeModel,
329 ContextPaths: cfg.Options.ContextPaths,
330 AllowedTools: []string{
331 "glob",
332 "grep",
333 "ls",
334 "sourcegraph",
335 "view",
336 },
337 // NO MCPs or LSPs by default
338 AllowedMCP: map[string][]string{},
339 AllowedLSP: []string{},
340 },
341 }
342 cfg.Agents = agents
343 mergeAgents(cfg, globalCfg, localConfig)
344 mergeMCPs(cfg, globalCfg, localConfig)
345 mergeLSPs(cfg, globalCfg, localConfig)
346
347 // Validate the final configuration
348 if err := cfg.Validate(); err != nil {
349 return cfg, fmt.Errorf("configuration validation failed: %w", err)
350 }
351
352 return cfg, nil
353}
354
355func Init(workingDir string, debug bool) (*Config, error) {
356 var err error
357 once.Do(func() {
358 cwd = workingDir
359 instance, err = loadConfig(cwd, debug)
360 if err != nil {
361 logging.Error("Failed to load config", "error", err)
362 }
363 })
364
365 return instance, err
366}
367
368func Get() *Config {
369 if instance == nil {
370 // TODO: Handle this better
371 panic("Config not initialized. Call InitConfig first.")
372 }
373 return instance
374}
375
376func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
377 providers := Providers()
378 for _, p := range providers {
379 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
380 return &providerConfig
381 }
382 }
383 // if none found return the first configured provider
384 for _, providerConfig := range configuredProviders {
385 if !providerConfig.Disabled {
386 return &providerConfig
387 }
388 }
389 return nil
390}
391
392func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
393 if other.APIKey != "" {
394 base.APIKey = other.APIKey
395 }
396 // Only change these options if the provider is not a known provider
397 if !slices.Contains(provider.KnownProviders(), p) {
398 if other.BaseURL != "" {
399 base.BaseURL = other.BaseURL
400 }
401 if other.ProviderType != "" {
402 base.ProviderType = other.ProviderType
403 }
404 if len(other.ExtraHeaders) > 0 {
405 if base.ExtraHeaders == nil {
406 base.ExtraHeaders = make(map[string]string)
407 }
408 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
409 }
410 if len(other.ExtraParams) > 0 {
411 if base.ExtraParams == nil {
412 base.ExtraParams = make(map[string]string)
413 }
414 maps.Copy(base.ExtraParams, other.ExtraParams)
415 }
416 }
417
418 if other.Disabled {
419 base.Disabled = other.Disabled
420 }
421
422 if other.DefaultLargeModel != "" {
423 base.DefaultLargeModel = other.DefaultLargeModel
424 }
425 // Add new models if they don't exist
426 if other.Models != nil {
427 for _, model := range other.Models {
428 // check if the model already exists
429 exists := false
430 for _, existingModel := range base.Models {
431 if existingModel.ID == model.ID {
432 exists = true
433 break
434 }
435 }
436 if !exists {
437 base.Models = append(base.Models, model)
438 }
439 }
440 }
441
442 return base
443}
444
445func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
446 if !slices.Contains(provider.KnownProviders(), p) {
447 if providerConfig.ProviderType != provider.TypeOpenAI {
448 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
449 }
450 if providerConfig.BaseURL == "" {
451 return errors.New("base URL must be set for custom providers")
452 }
453 if providerConfig.APIKey == "" {
454 return errors.New("API key must be set for custom providers")
455 }
456 }
457 return nil
458}
459
460func mergeModels(base, global, local *Config) {
461 for _, cfg := range []*Config{global, local} {
462 if cfg == nil {
463 continue
464 }
465 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
466 base.Models.Large = cfg.Models.Large
467 }
468
469 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
470 base.Models.Small = cfg.Models.Small
471 }
472 }
473}
474
475func mergeOptions(base, global, local *Config) {
476 for _, cfg := range []*Config{global, local} {
477 if cfg == nil {
478 continue
479 }
480 baseOptions := base.Options
481 other := cfg.Options
482 if len(other.ContextPaths) > 0 {
483 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
484 }
485
486 if other.TUI.CompactMode {
487 baseOptions.TUI.CompactMode = other.TUI.CompactMode
488 }
489
490 if other.Debug {
491 baseOptions.Debug = other.Debug
492 }
493
494 if other.DebugLSP {
495 baseOptions.DebugLSP = other.DebugLSP
496 }
497
498 if other.DisableAutoSummarize {
499 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
500 }
501
502 if other.DataDirectory != "" {
503 baseOptions.DataDirectory = other.DataDirectory
504 }
505 base.Options = baseOptions
506 }
507}
508
509func mergeAgents(base, global, local *Config) {
510 for _, cfg := range []*Config{global, local} {
511 if cfg == nil {
512 continue
513 }
514 for agentID, newAgent := range cfg.Agents {
515 if _, ok := base.Agents[agentID]; !ok {
516 newAgent.ID = agentID
517 if newAgent.Model == "" {
518 newAgent.Model = LargeModel
519 }
520 if len(newAgent.ContextPaths) > 0 {
521 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
522 } else {
523 newAgent.ContextPaths = base.Options.ContextPaths
524 }
525 base.Agents[agentID] = newAgent
526 } else {
527 baseAgent := base.Agents[agentID]
528
529 if agentID == AgentCoder || agentID == AgentTask {
530 if newAgent.Model != "" {
531 baseAgent.Model = newAgent.Model
532 }
533 if newAgent.AllowedMCP != nil {
534 baseAgent.AllowedMCP = newAgent.AllowedMCP
535 }
536 if newAgent.AllowedLSP != nil {
537 baseAgent.AllowedLSP = newAgent.AllowedLSP
538 }
539 // Context paths are additive for known agents too
540 if len(newAgent.ContextPaths) > 0 {
541 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
542 }
543 } else {
544 if newAgent.Name != "" {
545 baseAgent.Name = newAgent.Name
546 }
547 if newAgent.Description != "" {
548 baseAgent.Description = newAgent.Description
549 }
550 if newAgent.Model != "" {
551 baseAgent.Model = newAgent.Model
552 } else if baseAgent.Model == "" {
553 baseAgent.Model = LargeModel
554 }
555
556 baseAgent.Disabled = newAgent.Disabled
557
558 if newAgent.AllowedTools != nil {
559 baseAgent.AllowedTools = newAgent.AllowedTools
560 }
561 if newAgent.AllowedMCP != nil {
562 baseAgent.AllowedMCP = newAgent.AllowedMCP
563 }
564 if newAgent.AllowedLSP != nil {
565 baseAgent.AllowedLSP = newAgent.AllowedLSP
566 }
567 if len(newAgent.ContextPaths) > 0 {
568 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
569 }
570 }
571
572 base.Agents[agentID] = baseAgent
573 }
574 }
575 }
576}
577
578func mergeMCPs(base, global, local *Config) {
579 for _, cfg := range []*Config{global, local} {
580 if cfg == nil {
581 continue
582 }
583 maps.Copy(base.MCP, cfg.MCP)
584 }
585}
586
587func mergeLSPs(base, global, local *Config) {
588 for _, cfg := range []*Config{global, local} {
589 if cfg == nil {
590 continue
591 }
592 maps.Copy(base.LSP, cfg.LSP)
593 }
594}
595
596func mergeProviderConfigs(base, global, local *Config) {
597 for _, cfg := range []*Config{global, local} {
598 if cfg == nil {
599 continue
600 }
601 for providerName, p := range cfg.Providers {
602 p.ID = providerName
603 if _, ok := base.Providers[providerName]; !ok {
604 base.Providers[providerName] = p
605 } else {
606 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
607 }
608 }
609 }
610
611 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
612 for providerName, providerConfig := range base.Providers {
613 err := validateProvider(providerName, providerConfig)
614 if err != nil {
615 logging.Warn("Skipping provider", "name", providerName, "error", err)
616 continue // Skip invalid providers
617 }
618 finalProviders[providerName] = providerConfig
619 }
620 base.Providers = finalProviders
621}
622
623func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig {
624 switch providerID {
625 case provider.InferenceProviderAnthropic:
626 return ProviderConfig{
627 ID: providerID,
628 ProviderType: provider.TypeAnthropic,
629 }
630 case provider.InferenceProviderOpenAI:
631 return ProviderConfig{
632 ID: providerID,
633 ProviderType: provider.TypeOpenAI,
634 }
635 case provider.InferenceProviderGemini:
636 return ProviderConfig{
637 ID: providerID,
638 ProviderType: provider.TypeGemini,
639 }
640 case provider.InferenceProviderBedrock:
641 return ProviderConfig{
642 ID: providerID,
643 ProviderType: provider.TypeBedrock,
644 }
645 case provider.InferenceProviderAzure:
646 return ProviderConfig{
647 ID: providerID,
648 ProviderType: provider.TypeAzure,
649 }
650 case provider.InferenceProviderOpenRouter:
651 return ProviderConfig{
652 ID: providerID,
653 ProviderType: provider.TypeOpenAI,
654 BaseURL: "https://openrouter.ai/api/v1",
655 ExtraHeaders: map[string]string{
656 "HTTP-Referer": "crush.charm.land",
657 "X-Title": "Crush",
658 },
659 }
660 case provider.InferenceProviderXAI:
661 return ProviderConfig{
662 ID: providerID,
663 ProviderType: provider.TypeXAI,
664 BaseURL: "https://api.x.ai/v1",
665 }
666 case provider.InferenceProviderVertexAI:
667 return ProviderConfig{
668 ID: providerID,
669 ProviderType: provider.TypeVertexAI,
670 }
671 default:
672 return ProviderConfig{
673 ID: providerID,
674 ProviderType: provider.TypeOpenAI,
675 }
676 }
677}
678
679func defaultConfigBasedOnEnv() *Config {
680 cfg := &Config{
681 Options: Options{
682 DataDirectory: defaultDataDirectory,
683 ContextPaths: defaultContextPaths,
684 },
685 Providers: make(map[provider.InferenceProvider]ProviderConfig),
686 Agents: make(map[AgentID]Agent),
687 LSP: make(map[string]LSPConfig),
688 MCP: make(map[string]MCP),
689 }
690
691 providers := Providers()
692
693 for _, p := range providers {
694 if strings.HasPrefix(p.APIKey, "$") {
695 envVar := strings.TrimPrefix(p.APIKey, "$")
696 if apiKey := os.Getenv(envVar); apiKey != "" {
697 providerConfig := providerDefaultConfig(p.ID)
698 providerConfig.APIKey = apiKey
699 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
700 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
701 baseURL := p.APIEndpoint
702 if strings.HasPrefix(baseURL, "$") {
703 envVar := strings.TrimPrefix(baseURL, "$")
704 baseURL = os.Getenv(envVar)
705 }
706 providerConfig.BaseURL = baseURL
707 for _, model := range p.Models {
708 configModel := Model{
709 ID: model.ID,
710 Name: model.Name,
711 CostPer1MIn: model.CostPer1MIn,
712 CostPer1MOut: model.CostPer1MOut,
713 CostPer1MInCached: model.CostPer1MInCached,
714 CostPer1MOutCached: model.CostPer1MOutCached,
715 ContextWindow: model.ContextWindow,
716 DefaultMaxTokens: model.DefaultMaxTokens,
717 CanReason: model.CanReason,
718 SupportsImages: model.SupportsImages,
719 }
720 // Set reasoning effort for reasoning models
721 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
722 configModel.HasReasoningEffort = model.HasReasoningEffort
723 configModel.ReasoningEffort = model.DefaultReasoningEffort
724 }
725 providerConfig.Models = append(providerConfig.Models, configModel)
726 }
727 cfg.Providers[p.ID] = providerConfig
728 }
729 }
730 }
731 // TODO: support local models
732
733 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
734 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
735 providerConfig.ExtraParams = map[string]string{
736 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
737 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
738 }
739 // Find the VertexAI provider definition to get default models
740 for _, p := range providers {
741 if p.ID == provider.InferenceProviderVertexAI {
742 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
743 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
744 for _, model := range p.Models {
745 configModel := Model{
746 ID: model.ID,
747 Name: model.Name,
748 CostPer1MIn: model.CostPer1MIn,
749 CostPer1MOut: model.CostPer1MOut,
750 CostPer1MInCached: model.CostPer1MInCached,
751 CostPer1MOutCached: model.CostPer1MOutCached,
752 ContextWindow: model.ContextWindow,
753 DefaultMaxTokens: model.DefaultMaxTokens,
754 CanReason: model.CanReason,
755 SupportsImages: model.SupportsImages,
756 }
757 // Set reasoning effort for reasoning models
758 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
759 configModel.HasReasoningEffort = model.HasReasoningEffort
760 configModel.ReasoningEffort = model.DefaultReasoningEffort
761 }
762 providerConfig.Models = append(providerConfig.Models, configModel)
763 }
764 break
765 }
766 }
767 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
768 }
769
770 if hasAWSCredentials() {
771 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
772 providerConfig.ExtraParams = map[string]string{
773 "region": os.Getenv("AWS_DEFAULT_REGION"),
774 }
775 if providerConfig.ExtraParams["region"] == "" {
776 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
777 }
778 // Find the Bedrock provider definition to get default models
779 for _, p := range providers {
780 if p.ID == provider.InferenceProviderBedrock {
781 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
782 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
783 for _, model := range p.Models {
784 configModel := Model{
785 ID: model.ID,
786 Name: model.Name,
787 CostPer1MIn: model.CostPer1MIn,
788 CostPer1MOut: model.CostPer1MOut,
789 CostPer1MInCached: model.CostPer1MInCached,
790 CostPer1MOutCached: model.CostPer1MOutCached,
791 ContextWindow: model.ContextWindow,
792 DefaultMaxTokens: model.DefaultMaxTokens,
793 CanReason: model.CanReason,
794 SupportsImages: model.SupportsImages,
795 }
796 // Set reasoning effort for reasoning models
797 if model.HasReasoningEffort && model.DefaultReasoningEffort != "" {
798 configModel.HasReasoningEffort = model.HasReasoningEffort
799 configModel.ReasoningEffort = model.DefaultReasoningEffort
800 }
801 providerConfig.Models = append(providerConfig.Models, configModel)
802 }
803 break
804 }
805 }
806 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
807 }
808 return cfg
809}
810
811func hasAWSCredentials() bool {
812 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
813 return true
814 }
815
816 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
817 return true
818 }
819
820 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
821 return true
822 }
823
824 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
825 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
826 return true
827 }
828
829 return false
830}
831
832func WorkingDirectory() string {
833 return cwd
834}
835
836// TODO: Handle error state
837
838func GetAgentModel(agentID AgentID) Model {
839 cfg := Get()
840 agent, ok := cfg.Agents[agentID]
841 if !ok {
842 logging.Error("Agent not found", "agent_id", agentID)
843 return Model{}
844 }
845
846 var model PreferredModel
847 switch agent.Model {
848 case LargeModel:
849 model = cfg.Models.Large
850 case SmallModel:
851 model = cfg.Models.Small
852 default:
853 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
854 model = cfg.Models.Large // Fallback to large model
855 }
856 providerConfig, ok := cfg.Providers[model.Provider]
857 if !ok {
858 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
859 return Model{}
860 }
861
862 for _, m := range providerConfig.Models {
863 if m.ID == model.ModelID {
864 return m
865 }
866 }
867
868 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
869 return Model{}
870}
871
872// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
873// considering any overrides from the preferred model configuration
874func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
875 cfg := Get()
876 agent, ok := cfg.Agents[agentID]
877 if !ok {
878 logging.Error("Agent not found", "agent_id", agentID)
879 return 0
880 }
881
882 var preferredModel PreferredModel
883 switch agent.Model {
884 case LargeModel:
885 preferredModel = cfg.Models.Large
886 case SmallModel:
887 preferredModel = cfg.Models.Small
888 default:
889 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
890 preferredModel = cfg.Models.Large // Fallback to large model
891 }
892
893 // Get the base model configuration
894 baseModel := GetAgentModel(agentID)
895 if baseModel.ID == "" {
896 return 0
897 }
898
899 // Start with the default max tokens from the base model
900 maxTokens := baseModel.DefaultMaxTokens
901
902 // Override with preferred model max tokens if set
903 if preferredModel.MaxTokens > 0 {
904 maxTokens = preferredModel.MaxTokens
905 }
906
907 return maxTokens
908}
909
910func GetAgentProvider(agentID AgentID) ProviderConfig {
911 cfg := Get()
912 agent, ok := cfg.Agents[agentID]
913 if !ok {
914 logging.Error("Agent not found", "agent_id", agentID)
915 return ProviderConfig{}
916 }
917
918 var model PreferredModel
919 switch agent.Model {
920 case LargeModel:
921 model = cfg.Models.Large
922 case SmallModel:
923 model = cfg.Models.Small
924 default:
925 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
926 model = cfg.Models.Large // Fallback to large model
927 }
928
929 providerConfig, ok := cfg.Providers[model.Provider]
930 if !ok {
931 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
932 return ProviderConfig{}
933 }
934
935 return providerConfig
936}
937
938func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
939 cfg := Get()
940 providerConfig, ok := cfg.Providers[provider]
941 if !ok {
942 logging.Error("Provider not found", "provider", provider)
943 return Model{}
944 }
945
946 for _, model := range providerConfig.Models {
947 if model.ID == modelID {
948 return model
949 }
950 }
951
952 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
953 return Model{}
954}
955
956func GetModel(modelType ModelType) Model {
957 cfg := Get()
958 var model PreferredModel
959 switch modelType {
960 case LargeModel:
961 model = cfg.Models.Large
962 case SmallModel:
963 model = cfg.Models.Small
964 default:
965 model = cfg.Models.Large // Fallback to large model
966 }
967 providerConfig, ok := cfg.Providers[model.Provider]
968 if !ok {
969 return Model{}
970 }
971
972 for _, m := range providerConfig.Models {
973 if m.ID == model.ModelID {
974 return m
975 }
976 }
977 return Model{}
978}
979
980func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
981 cfg := Get()
982 switch modelType {
983 case LargeModel:
984 cfg.Models.Large = model
985 case SmallModel:
986 cfg.Models.Small = model
987 default:
988 return fmt.Errorf("unknown model type: %s", modelType)
989 }
990 return nil
991}
992
993// ValidationError represents a configuration validation error
994type ValidationError struct {
995 Field string
996 Message string
997}
998
999func (e ValidationError) Error() string {
1000 return fmt.Sprintf("validation error in %s: %s", e.Field, e.Message)
1001}
1002
1003// ValidationErrors represents multiple validation errors
1004type ValidationErrors []ValidationError
1005
1006func (e ValidationErrors) Error() string {
1007 if len(e) == 0 {
1008 return "no validation errors"
1009 }
1010 if len(e) == 1 {
1011 return e[0].Error()
1012 }
1013
1014 var messages []string
1015 for _, err := range e {
1016 messages = append(messages, err.Error())
1017 }
1018 return fmt.Sprintf("multiple validation errors: %s", strings.Join(messages, "; "))
1019}
1020
1021// HasErrors returns true if there are any validation errors
1022func (e ValidationErrors) HasErrors() bool {
1023 return len(e) > 0
1024}
1025
1026// Add appends a new validation error
1027func (e *ValidationErrors) Add(field, message string) {
1028 *e = append(*e, ValidationError{Field: field, Message: message})
1029}
1030
1031// Validate performs comprehensive validation of the configuration
1032func (c *Config) Validate() error {
1033 var errors ValidationErrors
1034
1035 // Validate providers
1036 c.validateProviders(&errors)
1037
1038 // Validate models
1039 c.validateModels(&errors)
1040
1041 // Validate agents
1042 c.validateAgents(&errors)
1043
1044 // Validate options
1045 c.validateOptions(&errors)
1046
1047 // Validate MCP configurations
1048 c.validateMCPs(&errors)
1049
1050 // Validate LSP configurations
1051 c.validateLSPs(&errors)
1052
1053 // Validate cross-references
1054 c.validateCrossReferences(&errors)
1055
1056 // Validate completeness
1057 c.validateCompleteness(&errors)
1058
1059 if errors.HasErrors() {
1060 return errors
1061 }
1062
1063 return nil
1064}
1065
1066// validateProviders validates all provider configurations
1067func (c *Config) validateProviders(errors *ValidationErrors) {
1068 if c.Providers == nil {
1069 c.Providers = make(map[provider.InferenceProvider]ProviderConfig)
1070 }
1071
1072 knownProviders := provider.KnownProviders()
1073 validTypes := []provider.Type{
1074 provider.TypeOpenAI,
1075 provider.TypeAnthropic,
1076 provider.TypeGemini,
1077 provider.TypeAzure,
1078 provider.TypeBedrock,
1079 provider.TypeVertexAI,
1080 provider.TypeXAI,
1081 }
1082
1083 for providerID, providerConfig := range c.Providers {
1084 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1085
1086 // Validate API key for non-disabled providers
1087 if !providerConfig.Disabled && providerConfig.APIKey == "" {
1088 // Special case for AWS Bedrock and VertexAI which may use other auth methods
1089 if providerID != provider.InferenceProviderBedrock && providerID != provider.InferenceProviderVertexAI {
1090 errors.Add(fieldPrefix+".api_key", "API key is required for non-disabled providers")
1091 }
1092 }
1093
1094 // Validate provider type
1095 validType := slices.Contains(validTypes, providerConfig.ProviderType)
1096 if !validType {
1097 errors.Add(fieldPrefix+".provider_type", fmt.Sprintf("invalid provider type: %s", providerConfig.ProviderType))
1098 }
1099
1100 // Validate custom providers
1101 isKnownProvider := slices.Contains(knownProviders, providerID)
1102
1103 if !isKnownProvider {
1104 // Custom provider validation
1105 if providerConfig.BaseURL == "" {
1106 errors.Add(fieldPrefix+".base_url", "BaseURL is required for custom providers")
1107 }
1108 if providerConfig.ProviderType != provider.TypeOpenAI {
1109 errors.Add(fieldPrefix+".provider_type", "custom providers currently only support OpenAI type")
1110 }
1111 }
1112
1113 // Validate models
1114 modelIDs := make(map[string]bool)
1115 for i, model := range providerConfig.Models {
1116 modelFieldPrefix := fmt.Sprintf("%s.models[%d]", fieldPrefix, i)
1117
1118 // Check for duplicate model IDs
1119 if modelIDs[model.ID] {
1120 errors.Add(modelFieldPrefix+".id", fmt.Sprintf("duplicate model ID: %s", model.ID))
1121 }
1122 modelIDs[model.ID] = true
1123
1124 // Validate required model fields
1125 if model.ID == "" {
1126 errors.Add(modelFieldPrefix+".id", "model ID is required")
1127 }
1128 if model.Name == "" {
1129 errors.Add(modelFieldPrefix+".name", "model name is required")
1130 }
1131 if model.ContextWindow <= 0 {
1132 errors.Add(modelFieldPrefix+".context_window", "context window must be positive")
1133 }
1134 if model.DefaultMaxTokens <= 0 {
1135 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens must be positive")
1136 }
1137 if model.DefaultMaxTokens > model.ContextWindow {
1138 errors.Add(modelFieldPrefix+".default_max_tokens", "default max tokens cannot exceed context window")
1139 }
1140
1141 // Validate cost fields
1142 if model.CostPer1MIn < 0 {
1143 errors.Add(modelFieldPrefix+".cost_per_1m_in", "cost per 1M input tokens cannot be negative")
1144 }
1145 if model.CostPer1MOut < 0 {
1146 errors.Add(modelFieldPrefix+".cost_per_1m_out", "cost per 1M output tokens cannot be negative")
1147 }
1148 if model.CostPer1MInCached < 0 {
1149 errors.Add(modelFieldPrefix+".cost_per_1m_in_cached", "cached cost per 1M input tokens cannot be negative")
1150 }
1151 if model.CostPer1MOutCached < 0 {
1152 errors.Add(modelFieldPrefix+".cost_per_1m_out_cached", "cached cost per 1M output tokens cannot be negative")
1153 }
1154 }
1155
1156 // Validate default model references
1157 if providerConfig.DefaultLargeModel != "" {
1158 if !modelIDs[providerConfig.DefaultLargeModel] {
1159 errors.Add(fieldPrefix+".default_large_model", fmt.Sprintf("default large model '%s' not found in provider models", providerConfig.DefaultLargeModel))
1160 }
1161 }
1162 if providerConfig.DefaultSmallModel != "" {
1163 if !modelIDs[providerConfig.DefaultSmallModel] {
1164 errors.Add(fieldPrefix+".default_small_model", fmt.Sprintf("default small model '%s' not found in provider models", providerConfig.DefaultSmallModel))
1165 }
1166 }
1167
1168 // Validate provider-specific requirements
1169 c.validateProviderSpecific(providerID, providerConfig, errors)
1170 }
1171}
1172
1173// validateProviderSpecific validates provider-specific requirements
1174func (c *Config) validateProviderSpecific(providerID provider.InferenceProvider, providerConfig ProviderConfig, errors *ValidationErrors) {
1175 fieldPrefix := fmt.Sprintf("providers.%s", providerID)
1176
1177 switch providerID {
1178 case provider.InferenceProviderVertexAI:
1179 if !providerConfig.Disabled {
1180 if providerConfig.ExtraParams == nil {
1181 errors.Add(fieldPrefix+".extra_params", "VertexAI requires extra_params configuration")
1182 } else {
1183 if providerConfig.ExtraParams["project"] == "" {
1184 errors.Add(fieldPrefix+".extra_params.project", "VertexAI requires project parameter")
1185 }
1186 if providerConfig.ExtraParams["location"] == "" {
1187 errors.Add(fieldPrefix+".extra_params.location", "VertexAI requires location parameter")
1188 }
1189 }
1190 }
1191 case provider.InferenceProviderBedrock:
1192 if !providerConfig.Disabled {
1193 if providerConfig.ExtraParams == nil || providerConfig.ExtraParams["region"] == "" {
1194 errors.Add(fieldPrefix+".extra_params.region", "Bedrock requires region parameter")
1195 }
1196 // Check for AWS credentials in environment
1197 if !hasAWSCredentials() {
1198 errors.Add(fieldPrefix, "Bedrock requires AWS credentials in environment")
1199 }
1200 }
1201 }
1202}
1203
1204// validateModels validates preferred model configurations
1205func (c *Config) validateModels(errors *ValidationErrors) {
1206 // Validate large model
1207 if c.Models.Large.ModelID != "" || c.Models.Large.Provider != "" {
1208 if c.Models.Large.ModelID == "" {
1209 errors.Add("models.large.model_id", "large model ID is required when provider is set")
1210 }
1211 if c.Models.Large.Provider == "" {
1212 errors.Add("models.large.provider", "large model provider is required when model ID is set")
1213 }
1214
1215 // Check if provider exists and is not disabled
1216 if providerConfig, exists := c.Providers[c.Models.Large.Provider]; exists {
1217 if providerConfig.Disabled {
1218 errors.Add("models.large.provider", "large model provider is disabled")
1219 }
1220
1221 // Check if model exists in provider
1222 modelExists := false
1223 for _, model := range providerConfig.Models {
1224 if model.ID == c.Models.Large.ModelID {
1225 modelExists = true
1226 break
1227 }
1228 }
1229 if !modelExists {
1230 errors.Add("models.large.model_id", fmt.Sprintf("large model '%s' not found in provider '%s'", c.Models.Large.ModelID, c.Models.Large.Provider))
1231 }
1232 } else {
1233 errors.Add("models.large.provider", fmt.Sprintf("large model provider '%s' not found", c.Models.Large.Provider))
1234 }
1235 }
1236
1237 // Validate small model
1238 if c.Models.Small.ModelID != "" || c.Models.Small.Provider != "" {
1239 if c.Models.Small.ModelID == "" {
1240 errors.Add("models.small.model_id", "small model ID is required when provider is set")
1241 }
1242 if c.Models.Small.Provider == "" {
1243 errors.Add("models.small.provider", "small model provider is required when model ID is set")
1244 }
1245
1246 // Check if provider exists and is not disabled
1247 if providerConfig, exists := c.Providers[c.Models.Small.Provider]; exists {
1248 if providerConfig.Disabled {
1249 errors.Add("models.small.provider", "small model provider is disabled")
1250 }
1251
1252 // Check if model exists in provider
1253 modelExists := false
1254 for _, model := range providerConfig.Models {
1255 if model.ID == c.Models.Small.ModelID {
1256 modelExists = true
1257 break
1258 }
1259 }
1260 if !modelExists {
1261 errors.Add("models.small.model_id", fmt.Sprintf("small model '%s' not found in provider '%s'", c.Models.Small.ModelID, c.Models.Small.Provider))
1262 }
1263 } else {
1264 errors.Add("models.small.provider", fmt.Sprintf("small model provider '%s' not found", c.Models.Small.Provider))
1265 }
1266 }
1267}
1268
1269// validateAgents validates agent configurations
1270func (c *Config) validateAgents(errors *ValidationErrors) {
1271 if c.Agents == nil {
1272 c.Agents = make(map[AgentID]Agent)
1273 }
1274
1275 validTools := []string{
1276 "bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
1277 }
1278
1279 for agentID, agent := range c.Agents {
1280 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1281
1282 // Validate agent ID consistency
1283 if agent.ID != agentID {
1284 errors.Add(fieldPrefix+".id", fmt.Sprintf("agent ID mismatch: expected '%s', got '%s'", agentID, agent.ID))
1285 }
1286
1287 // Validate required fields
1288 if agent.ID == "" {
1289 errors.Add(fieldPrefix+".id", "agent ID is required")
1290 }
1291 if agent.Name == "" {
1292 errors.Add(fieldPrefix+".name", "agent name is required")
1293 }
1294
1295 // Validate model type
1296 if agent.Model != LargeModel && agent.Model != SmallModel {
1297 errors.Add(fieldPrefix+".model", fmt.Sprintf("invalid model type: %s (must be 'large' or 'small')", agent.Model))
1298 }
1299
1300 // Validate allowed tools
1301 if agent.AllowedTools != nil {
1302 for i, tool := range agent.AllowedTools {
1303 validTool := slices.Contains(validTools, tool)
1304 if !validTool {
1305 errors.Add(fmt.Sprintf("%s.allowed_tools[%d]", fieldPrefix, i), fmt.Sprintf("unknown tool: %s", tool))
1306 }
1307 }
1308 }
1309
1310 // Validate MCP references
1311 if agent.AllowedMCP != nil {
1312 for mcpName := range agent.AllowedMCP {
1313 if _, exists := c.MCP[mcpName]; !exists {
1314 errors.Add(fieldPrefix+".allowed_mcp", fmt.Sprintf("referenced MCP '%s' not found", mcpName))
1315 }
1316 }
1317 }
1318
1319 // Validate LSP references
1320 if agent.AllowedLSP != nil {
1321 for _, lspName := range agent.AllowedLSP {
1322 if _, exists := c.LSP[lspName]; !exists {
1323 errors.Add(fieldPrefix+".allowed_lsp", fmt.Sprintf("referenced LSP '%s' not found", lspName))
1324 }
1325 }
1326 }
1327
1328 // Validate context paths (basic path validation)
1329 for i, contextPath := range agent.ContextPaths {
1330 if contextPath == "" {
1331 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path cannot be empty")
1332 }
1333 // Check for invalid characters in path
1334 if strings.Contains(contextPath, "\x00") {
1335 errors.Add(fmt.Sprintf("%s.context_paths[%d]", fieldPrefix, i), "context path contains invalid characters")
1336 }
1337 }
1338
1339 // Validate known agents maintain their core properties
1340 if agentID == AgentCoder {
1341 if agent.Name != "Coder" {
1342 errors.Add(fieldPrefix+".name", "coder agent name cannot be changed")
1343 }
1344 if agent.Description != "An agent that helps with executing coding tasks." {
1345 errors.Add(fieldPrefix+".description", "coder agent description cannot be changed")
1346 }
1347 } else if agentID == AgentTask {
1348 if agent.Name != "Task" {
1349 errors.Add(fieldPrefix+".name", "task agent name cannot be changed")
1350 }
1351 if agent.Description != "An agent that helps with searching for context and finding implementation details." {
1352 errors.Add(fieldPrefix+".description", "task agent description cannot be changed")
1353 }
1354 expectedTools := []string{"glob", "grep", "ls", "sourcegraph", "view"}
1355 if agent.AllowedTools != nil && !slices.Equal(agent.AllowedTools, expectedTools) {
1356 errors.Add(fieldPrefix+".allowed_tools", "task agent allowed tools cannot be changed")
1357 }
1358 }
1359 }
1360}
1361
1362// validateOptions validates configuration options
1363func (c *Config) validateOptions(errors *ValidationErrors) {
1364 // Validate data directory
1365 if c.Options.DataDirectory == "" {
1366 errors.Add("options.data_directory", "data directory is required")
1367 }
1368
1369 // Validate context paths
1370 for i, contextPath := range c.Options.ContextPaths {
1371 if contextPath == "" {
1372 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path cannot be empty")
1373 }
1374 if strings.Contains(contextPath, "\x00") {
1375 errors.Add(fmt.Sprintf("options.context_paths[%d]", i), "context path contains invalid characters")
1376 }
1377 }
1378}
1379
1380// validateMCPs validates MCP configurations
1381func (c *Config) validateMCPs(errors *ValidationErrors) {
1382 if c.MCP == nil {
1383 c.MCP = make(map[string]MCP)
1384 }
1385
1386 for mcpName, mcpConfig := range c.MCP {
1387 fieldPrefix := fmt.Sprintf("mcp.%s", mcpName)
1388
1389 // Validate MCP type
1390 if mcpConfig.Type != MCPStdio && mcpConfig.Type != MCPSse {
1391 errors.Add(fieldPrefix+".type", fmt.Sprintf("invalid MCP type: %s (must be 'stdio' or 'sse')", mcpConfig.Type))
1392 }
1393
1394 // Validate based on type
1395 if mcpConfig.Type == MCPStdio {
1396 if mcpConfig.Command == "" {
1397 errors.Add(fieldPrefix+".command", "command is required for stdio MCP")
1398 }
1399 } else if mcpConfig.Type == MCPSse {
1400 if mcpConfig.URL == "" {
1401 errors.Add(fieldPrefix+".url", "URL is required for SSE MCP")
1402 }
1403 }
1404 }
1405}
1406
1407// validateLSPs validates LSP configurations
1408func (c *Config) validateLSPs(errors *ValidationErrors) {
1409 if c.LSP == nil {
1410 c.LSP = make(map[string]LSPConfig)
1411 }
1412
1413 for lspName, lspConfig := range c.LSP {
1414 fieldPrefix := fmt.Sprintf("lsp.%s", lspName)
1415
1416 if lspConfig.Command == "" {
1417 errors.Add(fieldPrefix+".command", "command is required for LSP")
1418 }
1419 }
1420}
1421
1422// validateCrossReferences validates cross-references between different config sections
1423func (c *Config) validateCrossReferences(errors *ValidationErrors) {
1424 // Validate that agents can use their assigned model types
1425 for agentID, agent := range c.Agents {
1426 fieldPrefix := fmt.Sprintf("agents.%s", agentID)
1427
1428 var preferredModel PreferredModel
1429 switch agent.Model {
1430 case LargeModel:
1431 preferredModel = c.Models.Large
1432 case SmallModel:
1433 preferredModel = c.Models.Small
1434 }
1435
1436 if preferredModel.Provider != "" {
1437 if providerConfig, exists := c.Providers[preferredModel.Provider]; exists {
1438 if providerConfig.Disabled {
1439 errors.Add(fieldPrefix+".model", fmt.Sprintf("agent cannot use model type '%s' because provider '%s' is disabled", agent.Model, preferredModel.Provider))
1440 }
1441 }
1442 }
1443 }
1444}
1445
1446// validateCompleteness validates that the configuration is complete and usable
1447func (c *Config) validateCompleteness(errors *ValidationErrors) {
1448 // Check for at least one valid, non-disabled provider
1449 hasValidProvider := false
1450 for _, providerConfig := range c.Providers {
1451 if !providerConfig.Disabled {
1452 hasValidProvider = true
1453 break
1454 }
1455 }
1456 if !hasValidProvider {
1457 errors.Add("providers", "at least one non-disabled provider is required")
1458 }
1459
1460 // Check that default agents exist
1461 if _, exists := c.Agents[AgentCoder]; !exists {
1462 errors.Add("agents", "coder agent is required")
1463 }
1464 if _, exists := c.Agents[AgentTask]; !exists {
1465 errors.Add("agents", "task agent is required")
1466 }
1467
1468 // Check that preferred models are set if providers exist
1469 if hasValidProvider {
1470 if c.Models.Large.ModelID == "" || c.Models.Large.Provider == "" {
1471 errors.Add("models.large", "large preferred model must be configured when providers are available")
1472 }
1473 if c.Models.Small.ModelID == "" || c.Models.Small.Provider == "" {
1474 errors.Add("models.small", "small preferred model must be configured when providers are available")
1475 }
1476 }
1477}
1478
1479// JSONSchemaExtend adds custom schema properties for AgentID
1480func (AgentID) JSONSchemaExtend(schema *jsonschema.Schema) {
1481 schema.Enum = []any{
1482 string(AgentCoder),
1483 string(AgentTask),
1484 }
1485}
1486
1487// JSONSchemaExtend adds custom schema properties for ModelType
1488func (ModelType) JSONSchemaExtend(schema *jsonschema.Schema) {
1489 schema.Enum = []any{
1490 string(LargeModel),
1491 string(SmallModel),
1492 }
1493}
1494
1495// JSONSchemaExtend adds custom schema properties for MCPType
1496func (MCPType) JSONSchemaExtend(schema *jsonschema.Schema) {
1497 schema.Enum = []any{
1498 string(MCPStdio),
1499 string(MCPSse),
1500 }
1501}