1package commands
2
3import (
4 "context"
5 "fmt"
6 "io/fs"
7 "os"
8 "path/filepath"
9 "regexp"
10 "strings"
11
12 tea "github.com/charmbracelet/bubbletea/v2"
13 "github.com/modelcontextprotocol/go-sdk/mcp"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/home"
17 "github.com/charmbracelet/crush/internal/llm/agent"
18 "github.com/charmbracelet/crush/internal/tui/components/chat"
19 "github.com/charmbracelet/crush/internal/tui/util"
20)
21
22const (
23 UserCommandPrefix = "user:"
24 ProjectCommandPrefix = "project:"
25 MCPPromptPrefix = "mcp:"
26)
27
28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
29
30type commandLoader struct {
31 sources []commandSource
32}
33
34type commandSource struct {
35 path string
36 prefix string
37}
38
39func LoadCustomCommands() ([]Command, error) {
40 cfg := config.Get()
41 if cfg == nil {
42 return nil, fmt.Errorf("config not loaded")
43 }
44
45 loader := &commandLoader{
46 sources: buildCommandSources(cfg),
47 }
48
49 return loader.loadAll()
50}
51
52func buildCommandSources(cfg *config.Config) []commandSource {
53 var sources []commandSource
54
55 // XDG config directory
56 if dir := getXDGCommandsDir(); dir != "" {
57 sources = append(sources, commandSource{
58 path: dir,
59 prefix: UserCommandPrefix,
60 })
61 }
62
63 // Home directory
64 if home := home.Dir(); home != "" {
65 sources = append(sources, commandSource{
66 path: filepath.Join(home, ".crush", "commands"),
67 prefix: UserCommandPrefix,
68 })
69 }
70
71 // Project directory
72 sources = append(sources, commandSource{
73 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
74 prefix: ProjectCommandPrefix,
75 })
76
77 return sources
78}
79
80func getXDGCommandsDir() string {
81 xdgHome := os.Getenv("XDG_CONFIG_HOME")
82 if xdgHome == "" {
83 if home := home.Dir(); home != "" {
84 xdgHome = filepath.Join(home, ".config")
85 }
86 }
87 if xdgHome != "" {
88 return filepath.Join(xdgHome, "crush", "commands")
89 }
90 return ""
91}
92
93func (l *commandLoader) loadAll() ([]Command, error) {
94 var commands []Command
95
96 for _, source := range l.sources {
97 if cmds, err := l.loadFromSource(source); err == nil {
98 commands = append(commands, cmds...)
99 }
100 }
101
102 return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106 if err := ensureDir(source.path); err != nil {
107 return nil, err
108 }
109
110 var commands []Command
111
112 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114 return err
115 }
116
117 cmd, err := l.loadCommand(path, source.path, source.prefix)
118 if err != nil {
119 return nil // Skip invalid files
120 }
121
122 commands = append(commands, cmd)
123 return nil
124 })
125
126 return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130 content, err := os.ReadFile(path)
131 if err != nil {
132 return Command{}, err
133 }
134
135 id := buildCommandID(path, baseDir, prefix)
136
137 return Command{
138 ID: id,
139 Title: id,
140 Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
141 Handler: createCommandHandler(id, string(content)),
142 }, nil
143}
144
145func buildCommandID(path, baseDir, prefix string) string {
146 relPath, _ := filepath.Rel(baseDir, path)
147 parts := strings.Split(relPath, string(filepath.Separator))
148
149 // Remove .md extension from last part
150 if len(parts) > 0 {
151 lastIdx := len(parts) - 1
152 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
153 }
154
155 return prefix + strings.Join(parts, ":")
156}
157
158func createCommandHandler(id string, content string) func(Command) tea.Cmd {
159 return func(cmd Command) tea.Cmd {
160 args := extractArgNames(content)
161
162 if len(args) > 0 {
163 return util.CmdHandler(ShowArgumentsDialogMsg{
164 CommandID: id,
165 Content: content,
166 ArgNames: args,
167 })
168 }
169
170 return util.CmdHandler(CommandRunCustomMsg{
171 Content: content,
172 })
173 }
174}
175
176func extractArgNames(content string) []string {
177 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
178 if len(matches) == 0 {
179 return nil
180 }
181
182 seen := make(map[string]bool)
183 var args []string
184
185 for _, match := range matches {
186 arg := match[1]
187 if !seen[arg] {
188 seen[arg] = true
189 args = append(args, arg)
190 }
191 }
192
193 return args
194}
195
196func ensureDir(path string) error {
197 if _, err := os.Stat(path); os.IsNotExist(err) {
198 return os.MkdirAll(path, 0o755)
199 }
200 return nil
201}
202
203func isMarkdownFile(name string) bool {
204 return strings.HasSuffix(strings.ToLower(name), ".md")
205}
206
207type CommandRunCustomMsg struct {
208 Content string
209}
210
211func LoadMCPPrompts() []Command {
212 prompts := agent.GetMCPPrompts()
213 commands := make([]Command, 0, len(prompts))
214
215 for key, prompt := range prompts {
216 p := prompt
217 // key format is "clientName:promptName"
218 parts := strings.SplitN(key, ":", 2)
219 if len(parts) != 2 {
220 continue
221 }
222 clientName, promptName := parts[0], parts[1]
223
224 displayName := promptName
225 if p.Title != "" {
226 displayName = p.Title
227 }
228
229 commands = append(commands, Command{
230 ID: MCPPromptPrefix + key,
231 Title: displayName,
232 Description: fmt.Sprintf("[%s] %s", clientName, p.Description),
233 Handler: createMCPPromptHandler(key, promptName, p),
234 })
235 }
236
237 return commands
238}
239
240func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
241 return func(cmd Command) tea.Cmd {
242 if len(prompt.Arguments) == 0 {
243 return executeMCPPromptWithoutArgs(key, promptName)
244 }
245 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
246 PromptID: cmd.ID,
247 PromptName: promptName,
248 })
249 }
250}
251
252func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
253 return func() tea.Msg {
254 // key format is "clientName:promptName"
255 parts := strings.SplitN(key, ":", 2)
256 if len(parts) != 2 {
257 return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
258 }
259 clientName := parts[0]
260
261 ctx := context.Background()
262 result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
263 if err != nil {
264 return util.ReportError(err)
265 }
266
267 var content strings.Builder
268 for _, msg := range result.Messages {
269 if msg.Role == "user" {
270 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
271 content.WriteString(textContent.Text)
272 content.WriteString("\n")
273 }
274 }
275 }
276
277 return chat.SendMsg{
278 Text: content.String(),
279 }
280 }
281}
282
283type ShowMCPPromptArgumentsDialogMsg struct {
284 PromptID string
285 PromptName string
286}