config.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"maps"
  9	"os"
 10	"path/filepath"
 11	"slices"
 12	"strings"
 13	"sync"
 14
 15	"github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/logging"
 17)
 18
 19const (
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22	appName              = "crush"
 23
 24	MaxTokensFallbackDefault = 4096
 25)
 26
 27var defaultContextPaths = []string{
 28	".github/copilot-instructions.md",
 29	".cursorrules",
 30	".cursor/rules/",
 31	"CLAUDE.md",
 32	"CLAUDE.local.md",
 33	"GEMINI.md",
 34	"gemini.md",
 35	"crush.md",
 36	"crush.local.md",
 37	"Crush.md",
 38	"Crush.local.md",
 39	"CRUSH.md",
 40	"CRUSH.local.md",
 41}
 42
 43type AgentID string
 44
 45const (
 46	AgentCoder AgentID = "coder"
 47	AgentTask  AgentID = "task"
 48)
 49
 50type ModelType string
 51
 52const (
 53	LargeModel ModelType = "large"
 54	SmallModel ModelType = "small"
 55)
 56
 57type Model struct {
 58	ID                 string  `json:"id"`
 59	Name               string  `json:"model"`
 60	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 61	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 62	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
 63	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
 64	ContextWindow      int64   `json:"context_window"`
 65	DefaultMaxTokens   int64   `json:"default_max_tokens"`
 66	CanReason          bool    `json:"can_reason"`
 67	ReasoningEffort    string  `json:"reasoning_effort"`
 68	SupportsImages     bool    `json:"supports_attachments"`
 69}
 70
 71type VertexAIOptions struct {
 72	APIKey   string `json:"api_key,omitempty"`
 73	Project  string `json:"project,omitempty"`
 74	Location string `json:"location,omitempty"`
 75}
 76
 77type ProviderConfig struct {
 78	ID           provider.InferenceProvider `json:"id"`
 79	BaseURL      string                     `json:"base_url,omitempty"`
 80	ProviderType provider.Type              `json:"provider_type"`
 81	APIKey       string                     `json:"api_key,omitempty"`
 82	Disabled     bool                       `json:"disabled"`
 83	ExtraHeaders map[string]string          `json:"extra_headers,omitempty"`
 84	// used for e.x for vertex to set the project
 85	ExtraParams map[string]string `json:"extra_params,omitempty"`
 86
 87	DefaultLargeModel string `json:"default_large_model,omitempty"`
 88	DefaultSmallModel string `json:"default_small_model,omitempty"`
 89
 90	Models []Model `json:"models,omitempty"`
 91}
 92
 93type Agent struct {
 94	ID          AgentID `json:"id"`
 95	Name        string  `json:"name"`
 96	Description string  `json:"description,omitempty"`
 97	// This is the id of the system prompt used by the agent
 98	Disabled bool `json:"disabled"`
 99
100	Model ModelType `json:"model"`
101
102	// The available tools for the agent
103	//  if this is nil, all tools are available
104	AllowedTools []string `json:"allowed_tools"`
105
106	// this tells us which MCPs are available for this agent
107	//  if this is empty all mcps are available
108	//  the string array is the list of tools from the AllowedMCP the agent has available
109	//  if the string array is nil, all tools from the AllowedMCP are available
110	AllowedMCP map[string][]string `json:"allowed_mcp"`
111
112	// The list of LSPs that this agent can use
113	//  if this is nil, all LSPs are available
114	AllowedLSP []string `json:"allowed_lsp"`
115
116	// Overrides the context paths for this agent
117	ContextPaths []string `json:"context_paths"`
118}
119
120type MCPType string
121
122const (
123	MCPStdio MCPType = "stdio"
124	MCPSse   MCPType = "sse"
125)
126
127type MCP struct {
128	Command string            `json:"command"`
129	Env     []string          `json:"env"`
130	Args    []string          `json:"args"`
131	Type    MCPType           `json:"type"`
132	URL     string            `json:"url"`
133	Headers map[string]string `json:"headers"`
134}
135
136type LSPConfig struct {
137	Disabled bool     `json:"enabled"`
138	Command  string   `json:"command"`
139	Args     []string `json:"args"`
140	Options  any      `json:"options"`
141}
142
143type TUIOptions struct {
144	CompactMode bool `json:"compact_mode"`
145	// Here we can add themes later or any TUI related options
146}
147
148type Options struct {
149	ContextPaths         []string   `json:"context_paths"`
150	TUI                  TUIOptions `json:"tui"`
151	Debug                bool       `json:"debug"`
152	DebugLSP             bool       `json:"debug_lsp"`
153	DisableAutoSummarize bool       `json:"disable_auto_summarize"`
154	// Relative to the cwd
155	DataDirectory string `json:"data_directory"`
156}
157
158type PreferredModel struct {
159	ModelID  string                     `json:"model_id"`
160	Provider provider.InferenceProvider `json:"provider"`
161}
162
163type PreferredModels struct {
164	Large PreferredModel `json:"large"`
165	Small PreferredModel `json:"small"`
166}
167
168type Config struct {
169	Models PreferredModels `json:"models"`
170	// List of configured providers
171	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
172
173	// List of configured agents
174	Agents map[AgentID]Agent `json:"agents,omitempty"`
175
176	// List of configured MCPs
177	MCP map[string]MCP `json:"mcp,omitempty"`
178
179	// List of configured LSPs
180	LSP map[string]LSPConfig `json:"lsp,omitempty"`
181
182	// Miscellaneous options
183	Options Options `json:"options"`
184}
185
186var (
187	instance *Config // The single instance of the Singleton
188	cwd      string
189	once     sync.Once // Ensures the initialization happens only once
190
191)
192
193func loadConfig(cwd string, debug bool) (*Config, error) {
194	// First read the global config file
195	cfgPath := ConfigPath()
196
197	cfg := defaultConfigBasedOnEnv()
198	cfg.Options.Debug = debug
199	defaultLevel := slog.LevelInfo
200	if cfg.Options.Debug {
201		defaultLevel = slog.LevelDebug
202	}
203	if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
204		loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
205
206		// if file does not exist create it
207		if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
208			if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
209				return cfg, fmt.Errorf("failed to create directory: %w", err)
210			}
211			if _, err := os.Create(loggingFile); err != nil {
212				return cfg, fmt.Errorf("failed to create log file: %w", err)
213			}
214		}
215
216		sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
217		if err != nil {
218			return cfg, fmt.Errorf("failed to open log file: %w", err)
219		}
220		// Configure logger
221		logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
222			Level: defaultLevel,
223		}))
224		slog.SetDefault(logger)
225	} else {
226		// Configure logger
227		logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
228			Level: defaultLevel,
229		}))
230		slog.SetDefault(logger)
231	}
232	var globalCfg *Config
233	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
234		// some other error occurred while checking the file
235		return nil, err
236	} else if err == nil {
237		// config file exists, read it
238		file, err := os.ReadFile(cfgPath)
239		if err != nil {
240			return nil, err
241		}
242		globalCfg = &Config{}
243		if err := json.Unmarshal(file, globalCfg); err != nil {
244			return nil, err
245		}
246	} else {
247		// config file does not exist, create a new one
248		globalCfg = &Config{}
249	}
250
251	var localConfig *Config
252	// Global config loaded, now read the local config file
253	localConfigPath := filepath.Join(cwd, "crush.json")
254	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
255		// some other error occurred while checking the file
256		return nil, err
257	} else if err == nil {
258		// local config file exists, read it
259		file, err := os.ReadFile(localConfigPath)
260		if err != nil {
261			return nil, err
262		}
263		localConfig = &Config{}
264		if err := json.Unmarshal(file, localConfig); err != nil {
265			return nil, err
266		}
267	}
268
269	// merge options
270	mergeOptions(cfg, globalCfg, localConfig)
271
272	mergeProviderConfigs(cfg, globalCfg, localConfig)
273	// no providers found the app is not initialized yet
274	if len(cfg.Providers) == 0 {
275		return cfg, nil
276	}
277	preferredProvider := getPreferredProvider(cfg.Providers)
278	cfg.Models = PreferredModels{
279		Large: PreferredModel{
280			ModelID:  preferredProvider.DefaultLargeModel,
281			Provider: preferredProvider.ID,
282		},
283		Small: PreferredModel{
284			ModelID:  preferredProvider.DefaultSmallModel,
285			Provider: preferredProvider.ID,
286		},
287	}
288
289	mergeModels(cfg, globalCfg, localConfig)
290
291	if preferredProvider == nil {
292		return nil, errors.New("no valid providers configured")
293	}
294
295	agents := map[AgentID]Agent{
296		AgentCoder: {
297			ID:           AgentCoder,
298			Name:         "Coder",
299			Description:  "An agent that helps with executing coding tasks.",
300			Model:        LargeModel,
301			ContextPaths: cfg.Options.ContextPaths,
302			// All tools allowed
303		},
304		AgentTask: {
305			ID:           AgentTask,
306			Name:         "Task",
307			Description:  "An agent that helps with searching for context and finding implementation details.",
308			Model:        LargeModel,
309			ContextPaths: cfg.Options.ContextPaths,
310			AllowedTools: []string{
311				"glob",
312				"grep",
313				"ls",
314				"sourcegraph",
315				"view",
316			},
317			// NO MCPs or LSPs by default
318			AllowedMCP: map[string][]string{},
319			AllowedLSP: []string{},
320		},
321	}
322	cfg.Agents = agents
323	mergeAgents(cfg, globalCfg, localConfig)
324	mergeMCPs(cfg, globalCfg, localConfig)
325	mergeLSPs(cfg, globalCfg, localConfig)
326
327	return cfg, nil
328}
329
330func Init(workingDir string, debug bool) (*Config, error) {
331	var err error
332	once.Do(func() {
333		cwd = workingDir
334		instance, err = loadConfig(cwd, debug)
335		if err != nil {
336			logging.Error("Failed to load config", "error", err)
337		}
338	})
339
340	return instance, err
341}
342
343func Get() *Config {
344	if instance == nil {
345		// TODO: Handle this better
346		panic("Config not initialized. Call InitConfig first.")
347	}
348	return instance
349}
350
351func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
352	providers := Providers()
353	for _, p := range providers {
354		if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
355			return &providerConfig
356		}
357	}
358	// if none found return the first configured provider
359	for _, providerConfig := range configuredProviders {
360		if !providerConfig.Disabled {
361			return &providerConfig
362		}
363	}
364	return nil
365}
366
367func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
368	if other.APIKey != "" {
369		base.APIKey = other.APIKey
370	}
371	// Only change these options if the provider is not a known provider
372	if !slices.Contains(provider.KnownProviders(), p) {
373		if other.BaseURL != "" {
374			base.BaseURL = other.BaseURL
375		}
376		if other.ProviderType != "" {
377			base.ProviderType = other.ProviderType
378		}
379		if len(base.ExtraHeaders) > 0 {
380			if base.ExtraHeaders == nil {
381				base.ExtraHeaders = make(map[string]string)
382			}
383			maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
384		}
385		if len(other.ExtraParams) > 0 {
386			if base.ExtraParams == nil {
387				base.ExtraParams = make(map[string]string)
388			}
389			maps.Copy(base.ExtraParams, other.ExtraParams)
390		}
391	}
392
393	if other.Disabled {
394		base.Disabled = other.Disabled
395	}
396
397	if other.DefaultLargeModel != "" {
398		base.DefaultLargeModel = other.DefaultLargeModel
399	}
400	// Add new models if they don't exist
401	if other.Models != nil {
402		for _, model := range other.Models {
403			// check if the model already exists
404			exists := false
405			for _, existingModel := range base.Models {
406				if existingModel.ID == model.ID {
407					exists = true
408					break
409				}
410			}
411			if !exists {
412				base.Models = append(base.Models, model)
413			}
414		}
415	}
416
417	return base
418}
419
420func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
421	if !slices.Contains(provider.KnownProviders(), p) {
422		if providerConfig.ProviderType != provider.TypeOpenAI {
423			return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
424		}
425		if providerConfig.BaseURL == "" {
426			return errors.New("base URL must be set for custom providers")
427		}
428		if providerConfig.APIKey == "" {
429			return errors.New("API key must be set for custom providers")
430		}
431	}
432	return nil
433}
434
435func mergeModels(base, global, local *Config) {
436	for _, cfg := range []*Config{global, local} {
437		if cfg == nil {
438			continue
439		}
440		if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
441			base.Models.Large = cfg.Models.Large
442		}
443
444		if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
445			base.Models.Small = cfg.Models.Small
446		}
447	}
448}
449
450func mergeOptions(base, global, local *Config) {
451	for _, cfg := range []*Config{global, local} {
452		if cfg == nil {
453			continue
454		}
455		baseOptions := base.Options
456		other := cfg.Options
457		if len(other.ContextPaths) > 0 {
458			baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
459		}
460
461		if other.TUI.CompactMode {
462			baseOptions.TUI.CompactMode = other.TUI.CompactMode
463		}
464
465		if other.Debug {
466			baseOptions.Debug = other.Debug
467		}
468
469		if other.DebugLSP {
470			baseOptions.DebugLSP = other.DebugLSP
471		}
472
473		if other.DisableAutoSummarize {
474			baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
475		}
476
477		if other.DataDirectory != "" {
478			baseOptions.DataDirectory = other.DataDirectory
479		}
480		base.Options = baseOptions
481	}
482}
483
484func mergeAgents(base, global, local *Config) {
485	for _, cfg := range []*Config{global, local} {
486		if cfg == nil {
487			continue
488		}
489		for agentID, newAgent := range cfg.Agents {
490			if _, ok := base.Agents[agentID]; !ok {
491				newAgent.ID = agentID // Ensure the ID is set correctly
492				base.Agents[agentID] = newAgent
493			} else {
494				switch agentID {
495				case AgentCoder:
496					baseAgent := base.Agents[agentID]
497					if newAgent.Model != "" {
498						baseAgent.Model = newAgent.Model
499					}
500					baseAgent.AllowedMCP = newAgent.AllowedMCP
501					baseAgent.AllowedLSP = newAgent.AllowedLSP
502					base.Agents[agentID] = baseAgent
503				default:
504					baseAgent := base.Agents[agentID]
505					baseAgent.Name = newAgent.Name
506					baseAgent.Description = newAgent.Description
507					baseAgent.Disabled = newAgent.Disabled
508					if newAgent.Model == "" {
509						baseAgent.Model = LargeModel
510					}
511					baseAgent.AllowedTools = newAgent.AllowedTools
512					baseAgent.AllowedMCP = newAgent.AllowedMCP
513					baseAgent.AllowedLSP = newAgent.AllowedLSP
514					base.Agents[agentID] = baseAgent
515				}
516			}
517		}
518	}
519}
520
521func mergeMCPs(base, global, local *Config) {
522	for _, cfg := range []*Config{global, local} {
523		if cfg == nil {
524			continue
525		}
526		maps.Copy(base.MCP, cfg.MCP)
527	}
528}
529
530func mergeLSPs(base, global, local *Config) {
531	for _, cfg := range []*Config{global, local} {
532		if cfg == nil {
533			continue
534		}
535		maps.Copy(base.LSP, cfg.LSP)
536	}
537}
538
539func mergeProviderConfigs(base, global, local *Config) {
540	for _, cfg := range []*Config{global, local} {
541		if cfg == nil {
542			continue
543		}
544		for providerName, globalProvider := range cfg.Providers {
545			if _, ok := base.Providers[providerName]; !ok {
546				base.Providers[providerName] = globalProvider
547			} else {
548				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
549			}
550		}
551	}
552
553	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
554	for providerName, providerConfig := range base.Providers {
555		err := validateProvider(providerName, providerConfig)
556		if err != nil {
557			logging.Warn("Skipping provider", "name", providerName, "error", err)
558		}
559		finalProviders[providerName] = providerConfig
560	}
561	base.Providers = finalProviders
562}
563
564func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
565	switch providerId {
566	case provider.InferenceProviderAnthropic:
567		return ProviderConfig{
568			ID:           providerId,
569			ProviderType: provider.TypeAnthropic,
570		}
571	case provider.InferenceProviderOpenAI:
572		return ProviderConfig{
573			ID:           providerId,
574			ProviderType: provider.TypeOpenAI,
575		}
576	case provider.InferenceProviderGemini:
577		return ProviderConfig{
578			ID:           providerId,
579			ProviderType: provider.TypeGemini,
580		}
581	case provider.InferenceProviderBedrock:
582		return ProviderConfig{
583			ID:           providerId,
584			ProviderType: provider.TypeBedrock,
585		}
586	case provider.InferenceProviderAzure:
587		return ProviderConfig{
588			ID:           providerId,
589			ProviderType: provider.TypeAzure,
590		}
591	case provider.InferenceProviderOpenRouter:
592		return ProviderConfig{
593			ID:           providerId,
594			ProviderType: provider.TypeOpenAI,
595			BaseURL:      "https://openrouter.ai/api/v1",
596			ExtraHeaders: map[string]string{
597				"HTTP-Referer": "crush.charm.land",
598				"X-Title":      "Crush",
599			},
600		}
601	case provider.InferenceProviderXAI:
602		return ProviderConfig{
603			ID:           providerId,
604			ProviderType: provider.TypeXAI,
605			BaseURL:      "https://api.x.ai/v1",
606		}
607	case provider.InferenceProviderVertexAI:
608		return ProviderConfig{
609			ID:           providerId,
610			ProviderType: provider.TypeVertexAI,
611		}
612	default:
613		return ProviderConfig{
614			ID:           providerId,
615			ProviderType: provider.TypeOpenAI,
616		}
617	}
618}
619
620func defaultConfigBasedOnEnv() *Config {
621	cfg := &Config{
622		Options: Options{
623			DataDirectory: defaultDataDirectory,
624			ContextPaths:  defaultContextPaths,
625		},
626		Providers: make(map[provider.InferenceProvider]ProviderConfig),
627	}
628
629	providers := Providers()
630
631	for _, p := range providers {
632		if strings.HasPrefix(p.APIKey, "$") {
633			envVar := strings.TrimPrefix(p.APIKey, "$")
634			if apiKey := os.Getenv(envVar); apiKey != "" {
635				providerConfig := providerDefaultConfig(p.ID)
636				providerConfig.APIKey = apiKey
637				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
638				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
639				baseURL := p.APIEndpoint
640				if strings.HasPrefix(baseURL, "$") {
641					envVar := strings.TrimPrefix(baseURL, "$")
642					if url := os.Getenv(envVar); url != "" {
643						baseURL = url
644					}
645				}
646				providerConfig.BaseURL = baseURL
647				for _, model := range p.Models {
648					providerConfig.Models = append(providerConfig.Models, Model{
649						ID:                 model.ID,
650						Name:               model.Name,
651						CostPer1MIn:        model.CostPer1MIn,
652						CostPer1MOut:       model.CostPer1MOut,
653						CostPer1MInCached:  model.CostPer1MInCached,
654						CostPer1MOutCached: model.CostPer1MOutCached,
655						ContextWindow:      model.ContextWindow,
656						DefaultMaxTokens:   model.DefaultMaxTokens,
657						CanReason:          model.CanReason,
658						SupportsImages:     model.SupportsImages,
659					})
660				}
661				cfg.Providers[p.ID] = providerConfig
662			}
663		}
664	}
665	// TODO: support local models
666
667	if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
668		providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
669		providerConfig.ExtraParams = map[string]string{
670			"project":  os.Getenv("GOOGLE_CLOUD_PROJECT"),
671			"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
672		}
673		cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
674	}
675
676	if hasAWSCredentials() {
677		providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
678		providerConfig.ExtraParams = map[string]string{
679			"region": os.Getenv("AWS_DEFAULT_REGION"),
680		}
681		if providerConfig.ExtraParams["region"] == "" {
682			providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
683		}
684		cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
685	}
686	return cfg
687}
688
689func hasAWSCredentials() bool {
690	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
691		return true
692	}
693
694	if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
695		return true
696	}
697
698	if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
699		return true
700	}
701
702	if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
703		os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
704		return true
705	}
706
707	return false
708}
709
710func WorkingDirectory() string {
711	return cwd
712}
713
714// TODO: Handle error state
715
716func GetAgentModel(agentID AgentID) Model {
717	cfg := Get()
718	agent, ok := cfg.Agents[agentID]
719	if !ok {
720		logging.Error("Agent not found", "agent_id", agentID)
721		return Model{}
722	}
723
724	var model PreferredModel
725	switch agent.Model {
726	case LargeModel:
727		model = cfg.Models.Large
728	case SmallModel:
729		model = cfg.Models.Small
730	default:
731		logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
732		model = cfg.Models.Large // Fallback to large model
733	}
734	providerConfig, ok := cfg.Providers[model.Provider]
735	if !ok {
736		logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
737		return Model{}
738	}
739
740	for _, m := range providerConfig.Models {
741		if m.ID == model.ModelID {
742			return m
743		}
744	}
745
746	logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
747	return Model{}
748}
749
750func GetAgentProvider(agentID AgentID) ProviderConfig {
751	cfg := Get()
752	agent, ok := cfg.Agents[agentID]
753	if !ok {
754		logging.Error("Agent not found", "agent_id", agentID)
755		return ProviderConfig{}
756	}
757
758	var model PreferredModel
759	switch agent.Model {
760	case LargeModel:
761		model = cfg.Models.Large
762	case SmallModel:
763		model = cfg.Models.Small
764	default:
765		logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
766		model = cfg.Models.Large // Fallback to large model
767	}
768
769	providerConfig, ok := cfg.Providers[model.Provider]
770	if !ok {
771		logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
772		return ProviderConfig{}
773	}
774
775	return providerConfig
776}
777
778func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
779	cfg := Get()
780	providerConfig, ok := cfg.Providers[provider]
781	if !ok {
782		logging.Error("Provider not found", "provider", provider)
783		return Model{}
784	}
785
786	for _, model := range providerConfig.Models {
787		if model.ID == modelID {
788			return model
789		}
790	}
791
792	logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
793	return Model{}
794}
795
796func GetModel(modelType ModelType) Model {
797	cfg := Get()
798	var model PreferredModel
799	switch modelType {
800	case LargeModel:
801		model = cfg.Models.Large
802	case SmallModel:
803		model = cfg.Models.Small
804	default:
805		model = cfg.Models.Large // Fallback to large model
806	}
807	providerConfig, ok := cfg.Providers[model.Provider]
808	if !ok {
809		return Model{}
810	}
811
812	for _, m := range providerConfig.Models {
813		if m.ID == model.ModelID {
814			return m
815		}
816	}
817	return Model{}
818}
819
820func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
821	cfg := Get()
822	switch modelType {
823	case LargeModel:
824		cfg.Models.Large = model
825	case SmallModel:
826		cfg.Models.Small = model
827	default:
828		return fmt.Errorf("unknown model type: %s", modelType)
829	}
830	return nil
831}