1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "testing"
13 "text/template"
14 "time"
15
16 "git.secluded.site/crush/internal/config"
17 "git.secluded.site/crush/internal/filepathext"
18 "git.secluded.site/crush/internal/home"
19 "git.secluded.site/crush/internal/shell"
20 "git.secluded.site/crush/internal/skills"
21)
22
23// Prompt represents a template-based prompt generator.
24type Prompt struct {
25 name string
26 template string
27 now func() time.Time
28 platform string
29 workingDir string
30}
31
32type PromptDat struct {
33 Provider string
34 Model string
35 Config config.Config
36 WorkingDir string
37 IsGitRepo bool
38 Platform string
39 Date string
40 GitStatus string
41 ContextFiles []ContextFile
42 GlobalContextFiles []ContextFile
43 AvailSkillXML string
44}
45
46type ContextFile struct {
47 Path string
48 Content string
49}
50
51type Option func(*Prompt)
52
53func WithTimeFunc(fn func() time.Time) Option {
54 return func(p *Prompt) {
55 p.now = fn
56 }
57}
58
59func WithPlatform(platform string) Option {
60 return func(p *Prompt) {
61 p.platform = platform
62 }
63}
64
65func WithWorkingDir(workingDir string) Option {
66 return func(p *Prompt) {
67 p.workingDir = workingDir
68 }
69}
70
71func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
72 p := &Prompt{
73 name: name,
74 template: promptTemplate,
75 now: time.Now,
76 }
77 for _, opt := range opts {
78 opt(p)
79 }
80 return p, nil
81}
82
83func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
84 t, err := template.New(p.name).Parse(p.template)
85 if err != nil {
86 return "", fmt.Errorf("parsing template: %w", err)
87 }
88 var sb strings.Builder
89 d, err := p.promptData(ctx, provider, model, store)
90 if err != nil {
91 return "", err
92 }
93 if err := t.Execute(&sb, d); err != nil {
94 return "", fmt.Errorf("executing template: %w", err)
95 }
96
97 return sb.String(), nil
98}
99
100func processFile(filePath string) *ContextFile {
101 content, err := os.ReadFile(filePath)
102 if err != nil {
103 return nil
104 }
105 return &ContextFile{
106 Path: filePath,
107 Content: string(content),
108 }
109}
110
111func processContextPath(p string, store *config.ConfigStore) []ContextFile {
112 var contexts []ContextFile
113 fullPath := filepathext.SmartJoin(store.WorkingDir(), p)
114 info, err := os.Stat(fullPath)
115 if err != nil {
116 return contexts
117 }
118 if info.IsDir() {
119 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
120 if err != nil {
121 return err
122 }
123 if !d.IsDir() {
124 if result := processFile(path); result != nil {
125 contexts = append(contexts, *result)
126 }
127 }
128 return nil
129 })
130 } else {
131 result := processFile(fullPath)
132 if result != nil {
133 contexts = append(contexts, *result)
134 }
135 }
136 return contexts
137}
138
139// expandPath expands ~ and environment variables in file paths
140func expandPath(path string, store *config.ConfigStore) string {
141 path = home.Long(path)
142 // Handle environment variable expansion using the same pattern as config
143 if strings.HasPrefix(path, "$") {
144 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
145 path = expanded
146 }
147 }
148
149 return path
150}
151
152func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
153 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
154 platform := cmp.Or(p.platform, runtime.GOOS)
155
156 contextFiles := map[string][]ContextFile{}
157 globalContextFiles := map[string][]ContextFile{}
158
159 cfg := store.Config()
160 for _, pth := range cfg.Options.ContextPaths {
161 expanded := expandPath(pth, store)
162 pathKey := strings.ToLower(expanded)
163 if _, ok := contextFiles[pathKey]; ok {
164 continue
165 }
166 content := processContextPath(expanded, store)
167 contextFiles[pathKey] = content
168 }
169
170 for _, pth := range cfg.Options.GlobalContextPaths {
171 expanded := expandPath(pth, store)
172 pathKey := strings.ToLower(expanded)
173 if _, ok := globalContextFiles[pathKey]; ok {
174 continue
175 }
176 content := processContextPath(expanded, store)
177 globalContextFiles[pathKey] = content
178 }
179
180 // Discover and load skills metadata.
181 var availSkillXML string
182
183 // Start with builtin skills.
184 allSkills := skills.DiscoverBuiltin()
185 builtinNames := make(map[string]bool, len(allSkills))
186 for _, s := range allSkills {
187 builtinNames[s.Name] = true
188 }
189
190 // Discover user skills from configured paths.
191 if len(cfg.Options.SkillsPaths) > 0 {
192 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
193 for _, pth := range cfg.Options.SkillsPaths {
194 expandedPaths = append(expandedPaths, expandPath(pth, store))
195 }
196 for _, userSkill := range skills.Discover(expandedPaths) {
197 if builtinNames[userSkill.Name] {
198 slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
199 }
200 allSkills = append(allSkills, userSkill)
201 }
202 }
203
204 // Deduplicate: user skills override builtins with the same name.
205 allSkills = skills.Deduplicate(allSkills)
206
207 // Filter out disabled skills.
208 allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
209
210 if len(allSkills) > 0 {
211 availSkillXML = skills.ToPromptXML(allSkills)
212 }
213
214 isGit := isGitRepo(store.WorkingDir())
215 data := PromptDat{
216 Provider: provider,
217 Model: model,
218 Config: *cfg,
219 WorkingDir: filepath.ToSlash(workingDir),
220 IsGitRepo: isGit,
221 Platform: platform,
222 Date: p.now().Format("1/2/2006"),
223 AvailSkillXML: availSkillXML,
224 }
225 if isGit {
226 var err error
227 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
228 if err != nil {
229 return PromptDat{}, err
230 }
231 }
232
233 for _, files := range contextFiles {
234 data.ContextFiles = append(data.ContextFiles, files...)
235 }
236 if !testing.Testing() {
237 for _, files := range globalContextFiles {
238 data.GlobalContextFiles = append(data.GlobalContextFiles, files...)
239 }
240 }
241 return data, nil
242}
243
244func isGitRepo(dir string) bool {
245 _, err := os.Stat(filepath.Join(dir, ".git"))
246 return err == nil
247}
248
249func getGitStatus(ctx context.Context, dir string) (string, error) {
250 sh := shell.NewShell(&shell.Options{
251 WorkingDir: dir,
252 })
253 branch, err := getGitBranch(ctx, sh)
254 if err != nil {
255 return "", err
256 }
257 status, err := getGitStatusSummary(ctx, sh)
258 if err != nil {
259 return "", err
260 }
261 commits, err := getGitRecentCommits(ctx, sh)
262 if err != nil {
263 return "", err
264 }
265 return branch + status + commits, nil
266}
267
268func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
269 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
270 if err != nil {
271 return "", nil
272 }
273 out = strings.TrimSpace(out)
274 if out == "" {
275 return "", nil
276 }
277 return fmt.Sprintf("Current branch: %s\n", out), nil
278}
279
280func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
281 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
282 if err != nil {
283 return "", nil
284 }
285 out = strings.TrimSpace(out)
286 if out == "" {
287 return "Status: clean\n", nil
288 }
289 return fmt.Sprintf("Status:\n%s\n", out), nil
290}
291
292func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
293 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
294 if err != nil || out == "" {
295 return "", nil
296 }
297 out = strings.TrimSpace(out)
298 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
299}
300
301func (p *Prompt) Name() string {
302 return p.name
303}