1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 SupportsImages bool `json:"supports_attachments"`
69}
70
71type VertexAIOptions struct {
72 APIKey string `json:"api_key,omitempty"`
73 Project string `json:"project,omitempty"`
74 Location string `json:"location,omitempty"`
75}
76
77type ProviderConfig struct {
78 ID provider.InferenceProvider `json:"id"`
79 BaseURL string `json:"base_url,omitempty"`
80 ProviderType provider.Type `json:"provider_type"`
81 APIKey string `json:"api_key,omitempty"`
82 Disabled bool `json:"disabled"`
83 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
84 // used for e.x for vertex to set the project
85 ExtraParams map[string]string `json:"extra_params,omitempty"`
86
87 DefaultLargeModel string `json:"default_large_model,omitempty"`
88 DefaultSmallModel string `json:"default_small_model,omitempty"`
89
90 Models []Model `json:"models,omitempty"`
91}
92
93type Agent struct {
94 ID AgentID `json:"id"`
95 Name string `json:"name"`
96 Description string `json:"description,omitempty"`
97 // This is the id of the system prompt used by the agent
98 Disabled bool `json:"disabled"`
99
100 Model ModelType `json:"model"`
101
102 // The available tools for the agent
103 // if this is nil, all tools are available
104 AllowedTools []string `json:"allowed_tools"`
105
106 // this tells us which MCPs are available for this agent
107 // if this is empty all mcps are available
108 // the string array is the list of tools from the AllowedMCP the agent has available
109 // if the string array is nil, all tools from the AllowedMCP are available
110 AllowedMCP map[string][]string `json:"allowed_mcp"`
111
112 // The list of LSPs that this agent can use
113 // if this is nil, all LSPs are available
114 AllowedLSP []string `json:"allowed_lsp"`
115
116 // Overrides the context paths for this agent
117 ContextPaths []string `json:"context_paths"`
118}
119
120type MCPType string
121
122const (
123 MCPStdio MCPType = "stdio"
124 MCPSse MCPType = "sse"
125)
126
127type MCP struct {
128 Command string `json:"command"`
129 Env []string `json:"env"`
130 Args []string `json:"args"`
131 Type MCPType `json:"type"`
132 URL string `json:"url"`
133 Headers map[string]string `json:"headers"`
134}
135
136type LSPConfig struct {
137 Disabled bool `json:"enabled"`
138 Command string `json:"command"`
139 Args []string `json:"args"`
140 Options any `json:"options"`
141}
142
143type TUIOptions struct {
144 CompactMode bool `json:"compact_mode"`
145 // Here we can add themes later or any TUI related options
146}
147
148type Options struct {
149 ContextPaths []string `json:"context_paths"`
150 TUI TUIOptions `json:"tui"`
151 Debug bool `json:"debug"`
152 DebugLSP bool `json:"debug_lsp"`
153 DisableAutoSummarize bool `json:"disable_auto_summarize"`
154 // Relative to the cwd
155 DataDirectory string `json:"data_directory"`
156}
157
158type PreferredModel struct {
159 ModelID string `json:"model_id"`
160 Provider provider.InferenceProvider `json:"provider"`
161}
162
163type PreferredModels struct {
164 Large PreferredModel `json:"large"`
165 Small PreferredModel `json:"small"`
166}
167
168type Config struct {
169 Models PreferredModels `json:"models"`
170 // List of configured providers
171 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
172
173 // List of configured agents
174 Agents map[AgentID]Agent `json:"agents,omitempty"`
175
176 // List of configured MCPs
177 MCP map[string]MCP `json:"mcp,omitempty"`
178
179 // List of configured LSPs
180 LSP map[string]LSPConfig `json:"lsp,omitempty"`
181
182 // Miscellaneous options
183 Options Options `json:"options"`
184}
185
186var (
187 instance *Config // The single instance of the Singleton
188 cwd string
189 once sync.Once // Ensures the initialization happens only once
190
191)
192
193func loadConfig(cwd string, debug bool) (*Config, error) {
194 // First read the global config file
195 cfgPath := ConfigPath()
196
197 cfg := defaultConfigBasedOnEnv()
198 cfg.Options.Debug = debug
199 defaultLevel := slog.LevelInfo
200 if cfg.Options.Debug {
201 defaultLevel = slog.LevelDebug
202 }
203 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
204 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
205
206 // if file does not exist create it
207 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
208 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
209 return cfg, fmt.Errorf("failed to create directory: %w", err)
210 }
211 if _, err := os.Create(loggingFile); err != nil {
212 return cfg, fmt.Errorf("failed to create log file: %w", err)
213 }
214 }
215
216 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
217 if err != nil {
218 return cfg, fmt.Errorf("failed to open log file: %w", err)
219 }
220 // Configure logger
221 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
222 Level: defaultLevel,
223 }))
224 slog.SetDefault(logger)
225 } else {
226 // Configure logger
227 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
228 Level: defaultLevel,
229 }))
230 slog.SetDefault(logger)
231 }
232 var globalCfg *Config
233 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
234 // some other error occurred while checking the file
235 return nil, err
236 } else if err == nil {
237 // config file exists, read it
238 file, err := os.ReadFile(cfgPath)
239 if err != nil {
240 return nil, err
241 }
242 globalCfg = &Config{}
243 if err := json.Unmarshal(file, globalCfg); err != nil {
244 return nil, err
245 }
246 } else {
247 // config file does not exist, create a new one
248 globalCfg = &Config{}
249 }
250
251 var localConfig *Config
252 // Global config loaded, now read the local config file
253 localConfigPath := filepath.Join(cwd, "crush.json")
254 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
255 // some other error occurred while checking the file
256 return nil, err
257 } else if err == nil {
258 // local config file exists, read it
259 file, err := os.ReadFile(localConfigPath)
260 if err != nil {
261 return nil, err
262 }
263 localConfig = &Config{}
264 if err := json.Unmarshal(file, localConfig); err != nil {
265 return nil, err
266 }
267 }
268
269 // merge options
270 mergeOptions(cfg, globalCfg, localConfig)
271
272 mergeProviderConfigs(cfg, globalCfg, localConfig)
273 // no providers found the app is not initialized yet
274 if len(cfg.Providers) == 0 {
275 return cfg, nil
276 }
277 preferredProvider := getPreferredProvider(cfg.Providers)
278 if preferredProvider != nil {
279 cfg.Models = PreferredModels{
280 Large: PreferredModel{
281 ModelID: preferredProvider.DefaultLargeModel,
282 Provider: preferredProvider.ID,
283 },
284 Small: PreferredModel{
285 ModelID: preferredProvider.DefaultSmallModel,
286 Provider: preferredProvider.ID,
287 },
288 }
289 } else {
290 // No valid providers found, set empty models
291 cfg.Models = PreferredModels{}
292 }
293
294 mergeModels(cfg, globalCfg, localConfig)
295
296 agents := map[AgentID]Agent{
297 AgentCoder: {
298 ID: AgentCoder,
299 Name: "Coder",
300 Description: "An agent that helps with executing coding tasks.",
301 Model: LargeModel,
302 ContextPaths: cfg.Options.ContextPaths,
303 // All tools allowed
304 },
305 AgentTask: {
306 ID: AgentTask,
307 Name: "Task",
308 Description: "An agent that helps with searching for context and finding implementation details.",
309 Model: LargeModel,
310 ContextPaths: cfg.Options.ContextPaths,
311 AllowedTools: []string{
312 "glob",
313 "grep",
314 "ls",
315 "sourcegraph",
316 "view",
317 },
318 // NO MCPs or LSPs by default
319 AllowedMCP: map[string][]string{},
320 AllowedLSP: []string{},
321 },
322 }
323 cfg.Agents = agents
324 mergeAgents(cfg, globalCfg, localConfig)
325 mergeMCPs(cfg, globalCfg, localConfig)
326 mergeLSPs(cfg, globalCfg, localConfig)
327
328 return cfg, nil
329}
330
331func Init(workingDir string, debug bool) (*Config, error) {
332 var err error
333 once.Do(func() {
334 cwd = workingDir
335 instance, err = loadConfig(cwd, debug)
336 if err != nil {
337 logging.Error("Failed to load config", "error", err)
338 }
339 })
340
341 return instance, err
342}
343
344func Get() *Config {
345 if instance == nil {
346 // TODO: Handle this better
347 panic("Config not initialized. Call InitConfig first.")
348 }
349 return instance
350}
351
352func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
353 providers := Providers()
354 for _, p := range providers {
355 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
356 return &providerConfig
357 }
358 }
359 // if none found return the first configured provider
360 for _, providerConfig := range configuredProviders {
361 if !providerConfig.Disabled {
362 return &providerConfig
363 }
364 }
365 return nil
366}
367
368func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
369 if other.APIKey != "" {
370 base.APIKey = other.APIKey
371 }
372 // Only change these options if the provider is not a known provider
373 if !slices.Contains(provider.KnownProviders(), p) {
374 if other.BaseURL != "" {
375 base.BaseURL = other.BaseURL
376 }
377 if other.ProviderType != "" {
378 base.ProviderType = other.ProviderType
379 }
380 if len(other.ExtraHeaders) > 0 {
381 if base.ExtraHeaders == nil {
382 base.ExtraHeaders = make(map[string]string)
383 }
384 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
385 }
386 if len(other.ExtraParams) > 0 {
387 if base.ExtraParams == nil {
388 base.ExtraParams = make(map[string]string)
389 }
390 maps.Copy(base.ExtraParams, other.ExtraParams)
391 }
392 }
393
394 if other.Disabled {
395 base.Disabled = other.Disabled
396 }
397
398 if other.DefaultLargeModel != "" {
399 base.DefaultLargeModel = other.DefaultLargeModel
400 }
401 // Add new models if they don't exist
402 if other.Models != nil {
403 for _, model := range other.Models {
404 // check if the model already exists
405 exists := false
406 for _, existingModel := range base.Models {
407 if existingModel.ID == model.ID {
408 exists = true
409 break
410 }
411 }
412 if !exists {
413 base.Models = append(base.Models, model)
414 }
415 }
416 }
417
418 return base
419}
420
421func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
422 if !slices.Contains(provider.KnownProviders(), p) {
423 if providerConfig.ProviderType != provider.TypeOpenAI {
424 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
425 }
426 if providerConfig.BaseURL == "" {
427 return errors.New("base URL must be set for custom providers")
428 }
429 if providerConfig.APIKey == "" {
430 return errors.New("API key must be set for custom providers")
431 }
432 }
433 return nil
434}
435
436func mergeModels(base, global, local *Config) {
437 for _, cfg := range []*Config{global, local} {
438 if cfg == nil {
439 continue
440 }
441 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
442 base.Models.Large = cfg.Models.Large
443 }
444
445 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
446 base.Models.Small = cfg.Models.Small
447 }
448 }
449}
450
451func mergeOptions(base, global, local *Config) {
452 for _, cfg := range []*Config{global, local} {
453 if cfg == nil {
454 continue
455 }
456 baseOptions := base.Options
457 other := cfg.Options
458 if len(other.ContextPaths) > 0 {
459 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
460 }
461
462 if other.TUI.CompactMode {
463 baseOptions.TUI.CompactMode = other.TUI.CompactMode
464 }
465
466 if other.Debug {
467 baseOptions.Debug = other.Debug
468 }
469
470 if other.DebugLSP {
471 baseOptions.DebugLSP = other.DebugLSP
472 }
473
474 if other.DisableAutoSummarize {
475 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
476 }
477
478 if other.DataDirectory != "" {
479 baseOptions.DataDirectory = other.DataDirectory
480 }
481 base.Options = baseOptions
482 }
483}
484
485func mergeAgents(base, global, local *Config) {
486 for _, cfg := range []*Config{global, local} {
487 if cfg == nil {
488 continue
489 }
490 for agentID, newAgent := range cfg.Agents {
491 if _, ok := base.Agents[agentID]; !ok {
492 // New agent - apply defaults
493 newAgent.ID = agentID // Ensure the ID is set correctly
494 if newAgent.Model == "" {
495 newAgent.Model = LargeModel // Default model type
496 }
497 // Context paths are always additive - start with global, then add custom
498 if len(newAgent.ContextPaths) > 0 {
499 newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
500 } else {
501 newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
502 }
503 base.Agents[agentID] = newAgent
504 } else {
505 baseAgent := base.Agents[agentID]
506
507 // Special handling for known agents - only allow model changes
508 if agentID == AgentCoder || agentID == AgentTask {
509 if newAgent.Model != "" {
510 baseAgent.Model = newAgent.Model
511 }
512 // For known agents, only allow MCP and LSP configuration
513 if newAgent.AllowedMCP != nil {
514 baseAgent.AllowedMCP = newAgent.AllowedMCP
515 }
516 if newAgent.AllowedLSP != nil {
517 baseAgent.AllowedLSP = newAgent.AllowedLSP
518 }
519 // Context paths are additive for known agents too
520 if len(newAgent.ContextPaths) > 0 {
521 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
522 }
523 } else {
524 // Custom agents - allow full merging
525 if newAgent.Name != "" {
526 baseAgent.Name = newAgent.Name
527 }
528 if newAgent.Description != "" {
529 baseAgent.Description = newAgent.Description
530 }
531 if newAgent.Model != "" {
532 baseAgent.Model = newAgent.Model
533 } else if baseAgent.Model == "" {
534 baseAgent.Model = LargeModel // Default fallback
535 }
536
537 // Boolean fields - always update (including false values)
538 baseAgent.Disabled = newAgent.Disabled
539
540 // Slice/Map fields - update if provided (including empty slices/maps)
541 if newAgent.AllowedTools != nil {
542 baseAgent.AllowedTools = newAgent.AllowedTools
543 }
544 if newAgent.AllowedMCP != nil {
545 baseAgent.AllowedMCP = newAgent.AllowedMCP
546 }
547 if newAgent.AllowedLSP != nil {
548 baseAgent.AllowedLSP = newAgent.AllowedLSP
549 }
550 // Context paths are additive for custom agents too
551 if len(newAgent.ContextPaths) > 0 {
552 baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
553 }
554 }
555
556 base.Agents[agentID] = baseAgent
557 }
558 }
559 }
560}
561
562func mergeMCPs(base, global, local *Config) {
563 for _, cfg := range []*Config{global, local} {
564 if cfg == nil {
565 continue
566 }
567 maps.Copy(base.MCP, cfg.MCP)
568 }
569}
570
571func mergeLSPs(base, global, local *Config) {
572 for _, cfg := range []*Config{global, local} {
573 if cfg == nil {
574 continue
575 }
576 maps.Copy(base.LSP, cfg.LSP)
577 }
578}
579
580func mergeProviderConfigs(base, global, local *Config) {
581 for _, cfg := range []*Config{global, local} {
582 if cfg == nil {
583 continue
584 }
585 for providerName, globalProvider := range cfg.Providers {
586 if _, ok := base.Providers[providerName]; !ok {
587 base.Providers[providerName] = globalProvider
588 } else {
589 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
590 }
591 }
592 }
593
594 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
595 for providerName, providerConfig := range base.Providers {
596 err := validateProvider(providerName, providerConfig)
597 if err != nil {
598 logging.Warn("Skipping provider", "name", providerName, "error", err)
599 continue // Skip invalid providers
600 }
601 finalProviders[providerName] = providerConfig
602 }
603 base.Providers = finalProviders
604}
605
606func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
607 switch providerId {
608 case provider.InferenceProviderAnthropic:
609 return ProviderConfig{
610 ID: providerId,
611 ProviderType: provider.TypeAnthropic,
612 }
613 case provider.InferenceProviderOpenAI:
614 return ProviderConfig{
615 ID: providerId,
616 ProviderType: provider.TypeOpenAI,
617 }
618 case provider.InferenceProviderGemini:
619 return ProviderConfig{
620 ID: providerId,
621 ProviderType: provider.TypeGemini,
622 }
623 case provider.InferenceProviderBedrock:
624 return ProviderConfig{
625 ID: providerId,
626 ProviderType: provider.TypeBedrock,
627 }
628 case provider.InferenceProviderAzure:
629 return ProviderConfig{
630 ID: providerId,
631 ProviderType: provider.TypeAzure,
632 }
633 case provider.InferenceProviderOpenRouter:
634 return ProviderConfig{
635 ID: providerId,
636 ProviderType: provider.TypeOpenAI,
637 BaseURL: "https://openrouter.ai/api/v1",
638 ExtraHeaders: map[string]string{
639 "HTTP-Referer": "crush.charm.land",
640 "X-Title": "Crush",
641 },
642 }
643 case provider.InferenceProviderXAI:
644 return ProviderConfig{
645 ID: providerId,
646 ProviderType: provider.TypeXAI,
647 BaseURL: "https://api.x.ai/v1",
648 }
649 case provider.InferenceProviderVertexAI:
650 return ProviderConfig{
651 ID: providerId,
652 ProviderType: provider.TypeVertexAI,
653 }
654 default:
655 return ProviderConfig{
656 ID: providerId,
657 ProviderType: provider.TypeOpenAI,
658 }
659 }
660}
661
662func defaultConfigBasedOnEnv() *Config {
663 cfg := &Config{
664 Options: Options{
665 DataDirectory: defaultDataDirectory,
666 ContextPaths: defaultContextPaths,
667 },
668 Providers: make(map[provider.InferenceProvider]ProviderConfig),
669 }
670
671 providers := Providers()
672
673 for _, p := range providers {
674 if strings.HasPrefix(p.APIKey, "$") {
675 envVar := strings.TrimPrefix(p.APIKey, "$")
676 if apiKey := os.Getenv(envVar); apiKey != "" {
677 providerConfig := providerDefaultConfig(p.ID)
678 providerConfig.APIKey = apiKey
679 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
680 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
681 baseURL := p.APIEndpoint
682 if strings.HasPrefix(baseURL, "$") {
683 envVar := strings.TrimPrefix(baseURL, "$")
684 if url := os.Getenv(envVar); url != "" {
685 baseURL = url
686 }
687 }
688 providerConfig.BaseURL = baseURL
689 for _, model := range p.Models {
690 providerConfig.Models = append(providerConfig.Models, Model{
691 ID: model.ID,
692 Name: model.Name,
693 CostPer1MIn: model.CostPer1MIn,
694 CostPer1MOut: model.CostPer1MOut,
695 CostPer1MInCached: model.CostPer1MInCached,
696 CostPer1MOutCached: model.CostPer1MOutCached,
697 ContextWindow: model.ContextWindow,
698 DefaultMaxTokens: model.DefaultMaxTokens,
699 CanReason: model.CanReason,
700 SupportsImages: model.SupportsImages,
701 })
702 }
703 cfg.Providers[p.ID] = providerConfig
704 }
705 }
706 }
707 // TODO: support local models
708
709 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
710 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
711 providerConfig.ExtraParams = map[string]string{
712 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
713 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
714 }
715 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
716 }
717
718 if hasAWSCredentials() {
719 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
720 providerConfig.ExtraParams = map[string]string{
721 "region": os.Getenv("AWS_DEFAULT_REGION"),
722 }
723 if providerConfig.ExtraParams["region"] == "" {
724 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
725 }
726 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
727 }
728 return cfg
729}
730
731func hasAWSCredentials() bool {
732 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
733 return true
734 }
735
736 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
737 return true
738 }
739
740 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
741 return true
742 }
743
744 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
745 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
746 return true
747 }
748
749 return false
750}
751
752func WorkingDirectory() string {
753 return cwd
754}
755
756// TODO: Handle error state
757
758func GetAgentModel(agentID AgentID) Model {
759 cfg := Get()
760 agent, ok := cfg.Agents[agentID]
761 if !ok {
762 logging.Error("Agent not found", "agent_id", agentID)
763 return Model{}
764 }
765
766 var model PreferredModel
767 switch agent.Model {
768 case LargeModel:
769 model = cfg.Models.Large
770 case SmallModel:
771 model = cfg.Models.Small
772 default:
773 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
774 model = cfg.Models.Large // Fallback to large model
775 }
776 providerConfig, ok := cfg.Providers[model.Provider]
777 if !ok {
778 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
779 return Model{}
780 }
781
782 for _, m := range providerConfig.Models {
783 if m.ID == model.ModelID {
784 return m
785 }
786 }
787
788 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
789 return Model{}
790}
791
792func GetAgentProvider(agentID AgentID) ProviderConfig {
793 cfg := Get()
794 agent, ok := cfg.Agents[agentID]
795 if !ok {
796 logging.Error("Agent not found", "agent_id", agentID)
797 return ProviderConfig{}
798 }
799
800 var model PreferredModel
801 switch agent.Model {
802 case LargeModel:
803 model = cfg.Models.Large
804 case SmallModel:
805 model = cfg.Models.Small
806 default:
807 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
808 model = cfg.Models.Large // Fallback to large model
809 }
810
811 providerConfig, ok := cfg.Providers[model.Provider]
812 if !ok {
813 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
814 return ProviderConfig{}
815 }
816
817 return providerConfig
818}
819
820func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
821 cfg := Get()
822 providerConfig, ok := cfg.Providers[provider]
823 if !ok {
824 logging.Error("Provider not found", "provider", provider)
825 return Model{}
826 }
827
828 for _, model := range providerConfig.Models {
829 if model.ID == modelID {
830 return model
831 }
832 }
833
834 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
835 return Model{}
836}
837
838func GetModel(modelType ModelType) Model {
839 cfg := Get()
840 var model PreferredModel
841 switch modelType {
842 case LargeModel:
843 model = cfg.Models.Large
844 case SmallModel:
845 model = cfg.Models.Small
846 default:
847 model = cfg.Models.Large // Fallback to large model
848 }
849 providerConfig, ok := cfg.Providers[model.Provider]
850 if !ok {
851 return Model{}
852 }
853
854 for _, m := range providerConfig.Models {
855 if m.ID == model.ModelID {
856 return m
857 }
858 }
859 return Model{}
860}
861
862func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
863 cfg := Get()
864 switch modelType {
865 case LargeModel:
866 cfg.Models.Large = model
867 case SmallModel:
868 cfg.Models.Small = model
869 default:
870 return fmt.Errorf("unknown model type: %s", modelType)
871 }
872 return nil
873}