loader.go

  1package commands
  2
  3import (
  4	"fmt"
  5	"io/fs"
  6	"os"
  7	"path/filepath"
  8	"regexp"
  9	"strings"
 10
 11	tea "github.com/charmbracelet/bubbletea/v2"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/tui/util"
 14)
 15
 16const (
 17	UserCommandPrefix    = "user:"
 18	ProjectCommandPrefix = "project:"
 19)
 20
 21var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
 22
 23type commandLoader struct {
 24	sources []commandSource
 25}
 26
 27type commandSource struct {
 28	path   string
 29	prefix string
 30}
 31
 32func LoadCustomCommands(cfg *config.Config) ([]Command, error) {
 33	if cfg == nil {
 34		return nil, fmt.Errorf("config not loaded")
 35	}
 36
 37	loader := &commandLoader{
 38		sources: buildCommandSources(cfg),
 39	}
 40
 41	return loader.loadAll()
 42}
 43
 44func buildCommandSources(cfg *config.Config) []commandSource {
 45	var sources []commandSource
 46
 47	// XDG config directory
 48	if dir := getXDGCommandsDir(); dir != "" {
 49		sources = append(sources, commandSource{
 50			path:   dir,
 51			prefix: UserCommandPrefix,
 52		})
 53	}
 54
 55	// Home directory
 56	if home, err := os.UserHomeDir(); err == nil {
 57		sources = append(sources, commandSource{
 58			path:   filepath.Join(home, ".crush", "commands"),
 59			prefix: UserCommandPrefix,
 60		})
 61	}
 62
 63	// Project directory
 64	sources = append(sources, commandSource{
 65		path:   filepath.Join(cfg.Options.DataDirectory, "commands"),
 66		prefix: ProjectCommandPrefix,
 67	})
 68
 69	return sources
 70}
 71
 72func getXDGCommandsDir() string {
 73	xdgHome := os.Getenv("XDG_CONFIG_HOME")
 74	if xdgHome == "" {
 75		if home, err := os.UserHomeDir(); err == nil {
 76			xdgHome = filepath.Join(home, ".config")
 77		}
 78	}
 79	if xdgHome != "" {
 80		return filepath.Join(xdgHome, "crush", "commands")
 81	}
 82	return ""
 83}
 84
 85func (l *commandLoader) loadAll() ([]Command, error) {
 86	var commands []Command
 87
 88	for _, source := range l.sources {
 89		if cmds, err := l.loadFromSource(source); err == nil {
 90			commands = append(commands, cmds...)
 91		}
 92	}
 93
 94	return commands, nil
 95}
 96
 97func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
 98	if err := ensureDir(source.path); err != nil {
 99		return nil, err
100	}
101
102	var commands []Command
103
104	err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
105		if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
106			return err
107		}
108
109		cmd, err := l.loadCommand(path, source.path, source.prefix)
110		if err != nil {
111			return nil // Skip invalid files
112		}
113
114		commands = append(commands, cmd)
115		return nil
116	})
117
118	return commands, err
119}
120
121func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
122	content, err := os.ReadFile(path)
123	if err != nil {
124		return Command{}, err
125	}
126
127	id := buildCommandID(path, baseDir, prefix)
128
129	return Command{
130		ID:          id,
131		Title:       id,
132		Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
133		Handler:     createCommandHandler(id, string(content)),
134	}, nil
135}
136
137func buildCommandID(path, baseDir, prefix string) string {
138	relPath, _ := filepath.Rel(baseDir, path)
139	parts := strings.Split(relPath, string(filepath.Separator))
140
141	// Remove .md extension from last part
142	if len(parts) > 0 {
143		lastIdx := len(parts) - 1
144		parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
145	}
146
147	return prefix + strings.Join(parts, ":")
148}
149
150func createCommandHandler(id string, content string) func(Command) tea.Cmd {
151	return func(cmd Command) tea.Cmd {
152		args := extractArgNames(content)
153
154		if len(args) > 0 {
155			return util.CmdHandler(ShowArgumentsDialogMsg{
156				CommandID: id,
157				Content:   content,
158				ArgNames:  args,
159			})
160		}
161
162		return util.CmdHandler(CommandRunCustomMsg{
163			Content: content,
164		})
165	}
166}
167
168func extractArgNames(content string) []string {
169	matches := namedArgPattern.FindAllStringSubmatch(content, -1)
170	if len(matches) == 0 {
171		return nil
172	}
173
174	seen := make(map[string]bool)
175	var args []string
176
177	for _, match := range matches {
178		arg := match[1]
179		if !seen[arg] {
180			seen[arg] = true
181			args = append(args, arg)
182		}
183	}
184
185	return args
186}
187
188func ensureDir(path string) error {
189	if _, err := os.Stat(path); os.IsNotExist(err) {
190		return os.MkdirAll(path, 0o755)
191	}
192	return nil
193}
194
195func isMarkdownFile(name string) bool {
196	return strings.HasSuffix(strings.ToLower(name), ".md")
197}
198
199type CommandRunCustomMsg struct {
200	Content string
201}