1// Package uicmd provides functionality to load and handle custom commands
2// from markdown files and MCP prompts.
3// TODO: Move this into internal/ui after refactoring.
4package uicmd
5
6import (
7 "cmp"
8 "context"
9 "fmt"
10 "io/fs"
11 "os"
12 "path/filepath"
13 "regexp"
14 "strings"
15
16 tea "charm.land/bubbletea/v2"
17 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/home"
20 "github.com/charmbracelet/crush/internal/tui/components/chat"
21 "github.com/charmbracelet/crush/internal/tui/util"
22)
23
24// Command represents a command that can be executed
25type Command struct {
26 ID string
27 Title string
28 Description string
29 Shortcut string // Optional shortcut for the command
30 Handler func(cmd Command) tea.Cmd
31}
32
33// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
34type ShowArgumentsDialogMsg struct {
35 CommandID string
36 Description string
37 ArgNames []string
38 OnSubmit func(args map[string]string) tea.Cmd
39}
40
41// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
42type CloseArgumentsDialogMsg struct {
43 Submit bool
44 CommandID string
45 Content string
46 Args map[string]string
47}
48
49const (
50 userCommandPrefix = "user:"
51 projectCommandPrefix = "project:"
52)
53
54var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
55
56type commandLoader struct {
57 sources []commandSource
58}
59
60type commandSource struct {
61 path string
62 prefix string
63}
64
65func LoadCustomCommands() ([]Command, error) {
66 return LoadCustomCommandsFromConfig(config.Get())
67}
68
69func LoadCustomCommandsFromConfig(cfg *config.Config) ([]Command, error) {
70 if cfg == nil {
71 return nil, fmt.Errorf("config not loaded")
72 }
73
74 loader := &commandLoader{
75 sources: buildCommandSources(cfg),
76 }
77
78 return loader.loadAll()
79}
80
81func buildCommandSources(cfg *config.Config) []commandSource {
82 var sources []commandSource
83
84 // XDG config directory
85 if dir := getXDGCommandsDir(); dir != "" {
86 sources = append(sources, commandSource{
87 path: dir,
88 prefix: userCommandPrefix,
89 })
90 }
91
92 // Home directory
93 if home := home.Dir(); home != "" {
94 sources = append(sources, commandSource{
95 path: filepath.Join(home, ".crush", "commands"),
96 prefix: userCommandPrefix,
97 })
98 }
99
100 // Project directory
101 sources = append(sources, commandSource{
102 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
103 prefix: projectCommandPrefix,
104 })
105
106 return sources
107}
108
109func getXDGCommandsDir() string {
110 xdgHome := os.Getenv("XDG_CONFIG_HOME")
111 if xdgHome == "" {
112 if home := home.Dir(); home != "" {
113 xdgHome = filepath.Join(home, ".config")
114 }
115 }
116 if xdgHome != "" {
117 return filepath.Join(xdgHome, "crush", "commands")
118 }
119 return ""
120}
121
122func (l *commandLoader) loadAll() ([]Command, error) {
123 var commands []Command
124
125 for _, source := range l.sources {
126 if cmds, err := l.loadFromSource(source); err == nil {
127 commands = append(commands, cmds...)
128 }
129 }
130
131 return commands, nil
132}
133
134func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
135 if err := ensureDir(source.path); err != nil {
136 return nil, err
137 }
138
139 var commands []Command
140
141 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
142 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
143 return err
144 }
145
146 cmd, err := l.loadCommand(path, source.path, source.prefix)
147 if err != nil {
148 return nil // Skip invalid files
149 }
150
151 commands = append(commands, cmd)
152 return nil
153 })
154
155 return commands, err
156}
157
158func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
159 content, err := os.ReadFile(path)
160 if err != nil {
161 return Command{}, err
162 }
163
164 id := buildCommandID(path, baseDir, prefix)
165 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
166
167 return Command{
168 ID: id,
169 Title: id,
170 Description: desc,
171 Handler: createCommandHandler(id, desc, string(content)),
172 }, nil
173}
174
175func buildCommandID(path, baseDir, prefix string) string {
176 relPath, _ := filepath.Rel(baseDir, path)
177 parts := strings.Split(relPath, string(filepath.Separator))
178
179 // Remove .md extension from last part
180 if len(parts) > 0 {
181 lastIdx := len(parts) - 1
182 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
183 }
184
185 return prefix + strings.Join(parts, ":")
186}
187
188func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
189 return func(cmd Command) tea.Cmd {
190 args := extractArgNames(content)
191
192 if len(args) == 0 {
193 return util.CmdHandler(CommandRunCustomMsg{
194 Content: content,
195 })
196 }
197 return util.CmdHandler(ShowArgumentsDialogMsg{
198 CommandID: id,
199 Description: desc,
200 ArgNames: args,
201 OnSubmit: func(args map[string]string) tea.Cmd {
202 return execUserPrompt(content, args)
203 },
204 })
205 }
206}
207
208func execUserPrompt(content string, args map[string]string) tea.Cmd {
209 return func() tea.Msg {
210 for name, value := range args {
211 placeholder := "$" + name
212 content = strings.ReplaceAll(content, placeholder, value)
213 }
214 return CommandRunCustomMsg{
215 Content: content,
216 }
217 }
218}
219
220func extractArgNames(content string) []string {
221 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
222 if len(matches) == 0 {
223 return nil
224 }
225
226 seen := make(map[string]bool)
227 var args []string
228
229 for _, match := range matches {
230 arg := match[1]
231 if !seen[arg] {
232 seen[arg] = true
233 args = append(args, arg)
234 }
235 }
236
237 return args
238}
239
240func ensureDir(path string) error {
241 if _, err := os.Stat(path); os.IsNotExist(err) {
242 return os.MkdirAll(path, 0o755)
243 }
244 return nil
245}
246
247func isMarkdownFile(name string) bool {
248 return strings.HasSuffix(strings.ToLower(name), ".md")
249}
250
251type CommandRunCustomMsg struct {
252 Content string
253}
254
255func LoadMCPPrompts() []Command {
256 var commands []Command
257 for mcpName, prompts := range mcp.Prompts() {
258 for _, prompt := range prompts {
259 key := mcpName + ":" + prompt.Name
260 commands = append(commands, Command{
261 ID: key,
262 Title: cmp.Or(prompt.Title, prompt.Name),
263 Description: prompt.Description,
264 Handler: createMCPPromptHandler(mcpName, prompt.Name, prompt),
265 })
266 }
267 }
268
269 return commands
270}
271
272func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
273 return func(cmd Command) tea.Cmd {
274 if len(prompt.Arguments) == 0 {
275 return execMCPPrompt(mcpName, promptName, nil)
276 }
277 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
278 Prompt: prompt,
279 OnSubmit: func(args map[string]string) tea.Cmd {
280 return execMCPPrompt(mcpName, promptName, args)
281 },
282 })
283 }
284}
285
286func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
287 return func() tea.Msg {
288 ctx := context.Background()
289 result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
290 if err != nil {
291 return util.ReportError(err)
292 }
293
294 return chat.SendMsg{
295 Text: strings.Join(result, " "),
296 }
297 }
298}
299
300type ShowMCPPromptArgumentsDialogMsg struct {
301 Prompt *mcp.Prompt
302 OnSubmit func(arg map[string]string) tea.Cmd
303}