1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/csync"
 11	"github.com/charmbracelet/crush/internal/llm/agent"
 12	"github.com/charmbracelet/crush/internal/llm/provider"
 13	"github.com/charmbracelet/crush/internal/resolver"
 14	"github.com/tidwall/sjson"
 15)
 16
 17const (
 18	appName              = "crush"
 19	defaultDataDirectory = ".crush"
 20	defaultLogLevel      = "info"
 21)
 22
 23var defaultContextPaths = []string{
 24	".github/copilot-instructions.md",
 25	".cursorrules",
 26	".cursor/rules/",
 27	"CLAUDE.md",
 28	"CLAUDE.local.md",
 29	"GEMINI.md",
 30	"gemini.md",
 31	"crush.md",
 32	"crush.local.md",
 33	"Crush.md",
 34	"Crush.local.md",
 35	"CRUSH.md",
 36	"CRUSH.local.md",
 37}
 38
 39type SelectedModelType string
 40
 41const (
 42	SelectedModelTypeLarge SelectedModelType = "large"
 43	SelectedModelTypeSmall SelectedModelType = "small"
 44)
 45
 46type LSPConfig struct {
 47	Disabled bool     `json:"enabled,omitempty"`
 48	Command  string   `json:"command"`
 49	Args     []string `json:"args,omitempty"`
 50	Options  any      `json:"options,omitempty"`
 51}
 52
 53type TUIOptions struct {
 54	CompactMode bool `json:"compact_mode,omitempty"`
 55	// Here we can add themes later or any TUI related options
 56}
 57
 58type Permissions struct {
 59	AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
 60	SkipRequests bool     `json:"-"`                       // Automatically accept all permissions (YOLO mode)
 61}
 62
 63type Options struct {
 64	ContextPaths         []string    `json:"context_paths,omitempty"`
 65	TUI                  *TUIOptions `json:"tui,omitempty"`
 66	Debug                bool        `json:"debug,omitempty"`
 67	DebugLSP             bool        `json:"debug_lsp,omitempty"`
 68	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
 69	DataDirectory        string      `json:"data_directory,omitempty"` // Relative to the cwd
 70}
 71
 72type MCPs map[string]agent.MCPConfig
 73
 74type MCP struct {
 75	Name string          `json:"name"`
 76	MCP  agent.MCPConfig `json:"mcp"`
 77}
 78
 79func (m MCPs) Sorted() []MCP {
 80	sorted := make([]MCP, 0, len(m))
 81	for k, v := range m {
 82		sorted = append(sorted, MCP{
 83			Name: k,
 84			MCP:  v,
 85		})
 86	}
 87	slices.SortFunc(sorted, func(a, b MCP) int {
 88		return strings.Compare(a.Name, b.Name)
 89	})
 90	return sorted
 91}
 92
 93type LSPs map[string]LSPConfig
 94
 95type LSP struct {
 96	Name string    `json:"name"`
 97	LSP  LSPConfig `json:"lsp"`
 98}
 99
100func (l LSPs) Sorted() []LSP {
101	sorted := make([]LSP, 0, len(l))
102	for k, v := range l {
103		sorted = append(sorted, LSP{
104			Name: k,
105			LSP:  v,
106		})
107	}
108	slices.SortFunc(sorted, func(a, b LSP) int {
109		return strings.Compare(a.Name, b.Name)
110	})
111	return sorted
112}
113
114// Config holds the configuration for crush.
115type Config struct {
116	// We currently only support large/small as values here.
117	Models map[SelectedModelType]agent.Model `json:"models,omitempty"`
118
119	// The providers that are configured
120	Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"`
121
122	MCP MCPs `json:"mcp,omitempty"`
123
124	LSP LSPs `json:"lsp,omitempty"`
125
126	Options *Options `json:"options,omitempty"`
127
128	Permissions *Permissions `json:"permissions,omitempty"`
129
130	// Internal
131	workingDir     string `json:"-"`
132	resolver       resolver.Resolver
133	dataConfigDir  string             `json:"-"`
134	knownProviders []catwalk.Provider `json:"-"`
135}
136
137func (c *Config) WorkingDir() string {
138	return c.workingDir
139}
140
141func (c *Config) EnabledProviders() []provider.Config {
142	var enabled []provider.Config
143	for p := range c.Providers.Seq() {
144		if !p.Disable {
145			enabled = append(enabled, p)
146		}
147	}
148	return enabled
149}
150
151// IsConfigured  return true if at least one provider is configured
152func (c *Config) IsConfigured() bool {
153	return len(c.EnabledProviders()) > 0
154}
155
156func (c *Config) GetModel(provider, model string) *catwalk.Model {
157	if providerConfig, ok := c.Providers.Get(provider); ok {
158		for _, m := range providerConfig.Models {
159			if m.ID == model {
160				return &m
161			}
162		}
163	}
164	return nil
165}
166
167func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config {
168	model, ok := c.Models[modelType]
169	if !ok {
170		return nil
171	}
172	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
173		return &providerConfig
174	}
175	return nil
176}
177
178func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
179	model, ok := c.Models[modelType]
180	if !ok {
181		return nil
182	}
183	return c.GetModel(model.Provider, model.Model)
184}
185
186func (c *Config) LargeModel() *catwalk.Model {
187	model, ok := c.Models[SelectedModelTypeLarge]
188	if !ok {
189		return nil
190	}
191	return c.GetModel(model.Provider, model.Model)
192}
193
194func (c *Config) SmallModel() *catwalk.Model {
195	model, ok := c.Models[SelectedModelTypeSmall]
196	if !ok {
197		return nil
198	}
199	return c.GetModel(model.Provider, model.Model)
200}
201
202func (c *Config) SetCompactMode(enabled bool) error {
203	if c.Options == nil {
204		c.Options = &Options{}
205	}
206	c.Options.TUI.CompactMode = enabled
207	return c.SetConfigField("options.tui.compact_mode", enabled)
208}
209
210func (c *Config) Resolve(key string) (string, error) {
211	if c.resolver == nil {
212		return "", fmt.Errorf("no variable resolver configured")
213	}
214	return c.resolver.ResolveValue(key)
215}
216
217func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error {
218	c.Models[modelType] = model
219	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
220		return fmt.Errorf("failed to update preferred model: %w", err)
221	}
222	return nil
223}
224
225func (c *Config) SetConfigField(key string, value any) error {
226	// read the data
227	data, err := os.ReadFile(c.dataConfigDir)
228	if err != nil {
229		if os.IsNotExist(err) {
230			data = []byte("{}")
231		} else {
232			return fmt.Errorf("failed to read config file: %w", err)
233		}
234	}
235
236	newValue, err := sjson.Set(string(data), key, value)
237	if err != nil {
238		return fmt.Errorf("failed to set config field %s: %w", key, err)
239	}
240	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
241		return fmt.Errorf("failed to write config file: %w", err)
242	}
243	return nil
244}
245
246func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
247	// First save to the config file
248	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
249	if err != nil {
250		return fmt.Errorf("failed to save API key to config file: %w", err)
251	}
252
253	providerConfig, exists := c.Providers.Get(providerID)
254	if exists {
255		providerConfig.APIKey = apiKey
256		c.Providers.Set(providerID, providerConfig)
257		return nil
258	}
259
260	var foundProvider *catwalk.Provider
261	for _, p := range c.knownProviders {
262		if string(p.ID) == providerID {
263			foundProvider = &p
264			break
265		}
266	}
267
268	if foundProvider != nil {
269		// Create new provider config based on known provider
270		providerConfig = provider.Config{
271			ID:           providerID,
272			Name:         foundProvider.Name,
273			BaseURL:      foundProvider.APIEndpoint,
274			Type:         foundProvider.Type,
275			APIKey:       apiKey,
276			Disable:      false,
277			ExtraHeaders: make(map[string]string),
278			ExtraParams:  make(map[string]string),
279			Models:       foundProvider.Models,
280		}
281	} else {
282		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
283	}
284	// Store the updated provider config
285	c.Providers.Set(providerID, providerConfig)
286	return nil
287}
288
289func (c *Config) Resolver() resolver.Resolver {
290	return c.resolver
291}