loader.go

  1package commands
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io/fs"
  7	"os"
  8	"path/filepath"
  9	"regexp"
 10	"strings"
 11
 12	tea "github.com/charmbracelet/bubbletea/v2"
 13	"github.com/modelcontextprotocol/go-sdk/mcp"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/home"
 17	"github.com/charmbracelet/crush/internal/llm/agent"
 18	"github.com/charmbracelet/crush/internal/tui/components/chat"
 19	"github.com/charmbracelet/crush/internal/tui/util"
 20)
 21
 22const (
 23	userCommandPrefix    = "user:"
 24	projectCommandPrefix = "project:"
 25)
 26
 27var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 28
 29type commandLoader struct {
 30	sources []commandSource
 31}
 32
 33type commandSource struct {
 34	path   string
 35	prefix string
 36}
 37
 38func LoadCustomCommands() ([]Command, error) {
 39	cfg := config.Get()
 40	if cfg == nil {
 41		return nil, fmt.Errorf("config not loaded")
 42	}
 43
 44	loader := &commandLoader{
 45		sources: buildCommandSources(cfg),
 46	}
 47
 48	return loader.loadAll()
 49}
 50
 51func buildCommandSources(cfg *config.Config) []commandSource {
 52	var sources []commandSource
 53
 54	// XDG config directory
 55	if dir := getXDGCommandsDir(); dir != "" {
 56		sources = append(sources, commandSource{
 57			path:   dir,
 58			prefix: userCommandPrefix,
 59		})
 60	}
 61
 62	// Home directory
 63	if home := home.Dir(); home != "" {
 64		sources = append(sources, commandSource{
 65			path:   filepath.Join(home, ".crush", "commands"),
 66			prefix: userCommandPrefix,
 67		})
 68	}
 69
 70	// Project directory
 71	sources = append(sources, commandSource{
 72		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 73		prefix: projectCommandPrefix,
 74	})
 75
 76	return sources
 77}
 78
 79func getXDGCommandsDir() string {
 80	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 81	if xdgHome == "" {
 82		if home := home.Dir(); home != "" {
 83			xdgHome = filepath.Join(home, ".config")
 84		}
 85	}
 86	if xdgHome != "" {
 87		return filepath.Join(xdgHome, "crush", "commands")
 88	}
 89	return ""
 90}
 91
 92func (l *commandLoader) loadAll() ([]Command, error) {
 93	var commands []Command
 94
 95	for _, source := range l.sources {
 96		if cmds, err := l.loadFromSource(source); err == nil {
 97			commands = append(commands, cmds...)
 98		}
 99	}
100
101	return commands, nil
102}
103
104func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
105	if err := ensureDir(source.path); err != nil {
106		return nil, err
107	}
108
109	var commands []Command
110
111	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
112		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
113			return err
114		}
115
116		cmd, err := l.loadCommand(path, source.path, source.prefix)
117		if err != nil {
118			return nil // Skip invalid files
119		}
120
121		commands = append(commands, cmd)
122		return nil
123	})
124
125	return commands, err
126}
127
128func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
129	content, err := os.ReadFile(path)
130	if err != nil {
131		return Command{}, err
132	}
133
134	id := buildCommandID(path, baseDir, prefix)
135	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
136
137	return Command{
138		ID:          id,
139		Title:       id,
140		Description: desc,
141		Handler:     createCommandHandler(id, desc, string(content)),
142	}, nil
143}
144
145func buildCommandID(path, baseDir, prefix string) string {
146	relPath, _ := filepath.Rel(baseDir, path)
147	parts := strings.Split(relPath, string(filepath.Separator))
148
149	// Remove .md extension from last part
150	if len(parts) > 0 {
151		lastIdx := len(parts) - 1
152		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
153	}
154
155	return prefix + strings.Join(parts, ":")
156}
157
158func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
159	return func(cmd Command) tea.Cmd {
160		args := extractArgNames(content)
161
162		if len(args) > 0 {
163			return util.CmdHandler(ShowArgumentsDialogMsg{
164				CommandID:   id,
165				Description: desc,
166				Content:     content,
167				ArgNames:    args,
168			})
169		}
170
171		return util.CmdHandler(CommandRunCustomMsg{
172			Content: content,
173		})
174	}
175}
176
177func extractArgNames(content string) []string {
178	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
179	if len(matches) == 0 {
180		return nil
181	}
182
183	seen := make(map[string]bool)
184	var args []string
185
186	for _, match := range matches {
187		arg := match[1]
188		if !seen[arg] {
189			seen[arg] = true
190			args = append(args, arg)
191		}
192	}
193
194	return args
195}
196
197func ensureDir(path string) error {
198	if _, err := os.Stat(path); os.IsNotExist(err) {
199		return os.MkdirAll(path, 0o755)
200	}
201	return nil
202}
203
204func isMarkdownFile(name string) bool {
205	return strings.HasSuffix(strings.ToLower(name), ".md")
206}
207
208type CommandRunCustomMsg struct {
209	Content string
210}
211
212func LoadMCPPrompts() []Command {
213	prompts := agent.GetMCPPrompts()
214	commands := make([]Command, 0, len(prompts))
215
216	for key, prompt := range prompts {
217		p := prompt
218		clientName, promptName, ok := strings.Cut(key, ":")
219		if !ok {
220			continue
221		}
222		commands = append(commands, Command{
223			ID:          key,
224			Title:       clientName + ":" + promptName,
225			Description: p.Description,
226			Handler:     createMCPPromptHandler(key, promptName, p),
227		})
228	}
229
230	return commands
231}
232
233func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
234	return func(cmd Command) tea.Cmd {
235		if len(prompt.Arguments) == 0 {
236			return executeMCPPromptWithoutArgs(key, promptName)
237		}
238		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
239			PromptID:   cmd.ID,
240			PromptName: promptName,
241		})
242	}
243}
244
245func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
246	return func() tea.Msg {
247		// key format is "clientName:promptName"
248		parts := strings.SplitN(key, ":", 2)
249		if len(parts) != 2 {
250			return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
251		}
252		clientName := parts[0]
253
254		ctx := context.Background()
255		result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
256		if err != nil {
257			return util.ReportError(err)
258		}
259
260		var content strings.Builder
261		for _, msg := range result.Messages {
262			if msg.Role == "user" {
263				if textContent, ok := msg.Content.(*mcp.TextContent); ok {
264					content.WriteString(textContent.Text)
265					content.WriteString("\n")
266				}
267			}
268		}
269
270		return chat.SendMsg{
271			Text: content.String(),
272		}
273	}
274}
275
276type ShowMCPPromptArgumentsDialogMsg struct {
277	PromptID   string
278	PromptName string
279}