config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/catwalk/pkg/catwalk"
 10	"github.com/charmbracelet/crush/internal/csync"
 11	"github.com/charmbracelet/crush/internal/llm/agent"
 12	"github.com/charmbracelet/crush/internal/llm/provider"
 13	"github.com/charmbracelet/crush/internal/resolver"
 14	"github.com/tidwall/sjson"
 15)
 16
 17const (
 18	appName              = "crush"
 19	defaultDataDirectory = ".crush"
 20	defaultLogLevel      = "info"
 21)
 22
 23var defaultContextPaths = []string{
 24	".github/copilot-instructions.md",
 25	".cursorrules",
 26	".cursor/rules/",
 27	"CLAUDE.md",
 28	"CLAUDE.local.md",
 29	"GEMINI.md",
 30	"gemini.md",
 31	"crush.md",
 32	"crush.local.md",
 33	"Crush.md",
 34	"Crush.local.md",
 35	"CRUSH.md",
 36	"CRUSH.local.md",
 37}
 38
 39type SelectedModelType string
 40
 41const (
 42	SelectedModelTypeLarge SelectedModelType = "large"
 43	SelectedModelTypeSmall SelectedModelType = "small"
 44)
 45
 46type LSPConfig struct {
 47	Disabled bool     `json:"enabled,omitempty"`
 48	Command  string   `json:"command"`
 49	Args     []string `json:"args,omitempty"`
 50	Options  any      `json:"options,omitempty"`
 51}
 52
 53type TUIOptions struct {
 54	CompactMode bool `json:"compact_mode,omitempty"`
 55	// Here we can add themes later or any TUI related options
 56}
 57
 58type Permissions struct {
 59	AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
 60	SkipRequests bool     `json:"-"`                       // Automatically accept all permissions (YOLO mode)
 61}
 62
 63type Options struct {
 64	ContextPaths         []string    `json:"context_paths,omitempty"`
 65	TUI                  *TUIOptions `json:"tui,omitempty"`
 66	Debug                bool        `json:"debug,omitempty"`
 67	DebugLSP             bool        `json:"debug_lsp,omitempty"`
 68	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
 69	DataDirectory        string      `json:"data_directory,omitempty"` // Relative to the cwd
 70}
 71
 72type MCPs map[string]agent.MCPConfig
 73
 74type MCP struct {
 75	Name string          `json:"name"`
 76	MCP  agent.MCPConfig `json:"mcp"`
 77}
 78
 79func (m MCPs) Sorted() []MCP {
 80	sorted := make([]MCP, 0, len(m))
 81	for k, v := range m {
 82		sorted = append(sorted, MCP{
 83			Name: k,
 84			MCP:  v,
 85		})
 86	}
 87	slices.SortFunc(sorted, func(a, b MCP) int {
 88		return strings.Compare(a.Name, b.Name)
 89	})
 90	return sorted
 91}
 92
 93type LSPs map[string]LSPConfig
 94
 95type LSP struct {
 96	Name string    `json:"name"`
 97	LSP  LSPConfig `json:"lsp"`
 98}
 99
100func (l LSPs) Sorted() []LSP {
101	sorted := make([]LSP, 0, len(l))
102	for k, v := range l {
103		sorted = append(sorted, LSP{
104			Name: k,
105			LSP:  v,
106		})
107	}
108	slices.SortFunc(sorted, func(a, b LSP) int {
109		return strings.Compare(a.Name, b.Name)
110	})
111	return sorted
112}
113
114// Config holds the configuration for crush.
115type Config struct {
116	// We currently only support large/small as values here.
117	Models map[SelectedModelType]agent.Model `json:"models,omitempty"`
118
119	// The providers that are configured
120	Providers *csync.Map[string, provider.Config] `json:"providers,omitempty"`
121
122	MCP MCPs `json:"mcp,omitempty"`
123
124	LSP LSPs `json:"lsp,omitempty"`
125
126	Options *Options `json:"options,omitempty"`
127
128	Permissions *Permissions `json:"permissions,omitempty"`
129
130	// Internal
131	workingDir string `json:"-"`
132	// TODO: find a better way to do this this should probably not be part of the config
133	resolver       resolver.Resolver
134	dataConfigDir  string             `json:"-"`
135	knownProviders []catwalk.Provider `json:"-"`
136}
137
138func (c *Config) WorkingDir() string {
139	return c.workingDir
140}
141
142func (c *Config) EnabledProviders() []provider.Config {
143	var enabled []provider.Config
144	for p := range c.Providers.Seq() {
145		if !p.Disable {
146			enabled = append(enabled, p)
147		}
148	}
149	return enabled
150}
151
152// IsConfigured  return true if at least one provider is configured
153func (c *Config) IsConfigured() bool {
154	return len(c.EnabledProviders()) > 0
155}
156
157func (c *Config) GetModel(provider, model string) *catwalk.Model {
158	if providerConfig, ok := c.Providers.Get(provider); ok {
159		for _, m := range providerConfig.Models {
160			if m.ID == model {
161				return &m
162			}
163		}
164	}
165	return nil
166}
167
168func (c *Config) GetProviderForModel(modelType SelectedModelType) *provider.Config {
169	model, ok := c.Models[modelType]
170	if !ok {
171		return nil
172	}
173	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
174		return &providerConfig
175	}
176	return nil
177}
178
179func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
180	model, ok := c.Models[modelType]
181	if !ok {
182		return nil
183	}
184	return c.GetModel(model.Provider, model.Model)
185}
186
187func (c *Config) LargeModel() *catwalk.Model {
188	model, ok := c.Models[SelectedModelTypeLarge]
189	if !ok {
190		return nil
191	}
192	return c.GetModel(model.Provider, model.Model)
193}
194
195func (c *Config) SmallModel() *catwalk.Model {
196	model, ok := c.Models[SelectedModelTypeSmall]
197	if !ok {
198		return nil
199	}
200	return c.GetModel(model.Provider, model.Model)
201}
202
203func (c *Config) SetCompactMode(enabled bool) error {
204	if c.Options == nil {
205		c.Options = &Options{}
206	}
207	c.Options.TUI.CompactMode = enabled
208	return c.SetConfigField("options.tui.compact_mode", enabled)
209}
210
211func (c *Config) Resolve(key string) (string, error) {
212	if c.resolver == nil {
213		return "", fmt.Errorf("no variable resolver configured")
214	}
215	return c.resolver.ResolveValue(key)
216}
217
218func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model agent.Model) error {
219	c.Models[modelType] = model
220	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
221		return fmt.Errorf("failed to update preferred model: %w", err)
222	}
223	return nil
224}
225
226func (c *Config) SetConfigField(key string, value any) error {
227	// read the data
228	data, err := os.ReadFile(c.dataConfigDir)
229	if err != nil {
230		if os.IsNotExist(err) {
231			data = []byte("{}")
232		} else {
233			return fmt.Errorf("failed to read config file: %w", err)
234		}
235	}
236
237	newValue, err := sjson.Set(string(data), key, value)
238	if err != nil {
239		return fmt.Errorf("failed to set config field %s: %w", key, err)
240	}
241	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
242		return fmt.Errorf("failed to write config file: %w", err)
243	}
244	return nil
245}
246
247func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
248	// First save to the config file
249	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
250	if err != nil {
251		return fmt.Errorf("failed to save API key to config file: %w", err)
252	}
253
254	providerConfig, exists := c.Providers.Get(providerID)
255	if exists {
256		providerConfig.APIKey = apiKey
257		c.Providers.Set(providerID, providerConfig)
258		return nil
259	}
260
261	var foundProvider *catwalk.Provider
262	for _, p := range c.knownProviders {
263		if string(p.ID) == providerID {
264			foundProvider = &p
265			break
266		}
267	}
268
269	if foundProvider != nil {
270		// Create new provider config based on known provider
271		providerConfig = provider.Config{
272			ID:           providerID,
273			Name:         foundProvider.Name,
274			BaseURL:      foundProvider.APIEndpoint,
275			Type:         foundProvider.Type,
276			APIKey:       apiKey,
277			Disable:      false,
278			ExtraHeaders: make(map[string]string),
279			ExtraParams:  make(map[string]string),
280			Models:       foundProvider.Models,
281		}
282	} else {
283		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
284	}
285	// Store the updated provider config
286	c.Providers.Set(providerID, providerConfig)
287	return nil
288}
289
290func (c *Config) Resolver() resolver.Resolver {
291	return c.resolver
292}