commands.go

  1package commands
  2
  3import (
  4	"context"
  5	"io/fs"
  6	"os"
  7	"path/filepath"
  8	"regexp"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/home"
 14)
 15
 16var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 17
 18const (
 19	userCommandPrefix    = "user:"
 20	projectCommandPrefix = "project:"
 21)
 22
 23// Argument represents a command argument with its metadata.
 24type Argument struct {
 25	ID          string
 26	Title       string
 27	Description string
 28	Required    bool
 29}
 30
 31// MCPPrompt represents a custom command loaded from an MCP server.
 32type MCPPrompt struct {
 33	ID          string
 34	Title       string
 35	Description string
 36	PromptID    string
 37	ClientID    string
 38	Arguments   []Argument
 39}
 40
 41// CustomCommand represents a user-defined custom command loaded from markdown files.
 42type CustomCommand struct {
 43	ID        string
 44	Name      string
 45	Content   string
 46	Arguments []Argument
 47}
 48
 49type commandSource struct {
 50	path   string
 51	prefix string
 52}
 53
 54// LoadCustomCommands loads custom commands from multiple sources including
 55// XDG config directory, home directory, and project directory.
 56func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
 57	return loadAll(buildCommandSources(cfg))
 58}
 59
 60// LoadMCPPrompts loads custom commands from available MCP servers.
 61func LoadMCPPrompts() ([]MCPPrompt, error) {
 62	var commands []MCPPrompt
 63	for mcpName, prompts := range mcp.Prompts() {
 64		for _, prompt := range prompts {
 65			key := mcpName + ":" + prompt.Name
 66			var args []Argument
 67			for _, arg := range prompt.Arguments {
 68				title := arg.Title
 69				if title == "" {
 70					title = arg.Name
 71				}
 72				args = append(args, Argument{
 73					ID:          arg.Name,
 74					Title:       title,
 75					Description: arg.Description,
 76					Required:    arg.Required,
 77				})
 78			}
 79			commands = append(commands, MCPPrompt{
 80				ID:          key,
 81				Title:       prompt.Title,
 82				Description: prompt.Description,
 83				PromptID:    prompt.Name,
 84				ClientID:    mcpName,
 85				Arguments:   args,
 86			})
 87		}
 88	}
 89	return commands, nil
 90}
 91
 92func buildCommandSources(cfg *config.Config) []commandSource {
 93	var sources []commandSource
 94
 95	// XDG config directory
 96	if dir := getXDGCommandsDir(); dir != "" {
 97		sources = append(sources, commandSource{
 98			path:   dir,
 99			prefix: userCommandPrefix,
100		})
101	}
102
103	// Home directory
104	if home := home.Dir(); home != "" {
105		sources = append(sources, commandSource{
106			path:   filepath.Join(home, ".crush", "commands"),
107			prefix: userCommandPrefix,
108		})
109	}
110
111	// Project directory
112	sources = append(sources, commandSource{
113		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
114		prefix: projectCommandPrefix,
115	})
116
117	return sources
118}
119
120func loadAll(sources []commandSource) ([]CustomCommand, error) {
121	var commands []CustomCommand
122
123	for _, source := range sources {
124		if cmds, err := loadFromSource(source); err == nil {
125			commands = append(commands, cmds...)
126		}
127	}
128
129	return commands, nil
130}
131
132func loadFromSource(source commandSource) ([]CustomCommand, error) {
133	if err := ensureDir(source.path); err != nil {
134		return nil, err
135	}
136
137	var commands []CustomCommand
138
139	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
140		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
141			return err
142		}
143
144		cmd, err := loadCommand(path, source.path, source.prefix)
145		if err != nil {
146			return nil // Skip invalid files
147		}
148
149		commands = append(commands, cmd)
150		return nil
151	})
152
153	return commands, err
154}
155
156func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
157	content, err := os.ReadFile(path)
158	if err != nil {
159		return CustomCommand{}, err
160	}
161
162	id := buildCommandID(path, baseDir, prefix)
163
164	return CustomCommand{
165		ID:        id,
166		Name:      id,
167		Content:   string(content),
168		Arguments: extractArgNames(string(content)),
169	}, nil
170}
171
172func extractArgNames(content string) []Argument {
173	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
174	if len(matches) == 0 {
175		return nil
176	}
177
178	seen := make(map[string]bool)
179	var args []Argument
180
181	for _, match := range matches {
182		arg := match[1]
183		if !seen[arg] {
184			seen[arg] = true
185			// for normal custom commands, all args are required
186			args = append(args, Argument{ID: arg, Title: arg, Required: true})
187		}
188	}
189
190	return args
191}
192
193func buildCommandID(path, baseDir, prefix string) string {
194	relPath, _ := filepath.Rel(baseDir, path)
195	parts := strings.Split(relPath, string(filepath.Separator))
196
197	// Remove .md extension from last part
198	if len(parts) > 0 {
199		lastIdx := len(parts) - 1
200		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
201	}
202
203	return prefix + strings.Join(parts, ":")
204}
205
206func getXDGCommandsDir() string {
207	xdgHome := os.Getenv("XDG_CONFIG_HOME")
208	if xdgHome == "" {
209		if home := home.Dir(); home != "" {
210			xdgHome = filepath.Join(home, ".config")
211		}
212	}
213	if xdgHome != "" {
214		return filepath.Join(xdgHome, "crush", "commands")
215	}
216	return ""
217}
218
219func ensureDir(path string) error {
220	if _, err := os.Stat(path); os.IsNotExist(err) {
221		return os.MkdirAll(path, 0o755)
222	}
223	return nil
224}
225
226func isMarkdownFile(name string) bool {
227	return strings.HasSuffix(strings.ToLower(name), ".md")
228}
229
230func GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
231	// TODO: we should pass the context down
232	result, err := mcp.GetPromptMessages(context.Background(), clientID, promptID, args)
233	if err != nil {
234		return "", err
235	}
236	return strings.Join(result, " "), nil
237}