1package commands
2
3import (
4 "context"
5 "fmt"
6 "io/fs"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "regexp"
11 "strings"
12
13 tea "github.com/charmbracelet/bubbletea/v2"
14 "github.com/modelcontextprotocol/go-sdk/mcp"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/llm/agent"
19 "github.com/charmbracelet/crush/internal/tui/components/chat"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 userCommandPrefix = "user:"
25 projectCommandPrefix = "project:"
26)
27
28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
29
30type commandLoader struct {
31 sources []commandSource
32}
33
34type commandSource struct {
35 path string
36 prefix string
37}
38
39func LoadCustomCommands() ([]Command, error) {
40 cfg := config.Get()
41 if cfg == nil {
42 return nil, fmt.Errorf("config not loaded")
43 }
44
45 loader := &commandLoader{
46 sources: buildCommandSources(cfg),
47 }
48
49 return loader.loadAll()
50}
51
52func buildCommandSources(cfg *config.Config) []commandSource {
53 var sources []commandSource
54
55 // XDG config directory
56 if dir := getXDGCommandsDir(); dir != "" {
57 sources = append(sources, commandSource{
58 path: dir,
59 prefix: userCommandPrefix,
60 })
61 }
62
63 // Home directory
64 if home := home.Dir(); home != "" {
65 sources = append(sources, commandSource{
66 path: filepath.Join(home, ".crush", "commands"),
67 prefix: userCommandPrefix,
68 })
69 }
70
71 // Project directory
72 sources = append(sources, commandSource{
73 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
74 prefix: projectCommandPrefix,
75 })
76
77 return sources
78}
79
80func getXDGCommandsDir() string {
81 xdgHome := os.Getenv("XDG_CONFIG_HOME")
82 if xdgHome == "" {
83 if home := home.Dir(); home != "" {
84 xdgHome = filepath.Join(home, ".config")
85 }
86 }
87 if xdgHome != "" {
88 return filepath.Join(xdgHome, "crush", "commands")
89 }
90 return ""
91}
92
93func (l *commandLoader) loadAll() ([]Command, error) {
94 var commands []Command
95
96 for _, source := range l.sources {
97 if cmds, err := l.loadFromSource(source); err == nil {
98 commands = append(commands, cmds...)
99 }
100 }
101
102 return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106 if err := ensureDir(source.path); err != nil {
107 return nil, err
108 }
109
110 var commands []Command
111
112 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114 return err
115 }
116
117 cmd, err := l.loadCommand(path, source.path, source.prefix)
118 if err != nil {
119 return nil // Skip invalid files
120 }
121
122 commands = append(commands, cmd)
123 return nil
124 })
125
126 return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130 content, err := os.ReadFile(path)
131 if err != nil {
132 return Command{}, err
133 }
134
135 id := buildCommandID(path, baseDir, prefix)
136 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
137
138 return Command{
139 ID: id,
140 Title: id,
141 Description: desc,
142 Handler: createCommandHandler(id, desc, string(content)),
143 }, nil
144}
145
146func buildCommandID(path, baseDir, prefix string) string {
147 relPath, _ := filepath.Rel(baseDir, path)
148 parts := strings.Split(relPath, string(filepath.Separator))
149
150 // Remove .md extension from last part
151 if len(parts) > 0 {
152 lastIdx := len(parts) - 1
153 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
154 }
155
156 return prefix + strings.Join(parts, ":")
157}
158
159func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
160 return func(cmd Command) tea.Cmd {
161 args := extractArgNames(content)
162
163 if len(args) == 0 {
164 return util.CmdHandler(CommandRunCustomMsg{
165 Content: content,
166 })
167 }
168 return util.CmdHandler(ShowArgumentsDialogMsg{
169 CommandID: id,
170 Description: desc,
171 ArgNames: args,
172 OnSubmit: func(args map[string]string) tea.Cmd {
173 return execUserPrompt(content, args)
174 },
175 })
176 }
177}
178
179func execUserPrompt(content string, args map[string]string) tea.Cmd {
180 return func() tea.Msg {
181 for name, value := range args {
182 placeholder := "$" + name
183 content = strings.ReplaceAll(content, placeholder, value)
184 }
185 return CommandRunCustomMsg{
186 Content: content,
187 }
188 }
189}
190
191func extractArgNames(content string) []string {
192 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
193 if len(matches) == 0 {
194 return nil
195 }
196
197 seen := make(map[string]bool)
198 var args []string
199
200 for _, match := range matches {
201 arg := match[1]
202 if !seen[arg] {
203 seen[arg] = true
204 args = append(args, arg)
205 }
206 }
207
208 return args
209}
210
211func ensureDir(path string) error {
212 if _, err := os.Stat(path); os.IsNotExist(err) {
213 return os.MkdirAll(path, 0o755)
214 }
215 return nil
216}
217
218func isMarkdownFile(name string) bool {
219 return strings.HasSuffix(strings.ToLower(name), ".md")
220}
221
222type CommandRunCustomMsg struct {
223 Content string
224}
225
226func loadMCPPrompts() []Command {
227 prompts := agent.GetMCPPrompts()
228 commands := make([]Command, 0, len(prompts))
229
230 for key, prompt := range prompts {
231 clientName, promptName, ok := strings.Cut(key, ":")
232 if !ok {
233 slog.Warn("prompt not found", "key", key)
234 continue
235 }
236 commands = append(commands, Command{
237 ID: key,
238 Title: clientName + ":" + promptName,
239 Description: prompt.Description,
240 Handler: createMCPPromptHandler(clientName, promptName, prompt),
241 })
242 }
243
244 return commands
245}
246
247func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
248 return func(cmd Command) tea.Cmd {
249 if len(prompt.Arguments) == 0 {
250 return execMCPPrompt(clientName, promptName, nil)
251 }
252 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
253 Prompt: prompt,
254 OnSubmit: func(args map[string]string) tea.Cmd {
255 return execMCPPrompt(clientName, promptName, args)
256 },
257 })
258 }
259}
260
261func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
262 return func() tea.Msg {
263 ctx := context.Background()
264 result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, args)
265 if err != nil {
266 return util.ReportError(err)
267 }
268
269 var content strings.Builder
270 for _, msg := range result.Messages {
271 if msg.Role == "user" {
272 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
273 content.WriteString(textContent.Text)
274 content.WriteString("\n")
275 }
276 }
277 }
278
279 return chat.SendMsg{
280 Text: content.String(),
281 }
282 }
283}
284
285type ShowMCPPromptArgumentsDialogMsg struct {
286 Prompt *mcp.Prompt
287 OnSubmit func(arg map[string]string) tea.Cmd
288}