loader.go

  1package commands
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io/fs"
  7	"os"
  8	"path/filepath"
  9	"regexp"
 10	"strings"
 11
 12	tea "github.com/charmbracelet/bubbletea/v2"
 13	"github.com/modelcontextprotocol/go-sdk/mcp"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/home"
 17	"github.com/charmbracelet/crush/internal/llm/agent"
 18	"github.com/charmbracelet/crush/internal/tui/components/chat"
 19	"github.com/charmbracelet/crush/internal/tui/util"
 20)
 21
 22const (
 23	UserCommandPrefix    = "user:"
 24	ProjectCommandPrefix = "project:"
 25	MCPPromptPrefix      = "mcp:"
 26)
 27
 28var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 29
 30type commandLoader struct {
 31	sources []commandSource
 32}
 33
 34type commandSource struct {
 35	path   string
 36	prefix string
 37}
 38
 39func LoadCustomCommands() ([]Command, error) {
 40	cfg := config.Get()
 41	if cfg == nil {
 42		return nil, fmt.Errorf("config not loaded")
 43	}
 44
 45	loader := &commandLoader{
 46		sources: buildCommandSources(cfg),
 47	}
 48
 49	return loader.loadAll()
 50}
 51
 52func buildCommandSources(cfg *config.Config) []commandSource {
 53	var sources []commandSource
 54
 55	// XDG config directory
 56	if dir := getXDGCommandsDir(); dir != "" {
 57		sources = append(sources, commandSource{
 58			path:   dir,
 59			prefix: UserCommandPrefix,
 60		})
 61	}
 62
 63	// Home directory
 64	if home := home.Dir(); home != "" {
 65		sources = append(sources, commandSource{
 66			path:   filepath.Join(home, ".crush", "commands"),
 67			prefix: UserCommandPrefix,
 68		})
 69	}
 70
 71	// Project directory
 72	sources = append(sources, commandSource{
 73		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 74		prefix: ProjectCommandPrefix,
 75	})
 76
 77	return sources
 78}
 79
 80func getXDGCommandsDir() string {
 81	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 82	if xdgHome == "" {
 83		if home := home.Dir(); home != "" {
 84			xdgHome = filepath.Join(home, ".config")
 85		}
 86	}
 87	if xdgHome != "" {
 88		return filepath.Join(xdgHome, "crush", "commands")
 89	}
 90	return ""
 91}
 92
 93func (l *commandLoader) loadAll() ([]Command, error) {
 94	var commands []Command
 95
 96	for _, source := range l.sources {
 97		if cmds, err := l.loadFromSource(source); err == nil {
 98			commands = append(commands, cmds...)
 99		}
100	}
101
102	return commands, nil
103}
104
105func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
106	if err := ensureDir(source.path); err != nil {
107		return nil, err
108	}
109
110	var commands []Command
111
112	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
113		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
114			return err
115		}
116
117		cmd, err := l.loadCommand(path, source.path, source.prefix)
118		if err != nil {
119			return nil // Skip invalid files
120		}
121
122		commands = append(commands, cmd)
123		return nil
124	})
125
126	return commands, err
127}
128
129func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
130	content, err := os.ReadFile(path)
131	if err != nil {
132		return Command{}, err
133	}
134
135	id := buildCommandID(path, baseDir, prefix)
136
137	return Command{
138		ID:          id,
139		Title:       id,
140		Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
141		Handler:     createCommandHandler(id, string(content)),
142	}, nil
143}
144
145func buildCommandID(path, baseDir, prefix string) string {
146	relPath, _ := filepath.Rel(baseDir, path)
147	parts := strings.Split(relPath, string(filepath.Separator))
148
149	// Remove .md extension from last part
150	if len(parts) > 0 {
151		lastIdx := len(parts) - 1
152		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
153	}
154
155	return prefix + strings.Join(parts, ":")
156}
157
158func createCommandHandler(id string, content string) func(Command) tea.Cmd {
159	return func(cmd Command) tea.Cmd {
160		args := extractArgNames(content)
161
162		if len(args) > 0 {
163			return util.CmdHandler(ShowArgumentsDialogMsg{
164				CommandID: id,
165				Content:   content,
166				ArgNames:  args,
167			})
168		}
169
170		return util.CmdHandler(CommandRunCustomMsg{
171			Content: content,
172		})
173	}
174}
175
176func extractArgNames(content string) []string {
177	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
178	if len(matches) == 0 {
179		return nil
180	}
181
182	seen := make(map[string]bool)
183	var args []string
184
185	for _, match := range matches {
186		arg := match[1]
187		if !seen[arg] {
188			seen[arg] = true
189			args = append(args, arg)
190		}
191	}
192
193	return args
194}
195
196func ensureDir(path string) error {
197	if _, err := os.Stat(path); os.IsNotExist(err) {
198		return os.MkdirAll(path, 0o755)
199	}
200	return nil
201}
202
203func isMarkdownFile(name string) bool {
204	return strings.HasSuffix(strings.ToLower(name), ".md")
205}
206
207type CommandRunCustomMsg struct {
208	Content string
209}
210
211func LoadMCPPrompts() []Command {
212	prompts := agent.GetMCPPrompts()
213	commands := make([]Command, 0, len(prompts))
214
215	for key, prompt := range prompts {
216		p := prompt
217		// key format is "clientName:promptName"
218		parts := strings.SplitN(key, ":", 2)
219		if len(parts) != 2 {
220			continue
221		}
222		clientName, promptName := parts[0], parts[1]
223
224		displayName := promptName
225		if p.Title != "" {
226			displayName = p.Title
227		}
228
229		commands = append(commands, Command{
230			ID:          MCPPromptPrefix + key,
231			Title:       displayName,
232			Description: fmt.Sprintf("[%s] %s", clientName, p.Description),
233			Handler:     createMCPPromptHandler(key, promptName, p),
234		})
235	}
236
237	return commands
238}
239
240func createMCPPromptHandler(key, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
241	return func(cmd Command) tea.Cmd {
242		if len(prompt.Arguments) == 0 {
243			return executeMCPPromptWithoutArgs(key, promptName)
244		}
245		return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
246			PromptID:   cmd.ID,
247			PromptName: promptName,
248		})
249	}
250}
251
252func executeMCPPromptWithoutArgs(key, promptName string) tea.Cmd {
253	return func() tea.Msg {
254		// key format is "clientName:promptName"
255		parts := strings.SplitN(key, ":", 2)
256		if len(parts) != 2 {
257			return util.ReportError(fmt.Errorf("invalid prompt key: %s", key))
258		}
259		clientName := parts[0]
260
261		ctx := context.Background()
262		result, err := agent.GetMCPPromptContent(ctx, clientName, promptName, nil)
263		if err != nil {
264			return util.ReportError(err)
265		}
266
267		var content strings.Builder
268		for _, msg := range result.Messages {
269			if msg.Role == "user" {
270				if textContent, ok := msg.Content.(*mcp.TextContent); ok {
271					content.WriteString(textContent.Text)
272					content.WriteString("\n")
273				}
274			}
275		}
276
277		return chat.SendMsg{
278			Text: content.String(),
279		}
280	}
281}
282
283type ShowMCPPromptArgumentsDialogMsg struct {
284	PromptID   string
285	PromptName string
286}