uicmd.go

  1// Package uicmd provides functionality to load and handle custom commands
  2// from markdown files and MCP prompts.
  3// TODO: Move this into internal/ui after refactoring.
  4package uicmd
  5
  6import (
  7	"cmp"
  8	"context"
  9	"fmt"
 10	"io/fs"
 11	"os"
 12	"path/filepath"
 13	"regexp"
 14	"strings"
 15
 16	tea "charm.land/bubbletea/v2"
 17	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/tui/components/chat"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22)
 23
 24type CommandType uint
 25
 26func (c CommandType) String() string { return []string{"System", "User", "MCP"}[c] }
 27
 28const (
 29	SystemCommands CommandType = iota
 30	UserCommands
 31	MCPPrompts
 32)
 33
 34// Command represents a command that can be executed
 35type Command struct {
 36	ID          string
 37	Title       string
 38	Description string
 39	Shortcut    string // Optional shortcut for the command
 40	Handler     func(cmd Command) tea.Cmd
 41}
 42
 43// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
 44type ShowArgumentsDialogMsg struct {
 45	CommandID   string
 46	Description string
 47	ArgNames    []string
 48	OnSubmit    func(args map[string]string) tea.Cmd
 49}
 50
 51// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
 52type CloseArgumentsDialogMsg struct {
 53	Submit    bool
 54	CommandID string
 55	Content   string
 56	Args      map[string]string
 57}
 58
 59const (
 60	userCommandPrefix    = "user:"
 61	projectCommandPrefix = "project:"
 62)
 63
 64var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 65
 66type commandLoader struct {
 67	sources []commandSource
 68}
 69
 70type commandSource struct {
 71	path   string
 72	prefix string
 73}
 74
 75func LoadCustomCommands() ([]Command, error) {
 76	return LoadCustomCommandsFromConfig(config.Get())
 77}
 78
 79func LoadCustomCommandsFromConfig(cfg *config.Config) ([]Command, error) {
 80	if cfg == nil {
 81		return nil, fmt.Errorf("config not loaded")
 82	}
 83
 84	loader := &commandLoader{
 85		sources: buildCommandSources(cfg),
 86	}
 87
 88	return loader.loadAll()
 89}
 90
 91func buildCommandSources(cfg *config.Config) []commandSource {
 92	var sources []commandSource
 93
 94	// XDG config directory
 95	if dir := getXDGCommandsDir(); dir != "" {
 96		sources = append(sources, commandSource{
 97			path:   dir,
 98			prefix: userCommandPrefix,
 99		})
100	}
101
102	// Home directory
103	if home := home.Dir(); home != "" {
104		sources = append(sources, commandSource{
105			path:   filepath.Join(home, ".crush", "commands"),
106			prefix: userCommandPrefix,
107		})
108	}
109
110	// Project directory
111	sources = append(sources, commandSource{
112		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
113		prefix: projectCommandPrefix,
114	})
115
116	return sources
117}
118
119func getXDGCommandsDir() string {
120	xdgHome := os.Getenv("XDG_CONFIG_HOME")
121	if xdgHome == "" {
122		if home := home.Dir(); home != "" {
123			xdgHome = filepath.Join(home, ".config")
124		}
125	}
126	if xdgHome != "" {
127		return filepath.Join(xdgHome, "crush", "commands")
128	}
129	return ""
130}
131
132func (l *commandLoader) loadAll() ([]Command, error) {
133	var commands []Command
134
135	for _, source := range l.sources {
136		if cmds, err := l.loadFromSource(source); err == nil {
137			commands = append(commands, cmds...)
138		}
139	}
140
141	return commands, nil
142}
143
144func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
145	if err := ensureDir(source.path); err != nil {
146		return nil, err
147	}
148
149	var commands []Command
150
151	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
152		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
153			return err
154		}
155
156		cmd, err := l.loadCommand(path, source.path, source.prefix)
157		if err != nil {
158			return nil // Skip invalid files
159		}
160
161		commands = append(commands, cmd)
162		return nil
163	})
164
165	return commands, err
166}
167
168func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
169	content, err := os.ReadFile(path)
170	if err != nil {
171		return Command{}, err
172	}
173
174	id := buildCommandID(path, baseDir, prefix)
175	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
176
177	return Command{
178		ID:          id,
179		Title:       id,
180		Description: desc,
181		Handler:     createCommandHandler(id, desc, string(content)),
182	}, nil
183}
184
185func buildCommandID(path, baseDir, prefix string) string {
186	relPath, _ := filepath.Rel(baseDir, path)
187	parts := strings.Split(relPath, string(filepath.Separator))
188
189	// Remove .md extension from last part
190	if len(parts) > 0 {
191		lastIdx := len(parts) - 1
192		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
193	}
194
195	return prefix + strings.Join(parts, ":")
196}
197
198func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
199	return func(cmd Command) tea.Cmd {
200		args := extractArgNames(content)
201
202		if len(args) == 0 {
203			return util.CmdHandler(CommandRunCustomMsg{
204				Content: content,
205			})
206		}
207		return util.CmdHandler(ShowArgumentsDialogMsg{
208			CommandID:   id,
209			Description: desc,
210			ArgNames:    args,
211			OnSubmit: func(args map[string]string) tea.Cmd {
212				return execUserPrompt(content, args)
213			},
214		})
215	}
216}
217
218func execUserPrompt(content string, args map[string]string) tea.Cmd {
219	return func() tea.Msg {
220		for name, value := range args {
221			placeholder := "$" + name
222			content = strings.ReplaceAll(content, placeholder, value)
223		}
224		return CommandRunCustomMsg{
225			Content: content,
226		}
227	}
228}
229
230func extractArgNames(content string) []string {
231	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
232	if len(matches) == 0 {
233		return nil
234	}
235
236	seen := make(map[string]bool)
237	var args []string
238
239	for _, match := range matches {
240		arg := match[1]
241		if !seen[arg] {
242			seen[arg] = true
243			args = append(args, arg)
244		}
245	}
246
247	return args
248}
249
250func ensureDir(path string) error {
251	if _, err := os.Stat(path); os.IsNotExist(err) {
252		return os.MkdirAll(path, 0o755)
253	}
254	return nil
255}
256
257func isMarkdownFile(name string) bool {
258	return strings.HasSuffix(strings.ToLower(name), ".md")
259}
260
261type CommandRunCustomMsg struct {
262	Content string
263}
264
265func LoadMCPPrompts() []Command {
266	var commands []Command
267	for mcpName, prompts := range mcp.Prompts() {
268		for _, prompt := range prompts {
269			key := mcpName + ":" + prompt.Name
270			commands = append(commands, Command{
271				ID:          key,
272				Title:       cmp.Or(prompt.Title, prompt.Name),
273				Description: prompt.Description,
274				Handler:     createMCPPromptHandler(mcpName, prompt.Name, prompt),
275			})
276		}
277	}
278
279	return commands
280}
281
282func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
283	return func(cmd Command) tea.Cmd {
284		if len(prompt.Arguments) == 0 {
285			return execMCPPrompt(mcpName, promptName, nil)
286		}
287		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
288			Prompt: prompt,
289			OnSubmit: func(args map[string]string) tea.Cmd {
290				return execMCPPrompt(mcpName, promptName, args)
291			},
292		})
293	}
294}
295
296func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
297	return func() tea.Msg {
298		ctx := context.Background()
299		result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
300		if err != nil {
301			return util.ReportError(err)
302		}
303
304		return chat.SendMsg{
305			Text: strings.Join(result, " "),
306		}
307	}
308}
309
310type ShowMCPPromptArgumentsDialogMsg struct {
311	Prompt   *mcp.Prompt
312	OnSubmit func(arg map[string]string) tea.Cmd
313}