1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/env"
 10	"github.com/charmbracelet/crush/internal/fur/provider"
 11	"github.com/tidwall/sjson"
 12	"golang.org/x/exp/slog"
 13)
 14
 15const (
 16	appName              = "crush"
 17	defaultDataDirectory = ".crush"
 18	defaultLogLevel      = "info"
 19)
 20
 21var defaultContextPaths = []string{
 22	".github/copilot-instructions.md",
 23	".cursorrules",
 24	".cursor/rules/",
 25	"CLAUDE.md",
 26	"CLAUDE.local.md",
 27	"GEMINI.md",
 28	"gemini.md",
 29	"crush.md",
 30	"crush.local.md",
 31	"Crush.md",
 32	"Crush.local.md",
 33	"CRUSH.md",
 34	"CRUSH.local.md",
 35}
 36
 37type SelectedModelType string
 38
 39const (
 40	SelectedModelTypeLarge SelectedModelType = "large"
 41	SelectedModelTypeSmall SelectedModelType = "small"
 42)
 43
 44type SelectedModel struct {
 45	// The model id as used by the provider API.
 46	// Required.
 47	Model string `json:"model"`
 48	// The model provider, same as the key/id used in the providers config.
 49	// Required.
 50	Provider string `json:"provider"`
 51
 52	// Only used by models that use the openai provider and need this set.
 53	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 54
 55	// Overrides the default model configuration.
 56	MaxTokens int64 `json:"max_tokens,omitempty"`
 57
 58	// Used by anthropic models that can reason to indicate if the model should think.
 59	Think bool `json:"think,omitempty"`
 60}
 61
 62type ProviderConfig struct {
 63	// The provider's id.
 64	ID string `json:"id,omitempty"`
 65	// The provider's name, used for display purposes.
 66	Name string `json:"name,omitempty"`
 67	// The provider's API endpoint.
 68	BaseURL string `json:"base_url,omitempty"`
 69	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 70	Type provider.Type `json:"type,omitempty"`
 71	// The provider's API key.
 72	APIKey string `json:"api_key,omitempty"`
 73	// Marks the provider as disabled.
 74	Disable bool `json:"disable,omitempty"`
 75
 76	// Extra headers to send with each request to the provider.
 77	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 78
 79	// Used to pass extra parameters to the provider.
 80	ExtraParams map[string]string `json:"-"`
 81
 82	// The provider models
 83	Models []provider.Model `json:"models,omitempty"`
 84}
 85
 86type MCPType string
 87
 88const (
 89	MCPStdio MCPType = "stdio"
 90	MCPSse   MCPType = "sse"
 91	MCPHttp  MCPType = "http"
 92)
 93
 94type MCPConfig struct {
 95	Command  string            `json:"command,omitempty" `
 96	Env      map[string]string `json:"env,omitempty"`
 97	Args     []string          `json:"args,omitempty"`
 98	Type     MCPType           `json:"type"`
 99	URL      string            `json:"url,omitempty"`
100	Disabled bool              `json:"disabled,omitempty"`
101
102	// TODO: maybe make it possible to get the value from the env
103	Headers map[string]string `json:"headers,omitempty"`
104}
105
106type LSPConfig struct {
107	Disabled bool     `json:"enabled,omitempty"`
108	Command  string   `json:"command"`
109	Args     []string `json:"args,omitempty"`
110	Options  any      `json:"options,omitempty"`
111}
112
113type TUIOptions struct {
114	CompactMode bool `json:"compact_mode,omitempty"`
115	// Here we can add themes later or any TUI related options
116}
117
118type Options struct {
119	ContextPaths            []string    `json:"context_paths,omitempty"`
120	TUI                     *TUIOptions `json:"tui,omitempty"`
121	Debug                   bool        `json:"debug,omitempty"`
122	DebugLSP                bool        `json:"debug_lsp,omitempty"`
123	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
124	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
125	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
126}
127
128type MCPs map[string]MCPConfig
129
130type MCP struct {
131	Name string    `json:"name"`
132	MCP  MCPConfig `json:"mcp"`
133}
134
135func (m MCPs) Sorted() []MCP {
136	sorted := make([]MCP, 0, len(m))
137	for k, v := range m {
138		sorted = append(sorted, MCP{
139			Name: k,
140			MCP:  v,
141		})
142	}
143	slices.SortFunc(sorted, func(a, b MCP) int {
144		return strings.Compare(a.Name, b.Name)
145	})
146	return sorted
147}
148
149type LSPs map[string]LSPConfig
150
151type LSP struct {
152	Name string    `json:"name"`
153	LSP  LSPConfig `json:"lsp"`
154}
155
156func (l LSPs) Sorted() []LSP {
157	sorted := make([]LSP, 0, len(l))
158	for k, v := range l {
159		sorted = append(sorted, LSP{
160			Name: k,
161			LSP:  v,
162		})
163	}
164	slices.SortFunc(sorted, func(a, b LSP) int {
165		return strings.Compare(a.Name, b.Name)
166	})
167	return sorted
168}
169
170func (m MCPConfig) ResolvedEnv() []string {
171	resolver := NewShellVariableResolver(env.New())
172	for e, v := range m.Env {
173		var err error
174		m.Env[e], err = resolver.ResolveValue(v)
175		if err != nil {
176			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
177			continue
178		}
179	}
180
181	env := make([]string, 0, len(m.Env))
182	for k, v := range m.Env {
183		env = append(env, fmt.Sprintf("%s=%s", k, v))
184	}
185	return env
186}
187
188func (m MCPConfig) ResolvedHeaders() map[string]string {
189	resolver := NewShellVariableResolver(env.New())
190	for e, v := range m.Headers {
191		var err error
192		m.Headers[e], err = resolver.ResolveValue(v)
193		if err != nil {
194			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
195			continue
196		}
197	}
198	return m.Headers
199}
200
201type Agent struct {
202	ID          string `json:"id,omitempty"`
203	Name        string `json:"name,omitempty"`
204	Description string `json:"description,omitempty"`
205	// This is the id of the system prompt used by the agent
206	Disabled bool `json:"disabled,omitempty"`
207
208	Model SelectedModelType `json:"model"`
209
210	// The available tools for the agent
211	//  if this is nil, all tools are available
212	AllowedTools []string `json:"allowed_tools,omitempty"`
213
214	// this tells us which MCPs are available for this agent
215	//  if this is empty all mcps are available
216	//  the string array is the list of tools from the AllowedMCP the agent has available
217	//  if the string array is nil, all tools from the AllowedMCP are available
218	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
219
220	// The list of LSPs that this agent can use
221	//  if this is nil, all LSPs are available
222	AllowedLSP []string `json:"allowed_lsp,omitempty"`
223
224	// Overrides the context paths for this agent
225	ContextPaths []string `json:"context_paths,omitempty"`
226}
227
228// Config holds the configuration for crush.
229type Config struct {
230	// We currently only support large/small as values here.
231	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
232
233	// The providers that are configured
234	Providers map[string]ProviderConfig `json:"providers,omitempty"`
235
236	MCP MCPs `json:"mcp,omitempty"`
237
238	LSP LSPs `json:"lsp,omitempty"`
239
240	Options *Options `json:"options,omitempty"`
241
242	// Internal
243	workingDir string `json:"-"`
244	// TODO: most likely remove this concept when I come back to it
245	Agents map[string]Agent `json:"-"`
246	// TODO: find a better way to do this this should probably not be part of the config
247	resolver       VariableResolver
248	dataConfigDir  string              `json:"-"`
249	knownProviders []provider.Provider `json:"-"`
250}
251
252func (c *Config) WorkingDir() string {
253	return c.workingDir
254}
255
256func (c *Config) EnabledProviders() []ProviderConfig {
257	enabled := make([]ProviderConfig, 0, len(c.Providers))
258	for _, p := range c.Providers {
259		if !p.Disable {
260			enabled = append(enabled, p)
261		}
262	}
263	return enabled
264}
265
266// IsConfigured  return true if at least one provider is configured
267func (c *Config) IsConfigured() bool {
268	return len(c.EnabledProviders()) > 0
269}
270
271func (c *Config) GetModel(provider, model string) *provider.Model {
272	if providerConfig, ok := c.Providers[provider]; ok {
273		for _, m := range providerConfig.Models {
274			if m.ID == model {
275				return &m
276			}
277		}
278	}
279	return nil
280}
281
282func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
283	model, ok := c.Models[modelType]
284	if !ok {
285		return nil
286	}
287	if providerConfig, ok := c.Providers[model.Provider]; ok {
288		return &providerConfig
289	}
290	return nil
291}
292
293func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
294	model, ok := c.Models[modelType]
295	if !ok {
296		return nil
297	}
298	return c.GetModel(model.Provider, model.Model)
299}
300
301func (c *Config) LargeModel() *provider.Model {
302	model, ok := c.Models[SelectedModelTypeLarge]
303	if !ok {
304		return nil
305	}
306	return c.GetModel(model.Provider, model.Model)
307}
308
309func (c *Config) SmallModel() *provider.Model {
310	model, ok := c.Models[SelectedModelTypeSmall]
311	if !ok {
312		return nil
313	}
314	return c.GetModel(model.Provider, model.Model)
315}
316
317func (c *Config) SetCompactMode(enabled bool) error {
318	if c.Options == nil {
319		c.Options = &Options{}
320	}
321	c.Options.TUI.CompactMode = enabled
322	return c.SetConfigField("options.tui.compact_mode", enabled)
323}
324
325func (c *Config) Resolve(key string) (string, error) {
326	if c.resolver == nil {
327		return "", fmt.Errorf("no variable resolver configured")
328	}
329	return c.resolver.ResolveValue(key)
330}
331
332func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
333	c.Models[modelType] = model
334	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
335		return fmt.Errorf("failed to update preferred model: %w", err)
336	}
337	return nil
338}
339
340func (c *Config) SetConfigField(key string, value any) error {
341	// read the data
342	data, err := os.ReadFile(c.dataConfigDir)
343	if err != nil {
344		if os.IsNotExist(err) {
345			data = []byte("{}")
346		} else {
347			return fmt.Errorf("failed to read config file: %w", err)
348		}
349	}
350
351	newValue, err := sjson.Set(string(data), key, value)
352	if err != nil {
353		return fmt.Errorf("failed to set config field %s: %w", key, err)
354	}
355	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
356		return fmt.Errorf("failed to write config file: %w", err)
357	}
358	return nil
359}
360
361func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
362	// First save to the config file
363	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
364	if err != nil {
365		return fmt.Errorf("failed to save API key to config file: %w", err)
366	}
367
368	if c.Providers == nil {
369		c.Providers = make(map[string]ProviderConfig)
370	}
371
372	providerConfig, exists := c.Providers[providerID]
373	if exists {
374		providerConfig.APIKey = apiKey
375		c.Providers[providerID] = providerConfig
376		return nil
377	}
378
379	var foundProvider *provider.Provider
380	for _, p := range c.knownProviders {
381		if string(p.ID) == providerID {
382			foundProvider = &p
383			break
384		}
385	}
386
387	if foundProvider != nil {
388		// Create new provider config based on known provider
389		providerConfig = ProviderConfig{
390			ID:           providerID,
391			Name:         foundProvider.Name,
392			BaseURL:      foundProvider.APIEndpoint,
393			Type:         foundProvider.Type,
394			APIKey:       apiKey,
395			Disable:      false,
396			ExtraHeaders: make(map[string]string),
397			ExtraParams:  make(map[string]string),
398			Models:       foundProvider.Models,
399		}
400	} else {
401		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
402	}
403	// Store the updated provider config
404	c.Providers[providerID] = providerConfig
405	return nil
406}
407
408func (c *Config) SetupAgents() {
409	agents := map[string]Agent{
410		"coder": {
411			ID:           "coder",
412			Name:         "Coder",
413			Description:  "An agent that helps with executing coding tasks.",
414			Model:        SelectedModelTypeLarge,
415			ContextPaths: c.Options.ContextPaths,
416			// All tools allowed
417		},
418		"task": {
419			ID:           "task",
420			Name:         "Task",
421			Description:  "An agent that helps with searching for context and finding implementation details.",
422			Model:        SelectedModelTypeLarge,
423			ContextPaths: c.Options.ContextPaths,
424			AllowedTools: []string{
425				"glob",
426				"grep",
427				"ls",
428				"sourcegraph",
429				"view",
430			},
431			// NO MCPs or LSPs by default
432			AllowedMCP: map[string][]string{},
433			AllowedLSP: []string{},
434		},
435	}
436	c.Agents = agents
437}