1package config
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/env"
 14	"github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/tidwall/sjson"
 16	"golang.org/x/exp/slog"
 17)
 18
 19const (
 20	appName              = "crush"
 21	defaultDataDirectory = ".crush"
 22	defaultLogLevel      = "info"
 23)
 24
 25var defaultContextPaths = []string{
 26	".github/copilot-instructions.md",
 27	".cursorrules",
 28	".cursor/rules/",
 29	"CLAUDE.md",
 30	"CLAUDE.local.md",
 31	"GEMINI.md",
 32	"gemini.md",
 33	"crush.md",
 34	"crush.local.md",
 35	"Crush.md",
 36	"Crush.local.md",
 37	"CRUSH.md",
 38	"CRUSH.local.md",
 39}
 40
 41type SelectedModelType string
 42
 43const (
 44	SelectedModelTypeLarge SelectedModelType = "large"
 45	SelectedModelTypeSmall SelectedModelType = "small"
 46)
 47
 48type SelectedModel struct {
 49	// The model id as used by the provider API.
 50	// Required.
 51	Model string `json:"model"`
 52	// The model provider, same as the key/id used in the providers config.
 53	// Required.
 54	Provider string `json:"provider"`
 55
 56	// Only used by models that use the openai provider and need this set.
 57	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 58
 59	// Overrides the default model configuration.
 60	MaxTokens int64 `json:"max_tokens,omitempty"`
 61
 62	// Used by anthropic models that can reason to indicate if the model should think.
 63	Think bool `json:"think,omitempty"`
 64}
 65
 66type ProviderConfig struct {
 67	// The provider's id.
 68	ID string `json:"id,omitempty"`
 69	// The provider's name, used for display purposes.
 70	Name string `json:"name,omitempty"`
 71	// The provider's API endpoint.
 72	BaseURL string `json:"base_url,omitempty"`
 73	// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
 74	Type provider.Type `json:"type,omitempty"`
 75	// The provider's API key.
 76	APIKey string `json:"api_key,omitempty"`
 77	// Marks the provider as disabled.
 78	Disable bool `json:"disable,omitempty"`
 79
 80	// Extra headers to send with each request to the provider.
 81	ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
 82	// Extra body
 83	ExtraBody map[string]any `json:"extra_body,omitempty"`
 84
 85	// Used to pass extra parameters to the provider.
 86	ExtraParams map[string]string `json:"-"`
 87
 88	// The provider models
 89	Models []provider.Model `json:"models,omitempty"`
 90}
 91
 92type MCPType string
 93
 94const (
 95	MCPStdio MCPType = "stdio"
 96	MCPSse   MCPType = "sse"
 97	MCPHttp  MCPType = "http"
 98)
 99
100type MCPConfig struct {
101	Command  string            `json:"command,omitempty" `
102	Env      map[string]string `json:"env,omitempty"`
103	Args     []string          `json:"args,omitempty"`
104	Type     MCPType           `json:"type"`
105	URL      string            `json:"url,omitempty"`
106	Disabled bool              `json:"disabled,omitempty"`
107
108	// TODO: maybe make it possible to get the value from the env
109	Headers map[string]string `json:"headers,omitempty"`
110}
111
112type LSPConfig struct {
113	Disabled bool     `json:"enabled,omitempty"`
114	Command  string   `json:"command"`
115	Args     []string `json:"args,omitempty"`
116	Options  any      `json:"options,omitempty"`
117}
118
119type TUIOptions struct {
120	CompactMode bool `json:"compact_mode,omitempty"`
121	// Here we can add themes later or any TUI related options
122}
123
124type Options struct {
125	ContextPaths            []string    `json:"context_paths,omitempty"`
126	TUI                     *TUIOptions `json:"tui,omitempty"`
127	Debug                   bool        `json:"debug,omitempty"`
128	DebugLSP                bool        `json:"debug_lsp,omitempty"`
129	DisableAutoSummarize    bool        `json:"disable_auto_summarize,omitempty"`
130	DataDirectory           string      `json:"data_directory,omitempty"` // Relative to the cwd
131	SkipPermissionsRequests bool        `json:"-"`                        // Automatically accept all permissions (YOLO mode)
132}
133
134type MCPs map[string]MCPConfig
135
136type MCP struct {
137	Name string    `json:"name"`
138	MCP  MCPConfig `json:"mcp"`
139}
140
141func (m MCPs) Sorted() []MCP {
142	sorted := make([]MCP, 0, len(m))
143	for k, v := range m {
144		sorted = append(sorted, MCP{
145			Name: k,
146			MCP:  v,
147		})
148	}
149	slices.SortFunc(sorted, func(a, b MCP) int {
150		return strings.Compare(a.Name, b.Name)
151	})
152	return sorted
153}
154
155type LSPs map[string]LSPConfig
156
157type LSP struct {
158	Name string    `json:"name"`
159	LSP  LSPConfig `json:"lsp"`
160}
161
162func (l LSPs) Sorted() []LSP {
163	sorted := make([]LSP, 0, len(l))
164	for k, v := range l {
165		sorted = append(sorted, LSP{
166			Name: k,
167			LSP:  v,
168		})
169	}
170	slices.SortFunc(sorted, func(a, b LSP) int {
171		return strings.Compare(a.Name, b.Name)
172	})
173	return sorted
174}
175
176func (m MCPConfig) ResolvedEnv() []string {
177	resolver := NewShellVariableResolver(env.New())
178	for e, v := range m.Env {
179		var err error
180		m.Env[e], err = resolver.ResolveValue(v)
181		if err != nil {
182			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
183			continue
184		}
185	}
186
187	env := make([]string, 0, len(m.Env))
188	for k, v := range m.Env {
189		env = append(env, fmt.Sprintf("%s=%s", k, v))
190	}
191	return env
192}
193
194func (m MCPConfig) ResolvedHeaders() map[string]string {
195	resolver := NewShellVariableResolver(env.New())
196	for e, v := range m.Headers {
197		var err error
198		m.Headers[e], err = resolver.ResolveValue(v)
199		if err != nil {
200			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
201			continue
202		}
203	}
204	return m.Headers
205}
206
207type Agent struct {
208	ID          string `json:"id,omitempty"`
209	Name        string `json:"name,omitempty"`
210	Description string `json:"description,omitempty"`
211	// This is the id of the system prompt used by the agent
212	Disabled bool `json:"disabled,omitempty"`
213
214	Model SelectedModelType `json:"model"`
215
216	// The available tools for the agent
217	//  if this is nil, all tools are available
218	AllowedTools []string `json:"allowed_tools,omitempty"`
219
220	// this tells us which MCPs are available for this agent
221	//  if this is empty all mcps are available
222	//  the string array is the list of tools from the AllowedMCP the agent has available
223	//  if the string array is nil, all tools from the AllowedMCP are available
224	AllowedMCP map[string][]string `json:"allowed_mcp,omitempty"`
225
226	// The list of LSPs that this agent can use
227	//  if this is nil, all LSPs are available
228	AllowedLSP []string `json:"allowed_lsp,omitempty"`
229
230	// Overrides the context paths for this agent
231	ContextPaths []string `json:"context_paths,omitempty"`
232}
233
234// Config holds the configuration for crush.
235type Config struct {
236	// We currently only support large/small as values here.
237	Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
238
239	// The providers that are configured
240	Providers *csync.Map[string, ProviderConfig] `json:"providers,omitempty"`
241
242	MCP MCPs `json:"mcp,omitempty"`
243
244	LSP LSPs `json:"lsp,omitempty"`
245
246	Options *Options `json:"options,omitempty"`
247
248	// Internal
249	workingDir string `json:"-"`
250	// TODO: most likely remove this concept when I come back to it
251	Agents map[string]Agent `json:"-"`
252	// TODO: find a better way to do this this should probably not be part of the config
253	resolver       VariableResolver
254	dataConfigDir  string              `json:"-"`
255	knownProviders []provider.Provider `json:"-"`
256}
257
258func (c *Config) WorkingDir() string {
259	return c.workingDir
260}
261
262func (c *Config) EnabledProviders() []ProviderConfig {
263	var enabled []ProviderConfig
264	for _, p := range c.Providers.Seq2() {
265		if !p.Disable {
266			enabled = append(enabled, p)
267		}
268	}
269	return enabled
270}
271
272// IsConfigured  return true if at least one provider is configured
273func (c *Config) IsConfigured() bool {
274	return len(c.EnabledProviders()) > 0
275}
276
277func (c *Config) GetModel(provider, model string) *provider.Model {
278	if providerConfig, ok := c.Providers.Get(provider); ok {
279		for _, m := range providerConfig.Models {
280			if m.ID == model {
281				return &m
282			}
283		}
284	}
285	return nil
286}
287
288func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig {
289	model, ok := c.Models[modelType]
290	if !ok {
291		return nil
292	}
293	if providerConfig, ok := c.Providers.Get(model.Provider); ok {
294		return &providerConfig
295	}
296	return nil
297}
298
299func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
300	model, ok := c.Models[modelType]
301	if !ok {
302		return nil
303	}
304	return c.GetModel(model.Provider, model.Model)
305}
306
307func (c *Config) LargeModel() *provider.Model {
308	model, ok := c.Models[SelectedModelTypeLarge]
309	if !ok {
310		return nil
311	}
312	return c.GetModel(model.Provider, model.Model)
313}
314
315func (c *Config) SmallModel() *provider.Model {
316	model, ok := c.Models[SelectedModelTypeSmall]
317	if !ok {
318		return nil
319	}
320	return c.GetModel(model.Provider, model.Model)
321}
322
323func (c *Config) SetCompactMode(enabled bool) error {
324	if c.Options == nil {
325		c.Options = &Options{}
326	}
327	c.Options.TUI.CompactMode = enabled
328	return c.SetConfigField("options.tui.compact_mode", enabled)
329}
330
331func (c *Config) Resolve(key string) (string, error) {
332	if c.resolver == nil {
333		return "", fmt.Errorf("no variable resolver configured")
334	}
335	return c.resolver.ResolveValue(key)
336}
337
338func (c *Config) UpdatePreferredModel(modelType SelectedModelType, model SelectedModel) error {
339	c.Models[modelType] = model
340	if err := c.SetConfigField(fmt.Sprintf("models.%s", modelType), model); err != nil {
341		return fmt.Errorf("failed to update preferred model: %w", err)
342	}
343	return nil
344}
345
346func (c *Config) SetConfigField(key string, value any) error {
347	// read the data
348	data, err := os.ReadFile(c.dataConfigDir)
349	if err != nil {
350		if os.IsNotExist(err) {
351			data = []byte("{}")
352		} else {
353			return fmt.Errorf("failed to read config file: %w", err)
354		}
355	}
356
357	newValue, err := sjson.Set(string(data), key, value)
358	if err != nil {
359		return fmt.Errorf("failed to set config field %s: %w", key, err)
360	}
361	if err := os.WriteFile(c.dataConfigDir, []byte(newValue), 0o644); err != nil {
362		return fmt.Errorf("failed to write config file: %w", err)
363	}
364	return nil
365}
366
367func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
368	// First save to the config file
369	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
370	if err != nil {
371		return fmt.Errorf("failed to save API key to config file: %w", err)
372	}
373
374	providerConfig, exists := c.Providers.Get(providerID)
375	if exists {
376		providerConfig.APIKey = apiKey
377		c.Providers.Set(providerID, providerConfig)
378		return nil
379	}
380
381	var foundProvider *provider.Provider
382	for _, p := range c.knownProviders {
383		if string(p.ID) == providerID {
384			foundProvider = &p
385			break
386		}
387	}
388
389	if foundProvider != nil {
390		// Create new provider config based on known provider
391		providerConfig = ProviderConfig{
392			ID:           providerID,
393			Name:         foundProvider.Name,
394			BaseURL:      foundProvider.APIEndpoint,
395			Type:         foundProvider.Type,
396			APIKey:       apiKey,
397			Disable:      false,
398			ExtraHeaders: make(map[string]string),
399			ExtraParams:  make(map[string]string),
400			Models:       foundProvider.Models,
401		}
402	} else {
403		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
404	}
405	// Store the updated provider config
406	c.Providers.Set(providerID, providerConfig)
407	return nil
408}
409
410func (c *Config) SetupAgents() {
411	agents := map[string]Agent{
412		"coder": {
413			ID:           "coder",
414			Name:         "Coder",
415			Description:  "An agent that helps with executing coding tasks.",
416			Model:        SelectedModelTypeLarge,
417			ContextPaths: c.Options.ContextPaths,
418			// All tools allowed
419		},
420		"task": {
421			ID:           "task",
422			Name:         "Task",
423			Description:  "An agent that helps with searching for context and finding implementation details.",
424			Model:        SelectedModelTypeLarge,
425			ContextPaths: c.Options.ContextPaths,
426			AllowedTools: []string{
427				"glob",
428				"grep",
429				"ls",
430				"sourcegraph",
431				"view",
432			},
433			// NO MCPs or LSPs by default
434			AllowedMCP: map[string][]string{},
435			AllowedLSP: []string{},
436		},
437	}
438	c.Agents = agents
439}
440
441func (c *Config) Resolver() VariableResolver {
442	return c.resolver
443}
444
445func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
446	testURL := ""
447	headers := make(map[string]string)
448	apiKey, _ := resolver.ResolveValue(c.APIKey)
449	switch c.Type {
450	case provider.TypeOpenAI:
451		baseURL, _ := resolver.ResolveValue(c.BaseURL)
452		if baseURL == "" {
453			baseURL = "https://api.openai.com/v1"
454		}
455		testURL = baseURL + "/models"
456		headers["Authorization"] = "Bearer " + apiKey
457	case provider.TypeAnthropic:
458		baseURL, _ := resolver.ResolveValue(c.BaseURL)
459		if baseURL == "" {
460			baseURL = "https://api.anthropic.com/v1"
461		}
462		testURL = baseURL + "/models"
463		headers["x-api-key"] = apiKey
464		headers["anthropic-version"] = "2023-06-01"
465	}
466	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
467	defer cancel()
468	client := &http.Client{}
469	req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
470	if err != nil {
471		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
472	}
473	for k, v := range headers {
474		req.Header.Set(k, v)
475	}
476	for k, v := range c.ExtraHeaders {
477		req.Header.Set(k, v)
478	}
479	b, err := client.Do(req)
480	if err != nil {
481		return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
482	}
483	if b.StatusCode != http.StatusOK {
484		return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
485	}
486	_ = b.Body.Close()
487	return nil
488}