config.go

  1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/env"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/tidwall/sjson"
 15	"golang.org/x/exp/slog"
 16)
 17
 18const (
 19	appName              = "crush"
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22)
 23
 24var defaultContextPaths = []string{
 25	".github/copilot-instructions.md",
 26	".cursorrules",
 27	".cursor/rules/",
 28	"CLAUDE.md",
 29	"CLAUDE.local.md",
 30	"GEMINI.md",
 31	"gemini.md",
 32	"crush.md",
 33	"crush.local.md",
 34	"Crush.md",
 35	"Crush.local.md",
 36	"CRUSH.md",
 37	"CRUSH.local.md",
 38}
 39
 40type SelectedModelType string
 41
 42const (
 43	SelectedModelTypeLarge SelectedModelType = "large"
 44	SelectedModelTypeSmall SelectedModelType = "small"
 45)
 46
 47type SelectedModel struct {
 48	// The model id as used by the provider API.
 49	// Required.
 50	Model string `json:"model"`
 51	// The model provider, same as the key/id used in the providers config.
 52	// Required.
 53	Provider string `json:"provider"`
 54
 55	// Only used by models that use the openai provider and need this set.
 56	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 57
 58	// Overrides the default model configuration.
 59	MaxTokens int64 `json:"max_tokens,omitempty"`
 60
 61	// Used by anthropic models that can reason to indicate if the model should think.
 62	Think bool `json:"think,omitempty"`
 63}
 64
 65type ProviderConfig struct {
 66	// The provider's id.
 67	ID string `json:"id,omitempty"`
 68	// The provider's name, used for display purposes.
 69	Name string `json:"name,omitempty"`
 70	// The provider's API endpoint.
 71	BaseURL string `json:"base_url,omitempty"`
 72	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 73	Type provider.Type `json:"type,omitempty"`
 74	// The provider's API key.
 75	APIKey string `json:"api_key,omitempty"`
 76	// Marks the provider as disabled.
 77	Disable bool `json:"disable,omitempty"`
 78
 79	// Extra headers to send with each request to the provider.
 80	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 81	// Extra body
 82	ExtraBody map[string]any `json:"extra_body,omitempty"`
 83
 84	// Used to pass extra parameters to the provider.
 85	ExtraParams map[string]string `json:"-"`
 86
 87	// The provider models
 88	Models []provider.Model `json:"models,omitempty"`
 89}
 90
 91type MCPType string
 92
 93const (
 94	MCPStdio MCPType = "stdio"
 95	MCPSse   MCPType = "sse"
 96	MCPHttp  MCPType = "http"
 97)
 98
 99type MCPConfig struct {
100	Command  string            `json:"command,omitempty" `
101	Env      map[string]string `json:"env,omitempty"`
102	Args     []string          `json:"args,omitempty"`
103	Type     MCPType           `json:"type"`
104	URL      string            `json:"url,omitempty"`
105	Disabled bool              `json:"disabled,omitempty"`
106
107	// TODO: maybe make it possible to get the value from the env
108	Headers map[string]string `json:"headers,omitempty"`
109}
110
111type LSPConfig struct {
112	Disabled bool     `json:"enabled,omitempty"`
113	Command  string   `json:"command"`
114	Args     []string `json:"args,omitempty"`
115	Options  any      `json:"options,omitempty"`
116}
117
118type TUIOptions struct {
119	CompactMode bool `json:"compact_mode,omitempty"`
120	// Here we can add themes later or any TUI related options
121}
122
123type Options struct {
124	ContextPaths            []string    `json:"context_paths,omitempty"`
125	TUI                     *TUIOptions `json:"tui,omitempty"`
126	Debug                   bool        `json:"debug,omitempty"`
127	DebugLSP                bool        `json:"debug_lsp,omitempty"`
128	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
129	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
130	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
131}
132
133type MCPs map[string]MCPConfig
134
135type MCP struct {
136	Name string    `json:"name"`
137	MCP  MCPConfig `json:"mcp"`
138}
139
140func (m MCPs) Sorted() []MCP {
141	sorted := make([]MCP, 0, len(m))
142	for k, v := range m {
143		sorted = append(sorted, MCP{
144			Name: k,
145			MCP:  v,
146		})
147	}
148	slices.SortFunc(sorted, func(a, b MCP) int {
149		return strings.Compare(a.Name, b.Name)
150	})
151	return sorted
152}
153
154type LSPs map[string]LSPConfig
155
156type LSP struct {
157	Name string    `json:"name"`
158	LSP  LSPConfig `json:"lsp"`
159}
160
161func (l LSPs) Sorted() []LSP {
162	sorted := make([]LSP, 0, len(l))
163	for k, v := range l {
164		sorted = append(sorted, LSP{
165			Name: k,
166			LSP:  v,
167		})
168	}
169	slices.SortFunc(sorted, func(a, b LSP) int {
170		return strings.Compare(a.Name, b.Name)
171	})
172	return sorted
173}
174
175func (m MCPConfig) ResolvedEnv() []string {
176	resolver := NewShellVariableResolver(env.New())
177	for e, v := range m.Env {
178		var err error
179		m.Env[e], err = resolver.ResolveValue(v)
180		if err != nil {
181			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
182			continue
183		}
184	}
185
186	env := make([]string, 0, len(m.Env))
187	for k, v := range m.Env {
188		env = append(env, fmt.Sprintf("%s=%s", k, v))
189	}
190	return env
191}
192
193func (m MCPConfig) ResolvedHeaders() map[string]string {
194	resolver := NewShellVariableResolver(env.New())
195	for e, v := range m.Headers {
196		var err error
197		m.Headers[e], err = resolver.ResolveValue(v)
198		if err != nil {
199			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
200			continue
201		}
202	}
203	return m.Headers
204}
205
206type Agent struct {
207	ID          string `json:"id,omitempty"`
208	Name        string `json:"name,omitempty"`
209	Description string `json:"description,omitempty"`
210	// This is the id of the system prompt used by the agent
211	Disabled bool `json:"disabled,omitempty"`
212
213	Model SelectedModelType `json:"model"`
214
215	// The available tools for the agent
216	//  if this is nil, all tools are available
217	AllowedTools []string `json:"allowed_tools,omitempty"`
218
219	// this tells us which MCPs are available for this agent
220	//  if this is empty all mcps are available
221	//  the string array is the list of tools from the AllowedMCP the agent has available
222	//  if the string array is nil, all tools from the AllowedMCP are available
223	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
224
225	// The list of LSPs that this agent can use
226	//  if this is nil, all LSPs are available
227	AllowedLSP []string `json:"allowed_lsp,omitempty"`
228
229	// Overrides the context paths for this agent
230	ContextPaths []string `json:"context_paths,omitempty"`
231}
232
233// Config holds the configuration for crush.
234type Config struct {
235	// We currently only support large/small as values here.
236	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
237
238	// The providers that are configured
239	Providers map[string]ProviderConfig `json:"providers,omitempty"`
240
241	MCP MCPs `json:"mcp,omitempty"`
242
243	LSP LSPs `json:"lsp,omitempty"`
244
245	Options *Options `json:"options,omitempty"`
246
247	// Internal
248	workingDir string `json:"-"`
249	// TODO: most likely remove this concept when I come back to it
250	Agents map[string]Agent `json:"-"`
251	// TODO: find a better way to do this this should probably not be part of the config
252	resolver       VariableResolver
253	dataConfigDir  string              `json:"-"`
254	knownProviders []provider.Provider `json:"-"`
255}
256
257func (c *Config) WorkingDir() string {
258	return c.workingDir
259}
260
261func (c *Config) EnabledProviders() []ProviderConfig {
262	enabled := make([]ProviderConfig, 0, len(c.Providers))
263	for _, p := range c.Providers {
264		if !p.Disable {
265			enabled = append(enabled, p)
266		}
267	}
268	return enabled
269}
270
271// IsConfigured  return true if at least one provider is configured
272func (c *Config) IsConfigured() bool {
273	return len(c.EnabledProviders()) > 0
274}
275
276func (c *Config) GetModel(provider, model string) *provider.Model {
277	if providerConfig, ok := c.Providers[provider]; ok {
278		for _, m := range providerConfig.Models {
279			if m.ID == model {
280				return &m
281			}
282		}
283	}
284	return nil
285}
286
287func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
288	model, ok := c.Models[modelType]
289	if !ok {
290		return nil
291	}
292	if providerConfig, ok := c.Providers[model.Provider]; ok {
293		return &providerConfig
294	}
295	return nil
296}
297
298func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
299	model, ok := c.Models[modelType]
300	if !ok {
301		return nil
302	}
303	return c.GetModel(model.Provider, model.Model)
304}
305
306func (c *Config) LargeModel() *provider.Model {
307	model, ok := c.Models[SelectedModelTypeLarge]
308	if !ok {
309		return nil
310	}
311	return c.GetModel(model.Provider, model.Model)
312}
313
314func (c *Config) SmallModel() *provider.Model {
315	model, ok := c.Models[SelectedModelTypeSmall]
316	if !ok {
317		return nil
318	}
319	return c.GetModel(model.Provider, model.Model)
320}
321
322func (c *Config) SetCompactMode(enabled bool) error {
323	if c.Options == nil {
324		c.Options = &Options{}
325	}
326	c.Options.TUI.CompactMode = enabled
327	return c.SetConfigField("options.tui.compact_mode", enabled)
328}
329
330func (c *Config) Resolve(key string) (string, error) {
331	if c.resolver == nil {
332		return "", fmt.Errorf("no variable resolver configured")
333	}
334	return c.resolver.ResolveValue(key)
335}
336
337func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
338	c.Models[modelType] = model
339	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
340		return fmt.Errorf("failed to update preferred model: %w", err)
341	}
342	return nil
343}
344
345func (c *Config) SetConfigField(key string, value any) error {
346	// read the data
347	data, err := os.ReadFile(c.dataConfigDir)
348	if err != nil {
349		if os.IsNotExist(err) {
350			data = []byte("{}")
351		} else {
352			return fmt.Errorf("failed to read config file: %w", err)
353		}
354	}
355
356	newValue, err := sjson.Set(string(data), key, value)
357	if err != nil {
358		return fmt.Errorf("failed to set config field %s: %w", key, err)
359	}
360	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
361		return fmt.Errorf("failed to write config file: %w", err)
362	}
363	return nil
364}
365
366func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
367	// First save to the config file
368	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
369	if err != nil {
370		return fmt.Errorf("failed to save API key to config file: %w", err)
371	}
372
373	if c.Providers == nil {
374		c.Providers = make(map[string]ProviderConfig)
375	}
376
377	providerConfig, exists := c.Providers[providerID]
378	if exists {
379		providerConfig.APIKey = apiKey
380		c.Providers[providerID] = providerConfig
381		return nil
382	}
383
384	var foundProvider *provider.Provider
385	for _, p := range c.knownProviders {
386		if string(p.ID) == providerID {
387			foundProvider = &p
388			break
389		}
390	}
391
392	if foundProvider != nil {
393		// Create new provider config based on known provider
394		providerConfig = ProviderConfig{
395			ID:           providerID,
396			Name:         foundProvider.Name,
397			BaseURL:      foundProvider.APIEndpoint,
398			Type:         foundProvider.Type,
399			APIKey:       apiKey,
400			Disable:      false,
401			ExtraHeaders: make(map[string]string),
402			ExtraParams:  make(map[string]string),
403			Models:       foundProvider.Models,
404		}
405	} else {
406		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
407	}
408	// Store the updated provider config
409	c.Providers[providerID] = providerConfig
410	return nil
411}
412
413func (c *Config) SetupAgents() {
414	agents := map[string]Agent{
415		"coder": {
416			ID:           "coder",
417			Name:         "Coder",
418			Description:  "An agent that helps with executing coding tasks.",
419			Model:        SelectedModelTypeLarge,
420			ContextPaths: c.Options.ContextPaths,
421			// All tools allowed
422		},
423		"task": {
424			ID:           "task",
425			Name:         "Task",
426			Description:  "An agent that helps with searching for context and finding implementation details.",
427			Model:        SelectedModelTypeLarge,
428			ContextPaths: c.Options.ContextPaths,
429			AllowedTools: []string{
430				"glob",
431				"grep",
432				"ls",
433				"sourcegraph",
434				"view",
435			},
436			// NO MCPs or LSPs by default
437			AllowedMCP: map[string][]string{},
438			AllowedLSP: []string{},
439		},
440	}
441	c.Agents = agents
442}
443
444func (c *Config) Resolver() VariableResolver {
445	return c.resolver
446}
447
448func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
449	testURL := ""
450	headers := make(map[string]string)
451	apiKey, _ := resolver.ResolveValue(c.APIKey)
452	switch c.Type {
453	case provider.TypeOpenAI:
454		baseURL, _ := resolver.ResolveValue(c.BaseURL)
455		if baseURL == "" {
456			baseURL = "https://api.openai.com/v1"
457		}
458		testURL = baseURL + "/models"
459		headers["Authorization"] = "Bearer " + apiKey
460	case provider.TypeAnthropic:
461		baseURL, _ := resolver.ResolveValue(c.BaseURL)
462		if baseURL == "" {
463			baseURL = "https://api.anthropic.com/v1"
464		}
465		testURL = baseURL + "/models"
466		headers["x-api-key"] = apiKey
467		headers["anthropic-version"] = "2023-06-01"
468	}
469	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
470	defer cancel()
471	client := &http.Client{}
472	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
473	if err != nil {
474		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
475	}
476	for k, v := range headers {
477		req.Header.Set(k, v)
478	}
479	for k, v := range c.ExtraHeaders {
480		req.Header.Set(k, v)
481	}
482	b, err := client.Do(req)
483	if err != nil {
484		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
485	}
486	if b.StatusCode != http.StatusOK {
487		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
488	}
489	_ = b.Body.Close()
490	return nil
491}