config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/fur/provider"
 10	"github.com/tidwall/sjson"
 11)
 12
 13const (
 14	appName              = "crush"
 15	defaultDataDirectory = ".crush"
 16	defaultLogLevel      = "info"
 17)
 18
 19var defaultContextPaths = []string{
 20	".github/copilot-instructions.md",
 21	".cursorrules",
 22	".cursor/rules/",
 23	"CLAUDE.md",
 24	"CLAUDE.local.md",
 25	"GEMINI.md",
 26	"gemini.md",
 27	"crush.md",
 28	"crush.local.md",
 29	"Crush.md",
 30	"Crush.local.md",
 31	"CRUSH.md",
 32	"CRUSH.local.md",
 33}
 34
 35type SelectedModelType string
 36
 37const (
 38	SelectedModelTypeLarge SelectedModelType = "large"
 39	SelectedModelTypeSmall SelectedModelType = "small"
 40)
 41
 42type SelectedModel struct {
 43	// The model id as used by the provider API.
 44	// Required.
 45	Model string `json:"model"`
 46	// The model provider, same as the key/id used in the providers config.
 47	// Required.
 48	Provider string `json:"provider"`
 49
 50	// Only used by models that use the openai provider and need this set.
 51	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 52
 53	// Overrides the default model configuration.
 54	MaxTokens int64 `json:"max_tokens,omitempty"`
 55
 56	// Used by anthropic models that can reason to indicate if the model should think.
 57	Think bool `json:"think,omitempty"`
 58}
 59
 60type ProviderConfig struct {
 61	// The provider's id.
 62	ID string `json:"id,omitempty"`
 63	// The provider's name, used for display purposes.
 64	Name string `json:"name,omitempty"`
 65	// The provider's API endpoint.
 66	BaseURL string `json:"base_url,omitempty"`
 67	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 68	Type provider.Type `json:"type,omitempty"`
 69	// The provider's API key.
 70	APIKey string `json:"api_key,omitempty"`
 71	// Marks the provider as disabled.
 72	Disable bool `json:"disable,omitempty"`
 73
 74	// Extra headers to send with each request to the provider.
 75	ExtraHeaders map[string]string
 76
 77	// Used to pass extra parameters to the provider.
 78	ExtraParams map[string]string `json:"-"`
 79
 80	// The provider models
 81	Models []provider.Model `json:"models,omitempty"`
 82}
 83
 84type MCPType string
 85
 86const (
 87	MCPStdio MCPType = "stdio"
 88	MCPSse   MCPType = "sse"
 89	MCPHttp  MCPType = "http"
 90)
 91
 92type MCPConfig struct {
 93	Command  string   `json:"command,omitempty" `
 94	Env      []string `json:"env,omitempty"`
 95	Args     []string `json:"args,omitempty"`
 96	Type     MCPType  `json:"type"`
 97	URL      string   `json:"url,omitempty"`
 98	Disabled bool     `json:"disabled,omitempty"`
 99
100	// TODO: maybe make it possible to get the value from the env
101	Headers map[string]string `json:"headers,omitempty"`
102}
103
104type LSPConfig struct {
105	Disabled bool     `json:"enabled,omitempty"`
106	Command  string   `json:"command"`
107	Args     []string `json:"args,omitempty"`
108	Options  any      `json:"options,omitempty"`
109}
110
111type TUIOptions struct {
112	CompactMode bool `json:"compact_mode,omitempty"`
113	// Here we can add themes later or any TUI related options
114}
115
116type Options struct {
117	ContextPaths         []string    `json:"context_paths,omitempty"`
118	TUI                  *TUIOptions `json:"tui,omitempty"`
119	Debug                bool        `json:"debug,omitempty"`
120	DebugLSP             bool        `json:"debug_lsp,omitempty"`
121	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
122	// Relative to the cwd
123	DataDirectory string `json:"data_directory,omitempty"`
124}
125
126type MCPs map[string]MCPConfig
127
128type MCP struct {
129	Name string    `json:"name"`
130	MCP  MCPConfig `json:"mcp"`
131}
132
133func (m MCPs) Sorted() []MCP {
134	sorted := make([]MCP, 0, len(m))
135	for k, v := range m {
136		sorted = append(sorted, MCP{
137			Name: k,
138			MCP:  v,
139		})
140	}
141	slices.SortFunc(sorted, func(a, b MCP) int {
142		return strings.Compare(a.Name, b.Name)
143	})
144	return sorted
145}
146
147type LSPs map[string]LSPConfig
148
149type LSP struct {
150	Name string    `json:"name"`
151	LSP  LSPConfig `json:"lsp"`
152}
153
154func (l LSPs) Sorted() []LSP {
155	sorted := make([]LSP, 0, len(l))
156	for k, v := range l {
157		sorted = append(sorted, LSP{
158			Name: k,
159			LSP:  v,
160		})
161	}
162	slices.SortFunc(sorted, func(a, b LSP) int {
163		return strings.Compare(a.Name, b.Name)
164	})
165	return sorted
166}
167
168type Agent struct {
169	ID          string `json:"id,omitempty"`
170	Name        string `json:"name,omitempty"`
171	Description string `json:"description,omitempty"`
172	// This is the id of the system prompt used by the agent
173	Disabled bool `json:"disabled,omitempty"`
174
175	Model SelectedModelType `json:"model"`
176
177	// The available tools for the agent
178	//  if this is nil, all tools are available
179	AllowedTools []string `json:"allowed_tools,omitempty"`
180
181	// this tells us which MCPs are available for this agent
182	//  if this is empty all mcps are available
183	//  the string array is the list of tools from the AllowedMCP the agent has available
184	//  if the string array is nil, all tools from the AllowedMCP are available
185	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
186
187	// The list of LSPs that this agent can use
188	//  if this is nil, all LSPs are available
189	AllowedLSP []string `json:"allowed_lsp,omitempty"`
190
191	// Overrides the context paths for this agent
192	ContextPaths []string `json:"context_paths,omitempty"`
193}
194
195// Config holds the configuration for crush.
196type Config struct {
197	// We currently only support large/small as values here.
198	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
199
200	// The providers that are configured
201	Providers map[string]ProviderConfig `json:"providers,omitempty"`
202
203	MCP MCPs `json:"mcp,omitempty"`
204
205	LSP LSPs `json:"lsp,omitempty"`
206
207	Options *Options `json:"options,omitempty"`
208
209	// Internal
210	workingDir string `json:"-"`
211	// TODO: most likely remove this concept when I come back to it
212	Agents map[string]Agent `json:"-"`
213	// TODO: find a better way to do this this should probably not be part of the config
214	resolver       VariableResolver
215	dataConfigDir  string              `json:"-"`
216	knownProviders []provider.Provider `json:"-"`
217}
218
219func (c *Config) WorkingDir() string {
220	return c.workingDir
221}
222
223func (c *Config) EnabledProviders() []ProviderConfig {
224	enabled := make([]ProviderConfig, 0, len(c.Providers))
225	for _, p := range c.Providers {
226		if !p.Disable {
227			enabled = append(enabled, p)
228		}
229	}
230	return enabled
231}
232
233// IsConfigured  return true if at least one provider is configured
234func (c *Config) IsConfigured() bool {
235	return len(c.EnabledProviders()) > 0
236}
237
238func (c *Config) GetModel(provider, model string) *provider.Model {
239	if providerConfig, ok := c.Providers[provider]; ok {
240		for _, m := range providerConfig.Models {
241			if m.ID == model {
242				return &m
243			}
244		}
245	}
246	return nil
247}
248
249func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
250	model, ok := c.Models[modelType]
251	if !ok {
252		return nil
253	}
254	if providerConfig, ok := c.Providers[model.Provider]; ok {
255		return &providerConfig
256	}
257	return nil
258}
259
260func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
261	model, ok := c.Models[modelType]
262	if !ok {
263		return nil
264	}
265	return c.GetModel(model.Provider, model.Model)
266}
267
268func (c *Config) LargeModel() *provider.Model {
269	model, ok := c.Models[SelectedModelTypeLarge]
270	if !ok {
271		return nil
272	}
273	return c.GetModel(model.Provider, model.Model)
274}
275
276func (c *Config) SmallModel() *provider.Model {
277	model, ok := c.Models[SelectedModelTypeSmall]
278	if !ok {
279		return nil
280	}
281	return c.GetModel(model.Provider, model.Model)
282}
283
284func (c *Config) SetCompactMode(enabled bool) error {
285	if c.Options == nil {
286		c.Options = &Options{}
287	}
288	c.Options.TUI.CompactMode = enabled
289	return c.SetConfigField("options.tui.compact_mode", enabled)
290}
291
292func (c *Config) Resolve(key string) (string, error) {
293	if c.resolver == nil {
294		return "", fmt.Errorf("no variable resolver configured")
295	}
296	return c.resolver.ResolveValue(key)
297}
298
299func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
300	c.Models[modelType] = model
301	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
302		return fmt.Errorf("failed to update preferred model: %w", err)
303	}
304	return nil
305}
306
307func (c *Config) SetConfigField(key string, value any) error {
308	// read the data
309	data, err := os.ReadFile(c.dataConfigDir)
310	if err != nil {
311		if os.IsNotExist(err) {
312			data = []byte("{}")
313		} else {
314			return fmt.Errorf("failed to read config file: %w", err)
315		}
316	}
317
318	newValue, err := sjson.Set(string(data), key, value)
319	if err != nil {
320		return fmt.Errorf("failed to set config field %s: %w", key, err)
321	}
322	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
323		return fmt.Errorf("failed to write config file: %w", err)
324	}
325	return nil
326}
327
328func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
329	// First save to the config file
330	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
331	if err != nil {
332		return fmt.Errorf("failed to save API key to config file: %w", err)
333	}
334
335	if c.Providers == nil {
336		c.Providers = make(map[string]ProviderConfig)
337	}
338
339	providerConfig, exists := c.Providers[providerID]
340	if exists {
341		providerConfig.APIKey = apiKey
342		c.Providers[providerID] = providerConfig
343		return nil
344	}
345
346	var foundProvider *provider.Provider
347	for _, p := range c.knownProviders {
348		if string(p.ID) == providerID {
349			foundProvider = &p
350			break
351		}
352	}
353
354	if foundProvider != nil {
355		// Create new provider config based on known provider
356		providerConfig = ProviderConfig{
357			ID:           providerID,
358			Name:         foundProvider.Name,
359			BaseURL:      foundProvider.APIEndpoint,
360			Type:         foundProvider.Type,
361			APIKey:       apiKey,
362			Disable:      false,
363			ExtraHeaders: make(map[string]string),
364			ExtraParams:  make(map[string]string),
365			Models:       foundProvider.Models,
366		}
367	} else {
368		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
369	}
370	// Store the updated provider config
371	c.Providers[providerID] = providerConfig
372	return nil
373}