1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "text/template"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/filepathext"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/shell"
19 "github.com/charmbracelet/crush/internal/skills"
20)
21
22// Prompt represents a template-based prompt generator.
23type Prompt struct {
24 name string
25 template string
26 now func() time.Time
27 platform string
28 workingDir string
29}
30
31type PromptDat struct {
32 Provider string
33 Model string
34 Config config.Config
35 WorkingDir string
36 IsGitRepo bool
37 Platform string
38 Date string
39 GitStatus string
40 ContextFiles []ContextFile
41 AvailSkillXML string
42}
43
44type ContextFile struct {
45 Path string
46 Content string
47}
48
49type Option func(*Prompt)
50
51func WithTimeFunc(fn func() time.Time) Option {
52 return func(p *Prompt) {
53 p.now = fn
54 }
55}
56
57func WithPlatform(platform string) Option {
58 return func(p *Prompt) {
59 p.platform = platform
60 }
61}
62
63func WithWorkingDir(workingDir string) Option {
64 return func(p *Prompt) {
65 p.workingDir = workingDir
66 }
67}
68
69func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
70 p := &Prompt{
71 name: name,
72 template: promptTemplate,
73 now: time.Now,
74 }
75 for _, opt := range opts {
76 opt(p)
77 }
78 return p, nil
79}
80
81func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
82 t, err := template.New(p.name).Parse(p.template)
83 if err != nil {
84 return "", fmt.Errorf("parsing template: %w", err)
85 }
86 var sb strings.Builder
87 d, err := p.promptData(ctx, provider, model, store)
88 if err != nil {
89 return "", err
90 }
91 if err := t.Execute(&sb, d); err != nil {
92 return "", fmt.Errorf("executing template: %w", err)
93 }
94
95 return sb.String(), nil
96}
97
98func processFile(filePath string) *ContextFile {
99 content, err := os.ReadFile(filePath)
100 if err != nil {
101 return nil
102 }
103 return &ContextFile{
104 Path: filePath,
105 Content: string(content),
106 }
107}
108
109func processContextPath(p string, store *config.ConfigStore) []ContextFile {
110 var contexts []ContextFile
111 fullPath := filepathext.SmartJoin(store.WorkingDir(), p)
112 info, err := os.Stat(fullPath)
113 if err != nil {
114 return contexts
115 }
116 if info.IsDir() {
117 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
118 if err != nil {
119 return err
120 }
121 if !d.IsDir() {
122 if result := processFile(path); result != nil {
123 contexts = append(contexts, *result)
124 }
125 }
126 return nil
127 })
128 } else {
129 result := processFile(fullPath)
130 if result != nil {
131 contexts = append(contexts, *result)
132 }
133 }
134 return contexts
135}
136
137// expandPath expands ~ and environment variables in file paths
138func expandPath(path string, store *config.ConfigStore) string {
139 path = home.Long(path)
140 // Handle environment variable expansion using the same pattern as config
141 if strings.HasPrefix(path, "$") {
142 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
143 path = expanded
144 }
145 }
146
147 return path
148}
149
150func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
151 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
152 platform := cmp.Or(p.platform, runtime.GOOS)
153
154 files := map[string][]ContextFile{}
155
156 cfg := store.Config()
157 for _, pth := range cfg.Options.ContextPaths {
158 expanded := expandPath(pth, store)
159 pathKey := strings.ToLower(expanded)
160 if _, ok := files[pathKey]; ok {
161 continue
162 }
163 content := processContextPath(expanded, store)
164 files[pathKey] = content
165 }
166
167 // Discover and load skills metadata.
168 var availSkillXML string
169
170 // Start with builtin skills.
171 allSkills := skills.DiscoverBuiltin()
172 builtinNames := make(map[string]bool, len(allSkills))
173 for _, s := range allSkills {
174 builtinNames[s.Name] = true
175 }
176
177 // Discover user skills from configured paths.
178 if len(cfg.Options.SkillsPaths) > 0 {
179 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
180 for _, pth := range cfg.Options.SkillsPaths {
181 expandedPaths = append(expandedPaths, expandPath(pth, store))
182 }
183 for _, userSkill := range skills.Discover(expandedPaths) {
184 if builtinNames[userSkill.Name] {
185 slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
186 }
187 allSkills = append(allSkills, userSkill)
188 }
189 }
190
191 // Deduplicate: user skills override builtins with the same name.
192 allSkills = skills.Deduplicate(allSkills)
193
194 // Filter out disabled skills.
195 allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
196
197 if len(allSkills) > 0 {
198 availSkillXML = skills.ToPromptXML(allSkills)
199 }
200
201 isGit := isGitRepo(store.WorkingDir())
202 data := PromptDat{
203 Provider: provider,
204 Model: model,
205 Config: *cfg,
206 WorkingDir: filepath.ToSlash(workingDir),
207 IsGitRepo: isGit,
208 Platform: platform,
209 Date: p.now().Format("1/2/2006"),
210 AvailSkillXML: availSkillXML,
211 }
212 if isGit {
213 var err error
214 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
215 if err != nil {
216 return PromptDat{}, err
217 }
218 }
219
220 for _, contextFiles := range files {
221 data.ContextFiles = append(data.ContextFiles, contextFiles...)
222 }
223 return data, nil
224}
225
226func isGitRepo(dir string) bool {
227 _, err := os.Stat(filepath.Join(dir, ".git"))
228 return err == nil
229}
230
231func getGitStatus(ctx context.Context, dir string) (string, error) {
232 sh := shell.NewShell(&shell.Options{
233 WorkingDir: dir,
234 })
235 branch, err := getGitBranch(ctx, sh)
236 if err != nil {
237 return "", err
238 }
239 status, err := getGitStatusSummary(ctx, sh)
240 if err != nil {
241 return "", err
242 }
243 commits, err := getGitRecentCommits(ctx, sh)
244 if err != nil {
245 return "", err
246 }
247 return branch + status + commits, nil
248}
249
250func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
251 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
252 if err != nil {
253 return "", nil
254 }
255 out = strings.TrimSpace(out)
256 if out == "" {
257 return "", nil
258 }
259 return fmt.Sprintf("Current branch: %s\n", out), nil
260}
261
262func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
263 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
264 if err != nil {
265 return "", nil
266 }
267 out = strings.TrimSpace(out)
268 if out == "" {
269 return "Status: clean\n", nil
270 }
271 return fmt.Sprintf("Status:\n%s\n", out), nil
272}
273
274func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
275 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
276 if err != nil || out == "" {
277 return "", nil
278 }
279 out = strings.TrimSpace(out)
280 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
281}
282
283func (p *Prompt) Name() string {
284 return p.name
285}