1package commands
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "io/fs"
8 "os"
9 "path/filepath"
10 "regexp"
11 "strings"
12
13 tea "github.com/charmbracelet/bubbletea/v2"
14 "github.com/modelcontextprotocol/go-sdk/mcp"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/llm/agent"
19 "github.com/charmbracelet/crush/internal/tui/components/chat"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 userCommandPrefix = "user:"
25 projectCommandPrefix = "project:"
26)
27
28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
29
30type commandLoader struct {
31 sources []commandSource
32}
33
34type commandSource struct {
35 path string
36 prefix string
37}
38
39func LoadCustomCommands() ([]Command, error) {
40 cfg := config.Get()
41 if cfg == nil {
42 return nil, fmt.Errorf("config not loaded")
43 }
44
45 loader := &commandLoader{
46 sources: buildCommandSources(cfg),
47 }
48
49 return loader.loadAll()
50}
51
52func buildCommandSources(cfg *config.Config) []commandSource {
53 var sources []commandSource
54
55 // XDG config directory
56 if dir := getXDGCommandsDir(); dir != "" {
57 sources = append(sources, commandSource{
58 path: dir,
59 prefix: userCommandPrefix,
60 })
61 }
62
63 // Home directory
64 if home := home.Dir(); home != "" {
65 sources = append(sources, commandSource{
66 path: filepath.Join(home, ".crush", "commands"),
67 prefix: userCommandPrefix,
68 })
69 }
70
71 // Project directory
72 sources = append(sources, commandSource{
73 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
74 prefix: projectCommandPrefix,
75 })
76
77 return sources
78}
79
80func getXDGCommandsDir() string {
81 xdgHome := os.Getenv("XDG_CONFIG_HOME")
82 if xdgHome == "" {
83 if home := home.Dir(); home != "" {
84 xdgHome = filepath.Join(home, ".config")
85 }
86 }
87 if xdgHome != "" {
88 return filepath.Join(xdgHome, "crush", "commands")
89 }
90 return ""
91}
92
93func (l *commandLoader) loadAll() ([]Command, error) {
94 var commands []Command
95
96 for _, source := range l.sources {
97 if cmds, err := l.loadFromSource(source); err == nil {
98 commands = append(commands, cmds...)
99 }
100 }
101
102 return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106 if err := ensureDir(source.path); err != nil {
107 return nil, err
108 }
109
110 var commands []Command
111
112 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114 return err
115 }
116
117 cmd, err := l.loadCommand(path, source.path, source.prefix)
118 if err != nil {
119 return nil // Skip invalid files
120 }
121
122 commands = append(commands, cmd)
123 return nil
124 })
125
126 return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130 content, err := os.ReadFile(path)
131 if err != nil {
132 return Command{}, err
133 }
134
135 id := buildCommandID(path, baseDir, prefix)
136 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
137
138 return Command{
139 ID: id,
140 Title: id,
141 Description: desc,
142 Handler: createCommandHandler(id, desc, string(content)),
143 }, nil
144}
145
146func buildCommandID(path, baseDir, prefix string) string {
147 relPath, _ := filepath.Rel(baseDir, path)
148 parts := strings.Split(relPath, string(filepath.Separator))
149
150 // Remove .md extension from last part
151 if len(parts) > 0 {
152 lastIdx := len(parts) - 1
153 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
154 }
155
156 return prefix + strings.Join(parts, ":")
157}
158
159func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
160 return func(cmd Command) tea.Cmd {
161 args := extractArgNames(content)
162
163 if len(args) > 0 {
164 return util.CmdHandler(ShowArgumentsDialogMsg{
165 CommandID: id,
166 Description: desc,
167 Content: content,
168 ArgNames: args,
169 })
170 }
171
172 return util.CmdHandler(CommandRunCustomMsg{
173 Content: content,
174 })
175 }
176}
177
178func extractArgNames(content string) []string {
179 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
180 if len(matches) == 0 {
181 return nil
182 }
183
184 seen := make(map[string]bool)
185 var args []string
186
187 for _, match := range matches {
188 arg := match[1]
189 if !seen[arg] {
190 seen[arg] = true
191 args = append(args, arg)
192 }
193 }
194
195 return args
196}
197
198func ensureDir(path string) error {
199 if _, err := os.Stat(path); os.IsNotExist(err) {
200 return os.MkdirAll(path, 0o755)
201 }
202 return nil
203}
204
205func isMarkdownFile(name string) bool {
206 return strings.HasSuffix(strings.ToLower(name), ".md")
207}
208
209type CommandRunCustomMsg struct {
210 Content string
211}
212
213func LoadMCPPrompts() []Command {
214 prompts := agent.GetMCPPrompts()
215 commands := make([]Command, 0, len(prompts))
216
217 for key, prompt := range prompts {
218 p := prompt
219 // key format is "clientName:promptName"
220 parts := strings.SplitN(key, ":", 2)
221 if len(parts) != 2 {
222 continue
223 }
224 clientName, promptName := parts[0], parts[1]
225 displayName := clientName + " " + cmp.Or(p.Title, promptName)
226 commands = append(commands, Command{
227 ID: key,
228 Title: displayName,
229 Description: fmt.Sprintf("[%s] %s", clientName, p.Description),
230 Handler: createMCPPromptHandler(key, promptName, p),
231 })
232 }
233
234 return commands
235}
236
237func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
238 return func(cmd Command) tea.Cmd {
239 if len(prompt.Arguments) == 0 {
240 return executeMCPPromptWithoutArgs(key, promptName)
241 }
242 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
243 PromptID: cmd.ID,
244 PromptName: promptName,
245 })
246 }
247}
248
249func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
250 return func() tea.Msg {
251 // key format is "clientName:promptName"
252 parts := strings.SplitN(key, ":", 2)
253 if len(parts) != 2 {
254 return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
255 }
256 clientName := parts[0]
257
258 ctx := context.Background()
259 result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
260 if err != nil {
261 return util.ReportError(err)
262 }
263
264 var content strings.Builder
265 for _, msg := range result.Messages {
266 if msg.Role == "user" {
267 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
268 content.WriteString(textContent.Text)
269 content.WriteString("\n")
270 }
271 }
272 }
273
274 return chat.SendMsg{
275 Text: content.String(),
276 }
277 }
278}
279
280type ShowMCPPromptArgumentsDialogMsg struct {
281 PromptID string
282 PromptName string
283}