1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/home"
16 "github.com/charmbracelet/crush/internal/shell"
17 "github.com/charmbracelet/crush/internal/skills"
18)
19
20// Prompt represents a template-based prompt generator.
21type Prompt struct {
22 name string
23 template string
24 now func() time.Time
25 platform string
26 workingDir string
27 modelFamily ModelFamily
28}
29
30type PromptDat struct {
31 Provider string
32 Model string
33 ModelFamily ModelFamily
34 Config config.Config
35 WorkingDir string
36 IsGitRepo bool
37 Platform string
38 Date string
39 GitStatus string
40 ContextFiles []ContextFile
41 AvailSkillXML string
42}
43
44type ContextFile struct {
45 Path string
46 Content string
47}
48
49type Option func(*Prompt)
50
51func WithTimeFunc(fn func() time.Time) Option {
52 return func(p *Prompt) {
53 p.now = fn
54 }
55}
56
57func WithPlatform(platform string) Option {
58 return func(p *Prompt) {
59 p.platform = platform
60 }
61}
62
63func WithWorkingDir(workingDir string) Option {
64 return func(p *Prompt) {
65 p.workingDir = workingDir
66 }
67}
68
69func WithModelFamily(modelFamily ModelFamily) Option {
70 return func(p *Prompt) {
71 p.modelFamily = modelFamily
72 }
73}
74
75func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
76 p := &Prompt{
77 name: name,
78 template: promptTemplate,
79 now: time.Now,
80 }
81 for _, opt := range opts {
82 opt(p)
83 }
84 return p, nil
85}
86
87func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
88 t, err := template.New(p.name).Parse(p.template)
89 if err != nil {
90 return "", fmt.Errorf("parsing template: %w", err)
91 }
92 var sb strings.Builder
93 d, err := p.promptData(ctx, provider, model, cfg)
94 if err != nil {
95 return "", err
96 }
97 if err := t.Execute(&sb, d); err != nil {
98 return "", fmt.Errorf("executing template: %w", err)
99 }
100
101 return sb.String(), nil
102}
103
104func processFile(filePath string) *ContextFile {
105 content, err := os.ReadFile(filePath)
106 if err != nil {
107 return nil
108 }
109 return &ContextFile{
110 Path: filePath,
111 Content: string(content),
112 }
113}
114
115func processContextPath(p string, cfg config.Config) []ContextFile {
116 var contexts []ContextFile
117 fullPath := p
118 if !filepath.IsAbs(p) {
119 fullPath = filepath.Join(cfg.WorkingDir(), p)
120 }
121 info, err := os.Stat(fullPath)
122 if err != nil {
123 return contexts
124 }
125 if info.IsDir() {
126 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
127 if err != nil {
128 return err
129 }
130 if !d.IsDir() {
131 if result := processFile(path); result != nil {
132 contexts = append(contexts, *result)
133 }
134 }
135 return nil
136 })
137 } else {
138 result := processFile(fullPath)
139 if result != nil {
140 contexts = append(contexts, *result)
141 }
142 }
143 return contexts
144}
145
146// expandPath expands ~ and environment variables in file paths
147func expandPath(path string, cfg config.Config) string {
148 path = home.Long(path)
149 // Handle environment variable expansion using the same pattern as config
150 if strings.HasPrefix(path, "$") {
151 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
152 path = expanded
153 }
154 }
155
156 return path
157}
158
159func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
160 workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
161 platform := cmp.Or(p.platform, runtime.GOOS)
162
163 files := map[string][]ContextFile{}
164
165 for _, pth := range cfg.Options.ContextPaths {
166 expanded := expandPath(pth, cfg)
167 pathKey := strings.ToLower(expanded)
168 if _, ok := files[pathKey]; ok {
169 continue
170 }
171 content := processContextPath(expanded, cfg)
172 files[pathKey] = content
173 }
174
175 // Discover and load skills metadata.
176 var availSkillXML string
177 if len(cfg.Options.SkillsPaths) > 0 {
178 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
179 for _, pth := range cfg.Options.SkillsPaths {
180 expandedPaths = append(expandedPaths, expandPath(pth, cfg))
181 }
182 if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
183 availSkillXML = skills.ToPromptXML(discoveredSkills)
184 }
185 }
186
187 isGit := isGitRepo(cfg.WorkingDir())
188 data := PromptDat{
189 Provider: provider,
190 Model: model,
191 ModelFamily: p.modelFamily,
192 Config: cfg,
193 WorkingDir: filepath.ToSlash(workingDir),
194 IsGitRepo: isGit,
195 Platform: platform,
196 Date: p.now().Format("1/2/2006"),
197 AvailSkillXML: availSkillXML,
198 }
199 if isGit {
200 var err error
201 data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
202 if err != nil {
203 return PromptDat{}, err
204 }
205 }
206
207 for _, contextFiles := range files {
208 data.ContextFiles = append(data.ContextFiles, contextFiles...)
209 }
210 return data, nil
211}
212
213func isGitRepo(dir string) bool {
214 _, err := os.Stat(filepath.Join(dir, ".git"))
215 return err == nil
216}
217
218func getGitStatus(ctx context.Context, dir string) (string, error) {
219 sh := shell.NewShell(&shell.Options{
220 WorkingDir: dir,
221 })
222 branch, err := getGitBranch(ctx, sh)
223 if err != nil {
224 return "", err
225 }
226 status, err := getGitStatusSummary(ctx, sh)
227 if err != nil {
228 return "", err
229 }
230 commits, err := getGitRecentCommits(ctx, sh)
231 if err != nil {
232 return "", err
233 }
234 return branch + status + commits, nil
235}
236
237func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
238 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
239 if err != nil {
240 return "", nil
241 }
242 out = strings.TrimSpace(out)
243 if out == "" {
244 return "", nil
245 }
246 return fmt.Sprintf("Current branch: %s\n", out), nil
247}
248
249func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
250 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
251 if err != nil {
252 return "", nil
253 }
254 out = strings.TrimSpace(out)
255 if out == "" {
256 return "Status: clean\n", nil
257 }
258 return fmt.Sprintf("Status:\n%s\n", out), nil
259}
260
261func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
262 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
263 if err != nil || out == "" {
264 return "", nil
265 }
266 out = strings.TrimSpace(out)
267 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
268}
269
270func (p *Prompt) Name() string {
271 return p.name
272}