config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/env"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/tidwall/sjson"
 12	"golang.org/x/exp/slog"
 13)
 14
 15const (
 16	appName              = "crush"
 17	defaultDataDirectory = ".crush"
 18	defaultLogLevel      = "info"
 19)
 20
 21var defaultContextPaths = []string{
 22	".github/copilot-instructions.md",
 23	".cursorrules",
 24	".cursor/rules/",
 25	"CLAUDE.md",
 26	"CLAUDE.local.md",
 27	"GEMINI.md",
 28	"gemini.md",
 29	"crush.md",
 30	"crush.local.md",
 31	"Crush.md",
 32	"Crush.local.md",
 33	"CRUSH.md",
 34	"CRUSH.local.md",
 35}
 36
 37type SelectedModelType string
 38
 39const (
 40	SelectedModelTypeLarge SelectedModelType = "large"
 41	SelectedModelTypeSmall SelectedModelType = "small"
 42)
 43
 44type SelectedModel struct {
 45	// The model id as used by the provider API.
 46	// Required.
 47	Model string `json:"model"`
 48	// The model provider, same as the key/id used in the providers config.
 49	// Required.
 50	Provider string `json:"provider"`
 51
 52	// Only used by models that use the openai provider and need this set.
 53	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 54
 55	// Overrides the default model configuration.
 56	MaxTokens int64 `json:"max_tokens,omitempty"`
 57
 58	// Used by anthropic models that can reason to indicate if the model should think.
 59	Think bool `json:"think,omitempty"`
 60}
 61
 62type ProviderConfig struct {
 63	// The provider's id.
 64	ID string `json:"id,omitempty"`
 65	// The provider's name, used for display purposes.
 66	Name string `json:"name,omitempty"`
 67	// The provider's API endpoint.
 68	BaseURL string `json:"base_url,omitempty"`
 69	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 70	Type provider.Type `json:"type,omitempty"`
 71	// The provider's API key.
 72	APIKey string `json:"api_key,omitempty"`
 73	// Marks the provider as disabled.
 74	Disable bool `json:"disable,omitempty"`
 75
 76	// Extra headers to send with each request to the provider.
 77	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 78	// Extra body
 79	ExtraBody map[string]any `json:"extra_body,omitempty"`
 80
 81	// Used to pass extra parameters to the provider.
 82	ExtraParams map[string]string `json:"-"`
 83
 84	// The provider models
 85	Models []provider.Model `json:"models,omitempty"`
 86}
 87
 88type MCPType string
 89
 90const (
 91	MCPStdio MCPType = "stdio"
 92	MCPSse   MCPType = "sse"
 93	MCPHttp  MCPType = "http"
 94)
 95
 96type MCPConfig struct {
 97	Command  string            `json:"command,omitempty" `
 98	Env      map[string]string `json:"env,omitempty"`
 99	Args     []string          `json:"args,omitempty"`
100	Type     MCPType           `json:"type"`
101	URL      string            `json:"url,omitempty"`
102	Disabled bool              `json:"disabled,omitempty"`
103
104	// TODO: maybe make it possible to get the value from the env
105	Headers map[string]string `json:"headers,omitempty"`
106}
107
108type LSPConfig struct {
109	Disabled bool     `json:"enabled,omitempty"`
110	Command  string   `json:"command"`
111	Args     []string `json:"args,omitempty"`
112	Options  any      `json:"options,omitempty"`
113}
114
115type TUIOptions struct {
116	CompactMode bool `json:"compact_mode,omitempty"`
117	// Here we can add themes later or any TUI related options
118}
119
120type Options struct {
121	ContextPaths            []string    `json:"context_paths,omitempty"`
122	TUI                     *TUIOptions `json:"tui,omitempty"`
123	Debug                   bool        `json:"debug,omitempty"`
124	DebugLSP                bool        `json:"debug_lsp,omitempty"`
125	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
126	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
127	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
128}
129
130type MCPs map[string]MCPConfig
131
132type MCP struct {
133	Name string    `json:"name"`
134	MCP  MCPConfig `json:"mcp"`
135}
136
137func (m MCPs) Sorted() []MCP {
138	sorted := make([]MCP, 0, len(m))
139	for k, v := range m {
140		sorted = append(sorted, MCP{
141			Name: k,
142			MCP:  v,
143		})
144	}
145	slices.SortFunc(sorted, func(a, b MCP) int {
146		return strings.Compare(a.Name, b.Name)
147	})
148	return sorted
149}
150
151type LSPs map[string]LSPConfig
152
153type LSP struct {
154	Name string    `json:"name"`
155	LSP  LSPConfig `json:"lsp"`
156}
157
158func (l LSPs) Sorted() []LSP {
159	sorted := make([]LSP, 0, len(l))
160	for k, v := range l {
161		sorted = append(sorted, LSP{
162			Name: k,
163			LSP:  v,
164		})
165	}
166	slices.SortFunc(sorted, func(a, b LSP) int {
167		return strings.Compare(a.Name, b.Name)
168	})
169	return sorted
170}
171
172func (m MCPConfig) ResolvedEnv() []string {
173	resolver := NewShellVariableResolver(env.New())
174	for e, v := range m.Env {
175		var err error
176		m.Env[e], err = resolver.ResolveValue(v)
177		if err != nil {
178			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
179			continue
180		}
181	}
182
183	env := make([]string, 0, len(m.Env))
184	for k, v := range m.Env {
185		env = append(env, fmt.Sprintf("%s=%s", k, v))
186	}
187	return env
188}
189
190func (m MCPConfig) ResolvedHeaders() map[string]string {
191	resolver := NewShellVariableResolver(env.New())
192	for e, v := range m.Headers {
193		var err error
194		m.Headers[e], err = resolver.ResolveValue(v)
195		if err != nil {
196			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
197			continue
198		}
199	}
200	return m.Headers
201}
202
203type Agent struct {
204	ID          string `json:"id,omitempty"`
205	Name        string `json:"name,omitempty"`
206	Description string `json:"description,omitempty"`
207	// This is the id of the system prompt used by the agent
208	Disabled bool `json:"disabled,omitempty"`
209
210	Model SelectedModelType `json:"model"`
211
212	// The available tools for the agent
213	//  if this is nil, all tools are available
214	AllowedTools []string `json:"allowed_tools,omitempty"`
215
216	// this tells us which MCPs are available for this agent
217	//  if this is empty all mcps are available
218	//  the string array is the list of tools from the AllowedMCP the agent has available
219	//  if the string array is nil, all tools from the AllowedMCP are available
220	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
221
222	// The list of LSPs that this agent can use
223	//  if this is nil, all LSPs are available
224	AllowedLSP []string `json:"allowed_lsp,omitempty"`
225
226	// Overrides the context paths for this agent
227	ContextPaths []string `json:"context_paths,omitempty"`
228}
229
230// Config holds the configuration for crush.
231type Config struct {
232	// We currently only support large/small as values here.
233	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
234
235	// The providers that are configured
236	Providers map[string]ProviderConfig `json:"providers,omitempty"`
237
238	MCP MCPs `json:"mcp,omitempty"`
239
240	LSP LSPs `json:"lsp,omitempty"`
241
242	Options *Options `json:"options,omitempty"`
243
244	// Internal
245	workingDir string `json:"-"`
246	// TODO: most likely remove this concept when I come back to it
247	Agents map[string]Agent `json:"-"`
248	// TODO: find a better way to do this this should probably not be part of the config
249	resolver       VariableResolver
250	dataConfigDir  string              `json:"-"`
251	knownProviders []provider.Provider `json:"-"`
252}
253
254func (c *Config) WorkingDir() string {
255	return c.workingDir
256}
257
258func (c *Config) EnabledProviders() []ProviderConfig {
259	enabled := make([]ProviderConfig, 0, len(c.Providers))
260	for _, p := range c.Providers {
261		if !p.Disable {
262			enabled = append(enabled, p)
263		}
264	}
265	return enabled
266}
267
268// IsConfigured  return true if at least one provider is configured
269func (c *Config) IsConfigured() bool {
270	return len(c.EnabledProviders()) > 0
271}
272
273func (c *Config) GetModel(provider, model string) *provider.Model {
274	if providerConfig, ok := c.Providers[provider]; ok {
275		for _, m := range providerConfig.Models {
276			if m.ID == model {
277				return &m
278			}
279		}
280	}
281	return nil
282}
283
284func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
285	model, ok := c.Models[modelType]
286	if !ok {
287		return nil
288	}
289	if providerConfig, ok := c.Providers[model.Provider]; ok {
290		return &providerConfig
291	}
292	return nil
293}
294
295func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
296	model, ok := c.Models[modelType]
297	if !ok {
298		return nil
299	}
300	return c.GetModel(model.Provider, model.Model)
301}
302
303func (c *Config) LargeModel() *provider.Model {
304	model, ok := c.Models[SelectedModelTypeLarge]
305	if !ok {
306		return nil
307	}
308	return c.GetModel(model.Provider, model.Model)
309}
310
311func (c *Config) SmallModel() *provider.Model {
312	model, ok := c.Models[SelectedModelTypeSmall]
313	if !ok {
314		return nil
315	}
316	return c.GetModel(model.Provider, model.Model)
317}
318
319func (c *Config) SetCompactMode(enabled bool) error {
320	if c.Options == nil {
321		c.Options = &Options{}
322	}
323	c.Options.TUI.CompactMode = enabled
324	return c.SetConfigField("options.tui.compact_mode", enabled)
325}
326
327func (c *Config) Resolve(key string) (string, error) {
328	if c.resolver == nil {
329		return "", fmt.Errorf("no variable resolver configured")
330	}
331	return c.resolver.ResolveValue(key)
332}
333
334func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
335	c.Models[modelType] = model
336	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
337		return fmt.Errorf("failed to update preferred model: %w", err)
338	}
339	return nil
340}
341
342func (c *Config) SetConfigField(key string, value any) error {
343	// read the data
344	data, err := os.ReadFile(c.dataConfigDir)
345	if err != nil {
346		if os.IsNotExist(err) {
347			data = []byte("{}")
348		} else {
349			return fmt.Errorf("failed to read config file: %w", err)
350		}
351	}
352
353	newValue, err := sjson.Set(string(data), key, value)
354	if err != nil {
355		return fmt.Errorf("failed to set config field %s: %w", key, err)
356	}
357	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
358		return fmt.Errorf("failed to write config file: %w", err)
359	}
360	return nil
361}
362
363func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
364	// First save to the config file
365	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
366	if err != nil {
367		return fmt.Errorf("failed to save API key to config file: %w", err)
368	}
369
370	if c.Providers == nil {
371		c.Providers = make(map[string]ProviderConfig)
372	}
373
374	providerConfig, exists := c.Providers[providerID]
375	if exists {
376		providerConfig.APIKey = apiKey
377		c.Providers[providerID] = providerConfig
378		return nil
379	}
380
381	var foundProvider *provider.Provider
382	for _, p := range c.knownProviders {
383		if string(p.ID) == providerID {
384			foundProvider = &p
385			break
386		}
387	}
388
389	if foundProvider != nil {
390		// Create new provider config based on known provider
391		providerConfig = ProviderConfig{
392			ID:           providerID,
393			Name:         foundProvider.Name,
394			BaseURL:      foundProvider.APIEndpoint,
395			Type:         foundProvider.Type,
396			APIKey:       apiKey,
397			Disable:      false,
398			ExtraHeaders: make(map[string]string),
399			ExtraParams:  make(map[string]string),
400			Models:       foundProvider.Models,
401		}
402	} else {
403		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
404	}
405	// Store the updated provider config
406	c.Providers[providerID] = providerConfig
407	return nil
408}
409
410func (c *Config) SetupAgents() {
411	agents := map[string]Agent{
412		"coder": {
413			ID:           "coder",
414			Name:         "Coder",
415			Description:  "An agent that helps with executing coding tasks.",
416			Model:        SelectedModelTypeLarge,
417			ContextPaths: c.Options.ContextPaths,
418			// All tools allowed
419		},
420		"task": {
421			ID:           "task",
422			Name:         "Task",
423			Description:  "An agent that helps with searching for context and finding implementation details.",
424			Model:        SelectedModelTypeLarge,
425			ContextPaths: c.Options.ContextPaths,
426			AllowedTools: []string{
427				"glob",
428				"grep",
429				"ls",
430				"sourcegraph",
431				"view",
432			},
433			// NO MCPs or LSPs by default
434			AllowedMCP: map[string][]string{},
435			AllowedLSP: []string{},
436		},
437	}
438	c.Agents = agents
439}