1package commands
2
3import (
4 "fmt"
5 "io/fs"
6 "os"
7 "path/filepath"
8 "regexp"
9 "strings"
10
11 tea "github.com/charmbracelet/bubbletea/v2"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17const (
18 UserCommandPrefix = "user:"
19 ProjectCommandPrefix = "project:"
20)
21
22var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
23
24type commandLoader struct {
25 sources []commandSource
26}
27
28type commandSource struct {
29 path string
30 prefix string
31}
32
33func LoadCustomCommands() ([]Command, error) {
34 cfg := config.Get()
35 if cfg == nil {
36 return nil, fmt.Errorf("config not loaded")
37 }
38
39 loader := &commandLoader{
40 sources: buildCommandSources(cfg),
41 }
42
43 return loader.loadAll()
44}
45
46func buildCommandSources(cfg *config.Config) []commandSource {
47 var sources []commandSource
48
49 // XDG config directory
50 if dir := getXDGCommandsDir(); dir != "" {
51 sources = append(sources, commandSource{
52 path: dir,
53 prefix: UserCommandPrefix,
54 })
55 }
56
57 // Home directory
58 if home := home.Dir(); home != "" {
59 sources = append(sources, commandSource{
60 path: filepath.Join(home, ".crush", "commands"),
61 prefix: UserCommandPrefix,
62 })
63 }
64
65 // Project directory
66 sources = append(sources, commandSource{
67 path: filepath.Join(cfg.Options.DataDirectory, "commands"),
68 prefix: ProjectCommandPrefix,
69 })
70
71 return sources
72}
73
74func getXDGCommandsDir() string {
75 xdgHome := os.Getenv("XDG_CONFIG_HOME")
76 if xdgHome == "" {
77 if home := home.Dir(); home != "" {
78 xdgHome = filepath.Join(home, ".config")
79 }
80 }
81 if xdgHome != "" {
82 return filepath.Join(xdgHome, "crush", "commands")
83 }
84 return ""
85}
86
87func (l *commandLoader) loadAll() ([]Command, error) {
88 var commands []Command
89
90 for _, source := range l.sources {
91 if cmds, err := l.loadFromSource(source); err == nil {
92 commands = append(commands, cmds...)
93 }
94 }
95
96 return commands, nil
97}
98
99func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
100 if err := ensureDir(source.path); err != nil {
101 return nil, err
102 }
103
104 var commands []Command
105
106 err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
107 if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
108 return err
109 }
110
111 cmd, err := l.loadCommand(path, source.path, source.prefix)
112 if err != nil {
113 return nil // Skip invalid files
114 }
115
116 commands = append(commands, cmd)
117 return nil
118 })
119
120 return commands, err
121}
122
123func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
124 content, err := os.ReadFile(path)
125 if err != nil {
126 return Command{}, err
127 }
128
129 id := buildCommandID(path, baseDir, prefix)
130
131 return Command{
132 ID: id,
133 Title: id,
134 Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)),
135 Handler: createCommandHandler(id, string(content)),
136 }, nil
137}
138
139func buildCommandID(path, baseDir, prefix string) string {
140 relPath, _ := filepath.Rel(baseDir, path)
141 parts := strings.Split(relPath, string(filepath.Separator))
142
143 // Remove .md extension from last part
144 if len(parts) > 0 {
145 lastIdx := len(parts) - 1
146 parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
147 }
148
149 return prefix + strings.Join(parts, ":")
150}
151
152func createCommandHandler(id string, content string) func(Command) tea.Cmd {
153 return func(cmd Command) tea.Cmd {
154 args := extractArgNames(content)
155
156 if len(args) > 0 {
157 return util.CmdHandler(ShowArgumentsDialogMsg{
158 CommandID: id,
159 Content: content,
160 ArgNames: args,
161 })
162 }
163
164 return util.CmdHandler(CommandRunCustomMsg{
165 Content: content,
166 })
167 }
168}
169
170func extractArgNames(content string) []string {
171 matches := namedArgPattern.FindAllStringSubmatch(content, -1)
172 if len(matches) == 0 {
173 return nil
174 }
175
176 seen := make(map[string]bool)
177 var args []string
178
179 for _, match := range matches {
180 arg := match[1]
181 if !seen[arg] {
182 seen[arg] = true
183 args = append(args, arg)
184 }
185 }
186
187 return args
188}
189
190func ensureDir(path string) error {
191 if _, err := os.Stat(path); os.IsNotExist(err) {
192 return os.MkdirAll(path, 0o755)
193 }
194 return nil
195}
196
197func isMarkdownFile(name string) bool {
198 return strings.HasSuffix(strings.ToLower(name), ".md")
199}
200
201type CommandRunCustomMsg struct {
202 Content string
203}