config.go

  1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/env"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/tidwall/sjson"
 15	"golang.org/x/exp/slog"
 16)
 17
 18const (
 19	appName              = "crush"
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22)
 23
 24var defaultContextPaths = []string{
 25	".github/copilot-instructions.md",
 26	".cursorrules",
 27	".cursor/rules/",
 28	"CLAUDE.md",
 29	"CLAUDE.local.md",
 30	"GEMINI.md",
 31	"gemini.md",
 32	"crush.md",
 33	"crush.local.md",
 34	"Crush.md",
 35	"Crush.local.md",
 36	"CRUSH.md",
 37	"CRUSH.local.md",
 38}
 39
 40type SelectedModelType string
 41
 42const (
 43	SelectedModelTypeLarge SelectedModelType = "large"
 44	SelectedModelTypeSmall SelectedModelType = "small"
 45)
 46
 47type SelectedModel struct {
 48	// The model id as used by the provider API.
 49	// Required.
 50	Model string `json:"model"`
 51	// The model provider, same as the key/id used in the providers config.
 52	// Required.
 53	Provider string `json:"provider"`
 54
 55	// Only used by models that use the openai provider and need this set.
 56	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 57
 58	// Overrides the default model configuration.
 59	MaxTokens int64 `json:"max_tokens,omitempty"`
 60
 61	// Used by anthropic models that can reason to indicate if the model should think.
 62	Think bool `json:"think,omitempty"`
 63}
 64
 65type ProviderConfig struct {
 66	// The provider's id.
 67	ID string `json:"id,omitempty"`
 68	// The provider's name, used for display purposes.
 69	Name string `json:"name,omitempty"`
 70	// The provider's API endpoint.
 71	BaseURL string `json:"base_url,omitempty"`
 72	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 73	Type provider.Type `json:"type,omitempty"`
 74	// The provider's API key.
 75	APIKey string `json:"api_key,omitempty"`
 76	// Marks the provider as disabled.
 77	Disable bool `json:"disable,omitempty"`
 78
 79	// Extra headers to send with each request to the provider.
 80	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 81	// Extra body
 82	ExtraBody map[string]any `json:"extra_body,omitempty"`
 83
 84	// Used to pass extra parameters to the provider.
 85	ExtraParams map[string]string `json:"-"`
 86
 87	// The provider models
 88	Models []provider.Model `json:"models,omitempty"`
 89}
 90
 91type MCPType string
 92
 93const (
 94	MCPStdio MCPType = "stdio"
 95	MCPSse   MCPType = "sse"
 96	MCPHttp  MCPType = "http"
 97)
 98
 99type MCPConfig struct {
100	Command  string            `json:"command,omitempty" `
101	Env      map[string]string `json:"env,omitempty"`
102	Args     []string          `json:"args,omitempty"`
103	Type     MCPType           `json:"type"`
104	URL      string            `json:"url,omitempty"`
105	Disabled bool              `json:"disabled,omitempty"`
106
107	// TODO: maybe make it possible to get the value from the env
108	Headers map[string]string `json:"headers,omitempty"`
109}
110
111type LSPConfig struct {
112	Disabled bool     `json:"enabled,omitempty"`
113	Command  string   `json:"command"`
114	Args     []string `json:"args,omitempty"`
115	Options  any      `json:"options,omitempty"`
116}
117
118type TUIOptions struct {
119	CompactMode bool `json:"compact_mode,omitempty"`
120	// Here we can add themes later or any TUI related options
121}
122
123type Permissions struct {
124	AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
125	SkipRequests bool     `json:"-"`                       // Automatically accept all permissions (YOLO mode)
126}
127
128type Options struct {
129	ContextPaths         []string    `json:"context_paths,omitempty"`
130	TUI                  *TUIOptions `json:"tui,omitempty"`
131	Debug                bool        `json:"debug,omitempty"`
132	DebugLSP             bool        `json:"debug_lsp,omitempty"`
133	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
134	DataDirectory        string      `json:"data_directory,omitempty"` // Relative to the cwd
135}
136
137type MCPs map[string]MCPConfig
138
139type MCP struct {
140	Name string    `json:"name"`
141	MCP  MCPConfig `json:"mcp"`
142}
143
144func (m MCPs) Sorted() []MCP {
145	sorted := make([]MCP, 0, len(m))
146	for k, v := range m {
147		sorted = append(sorted, MCP{
148			Name: k,
149			MCP:  v,
150		})
151	}
152	slices.SortFunc(sorted, func(a, b MCP) int {
153		return strings.Compare(a.Name, b.Name)
154	})
155	return sorted
156}
157
158type LSPs map[string]LSPConfig
159
160type LSP struct {
161	Name string    `json:"name"`
162	LSP  LSPConfig `json:"lsp"`
163}
164
165func (l LSPs) Sorted() []LSP {
166	sorted := make([]LSP, 0, len(l))
167	for k, v := range l {
168		sorted = append(sorted, LSP{
169			Name: k,
170			LSP:  v,
171		})
172	}
173	slices.SortFunc(sorted, func(a, b LSP) int {
174		return strings.Compare(a.Name, b.Name)
175	})
176	return sorted
177}
178
179func (m MCPConfig) ResolvedEnv() []string {
180	resolver := NewShellVariableResolver(env.New())
181	for e, v := range m.Env {
182		var err error
183		m.Env[e], err = resolver.ResolveValue(v)
184		if err != nil {
185			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
186			continue
187		}
188	}
189
190	env := make([]string, 0, len(m.Env))
191	for k, v := range m.Env {
192		env = append(env, fmt.Sprintf("%s=%s", k, v))
193	}
194	return env
195}
196
197func (m MCPConfig) ResolvedHeaders() map[string]string {
198	resolver := NewShellVariableResolver(env.New())
199	for e, v := range m.Headers {
200		var err error
201		m.Headers[e], err = resolver.ResolveValue(v)
202		if err != nil {
203			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
204			continue
205		}
206	}
207	return m.Headers
208}
209
210type Agent struct {
211	ID          string `json:"id,omitempty"`
212	Name        string `json:"name,omitempty"`
213	Description string `json:"description,omitempty"`
214	// This is the id of the system prompt used by the agent
215	Disabled bool `json:"disabled,omitempty"`
216
217	Model SelectedModelType `json:"model"`
218
219	// The available tools for the agent
220	//  if this is nil, all tools are available
221	AllowedTools []string `json:"allowed_tools,omitempty"`
222
223	// this tells us which MCPs are available for this agent
224	//  if this is empty all mcps are available
225	//  the string array is the list of tools from the AllowedMCP the agent has available
226	//  if the string array is nil, all tools from the AllowedMCP are available
227	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
228
229	// The list of LSPs that this agent can use
230	//  if this is nil, all LSPs are available
231	AllowedLSP []string `json:"allowed_lsp,omitempty"`
232
233	// Overrides the context paths for this agent
234	ContextPaths []string `json:"context_paths,omitempty"`
235}
236
237// Config holds the configuration for crush.
238type Config struct {
239	// We currently only support large/small as values here.
240	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
241
242	// The providers that are configured
243	Providers map[string]ProviderConfig `json:"providers,omitempty"`
244
245	MCP MCPs `json:"mcp,omitempty"`
246
247	LSP LSPs `json:"lsp,omitempty"`
248
249	Options *Options `json:"options,omitempty"`
250
251	Permissions *Permissions `json:"permissions,omitempty"`
252
253	// Internal
254	workingDir string `json:"-"`
255	// TODO: most likely remove this concept when I come back to it
256	Agents map[string]Agent `json:"-"`
257	// TODO: find a better way to do this this should probably not be part of the config
258	resolver       VariableResolver
259	dataConfigDir  string              `json:"-"`
260	knownProviders []provider.Provider `json:"-"`
261}
262
263func (c *Config) WorkingDir() string {
264	return c.workingDir
265}
266
267func (c *Config) EnabledProviders() []ProviderConfig {
268	enabled := make([]ProviderConfig, 0, len(c.Providers))
269	for _, p := range c.Providers {
270		if !p.Disable {
271			enabled = append(enabled, p)
272		}
273	}
274	return enabled
275}
276
277// IsConfigured  return true if at least one provider is configured
278func (c *Config) IsConfigured() bool {
279	return len(c.EnabledProviders()) > 0
280}
281
282func (c *Config) GetModel(provider, model string) *provider.Model {
283	if providerConfig, ok := c.Providers[provider]; ok {
284		for _, m := range providerConfig.Models {
285			if m.ID == model {
286				return &m
287			}
288		}
289	}
290	return nil
291}
292
293func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
294	model, ok := c.Models[modelType]
295	if !ok {
296		return nil
297	}
298	if providerConfig, ok := c.Providers[model.Provider]; ok {
299		return &providerConfig
300	}
301	return nil
302}
303
304func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
305	model, ok := c.Models[modelType]
306	if !ok {
307		return nil
308	}
309	return c.GetModel(model.Provider, model.Model)
310}
311
312func (c *Config) LargeModel() *provider.Model {
313	model, ok := c.Models[SelectedModelTypeLarge]
314	if !ok {
315		return nil
316	}
317	return c.GetModel(model.Provider, model.Model)
318}
319
320func (c *Config) SmallModel() *provider.Model {
321	model, ok := c.Models[SelectedModelTypeSmall]
322	if !ok {
323		return nil
324	}
325	return c.GetModel(model.Provider, model.Model)
326}
327
328func (c *Config) SetCompactMode(enabled bool) error {
329	if c.Options == nil {
330		c.Options = &Options{}
331	}
332	c.Options.TUI.CompactMode = enabled
333	return c.SetConfigField("options.tui.compact_mode", enabled)
334}
335
336func (c *Config) Resolve(key string) (string, error) {
337	if c.resolver == nil {
338		return "", fmt.Errorf("no variable resolver configured")
339	}
340	return c.resolver.ResolveValue(key)
341}
342
343func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
344	c.Models[modelType] = model
345	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
346		return fmt.Errorf("failed to update preferred model: %w", err)
347	}
348	return nil
349}
350
351func (c *Config) SetConfigField(key string, value any) error {
352	// read the data
353	data, err := os.ReadFile(c.dataConfigDir)
354	if err != nil {
355		if os.IsNotExist(err) {
356			data = []byte("{}")
357		} else {
358			return fmt.Errorf("failed to read config file: %w", err)
359		}
360	}
361
362	newValue, err := sjson.Set(string(data), key, value)
363	if err != nil {
364		return fmt.Errorf("failed to set config field %s: %w", key, err)
365	}
366	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
367		return fmt.Errorf("failed to write config file: %w", err)
368	}
369	return nil
370}
371
372func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
373	// First save to the config file
374	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
375	if err != nil {
376		return fmt.Errorf("failed to save API key to config file: %w", err)
377	}
378
379	if c.Providers == nil {
380		c.Providers = make(map[string]ProviderConfig)
381	}
382
383	providerConfig, exists := c.Providers[providerID]
384	if exists {
385		providerConfig.APIKey = apiKey
386		c.Providers[providerID] = providerConfig
387		return nil
388	}
389
390	var foundProvider *provider.Provider
391	for _, p := range c.knownProviders {
392		if string(p.ID) == providerID {
393			foundProvider = &p
394			break
395		}
396	}
397
398	if foundProvider != nil {
399		// Create new provider config based on known provider
400		providerConfig = ProviderConfig{
401			ID:           providerID,
402			Name:         foundProvider.Name,
403			BaseURL:      foundProvider.APIEndpoint,
404			Type:         foundProvider.Type,
405			APIKey:       apiKey,
406			Disable:      false,
407			ExtraHeaders: make(map[string]string),
408			ExtraParams:  make(map[string]string),
409			Models:       foundProvider.Models,
410		}
411	} else {
412		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
413	}
414	// Store the updated provider config
415	c.Providers[providerID] = providerConfig
416	return nil
417}
418
419func (c *Config) SetupAgents() {
420	agents := map[string]Agent{
421		"coder": {
422			ID:           "coder",
423			Name:         "Coder",
424			Description:  "An agent that helps with executing coding tasks.",
425			Model:        SelectedModelTypeLarge,
426			ContextPaths: c.Options.ContextPaths,
427			// All tools allowed
428		},
429		"task": {
430			ID:           "task",
431			Name:         "Task",
432			Description:  "An agent that helps with searching for context and finding implementation details.",
433			Model:        SelectedModelTypeLarge,
434			ContextPaths: c.Options.ContextPaths,
435			AllowedTools: []string{
436				"glob",
437				"grep",
438				"ls",
439				"sourcegraph",
440				"view",
441			},
442			// NO MCPs or LSPs by default
443			AllowedMCP: map[string][]string{},
444			AllowedLSP: []string{},
445		},
446	}
447	c.Agents = agents
448}
449
450func (c *Config) Resolver() VariableResolver {
451	return c.resolver
452}
453
454func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
455	testURL := ""
456	headers := make(map[string]string)
457	apiKey, _ := resolver.ResolveValue(c.APIKey)
458	switch c.Type {
459	case provider.TypeOpenAI:
460		baseURL, _ := resolver.ResolveValue(c.BaseURL)
461		if baseURL == "" {
462			baseURL = "https://api.openai.com/v1"
463		}
464		testURL = baseURL + "/models"
465		headers["Authorization"] = "Bearer " + apiKey
466	case provider.TypeAnthropic:
467		baseURL, _ := resolver.ResolveValue(c.BaseURL)
468		if baseURL == "" {
469			baseURL = "https://api.anthropic.com/v1"
470		}
471		testURL = baseURL + "/models"
472		headers["x-api-key"] = apiKey
473		headers["anthropic-version"] = "2023-06-01"
474	}
475	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
476	defer cancel()
477	client := &http.Client{}
478	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
479	if err != nil {
480		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
481	}
482	for k, v := range headers {
483		req.Header.Set(k, v)
484	}
485	for k, v := range c.ExtraHeaders {
486		req.Header.Set(k, v)
487	}
488	b, err := client.Do(req)
489	if err != nil {
490		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
491	}
492	if b.StatusCode != http.StatusOK {
493		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
494	}
495	_ = b.Body.Close()
496	return nil
497}