1package commands
2
3import (
4 "context"
5 "fmt"
6 "io/fs"
7 "os"
8 "path/filepath"
9 "regexp"
10 "strings"
11
12 tea "github.com/charmbracelet/bubbletea/v2"
13 "github.com/modelcontextprotocol/go-sdk/mcp"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/home"
17 "github.com/charmbracelet/crush/internal/llm/agent"
18 "github.com/charmbracelet/crush/internal/tui/components/chat"
19 "github.com/charmbracelet/crush/internal/tui/util"
20)
21
22const (
23 userCommandPrefix = "user:"
24 projectCommandPrefix = "project:"
25)
26
27var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
28
29type commandLoader struct {
30 sources []commandSource
31}
32
33type commandSource struct {
34 path string
35 prefix string
36}
37
38func LoadCustomCommands() ([]Command, error) {
39 cfg := config.Get()
40 if cfg == nil {
41 return nil, fmt.Errorf("config not loaded")
42 }
43
44 loader := &commandLoader{
45 sources: buildCommandSources(cfg),
46 }
47
48 return loader.loadAll()
49}
50
51func buildCommandSources(cfg *config.Config) []commandSource {
52 var sources []commandSource
53
54 // XDG config directory
55 if dir := getXDGCommandsDir(); dir != "" {
56 sources = append(sources, commandSource{
57 path: dir,
58 prefix: userCommandPrefix,
59 })
60 }
61
62 // Home directory
63 if home := home.Dir(); home != "" {
64 sources = append(sources, commandSource{
65 path: filepath.Join(home, ".crush", "commands"),
66 prefix: userCommandPrefix,
67 })
68 }
69
70 // Project directory
71 sources = append(sources, commandSource{
72 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
73 prefix: projectCommandPrefix,
74 })
75
76 return sources
77}
78
79func getXDGCommandsDir() string {
80 xdgHome := os.Getenv("XDG_CONFIG_HOME")
81 if xdgHome == "" {
82 if home := home.Dir(); home != "" {
83 xdgHome = filepath.Join(home, ".config")
84 }
85 }
86 if xdgHome != "" {
87 return filepath.Join(xdgHome, "crush", "commands")
88 }
89 return ""
90}
91
92func (l *commandLoader) loadAll() ([]Command, error) {
93 var commands []Command
94
95 for _, source := range l.sources {
96 if cmds, err := l.loadFromSource(source); err == nil {
97 commands = append(commands, cmds...)
98 }
99 }
100
101 return commands, nil
102}
103
104func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
105 if err := ensureDir(source.path); err != nil {
106 return nil, err
107 }
108
109 var commands []Command
110
111 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
112 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
113 return err
114 }
115
116 cmd, err := l.loadCommand(path, source.path, source.prefix)
117 if err != nil {
118 return nil // Skip invalid files
119 }
120
121 commands = append(commands, cmd)
122 return nil
123 })
124
125 return commands, err
126}
127
128func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
129 content, err := os.ReadFile(path)
130 if err != nil {
131 return Command{}, err
132 }
133
134 id := buildCommandID(path, baseDir, prefix)
135 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
136
137 return Command{
138 ID: id,
139 Title: id,
140 Description: desc,
141 Handler: createCommandHandler(id, desc, string(content)),
142 }, nil
143}
144
145func buildCommandID(path, baseDir, prefix string) string {
146 relPath, _ := filepath.Rel(baseDir, path)
147 parts := strings.Split(relPath, string(filepath.Separator))
148
149 // Remove .md extension from last part
150 if len(parts) > 0 {
151 lastIdx := len(parts) - 1
152 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
153 }
154
155 return prefix + strings.Join(parts, ":")
156}
157
158func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
159 return func(cmd Command) tea.Cmd {
160 args := extractArgNames(content)
161
162 if len(args) > 0 {
163 return util.CmdHandler(ShowArgumentsDialogMsg{
164 CommandID: id,
165 Description: desc,
166 Content: content,
167 ArgNames: args,
168 })
169 }
170
171 return util.CmdHandler(CommandRunCustomMsg{
172 Content: content,
173 })
174 }
175}
176
177func extractArgNames(content string) []string {
178 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
179 if len(matches) == 0 {
180 return nil
181 }
182
183 seen := make(map[string]bool)
184 var args []string
185
186 for _, match := range matches {
187 arg := match[1]
188 if !seen[arg] {
189 seen[arg] = true
190 args = append(args, arg)
191 }
192 }
193
194 return args
195}
196
197func ensureDir(path string) error {
198 if _, err := os.Stat(path); os.IsNotExist(err) {
199 return os.MkdirAll(path, 0o755)
200 }
201 return nil
202}
203
204func isMarkdownFile(name string) bool {
205 return strings.HasSuffix(strings.ToLower(name), ".md")
206}
207
208type CommandRunCustomMsg struct {
209 Content string
210}
211
212func LoadMCPPrompts() []Command {
213 prompts := agent.GetMCPPrompts()
214 commands := make([]Command, 0, len(prompts))
215
216 for key, prompt := range prompts {
217 p := prompt
218 clientName, promptName, ok := strings.Cut(key, ":")
219 if !ok {
220 continue
221 }
222 commands = append(commands, Command{
223 ID: key,
224 Title: clientName + ":" + promptName,
225 Description: p.Description,
226 Handler: createMCPPromptHandler(key, promptName, p),
227 })
228 }
229
230 return commands
231}
232
233func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
234 return func(cmd Command) tea.Cmd {
235 if len(prompt.Arguments) == 0 {
236 return executeMCPPromptWithoutArgs(key, promptName)
237 }
238 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
239 PromptID: cmd.ID,
240 PromptName: promptName,
241 })
242 }
243}
244
245func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
246 return func() tea.Msg {
247 // key format is "clientName:promptName"
248 parts := strings.SplitN(key, ":", 2)
249 if len(parts) != 2 {
250 return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
251 }
252 clientName := parts[0]
253
254 ctx := context.Background()
255 result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
256 if err != nil {
257 return util.ReportError(err)
258 }
259
260 var content strings.Builder
261 for _, msg := range result.Messages {
262 if msg.Role == "user" {
263 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
264 content.WriteString(textContent.Text)
265 content.WriteString("\n")
266 }
267 }
268 }
269
270 return chat.SendMsg{
271 Text: content.String(),
272 }
273 }
274}
275
276type ShowMCPPromptArgumentsDialogMsg struct {
277 PromptID string
278 PromptName string
279}