1package commands
2
3import (
4 "context"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "regexp"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14)
15
16var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
17
18const (
19 userCommandPrefix = "user:"
20 projectCommandPrefix = "project:"
21)
22
23// Argument represents a command argument with its metadata.
24type Argument struct {
25 ID string
26 Title string
27 Description string
28 Required bool
29}
30
31// MCPPrompt represents a custom command loaded from an MCP server.
32type MCPPrompt struct {
33 ID string
34 Title string
35 Description string
36 PromptID string
37 ClientID string
38 Arguments []Argument
39}
40
41// CustomCommand represents a user-defined custom command loaded from markdown files.
42type CustomCommand struct {
43 ID string
44 Name string
45 Content string
46 Arguments []Argument
47}
48
49type commandSource struct {
50 path string
51 prefix string
52}
53
54// LoadCustomCommands loads custom commands from multiple sources including
55// XDG config directory, home directory, and project directory.
56func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
57 return loadAll(buildCommandSources(cfg))
58}
59
60// LoadMCPPrompts loads custom commands from available MCP servers.
61func LoadMCPPrompts() ([]MCPPrompt, error) {
62 var commands []MCPPrompt
63 for mcpName, prompts := range mcp.Prompts() {
64 for _, prompt := range prompts {
65 key := mcpName + ":" + prompt.Name
66 var args []Argument
67 for _, arg := range prompt.Arguments {
68 title := arg.Title
69 if title == "" {
70 title = arg.Name
71 }
72 args = append(args, Argument{
73 ID: arg.Name,
74 Title: title,
75 Description: arg.Description,
76 Required: arg.Required,
77 })
78 }
79 commands = append(commands, MCPPrompt{
80 ID: key,
81 Title: prompt.Title,
82 Description: prompt.Description,
83 PromptID: prompt.Name,
84 ClientID: mcpName,
85 Arguments: args,
86 })
87 }
88 }
89 return commands, nil
90}
91
92func buildCommandSources(cfg *config.Config) []commandSource {
93 var sources []commandSource
94
95 // XDG config directory
96 if dir := getXDGCommandsDir(); dir != "" {
97 sources = append(sources, commandSource{
98 path: dir,
99 prefix: userCommandPrefix,
100 })
101 }
102
103 // Home directory
104 if home := home.Dir(); home != "" {
105 sources = append(sources, commandSource{
106 path: filepath.Join(home, ".crush", "commands"),
107 prefix: userCommandPrefix,
108 })
109 }
110
111 // Project directory
112 sources = append(sources, commandSource{
113 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
114 prefix: projectCommandPrefix,
115 })
116
117 return sources
118}
119
120func loadAll(sources []commandSource) ([]CustomCommand, error) {
121 var commands []CustomCommand
122
123 for _, source := range sources {
124 if cmds, err := loadFromSource(source); err == nil {
125 commands = append(commands, cmds...)
126 }
127 }
128
129 return commands, nil
130}
131
132func loadFromSource(source commandSource) ([]CustomCommand, error) {
133 if err := ensureDir(source.path); err != nil {
134 return nil, err
135 }
136
137 var commands []CustomCommand
138
139 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
140 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
141 return err
142 }
143
144 cmd, err := loadCommand(path, source.path, source.prefix)
145 if err != nil {
146 return nil // Skip invalid files
147 }
148
149 commands = append(commands, cmd)
150 return nil
151 })
152
153 return commands, err
154}
155
156func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
157 content, err := os.ReadFile(path)
158 if err != nil {
159 return CustomCommand{}, err
160 }
161
162 id := buildCommandID(path, baseDir, prefix)
163
164 return CustomCommand{
165 ID: id,
166 Name: id,
167 Content: string(content),
168 Arguments: extractArgNames(string(content)),
169 }, nil
170}
171
172func extractArgNames(content string) []Argument {
173 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
174 if len(matches) == 0 {
175 return nil
176 }
177
178 seen := make(map[string]bool)
179 var args []Argument
180
181 for _, match := range matches {
182 arg := match[1]
183 if !seen[arg] {
184 seen[arg] = true
185 // for normal custom commands, all args are required
186 args = append(args, Argument{ID: arg, Title: arg, Required: true})
187 }
188 }
189
190 return args
191}
192
193func buildCommandID(path, baseDir, prefix string) string {
194 relPath, _ := filepath.Rel(baseDir, path)
195 parts := strings.Split(relPath, string(filepath.Separator))
196
197 // Remove .md extension from last part
198 if len(parts) > 0 {
199 lastIdx := len(parts) - 1
200 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
201 }
202
203 return prefix + strings.Join(parts, ":")
204}
205
206func getXDGCommandsDir() string {
207 xdgHome := os.Getenv("XDG_CONFIG_HOME")
208 if xdgHome == "" {
209 if home := home.Dir(); home != "" {
210 xdgHome = filepath.Join(home, ".config")
211 }
212 }
213 if xdgHome != "" {
214 return filepath.Join(xdgHome, "crush", "commands")
215 }
216 return ""
217}
218
219func ensureDir(path string) error {
220 if _, err := os.Stat(path); os.IsNotExist(err) {
221 return os.MkdirAll(path, 0o755)
222 }
223 return nil
224}
225
226func isMarkdownFile(name string) bool {
227 return strings.HasSuffix(strings.ToLower(name), ".md")
228}
229
230func GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
231 // TODO: we should pass the context down
232 result, err := mcp.GetPromptMessages(context.Background(), clientID, promptID, args)
233 if err != nil {
234 return "", err
235 }
236 return strings.Join(result, " "), nil
237}