1package commands
2
3import (
4 "context"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "regexp"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/home"
15)
16
17var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
18
19const (
20 userCommandPrefix = "user:"
21 projectCommandPrefix = "project:"
22)
23
24// Argument represents a command argument with its metadata.
25type Argument struct {
26 ID string
27 Title string
28 Description string
29 Required bool
30}
31
32// MCPPrompt represents a custom command loaded from an MCP server.
33type MCPPrompt struct {
34 ID string
35 Title string
36 Description string
37 PromptID string
38 ClientID string
39 Arguments []Argument
40}
41
42// CustomCommand represents a user-defined custom command loaded from markdown files.
43type CustomCommand struct {
44 ID string
45 Name string
46 Content string
47 Arguments []Argument
48}
49
50type commandSource struct {
51 path string
52 prefix string
53}
54
55// LoadCustomCommands loads custom commands from multiple sources including
56// XDG config directory, home directory, and project directory.
57func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
58 return loadAll(buildCommandSources(cfg))
59}
60
61// LoadMCPPrompts loads custom commands from available MCP servers.
62func LoadMCPPrompts() ([]MCPPrompt, error) {
63 var commands []MCPPrompt
64 for mcpName, prompts := range mcp.Prompts() {
65 for _, prompt := range prompts {
66 key := mcpName + ":" + prompt.Name
67 var args []Argument
68 for _, arg := range prompt.Arguments {
69 title := arg.Title
70 if title == "" {
71 title = arg.Name
72 }
73 args = append(args, Argument{
74 ID: arg.Name,
75 Title: title,
76 Description: arg.Description,
77 Required: arg.Required,
78 })
79 }
80 commands = append(commands, MCPPrompt{
81 ID: key,
82 Title: prompt.Title,
83 Description: prompt.Description,
84 PromptID: prompt.Name,
85 ClientID: mcpName,
86 Arguments: args,
87 })
88 }
89 }
90 return commands, nil
91}
92
93func buildCommandSources(cfg *config.Config) []commandSource {
94 var sources []commandSource
95
96 // XDG config directory
97 if dir := getXDGCommandsDir(); dir != "" {
98 sources = append(sources, commandSource{
99 path: dir,
100 prefix: userCommandPrefix,
101 })
102 }
103
104 // Home directory
105 if home := home.Dir(); home != "" {
106 sources = append(sources, commandSource{
107 path: filepath.Join(home, ".crush", "commands"),
108 prefix: userCommandPrefix,
109 })
110 }
111
112 // Project directory
113 sources = append(sources, commandSource{
114 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
115 prefix: projectCommandPrefix,
116 })
117
118 return sources
119}
120
121func loadAll(sources []commandSource) ([]CustomCommand, error) {
122 var commands []CustomCommand
123
124 for _, source := range sources {
125 if cmds, err := loadFromSource(source); err == nil {
126 commands = append(commands, cmds...)
127 }
128 }
129
130 return commands, nil
131}
132
133func loadFromSource(source commandSource) ([]CustomCommand, error) {
134 if _, err := os.Stat(source.path); os.IsNotExist(err) {
135 return nil, nil
136 }
137
138 var commands []CustomCommand
139
140 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
141 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
142 return err
143 }
144
145 cmd, err := loadCommand(path, source.path, source.prefix)
146 if err != nil {
147 return nil // Skip invalid files
148 }
149
150 commands = append(commands, cmd)
151 return nil
152 })
153
154 return commands, err
155}
156
157func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
158 content, err := os.ReadFile(path)
159 if err != nil {
160 return CustomCommand{}, err
161 }
162
163 id := buildCommandID(path, baseDir, prefix)
164
165 return CustomCommand{
166 ID: id,
167 Name: id,
168 Content: string(content),
169 Arguments: extractArgNames(string(content)),
170 }, nil
171}
172
173func extractArgNames(content string) []Argument {
174 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
175 if len(matches) == 0 {
176 return nil
177 }
178
179 seen := make(map[string]bool)
180 var args []Argument
181
182 for _, match := range matches {
183 arg := match[1]
184 if !seen[arg] {
185 seen[arg] = true
186 // for normal custom commands, all args are required
187 args = append(args, Argument{ID: arg, Title: arg, Required: true})
188 }
189 }
190
191 return args
192}
193
194func buildCommandID(path, baseDir, prefix string) string {
195 relPath, _ := filepath.Rel(baseDir, path)
196 parts := strings.Split(relPath, string(filepath.Separator))
197
198 // Remove .md extension from last part
199 if len(parts) > 0 {
200 lastIdx := len(parts) - 1
201 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
202 }
203
204 return prefix + strings.Join(parts, ":")
205}
206
207func getXDGCommandsDir() string {
208 xdgHome := os.Getenv("XDG_CONFIG_HOME")
209 if xdgHome == "" {
210 if home := home.Dir(); home != "" {
211 xdgHome = filepath.Join(home, ".config")
212 }
213 }
214 if xdgHome != "" {
215 return filepath.Join(xdgHome, "crush", "commands")
216 }
217 return ""
218}
219
220func isMarkdownFile(name string) bool {
221 return strings.HasSuffix(strings.ToLower(name), ".md")
222}
223
224func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
225 // Create a context with timeout since tea.Cmd doesn't support context passing.
226 // The MCP client has its own timeout, but this provides an additional safeguard.
227 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
228 defer cancel()
229
230 result, err := mcp.GetPromptMessages(ctx, cfg, clientID, promptID, args)
231 if err != nil {
232 return "", err
233 }
234 return strings.Join(result, " "), nil
235}