loader.go

  1package commands
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"io/fs"
  8	"os"
  9	"path/filepath"
 10	"regexp"
 11	"strings"
 12
 13	tea "github.com/charmbracelet/bubbletea/v2"
 14	"github.com/modelcontextprotocol/go-sdk/mcp"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/home"
 18	"github.com/charmbracelet/crush/internal/llm/agent"
 19	"github.com/charmbracelet/crush/internal/tui/components/chat"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	userCommandPrefix    = "user:"
 25	projectCommandPrefix = "project:"
 26)
 27
 28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 29
 30type commandLoader struct {
 31	sources []commandSource
 32}
 33
 34type commandSource struct {
 35	path   string
 36	prefix string
 37}
 38
 39func LoadCustomCommands() ([]Command, error) {
 40	cfg := config.Get()
 41	if cfg == nil {
 42		return nil, fmt.Errorf("config not loaded")
 43	}
 44
 45	loader := &commandLoader{
 46		sources: buildCommandSources(cfg),
 47	}
 48
 49	return loader.loadAll()
 50}
 51
 52func buildCommandSources(cfg *config.Config) []commandSource {
 53	var sources []commandSource
 54
 55	// XDG config directory
 56	if dir := getXDGCommandsDir(); dir != "" {
 57		sources = append(sources, commandSource{
 58			path:   dir,
 59			prefix: userCommandPrefix,
 60		})
 61	}
 62
 63	// Home directory
 64	if home := home.Dir(); home != "" {
 65		sources = append(sources, commandSource{
 66			path:   filepath.Join(home, ".crush", "commands"),
 67			prefix: userCommandPrefix,
 68		})
 69	}
 70
 71	// Project directory
 72	sources = append(sources, commandSource{
 73		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 74		prefix: projectCommandPrefix,
 75	})
 76
 77	return sources
 78}
 79
 80func getXDGCommandsDir() string {
 81	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 82	if xdgHome == "" {
 83		if home := home.Dir(); home != "" {
 84			xdgHome = filepath.Join(home, ".config")
 85		}
 86	}
 87	if xdgHome != "" {
 88		return filepath.Join(xdgHome, "crush", "commands")
 89	}
 90	return ""
 91}
 92
 93func (l *commandLoader) loadAll() ([]Command, error) {
 94	var commands []Command
 95
 96	for _, source := range l.sources {
 97		if cmds, err := l.loadFromSource(source); err == nil {
 98			commands = append(commands, cmds...)
 99		}
100	}
101
102	return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106	if err := ensureDir(source.path); err != nil {
107		return nil, err
108	}
109
110	var commands []Command
111
112	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114			return err
115		}
116
117		cmd, err := l.loadCommand(path, source.path, source.prefix)
118		if err != nil {
119			return nil // Skip invalid files
120		}
121
122		commands = append(commands, cmd)
123		return nil
124	})
125
126	return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130	content, err := os.ReadFile(path)
131	if err != nil {
132		return Command{}, err
133	}
134
135	id := buildCommandID(path, baseDir, prefix)
136	desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
137
138	return Command{
139		ID:          id,
140		Title:       id,
141		Description: desc,
142		Handler:     createCommandHandler(id, desc, string(content)),
143	}, nil
144}
145
146func buildCommandID(path, baseDir, prefix string) string {
147	relPath, _ := filepath.Rel(baseDir, path)
148	parts := strings.Split(relPath, string(filepath.Separator))
149
150	// Remove .md extension from last part
151	if len(parts) > 0 {
152		lastIdx := len(parts) - 1
153		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
154	}
155
156	return prefix + strings.Join(parts, ":")
157}
158
159func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
160	return func(cmd Command) tea.Cmd {
161		args := extractArgNames(content)
162
163		if len(args) > 0 {
164			return util.CmdHandler(ShowArgumentsDialogMsg{
165				CommandID:   id,
166				Description: desc,
167				Content:     content,
168				ArgNames:    args,
169			})
170		}
171
172		return util.CmdHandler(CommandRunCustomMsg{
173			Content: content,
174		})
175	}
176}
177
178func extractArgNames(content string) []string {
179	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
180	if len(matches) == 0 {
181		return nil
182	}
183
184	seen := make(map[string]bool)
185	var args []string
186
187	for _, match := range matches {
188		arg := match[1]
189		if !seen[arg] {
190			seen[arg] = true
191			args = append(args, arg)
192		}
193	}
194
195	return args
196}
197
198func ensureDir(path string) error {
199	if _, err := os.Stat(path); os.IsNotExist(err) {
200		return os.MkdirAll(path, 0o755)
201	}
202	return nil
203}
204
205func isMarkdownFile(name string) bool {
206	return strings.HasSuffix(strings.ToLower(name), ".md")
207}
208
209type CommandRunCustomMsg struct {
210	Content string
211}
212
213func LoadMCPPrompts() []Command {
214	prompts := agent.GetMCPPrompts()
215	commands := make([]Command, 0, len(prompts))
216
217	for key, prompt := range prompts {
218		p := prompt
219		// key format is "clientName:promptName"
220		parts := strings.SplitN(key, ":", 2)
221		if len(parts) != 2 {
222			continue
223		}
224		clientName, promptName := parts[0], parts[1]
225		displayName := clientName + " " + cmp.Or(p.Title, promptName)
226		commands = append(commands, Command{
227			ID:          key,
228			Title:       displayName,
229			Description: fmt.Sprintf("[%s] %s", clientName, p.Description),
230			Handler:     createMCPPromptHandler(key, promptName, p),
231		})
232	}
233
234	return commands
235}
236
237func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
238	return func(cmd Command) tea.Cmd {
239		if len(prompt.Arguments) == 0 {
240			return executeMCPPromptWithoutArgs(key, promptName)
241		}
242		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
243			PromptID:   cmd.ID,
244			PromptName: promptName,
245		})
246	}
247}
248
249func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
250	return func() tea.Msg {
251		// key format is "clientName:promptName"
252		parts := strings.SplitN(key, ":", 2)
253		if len(parts) != 2 {
254			return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
255		}
256		clientName := parts[0]
257
258		ctx := context.Background()
259		result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
260		if err != nil {
261			return util.ReportError(err)
262		}
263
264		var content strings.Builder
265		for _, msg := range result.Messages {
266			if msg.Role == "user" {
267				if textContent, ok := msg.Content.(*mcp.TextContent); ok {
268					content.WriteString(textContent.Text)
269					content.WriteString("\n")
270				}
271			}
272		}
273
274		return chat.SendMsg{
275			Text: content.String(),
276		}
277	}
278}
279
280type ShowMCPPromptArgumentsDialogMsg struct {
281	PromptID   string
282	PromptName string
283}