1package commands
2
3import (
4 "io/fs"
5 "os"
6 "path/filepath"
7 "regexp"
8 "strings"
9
10 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/home"
13)
14
15var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
16
17const (
18 userCommandPrefix = "user:"
19 projectCommandPrefix = "project:"
20)
21
22// Argument represents a command argument with its name and required status.
23type Argument struct {
24 Name string
25 Required bool
26}
27
28// MCPCustomCommand represents a custom command loaded from an MCP server.
29type MCPCustomCommand struct {
30 ID string
31 Name string
32 Client string
33 Arguments []Argument
34}
35
36// CustomCommand represents a user-defined custom command loaded from markdown files.
37type CustomCommand struct {
38 ID string
39 Name string
40 Content string
41 Arguments []Argument
42}
43
44type commandSource struct {
45 path string
46 prefix string
47}
48
49// LoadCustomCommands loads custom commands from multiple sources including
50// XDG config directory, home directory, and project directory.
51func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
52 return loadAll(buildCommandSources(cfg))
53}
54
55// LoadMCPCustomCommands loads custom commands from available MCP servers.
56func LoadMCPCustomCommands() ([]MCPCustomCommand, error) {
57 var commands []MCPCustomCommand
58 for mcpName, prompts := range mcp.Prompts() {
59 for _, prompt := range prompts {
60 key := mcpName + ":" + prompt.Name
61 var args []Argument
62 for _, arg := range prompt.Arguments {
63 args = append(args, Argument{Name: arg.Name, Required: arg.Required})
64 }
65
66 commands = append(commands, MCPCustomCommand{
67 ID: key,
68 Name: prompt.Name,
69 Client: mcpName,
70 Arguments: args,
71 })
72 }
73 }
74 return commands, nil
75}
76
77func buildCommandSources(cfg *config.Config) []commandSource {
78 var sources []commandSource
79
80 // XDG config directory
81 if dir := getXDGCommandsDir(); dir != "" {
82 sources = append(sources, commandSource{
83 path: dir,
84 prefix: userCommandPrefix,
85 })
86 }
87
88 // Home directory
89 if home := home.Dir(); home != "" {
90 sources = append(sources, commandSource{
91 path: filepath.Join(home, ".crush", "commands"),
92 prefix: userCommandPrefix,
93 })
94 }
95
96 // Project directory
97 sources = append(sources, commandSource{
98 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
99 prefix: projectCommandPrefix,
100 })
101
102 return sources
103}
104
105func loadAll(sources []commandSource) ([]CustomCommand, error) {
106 var commands []CustomCommand
107
108 for _, source := range sources {
109 if cmds, err := loadFromSource(source); err == nil {
110 commands = append(commands, cmds...)
111 }
112 }
113
114 return commands, nil
115}
116
117func loadFromSource(source commandSource) ([]CustomCommand, error) {
118 if err := ensureDir(source.path); err != nil {
119 return nil, err
120 }
121
122 var commands []CustomCommand
123
124 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
125 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
126 return err
127 }
128
129 cmd, err := loadCommand(path, source.path, source.prefix)
130 if err != nil {
131 return nil // Skip invalid files
132 }
133
134 commands = append(commands, cmd)
135 return nil
136 })
137
138 return commands, err
139}
140
141func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
142 content, err := os.ReadFile(path)
143 if err != nil {
144 return CustomCommand{}, err
145 }
146
147 id := buildCommandID(path, baseDir, prefix)
148
149 return CustomCommand{
150 ID: id,
151 Name: id,
152 Content: string(content),
153 Arguments: extractArgNames(string(content)),
154 }, nil
155}
156
157func extractArgNames(content string) []Argument {
158 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
159 if len(matches) == 0 {
160 return nil
161 }
162
163 seen := make(map[string]bool)
164 var args []Argument
165
166 for _, match := range matches {
167 arg := match[1]
168 if !seen[arg] {
169 seen[arg] = true
170 // for normal custom commands, all args are required
171 args = append(args, Argument{Name: arg, Required: true})
172 }
173 }
174
175 return args
176}
177
178func buildCommandID(path, baseDir, prefix string) string {
179 relPath, _ := filepath.Rel(baseDir, path)
180 parts := strings.Split(relPath, string(filepath.Separator))
181
182 // Remove .md extension from last part
183 if len(parts) > 0 {
184 lastIdx := len(parts) - 1
185 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
186 }
187
188 return prefix + strings.Join(parts, ":")
189}
190
191func getXDGCommandsDir() string {
192 xdgHome := os.Getenv("XDG_CONFIG_HOME")
193 if xdgHome == "" {
194 if home := home.Dir(); home != "" {
195 xdgHome = filepath.Join(home, ".config")
196 }
197 }
198 if xdgHome != "" {
199 return filepath.Join(xdgHome, "crush", "commands")
200 }
201 return ""
202}
203
204func ensureDir(path string) error {
205 if _, err := os.Stat(path); os.IsNotExist(err) {
206 return os.MkdirAll(path, 0o755)
207 }
208 return nil
209}
210
211func isMarkdownFile(name string) bool {
212 return strings.HasSuffix(strings.ToLower(name), ".md")
213}