commands.go

  1package commands
  2
  3import (
  4	"context"
  5	"io/fs"
  6	"os"
  7	"path/filepath"
  8	"regexp"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/home"
 15)
 16
 17var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 18
 19const (
 20	userCommandPrefix    = "user:"
 21	projectCommandPrefix = "project:"
 22)
 23
 24// Argument represents a command argument with its metadata.
 25type Argument struct {
 26	ID          string
 27	Title       string
 28	Description string
 29	Required    bool
 30}
 31
 32// MCPPrompt represents a custom command loaded from an MCP server.
 33type MCPPrompt struct {
 34	ID          string
 35	Title       string
 36	Description string
 37	PromptID    string
 38	ClientID    string
 39	Arguments   []Argument
 40}
 41
 42// CustomCommand represents a user-defined custom command loaded from markdown files.
 43type CustomCommand struct {
 44	ID        string
 45	Name      string
 46	Content   string
 47	Arguments []Argument
 48}
 49
 50type commandSource struct {
 51	path   string
 52	prefix string
 53}
 54
 55// LoadCustomCommands loads custom commands from multiple sources including
 56// XDG config directory, home directory, and project directory.
 57func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
 58	return loadAll(buildCommandSources(cfg))
 59}
 60
 61// LoadMCPPrompts loads custom commands from available MCP servers.
 62func LoadMCPPrompts() ([]MCPPrompt, error) {
 63	var commands []MCPPrompt
 64	for mcpName, prompts := range mcp.Prompts() {
 65		for _, prompt := range prompts {
 66			key := mcpName + ":" + prompt.Name
 67			var args []Argument
 68			for _, arg := range prompt.Arguments {
 69				title := arg.Title
 70				if title == "" {
 71					title = arg.Name
 72				}
 73				args = append(args, Argument{
 74					ID:          arg.Name,
 75					Title:       title,
 76					Description: arg.Description,
 77					Required:    arg.Required,
 78				})
 79			}
 80			commands = append(commands, MCPPrompt{
 81				ID:          key,
 82				Title:       prompt.Title,
 83				Description: prompt.Description,
 84				PromptID:    prompt.Name,
 85				ClientID:    mcpName,
 86				Arguments:   args,
 87			})
 88		}
 89	}
 90	return commands, nil
 91}
 92
 93func buildCommandSources(cfg *config.Config) []commandSource {
 94	var sources []commandSource
 95
 96	// XDG config directory
 97	if dir := getXDGCommandsDir(); dir != "" {
 98		sources = append(sources, commandSource{
 99			path:   dir,
100			prefix: userCommandPrefix,
101		})
102	}
103
104	// Home directory
105	if home := home.Dir(); home != "" {
106		sources = append(sources, commandSource{
107			path:   filepath.Join(home, ".crush", "commands"),
108			prefix: userCommandPrefix,
109		})
110	}
111
112	// Project directory
113	sources = append(sources, commandSource{
114		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
115		prefix: projectCommandPrefix,
116	})
117
118	return sources
119}
120
121func loadAll(sources []commandSource) ([]CustomCommand, error) {
122	var commands []CustomCommand
123
124	for _, source := range sources {
125		if cmds, err := loadFromSource(source); err == nil {
126			commands = append(commands, cmds...)
127		}
128	}
129
130	return commands, nil
131}
132
133func loadFromSource(source commandSource) ([]CustomCommand, error) {
134	if _, err := os.Stat(source.path); os.IsNotExist(err) {
135		return nil, nil
136	}
137
138	var commands []CustomCommand
139
140	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
141		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
142			return err
143		}
144
145		cmd, err := loadCommand(path, source.path, source.prefix)
146		if err != nil {
147			return nil // Skip invalid files
148		}
149
150		commands = append(commands, cmd)
151		return nil
152	})
153
154	return commands, err
155}
156
157func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
158	content, err := os.ReadFile(path)
159	if err != nil {
160		return CustomCommand{}, err
161	}
162
163	id := buildCommandID(path, baseDir, prefix)
164
165	return CustomCommand{
166		ID:        id,
167		Name:      id,
168		Content:   string(content),
169		Arguments: extractArgNames(string(content)),
170	}, nil
171}
172
173func extractArgNames(content string) []Argument {
174	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
175	if len(matches) == 0 {
176		return nil
177	}
178
179	seen := make(map[string]bool)
180	var args []Argument
181
182	for _, match := range matches {
183		arg := match[1]
184		if !seen[arg] {
185			seen[arg] = true
186			// for normal custom commands, all args are required
187			args = append(args, Argument{ID: arg, Title: arg, Required: true})
188		}
189	}
190
191	return args
192}
193
194func buildCommandID(path, baseDir, prefix string) string {
195	relPath, _ := filepath.Rel(baseDir, path)
196	parts := strings.Split(relPath, string(filepath.Separator))
197
198	// Remove .md extension from last part
199	if len(parts) > 0 {
200		lastIdx := len(parts) - 1
201		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
202	}
203
204	return prefix + strings.Join(parts, ":")
205}
206
207func getXDGCommandsDir() string {
208	xdgHome := os.Getenv("XDG_CONFIG_HOME")
209	if xdgHome == "" {
210		if home := home.Dir(); home != "" {
211			xdgHome = filepath.Join(home, ".config")
212		}
213	}
214	if xdgHome != "" {
215		return filepath.Join(xdgHome, "crush", "commands")
216	}
217	return ""
218}
219
220func isMarkdownFile(name string) bool {
221	return strings.HasSuffix(strings.ToLower(name), ".md")
222}
223
224func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
225	// Create a context with timeout since tea.Cmd doesn't support context passing.
226	// The MCP client has its own timeout, but this provides an additional safeguard.
227	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
228	defer cancel()
229
230	result, err := mcp.GetPromptMessages(ctx, cfg, clientID, promptID, args)
231	if err != nil {
232		return "", err
233	}
234	return strings.Join(result, " "), nil
235}