uicmd.go

  1// Package uicmd provides functionality to load and handle custom commands
  2// from markdown files and MCP prompts.
  3// TODO: Move this into internal/ui after refactoring.
  4package uicmd
  5
  6import (
  7	"cmp"
  8	"context"
  9	"fmt"
 10	"io/fs"
 11	"os"
 12	"path/filepath"
 13	"regexp"
 14	"strings"
 15
 16	tea "charm.land/bubbletea/v2"
 17	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/tui/components/chat"
 21	"github.com/charmbracelet/crush/internal/tui/util"
 22)
 23
 24// Command represents a command that can be executed
 25type Command struct {
 26	ID          string
 27	Title       string
 28	Description string
 29	Shortcut    string // Optional shortcut for the command
 30	Handler     func(cmd Command) tea.Cmd
 31}
 32
 33// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
 34type ShowArgumentsDialogMsg struct {
 35	CommandID   string
 36	Description string
 37	ArgNames    []string
 38	OnSubmit    func(args map[string]string) tea.Cmd
 39}
 40
 41// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
 42type CloseArgumentsDialogMsg struct {
 43	Submit    bool
 44	CommandID string
 45	Content   string
 46	Args      map[string]string
 47}
 48
 49const (
 50	userCommandPrefix    = "user:"
 51	projectCommandPrefix = "project:"
 52)
 53
 54var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 55
 56type commandLoader struct {
 57	sources []commandSource
 58}
 59
 60type commandSource struct {
 61	path   string
 62	prefix string
 63}
 64
 65func LoadCustomCommands() ([]Command, error) {
 66	return LoadCustomCommandsFromConfig(config.Get())
 67}
 68
 69func LoadCustomCommandsFromConfig(cfg *config.Config) ([]Command, error) {
 70	if cfg == nil {
 71		return nil, fmt.Errorf("config not loaded")
 72	}
 73
 74	loader := &commandLoader{
 75		sources: buildCommandSources(cfg),
 76	}
 77
 78	return loader.loadAll()
 79}
 80
 81func buildCommandSources(cfg *config.Config) []commandSource {
 82	var sources []commandSource
 83
 84	// XDG config directory
 85	if dir := getXDGCommandsDir(); dir != "" {
 86		sources = append(sources, commandSource{
 87			path:   dir,
 88			prefix: userCommandPrefix,
 89		})
 90	}
 91
 92	// Home directory
 93	if home := home.Dir(); home != "" {
 94		sources = append(sources, commandSource{
 95			path:   filepath.Join(home, ".crush", "commands"),
 96			prefix: userCommandPrefix,
 97		})
 98	}
 99
100	// Project directory
101	sources = append(sources, commandSource{
102		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
103		prefix: projectCommandPrefix,
104	})
105
106	return sources
107}
108
109func getXDGCommandsDir() string {
110	xdgHome := os.Getenv("XDG_CONFIG_HOME")
111	if xdgHome == "" {
112		if home := home.Dir(); home != "" {
113			xdgHome = filepath.Join(home, ".config")
114		}
115	}
116	if xdgHome != "" {
117		return filepath.Join(xdgHome, "crush", "commands")
118	}
119	return ""
120}
121
122func (l *commandLoader) loadAll() ([]Command, error) {
123	var commands []Command
124
125	for _, source := range l.sources {
126		if cmds, err := l.loadFromSource(source); err == nil {
127			commands = append(commands, cmds...)
128		}
129	}
130
131	return commands, nil
132}
133
134func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
135	if err := ensureDir(source.path); err != nil {
136		return nil, err
137	}
138
139	var commands []Command
140
141	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
142		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
143			return err
144		}
145
146		cmd, err := l.loadCommand(path, source.path, source.prefix)
147		if err != nil {
148			return nil // Skip invalid files
149		}
150
151		commands = append(commands, cmd)
152		return nil
153	})
154
155	return commands, err
156}
157
158func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
159	content, err := os.ReadFile(path)
160	if err != nil {
161		return Command{}, err
162	}
163
164	id := buildCommandID(path, baseDir, prefix)
165	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
166
167	return Command{
168		ID:          id,
169		Title:       id,
170		Description: desc,
171		Handler:     createCommandHandler(id, desc, string(content)),
172	}, nil
173}
174
175func buildCommandID(path, baseDir, prefix string) string {
176	relPath, _ := filepath.Rel(baseDir, path)
177	parts := strings.Split(relPath, string(filepath.Separator))
178
179	// Remove .md extension from last part
180	if len(parts) > 0 {
181		lastIdx := len(parts) - 1
182		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
183	}
184
185	return prefix + strings.Join(parts, ":")
186}
187
188func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
189	return func(cmd Command) tea.Cmd {
190		args := extractArgNames(content)
191
192		if len(args) == 0 {
193			return util.CmdHandler(CommandRunCustomMsg{
194				Content: content,
195			})
196		}
197		return util.CmdHandler(ShowArgumentsDialogMsg{
198			CommandID:   id,
199			Description: desc,
200			ArgNames:    args,
201			OnSubmit: func(args map[string]string) tea.Cmd {
202				return execUserPrompt(content, args)
203			},
204		})
205	}
206}
207
208func execUserPrompt(content string, args map[string]string) tea.Cmd {
209	return func() tea.Msg {
210		for name, value := range args {
211			placeholder := "$" + name
212			content = strings.ReplaceAll(content, placeholder, value)
213		}
214		return CommandRunCustomMsg{
215			Content: content,
216		}
217	}
218}
219
220func extractArgNames(content string) []string {
221	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
222	if len(matches) == 0 {
223		return nil
224	}
225
226	seen := make(map[string]bool)
227	var args []string
228
229	for _, match := range matches {
230		arg := match[1]
231		if !seen[arg] {
232			seen[arg] = true
233			args = append(args, arg)
234		}
235	}
236
237	return args
238}
239
240func ensureDir(path string) error {
241	if _, err := os.Stat(path); os.IsNotExist(err) {
242		return os.MkdirAll(path, 0o755)
243	}
244	return nil
245}
246
247func isMarkdownFile(name string) bool {
248	return strings.HasSuffix(strings.ToLower(name), ".md")
249}
250
251type CommandRunCustomMsg struct {
252	Content string
253}
254
255func LoadMCPPrompts() []Command {
256	var commands []Command
257	for mcpName, prompts := range mcp.Prompts() {
258		for _, prompt := range prompts {
259			key := mcpName + ":" + prompt.Name
260			commands = append(commands, Command{
261				ID:          key,
262				Title:       cmp.Or(prompt.Title, prompt.Name),
263				Description: prompt.Description,
264				Handler:     createMCPPromptHandler(mcpName, prompt.Name, prompt),
265			})
266		}
267	}
268
269	return commands
270}
271
272func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
273	return func(cmd Command) tea.Cmd {
274		if len(prompt.Arguments) == 0 {
275			return execMCPPrompt(mcpName, promptName, nil)
276		}
277		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
278			Prompt: prompt,
279			OnSubmit: func(args map[string]string) tea.Cmd {
280				return execMCPPrompt(mcpName, promptName, args)
281			},
282		})
283	}
284}
285
286func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
287	return func() tea.Msg {
288		ctx := context.Background()
289		result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
290		if err != nil {
291			return util.ReportError(err)
292		}
293
294		return chat.SendMsg{
295			Text: strings.Join(result, " "),
296		}
297	}
298}
299
300type ShowMCPPromptArgumentsDialogMsg struct {
301	Prompt   *mcp.Prompt
302	OnSubmit func(arg map[string]string) tea.Cmd
303}