1package prompt
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "path/filepath"
8 "runtime"
9 "strings"
10 "text/template"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/home"
15 "github.com/charmbracelet/crush/internal/shell"
16)
17
18// Prompt represents a template-based prompt generator.
19type Prompt struct {
20 name string
21 template string
22 now func() time.Time
23 platform string
24 workingDir string
25}
26
27type PromptDat struct {
28 Provider string
29 Model string
30 Config config.Config
31 WorkingDir string
32 IsGitRepo bool
33 Platform string
34 Date string
35 GitStatus string
36 ContextFiles []ContextFile
37}
38
39type ContextFile struct {
40 Path string
41 Content string
42}
43
44type Option func(*Prompt)
45
46func WithTimeFunc(fn func() time.Time) Option {
47 return func(p *Prompt) {
48 p.now = fn
49 }
50}
51
52func WithPlatform(platform string) Option {
53 return func(p *Prompt) {
54 p.platform = platform
55 }
56}
57
58func WithWorkingDir(workingDir string) Option {
59 return func(p *Prompt) {
60 p.workingDir = workingDir
61 }
62}
63
64func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
65 p := &Prompt{
66 name: name,
67 template: promptTemplate,
68 now: time.Now,
69 }
70 for _, opt := range opts {
71 opt(p)
72 }
73 return p, nil
74}
75
76func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
77 t, err := template.New(p.name).Parse(p.template)
78 if err != nil {
79 return "", fmt.Errorf("parsing template: %w", err)
80 }
81 var sb strings.Builder
82 d, err := p.promptData(ctx, provider, model, cfg)
83 if err != nil {
84 return "", err
85 }
86 if err := t.Execute(&sb, d); err != nil {
87 return "", fmt.Errorf("executing template: %w", err)
88 }
89
90 return sb.String(), nil
91}
92
93func processFile(filePath string) *ContextFile {
94 content, err := os.ReadFile(filePath)
95 if err != nil {
96 return nil
97 }
98 return &ContextFile{
99 Path: filePath,
100 Content: string(content),
101 }
102}
103
104func processContextPath(p string, cfg config.Config) []ContextFile {
105 var contexts []ContextFile
106 fullPath := p
107 if !filepath.IsAbs(p) {
108 fullPath = filepath.Join(cfg.WorkingDir(), p)
109 }
110 info, err := os.Stat(fullPath)
111 if err != nil {
112 return contexts
113 }
114 if info.IsDir() {
115 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
116 if err != nil {
117 return err
118 }
119 if !d.IsDir() {
120 if result := processFile(path); result != nil {
121 contexts = append(contexts, *result)
122 }
123 }
124 return nil
125 })
126 } else {
127 result := processFile(fullPath)
128 if result != nil {
129 contexts = append(contexts, *result)
130 }
131 }
132 return contexts
133}
134
135// expandPath expands ~ and environment variables in file paths
136func expandPath(path string, cfg config.Config) string {
137 path = home.Long(path)
138 // Handle environment variable expansion using the same pattern as config
139 if strings.HasPrefix(path, "$") {
140 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
141 path = expanded
142 }
143 }
144
145 return path
146}
147
148func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
149 workingDir := cfg.WorkingDir()
150 if p.workingDir != "" {
151 workingDir = p.workingDir
152 }
153 platform := runtime.GOOS
154 if p.platform != "" {
155 platform = p.platform
156 }
157
158 files := map[string][]ContextFile{}
159
160 for _, pth := range cfg.Options.ContextPaths {
161 expanded := expandPath(pth, cfg)
162 pathKey := strings.ToLower(expanded)
163 if _, ok := files[pathKey]; ok {
164 continue
165 }
166 content := processContextPath(expanded, cfg)
167 files[pathKey] = content
168 }
169
170 isGit := isGitRepo(cfg.WorkingDir())
171 data := PromptDat{
172 Provider: provider,
173 Model: model,
174 Config: cfg,
175 WorkingDir: workingDir,
176 IsGitRepo: isGit,
177 Platform: platform,
178 Date: p.now().Format("1/2/2006"),
179 }
180 if isGit {
181 var err error
182 data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
183 if err != nil {
184 return PromptDat{}, err
185 }
186 }
187
188 for _, contextFiles := range files {
189 data.ContextFiles = append(data.ContextFiles, contextFiles...)
190 }
191 return data, nil
192}
193
194func isGitRepo(dir string) bool {
195 _, err := os.Stat(filepath.Join(dir, ".git"))
196 return err == nil
197}
198
199func getGitStatus(ctx context.Context, dir string) (string, error) {
200 sh := shell.NewShell(&shell.Options{
201 WorkingDir: dir,
202 })
203 branch, err := getGitBranch(ctx, sh)
204 if err != nil {
205 return "", err
206 }
207 status, err := getGitStatusSummary(ctx, sh)
208 if err != nil {
209 return "", err
210 }
211 commits, err := getGitRecentCommits(ctx, sh)
212 if err != nil {
213 return "", err
214 }
215 return branch + status + commits, nil
216}
217
218func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
219 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
220 if err != nil {
221 return "", nil
222 }
223 out = strings.TrimSpace(out)
224 if out == "" {
225 return "", nil
226 }
227 return fmt.Sprintf("Current branch: %s\n", out), nil
228}
229
230func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
231 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
232 if err != nil {
233 return "", nil
234 }
235 out = strings.TrimSpace(out)
236 if out == "" {
237 return "Status: clean\n", nil
238 }
239 return fmt.Sprintf("Status:\n%s\n", out), nil
240}
241
242func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
243 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
244 if err != nil || out == "" {
245 return "", nil
246 }
247 out = strings.TrimSpace(out)
248 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
249}
250
251func (p *Prompt) Name() string {
252 return p.name
253}