1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "testing"
13 "text/template"
14 "time"
15
16 "git.secluded.site/crush/internal/config"
17 "git.secluded.site/crush/internal/home"
18 "git.secluded.site/crush/internal/shell"
19 "git.secluded.site/crush/internal/skills"
20)
21
22// Prompt represents a template-based prompt generator.
23type Prompt struct {
24 name string
25 template string
26 now func() time.Time
27 platform string
28 workingDir string
29}
30
31type PromptDat struct {
32 Provider string
33 Model string
34 Config config.Config
35 WorkingDir string
36 IsGitRepo bool
37 Platform string
38 Date string
39 GitStatus string
40 ContextFiles []ContextFile
41 GlobalContextFiles []ContextFile
42 AvailSkillXML string
43}
44
45type ContextFile struct {
46 Path string
47 Content string
48}
49
50type Option func(*Prompt)
51
52func WithTimeFunc(fn func() time.Time) Option {
53 return func(p *Prompt) {
54 p.now = fn
55 }
56}
57
58func WithPlatform(platform string) Option {
59 return func(p *Prompt) {
60 p.platform = platform
61 }
62}
63
64func WithWorkingDir(workingDir string) Option {
65 return func(p *Prompt) {
66 p.workingDir = workingDir
67 }
68}
69
70func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
71 p := &Prompt{
72 name: name,
73 template: promptTemplate,
74 now: time.Now,
75 }
76 for _, opt := range opts {
77 opt(p)
78 }
79 return p, nil
80}
81
82func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
83 t, err := template.New(p.name).Parse(p.template)
84 if err != nil {
85 return "", fmt.Errorf("parsing template: %w", err)
86 }
87 var sb strings.Builder
88 d, err := p.promptData(ctx, provider, model, store)
89 if err != nil {
90 return "", err
91 }
92 if err := t.Execute(&sb, d); err != nil {
93 return "", fmt.Errorf("executing template: %w", err)
94 }
95
96 return sb.String(), nil
97}
98
99func processFile(filePath string) *ContextFile {
100 content, err := os.ReadFile(filePath)
101 if err != nil {
102 return nil
103 }
104 return &ContextFile{
105 Path: filePath,
106 Content: string(content),
107 }
108}
109
110func processContextPath(p string, store *config.ConfigStore) []ContextFile {
111 var contexts []ContextFile
112 fullPath := p
113 if !filepath.IsAbs(p) {
114 fullPath = filepath.Join(store.WorkingDir(), p)
115 }
116 info, err := os.Stat(fullPath)
117 if err != nil {
118 return contexts
119 }
120 if info.IsDir() {
121 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
122 if err != nil {
123 return err
124 }
125 if !d.IsDir() {
126 if result := processFile(path); result != nil {
127 contexts = append(contexts, *result)
128 }
129 }
130 return nil
131 })
132 } else {
133 result := processFile(fullPath)
134 if result != nil {
135 contexts = append(contexts, *result)
136 }
137 }
138 return contexts
139}
140
141// expandPath expands ~ and environment variables in file paths
142func expandPath(path string, store *config.ConfigStore) string {
143 path = home.Long(path)
144 // Handle environment variable expansion using the same pattern as config
145 if strings.HasPrefix(path, "$") {
146 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
147 path = expanded
148 }
149 }
150
151 return path
152}
153
154func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
155 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
156 platform := cmp.Or(p.platform, runtime.GOOS)
157
158 contextFiles := map[string][]ContextFile{}
159 globalContextFiles := map[string][]ContextFile{}
160
161 cfg := store.Config()
162 for _, pth := range cfg.Options.ContextPaths {
163 expanded := expandPath(pth, store)
164 pathKey := strings.ToLower(expanded)
165 if _, ok := contextFiles[pathKey]; ok {
166 continue
167 }
168 content := processContextPath(expanded, store)
169 contextFiles[pathKey] = content
170 }
171
172 for _, pth := range cfg.Options.GlobalContextPaths {
173 expanded := expandPath(pth, store)
174 pathKey := strings.ToLower(expanded)
175 if _, ok := globalContextFiles[pathKey]; ok {
176 continue
177 }
178 content := processContextPath(expanded, store)
179 globalContextFiles[pathKey] = content
180 }
181
182 // Discover and load skills metadata.
183 var availSkillXML string
184
185 // Start with builtin skills.
186 allSkills := skills.DiscoverBuiltin()
187 builtinNames := make(map[string]bool, len(allSkills))
188 for _, s := range allSkills {
189 builtinNames[s.Name] = true
190 }
191
192 // Discover user skills from configured paths.
193 if len(cfg.Options.SkillsPaths) > 0 {
194 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
195 for _, pth := range cfg.Options.SkillsPaths {
196 expandedPaths = append(expandedPaths, expandPath(pth, store))
197 }
198 for _, userSkill := range skills.Discover(expandedPaths) {
199 if builtinNames[userSkill.Name] {
200 slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
201 }
202 allSkills = append(allSkills, userSkill)
203 }
204 }
205
206 // Deduplicate: user skills override builtins with the same name.
207 allSkills = skills.Deduplicate(allSkills)
208
209 // Filter out disabled skills.
210 allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
211
212 if len(allSkills) > 0 {
213 availSkillXML = skills.ToPromptXML(allSkills)
214 }
215
216 isGit := isGitRepo(store.WorkingDir())
217 data := PromptDat{
218 Provider: provider,
219 Model: model,
220 Config: *cfg,
221 WorkingDir: filepath.ToSlash(workingDir),
222 IsGitRepo: isGit,
223 Platform: platform,
224 Date: p.now().Format("1/2/2006"),
225 AvailSkillXML: availSkillXML,
226 }
227 if isGit {
228 var err error
229 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
230 if err != nil {
231 return PromptDat{}, err
232 }
233 }
234
235 for _, files := range contextFiles {
236 data.ContextFiles = append(data.ContextFiles, files...)
237 }
238 if !testing.Testing() {
239 for _, files := range globalContextFiles {
240 data.GlobalContextFiles = append(data.GlobalContextFiles, files...)
241 }
242 }
243 return data, nil
244}
245
246func isGitRepo(dir string) bool {
247 _, err := os.Stat(filepath.Join(dir, ".git"))
248 return err == nil
249}
250
251func getGitStatus(ctx context.Context, dir string) (string, error) {
252 sh := shell.NewShell(&shell.Options{
253 WorkingDir: dir,
254 })
255 branch, err := getGitBranch(ctx, sh)
256 if err != nil {
257 return "", err
258 }
259 status, err := getGitStatusSummary(ctx, sh)
260 if err != nil {
261 return "", err
262 }
263 commits, err := getGitRecentCommits(ctx, sh)
264 if err != nil {
265 return "", err
266 }
267 return branch + status + commits, nil
268}
269
270func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
271 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
272 if err != nil {
273 return "", nil
274 }
275 out = strings.TrimSpace(out)
276 if out == "" {
277 return "", nil
278 }
279 return fmt.Sprintf("Current branch: %s\n", out), nil
280}
281
282func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
283 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
284 if err != nil {
285 return "", nil
286 }
287 out = strings.TrimSpace(out)
288 if out == "" {
289 return "Status: clean\n", nil
290 }
291 return fmt.Sprintf("Status:\n%s\n", out), nil
292}
293
294func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
295 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
296 if err != nil || out == "" {
297 return "", nil
298 }
299 out = strings.TrimSpace(out)
300 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
301}
302
303func (p *Prompt) Name() string {
304 return p.name
305}