config.go

  1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/env"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/tidwall/sjson"
 15	"golang.org/x/exp/slog"
 16)
 17
 18const (
 19	appName              = "crush"
 20	defaultDataDirectory = ".crush"
 21	defaultLogLevel      = "info"
 22)
 23
 24var defaultContextPaths = []string{
 25	".github/copilot-instructions.md",
 26	".cursorrules",
 27	".cursor/rules/",
 28	"CLAUDE.md",
 29	"CLAUDE.local.md",
 30	"GEMINI.md",
 31	"gemini.md",
 32	"crush.md",
 33	"crush.local.md",
 34	"Crush.md",
 35	"Crush.local.md",
 36	"CRUSH.md",
 37	"CRUSH.local.md",
 38}
 39
 40type SelectedModelType string
 41
 42const (
 43	SelectedModelTypeLarge SelectedModelType = "large"
 44	SelectedModelTypeSmall SelectedModelType = "small"
 45)
 46
 47type SelectedModel struct {
 48	// The model id as used by the provider API.
 49	// Required.
 50	Model string `json:"model"`
 51	// The model provider, same as the key/id used in the providers config.
 52	// Required.
 53	Provider string `json:"provider"`
 54
 55	// Only used by models that use the openai provider and need this set.
 56	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 57
 58	// Overrides the default model configuration.
 59	MaxTokens int64 `json:"max_tokens,omitempty"`
 60
 61	// Used by anthropic models that can reason to indicate if the model should think.
 62	Think bool `json:"think,omitempty"`
 63}
 64
 65type ProviderConfig struct {
 66	// The provider's id.
 67	ID string `json:"id,omitempty"`
 68	// The provider's name, used for display purposes.
 69	Name string `json:"name,omitempty"`
 70	// The provider's API endpoint.
 71	BaseURL string `json:"base_url,omitempty"`
 72	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 73	Type provider.Type `json:"type,omitempty"`
 74	// The provider's API key.
 75	APIKey string `json:"api_key,omitempty"`
 76	// Marks the provider as disabled.
 77	Disable bool `json:"disable,omitempty"`
 78
 79	// Extra headers to send with each request to the provider.
 80	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 81
 82	// Used to pass extra parameters to the provider.
 83	ExtraParams map[string]string `json:"-"`
 84
 85	// The provider models
 86	Models []provider.Model `json:"models,omitempty"`
 87}
 88
 89type MCPType string
 90
 91const (
 92	MCPStdio MCPType = "stdio"
 93	MCPSse   MCPType = "sse"
 94	MCPHttp  MCPType = "http"
 95)
 96
 97type MCPConfig struct {
 98	Command  string            `json:"command,omitempty" `
 99	Env      map[string]string `json:"env,omitempty"`
100	Args     []string          `json:"args,omitempty"`
101	Type     MCPType           `json:"type"`
102	URL      string            `json:"url,omitempty"`
103	Disabled bool              `json:"disabled,omitempty"`
104
105	// TODO: maybe make it possible to get the value from the env
106	Headers map[string]string `json:"headers,omitempty"`
107}
108
109type LSPConfig struct {
110	Disabled bool     `json:"enabled,omitempty"`
111	Command  string   `json:"command"`
112	Args     []string `json:"args,omitempty"`
113	Options  any      `json:"options,omitempty"`
114}
115
116type TUIOptions struct {
117	CompactMode bool `json:"compact_mode,omitempty"`
118	// Here we can add themes later or any TUI related options
119}
120
121type Options struct {
122	ContextPaths            []string    `json:"context_paths,omitempty"`
123	TUI                     *TUIOptions `json:"tui,omitempty"`
124	Debug                   bool        `json:"debug,omitempty"`
125	DebugLSP                bool        `json:"debug_lsp,omitempty"`
126	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
127	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
128	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
129}
130
131type MCPs map[string]MCPConfig
132
133type MCP struct {
134	Name string    `json:"name"`
135	MCP  MCPConfig `json:"mcp"`
136}
137
138func (m MCPs) Sorted() []MCP {
139	sorted := make([]MCP, 0, len(m))
140	for k, v := range m {
141		sorted = append(sorted, MCP{
142			Name: k,
143			MCP:  v,
144		})
145	}
146	slices.SortFunc(sorted, func(a, b MCP) int {
147		return strings.Compare(a.Name, b.Name)
148	})
149	return sorted
150}
151
152type LSPs map[string]LSPConfig
153
154type LSP struct {
155	Name string    `json:"name"`
156	LSP  LSPConfig `json:"lsp"`
157}
158
159func (l LSPs) Sorted() []LSP {
160	sorted := make([]LSP, 0, len(l))
161	for k, v := range l {
162		sorted = append(sorted, LSP{
163			Name: k,
164			LSP:  v,
165		})
166	}
167	slices.SortFunc(sorted, func(a, b LSP) int {
168		return strings.Compare(a.Name, b.Name)
169	})
170	return sorted
171}
172
173func (m MCPConfig) ResolvedEnv() []string {
174	resolver := NewShellVariableResolver(env.New())
175	for e, v := range m.Env {
176		var err error
177		m.Env[e], err = resolver.ResolveValue(v)
178		if err != nil {
179			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
180			continue
181		}
182	}
183
184	env := make([]string, 0, len(m.Env))
185	for k, v := range m.Env {
186		env = append(env, fmt.Sprintf("%s=%s", k, v))
187	}
188	return env
189}
190
191func (m MCPConfig) ResolvedHeaders() map[string]string {
192	resolver := NewShellVariableResolver(env.New())
193	for e, v := range m.Headers {
194		var err error
195		m.Headers[e], err = resolver.ResolveValue(v)
196		if err != nil {
197			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
198			continue
199		}
200	}
201	return m.Headers
202}
203
204type Agent struct {
205	ID          string `json:"id,omitempty"`
206	Name        string `json:"name,omitempty"`
207	Description string `json:"description,omitempty"`
208	// This is the id of the system prompt used by the agent
209	Disabled bool `json:"disabled,omitempty"`
210
211	Model SelectedModelType `json:"model"`
212
213	// The available tools for the agent
214	//  if this is nil, all tools are available
215	AllowedTools []string `json:"allowed_tools,omitempty"`
216
217	// this tells us which MCPs are available for this agent
218	//  if this is empty all mcps are available
219	//  the string array is the list of tools from the AllowedMCP the agent has available
220	//  if the string array is nil, all tools from the AllowedMCP are available
221	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
222
223	// The list of LSPs that this agent can use
224	//  if this is nil, all LSPs are available
225	AllowedLSP []string `json:"allowed_lsp,omitempty"`
226
227	// Overrides the context paths for this agent
228	ContextPaths []string `json:"context_paths,omitempty"`
229}
230
231// Config holds the configuration for crush.
232type Config struct {
233	// We currently only support large/small as values here.
234	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
235
236	// The providers that are configured
237	Providers map[string]ProviderConfig `json:"providers,omitempty"`
238
239	MCP MCPs `json:"mcp,omitempty"`
240
241	LSP LSPs `json:"lsp,omitempty"`
242
243	Options *Options `json:"options,omitempty"`
244
245	// Internal
246	workingDir string `json:"-"`
247	// TODO: most likely remove this concept when I come back to it
248	Agents map[string]Agent `json:"-"`
249	// TODO: find a better way to do this this should probably not be part of the config
250	resolver       VariableResolver
251	dataConfigDir  string              `json:"-"`
252	knownProviders []provider.Provider `json:"-"`
253}
254
255func (c *Config) WorkingDir() string {
256	return c.workingDir
257}
258
259func (c *Config) EnabledProviders() []ProviderConfig {
260	enabled := make([]ProviderConfig, 0, len(c.Providers))
261	for _, p := range c.Providers {
262		if !p.Disable {
263			enabled = append(enabled, p)
264		}
265	}
266	return enabled
267}
268
269// IsConfigured  return true if at least one provider is configured
270func (c *Config) IsConfigured() bool {
271	return len(c.EnabledProviders()) > 0
272}
273
274func (c *Config) GetModel(provider, model string) *provider.Model {
275	if providerConfig, ok := c.Providers[provider]; ok {
276		for _, m := range providerConfig.Models {
277			if m.ID == model {
278				return &m
279			}
280		}
281	}
282	return nil
283}
284
285func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
286	model, ok := c.Models[modelType]
287	if !ok {
288		return nil
289	}
290	if providerConfig, ok := c.Providers[model.Provider]; ok {
291		return &providerConfig
292	}
293	return nil
294}
295
296func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
297	model, ok := c.Models[modelType]
298	if !ok {
299		return nil
300	}
301	return c.GetModel(model.Provider, model.Model)
302}
303
304func (c *Config) LargeModel() *provider.Model {
305	model, ok := c.Models[SelectedModelTypeLarge]
306	if !ok {
307		return nil
308	}
309	return c.GetModel(model.Provider, model.Model)
310}
311
312func (c *Config) SmallModel() *provider.Model {
313	model, ok := c.Models[SelectedModelTypeSmall]
314	if !ok {
315		return nil
316	}
317	return c.GetModel(model.Provider, model.Model)
318}
319
320func (c *Config) SetCompactMode(enabled bool) error {
321	if c.Options == nil {
322		c.Options = &Options{}
323	}
324	c.Options.TUI.CompactMode = enabled
325	return c.SetConfigField("options.tui.compact_mode", enabled)
326}
327
328func (c *Config) Resolve(key string) (string, error) {
329	if c.resolver == nil {
330		return "", fmt.Errorf("no variable resolver configured")
331	}
332	return c.resolver.ResolveValue(key)
333}
334
335func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
336	c.Models[modelType] = model
337	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
338		return fmt.Errorf("failed to update preferred model: %w", err)
339	}
340	return nil
341}
342
343func (c *Config) SetConfigField(key string, value any) error {
344	// read the data
345	data, err := os.ReadFile(c.dataConfigDir)
346	if err != nil {
347		if os.IsNotExist(err) {
348			data = []byte("{}")
349		} else {
350			return fmt.Errorf("failed to read config file: %w", err)
351		}
352	}
353
354	newValue, err := sjson.Set(string(data), key, value)
355	if err != nil {
356		return fmt.Errorf("failed to set config field %s: %w", key, err)
357	}
358	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
359		return fmt.Errorf("failed to write config file: %w", err)
360	}
361	return nil
362}
363
364func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
365	// First save to the config file
366	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
367	if err != nil {
368		return fmt.Errorf("failed to save API key to config file: %w", err)
369	}
370
371	if c.Providers == nil {
372		c.Providers = make(map[string]ProviderConfig)
373	}
374
375	providerConfig, exists := c.Providers[providerID]
376	if exists {
377		providerConfig.APIKey = apiKey
378		c.Providers[providerID] = providerConfig
379		return nil
380	}
381
382	var foundProvider *provider.Provider
383	for _, p := range c.knownProviders {
384		if string(p.ID) == providerID {
385			foundProvider = &p
386			break
387		}
388	}
389
390	if foundProvider != nil {
391		// Create new provider config based on known provider
392		providerConfig = ProviderConfig{
393			ID:           providerID,
394			Name:         foundProvider.Name,
395			BaseURL:      foundProvider.APIEndpoint,
396			Type:         foundProvider.Type,
397			APIKey:       apiKey,
398			Disable:      false,
399			ExtraHeaders: make(map[string]string),
400			ExtraParams:  make(map[string]string),
401			Models:       foundProvider.Models,
402		}
403	} else {
404		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
405	}
406	// Store the updated provider config
407	c.Providers[providerID] = providerConfig
408	return nil
409}
410
411func (c *Config) SetupAgents() {
412	agents := map[string]Agent{
413		"coder": {
414			ID:           "coder",
415			Name:         "Coder",
416			Description:  "An agent that helps with executing coding tasks.",
417			Model:        SelectedModelTypeLarge,
418			ContextPaths: c.Options.ContextPaths,
419			// All tools allowed
420		},
421		"task": {
422			ID:           "task",
423			Name:         "Task",
424			Description:  "An agent that helps with searching for context and finding implementation details.",
425			Model:        SelectedModelTypeLarge,
426			ContextPaths: c.Options.ContextPaths,
427			AllowedTools: []string{
428				"glob",
429				"grep",
430				"ls",
431				"sourcegraph",
432				"view",
433			},
434			// NO MCPs or LSPs by default
435			AllowedMCP: map[string][]string{},
436			AllowedLSP: []string{},
437		},
438	}
439	c.Agents = agents
440}
441
442func (c *Config) Resolver() VariableResolver {
443	return c.resolver
444}
445
446func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
447	testURL := ""
448	headers := make(map[string]string)
449	apiKey, _ := resolver.ResolveValue(c.APIKey)
450	switch c.Type {
451	case provider.TypeOpenAI:
452		baseURL, _ := resolver.ResolveValue(c.BaseURL)
453		if baseURL == "" {
454			baseURL = "https://api.openai.com/v1"
455		}
456		testURL = baseURL + "/models"
457		headers["Authorization"] = "Bearer " + apiKey
458	case provider.TypeAnthropic:
459		baseURL, _ := resolver.ResolveValue(c.BaseURL)
460		if baseURL == "" {
461			baseURL = "https://api.anthropic.com/v1"
462		}
463		testURL = baseURL + "/models"
464		headers["x-api-key"] = apiKey
465		headers["anthropic-version"] = "2023-06-01"
466	}
467	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
468	defer cancel()
469	client := &http.Client{}
470	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
471	if err != nil {
472		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
473	}
474	for k, v := range headers {
475		req.Header.Set(k, v)
476	}
477	for k, v := range c.ExtraHeaders {
478		req.Header.Set(k, v)
479	}
480	b, err := client.Do(req)
481	if err != nil {
482		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
483	}
484	if b.StatusCode != http.StatusOK {
485		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
486	}
487	_ = b.Body.Close()
488	return nil
489}