commands.go

  1package commands
  2
  3import (
  4	"io/fs"
  5	"os"
  6	"path/filepath"
  7	"regexp"
  8	"strings"
  9
 10	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/home"
 13)
 14
 15var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 16
 17const (
 18	userCommandPrefix    = "user:"
 19	projectCommandPrefix = "project:"
 20)
 21
 22// Argument represents a command argument with its name and required status.
 23type Argument struct {
 24	Name     string
 25	Required bool
 26}
 27
 28// MCPCustomCommand represents a custom command loaded from an MCP server.
 29type MCPCustomCommand struct {
 30	ID        string
 31	Name      string
 32	Client    string
 33	Arguments []Argument
 34}
 35
 36// CustomCommand represents a user-defined custom command loaded from markdown files.
 37type CustomCommand struct {
 38	ID        string
 39	Name      string
 40	Content   string
 41	Arguments []Argument
 42}
 43
 44type commandSource struct {
 45	path   string
 46	prefix string
 47}
 48
 49// LoadCustomCommands loads custom commands from multiple sources including
 50// XDG config directory, home directory, and project directory.
 51func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
 52	return loadAll(buildCommandSources(cfg))
 53}
 54
 55// LoadMCPCustomCommands loads custom commands from available MCP servers.
 56func LoadMCPCustomCommands() ([]MCPCustomCommand, error) {
 57	var commands []MCPCustomCommand
 58	for mcpName, prompts := range mcp.Prompts() {
 59		for _, prompt := range prompts {
 60			key := mcpName + ":" + prompt.Name
 61			var args []Argument
 62			for _, arg := range prompt.Arguments {
 63				args = append(args, Argument{Name: arg.Name, Required: arg.Required})
 64			}
 65
 66			commands = append(commands, MCPCustomCommand{
 67				ID:        key,
 68				Name:      prompt.Name,
 69				Client:    mcpName,
 70				Arguments: args,
 71			})
 72		}
 73	}
 74	return commands, nil
 75}
 76
 77func buildCommandSources(cfg *config.Config) []commandSource {
 78	var sources []commandSource
 79
 80	// XDG config directory
 81	if dir := getXDGCommandsDir(); dir != "" {
 82		sources = append(sources, commandSource{
 83			path:   dir,
 84			prefix: userCommandPrefix,
 85		})
 86	}
 87
 88	// Home directory
 89	if home := home.Dir(); home != "" {
 90		sources = append(sources, commandSource{
 91			path:   filepath.Join(home, ".crush", "commands"),
 92			prefix: userCommandPrefix,
 93		})
 94	}
 95
 96	// Project directory
 97	sources = append(sources, commandSource{
 98		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 99		prefix: projectCommandPrefix,
100	})
101
102	return sources
103}
104
105func loadAll(sources []commandSource) ([]CustomCommand, error) {
106	var commands []CustomCommand
107
108	for _, source := range sources {
109		if cmds, err := loadFromSource(source); err == nil {
110			commands = append(commands, cmds...)
111		}
112	}
113
114	return commands, nil
115}
116
117func loadFromSource(source commandSource) ([]CustomCommand, error) {
118	if err := ensureDir(source.path); err != nil {
119		return nil, err
120	}
121
122	var commands []CustomCommand
123
124	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
125		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
126			return err
127		}
128
129		cmd, err := loadCommand(path, source.path, source.prefix)
130		if err != nil {
131			return nil // Skip invalid files
132		}
133
134		commands = append(commands, cmd)
135		return nil
136	})
137
138	return commands, err
139}
140
141func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
142	content, err := os.ReadFile(path)
143	if err != nil {
144		return CustomCommand{}, err
145	}
146
147	id := buildCommandID(path, baseDir, prefix)
148
149	return CustomCommand{
150		ID:        id,
151		Name:      id,
152		Content:   string(content),
153		Arguments: extractArgNames(string(content)),
154	}, nil
155}
156
157func extractArgNames(content string) []Argument {
158	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
159	if len(matches) == 0 {
160		return nil
161	}
162
163	seen := make(map[string]bool)
164	var args []Argument
165
166	for _, match := range matches {
167		arg := match[1]
168		if !seen[arg] {
169			seen[arg] = true
170			// for normal custom commands, all args are required
171			args = append(args, Argument{Name: arg, Required: true})
172		}
173	}
174
175	return args
176}
177
178func buildCommandID(path, baseDir, prefix string) string {
179	relPath, _ := filepath.Rel(baseDir, path)
180	parts := strings.Split(relPath, string(filepath.Separator))
181
182	// Remove .md extension from last part
183	if len(parts) > 0 {
184		lastIdx := len(parts) - 1
185		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
186	}
187
188	return prefix + strings.Join(parts, ":")
189}
190
191func getXDGCommandsDir() string {
192	xdgHome := os.Getenv("XDG_CONFIG_HOME")
193	if xdgHome == "" {
194		if home := home.Dir(); home != "" {
195			xdgHome = filepath.Join(home, ".config")
196		}
197	}
198	if xdgHome != "" {
199		return filepath.Join(xdgHome, "crush", "commands")
200	}
201	return ""
202}
203
204func ensureDir(path string) error {
205	if _, err := os.Stat(path); os.IsNotExist(err) {
206		return os.MkdirAll(path, 0o755)
207	}
208	return nil
209}
210
211func isMarkdownFile(name string) bool {
212	return strings.HasSuffix(strings.ToLower(name), ".md")
213}