1package commands
2
3import (
4 "context"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "regexp"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/home"
15 "github.com/charmbracelet/crush/internal/skills"
16)
17
18var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
19
20const (
21 userCommandPrefix = "user:"
22 projectCommandPrefix = "project:"
23)
24
25// Argument represents a command argument with its metadata.
26type Argument struct {
27 ID string
28 Title string
29 Description string
30 Required bool
31}
32
33// MCPPrompt represents a custom command loaded from an MCP server.
34type MCPPrompt struct {
35 ID string
36 Title string
37 Description string
38 PromptID string
39 ClientID string
40 Arguments []Argument
41}
42
43// CustomCommand represents a user-defined custom command loaded from markdown files.
44type CustomCommand struct {
45 ID string
46 Name string
47 Content string
48 Arguments []Argument
49 // Skill is set when this command represents a user-invocable skill
50 Skill *skills.Skill
51}
52
53type commandSource struct {
54 path string
55 prefix string
56}
57
58// LoadCustomCommands loads custom commands from multiple sources including
59// XDG config directory, home directory, and project directory.
60func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
61 return loadAll(buildCommandSources(cfg))
62}
63
64// FromSkillCatalog converts user-invocable catalog entries into custom
65// command entries for the command palette.
66func FromSkillCatalog(entries []skills.CatalogEntry) []CustomCommand {
67 commands := make([]CustomCommand, 0, len(entries))
68 for _, entry := range entries {
69 if !entry.UserInvocable {
70 continue
71 }
72 name := entry.Label
73 if name == "" {
74 name = userCommandPrefix + entry.Name
75 }
76 commands = append(commands, CustomCommand{
77 ID: name,
78 Name: name,
79 Skill: &skills.Skill{
80 Name: entry.Name,
81 Description: entry.Description,
82 SkillFilePath: entry.ID,
83 },
84 })
85 }
86 return commands
87}
88
89// LoadMCPPrompts loads custom commands from available MCP servers.
90func LoadMCPPrompts() ([]MCPPrompt, error) {
91 var commands []MCPPrompt
92 for mcpName, prompts := range mcp.Prompts() {
93 for _, prompt := range prompts {
94 key := mcpName + ":" + prompt.Name
95 var args []Argument
96 for _, arg := range prompt.Arguments {
97 title := arg.Title
98 if title == "" {
99 title = arg.Name
100 }
101 args = append(args, Argument{
102 ID: arg.Name,
103 Title: title,
104 Description: arg.Description,
105 Required: arg.Required,
106 })
107 }
108 commands = append(commands, MCPPrompt{
109 ID: key,
110 Title: prompt.Title,
111 Description: prompt.Description,
112 PromptID: prompt.Name,
113 ClientID: mcpName,
114 Arguments: args,
115 })
116 }
117 }
118 return commands, nil
119}
120
121func buildCommandSources(cfg *config.Config) []commandSource {
122 return []commandSource{
123 {
124 path: filepath.Join(home.Config(), "crush", "commands"),
125 prefix: userCommandPrefix,
126 },
127 {
128 path: filepath.Join(home.Dir(), ".crush", "commands"),
129 prefix: userCommandPrefix,
130 },
131 {
132 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
133 prefix: projectCommandPrefix,
134 },
135 }
136}
137
138func loadAll(sources []commandSource) ([]CustomCommand, error) {
139 var commands []CustomCommand
140
141 for _, source := range sources {
142 if cmds, err := loadFromSource(source); err == nil {
143 commands = append(commands, cmds...)
144 }
145 }
146
147 return commands, nil
148}
149
150func loadFromSource(source commandSource) ([]CustomCommand, error) {
151 if _, err := os.Stat(source.path); os.IsNotExist(err) {
152 return nil, nil
153 }
154
155 var commands []CustomCommand
156
157 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
158 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
159 return err
160 }
161
162 cmd, err := loadCommand(path, source.path, source.prefix)
163 if err != nil {
164 return nil // Skip invalid files
165 }
166
167 commands = append(commands, cmd)
168 return nil
169 })
170
171 return commands, err
172}
173
174func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
175 content, err := os.ReadFile(path)
176 if err != nil {
177 return CustomCommand{}, err
178 }
179
180 id := buildCommandID(path, baseDir, prefix)
181
182 return CustomCommand{
183 ID: id,
184 Name: id,
185 Content: string(content),
186 Arguments: extractArgNames(string(content)),
187 }, nil
188}
189
190func extractArgNames(content string) []Argument {
191 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
192 if len(matches) == 0 {
193 return nil
194 }
195
196 seen := make(map[string]bool)
197 var args []Argument
198
199 for _, match := range matches {
200 arg := match[1]
201 if !seen[arg] {
202 seen[arg] = true
203 // for normal custom commands, all args are required
204 args = append(args, Argument{ID: arg, Title: arg, Required: true})
205 }
206 }
207
208 return args
209}
210
211func buildCommandID(path, baseDir, prefix string) string {
212 relPath, _ := filepath.Rel(baseDir, path)
213 parts := strings.Split(relPath, string(filepath.Separator))
214
215 // Remove .md extension from last part
216 if len(parts) > 0 {
217 lastIdx := len(parts) - 1
218 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
219 }
220
221 return prefix + strings.Join(parts, ":")
222}
223
224func isMarkdownFile(name string) bool {
225 return strings.HasSuffix(strings.ToLower(name), ".md")
226}
227
228func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
229 // Create a context with timeout since tea.Cmd doesn't support context passing.
230 // The MCP client has its own timeout, but this provides an additional safeguard.
231 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
232 defer cancel()
233
234 result, err := mcp.GetPromptMessages(ctx, cfg, clientID, promptID, args)
235 if err != nil {
236 return "", err
237 }
238 return strings.Join(result, " "), nil
239}