1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/home"
16 "github.com/charmbracelet/crush/internal/shell"
17 "github.com/charmbracelet/crush/internal/skills"
18)
19
20// Prompt represents a template-based prompt generator.
21type Prompt struct {
22 name string
23 template string
24 now func() time.Time
25 platform string
26 workingDir string
27}
28
29type PromptDat struct {
30 Provider string
31 Model string
32 Config config.Config
33 WorkingDir string
34 IsGitRepo bool
35 Platform string
36 Date string
37 GitStatus string
38 ContextFiles []ContextFile
39 AvailSkillXML string
40}
41
42type ContextFile struct {
43 Path string
44 Content string
45}
46
47type Option func(*Prompt)
48
49func WithTimeFunc(fn func() time.Time) Option {
50 return func(p *Prompt) {
51 p.now = fn
52 }
53}
54
55func WithPlatform(platform string) Option {
56 return func(p *Prompt) {
57 p.platform = platform
58 }
59}
60
61func WithWorkingDir(workingDir string) Option {
62 return func(p *Prompt) {
63 p.workingDir = workingDir
64 }
65}
66
67func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
68 p := &Prompt{
69 name: name,
70 template: promptTemplate,
71 now: time.Now,
72 }
73 for _, opt := range opts {
74 opt(p)
75 }
76 return p, nil
77}
78
79func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
80 t, err := template.New(p.name).Parse(p.template)
81 if err != nil {
82 return "", fmt.Errorf("parsing template: %w", err)
83 }
84 var sb strings.Builder
85 d, err := p.promptData(ctx, provider, model, store)
86 if err != nil {
87 return "", err
88 }
89 if err := t.Execute(&sb, d); err != nil {
90 return "", fmt.Errorf("executing template: %w", err)
91 }
92
93 return sb.String(), nil
94}
95
96func processFile(filePath string) *ContextFile {
97 content, err := os.ReadFile(filePath)
98 if err != nil {
99 return nil
100 }
101 return &ContextFile{
102 Path: filePath,
103 Content: string(content),
104 }
105}
106
107func processContextPath(p string, store *config.ConfigStore) []ContextFile {
108 var contexts []ContextFile
109 fullPath := p
110 if !filepath.IsAbs(p) {
111 fullPath = filepath.Join(store.WorkingDir(), p)
112 }
113 info, err := os.Stat(fullPath)
114 if err != nil {
115 return contexts
116 }
117 if info.IsDir() {
118 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
119 if err != nil {
120 return err
121 }
122 if !d.IsDir() {
123 if result := processFile(path); result != nil {
124 contexts = append(contexts, *result)
125 }
126 }
127 return nil
128 })
129 } else {
130 result := processFile(fullPath)
131 if result != nil {
132 contexts = append(contexts, *result)
133 }
134 }
135 return contexts
136}
137
138// expandPath expands ~ and environment variables in file paths
139func expandPath(path string, store *config.ConfigStore) string {
140 path = home.Long(path)
141 // Handle environment variable expansion using the same pattern as config
142 if strings.HasPrefix(path, "$") {
143 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
144 path = expanded
145 }
146 }
147
148 return path
149}
150
151func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
152 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
153 platform := cmp.Or(p.platform, runtime.GOOS)
154
155 files := map[string][]ContextFile{}
156
157 cfg := store.Config()
158 for _, pth := range cfg.Options.ContextPaths {
159 expanded := expandPath(pth, store)
160 pathKey := strings.ToLower(expanded)
161 if _, ok := files[pathKey]; ok {
162 continue
163 }
164 content := processContextPath(expanded, store)
165 files[pathKey] = content
166 }
167
168 // Discover and load skills metadata.
169 var availSkillXML string
170 if len(cfg.Options.SkillsPaths) > 0 {
171 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
172 for _, pth := range cfg.Options.SkillsPaths {
173 expandedPaths = append(expandedPaths, expandPath(pth, store))
174 }
175 if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
176 availSkillXML = skills.ToPromptXML(discoveredSkills)
177 }
178 }
179
180 isGit := isGitRepo(store.WorkingDir())
181 data := PromptDat{
182 Provider: provider,
183 Model: model,
184 Config: *cfg,
185 WorkingDir: filepath.ToSlash(workingDir),
186 IsGitRepo: isGit,
187 Platform: platform,
188 Date: p.now().Format("1/2/2006"),
189 AvailSkillXML: availSkillXML,
190 }
191 if isGit {
192 var err error
193 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
194 if err != nil {
195 return PromptDat{}, err
196 }
197 }
198
199 for _, contextFiles := range files {
200 data.ContextFiles = append(data.ContextFiles, contextFiles...)
201 }
202 return data, nil
203}
204
205func isGitRepo(dir string) bool {
206 _, err := os.Stat(filepath.Join(dir, ".git"))
207 return err == nil
208}
209
210func getGitStatus(ctx context.Context, dir string) (string, error) {
211 sh := shell.NewShell(&shell.Options{
212 WorkingDir: dir,
213 })
214 branch, err := getGitBranch(ctx, sh)
215 if err != nil {
216 return "", err
217 }
218 status, err := getGitStatusSummary(ctx, sh)
219 if err != nil {
220 return "", err
221 }
222 commits, err := getGitRecentCommits(ctx, sh)
223 if err != nil {
224 return "", err
225 }
226 return branch + status + commits, nil
227}
228
229func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
230 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
231 if err != nil {
232 return "", nil
233 }
234 out = strings.TrimSpace(out)
235 if out == "" {
236 return "", nil
237 }
238 return fmt.Sprintf("Current branch: %s\n", out), nil
239}
240
241func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
242 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
243 if err != nil {
244 return "", nil
245 }
246 out = strings.TrimSpace(out)
247 if out == "" {
248 return "Status: clean\n", nil
249 }
250 return fmt.Sprintf("Status:\n%s\n", out), nil
251}
252
253func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
254 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
255 if err != nil || out == "" {
256 return "", nil
257 }
258 out = strings.TrimSpace(out)
259 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
260}
261
262func (p *Prompt) Name() string {
263 return p.name
264}