1// Package uicmd provides functionality to load and handle custom commands
2// from markdown files and MCP prompts.
3// TODO: Move this into internal/ui after refactoring.
4package uicmd
5
6import (
7 "cmp"
8 "context"
9 "fmt"
10 "io/fs"
11 "os"
12 "path/filepath"
13 "regexp"
14 "strings"
15
16 tea "charm.land/bubbletea/v2"
17 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/tui/components/chat"
21 "github.com/charmbracelet/crush/internal/tui/util"
22)
23
24type CommandType uint
25
26func (c CommandType) String() string { return []string{"System", "User", "MCP"}[c] }
27
28const (
29 SystemCommands CommandType = iota
30 UserCommands
31 MCPPrompts
32)
33
34// Command represents a command that can be executed
35type Command struct {
36 ID string
37 Title string
38 Description string
39 Shortcut string // Optional shortcut for the command
40 Handler func(cmd Command) tea.Cmd
41}
42
43// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
44type ShowArgumentsDialogMsg struct {
45 CommandID string
46 Description string
47 ArgNames []string
48 OnSubmit func(args map[string]string) tea.Cmd
49}
50
51// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
52type CloseArgumentsDialogMsg struct {
53 Submit bool
54 CommandID string
55 Content string
56 Args map[string]string
57}
58
59const (
60 userCommandPrefix = "user:"
61 projectCommandPrefix = "project:"
62)
63
64var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
65
66type commandLoader struct {
67 sources []commandSource
68}
69
70type commandSource struct {
71 path string
72 prefix string
73}
74
75func LoadCustomCommands() ([]Command, error) {
76 return LoadCustomCommandsFromConfig(config.Get())
77}
78
79func LoadCustomCommandsFromConfig(cfg *config.Config) ([]Command, error) {
80 if cfg == nil {
81 return nil, fmt.Errorf("config not loaded")
82 }
83
84 loader := &commandLoader{
85 sources: buildCommandSources(cfg),
86 }
87
88 return loader.loadAll()
89}
90
91func buildCommandSources(cfg *config.Config) []commandSource {
92 var sources []commandSource
93
94 // XDG config directory
95 if dir := getXDGCommandsDir(); dir != "" {
96 sources = append(sources, commandSource{
97 path: dir,
98 prefix: userCommandPrefix,
99 })
100 }
101
102 // Home directory
103 if home := home.Dir(); home != "" {
104 sources = append(sources, commandSource{
105 path: filepath.Join(home, ".crush", "commands"),
106 prefix: userCommandPrefix,
107 })
108 }
109
110 // Project directory
111 sources = append(sources, commandSource{
112 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
113 prefix: projectCommandPrefix,
114 })
115
116 return sources
117}
118
119func getXDGCommandsDir() string {
120 xdgHome := os.Getenv("XDG_CONFIG_HOME")
121 if xdgHome == "" {
122 if home := home.Dir(); home != "" {
123 xdgHome = filepath.Join(home, ".config")
124 }
125 }
126 if xdgHome != "" {
127 return filepath.Join(xdgHome, "crush", "commands")
128 }
129 return ""
130}
131
132func (l *commandLoader) loadAll() ([]Command, error) {
133 var commands []Command
134
135 for _, source := range l.sources {
136 if cmds, err := l.loadFromSource(source); err == nil {
137 commands = append(commands, cmds...)
138 }
139 }
140
141 return commands, nil
142}
143
144func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
145 if err := ensureDir(source.path); err != nil {
146 return nil, err
147 }
148
149 var commands []Command
150
151 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
152 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
153 return err
154 }
155
156 cmd, err := l.loadCommand(path, source.path, source.prefix)
157 if err != nil {
158 return nil // Skip invalid files
159 }
160
161 commands = append(commands, cmd)
162 return nil
163 })
164
165 return commands, err
166}
167
168func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
169 content, err := os.ReadFile(path)
170 if err != nil {
171 return Command{}, err
172 }
173
174 id := buildCommandID(path, baseDir, prefix)
175 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
176
177 return Command{
178 ID: id,
179 Title: id,
180 Description: desc,
181 Handler: createCommandHandler(id, desc, string(content)),
182 }, nil
183}
184
185func buildCommandID(path, baseDir, prefix string) string {
186 relPath, _ := filepath.Rel(baseDir, path)
187 parts := strings.Split(relPath, string(filepath.Separator))
188
189 // Remove .md extension from last part
190 if len(parts) > 0 {
191 lastIdx := len(parts) - 1
192 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
193 }
194
195 return prefix + strings.Join(parts, ":")
196}
197
198func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
199 return func(cmd Command) tea.Cmd {
200 args := extractArgNames(content)
201
202 if len(args) == 0 {
203 return util.CmdHandler(CommandRunCustomMsg{
204 Content: content,
205 })
206 }
207 return util.CmdHandler(ShowArgumentsDialogMsg{
208 CommandID: id,
209 Description: desc,
210 ArgNames: args,
211 OnSubmit: func(args map[string]string) tea.Cmd {
212 return execUserPrompt(content, args)
213 },
214 })
215 }
216}
217
218func execUserPrompt(content string, args map[string]string) tea.Cmd {
219 return func() tea.Msg {
220 for name, value := range args {
221 placeholder := "$" + name
222 content = strings.ReplaceAll(content, placeholder, value)
223 }
224 return CommandRunCustomMsg{
225 Content: content,
226 }
227 }
228}
229
230func extractArgNames(content string) []string {
231 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
232 if len(matches) == 0 {
233 return nil
234 }
235
236 seen := make(map[string]bool)
237 var args []string
238
239 for _, match := range matches {
240 arg := match[1]
241 if !seen[arg] {
242 seen[arg] = true
243 args = append(args, arg)
244 }
245 }
246
247 return args
248}
249
250func ensureDir(path string) error {
251 if _, err := os.Stat(path); os.IsNotExist(err) {
252 return os.MkdirAll(path, 0o755)
253 }
254 return nil
255}
256
257func isMarkdownFile(name string) bool {
258 return strings.HasSuffix(strings.ToLower(name), ".md")
259}
260
261type CommandRunCustomMsg struct {
262 Content string
263}
264
265func LoadMCPPrompts() []Command {
266 var commands []Command
267 for mcpName, prompts := range mcp.Prompts() {
268 for _, prompt := range prompts {
269 key := mcpName + ":" + prompt.Name
270 commands = append(commands, Command{
271 ID: key,
272 Title: cmp.Or(prompt.Title, prompt.Name),
273 Description: prompt.Description,
274 Handler: createMCPPromptHandler(mcpName, prompt.Name, prompt),
275 })
276 }
277 }
278
279 return commands
280}
281
282func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
283 return func(cmd Command) tea.Cmd {
284 if len(prompt.Arguments) == 0 {
285 return execMCPPrompt(mcpName, promptName, nil)
286 }
287 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
288 Prompt: prompt,
289 OnSubmit: func(args map[string]string) tea.Cmd {
290 return execMCPPrompt(mcpName, promptName, args)
291 },
292 })
293 }
294}
295
296func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
297 return func() tea.Msg {
298 ctx := context.Background()
299 result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
300 if err != nil {
301 return util.ReportError(err)
302 }
303
304 return chat.SendMsg{
305 Text: strings.Join(result, " "),
306 }
307 }
308}
309
310type ShowMCPPromptArgumentsDialogMsg struct {
311 Prompt *mcp.Prompt
312 OnSubmit func(arg map[string]string) tea.Cmd
313}