1package commands
2
3import (
4 "context"
5 "fmt"
6 "io/fs"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "regexp"
11 "strings"
12
13 tea "github.com/charmbracelet/bubbletea/v2"
14 "github.com/modelcontextprotocol/go-sdk/mcp"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/llm/agent"
19 "github.com/charmbracelet/crush/internal/tui/components/chat"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 userCommandPrefix = "user:"
25 projectCommandPrefix = "project:"
26)
27
28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
29
30type commandLoader struct {
31 sources []commandSource
32}
33
34type commandSource struct {
35 path string
36 prefix string
37}
38
39func LoadCustomCommands() ([]Command, error) {
40 cfg := config.Get()
41 if cfg == nil {
42 return nil, fmt.Errorf("config not loaded")
43 }
44
45 loader := &commandLoader{
46 sources: buildCommandSources(cfg),
47 }
48
49 return loader.loadAll()
50}
51
52func buildCommandSources(cfg *config.Config) []commandSource {
53 var sources []commandSource
54
55 // XDG config directory
56 if dir := getXDGCommandsDir(); dir != "" {
57 sources = append(sources, commandSource{
58 path: dir,
59 prefix: userCommandPrefix,
60 })
61 }
62
63 // Home directory
64 if home := home.Dir(); home != "" {
65 sources = append(sources, commandSource{
66 path: filepath.Join(home, ".crush", "commands"),
67 prefix: userCommandPrefix,
68 })
69 }
70
71 // Project directory
72 sources = append(sources, commandSource{
73 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
74 prefix: projectCommandPrefix,
75 })
76
77 return sources
78}
79
80func getXDGCommandsDir() string {
81 xdgHome := os.Getenv("XDG_CONFIG_HOME")
82 if xdgHome == "" {
83 if home := home.Dir(); home != "" {
84 xdgHome = filepath.Join(home, ".config")
85 }
86 }
87 if xdgHome != "" {
88 return filepath.Join(xdgHome, "crush", "commands")
89 }
90 return ""
91}
92
93func (l *commandLoader) loadAll() ([]Command, error) {
94 var commands []Command
95
96 for _, source := range l.sources {
97 if cmds, err := l.loadFromSource(source); err == nil {
98 commands = append(commands, cmds...)
99 }
100 }
101
102 return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106 if err := ensureDir(source.path); err != nil {
107 return nil, err
108 }
109
110 var commands []Command
111
112 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114 return err
115 }
116
117 cmd, err := l.loadCommand(path, source.path, source.prefix)
118 if err != nil {
119 return nil // Skip invalid files
120 }
121
122 commands = append(commands, cmd)
123 return nil
124 })
125
126 return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130 content, err := os.ReadFile(path)
131 if err != nil {
132 return Command{}, err
133 }
134
135 id := buildCommandID(path, baseDir, prefix)
136 desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
137
138 return Command{
139 ID: id,
140 Title: id,
141 Description: desc,
142 Handler: createCommandHandler(id, desc, string(content)),
143 }, nil
144}
145
146func buildCommandID(path, baseDir, prefix string) string {
147 relPath, _ := filepath.Rel(baseDir, path)
148 parts := strings.Split(relPath, string(filepath.Separator))
149
150 // Remove .md extension from last part
151 if len(parts) > 0 {
152 lastIdx := len(parts) - 1
153 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
154 }
155
156 return prefix + strings.Join(parts, ":")
157}
158
159func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
160 return func(cmd Command) tea.Cmd {
161 args := extractArgNames(content)
162
163 if len(args) > 0 {
164 return util.CmdHandler(ShowArgumentsDialogMsg{
165 CommandID: id,
166 Description: desc,
167 Content: content,
168 ArgNames: args,
169 })
170 }
171
172 return util.CmdHandler(CommandRunCustomMsg{
173 Content: content,
174 })
175 }
176}
177
178func extractArgNames(content string) []string {
179 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
180 if len(matches) == 0 {
181 return nil
182 }
183
184 seen := make(map[string]bool)
185 var args []string
186
187 for _, match := range matches {
188 arg := match[1]
189 if !seen[arg] {
190 seen[arg] = true
191 args = append(args, arg)
192 }
193 }
194
195 return args
196}
197
198func ensureDir(path string) error {
199 if _, err := os.Stat(path); os.IsNotExist(err) {
200 return os.MkdirAll(path, 0o755)
201 }
202 return nil
203}
204
205func isMarkdownFile(name string) bool {
206 return strings.HasSuffix(strings.ToLower(name), ".md")
207}
208
209type CommandRunCustomMsg struct {
210 Content string
211}
212
213func LoadMCPPrompts() []Command {
214 prompts := agent.GetMCPPrompts()
215 commands := make([]Command, 0, len(prompts))
216
217 for key, prompt := range prompts {
218 clientName, promptName, ok := strings.Cut(key, ":")
219 if !ok {
220 slog.Warn("prompt not found", "key", key)
221 continue
222 }
223 commands = append(commands, Command{
224 ID: key,
225 Title: clientName + ":" + promptName,
226 Description: prompt.Description,
227 Handler: createMCPPromptHandler(clientName, promptName, prompt),
228 })
229 }
230
231 return commands
232}
233
234func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
235 return func(cmd Command) tea.Cmd {
236 if len(prompt.Arguments) == 0 {
237 return execMCPPrompt(clientName, promptName, nil)
238 }
239 return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
240 Prompt: prompt,
241 OnSubmit: func(args map[string]string) tea.Cmd {
242 return execMCPPrompt(clientName, promptName, args)
243 },
244 })
245 }
246}
247
248func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
249 return func() tea.Msg {
250 ctx := context.Background()
251 result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, args)
252 if err != nil {
253 return util.ReportError(err)
254 }
255
256 var content strings.Builder
257 for _, msg := range result.Messages {
258 if msg.Role == "user" {
259 if textContent, ok := msg.Content.(*mcp.TextContent); ok {
260 content.WriteString(textContent.Text)
261 content.WriteString("\n")
262 }
263 }
264 }
265
266 return chat.SendMsg{
267 Text: content.String(),
268 }
269 }
270}
271
272type ShowMCPPromptArgumentsDialogMsg struct {
273 Prompt *mcp.Prompt
274 OnSubmit func(arg map[string]string) tea.Cmd
275}