config.go

  1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/env"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/tidwall/sjson"
 15	"golang.org/x/exp/slog"
 16)
 17
 18const (
 19	appName              = "crush"
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22)
 23
 24var defaultContextPaths = []string{
 25	".github/copilot-instructions.md",
 26	".cursorrules",
 27	".cursor/rules/",
 28	"CLAUDE.md",
 29	"CLAUDE.local.md",
 30	"GEMINI.md",
 31	"gemini.md",
 32	"crush.md",
 33	"crush.local.md",
 34	"Crush.md",
 35	"Crush.local.md",
 36	"CRUSH.md",
 37	"CRUSH.local.md",
 38}
 39
 40type SelectedModelType string
 41
 42const (
 43	SelectedModelTypeLarge SelectedModelType = "large"
 44	SelectedModelTypeSmall SelectedModelType = "small"
 45)
 46
 47type SelectedModel struct {
 48	// The model id as used by the provider API.
 49	// Required.
 50	Model string `json:"model"`
 51	// The model provider, same as the key/id used in the providers config.
 52	// Required.
 53	Provider string `json:"provider"`
 54
 55	// Only used by models that use the openai provider and need this set.
 56	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 57
 58	// Overrides the default model configuration.
 59	MaxTokens int64 `json:"max_tokens,omitempty"`
 60
 61	// Used by anthropic models that can reason to indicate if the model should think.
 62	Think bool `json:"think,omitempty"`
 63}
 64
 65type ProviderConfig struct {
 66	// The provider's id.
 67	ID string `json:"id,omitempty"`
 68	// The provider's name, used for display purposes.
 69	Name string `json:"name,omitempty"`
 70	// The provider's API endpoint.
 71	BaseURL string `json:"base_url,omitempty"`
 72	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 73	Type provider.Type `json:"type,omitempty"`
 74	// The provider's API key.
 75	APIKey string `json:"api_key,omitempty"`
 76	// Marks the provider as disabled.
 77	Disable bool `json:"disable,omitempty"`
 78
 79	// Extra headers to send with each request to the provider.
 80	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 81	// Extra body
 82	ExtraBody map[string]any `json:"extra_body,omitempty"`
 83
 84	// Used to pass extra parameters to the provider.
 85	ExtraParams map[string]string `json:"-"`
 86
 87	// The provider models
 88	Models []provider.Model `json:"models,omitempty"`
 89}
 90
 91type MCPType string
 92
 93const (
 94	MCPStdio MCPType = "stdio"
 95	MCPSse   MCPType = "sse"
 96	MCPHttp  MCPType = "http"
 97)
 98
 99type MCPConfig struct {
100	Command  string            `json:"command,omitempty" `
101	Env      map[string]string `json:"env,omitempty"`
102	Args     []string          `json:"args,omitempty"`
103	Type     MCPType           `json:"type"`
104	URL      string            `json:"url,omitempty"`
105	Disabled bool              `json:"disabled,omitempty"`
106
107	// TODO: maybe make it possible to get the value from the env
108	Headers map[string]string `json:"headers,omitempty"`
109}
110
111type LSPConfig struct {
112	Disabled bool     `json:"enabled,omitempty"`
113	Command  string   `json:"command"`
114	Args     []string `json:"args,omitempty"`
115	Options  any      `json:"options,omitempty"`
116}
117
118type TUIOptions struct {
119	CompactMode bool `json:"compact_mode,omitempty"`
120	// Here we can add themes later or any TUI related options
121}
122
123type Options struct {
124	ContextPaths            []string    `json:"context_paths,omitempty"`
125	TUI                     *TUIOptions `json:"tui,omitempty"`
126	Debug                   bool        `json:"debug,omitempty"`
127	DebugLSP                bool        `json:"debug_lsp,omitempty"`
128	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
129	DataDirectory           string      `json:"data_directory,omitempty"`   // Relative to the cwd
130	SkipPermissionsRequests bool        `json:"-"`                          // Automatically accept all permissions (YOLO mode)
131	AllowedCommands         []string    `json:"allowed_commands,omitempty"` // Commands that don't require permission prompts
132}
133
134type MCPs map[string]MCPConfig
135
136type MCP struct {
137	Name string    `json:"name"`
138	MCP  MCPConfig `json:"mcp"`
139}
140
141func (m MCPs) Sorted() []MCP {
142	sorted := make([]MCP, 0, len(m))
143	for k, v := range m {
144		sorted = append(sorted, MCP{
145			Name: k,
146			MCP:  v,
147		})
148	}
149	slices.SortFunc(sorted, func(a, b MCP) int {
150		return strings.Compare(a.Name, b.Name)
151	})
152	return sorted
153}
154
155type LSPs map[string]LSPConfig
156
157type LSP struct {
158	Name string    `json:"name"`
159	LSP  LSPConfig `json:"lsp"`
160}
161
162func (l LSPs) Sorted() []LSP {
163	sorted := make([]LSP, 0, len(l))
164	for k, v := range l {
165		sorted = append(sorted, LSP{
166			Name: k,
167			LSP:  v,
168		})
169	}
170	slices.SortFunc(sorted, func(a, b LSP) int {
171		return strings.Compare(a.Name, b.Name)
172	})
173	return sorted
174}
175
176func (m MCPConfig) ResolvedEnv() []string {
177	resolver := NewShellVariableResolver(env.New())
178	for e, v := range m.Env {
179		var err error
180		m.Env[e], err = resolver.ResolveValue(v)
181		if err != nil {
182			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
183			continue
184		}
185	}
186
187	env := make([]string, 0, len(m.Env))
188	for k, v := range m.Env {
189		env = append(env, fmt.Sprintf("%s=%s", k, v))
190	}
191	return env
192}
193
194func (m MCPConfig) ResolvedHeaders() map[string]string {
195	resolver := NewShellVariableResolver(env.New())
196	for e, v := range m.Headers {
197		var err error
198		m.Headers[e], err = resolver.ResolveValue(v)
199		if err != nil {
200			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
201			continue
202		}
203	}
204	return m.Headers
205}
206
207type Agent struct {
208	ID          string `json:"id,omitempty"`
209	Name        string `json:"name,omitempty"`
210	Description string `json:"description,omitempty"`
211	// This is the id of the system prompt used by the agent
212	Disabled bool `json:"disabled,omitempty"`
213
214	Model SelectedModelType `json:"model"`
215
216	// The available tools for the agent
217	//  if this is nil, all tools are available
218	AllowedTools []string `json:"allowed_tools,omitempty"`
219
220	// this tells us which MCPs are available for this agent
221	//  if this is empty all mcps are available
222	//  the string array is the list of tools from the AllowedMCP the agent has available
223	//  if the string array is nil, all tools from the AllowedMCP are available
224	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
225
226	// The list of LSPs that this agent can use
227	//  if this is nil, all LSPs are available
228	AllowedLSP []string `json:"allowed_lsp,omitempty"`
229
230	// Overrides the context paths for this agent
231	ContextPaths []string `json:"context_paths,omitempty"`
232}
233
234// Config holds the configuration for crush.
235type Config struct {
236	// We currently only support large/small as values here.
237	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
238
239	// The providers that are configured
240	Providers map[string]ProviderConfig `json:"providers,omitempty"`
241
242	MCP MCPs `json:"mcp,omitempty"`
243
244	LSP LSPs `json:"lsp,omitempty"`
245
246	Options *Options `json:"options,omitempty"`
247
248	// Internal
249	workingDir string `json:"-"`
250	// TODO: most likely remove this concept when I come back to it
251	Agents map[string]Agent `json:"-"`
252	// TODO: find a better way to do this this should probably not be part of the config
253	resolver       VariableResolver
254	dataConfigDir  string              `json:"-"`
255	knownProviders []provider.Provider `json:"-"`
256}
257
258func (c *Config) WorkingDir() string {
259	return c.workingDir
260}
261
262func (c *Config) EnabledProviders() []ProviderConfig {
263	enabled := make([]ProviderConfig, 0, len(c.Providers))
264	for _, p := range c.Providers {
265		if !p.Disable {
266			enabled = append(enabled, p)
267		}
268	}
269	return enabled
270}
271
272// IsConfigured  return true if at least one provider is configured
273func (c *Config) IsConfigured() bool {
274	return len(c.EnabledProviders()) > 0
275}
276
277func (c *Config) GetModel(provider, model string) *provider.Model {
278	if providerConfig, ok := c.Providers[provider]; ok {
279		for _, m := range providerConfig.Models {
280			if m.ID == model {
281				return &m
282			}
283		}
284	}
285	return nil
286}
287
288func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
289	model, ok := c.Models[modelType]
290	if !ok {
291		return nil
292	}
293	if providerConfig, ok := c.Providers[model.Provider]; ok {
294		return &providerConfig
295	}
296	return nil
297}
298
299func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
300	model, ok := c.Models[modelType]
301	if !ok {
302		return nil
303	}
304	return c.GetModel(model.Provider, model.Model)
305}
306
307func (c *Config) LargeModel() *provider.Model {
308	model, ok := c.Models[SelectedModelTypeLarge]
309	if !ok {
310		return nil
311	}
312	return c.GetModel(model.Provider, model.Model)
313}
314
315func (c *Config) SmallModel() *provider.Model {
316	model, ok := c.Models[SelectedModelTypeSmall]
317	if !ok {
318		return nil
319	}
320	return c.GetModel(model.Provider, model.Model)
321}
322
323func (c *Config) SetCompactMode(enabled bool) error {
324	if c.Options == nil {
325		c.Options = &Options{}
326	}
327	c.Options.TUI.CompactMode = enabled
328	return c.SetConfigField("options.tui.compact_mode", enabled)
329}
330
331func (c *Config) Resolve(key string) (string, error) {
332	if c.resolver == nil {
333		return "", fmt.Errorf("no variable resolver configured")
334	}
335	return c.resolver.ResolveValue(key)
336}
337
338func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
339	c.Models[modelType] = model
340	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
341		return fmt.Errorf("failed to update preferred model: %w", err)
342	}
343	return nil
344}
345
346func (c *Config) SetConfigField(key string, value any) error {
347	// read the data
348	data, err := os.ReadFile(c.dataConfigDir)
349	if err != nil {
350		if os.IsNotExist(err) {
351			data = []byte("{}")
352		} else {
353			return fmt.Errorf("failed to read config file: %w", err)
354		}
355	}
356
357	newValue, err := sjson.Set(string(data), key, value)
358	if err != nil {
359		return fmt.Errorf("failed to set config field %s: %w", key, err)
360	}
361	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
362		return fmt.Errorf("failed to write config file: %w", err)
363	}
364	return nil
365}
366
367func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
368	// First save to the config file
369	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
370	if err != nil {
371		return fmt.Errorf("failed to save API key to config file: %w", err)
372	}
373
374	if c.Providers == nil {
375		c.Providers = make(map[string]ProviderConfig)
376	}
377
378	providerConfig, exists := c.Providers[providerID]
379	if exists {
380		providerConfig.APIKey = apiKey
381		c.Providers[providerID] = providerConfig
382		return nil
383	}
384
385	var foundProvider *provider.Provider
386	for _, p := range c.knownProviders {
387		if string(p.ID) == providerID {
388			foundProvider = &p
389			break
390		}
391	}
392
393	if foundProvider != nil {
394		// Create new provider config based on known provider
395		providerConfig = ProviderConfig{
396			ID:           providerID,
397			Name:         foundProvider.Name,
398			BaseURL:      foundProvider.APIEndpoint,
399			Type:         foundProvider.Type,
400			APIKey:       apiKey,
401			Disable:      false,
402			ExtraHeaders: make(map[string]string),
403			ExtraParams:  make(map[string]string),
404			Models:       foundProvider.Models,
405		}
406	} else {
407		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
408	}
409	// Store the updated provider config
410	c.Providers[providerID] = providerConfig
411	return nil
412}
413
414func (c *Config) SetupAgents() {
415	agents := map[string]Agent{
416		"coder": {
417			ID:           "coder",
418			Name:         "Coder",
419			Description:  "An agent that helps with executing coding tasks.",
420			Model:        SelectedModelTypeLarge,
421			ContextPaths: c.Options.ContextPaths,
422			// All tools allowed
423		},
424		"task": {
425			ID:           "task",
426			Name:         "Task",
427			Description:  "An agent that helps with searching for context and finding implementation details.",
428			Model:        SelectedModelTypeLarge,
429			ContextPaths: c.Options.ContextPaths,
430			AllowedTools: []string{
431				"glob",
432				"grep",
433				"ls",
434				"sourcegraph",
435				"view",
436			},
437			// NO MCPs or LSPs by default
438			AllowedMCP: map[string][]string{},
439			AllowedLSP: []string{},
440		},
441	}
442	c.Agents = agents
443}
444
445func (c *Config) Resolver() VariableResolver {
446	return c.resolver
447}
448
449func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
450	testURL := ""
451	headers := make(map[string]string)
452	apiKey, _ := resolver.ResolveValue(c.APIKey)
453	switch c.Type {
454	case provider.TypeOpenAI:
455		baseURL, _ := resolver.ResolveValue(c.BaseURL)
456		if baseURL == "" {
457			baseURL = "https://api.openai.com/v1"
458		}
459		testURL = baseURL + "/models"
460		headers["Authorization"] = "Bearer " + apiKey
461	case provider.TypeAnthropic:
462		baseURL, _ := resolver.ResolveValue(c.BaseURL)
463		if baseURL == "" {
464			baseURL = "https://api.anthropic.com/v1"
465		}
466		testURL = baseURL + "/models"
467		headers["x-api-key"] = apiKey
468		headers["anthropic-version"] = "2023-06-01"
469	}
470	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
471	defer cancel()
472	client := &http.Client{}
473	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
474	if err != nil {
475		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
476	}
477	for k, v := range headers {
478		req.Header.Set(k, v)
479	}
480	for k, v := range c.ExtraHeaders {
481		req.Header.Set(k, v)
482	}
483	b, err := client.Do(req)
484	if err != nil {
485		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
486	}
487	if b.StatusCode != http.StatusOK {
488		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
489	}
490	_ = b.Body.Close()
491	return nil
492}