1package config
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/logging"
17)
18
19const (
20 defaultDataDirectory = ".crush"
21 defaultLogLevel = "info"
22 appName = "crush"
23
24 MaxTokensFallbackDefault = 4096
25)
26
27var defaultContextPaths = []string{
28 ".github/copilot-instructions.md",
29 ".cursorrules",
30 ".cursor/rules/",
31 "CLAUDE.md",
32 "CLAUDE.local.md",
33 "GEMINI.md",
34 "gemini.md",
35 "crush.md",
36 "crush.local.md",
37 "Crush.md",
38 "Crush.local.md",
39 "CRUSH.md",
40 "CRUSH.local.md",
41}
42
43type AgentID string
44
45const (
46 AgentCoder AgentID = "coder"
47 AgentTask AgentID = "task"
48)
49
50type ModelType string
51
52const (
53 LargeModel ModelType = "large"
54 SmallModel ModelType = "small"
55)
56
57type Model struct {
58 ID string `json:"id"`
59 Name string `json:"model"`
60 CostPer1MIn float64 `json:"cost_per_1m_in"`
61 CostPer1MOut float64 `json:"cost_per_1m_out"`
62 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
63 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
64 ContextWindow int64 `json:"context_window"`
65 DefaultMaxTokens int64 `json:"default_max_tokens"`
66 CanReason bool `json:"can_reason"`
67 ReasoningEffort string `json:"reasoning_effort"`
68 SupportsImages bool `json:"supports_attachments"`
69}
70
71type VertexAIOptions struct {
72 APIKey string `json:"api_key,omitempty"`
73 Project string `json:"project,omitempty"`
74 Location string `json:"location,omitempty"`
75}
76
77type ProviderConfig struct {
78 ID provider.InferenceProvider `json:"id"`
79 BaseURL string `json:"base_url,omitempty"`
80 ProviderType provider.Type `json:"provider_type"`
81 APIKey string `json:"api_key,omitempty"`
82 Disabled bool `json:"disabled"`
83 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
84 // used for e.x for vertex to set the project
85 ExtraParams map[string]string `json:"extra_params,omitempty"`
86
87 DefaultLargeModel string `json:"default_large_model,omitempty"`
88 DefaultSmallModel string `json:"default_small_model,omitempty"`
89
90 Models []Model `json:"models,omitempty"`
91}
92
93type Agent struct {
94 ID AgentID `json:"id"`
95 Name string `json:"name"`
96 Description string `json:"description,omitempty"`
97 // This is the id of the system prompt used by the agent
98 Disabled bool `json:"disabled"`
99
100 Model ModelType `json:"model"`
101
102 // The available tools for the agent
103 // if this is nil, all tools are available
104 AllowedTools []string `json:"allowed_tools"`
105
106 // this tells us which MCPs are available for this agent
107 // if this is empty all mcps are available
108 // the string array is the list of tools from the AllowedMCP the agent has available
109 // if the string array is nil, all tools from the AllowedMCP are available
110 AllowedMCP map[string][]string `json:"allowed_mcp"`
111
112 // The list of LSPs that this agent can use
113 // if this is nil, all LSPs are available
114 AllowedLSP []string `json:"allowed_lsp"`
115
116 // Overrides the context paths for this agent
117 ContextPaths []string `json:"context_paths"`
118}
119
120type MCPType string
121
122const (
123 MCPStdio MCPType = "stdio"
124 MCPSse MCPType = "sse"
125)
126
127type MCP struct {
128 Command string `json:"command"`
129 Env []string `json:"env"`
130 Args []string `json:"args"`
131 Type MCPType `json:"type"`
132 URL string `json:"url"`
133 Headers map[string]string `json:"headers"`
134}
135
136type LSPConfig struct {
137 Disabled bool `json:"enabled"`
138 Command string `json:"command"`
139 Args []string `json:"args"`
140 Options any `json:"options"`
141}
142
143type TUIOptions struct {
144 CompactMode bool `json:"compact_mode"`
145 // Here we can add themes later or any TUI related options
146}
147
148type Options struct {
149 ContextPaths []string `json:"context_paths"`
150 TUI TUIOptions `json:"tui"`
151 Debug bool `json:"debug"`
152 DebugLSP bool `json:"debug_lsp"`
153 DisableAutoSummarize bool `json:"disable_auto_summarize"`
154 // Relative to the cwd
155 DataDirectory string `json:"data_directory"`
156}
157
158type PreferredModel struct {
159 ModelID string `json:"model_id"`
160 Provider provider.InferenceProvider `json:"provider"`
161}
162
163type PreferredModels struct {
164 Large PreferredModel `json:"large"`
165 Small PreferredModel `json:"small"`
166}
167
168type Config struct {
169 Models PreferredModels `json:"models"`
170 // List of configured providers
171 Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
172
173 // List of configured agents
174 Agents map[AgentID]Agent `json:"agents,omitempty"`
175
176 // List of configured MCPs
177 MCP map[string]MCP `json:"mcp,omitempty"`
178
179 // List of configured LSPs
180 LSP map[string]LSPConfig `json:"lsp,omitempty"`
181
182 // Miscellaneous options
183 Options Options `json:"options"`
184}
185
186var (
187 instance *Config // The single instance of the Singleton
188 cwd string
189 once sync.Once // Ensures the initialization happens only once
190
191)
192
193func loadConfig(cwd string, debug bool) (*Config, error) {
194 // First read the global config file
195 cfgPath := ConfigPath()
196
197 cfg := defaultConfigBasedOnEnv()
198 cfg.Options.Debug = debug
199 defaultLevel := slog.LevelInfo
200 if cfg.Options.Debug {
201 defaultLevel = slog.LevelDebug
202 }
203 if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
204 loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
205
206 // if file does not exist create it
207 if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
208 if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
209 return cfg, fmt.Errorf("failed to create directory: %w", err)
210 }
211 if _, err := os.Create(loggingFile); err != nil {
212 return cfg, fmt.Errorf("failed to create log file: %w", err)
213 }
214 }
215
216 sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
217 if err != nil {
218 return cfg, fmt.Errorf("failed to open log file: %w", err)
219 }
220 // Configure logger
221 logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
222 Level: defaultLevel,
223 }))
224 slog.SetDefault(logger)
225 } else {
226 // Configure logger
227 logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
228 Level: defaultLevel,
229 }))
230 slog.SetDefault(logger)
231 }
232 var globalCfg *Config
233 if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
234 // some other error occurred while checking the file
235 return nil, err
236 } else if err == nil {
237 // config file exists, read it
238 file, err := os.ReadFile(cfgPath)
239 if err != nil {
240 return nil, err
241 }
242 globalCfg = &Config{}
243 if err := json.Unmarshal(file, globalCfg); err != nil {
244 return nil, err
245 }
246 } else {
247 // config file does not exist, create a new one
248 globalCfg = &Config{}
249 }
250
251 var localConfig *Config
252 // Global config loaded, now read the local config file
253 localConfigPath := filepath.Join(cwd, "crush.json")
254 if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
255 // some other error occurred while checking the file
256 return nil, err
257 } else if err == nil {
258 // local config file exists, read it
259 file, err := os.ReadFile(localConfigPath)
260 if err != nil {
261 return nil, err
262 }
263 localConfig = &Config{}
264 if err := json.Unmarshal(file, localConfig); err != nil {
265 return nil, err
266 }
267 }
268
269 // merge options
270 mergeOptions(cfg, globalCfg, localConfig)
271
272 mergeProviderConfigs(cfg, globalCfg, localConfig)
273 // no providers found the app is not initialized yet
274 if len(cfg.Providers) == 0 {
275 return cfg, nil
276 }
277 preferredProvider := getPreferredProvider(cfg.Providers)
278 cfg.Models = PreferredModels{
279 Large: PreferredModel{
280 ModelID: preferredProvider.DefaultLargeModel,
281 Provider: preferredProvider.ID,
282 },
283 Small: PreferredModel{
284 ModelID: preferredProvider.DefaultSmallModel,
285 Provider: preferredProvider.ID,
286 },
287 }
288
289 mergeModels(cfg, globalCfg, localConfig)
290
291 if preferredProvider == nil {
292 return nil, errors.New("no valid providers configured")
293 }
294
295 agents := map[AgentID]Agent{
296 AgentCoder: {
297 ID: AgentCoder,
298 Name: "Coder",
299 Description: "An agent that helps with executing coding tasks.",
300 Model: LargeModel,
301 ContextPaths: cfg.Options.ContextPaths,
302 // All tools allowed
303 },
304 AgentTask: {
305 ID: AgentTask,
306 Name: "Task",
307 Description: "An agent that helps with searching for context and finding implementation details.",
308 Model: LargeModel,
309 ContextPaths: cfg.Options.ContextPaths,
310 AllowedTools: []string{
311 "glob",
312 "grep",
313 "ls",
314 "sourcegraph",
315 "view",
316 },
317 // NO MCPs or LSPs by default
318 AllowedMCP: map[string][]string{},
319 AllowedLSP: []string{},
320 },
321 }
322 cfg.Agents = agents
323 mergeAgents(cfg, globalCfg, localConfig)
324 mergeMCPs(cfg, globalCfg, localConfig)
325 mergeLSPs(cfg, globalCfg, localConfig)
326
327 return cfg, nil
328}
329
330func Init(workingDir string, debug bool) (*Config, error) {
331 var err error
332 once.Do(func() {
333 cwd = workingDir
334 instance, err = loadConfig(cwd, debug)
335 if err != nil {
336 logging.Error("Failed to load config", "error", err)
337 }
338 })
339
340 return instance, err
341}
342
343func Get() *Config {
344 if instance == nil {
345 // TODO: Handle this better
346 panic("Config not initialized. Call InitConfig first.")
347 }
348 return instance
349}
350
351func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
352 providers := Providers()
353 for _, p := range providers {
354 if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
355 return &providerConfig
356 }
357 }
358 // if none found return the first configured provider
359 for _, providerConfig := range configuredProviders {
360 if !providerConfig.Disabled {
361 return &providerConfig
362 }
363 }
364 return nil
365}
366
367func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
368 if other.APIKey != "" {
369 base.APIKey = other.APIKey
370 }
371 // Only change these options if the provider is not a known provider
372 if !slices.Contains(provider.KnownProviders(), p) {
373 if other.BaseURL != "" {
374 base.BaseURL = other.BaseURL
375 }
376 if other.ProviderType != "" {
377 base.ProviderType = other.ProviderType
378 }
379 if len(base.ExtraHeaders) > 0 {
380 if base.ExtraHeaders == nil {
381 base.ExtraHeaders = make(map[string]string)
382 }
383 maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
384 }
385 if len(other.ExtraParams) > 0 {
386 if base.ExtraParams == nil {
387 base.ExtraParams = make(map[string]string)
388 }
389 maps.Copy(base.ExtraParams, other.ExtraParams)
390 }
391 }
392
393 if other.Disabled {
394 base.Disabled = other.Disabled
395 }
396
397 if other.DefaultLargeModel != "" {
398 base.DefaultLargeModel = other.DefaultLargeModel
399 }
400 // Add new models if they don't exist
401 if other.Models != nil {
402 for _, model := range other.Models {
403 // check if the model already exists
404 exists := false
405 for _, existingModel := range base.Models {
406 if existingModel.ID == model.ID {
407 exists = true
408 break
409 }
410 }
411 if !exists {
412 base.Models = append(base.Models, model)
413 }
414 }
415 }
416
417 return base
418}
419
420func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
421 if !slices.Contains(provider.KnownProviders(), p) {
422 if providerConfig.ProviderType != provider.TypeOpenAI {
423 return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
424 }
425 if providerConfig.BaseURL == "" {
426 return errors.New("base URL must be set for custom providers")
427 }
428 if providerConfig.APIKey == "" {
429 return errors.New("API key must be set for custom providers")
430 }
431 }
432 return nil
433}
434
435func mergeModels(base, global, local *Config) {
436 for _, cfg := range []*Config{global, local} {
437 if cfg == nil {
438 continue
439 }
440 if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
441 base.Models.Large = cfg.Models.Large
442 }
443
444 if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
445 base.Models.Small = cfg.Models.Small
446 }
447 }
448}
449
450func mergeOptions(base, global, local *Config) {
451 for _, cfg := range []*Config{global, local} {
452 if cfg == nil {
453 continue
454 }
455 baseOptions := base.Options
456 other := cfg.Options
457 if len(other.ContextPaths) > 0 {
458 baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
459 }
460
461 if other.TUI.CompactMode {
462 baseOptions.TUI.CompactMode = other.TUI.CompactMode
463 }
464
465 if other.Debug {
466 baseOptions.Debug = other.Debug
467 }
468
469 if other.DebugLSP {
470 baseOptions.DebugLSP = other.DebugLSP
471 }
472
473 if other.DisableAutoSummarize {
474 baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
475 }
476
477 if other.DataDirectory != "" {
478 baseOptions.DataDirectory = other.DataDirectory
479 }
480 base.Options = baseOptions
481 }
482}
483
484func mergeAgents(base, global, local *Config) {
485 for _, cfg := range []*Config{global, local} {
486 if cfg == nil {
487 continue
488 }
489 for agentID, newAgent := range cfg.Agents {
490 if _, ok := base.Agents[agentID]; !ok {
491 newAgent.ID = agentID // Ensure the ID is set correctly
492 base.Agents[agentID] = newAgent
493 } else {
494 switch agentID {
495 case AgentCoder:
496 baseAgent := base.Agents[agentID]
497 if newAgent.Model != "" {
498 baseAgent.Model = newAgent.Model
499 }
500 baseAgent.AllowedMCP = newAgent.AllowedMCP
501 baseAgent.AllowedLSP = newAgent.AllowedLSP
502 base.Agents[agentID] = baseAgent
503 default:
504 baseAgent := base.Agents[agentID]
505 baseAgent.Name = newAgent.Name
506 baseAgent.Description = newAgent.Description
507 baseAgent.Disabled = newAgent.Disabled
508 if newAgent.Model == "" {
509 baseAgent.Model = LargeModel
510 }
511 baseAgent.AllowedTools = newAgent.AllowedTools
512 baseAgent.AllowedMCP = newAgent.AllowedMCP
513 baseAgent.AllowedLSP = newAgent.AllowedLSP
514 base.Agents[agentID] = baseAgent
515 }
516 }
517 }
518 }
519}
520
521func mergeMCPs(base, global, local *Config) {
522 for _, cfg := range []*Config{global, local} {
523 if cfg == nil {
524 continue
525 }
526 maps.Copy(base.MCP, cfg.MCP)
527 }
528}
529
530func mergeLSPs(base, global, local *Config) {
531 for _, cfg := range []*Config{global, local} {
532 if cfg == nil {
533 continue
534 }
535 maps.Copy(base.LSP, cfg.LSP)
536 }
537}
538
539func mergeProviderConfigs(base, global, local *Config) {
540 for _, cfg := range []*Config{global, local} {
541 if cfg == nil {
542 continue
543 }
544 for providerName, globalProvider := range cfg.Providers {
545 if _, ok := base.Providers[providerName]; !ok {
546 base.Providers[providerName] = globalProvider
547 } else {
548 base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
549 }
550 }
551 }
552
553 finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
554 for providerName, providerConfig := range base.Providers {
555 err := validateProvider(providerName, providerConfig)
556 if err != nil {
557 logging.Warn("Skipping provider", "name", providerName, "error", err)
558 }
559 finalProviders[providerName] = providerConfig
560 }
561 base.Providers = finalProviders
562}
563
564func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
565 switch providerId {
566 case provider.InferenceProviderAnthropic:
567 return ProviderConfig{
568 ID: providerId,
569 ProviderType: provider.TypeAnthropic,
570 }
571 case provider.InferenceProviderOpenAI:
572 return ProviderConfig{
573 ID: providerId,
574 ProviderType: provider.TypeOpenAI,
575 }
576 case provider.InferenceProviderGemini:
577 return ProviderConfig{
578 ID: providerId,
579 ProviderType: provider.TypeGemini,
580 }
581 case provider.InferenceProviderBedrock:
582 return ProviderConfig{
583 ID: providerId,
584 ProviderType: provider.TypeBedrock,
585 }
586 case provider.InferenceProviderAzure:
587 return ProviderConfig{
588 ID: providerId,
589 ProviderType: provider.TypeAzure,
590 }
591 case provider.InferenceProviderOpenRouter:
592 return ProviderConfig{
593 ID: providerId,
594 ProviderType: provider.TypeOpenAI,
595 BaseURL: "https://openrouter.ai/api/v1",
596 ExtraHeaders: map[string]string{
597 "HTTP-Referer": "crush.charm.land",
598 "X-Title": "Crush",
599 },
600 }
601 case provider.InferenceProviderXAI:
602 return ProviderConfig{
603 ID: providerId,
604 ProviderType: provider.TypeXAI,
605 BaseURL: "https://api.x.ai/v1",
606 }
607 case provider.InferenceProviderVertexAI:
608 return ProviderConfig{
609 ID: providerId,
610 ProviderType: provider.TypeVertexAI,
611 }
612 default:
613 return ProviderConfig{
614 ID: providerId,
615 ProviderType: provider.TypeOpenAI,
616 }
617 }
618}
619
620func defaultConfigBasedOnEnv() *Config {
621 cfg := &Config{
622 Options: Options{
623 DataDirectory: defaultDataDirectory,
624 ContextPaths: defaultContextPaths,
625 },
626 Providers: make(map[provider.InferenceProvider]ProviderConfig),
627 }
628
629 providers := Providers()
630
631 for _, p := range providers {
632 if strings.HasPrefix(p.APIKey, "$") {
633 envVar := strings.TrimPrefix(p.APIKey, "$")
634 if apiKey := os.Getenv(envVar); apiKey != "" {
635 providerConfig := providerDefaultConfig(p.ID)
636 providerConfig.APIKey = apiKey
637 providerConfig.DefaultLargeModel = p.DefaultLargeModelID
638 providerConfig.DefaultSmallModel = p.DefaultSmallModelID
639 baseURL := p.APIEndpoint
640 if strings.HasPrefix(baseURL, "$") {
641 envVar := strings.TrimPrefix(baseURL, "$")
642 if url := os.Getenv(envVar); url != "" {
643 baseURL = url
644 }
645 }
646 providerConfig.BaseURL = baseURL
647 for _, model := range p.Models {
648 providerConfig.Models = append(providerConfig.Models, Model{
649 ID: model.ID,
650 Name: model.Name,
651 CostPer1MIn: model.CostPer1MIn,
652 CostPer1MOut: model.CostPer1MOut,
653 CostPer1MInCached: model.CostPer1MInCached,
654 CostPer1MOutCached: model.CostPer1MOutCached,
655 ContextWindow: model.ContextWindow,
656 DefaultMaxTokens: model.DefaultMaxTokens,
657 CanReason: model.CanReason,
658 SupportsImages: model.SupportsImages,
659 })
660 }
661 cfg.Providers[p.ID] = providerConfig
662 }
663 }
664 }
665 // TODO: support local models
666
667 if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
668 providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
669 providerConfig.ExtraParams = map[string]string{
670 "project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
671 "location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
672 }
673 cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
674 }
675
676 if hasAWSCredentials() {
677 providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
678 providerConfig.ExtraParams = map[string]string{
679 "region": os.Getenv("AWS_DEFAULT_REGION"),
680 }
681 if providerConfig.ExtraParams["region"] == "" {
682 providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
683 }
684 cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
685 }
686 return cfg
687}
688
689func hasAWSCredentials() bool {
690 if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
691 return true
692 }
693
694 if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
695 return true
696 }
697
698 if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
699 return true
700 }
701
702 if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
703 os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
704 return true
705 }
706
707 return false
708}
709
710func WorkingDirectory() string {
711 return cwd
712}
713
714// TODO: Handle error state
715
716func GetAgentModel(agentID AgentID) Model {
717 cfg := Get()
718 agent, ok := cfg.Agents[agentID]
719 if !ok {
720 logging.Error("Agent not found", "agent_id", agentID)
721 return Model{}
722 }
723
724 var model PreferredModel
725 switch agent.Model {
726 case LargeModel:
727 model = cfg.Models.Large
728 case SmallModel:
729 model = cfg.Models.Small
730 default:
731 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
732 model = cfg.Models.Large // Fallback to large model
733 }
734 providerConfig, ok := cfg.Providers[model.Provider]
735 if !ok {
736 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
737 return Model{}
738 }
739
740 for _, m := range providerConfig.Models {
741 if m.ID == model.ModelID {
742 return m
743 }
744 }
745
746 logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
747 return Model{}
748}
749
750func GetAgentProvider(agentID AgentID) ProviderConfig {
751 cfg := Get()
752 agent, ok := cfg.Agents[agentID]
753 if !ok {
754 logging.Error("Agent not found", "agent_id", agentID)
755 return ProviderConfig{}
756 }
757
758 var model PreferredModel
759 switch agent.Model {
760 case LargeModel:
761 model = cfg.Models.Large
762 case SmallModel:
763 model = cfg.Models.Small
764 default:
765 logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
766 model = cfg.Models.Large // Fallback to large model
767 }
768
769 providerConfig, ok := cfg.Providers[model.Provider]
770 if !ok {
771 logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
772 return ProviderConfig{}
773 }
774
775 return providerConfig
776}
777
778func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
779 cfg := Get()
780 providerConfig, ok := cfg.Providers[provider]
781 if !ok {
782 logging.Error("Provider not found", "provider", provider)
783 return Model{}
784 }
785
786 for _, model := range providerConfig.Models {
787 if model.ID == modelID {
788 return model
789 }
790 }
791
792 logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
793 return Model{}
794}
795
796func GetModel(modelType ModelType) Model {
797 cfg := Get()
798 var model PreferredModel
799 switch modelType {
800 case LargeModel:
801 model = cfg.Models.Large
802 case SmallModel:
803 model = cfg.Models.Small
804 default:
805 model = cfg.Models.Large // Fallback to large model
806 }
807 providerConfig, ok := cfg.Providers[model.Provider]
808 if !ok {
809 return Model{}
810 }
811
812 for _, m := range providerConfig.Models {
813 if m.ID == model.ModelID {
814 return m
815 }
816 }
817 return Model{}
818}
819
820func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
821 cfg := Get()
822 switch modelType {
823 case LargeModel:
824 cfg.Models.Large = model
825 case SmallModel:
826 cfg.Models.Small = model
827 default:
828 return fmt.Errorf("unknown model type: %s", modelType)
829 }
830 return nil
831}