1package commands
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "io/fs"
8 "os"
9 "path/filepath"
10 "regexp"
11 "strings"
12
13 tea "github.com/charmbracelet/bubbletea/v2"
14 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/home"
17 "github.com/charmbracelet/crush/internal/tui/components/chat"
18 "github.com/charmbracelet/crush/internal/tui/util"
19)
20
21const (
22 userCommandPrefix = "user:"
23 projectCommandPrefix = "project:"
24)
25
26var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
27
28type commandLoader struct {
29 sources []commandSource
30}
31
32type commandSource struct {
33 path string
34 prefix string
35}
36
37func LoadCustomCommands() ([]Command, error) {
38 cfg := config.Get()
39 if cfg == nil {
40 return nil, fmt.Errorf("config not loaded")
41 }
42
43 loader := &commandLoader{
44 sources: buildCommandSources(cfg),
45 }
46
47 return loader.loadAll()
48}
49
50func buildCommandSources(cfg *config.Config) []commandSource {
51 var sources []commandSource
52
53 // XDG config directory
54 if dir := getXDGCommandsDir(); dir != "" {
55 sources = append(sources, commandSource{
56 path: dir,
57 prefix: userCommandPrefix,
58 })
59 }
60
61 // Home directory
62 if home := home.Dir(); home != "" {
63 sources = append(sources, commandSource{
64 path: filepath.Join(home, ".crush", "commands"),
65 prefix: userCommandPrefix,
66 })
67 }
68
69 // Project directory
70 sources = append(sources, commandSource{
71 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
72 prefix: projectCommandPrefix,
73 })
74
75 return sources
76}
77
78func getXDGCommandsDir() string {
79 xdgHome := os.Getenv("XDG_CONFIG_HOME")
80 if xdgHome == "" {
81 if home := home.Dir(); home != "" {
82 xdgHome = filepath.Join(home, ".config")
83 }
84 }
85 if xdgHome != "" {
86 return filepath.Join(xdgHome, "crush", "commands")
87 }
88 return ""
89}
90
91func (l *commandLoader) loadAll() ([]Command, error) {
92 var commands []Command
93
94 for _, source := range l.sources {
95 if cmds, err := l.loadFromSource(source); err == nil {
96 commands = append(commands, cmds...)
97 }
98 }
99
100 return commands, nil
101}
102
103func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
104 if err := ensureDir(source.path); err != nil {
105 return nil, err
106 }
107
108 var commands []Command
109
110 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
111 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
112 return err
113 }
114
115 cmd, err := l.loadCommand(path, source.path, source.prefix)
116 if err != nil {
117 return nil // Skip invalid files
118 }
119
120 commands = append(commands, cmd)
121 return nil
122 })
123
124 return commands, err
125}
126
127func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
128 content, err := os.ReadFile(path)
129 if err != nil {
130 return Command{}, err
131 }
132
133 id := buildCommandID(path, baseDir, prefix)
134 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
135
136 return Command{
137 ID: id,
138 Title: id,
139 Description: desc,
140 Handler: createCommandHandler(id, desc, string(content)),
141 }, nil
142}
143
144func buildCommandID(path, baseDir, prefix string) string {
145 relPath, _ := filepath.Rel(baseDir, path)
146 parts := strings.Split(relPath, string(filepath.Separator))
147
148 // Remove .md extension from last part
149 if len(parts) > 0 {
150 lastIdx := len(parts) - 1
151 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
152 }
153
154 return prefix + strings.Join(parts, ":")
155}
156
157func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
158 return func(cmd Command) tea.Cmd {
159 args := extractArgNames(content)
160
161 if len(args) == 0 {
162 return util.CmdHandler(CommandRunCustomMsg{
163 Content: content,
164 })
165 }
166 return util.CmdHandler(ShowArgumentsDialogMsg{
167 CommandID: id,
168 Description: desc,
169 ArgNames: args,
170 OnSubmit: func(args map[string]string) tea.Cmd {
171 return execUserPrompt(content, args)
172 },
173 })
174 }
175}
176
177func execUserPrompt(content string, args map[string]string) tea.Cmd {
178 return func() tea.Msg {
179 for name, value := range args {
180 placeholder := "$" + name
181 content = strings.ReplaceAll(content, placeholder, value)
182 }
183 return CommandRunCustomMsg{
184 Content: content,
185 }
186 }
187}
188
189func extractArgNames(content string) []string {
190 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
191 if len(matches) == 0 {
192 return nil
193 }
194
195 seen := make(map[string]bool)
196 var args []string
197
198 for _, match := range matches {
199 arg := match[1]
200 if !seen[arg] {
201 seen[arg] = true
202 args = append(args, arg)
203 }
204 }
205
206 return args
207}
208
209func ensureDir(path string) error {
210 if _, err := os.Stat(path); os.IsNotExist(err) {
211 return os.MkdirAll(path, 0o755)
212 }
213 return nil
214}
215
216func isMarkdownFile(name string) bool {
217 return strings.HasSuffix(strings.ToLower(name), ".md")
218}
219
220type CommandRunCustomMsg struct {
221 Content string
222}
223
224func loadMCPPrompts() []Command {
225 var commands []Command
226 for mcpName, prompts := range mcp.Prompts() {
227 for _, prompt := range prompts {
228 key := mcpName + ":" + prompt.Name
229 commands = append(commands, Command{
230 ID: key,
231 Title: cmp.Or(prompt.Title, prompt.Name),
232 Description: prompt.Description,
233 Handler: createMCPPromptHandler(mcpName, prompt.Name, prompt),
234 })
235 }
236 }
237
238 return commands
239}
240
241func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
242 return func(cmd Command) tea.Cmd {
243 if len(prompt.Arguments) == 0 {
244 return execMCPPrompt(mcpName, promptName, nil)
245 }
246 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
247 Prompt: prompt,
248 OnSubmit: func(args map[string]string) tea.Cmd {
249 return execMCPPrompt(mcpName, promptName, args)
250 },
251 })
252 }
253}
254
255func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
256 return func() tea.Msg {
257 ctx := context.Background()
258 result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
259 if err != nil {
260 return util.ReportError(err)
261 }
262
263 return chat.SendMsg{
264 Text: strings.Join(result, " "),
265 }
266 }
267}
268
269type ShowMCPPromptArgumentsDialogMsg struct {
270 Prompt *mcp.Prompt
271 OnSubmit func(arg map[string]string) tea.Cmd
272}