loader.go

  1package commands
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"io/fs"
  8	"os"
  9	"path/filepath"
 10	"regexp"
 11	"strings"
 12
 13	tea "github.com/charmbracelet/bubbletea/v2"
 14	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/home"
 17	"github.com/charmbracelet/crush/internal/tui/components/chat"
 18	"github.com/charmbracelet/crush/internal/tui/util"
 19)
 20
 21const (
 22	userCommandPrefix    = "user:"
 23	projectCommandPrefix = "project:"
 24)
 25
 26var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 27
 28type commandLoader struct {
 29	sources []commandSource
 30}
 31
 32type commandSource struct {
 33	path   string
 34	prefix string
 35}
 36
 37func LoadCustomCommands() ([]Command, error) {
 38	cfg := config.Get()
 39	if cfg == nil {
 40		return nil, fmt.Errorf("config not loaded")
 41	}
 42
 43	loader := &commandLoader{
 44		sources: buildCommandSources(cfg),
 45	}
 46
 47	return loader.loadAll()
 48}
 49
 50func buildCommandSources(cfg *config.Config) []commandSource {
 51	var sources []commandSource
 52
 53	// XDG config directory
 54	if dir := getXDGCommandsDir(); dir != "" {
 55		sources = append(sources, commandSource{
 56			path:   dir,
 57			prefix: userCommandPrefix,
 58		})
 59	}
 60
 61	// Home directory
 62	if home := home.Dir(); home != "" {
 63		sources = append(sources, commandSource{
 64			path:   filepath.Join(home, ".crush", "commands"),
 65			prefix: userCommandPrefix,
 66		})
 67	}
 68
 69	// Project directory
 70	sources = append(sources, commandSource{
 71		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 72		prefix: projectCommandPrefix,
 73	})
 74
 75	return sources
 76}
 77
 78func getXDGCommandsDir() string {
 79	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 80	if xdgHome == "" {
 81		if home := home.Dir(); home != "" {
 82			xdgHome = filepath.Join(home, ".config")
 83		}
 84	}
 85	if xdgHome != "" {
 86		return filepath.Join(xdgHome, "crush", "commands")
 87	}
 88	return ""
 89}
 90
 91func (l *commandLoader) loadAll() ([]Command, error) {
 92	var commands []Command
 93
 94	for _, source := range l.sources {
 95		if cmds, err := l.loadFromSource(source); err == nil {
 96			commands = append(commands, cmds...)
 97		}
 98	}
 99
100	return commands, nil
101}
102
103func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
104	if err := ensureDir(source.path); err != nil {
105		return nil, err
106	}
107
108	var commands []Command
109
110	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
111		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
112			return err
113		}
114
115		cmd, err := l.loadCommand(path, source.path, source.prefix)
116		if err != nil {
117			return nil // Skip invalid files
118		}
119
120		commands = append(commands, cmd)
121		return nil
122	})
123
124	return commands, err
125}
126
127func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
128	content, err := os.ReadFile(path)
129	if err != nil {
130		return Command{}, err
131	}
132
133	id := buildCommandID(path, baseDir, prefix)
134	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
135
136	return Command{
137		ID:          id,
138		Title:       id,
139		Description: desc,
140		Handler:     createCommandHandler(id, desc, string(content)),
141	}, nil
142}
143
144func buildCommandID(path, baseDir, prefix string) string {
145	relPath, _ := filepath.Rel(baseDir, path)
146	parts := strings.Split(relPath, string(filepath.Separator))
147
148	// Remove .md extension from last part
149	if len(parts) > 0 {
150		lastIdx := len(parts) - 1
151		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
152	}
153
154	return prefix + strings.Join(parts, ":")
155}
156
157func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
158	return func(cmd Command) tea.Cmd {
159		args := extractArgNames(content)
160
161		if len(args) == 0 {
162			return util.CmdHandler(CommandRunCustomMsg{
163				Content: content,
164			})
165		}
166		return util.CmdHandler(ShowArgumentsDialogMsg{
167			CommandID:   id,
168			Description: desc,
169			ArgNames:    args,
170			OnSubmit: func(args map[string]string) tea.Cmd {
171				return execUserPrompt(content, args)
172			},
173		})
174	}
175}
176
177func execUserPrompt(content string, args map[string]string) tea.Cmd {
178	return func() tea.Msg {
179		for name, value := range args {
180			placeholder := "$" + name
181			content = strings.ReplaceAll(content, placeholder, value)
182		}
183		return CommandRunCustomMsg{
184			Content: content,
185		}
186	}
187}
188
189func extractArgNames(content string) []string {
190	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
191	if len(matches) == 0 {
192		return nil
193	}
194
195	seen := make(map[string]bool)
196	var args []string
197
198	for _, match := range matches {
199		arg := match[1]
200		if !seen[arg] {
201			seen[arg] = true
202			args = append(args, arg)
203		}
204	}
205
206	return args
207}
208
209func ensureDir(path string) error {
210	if _, err := os.Stat(path); os.IsNotExist(err) {
211		return os.MkdirAll(path, 0o755)
212	}
213	return nil
214}
215
216func isMarkdownFile(name string) bool {
217	return strings.HasSuffix(strings.ToLower(name), ".md")
218}
219
220type CommandRunCustomMsg struct {
221	Content string
222}
223
224func loadMCPPrompts() []Command {
225	var commands []Command
226	for mcpName, prompts := range mcp.Prompts() {
227		for _, prompt := range prompts {
228			key := mcpName + ":" + prompt.Name
229			commands = append(commands, Command{
230				ID:          key,
231				Title:       cmp.Or(prompt.Title, prompt.Name),
232				Description: prompt.Description,
233				Handler:     createMCPPromptHandler(mcpName, prompt.Name, prompt),
234			})
235		}
236	}
237
238	return commands
239}
240
241func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
242	return func(cmd Command) tea.Cmd {
243		if len(prompt.Arguments) == 0 {
244			return execMCPPrompt(mcpName, promptName, nil)
245		}
246		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
247			Prompt: prompt,
248			OnSubmit: func(args map[string]string) tea.Cmd {
249				return execMCPPrompt(mcpName, promptName, args)
250			},
251		})
252	}
253}
254
255func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
256	return func() tea.Msg {
257		ctx := context.Background()
258		result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
259		if err != nil {
260			return util.ReportError(err)
261		}
262
263		return chat.SendMsg{
264			Text: strings.Join(result, " "),
265		}
266	}
267}
268
269type ShowMCPPromptArgumentsDialogMsg struct {
270	Prompt   *mcp.Prompt
271	OnSubmit func(arg map[string]string) tea.Cmd
272}