loader.go

  1package commands
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io/fs"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"regexp"
 11	"strings"
 12
 13	tea "github.com/charmbracelet/bubbletea/v2"
 14	"github.com/modelcontextprotocol/go-sdk/mcp"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/home"
 18	"github.com/charmbracelet/crush/internal/llm/agent"
 19	"github.com/charmbracelet/crush/internal/tui/components/chat"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	userCommandPrefix    = "user:"
 25	projectCommandPrefix = "project:"
 26)
 27
 28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 29
 30type commandLoader struct {
 31	sources []commandSource
 32}
 33
 34type commandSource struct {
 35	path   string
 36	prefix string
 37}
 38
 39func LoadCustomCommands() ([]Command, error) {
 40	cfg := config.Get()
 41	if cfg == nil {
 42		return nil, fmt.Errorf("config not loaded")
 43	}
 44
 45	loader := &commandLoader{
 46		sources: buildCommandSources(cfg),
 47	}
 48
 49	return loader.loadAll()
 50}
 51
 52func buildCommandSources(cfg *config.Config) []commandSource {
 53	var sources []commandSource
 54
 55	// XDG config directory
 56	if dir := getXDGCommandsDir(); dir != "" {
 57		sources = append(sources, commandSource{
 58			path:   dir,
 59			prefix: userCommandPrefix,
 60		})
 61	}
 62
 63	// Home directory
 64	if home := home.Dir(); home != "" {
 65		sources = append(sources, commandSource{
 66			path:   filepath.Join(home, ".crush", "commands"),
 67			prefix: userCommandPrefix,
 68		})
 69	}
 70
 71	// Project directory
 72	sources = append(sources, commandSource{
 73		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 74		prefix: projectCommandPrefix,
 75	})
 76
 77	return sources
 78}
 79
 80func getXDGCommandsDir() string {
 81	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 82	if xdgHome == "" {
 83		if home := home.Dir(); home != "" {
 84			xdgHome = filepath.Join(home, ".config")
 85		}
 86	}
 87	if xdgHome != "" {
 88		return filepath.Join(xdgHome, "crush", "commands")
 89	}
 90	return ""
 91}
 92
 93func (l *commandLoader) loadAll() ([]Command, error) {
 94	var commands []Command
 95
 96	for _, source := range l.sources {
 97		if cmds, err := l.loadFromSource(source); err == nil {
 98			commands = append(commands, cmds...)
 99		}
100	}
101
102	return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106	if err := ensureDir(source.path); err != nil {
107		return nil, err
108	}
109
110	var commands []Command
111
112	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114			return err
115		}
116
117		cmd, err := l.loadCommand(path, source.path, source.prefix)
118		if err != nil {
119			return nil // Skip invalid files
120		}
121
122		commands = append(commands, cmd)
123		return nil
124	})
125
126	return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130	content, err := os.ReadFile(path)
131	if err != nil {
132		return Command{}, err
133	}
134
135	id := buildCommandID(path, baseDir, prefix)
136	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
137
138	return Command{
139		ID:          id,
140		Title:       id,
141		Description: desc,
142		Handler:     createCommandHandler(id, desc, string(content)),
143	}, nil
144}
145
146func buildCommandID(path, baseDir, prefix string) string {
147	relPath, _ := filepath.Rel(baseDir, path)
148	parts := strings.Split(relPath, string(filepath.Separator))
149
150	// Remove .md extension from last part
151	if len(parts) > 0 {
152		lastIdx := len(parts) - 1
153		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
154	}
155
156	return prefix + strings.Join(parts, ":")
157}
158
159func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
160	return func(cmd Command) tea.Cmd {
161		args := extractArgNames(content)
162
163		if len(args) == 0 {
164			return util.CmdHandler(CommandRunCustomMsg{
165				Content: content,
166			})
167		}
168		return util.CmdHandler(ShowArgumentsDialogMsg{
169			CommandID:   id,
170			Description: desc,
171			ArgNames:    args,
172			OnSubmit: func(args map[string]string) tea.Cmd {
173				return execUserPrompt(content, args)
174			},
175		})
176	}
177}
178
179func execUserPrompt(content string, args map[string]string) tea.Cmd {
180	return func() tea.Msg {
181		for name, value := range args {
182			placeholder := "$" + name
183			content = strings.ReplaceAll(content, placeholder, value)
184		}
185		return CommandRunCustomMsg{
186			Content: content,
187		}
188	}
189}
190
191func extractArgNames(content string) []string {
192	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
193	if len(matches) == 0 {
194		return nil
195	}
196
197	seen := make(map[string]bool)
198	var args []string
199
200	for _, match := range matches {
201		arg := match[1]
202		if !seen[arg] {
203			seen[arg] = true
204			args = append(args, arg)
205		}
206	}
207
208	return args
209}
210
211func ensureDir(path string) error {
212	if _, err := os.Stat(path); os.IsNotExist(err) {
213		return os.MkdirAll(path, 0o755)
214	}
215	return nil
216}
217
218func isMarkdownFile(name string) bool {
219	return strings.HasSuffix(strings.ToLower(name), ".md")
220}
221
222type CommandRunCustomMsg struct {
223	Content string
224}
225
226func LoadMCPPrompts() []Command {
227	prompts := agent.GetMCPPrompts()
228	commands := make([]Command, 0, len(prompts))
229
230	for key, prompt := range prompts {
231		clientName, promptName, ok := strings.Cut(key, ":")
232		if !ok {
233			slog.Warn("prompt not found", "key", key)
234			continue
235		}
236		commands = append(commands, Command{
237			ID:          key,
238			Title:       clientName + ":" + promptName,
239			Description: prompt.Description,
240			Handler:     createMCPPromptHandler(clientName, promptName, prompt),
241		})
242	}
243
244	return commands
245}
246
247func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
248	return func(cmd Command) tea.Cmd {
249		if len(prompt.Arguments) == 0 {
250			return execMCPPrompt(clientName, promptName, nil)
251		}
252		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
253			Prompt: prompt,
254			OnSubmit: func(args map[string]string) tea.Cmd {
255				return execMCPPrompt(clientName, promptName, args)
256			},
257		})
258	}
259}
260
261func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
262	return func() tea.Msg {
263		ctx := context.Background()
264		result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, args)
265		if err != nil {
266			return util.ReportError(err)
267		}
268
269		var content strings.Builder
270		for _, msg := range result.Messages {
271			if msg.Role == "user" {
272				if textContent, ok := msg.Content.(*mcp.TextContent); ok {
273					content.WriteString(textContent.Text)
274					content.WriteString("\n")
275				}
276			}
277		}
278
279		return chat.SendMsg{
280			Text: content.String(),
281		}
282	}
283}
284
285type ShowMCPPromptArgumentsDialogMsg struct {
286	Prompt   *mcp.Prompt
287	OnSubmit func(arg map[string]string) tea.Cmd
288}