1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/home"
16 "github.com/charmbracelet/crush/internal/shell"
17)
18
19// Prompt represents a template-based prompt generator.
20type Prompt struct {
21 name string
22 template string
23 now func() time.Time
24 platform string
25 workingDir string
26}
27
28type PromptDat struct {
29 Provider string
30 Model string
31 Config config.Config
32 WorkingDir string
33 IsGitRepo bool
34 Platform string
35 Date string
36 GitStatus string
37 ContextFiles []ContextFile
38}
39
40type ContextFile struct {
41 Path string
42 Content string
43}
44
45type Option func(*Prompt)
46
47func WithTimeFunc(fn func() time.Time) Option {
48 return func(p *Prompt) {
49 p.now = fn
50 }
51}
52
53func WithPlatform(platform string) Option {
54 return func(p *Prompt) {
55 p.platform = platform
56 }
57}
58
59func WithWorkingDir(workingDir string) Option {
60 return func(p *Prompt) {
61 p.workingDir = workingDir
62 }
63}
64
65func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
66 p := &Prompt{
67 name: name,
68 template: promptTemplate,
69 now: time.Now,
70 }
71 for _, opt := range opts {
72 opt(p)
73 }
74 return p, nil
75}
76
77func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
78 t, err := template.New(p.name).Parse(p.template)
79 if err != nil {
80 return "", fmt.Errorf("parsing template: %w", err)
81 }
82 var sb strings.Builder
83 d, err := p.promptData(ctx, provider, model, cfg)
84 if err != nil {
85 return "", err
86 }
87 if err := t.Execute(&sb, d); err != nil {
88 return "", fmt.Errorf("executing template: %w", err)
89 }
90
91 return sb.String(), nil
92}
93
94func processFile(filePath string) *ContextFile {
95 content, err := os.ReadFile(filePath)
96 if err != nil {
97 return nil
98 }
99 return &ContextFile{
100 Path: filePath,
101 Content: string(content),
102 }
103}
104
105func processContextPath(p string, cfg config.Config) []ContextFile {
106 var contexts []ContextFile
107 fullPath := p
108 if !filepath.IsAbs(p) {
109 fullPath = filepath.Join(cfg.WorkingDir(), p)
110 }
111 info, err := os.Stat(fullPath)
112 if err != nil {
113 return contexts
114 }
115 if info.IsDir() {
116 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
117 if err != nil {
118 return err
119 }
120 if !d.IsDir() {
121 if result := processFile(path); result != nil {
122 contexts = append(contexts, *result)
123 }
124 }
125 return nil
126 })
127 } else {
128 result := processFile(fullPath)
129 if result != nil {
130 contexts = append(contexts, *result)
131 }
132 }
133 return contexts
134}
135
136// expandPath expands ~ and environment variables in file paths
137func expandPath(path string, cfg config.Config) string {
138 path = home.Long(path)
139 // Handle environment variable expansion using the same pattern as config
140 if strings.HasPrefix(path, "$") {
141 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
142 path = expanded
143 }
144 }
145
146 return path
147}
148
149func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
150 workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
151 platform := cmp.Or(p.platform, runtime.GOOS)
152
153 files := map[string][]ContextFile{}
154
155 for _, pth := range cfg.Options.ContextPaths {
156 expanded := expandPath(pth, cfg)
157 pathKey := strings.ToLower(expanded)
158 if _, ok := files[pathKey]; ok {
159 continue
160 }
161 content := processContextPath(expanded, cfg)
162 files[pathKey] = content
163 }
164
165 isGit := isGitRepo(cfg.WorkingDir())
166 data := PromptDat{
167 Provider: provider,
168 Model: model,
169 Config: cfg,
170 WorkingDir: workingDir,
171 IsGitRepo: isGit,
172 Platform: platform,
173 Date: p.now().Format("1/2/2006"),
174 }
175 if isGit {
176 var err error
177 data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
178 if err != nil {
179 return PromptDat{}, err
180 }
181 }
182
183 for _, contextFiles := range files {
184 data.ContextFiles = append(data.ContextFiles, contextFiles...)
185 }
186 return data, nil
187}
188
189func isGitRepo(dir string) bool {
190 _, err := os.Stat(filepath.Join(dir, ".git"))
191 return err == nil
192}
193
194func getGitStatus(ctx context.Context, dir string) (string, error) {
195 sh := shell.NewShell(&shell.Options{
196 WorkingDir: dir,
197 })
198 branch, err := getGitBranch(ctx, sh)
199 if err != nil {
200 return "", err
201 }
202 status, err := getGitStatusSummary(ctx, sh)
203 if err != nil {
204 return "", err
205 }
206 commits, err := getGitRecentCommits(ctx, sh)
207 if err != nil {
208 return "", err
209 }
210 return branch + status + commits, nil
211}
212
213func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
214 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
215 if err != nil {
216 return "", nil
217 }
218 out = strings.TrimSpace(out)
219 if out == "" {
220 return "", nil
221 }
222 return fmt.Sprintf("Current branch: %s\n", out), nil
223}
224
225func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
226 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
227 if err != nil {
228 return "", nil
229 }
230 out = strings.TrimSpace(out)
231 if out == "" {
232 return "Status: clean\n", nil
233 }
234 return fmt.Sprintf("Status:\n%s\n", out), nil
235}
236
237func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
238 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
239 if err != nil || out == "" {
240 return "", nil
241 }
242 out = strings.TrimSpace(out)
243 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
244}
245
246func (p *Prompt) Name() string {
247 return p.name
248}