commands.go

  1package commands
  2
  3import (
  4	"context"
  5	"io/fs"
  6	"os"
  7	"path/filepath"
  8	"regexp"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/home"
 15	"github.com/charmbracelet/crush/internal/skills"
 16)
 17
 18var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 19
 20const (
 21	userCommandPrefix    = "user:"
 22	projectCommandPrefix = "project:"
 23)
 24
 25// Argument represents a command argument with its metadata.
 26type Argument struct {
 27	ID          string
 28	Title       string
 29	Description string
 30	Required    bool
 31}
 32
 33// MCPPrompt represents a custom command loaded from an MCP server.
 34type MCPPrompt struct {
 35	ID          string
 36	Title       string
 37	Description string
 38	PromptID    string
 39	ClientID    string
 40	Arguments   []Argument
 41}
 42
 43// CustomCommand represents a user-defined custom command loaded from markdown files.
 44type CustomCommand struct {
 45	ID        string
 46	Name      string
 47	Content   string
 48	Arguments []Argument
 49	// Skill is set when this command represents a user-invocable skill
 50	Skill *skills.Skill
 51}
 52
 53type commandSource struct {
 54	path   string
 55	prefix string
 56}
 57
 58// LoadCustomCommands loads custom commands from multiple sources including
 59// XDG config directory, home directory, and project directory.
 60func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
 61	return loadAll(buildCommandSources(cfg))
 62}
 63
 64// FromSkillCatalog converts user-invocable catalog entries into custom
 65// command entries for the command palette.
 66func FromSkillCatalog(entries []skills.CatalogEntry) []CustomCommand {
 67	commands := make([]CustomCommand, 0, len(entries))
 68	for _, entry := range entries {
 69		if !entry.UserInvocable {
 70			continue
 71		}
 72		name := entry.Label
 73		if name == "" {
 74			name = userCommandPrefix + entry.Name
 75		}
 76		commands = append(commands, CustomCommand{
 77			ID:   name,
 78			Name: name,
 79			Skill: &skills.Skill{
 80				Name:          entry.Name,
 81				Description:   entry.Description,
 82				SkillFilePath: entry.ID,
 83			},
 84		})
 85	}
 86	return commands
 87}
 88
 89// LoadMCPPrompts loads custom commands from available MCP servers.
 90func LoadMCPPrompts() ([]MCPPrompt, error) {
 91	var commands []MCPPrompt
 92	for mcpName, prompts := range mcp.Prompts() {
 93		for _, prompt := range prompts {
 94			key := mcpName + ":" + prompt.Name
 95			var args []Argument
 96			for _, arg := range prompt.Arguments {
 97				title := arg.Title
 98				if title == "" {
 99					title = arg.Name
100				}
101				args = append(args, Argument{
102					ID:          arg.Name,
103					Title:       title,
104					Description: arg.Description,
105					Required:    arg.Required,
106				})
107			}
108			commands = append(commands, MCPPrompt{
109				ID:          key,
110				Title:       prompt.Title,
111				Description: prompt.Description,
112				PromptID:    prompt.Name,
113				ClientID:    mcpName,
114				Arguments:   args,
115			})
116		}
117	}
118	return commands, nil
119}
120
121func buildCommandSources(cfg *config.Config) []commandSource {
122	return []commandSource{
123		{
124			path:   filepath.Join(home.Config(), "crush", "commands"),
125			prefix: userCommandPrefix,
126		},
127		{
128			path:   filepath.Join(home.Dir(), ".crush", "commands"),
129			prefix: userCommandPrefix,
130		},
131		{
132			path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
133			prefix: projectCommandPrefix,
134		},
135	}
136}
137
138func loadAll(sources []commandSource) ([]CustomCommand, error) {
139	var commands []CustomCommand
140
141	for _, source := range sources {
142		if cmds, err := loadFromSource(source); err == nil {
143			commands = append(commands, cmds...)
144		}
145	}
146
147	return commands, nil
148}
149
150func loadFromSource(source commandSource) ([]CustomCommand, error) {
151	if _, err := os.Stat(source.path); os.IsNotExist(err) {
152		return nil, nil
153	}
154
155	var commands []CustomCommand
156
157	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
158		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
159			return err
160		}
161
162		cmd, err := loadCommand(path, source.path, source.prefix)
163		if err != nil {
164			return nil // Skip invalid files
165		}
166
167		commands = append(commands, cmd)
168		return nil
169	})
170
171	return commands, err
172}
173
174func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
175	content, err := os.ReadFile(path)
176	if err != nil {
177		return CustomCommand{}, err
178	}
179
180	id := buildCommandID(path, baseDir, prefix)
181
182	return CustomCommand{
183		ID:        id,
184		Name:      id,
185		Content:   string(content),
186		Arguments: extractArgNames(string(content)),
187	}, nil
188}
189
190func extractArgNames(content string) []Argument {
191	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
192	if len(matches) == 0 {
193		return nil
194	}
195
196	seen := make(map[string]bool)
197	var args []Argument
198
199	for _, match := range matches {
200		arg := match[1]
201		if !seen[arg] {
202			seen[arg] = true
203			// for normal custom commands, all args are required
204			args = append(args, Argument{ID: arg, Title: arg, Required: true})
205		}
206	}
207
208	return args
209}
210
211func buildCommandID(path, baseDir, prefix string) string {
212	relPath, _ := filepath.Rel(baseDir, path)
213	parts := strings.Split(relPath, string(filepath.Separator))
214
215	// Remove .md extension from last part
216	if len(parts) > 0 {
217		lastIdx := len(parts) - 1
218		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
219	}
220
221	return prefix + strings.Join(parts, ":")
222}
223
224func isMarkdownFile(name string) bool {
225	return strings.HasSuffix(strings.ToLower(name), ".md")
226}
227
228func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
229	// Create a context with timeout since tea.Cmd doesn't support context passing.
230	// The MCP client has its own timeout, but this provides an additional safeguard.
231	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
232	defer cancel()
233
234	result, err := mcp.GetPromptMessages(ctx, cfg, clientID, promptID, args)
235	if err != nil {
236		return "", err
237	}
238	return strings.Join(result, " "), nil
239}