1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net/http"
  8	"os"
  9	"slices"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/env"
 16	"github.com/tidwall/sjson"
 17)
 18
 19const (
 20	appName              = "crush"
 21	defaultDataDirectory = ".crush"
 22	defaultLogLevel      = "info"
 23)
 24
 25var defaultContextPaths = []string{
 26	".github/copilot-instructions.md",
 27	".cursorrules",
 28	".cursor/rules/",
 29	"CLAUDE.md",
 30	"CLAUDE.local.md",
 31	"GEMINI.md",
 32	"gemini.md",
 33	"crush.md",
 34	"crush.local.md",
 35	"Crush.md",
 36	"Crush.local.md",
 37	"CRUSH.md",
 38	"CRUSH.local.md",
 39}
 40
 41type SelectedModelType string
 42
 43const (
 44	SelectedModelTypeLarge SelectedModelType = "large"
 45	SelectedModelTypeSmall SelectedModelType = "small"
 46)
 47
 48type SelectedModel struct {
 49	// The model id as used by the provider API.
 50	// Required.
 51	Model string `json:"model"`
 52	// The model provider, same as the key/id used in the providers config.
 53	// Required.
 54	Provider string `json:"provider"`
 55
 56	// Only used by models that use the openai provider and need this set.
 57	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 58
 59	// Overrides the default model configuration.
 60	MaxTokens int64 `json:"max_tokens,omitempty"`
 61
 62	// Used by anthropic models that can reason to indicate if the model should think.
 63	Think bool `json:"think,omitempty"`
 64}
 65
 66type ProviderConfig struct {
 67	// The provider's id.
 68	ID string `json:"id,omitempty"`
 69	// The provider's name, used for display purposes.
 70	Name string `json:"name,omitempty"`
 71	// The provider's API endpoint.
 72	BaseURL string `json:"base_url,omitempty"`
 73	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 74	Type catwalk.Type `json:"type,omitempty"`
 75	// The provider's API key.
 76	APIKey string `json:"api_key,omitempty"`
 77	// Marks the provider as disabled.
 78	Disable bool `json:"disable,omitempty"`
 79
 80	// Custom system prompt prefix.
 81	SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
 82
 83	// Extra headers to send with each request to the provider.
 84	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 85	// Extra body
 86	ExtraBody map[string]any `json:"extra_body,omitempty"`
 87
 88	// Used to pass extra parameters to the provider.
 89	ExtraParams map[string]string `json:"-"`
 90
 91	// The provider models
 92	Models []catwalk.Model `json:"models,omitempty"`
 93}
 94
 95type MCPType string
 96
 97const (
 98	MCPStdio MCPType = "stdio"
 99	MCPSse   MCPType = "sse"
100	MCPHttp  MCPType = "http"
101)
102
103type MCPConfig struct {
104	Command  string            `json:"command,omitempty" `
105	Env      map[string]string `json:"env,omitempty"`
106	Args     []string          `json:"args,omitempty"`
107	Type     MCPType           `json:"type"`
108	URL      string            `json:"url,omitempty"`
109	Disabled bool              `json:"disabled,omitempty"`
110
111	// TODO: maybe make it possible to get the value from the env
112	Headers map[string]string `json:"headers,omitempty"`
113}
114
115type LSPConfig struct {
116	Disabled bool     `json:"enabled,omitempty"`
117	Command  string   `json:"command"`
118	Args     []string `json:"args,omitempty"`
119	Options  any      `json:"options,omitempty"`
120}
121
122type TUIOptions struct {
123	CompactMode bool `json:"compact_mode,omitempty"`
124	// Here we can add themes later or any TUI related options
125}
126
127type Permissions struct {
128	AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
129	SkipRequests bool     `json:"-"`                       // Automatically accept all permissions (YOLO mode)
130}
131
132type Options struct {
133	ContextPaths         []string    `json:"context_paths,omitempty"`
134	TUI                  *TUIOptions `json:"tui,omitempty"`
135	Debug                bool        `json:"debug,omitempty"`
136	DebugLSP             bool        `json:"debug_lsp,omitempty"`
137	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
138	DataDirectory        string      `json:"data_directory,omitempty"` // Relative to the cwd
139}
140
141type MCPs map[string]MCPConfig
142
143type MCP struct {
144	Name string    `json:"name"`
145	MCP  MCPConfig `json:"mcp"`
146}
147
148func (m MCPs) Sorted() []MCP {
149	sorted := make([]MCP, 0, len(m))
150	for k, v := range m {
151		sorted = append(sorted, MCP{
152			Name: k,
153			MCP:  v,
154		})
155	}
156	slices.SortFunc(sorted, func(a, b MCP) int {
157		return strings.Compare(a.Name, b.Name)
158	})
159	return sorted
160}
161
162type LSPs map[string]LSPConfig
163
164type LSP struct {
165	Name string    `json:"name"`
166	LSP  LSPConfig `json:"lsp"`
167}
168
169func (l LSPs) Sorted() []LSP {
170	sorted := make([]LSP, 0, len(l))
171	for k, v := range l {
172		sorted = append(sorted, LSP{
173			Name: k,
174			LSP:  v,
175		})
176	}
177	slices.SortFunc(sorted, func(a, b LSP) int {
178		return strings.Compare(a.Name, b.Name)
179	})
180	return sorted
181}
182
183func (m MCPConfig) ResolvedEnv() []string {
184	resolver := NewShellVariableResolver(env.New())
185	for e, v := range m.Env {
186		var err error
187		m.Env[e], err = resolver.ResolveValue(v)
188		if err != nil {
189			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
190			continue
191		}
192	}
193
194	env := make([]string, 0, len(m.Env))
195	for k, v := range m.Env {
196		env = append(env, fmt.Sprintf("%s=%s", k, v))
197	}
198	return env
199}
200
201func (m MCPConfig) ResolvedHeaders() map[string]string {
202	resolver := NewShellVariableResolver(env.New())
203	for e, v := range m.Headers {
204		var err error
205		m.Headers[e], err = resolver.ResolveValue(v)
206		if err != nil {
207			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
208			continue
209		}
210	}
211	return m.Headers
212}
213
214type Agent struct {
215	ID          string `json:"id,omitempty"`
216	Name        string `json:"name,omitempty"`
217	Description string `json:"description,omitempty"`
218	// This is the id of the system prompt used by the agent
219	Disabled bool `json:"disabled,omitempty"`
220
221	Model SelectedModelType `json:"model"`
222
223	// The available tools for the agent
224	//  if this is nil, all tools are available
225	AllowedTools []string `json:"allowed_tools,omitempty"`
226
227	// this tells us which MCPs are available for this agent
228	//  if this is empty all mcps are available
229	//  the string array is the list of tools from the AllowedMCP the agent has available
230	//  if the string array is nil, all tools from the AllowedMCP are available
231	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
232
233	// The list of LSPs that this agent can use
234	//  if this is nil, all LSPs are available
235	AllowedLSP []string `json:"allowed_lsp,omitempty"`
236
237	// Overrides the context paths for this agent
238	ContextPaths []string `json:"context_paths,omitempty"`
239}
240
241// Config holds the configuration for crush.
242type Config struct {
243	// We currently only support large/small as values here.
244	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
245
246	// The providers that are configured
247	Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
248
249	MCP MCPs `json:"mcp,omitempty"`
250
251	LSP LSPs `json:"lsp,omitempty"`
252
253	Options *Options `json:"options,omitempty"`
254
255	Permissions *Permissions `json:"permissions,omitempty"`
256
257	// Internal
258	workingDir string `json:"-"`
259	// TODO: most likely remove this concept when I come back to it
260	Agents map[string]Agent `json:"-"`
261	// TODO: find a better way to do this this should probably not be part of the config
262	resolver       VariableResolver
263	dataConfigDir  string             `json:"-"`
264	knownProviders []catwalk.Provider `json:"-"`
265}
266
267func (c *Config) WorkingDir() string {
268	return c.workingDir
269}
270
271func (c *Config) EnabledProviders() []ProviderConfig {
272	var enabled []ProviderConfig
273	for _, p := range c.Providers.Seq2() {
274		if !p.Disable {
275			enabled = append(enabled, p)
276		}
277	}
278	return enabled
279}
280
281// IsConfigured  return true if at least one provider is configured
282func (c *Config) IsConfigured() bool {
283	return len(c.EnabledProviders()) > 0
284}
285
286func (c *Config) GetModel(provider, model string) *catwalk.Model {
287	if providerConfig, ok := c.Providers.Get(provider); ok {
288		for _, m := range providerConfig.Models {
289			if m.ID == model {
290				return &m
291			}
292		}
293	}
294	return nil
295}
296
297func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
298	model, ok := c.Models[modelType]
299	if !ok {
300		return nil
301	}
302	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
303		return &providerConfig
304	}
305	return nil
306}
307
308func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
309	model, ok := c.Models[modelType]
310	if !ok {
311		return nil
312	}
313	return c.GetModel(model.Provider, model.Model)
314}
315
316func (c *Config) LargeModel() *catwalk.Model {
317	model, ok := c.Models[SelectedModelTypeLarge]
318	if !ok {
319		return nil
320	}
321	return c.GetModel(model.Provider, model.Model)
322}
323
324func (c *Config) SmallModel() *catwalk.Model {
325	model, ok := c.Models[SelectedModelTypeSmall]
326	if !ok {
327		return nil
328	}
329	return c.GetModel(model.Provider, model.Model)
330}
331
332func (c *Config) SetCompactMode(enabled bool) error {
333	if c.Options == nil {
334		c.Options = &Options{}
335	}
336	c.Options.TUI.CompactMode = enabled
337	return c.SetConfigField("options.tui.compact_mode", enabled)
338}
339
340func (c *Config) Resolve(key string) (string, error) {
341	if c.resolver == nil {
342		return "", fmt.Errorf("no variable resolver configured")
343	}
344	return c.resolver.ResolveValue(key)
345}
346
347func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
348	c.Models[modelType] = model
349	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
350		return fmt.Errorf("failed to update preferred model: %w", err)
351	}
352	return nil
353}
354
355func (c *Config) SetConfigField(key string, value any) error {
356	// read the data
357	data, err := os.ReadFile(c.dataConfigDir)
358	if err != nil {
359		if os.IsNotExist(err) {
360			data = []byte("{}")
361		} else {
362			return fmt.Errorf("failed to read config file: %w", err)
363		}
364	}
365
366	newValue, err := sjson.Set(string(data), key, value)
367	if err != nil {
368		return fmt.Errorf("failed to set config field %s: %w", key, err)
369	}
370	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
371		return fmt.Errorf("failed to write config file: %w", err)
372	}
373	return nil
374}
375
376func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
377	// First save to the config file
378	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
379	if err != nil {
380		return fmt.Errorf("failed to save API key to config file: %w", err)
381	}
382
383	providerConfig, exists := c.Providers.Get(providerID)
384	if exists {
385		providerConfig.APIKey = apiKey
386		c.Providers.Set(providerID, providerConfig)
387		return nil
388	}
389
390	var foundProvider *catwalk.Provider
391	for _, p := range c.knownProviders {
392		if string(p.ID) == providerID {
393			foundProvider = &p
394			break
395		}
396	}
397
398	if foundProvider != nil {
399		// Create new provider config based on known provider
400		providerConfig = ProviderConfig{
401			ID:           providerID,
402			Name:         foundProvider.Name,
403			BaseURL:      foundProvider.APIEndpoint,
404			Type:         foundProvider.Type,
405			APIKey:       apiKey,
406			Disable:      false,
407			ExtraHeaders: make(map[string]string),
408			ExtraParams:  make(map[string]string),
409			Models:       foundProvider.Models,
410		}
411	} else {
412		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
413	}
414	// Store the updated provider config
415	c.Providers.Set(providerID, providerConfig)
416	return nil
417}
418
419func (c *Config) SetupAgents() {
420	agents := map[string]Agent{
421		"coder": {
422			ID:           "coder",
423			Name:         "Coder",
424			Description:  "An agent that helps with executing coding tasks.",
425			Model:        SelectedModelTypeLarge,
426			ContextPaths: c.Options.ContextPaths,
427			// All tools allowed
428		},
429		"task": {
430			ID:           "task",
431			Name:         "Task",
432			Description:  "An agent that helps with searching for context and finding implementation details.",
433			Model:        SelectedModelTypeLarge,
434			ContextPaths: c.Options.ContextPaths,
435			AllowedTools: []string{
436				"glob",
437				"grep",
438				"ls",
439				"sourcegraph",
440				"view",
441			},
442			// NO MCPs or LSPs by default
443			AllowedMCP: map[string][]string{},
444			AllowedLSP: []string{},
445		},
446	}
447	c.Agents = agents
448}
449
450func (c *Config) Resolver() VariableResolver {
451	return c.resolver
452}
453
454func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
455	testURL := ""
456	headers := make(map[string]string)
457	apiKey, _ := resolver.ResolveValue(c.APIKey)
458	switch c.Type {
459	case catwalk.TypeOpenAI:
460		baseURL, _ := resolver.ResolveValue(c.BaseURL)
461		if baseURL == "" {
462			baseURL = "https://api.openai.com/v1"
463		}
464		testURL = baseURL + "/models"
465		headers["Authorization"] = "Bearer " + apiKey
466	case catwalk.TypeAnthropic:
467		baseURL, _ := resolver.ResolveValue(c.BaseURL)
468		if baseURL == "" {
469			baseURL = "https://api.anthropic.com/v1"
470		}
471		testURL = baseURL + "/models"
472		headers["x-api-key"] = apiKey
473		headers["anthropic-version"] = "2023-06-01"
474	}
475	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
476	defer cancel()
477	client := &http.Client{}
478	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
479	if err != nil {
480		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
481	}
482	for k, v := range headers {
483		req.Header.Set(k, v)
484	}
485	for k, v := range c.ExtraHeaders {
486		req.Header.Set(k, v)
487	}
488	b, err := client.Do(req)
489	if err != nil {
490		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
491	}
492	if b.StatusCode != http.StatusOK {
493		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
494	}
495	_ = b.Body.Close()
496	return nil
497}