1package commands
2
3import (
4 "context"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "regexp"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/home"
15 "github.com/charmbracelet/crush/internal/skills"
16)
17
18var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
19
20const (
21 userCommandPrefix = "user:"
22 projectCommandPrefix = "project:"
23)
24
25// Argument represents a command argument with its metadata.
26type Argument struct {
27 ID string
28 Title string
29 Description string
30 Required bool
31}
32
33// MCPPrompt represents a custom command loaded from an MCP server.
34type MCPPrompt struct {
35 ID string
36 Title string
37 Description string
38 PromptID string
39 ClientID string
40 Arguments []Argument
41}
42
43// CustomCommand represents a user-defined custom command loaded from markdown files.
44type CustomCommand struct {
45 ID string
46 Name string
47 Content string
48 Arguments []Argument
49 // Skill is set when this command represents a user-invocable skill
50 Skill *skills.Skill
51}
52
53type commandSource struct {
54 path string
55 prefix string
56}
57
58// LoadCustomCommands loads custom commands from multiple sources including
59// XDG config directory, home directory, and project directory.
60func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
61 return loadAll(buildCommandSources(cfg))
62}
63
64// LoadSkillCommands loads user-invocable skills as custom commands.
65func LoadSkillCommands() []CustomCommand {
66 var commands []CustomCommand
67
68 // Load from global skills directories with "user:" prefix
69 for _, dir := range config.GlobalSkillsDirs() {
70 commands = append(commands, loadInvocableSkillsFromDir(dir, userCommandPrefix)...)
71 }
72
73 return commands
74}
75
76// LoadProjectSkillCommands loads user-invocable skills from project directories as custom commands.
77func LoadProjectSkillCommands(workingDir string) []CustomCommand {
78 var commands []CustomCommand
79
80 // Load from project skills directories with "project:" prefix
81 for _, dir := range config.ProjectSkillsDir(workingDir) {
82 commands = append(commands, loadInvocableSkillsFromDir(dir, projectCommandPrefix)...)
83 }
84
85 return commands
86}
87
88func loadInvocableSkillsFromDir(dir, prefix string) []CustomCommand {
89 if _, err := os.Stat(dir); os.IsNotExist(err) {
90 return nil
91 }
92
93 var commands []CustomCommand
94
95 entries, err := os.ReadDir(dir)
96 if err != nil {
97 return nil
98 }
99
100 for _, entry := range entries {
101 if !entry.IsDir() {
102 continue
103 }
104
105 skillPath := filepath.Join(dir, entry.Name(), skills.SkillFileName)
106 skill, err := skills.Parse(skillPath)
107 if err != nil {
108 continue
109 }
110
111 if !skill.UserInvocable {
112 continue
113 }
114
115 name := prefix + skill.Name
116 commands = append(commands, CustomCommand{
117 ID: name,
118 Name: name,
119 Content: skill.Instructions,
120 Arguments: nil,
121 Skill: skill,
122 })
123 }
124
125 return commands
126}
127
128// LoadMCPPrompts loads custom commands from available MCP servers.
129func LoadMCPPrompts() ([]MCPPrompt, error) {
130 var commands []MCPPrompt
131 for mcpName, prompts := range mcp.Prompts() {
132 for _, prompt := range prompts {
133 key := mcpName + ":" + prompt.Name
134 var args []Argument
135 for _, arg := range prompt.Arguments {
136 title := arg.Title
137 if title == "" {
138 title = arg.Name
139 }
140 args = append(args, Argument{
141 ID: arg.Name,
142 Title: title,
143 Description: arg.Description,
144 Required: arg.Required,
145 })
146 }
147 commands = append(commands, MCPPrompt{
148 ID: key,
149 Title: prompt.Title,
150 Description: prompt.Description,
151 PromptID: prompt.Name,
152 ClientID: mcpName,
153 Arguments: args,
154 })
155 }
156 }
157 return commands, nil
158}
159
160func buildCommandSources(cfg *config.Config) []commandSource {
161 return []commandSource{
162 {
163 path: filepath.Join(home.Config(), "crush", "commands"),
164 prefix: userCommandPrefix,
165 },
166 {
167 path: filepath.Join(home.Dir(), ".crush", "commands"),
168 prefix: userCommandPrefix,
169 },
170 {
171 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
172 prefix: projectCommandPrefix,
173 },
174 }
175}
176
177func loadAll(sources []commandSource) ([]CustomCommand, error) {
178 var commands []CustomCommand
179
180 for _, source := range sources {
181 if cmds, err := loadFromSource(source); err == nil {
182 commands = append(commands, cmds...)
183 }
184 }
185
186 return commands, nil
187}
188
189func loadFromSource(source commandSource) ([]CustomCommand, error) {
190 if _, err := os.Stat(source.path); os.IsNotExist(err) {
191 return nil, nil
192 }
193
194 var commands []CustomCommand
195
196 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
197 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
198 return err
199 }
200
201 cmd, err := loadCommand(path, source.path, source.prefix)
202 if err != nil {
203 return nil // Skip invalid files
204 }
205
206 commands = append(commands, cmd)
207 return nil
208 })
209
210 return commands, err
211}
212
213func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
214 content, err := os.ReadFile(path)
215 if err != nil {
216 return CustomCommand{}, err
217 }
218
219 id := buildCommandID(path, baseDir, prefix)
220
221 return CustomCommand{
222 ID: id,
223 Name: id,
224 Content: string(content),
225 Arguments: extractArgNames(string(content)),
226 }, nil
227}
228
229func extractArgNames(content string) []Argument {
230 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
231 if len(matches) == 0 {
232 return nil
233 }
234
235 seen := make(map[string]bool)
236 var args []Argument
237
238 for _, match := range matches {
239 arg := match[1]
240 if !seen[arg] {
241 seen[arg] = true
242 // for normal custom commands, all args are required
243 args = append(args, Argument{ID: arg, Title: arg, Required: true})
244 }
245 }
246
247 return args
248}
249
250func buildCommandID(path, baseDir, prefix string) string {
251 relPath, _ := filepath.Rel(baseDir, path)
252 parts := strings.Split(relPath, string(filepath.Separator))
253
254 // Remove .md extension from last part
255 if len(parts) > 0 {
256 lastIdx := len(parts) - 1
257 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
258 }
259
260 return prefix + strings.Join(parts, ":")
261}
262
263func isMarkdownFile(name string) bool {
264 return strings.HasSuffix(strings.ToLower(name), ".md")
265}
266
267func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
268 // Create a context with timeout since tea.Cmd doesn't support context passing.
269 // The MCP client has its own timeout, but this provides an additional safeguard.
270 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
271 defer cancel()
272
273 result, err := mcp.GetPromptMessages(ctx, cfg, clientID, promptID, args)
274 if err != nil {
275 return "", err
276 }
277 return strings.Join(result, " "), nil
278}