1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/home"
16 "github.com/charmbracelet/crush/internal/shell"
17 "github.com/charmbracelet/crush/internal/skills"
18)
19
20// Prompt represents a template-based prompt generator.
21type Prompt struct {
22 name string
23 template string
24 now func() time.Time
25 platform string
26 workingDir string
27}
28
29type PromptDat struct {
30 Provider string
31 Model string
32 Config config.Config
33 WorkingDir string
34 IsGitRepo bool
35 Platform string
36 Date string
37 GitStatus string
38 ContextFiles []ContextFile
39 AvailSkillXML string
40}
41
42type ContextFile struct {
43 Path string
44 Content string
45}
46
47type Option func(*Prompt)
48
49func WithTimeFunc(fn func() time.Time) Option {
50 return func(p *Prompt) {
51 p.now = fn
52 }
53}
54
55func WithPlatform(platform string) Option {
56 return func(p *Prompt) {
57 p.platform = platform
58 }
59}
60
61func WithWorkingDir(workingDir string) Option {
62 return func(p *Prompt) {
63 p.workingDir = workingDir
64 }
65}
66
67func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
68 p := &Prompt{
69 name: name,
70 template: promptTemplate,
71 now: time.Now,
72 }
73 for _, opt := range opts {
74 opt(p)
75 }
76 return p, nil
77}
78
79func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
80 t, err := template.New(p.name).Parse(p.template)
81 if err != nil {
82 return "", fmt.Errorf("parsing template: %w", err)
83 }
84 var sb strings.Builder
85 d, err := p.promptData(ctx, provider, model, cfg)
86 if err != nil {
87 return "", err
88 }
89 if err := t.Execute(&sb, d); err != nil {
90 return "", fmt.Errorf("executing template: %w", err)
91 }
92
93 return sb.String(), nil
94}
95
96func processFile(filePath string) *ContextFile {
97 content, err := os.ReadFile(filePath)
98 if err != nil {
99 return nil
100 }
101 return &ContextFile{
102 Path: filePath,
103 Content: string(content),
104 }
105}
106
107func processContextPath(p string, cfg config.Config) []ContextFile {
108 var contexts []ContextFile
109 fullPath := p
110 if !filepath.IsAbs(p) {
111 fullPath = filepath.Join(cfg.WorkingDir(), p)
112 }
113 info, err := os.Stat(fullPath)
114 if err != nil {
115 return contexts
116 }
117 if info.IsDir() {
118 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
119 if err != nil {
120 return err
121 }
122 if !d.IsDir() {
123 if result := processFile(path); result != nil {
124 contexts = append(contexts, *result)
125 }
126 }
127 return nil
128 })
129 } else {
130 result := processFile(fullPath)
131 if result != nil {
132 contexts = append(contexts, *result)
133 }
134 }
135 return contexts
136}
137
138// expandPath expands ~ and environment variables in file paths
139func expandPath(path string, cfg config.Config) string {
140 path = home.Long(path)
141 // Handle environment variable expansion using the same pattern as config
142 if strings.HasPrefix(path, "$") {
143 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
144 path = expanded
145 }
146 }
147
148 return path
149}
150
151func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
152 workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
153 platform := cmp.Or(p.platform, runtime.GOOS)
154
155 files := map[string][]ContextFile{}
156
157 for _, pth := range cfg.Options.ContextPaths {
158 expanded := expandPath(pth, cfg)
159 pathKey := strings.ToLower(expanded)
160 if _, ok := files[pathKey]; ok {
161 continue
162 }
163 content := processContextPath(expanded, cfg)
164 files[pathKey] = content
165 }
166
167 // Discover and load skills metadata.
168 var availSkillXML string
169 if len(cfg.Options.SkillsPaths) > 0 {
170 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
171 for _, pth := range cfg.Options.SkillsPaths {
172 expandedPaths = append(expandedPaths, expandPath(pth, cfg))
173 }
174 if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
175 availSkillXML = skills.ToPromptXML(discoveredSkills)
176 }
177 }
178
179 isGit := isGitRepo(cfg.WorkingDir())
180 data := PromptDat{
181 Provider: provider,
182 Model: model,
183 Config: cfg,
184 WorkingDir: filepath.ToSlash(workingDir),
185 IsGitRepo: isGit,
186 Platform: platform,
187 Date: p.now().Format("1/2/2006"),
188 AvailSkillXML: availSkillXML,
189 }
190 if isGit {
191 var err error
192 data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
193 if err != nil {
194 return PromptDat{}, err
195 }
196 }
197
198 for _, contextFiles := range files {
199 data.ContextFiles = append(data.ContextFiles, contextFiles...)
200 }
201 return data, nil
202}
203
204func isGitRepo(dir string) bool {
205 _, err := os.Stat(filepath.Join(dir, ".git"))
206 return err == nil
207}
208
209func getGitStatus(ctx context.Context, dir string) (string, error) {
210 sh := shell.NewShell(&shell.Options{
211 WorkingDir: dir,
212 })
213 branch, err := getGitBranch(ctx, sh)
214 if err != nil {
215 return "", err
216 }
217 status, err := getGitStatusSummary(ctx, sh)
218 if err != nil {
219 return "", err
220 }
221 commits, err := getGitRecentCommits(ctx, sh)
222 if err != nil {
223 return "", err
224 }
225 return branch + status + commits, nil
226}
227
228func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
229 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
230 if err != nil {
231 return "", nil
232 }
233 out = strings.TrimSpace(out)
234 if out == "" {
235 return "", nil
236 }
237 return fmt.Sprintf("Current branch: %s\n", out), nil
238}
239
240func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
241 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
242 if err != nil {
243 return "", nil
244 }
245 out = strings.TrimSpace(out)
246 if out == "" {
247 return "Status: clean\n", nil
248 }
249 return fmt.Sprintf("Status:\n%s\n", out), nil
250}
251
252func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
253 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
254 if err != nil || out == "" {
255 return "", nil
256 }
257 out = strings.TrimSpace(out)
258 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
259}
260
261func (p *Prompt) Name() string {
262 return p.name
263}