1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "git.secluded.site/crush/internal/config"
15 "git.secluded.site/crush/internal/home"
16 "git.secluded.site/crush/internal/shell"
17)
18
19// Prompt represents a template-based prompt generator.
20type Prompt struct {
21 name string
22 template string
23 now func() time.Time
24 platform string
25 workingDir string
26}
27
28type PromptDat struct {
29 Provider string
30 Model string
31 Config config.Config
32 WorkingDir string
33 IsGitRepo bool
34 Platform string
35 Date string
36 GitStatus string
37 ContextFiles []ContextFile
38 MemoryFiles []ContextFile
39}
40
41type ContextFile struct {
42 Path string
43 Content string
44}
45
46type Option func(*Prompt)
47
48func WithTimeFunc(fn func() time.Time) Option {
49 return func(p *Prompt) {
50 p.now = fn
51 }
52}
53
54func WithPlatform(platform string) Option {
55 return func(p *Prompt) {
56 p.platform = platform
57 }
58}
59
60func WithWorkingDir(workingDir string) Option {
61 return func(p *Prompt) {
62 p.workingDir = workingDir
63 }
64}
65
66func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
67 p := &Prompt{
68 name: name,
69 template: promptTemplate,
70 now: time.Now,
71 }
72 for _, opt := range opts {
73 opt(p)
74 }
75 return p, nil
76}
77
78func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
79 t, err := template.New(p.name).Parse(p.template)
80 if err != nil {
81 return "", fmt.Errorf("parsing template: %w", err)
82 }
83 var sb strings.Builder
84 d, err := p.promptData(ctx, provider, model, cfg)
85 if err != nil {
86 return "", err
87 }
88 if err := t.Execute(&sb, d); err != nil {
89 return "", fmt.Errorf("executing template: %w", err)
90 }
91
92 return sb.String(), nil
93}
94
95func processFile(filePath string) *ContextFile {
96 content, err := os.ReadFile(filePath)
97 if err != nil {
98 return nil
99 }
100 return &ContextFile{
101 Path: filePath,
102 Content: string(content),
103 }
104}
105
106func processContextPath(p string, cfg config.Config) []ContextFile {
107 var contexts []ContextFile
108 fullPath := p
109 if !filepath.IsAbs(p) {
110 fullPath = filepath.Join(cfg.WorkingDir(), p)
111 }
112 info, err := os.Stat(fullPath)
113 if err != nil {
114 return contexts
115 }
116 if info.IsDir() {
117 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
118 if err != nil {
119 return err
120 }
121 if !d.IsDir() {
122 if result := processFile(path); result != nil {
123 contexts = append(contexts, *result)
124 }
125 }
126 return nil
127 })
128 } else {
129 result := processFile(fullPath)
130 if result != nil {
131 contexts = append(contexts, *result)
132 }
133 }
134 return contexts
135}
136
137// expandPath expands ~ and environment variables in file paths
138func expandPath(path string, cfg config.Config) string {
139 path = home.Long(path)
140 // Handle environment variable expansion using the same pattern as config
141 if strings.HasPrefix(path, "$") {
142 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
143 path = expanded
144 }
145 }
146
147 return path
148}
149
150func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
151 workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
152 platform := cmp.Or(p.platform, runtime.GOOS)
153
154 contextFiles := map[string][]ContextFile{}
155 memoryFiles := map[string][]ContextFile{}
156
157 for _, pth := range cfg.Options.ContextPaths {
158 expanded := expandPath(pth, cfg)
159 pathKey := strings.ToLower(expanded)
160 if _, ok := contextFiles[pathKey]; ok {
161 continue
162 }
163 content := processContextPath(expanded, cfg)
164 contextFiles[pathKey] = content
165 }
166
167 for _, pth := range cfg.Options.MemoryPaths {
168 expanded := expandPath(pth, cfg)
169 pathKey := strings.ToLower(expanded)
170 if _, ok := memoryFiles[pathKey]; ok {
171 continue
172 }
173 content := processContextPath(expanded, cfg)
174 memoryFiles[pathKey] = content
175 }
176
177 isGit := isGitRepo(cfg.WorkingDir())
178 data := PromptDat{
179 Provider: provider,
180 Model: model,
181 Config: cfg,
182 WorkingDir: filepath.ToSlash(workingDir),
183 IsGitRepo: isGit,
184 Platform: platform,
185 Date: p.now().Format("1/2/2006"),
186 }
187 if isGit {
188 var err error
189 data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
190 if err != nil {
191 return PromptDat{}, err
192 }
193 }
194
195 for _, files := range contextFiles {
196 data.ContextFiles = append(data.ContextFiles, files...)
197 }
198 for _, files := range memoryFiles {
199 data.MemoryFiles = append(data.MemoryFiles, files...)
200 }
201 return data, nil
202}
203
204func isGitRepo(dir string) bool {
205 _, err := os.Stat(filepath.Join(dir, ".git"))
206 return err == nil
207}
208
209func getGitStatus(ctx context.Context, dir string) (string, error) {
210 sh := shell.NewShell(&shell.Options{
211 WorkingDir: dir,
212 })
213 branch, err := getGitBranch(ctx, sh)
214 if err != nil {
215 return "", err
216 }
217 status, err := getGitStatusSummary(ctx, sh)
218 if err != nil {
219 return "", err
220 }
221 commits, err := getGitRecentCommits(ctx, sh)
222 if err != nil {
223 return "", err
224 }
225 return branch + status + commits, nil
226}
227
228func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
229 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
230 if err != nil {
231 return "", nil
232 }
233 out = strings.TrimSpace(out)
234 if out == "" {
235 return "", nil
236 }
237 return fmt.Sprintf("Current branch: %s\n", out), nil
238}
239
240func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
241 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
242 if err != nil {
243 return "", nil
244 }
245 out = strings.TrimSpace(out)
246 if out == "" {
247 return "Status: clean\n", nil
248 }
249 return fmt.Sprintf("Status:\n%s\n", out), nil
250}
251
252func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
253 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
254 if err != nil || out == "" {
255 return "", nil
256 }
257 out = strings.TrimSpace(out)
258 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
259}
260
261func (p *Prompt) Name() string {
262 return p.name
263}