1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net/http"
  8	"os"
  9	"slices"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/env"
 16	"github.com/tidwall/sjson"
 17)
 18
 19const (
 20	appName              = "crush"
 21	defaultDataDirectory = ".crush"
 22	defaultLogLevel      = "info"
 23)
 24
 25var defaultContextPaths = []string{
 26	".github/copilot-instructions.md",
 27	".cursorrules",
 28	".cursor/rules/",
 29	"CLAUDE.md",
 30	"CLAUDE.local.md",
 31	"GEMINI.md",
 32	"gemini.md",
 33	"crush.md",
 34	"crush.local.md",
 35	"Crush.md",
 36	"Crush.local.md",
 37	"CRUSH.md",
 38	"CRUSH.local.md",
 39}
 40
 41type SelectedModelType string
 42
 43const (
 44	SelectedModelTypeLarge SelectedModelType = "large"
 45	SelectedModelTypeSmall SelectedModelType = "small"
 46)
 47
 48type SelectedModel struct {
 49	// The model id as used by the provider API.
 50	// Required.
 51	Model string `json:"model"`
 52	// The model provider, same as the key/id used in the providers config.
 53	// Required.
 54	Provider string `json:"provider"`
 55
 56	// Only used by models that use the openai provider and need this set.
 57	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 58
 59	// Overrides the default model configuration.
 60	MaxTokens int64 `json:"max_tokens,omitempty"`
 61
 62	// Used by anthropic models that can reason to indicate if the model should think.
 63	Think bool `json:"think,omitempty"`
 64}
 65
 66type ProviderConfig struct {
 67	// The provider's id.
 68	ID string `json:"id,omitempty"`
 69	// The provider's name, used for display purposes.
 70	Name string `json:"name,omitempty"`
 71	// The provider's API endpoint.
 72	BaseURL string `json:"base_url,omitempty"`
 73	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 74	Type catwalk.Type `json:"type,omitempty"`
 75	// The provider's API key.
 76	APIKey string `json:"api_key,omitempty"`
 77	// Marks the provider as disabled.
 78	Disable bool `json:"disable,omitempty"`
 79
 80	// Extra headers to send with each request to the provider.
 81	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 82	// Extra body
 83	ExtraBody map[string]any `json:"extra_body,omitempty"`
 84
 85	// Used to pass extra parameters to the provider.
 86	ExtraParams map[string]string `json:"-"`
 87
 88	// The provider models
 89	Models []catwalk.Model `json:"models,omitempty"`
 90}
 91
 92type MCPType string
 93
 94const (
 95	MCPStdio MCPType = "stdio"
 96	MCPSse   MCPType = "sse"
 97	MCPHttp  MCPType = "http"
 98)
 99
100type MCPConfig struct {
101	Command  string            `json:"command,omitempty" `
102	Env      map[string]string `json:"env,omitempty"`
103	Args     []string          `json:"args,omitempty"`
104	Type     MCPType           `json:"type"`
105	URL      string            `json:"url,omitempty"`
106	Disabled bool              `json:"disabled,omitempty"`
107
108	// TODO: maybe make it possible to get the value from the env
109	Headers map[string]string `json:"headers,omitempty"`
110}
111
112type LSPConfig struct {
113	Disabled bool     `json:"enabled,omitempty"`
114	Command  string   `json:"command"`
115	Args     []string `json:"args,omitempty"`
116	Options  any      `json:"options,omitempty"`
117}
118
119type TUIOptions struct {
120	CompactMode bool `json:"compact_mode,omitempty"`
121	// Here we can add themes later or any TUI related options
122}
123
124type Permissions struct {
125	AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
126	SkipRequests bool     `json:"-"`                       // Automatically accept all permissions (YOLO mode)
127}
128
129type Options struct {
130	ContextPaths         []string    `json:"context_paths,omitempty"`
131	TUI                  *TUIOptions `json:"tui,omitempty"`
132	Debug                bool        `json:"debug,omitempty"`
133	DebugLSP             bool        `json:"debug_lsp,omitempty"`
134	DisableAutoSummarize bool        `json:"disable_auto_summarize,omitempty"`
135	DataDirectory        string      `json:"data_directory,omitempty"` // Relative to the cwd
136}
137
138type MCPs map[string]MCPConfig
139
140type MCP struct {
141	Name string    `json:"name"`
142	MCP  MCPConfig `json:"mcp"`
143}
144
145func (m MCPs) Sorted() []MCP {
146	sorted := make([]MCP, 0, len(m))
147	for k, v := range m {
148		sorted = append(sorted, MCP{
149			Name: k,
150			MCP:  v,
151		})
152	}
153	slices.SortFunc(sorted, func(a, b MCP) int {
154		return strings.Compare(a.Name, b.Name)
155	})
156	return sorted
157}
158
159type LSPs map[string]LSPConfig
160
161type LSP struct {
162	Name string    `json:"name"`
163	LSP  LSPConfig `json:"lsp"`
164}
165
166func (l LSPs) Sorted() []LSP {
167	sorted := make([]LSP, 0, len(l))
168	for k, v := range l {
169		sorted = append(sorted, LSP{
170			Name: k,
171			LSP:  v,
172		})
173	}
174	slices.SortFunc(sorted, func(a, b LSP) int {
175		return strings.Compare(a.Name, b.Name)
176	})
177	return sorted
178}
179
180func (m MCPConfig) ResolvedEnv() []string {
181	resolver := NewShellVariableResolver(env.New())
182	for e, v := range m.Env {
183		var err error
184		m.Env[e], err = resolver.ResolveValue(v)
185		if err != nil {
186			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
187			continue
188		}
189	}
190
191	env := make([]string, 0, len(m.Env))
192	for k, v := range m.Env {
193		env = append(env, fmt.Sprintf("%s=%s", k, v))
194	}
195	return env
196}
197
198func (m MCPConfig) ResolvedHeaders() map[string]string {
199	resolver := NewShellVariableResolver(env.New())
200	for e, v := range m.Headers {
201		var err error
202		m.Headers[e], err = resolver.ResolveValue(v)
203		if err != nil {
204			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
205			continue
206		}
207	}
208	return m.Headers
209}
210
211type Agent struct {
212	ID          string `json:"id,omitempty"`
213	Name        string `json:"name,omitempty"`
214	Description string `json:"description,omitempty"`
215	// This is the id of the system prompt used by the agent
216	Disabled bool `json:"disabled,omitempty"`
217
218	Model SelectedModelType `json:"model"`
219
220	// The available tools for the agent
221	//  if this is nil, all tools are available
222	AllowedTools []string `json:"allowed_tools,omitempty"`
223
224	// this tells us which MCPs are available for this agent
225	//  if this is empty all mcps are available
226	//  the string array is the list of tools from the AllowedMCP the agent has available
227	//  if the string array is nil, all tools from the AllowedMCP are available
228	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
229
230	// The list of LSPs that this agent can use
231	//  if this is nil, all LSPs are available
232	AllowedLSP []string `json:"allowed_lsp,omitempty"`
233
234	// Overrides the context paths for this agent
235	ContextPaths []string `json:"context_paths,omitempty"`
236}
237
238// Config holds the configuration for crush.
239type Config struct {
240	// We currently only support large/small as values here.
241	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
242
243	// The providers that are configured
244	Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
245
246	MCP MCPs `json:"mcp,omitempty"`
247
248	LSP LSPs `json:"lsp,omitempty"`
249
250	Options *Options `json:"options,omitempty"`
251
252	Permissions *Permissions `json:"permissions,omitempty"`
253
254	// Internal
255	workingDir string `json:"-"`
256	// TODO: most likely remove this concept when I come back to it
257	Agents map[string]Agent `json:"-"`
258	// TODO: find a better way to do this this should probably not be part of the config
259	resolver       VariableResolver
260	dataConfigDir  string             `json:"-"`
261	knownProviders []catwalk.Provider `json:"-"`
262}
263
264func (c *Config) WorkingDir() string {
265	return c.workingDir
266}
267
268func (c *Config) EnabledProviders() []ProviderConfig {
269	var enabled []ProviderConfig
270	for _, p := range c.Providers.Seq2() {
271		if !p.Disable {
272			enabled = append(enabled, p)
273		}
274	}
275	return enabled
276}
277
278// IsConfigured  return true if at least one provider is configured
279func (c *Config) IsConfigured() bool {
280	return len(c.EnabledProviders()) > 0
281}
282
283func (c *Config) GetModel(provider, model string) *catwalk.Model {
284	if providerConfig, ok := c.Providers.Get(provider); ok {
285		for _, m := range providerConfig.Models {
286			if m.ID == model {
287				return &m
288			}
289		}
290	}
291	return nil
292}
293
294func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
295	model, ok := c.Models[modelType]
296	if !ok {
297		return nil
298	}
299	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
300		return &providerConfig
301	}
302	return nil
303}
304
305func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
306	model, ok := c.Models[modelType]
307	if !ok {
308		return nil
309	}
310	return c.GetModel(model.Provider, model.Model)
311}
312
313func (c *Config) LargeModel() *catwalk.Model {
314	model, ok := c.Models[SelectedModelTypeLarge]
315	if !ok {
316		return nil
317	}
318	return c.GetModel(model.Provider, model.Model)
319}
320
321func (c *Config) SmallModel() *catwalk.Model {
322	model, ok := c.Models[SelectedModelTypeSmall]
323	if !ok {
324		return nil
325	}
326	return c.GetModel(model.Provider, model.Model)
327}
328
329func (c *Config) SetCompactMode(enabled bool) error {
330	if c.Options == nil {
331		c.Options = &Options{}
332	}
333	c.Options.TUI.CompactMode = enabled
334	return c.SetConfigField("options.tui.compact_mode", enabled)
335}
336
337func (c *Config) Resolve(key string) (string, error) {
338	if c.resolver == nil {
339		return "", fmt.Errorf("no variable resolver configured")
340	}
341	return c.resolver.ResolveValue(key)
342}
343
344func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
345	c.Models[modelType] = model
346	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
347		return fmt.Errorf("failed to update preferred model: %w", err)
348	}
349	return nil
350}
351
352func (c *Config) SetConfigField(key string, value any) error {
353	// read the data
354	data, err := os.ReadFile(c.dataConfigDir)
355	if err != nil {
356		if os.IsNotExist(err) {
357			data = []byte("{}")
358		} else {
359			return fmt.Errorf("failed to read config file: %w", err)
360		}
361	}
362
363	newValue, err := sjson.Set(string(data), key, value)
364	if err != nil {
365		return fmt.Errorf("failed to set config field %s: %w", key, err)
366	}
367	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
368		return fmt.Errorf("failed to write config file: %w", err)
369	}
370	return nil
371}
372
373func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
374	// First save to the config file
375	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
376	if err != nil {
377		return fmt.Errorf("failed to save API key to config file: %w", err)
378	}
379
380	providerConfig, exists := c.Providers.Get(providerID)
381	if exists {
382		providerConfig.APIKey = apiKey
383		c.Providers.Set(providerID, providerConfig)
384		return nil
385	}
386
387	var foundProvider *catwalk.Provider
388	for _, p := range c.knownProviders {
389		if string(p.ID) == providerID {
390			foundProvider = &p
391			break
392		}
393	}
394
395	if foundProvider != nil {
396		// Create new provider config based on known provider
397		providerConfig = ProviderConfig{
398			ID:           providerID,
399			Name:         foundProvider.Name,
400			BaseURL:      foundProvider.APIEndpoint,
401			Type:         foundProvider.Type,
402			APIKey:       apiKey,
403			Disable:      false,
404			ExtraHeaders: make(map[string]string),
405			ExtraParams:  make(map[string]string),
406			Models:       foundProvider.Models,
407		}
408	} else {
409		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
410	}
411	// Store the updated provider config
412	c.Providers.Set(providerID, providerConfig)
413	return nil
414}
415
416func (c *Config) SetupAgents() {
417	agents := map[string]Agent{
418		"coder": {
419			ID:           "coder",
420			Name:         "Coder",
421			Description:  "An agent that helps with executing coding tasks.",
422			Model:        SelectedModelTypeLarge,
423			ContextPaths: c.Options.ContextPaths,
424			// All tools allowed
425		},
426		"task": {
427			ID:           "task",
428			Name:         "Task",
429			Description:  "An agent that helps with searching for context and finding implementation details.",
430			Model:        SelectedModelTypeLarge,
431			ContextPaths: c.Options.ContextPaths,
432			AllowedTools: []string{
433				"glob",
434				"grep",
435				"ls",
436				"sourcegraph",
437				"view",
438			},
439			// NO MCPs or LSPs by default
440			AllowedMCP: map[string][]string{},
441			AllowedLSP: []string{},
442		},
443	}
444	c.Agents = agents
445}
446
447func (c *Config) Resolver() VariableResolver {
448	return c.resolver
449}
450
451func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
452	testURL := ""
453	headers := make(map[string]string)
454	apiKey, _ := resolver.ResolveValue(c.APIKey)
455	switch c.Type {
456	case catwalk.TypeOpenAI:
457		baseURL, _ := resolver.ResolveValue(c.BaseURL)
458		if baseURL == "" {
459			baseURL = "https://api.openai.com/v1"
460		}
461		testURL = baseURL + "/models"
462		headers["Authorization"] = "Bearer " + apiKey
463	case catwalk.TypeAnthropic:
464		baseURL, _ := resolver.ResolveValue(c.BaseURL)
465		if baseURL == "" {
466			baseURL = "https://api.anthropic.com/v1"
467		}
468		testURL = baseURL + "/models"
469		headers["x-api-key"] = apiKey
470		headers["anthropic-version"] = "2023-06-01"
471	}
472	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
473	defer cancel()
474	client := &http.Client{}
475	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
476	if err != nil {
477		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
478	}
479	for k, v := range headers {
480		req.Header.Set(k, v)
481	}
482	for k, v := range c.ExtraHeaders {
483		req.Header.Set(k, v)
484	}
485	b, err := client.Do(req)
486	if err != nil {
487		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
488	}
489	if b.StatusCode != http.StatusOK {
490		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
491	}
492	_ = b.Body.Close()
493	return nil
494}