config.go

  1package config
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"maps"
  9	"os"
 10	"path/filepath"
 11	"slices"
 12	"strings"
 13	"sync"
 14
 15	"github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/logging"
 17)
 18
 19const (
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22	appName              = "crush"
 23
 24	MaxTokensFallbackDefault = 4096
 25)
 26
 27var defaultContextPaths = []string{
 28	".github/copilot-instructions.md",
 29	".cursorrules",
 30	".cursor/rules/",
 31	"CLAUDE.md",
 32	"CLAUDE.local.md",
 33	"GEMINI.md",
 34	"gemini.md",
 35	"crush.md",
 36	"crush.local.md",
 37	"Crush.md",
 38	"Crush.local.md",
 39	"CRUSH.md",
 40	"CRUSH.local.md",
 41}
 42
 43type AgentID string
 44
 45const (
 46	AgentCoder AgentID = "coder"
 47	AgentTask  AgentID = "task"
 48)
 49
 50type ModelType string
 51
 52const (
 53	LargeModel ModelType = "large"
 54	SmallModel ModelType = "small"
 55)
 56
 57type Model struct {
 58	ID                 string  `json:"id"`
 59	Name               string  `json:"model"`
 60	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 61	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 62	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
 63	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
 64	ContextWindow      int64   `json:"context_window"`
 65	DefaultMaxTokens   int64   `json:"default_max_tokens"`
 66	CanReason          bool    `json:"can_reason"`
 67	ReasoningEffort    string  `json:"reasoning_effort"`
 68	SupportsImages     bool    `json:"supports_attachments"`
 69}
 70
 71type VertexAIOptions struct {
 72	APIKey   string `json:"api_key,omitempty"`
 73	Project  string `json:"project,omitempty"`
 74	Location string `json:"location,omitempty"`
 75}
 76
 77type ProviderConfig struct {
 78	ID           provider.InferenceProvider `json:"id"`
 79	BaseURL      string                     `json:"base_url,omitempty"`
 80	ProviderType provider.Type              `json:"provider_type"`
 81	APIKey       string                     `json:"api_key,omitempty"`
 82	Disabled     bool                       `json:"disabled"`
 83	ExtraHeaders map[string]string          `json:"extra_headers,omitempty"`
 84	// used for e.x for vertex to set the project
 85	ExtraParams map[string]string `json:"extra_params,omitempty"`
 86
 87	DefaultLargeModel string `json:"default_large_model,omitempty"`
 88	DefaultSmallModel string `json:"default_small_model,omitempty"`
 89
 90	Models []Model `json:"models,omitempty"`
 91}
 92
 93type Agent struct {
 94	ID          AgentID `json:"id"`
 95	Name        string  `json:"name"`
 96	Description string  `json:"description,omitempty"`
 97	// This is the id of the system prompt used by the agent
 98	Disabled bool `json:"disabled"`
 99
100	Model ModelType `json:"model"`
101
102	// The available tools for the agent
103	//  if this is nil, all tools are available
104	AllowedTools []string `json:"allowed_tools"`
105
106	// this tells us which MCPs are available for this agent
107	//  if this is empty all mcps are available
108	//  the string array is the list of tools from the AllowedMCP the agent has available
109	//  if the string array is nil, all tools from the AllowedMCP are available
110	AllowedMCP map[string][]string `json:"allowed_mcp"`
111
112	// The list of LSPs that this agent can use
113	//  if this is nil, all LSPs are available
114	AllowedLSP []string `json:"allowed_lsp"`
115
116	// Overrides the context paths for this agent
117	ContextPaths []string `json:"context_paths"`
118}
119
120type MCPType string
121
122const (
123	MCPStdio MCPType = "stdio"
124	MCPSse   MCPType = "sse"
125)
126
127type MCP struct {
128	Command string            `json:"command"`
129	Env     []string          `json:"env"`
130	Args    []string          `json:"args"`
131	Type    MCPType           `json:"type"`
132	URL     string            `json:"url"`
133	Headers map[string]string `json:"headers"`
134}
135
136type LSPConfig struct {
137	Disabled bool     `json:"enabled"`
138	Command  string   `json:"command"`
139	Args     []string `json:"args"`
140	Options  any      `json:"options"`
141}
142
143type TUIOptions struct {
144	CompactMode bool `json:"compact_mode"`
145	// Here we can add themes later or any TUI related options
146}
147
148type Options struct {
149	ContextPaths         []string   `json:"context_paths"`
150	TUI                  TUIOptions `json:"tui"`
151	Debug                bool       `json:"debug"`
152	DebugLSP             bool       `json:"debug_lsp"`
153	DisableAutoSummarize bool       `json:"disable_auto_summarize"`
154	// Relative to the cwd
155	DataDirectory string `json:"data_directory"`
156}
157
158type PreferredModel struct {
159	ModelID  string                     `json:"model_id"`
160	Provider provider.InferenceProvider `json:"provider"`
161}
162
163type PreferredModels struct {
164	Large PreferredModel `json:"large"`
165	Small PreferredModel `json:"small"`
166}
167
168type Config struct {
169	Models PreferredModels `json:"models"`
170	// List of configured providers
171	Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
172
173	// List of configured agents
174	Agents map[AgentID]Agent `json:"agents,omitempty"`
175
176	// List of configured MCPs
177	MCP map[string]MCP `json:"mcp,omitempty"`
178
179	// List of configured LSPs
180	LSP map[string]LSPConfig `json:"lsp,omitempty"`
181
182	// Miscellaneous options
183	Options Options `json:"options"`
184}
185
186var (
187	instance *Config // The single instance of the Singleton
188	cwd      string
189	once     sync.Once // Ensures the initialization happens only once
190
191)
192
193func loadConfig(cwd string, debug bool) (*Config, error) {
194	// First read the global config file
195	cfgPath := ConfigPath()
196
197	cfg := defaultConfigBasedOnEnv()
198	cfg.Options.Debug = debug
199	defaultLevel := slog.LevelInfo
200	if cfg.Options.Debug {
201		defaultLevel = slog.LevelDebug
202	}
203	if os.Getenv("CRUSH_DEV_DEBUG") == "true" {
204		loggingFile := fmt.Sprintf("%s/%s", cfg.Options.DataDirectory, "debug.log")
205
206		// if file does not exist create it
207		if _, err := os.Stat(loggingFile); os.IsNotExist(err) {
208			if err := os.MkdirAll(cfg.Options.DataDirectory, 0o755); err != nil {
209				return cfg, fmt.Errorf("failed to create directory: %w", err)
210			}
211			if _, err := os.Create(loggingFile); err != nil {
212				return cfg, fmt.Errorf("failed to create log file: %w", err)
213			}
214		}
215
216		sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666)
217		if err != nil {
218			return cfg, fmt.Errorf("failed to open log file: %w", err)
219		}
220		// Configure logger
221		logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{
222			Level: defaultLevel,
223		}))
224		slog.SetDefault(logger)
225	} else {
226		// Configure logger
227		logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
228			Level: defaultLevel,
229		}))
230		slog.SetDefault(logger)
231	}
232	var globalCfg *Config
233	if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
234		// some other error occurred while checking the file
235		return nil, err
236	} else if err == nil {
237		// config file exists, read it
238		file, err := os.ReadFile(cfgPath)
239		if err != nil {
240			return nil, err
241		}
242		globalCfg = &Config{}
243		if err := json.Unmarshal(file, globalCfg); err != nil {
244			return nil, err
245		}
246	} else {
247		// config file does not exist, create a new one
248		globalCfg = &Config{}
249	}
250
251	var localConfig *Config
252	// Global config loaded, now read the local config file
253	localConfigPath := filepath.Join(cwd, "crush.json")
254	if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
255		// some other error occurred while checking the file
256		return nil, err
257	} else if err == nil {
258		// local config file exists, read it
259		file, err := os.ReadFile(localConfigPath)
260		if err != nil {
261			return nil, err
262		}
263		localConfig = &Config{}
264		if err := json.Unmarshal(file, localConfig); err != nil {
265			return nil, err
266		}
267	}
268
269	// merge options
270	mergeOptions(cfg, globalCfg, localConfig)
271
272	mergeProviderConfigs(cfg, globalCfg, localConfig)
273	// no providers found the app is not initialized yet
274	if len(cfg.Providers) == 0 {
275		return cfg, nil
276	}
277	preferredProvider := getPreferredProvider(cfg.Providers)
278	if preferredProvider != nil {
279		cfg.Models = PreferredModels{
280			Large: PreferredModel{
281				ModelID:  preferredProvider.DefaultLargeModel,
282				Provider: preferredProvider.ID,
283			},
284			Small: PreferredModel{
285				ModelID:  preferredProvider.DefaultSmallModel,
286				Provider: preferredProvider.ID,
287			},
288		}
289	} else {
290		// No valid providers found, set empty models
291		cfg.Models = PreferredModels{}
292	}
293
294	mergeModels(cfg, globalCfg, localConfig)
295
296	agents := map[AgentID]Agent{
297		AgentCoder: {
298			ID:           AgentCoder,
299			Name:         "Coder",
300			Description:  "An agent that helps with executing coding tasks.",
301			Model:        LargeModel,
302			ContextPaths: cfg.Options.ContextPaths,
303			// All tools allowed
304		},
305		AgentTask: {
306			ID:           AgentTask,
307			Name:         "Task",
308			Description:  "An agent that helps with searching for context and finding implementation details.",
309			Model:        LargeModel,
310			ContextPaths: cfg.Options.ContextPaths,
311			AllowedTools: []string{
312				"glob",
313				"grep",
314				"ls",
315				"sourcegraph",
316				"view",
317			},
318			// NO MCPs or LSPs by default
319			AllowedMCP: map[string][]string{},
320			AllowedLSP: []string{},
321		},
322	}
323	cfg.Agents = agents
324	mergeAgents(cfg, globalCfg, localConfig)
325	mergeMCPs(cfg, globalCfg, localConfig)
326	mergeLSPs(cfg, globalCfg, localConfig)
327
328	return cfg, nil
329}
330
331func Init(workingDir string, debug bool) (*Config, error) {
332	var err error
333	once.Do(func() {
334		cwd = workingDir
335		instance, err = loadConfig(cwd, debug)
336		if err != nil {
337			logging.Error("Failed to load config", "error", err)
338		}
339	})
340
341	return instance, err
342}
343
344func Get() *Config {
345	if instance == nil {
346		// TODO: Handle this better
347		panic("Config not initialized. Call InitConfig first.")
348	}
349	return instance
350}
351
352func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
353	providers := Providers()
354	for _, p := range providers {
355		if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
356			return &providerConfig
357		}
358	}
359	// if none found return the first configured provider
360	for _, providerConfig := range configuredProviders {
361		if !providerConfig.Disabled {
362			return &providerConfig
363		}
364	}
365	return nil
366}
367
368func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
369	if other.APIKey != "" {
370		base.APIKey = other.APIKey
371	}
372	// Only change these options if the provider is not a known provider
373	if !slices.Contains(provider.KnownProviders(), p) {
374		if other.BaseURL != "" {
375			base.BaseURL = other.BaseURL
376		}
377		if other.ProviderType != "" {
378			base.ProviderType = other.ProviderType
379		}
380		if len(other.ExtraHeaders) > 0 {
381			if base.ExtraHeaders == nil {
382				base.ExtraHeaders = make(map[string]string)
383			}
384			maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
385		}
386		if len(other.ExtraParams) > 0 {
387			if base.ExtraParams == nil {
388				base.ExtraParams = make(map[string]string)
389			}
390			maps.Copy(base.ExtraParams, other.ExtraParams)
391		}
392	}
393
394	if other.Disabled {
395		base.Disabled = other.Disabled
396	}
397
398	if other.DefaultLargeModel != "" {
399		base.DefaultLargeModel = other.DefaultLargeModel
400	}
401	// Add new models if they don't exist
402	if other.Models != nil {
403		for _, model := range other.Models {
404			// check if the model already exists
405			exists := false
406			for _, existingModel := range base.Models {
407				if existingModel.ID == model.ID {
408					exists = true
409					break
410				}
411			}
412			if !exists {
413				base.Models = append(base.Models, model)
414			}
415		}
416	}
417
418	return base
419}
420
421func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
422	if !slices.Contains(provider.KnownProviders(), p) {
423		if providerConfig.ProviderType != provider.TypeOpenAI {
424			return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
425		}
426		if providerConfig.BaseURL == "" {
427			return errors.New("base URL must be set for custom providers")
428		}
429		if providerConfig.APIKey == "" {
430			return errors.New("API key must be set for custom providers")
431		}
432	}
433	return nil
434}
435
436func mergeModels(base, global, local *Config) {
437	for _, cfg := range []*Config{global, local} {
438		if cfg == nil {
439			continue
440		}
441		if cfg.Models.Large.ModelID != "" && cfg.Models.Large.Provider != "" {
442			base.Models.Large = cfg.Models.Large
443		}
444
445		if cfg.Models.Small.ModelID != "" && cfg.Models.Small.Provider != "" {
446			base.Models.Small = cfg.Models.Small
447		}
448	}
449}
450
451func mergeOptions(base, global, local *Config) {
452	for _, cfg := range []*Config{global, local} {
453		if cfg == nil {
454			continue
455		}
456		baseOptions := base.Options
457		other := cfg.Options
458		if len(other.ContextPaths) > 0 {
459			baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
460		}
461
462		if other.TUI.CompactMode {
463			baseOptions.TUI.CompactMode = other.TUI.CompactMode
464		}
465
466		if other.Debug {
467			baseOptions.Debug = other.Debug
468		}
469
470		if other.DebugLSP {
471			baseOptions.DebugLSP = other.DebugLSP
472		}
473
474		if other.DisableAutoSummarize {
475			baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
476		}
477
478		if other.DataDirectory != "" {
479			baseOptions.DataDirectory = other.DataDirectory
480		}
481		base.Options = baseOptions
482	}
483}
484
485func mergeAgents(base, global, local *Config) {
486	for _, cfg := range []*Config{global, local} {
487		if cfg == nil {
488			continue
489		}
490		for agentID, newAgent := range cfg.Agents {
491			if _, ok := base.Agents[agentID]; !ok {
492				// New agent - apply defaults
493				newAgent.ID = agentID // Ensure the ID is set correctly
494				if newAgent.Model == "" {
495					newAgent.Model = LargeModel // Default model type
496				}
497				// Context paths are always additive - start with global, then add custom
498				if len(newAgent.ContextPaths) > 0 {
499					newAgent.ContextPaths = append(base.Options.ContextPaths, newAgent.ContextPaths...)
500				} else {
501					newAgent.ContextPaths = base.Options.ContextPaths // Use global context paths only
502				}
503				base.Agents[agentID] = newAgent
504			} else {
505				baseAgent := base.Agents[agentID]
506				
507				// Special handling for known agents - only allow model changes
508				if agentID == AgentCoder || agentID == AgentTask {
509					if newAgent.Model != "" {
510						baseAgent.Model = newAgent.Model
511					}
512					// For known agents, only allow MCP and LSP configuration
513					if newAgent.AllowedMCP != nil {
514						baseAgent.AllowedMCP = newAgent.AllowedMCP
515					}
516					if newAgent.AllowedLSP != nil {
517						baseAgent.AllowedLSP = newAgent.AllowedLSP
518					}
519					// Context paths are additive for known agents too
520					if len(newAgent.ContextPaths) > 0 {
521						baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
522					}
523				} else {
524					// Custom agents - allow full merging
525					if newAgent.Name != "" {
526						baseAgent.Name = newAgent.Name
527					}
528					if newAgent.Description != "" {
529						baseAgent.Description = newAgent.Description
530					}
531					if newAgent.Model != "" {
532						baseAgent.Model = newAgent.Model
533					} else if baseAgent.Model == "" {
534						baseAgent.Model = LargeModel // Default fallback
535					}
536					
537					// Boolean fields - always update (including false values)
538					baseAgent.Disabled = newAgent.Disabled
539					
540					// Slice/Map fields - update if provided (including empty slices/maps)
541					if newAgent.AllowedTools != nil {
542						baseAgent.AllowedTools = newAgent.AllowedTools
543					}
544					if newAgent.AllowedMCP != nil {
545						baseAgent.AllowedMCP = newAgent.AllowedMCP
546					}
547					if newAgent.AllowedLSP != nil {
548						baseAgent.AllowedLSP = newAgent.AllowedLSP
549					}
550					// Context paths are additive for custom agents too
551					if len(newAgent.ContextPaths) > 0 {
552						baseAgent.ContextPaths = append(baseAgent.ContextPaths, newAgent.ContextPaths...)
553					}
554				}
555				
556				base.Agents[agentID] = baseAgent
557			}
558		}
559	}
560}
561
562func mergeMCPs(base, global, local *Config) {
563	for _, cfg := range []*Config{global, local} {
564		if cfg == nil {
565			continue
566		}
567		maps.Copy(base.MCP, cfg.MCP)
568	}
569}
570
571func mergeLSPs(base, global, local *Config) {
572	for _, cfg := range []*Config{global, local} {
573		if cfg == nil {
574			continue
575		}
576		maps.Copy(base.LSP, cfg.LSP)
577	}
578}
579
580func mergeProviderConfigs(base, global, local *Config) {
581	for _, cfg := range []*Config{global, local} {
582		if cfg == nil {
583			continue
584		}
585		for providerName, globalProvider := range cfg.Providers {
586			if _, ok := base.Providers[providerName]; !ok {
587				base.Providers[providerName] = globalProvider
588			} else {
589				base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
590			}
591		}
592	}
593
594	finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
595	for providerName, providerConfig := range base.Providers {
596		err := validateProvider(providerName, providerConfig)
597		if err != nil {
598			logging.Warn("Skipping provider", "name", providerName, "error", err)
599			continue // Skip invalid providers
600		}
601		finalProviders[providerName] = providerConfig
602	}
603	base.Providers = finalProviders
604}
605
606func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
607	switch providerId {
608	case provider.InferenceProviderAnthropic:
609		return ProviderConfig{
610			ID:           providerId,
611			ProviderType: provider.TypeAnthropic,
612		}
613	case provider.InferenceProviderOpenAI:
614		return ProviderConfig{
615			ID:           providerId,
616			ProviderType: provider.TypeOpenAI,
617		}
618	case provider.InferenceProviderGemini:
619		return ProviderConfig{
620			ID:           providerId,
621			ProviderType: provider.TypeGemini,
622		}
623	case provider.InferenceProviderBedrock:
624		return ProviderConfig{
625			ID:           providerId,
626			ProviderType: provider.TypeBedrock,
627		}
628	case provider.InferenceProviderAzure:
629		return ProviderConfig{
630			ID:           providerId,
631			ProviderType: provider.TypeAzure,
632		}
633	case provider.InferenceProviderOpenRouter:
634		return ProviderConfig{
635			ID:           providerId,
636			ProviderType: provider.TypeOpenAI,
637			BaseURL:      "https://openrouter.ai/api/v1",
638			ExtraHeaders: map[string]string{
639				"HTTP-Referer": "crush.charm.land",
640				"X-Title":      "Crush",
641			},
642		}
643	case provider.InferenceProviderXAI:
644		return ProviderConfig{
645			ID:           providerId,
646			ProviderType: provider.TypeXAI,
647			BaseURL:      "https://api.x.ai/v1",
648		}
649	case provider.InferenceProviderVertexAI:
650		return ProviderConfig{
651			ID:           providerId,
652			ProviderType: provider.TypeVertexAI,
653		}
654	default:
655		return ProviderConfig{
656			ID:           providerId,
657			ProviderType: provider.TypeOpenAI,
658		}
659	}
660}
661
662func defaultConfigBasedOnEnv() *Config {
663	cfg := &Config{
664		Options: Options{
665			DataDirectory: defaultDataDirectory,
666			ContextPaths:  defaultContextPaths,
667		},
668		Providers: make(map[provider.InferenceProvider]ProviderConfig),
669	}
670
671	providers := Providers()
672
673	for _, p := range providers {
674		if strings.HasPrefix(p.APIKey, "$") {
675			envVar := strings.TrimPrefix(p.APIKey, "$")
676			if apiKey := os.Getenv(envVar); apiKey != "" {
677				providerConfig := providerDefaultConfig(p.ID)
678				providerConfig.APIKey = apiKey
679				providerConfig.DefaultLargeModel = p.DefaultLargeModelID
680				providerConfig.DefaultSmallModel = p.DefaultSmallModelID
681				baseURL := p.APIEndpoint
682				if strings.HasPrefix(baseURL, "$") {
683					envVar := strings.TrimPrefix(baseURL, "$")
684					if url := os.Getenv(envVar); url != "" {
685						baseURL = url
686					}
687				}
688				providerConfig.BaseURL = baseURL
689				for _, model := range p.Models {
690					providerConfig.Models = append(providerConfig.Models, Model{
691						ID:                 model.ID,
692						Name:               model.Name,
693						CostPer1MIn:        model.CostPer1MIn,
694						CostPer1MOut:       model.CostPer1MOut,
695						CostPer1MInCached:  model.CostPer1MInCached,
696						CostPer1MOutCached: model.CostPer1MOutCached,
697						ContextWindow:      model.ContextWindow,
698						DefaultMaxTokens:   model.DefaultMaxTokens,
699						CanReason:          model.CanReason,
700						SupportsImages:     model.SupportsImages,
701					})
702				}
703				cfg.Providers[p.ID] = providerConfig
704			}
705		}
706	}
707	// TODO: support local models
708
709	if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
710		providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
711		providerConfig.ExtraParams = map[string]string{
712			"project":  os.Getenv("GOOGLE_CLOUD_PROJECT"),
713			"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
714		}
715		cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
716	}
717
718	if hasAWSCredentials() {
719		providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
720		providerConfig.ExtraParams = map[string]string{
721			"region": os.Getenv("AWS_DEFAULT_REGION"),
722		}
723		if providerConfig.ExtraParams["region"] == "" {
724			providerConfig.ExtraParams["region"] = os.Getenv("AWS_REGION")
725		}
726		cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
727	}
728	return cfg
729}
730
731func hasAWSCredentials() bool {
732	if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
733		return true
734	}
735
736	if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
737		return true
738	}
739
740	if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
741		return true
742	}
743
744	if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
745		os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
746		return true
747	}
748
749	return false
750}
751
752func WorkingDirectory() string {
753	return cwd
754}
755
756// TODO: Handle error state
757
758func GetAgentModel(agentID AgentID) Model {
759	cfg := Get()
760	agent, ok := cfg.Agents[agentID]
761	if !ok {
762		logging.Error("Agent not found", "agent_id", agentID)
763		return Model{}
764	}
765
766	var model PreferredModel
767	switch agent.Model {
768	case LargeModel:
769		model = cfg.Models.Large
770	case SmallModel:
771		model = cfg.Models.Small
772	default:
773		logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
774		model = cfg.Models.Large // Fallback to large model
775	}
776	providerConfig, ok := cfg.Providers[model.Provider]
777	if !ok {
778		logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
779		return Model{}
780	}
781
782	for _, m := range providerConfig.Models {
783		if m.ID == model.ModelID {
784			return m
785		}
786	}
787
788	logging.Error("Model not found for agent", "agent_id", agentID, "model", agent.Model)
789	return Model{}
790}
791
792func GetAgentProvider(agentID AgentID) ProviderConfig {
793	cfg := Get()
794	agent, ok := cfg.Agents[agentID]
795	if !ok {
796		logging.Error("Agent not found", "agent_id", agentID)
797		return ProviderConfig{}
798	}
799
800	var model PreferredModel
801	switch agent.Model {
802	case LargeModel:
803		model = cfg.Models.Large
804	case SmallModel:
805		model = cfg.Models.Small
806	default:
807		logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
808		model = cfg.Models.Large // Fallback to large model
809	}
810
811	providerConfig, ok := cfg.Providers[model.Provider]
812	if !ok {
813		logging.Error("Provider not found for agent", "agent_id", agentID, "provider", model.Provider)
814		return ProviderConfig{}
815	}
816
817	return providerConfig
818}
819
820func GetProviderModel(provider provider.InferenceProvider, modelID string) Model {
821	cfg := Get()
822	providerConfig, ok := cfg.Providers[provider]
823	if !ok {
824		logging.Error("Provider not found", "provider", provider)
825		return Model{}
826	}
827
828	for _, model := range providerConfig.Models {
829		if model.ID == modelID {
830			return model
831		}
832	}
833
834	logging.Error("Model not found for provider", "provider", provider, "model_id", modelID)
835	return Model{}
836}
837
838func GetModel(modelType ModelType) Model {
839	cfg := Get()
840	var model PreferredModel
841	switch modelType {
842	case LargeModel:
843		model = cfg.Models.Large
844	case SmallModel:
845		model = cfg.Models.Small
846	default:
847		model = cfg.Models.Large // Fallback to large model
848	}
849	providerConfig, ok := cfg.Providers[model.Provider]
850	if !ok {
851		return Model{}
852	}
853
854	for _, m := range providerConfig.Models {
855		if m.ID == model.ModelID {
856			return m
857		}
858	}
859	return Model{}
860}
861
862func UpdatePreferredModel(modelType ModelType, model PreferredModel) error {
863	cfg := Get()
864	switch modelType {
865	case LargeModel:
866		cfg.Models.Large = model
867	case SmallModel:
868		cfg.Models.Small = model
869	default:
870		return fmt.Errorf("unknown model type: %s", modelType)
871	}
872	return nil
873}