1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "text/template"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/home"
17 "github.com/charmbracelet/crush/internal/shell"
18 "github.com/charmbracelet/crush/internal/skills"
19)
20
21// Prompt represents a template-based prompt generator.
22type Prompt struct {
23 name string
24 template string
25 now func() time.Time
26 platform string
27 workingDir string
28}
29
30type PromptDat struct {
31 Provider string
32 Model string
33 Config config.Config
34 WorkingDir string
35 IsGitRepo bool
36 Platform string
37 Date string
38 GitStatus string
39 ContextFiles []ContextFile
40 AvailSkillXML string
41}
42
43type ContextFile struct {
44 Path string
45 Content string
46}
47
48type Option func(*Prompt)
49
50func WithTimeFunc(fn func() time.Time) Option {
51 return func(p *Prompt) {
52 p.now = fn
53 }
54}
55
56func WithPlatform(platform string) Option {
57 return func(p *Prompt) {
58 p.platform = platform
59 }
60}
61
62func WithWorkingDir(workingDir string) Option {
63 return func(p *Prompt) {
64 p.workingDir = workingDir
65 }
66}
67
68func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
69 p := &Prompt{
70 name: name,
71 template: promptTemplate,
72 now: time.Now,
73 }
74 for _, opt := range opts {
75 opt(p)
76 }
77 return p, nil
78}
79
80func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
81 t, err := template.New(p.name).Parse(p.template)
82 if err != nil {
83 return "", fmt.Errorf("parsing template: %w", err)
84 }
85 var sb strings.Builder
86 d, err := p.promptData(ctx, provider, model, store)
87 if err != nil {
88 return "", err
89 }
90 if err := t.Execute(&sb, d); err != nil {
91 return "", fmt.Errorf("executing template: %w", err)
92 }
93
94 return sb.String(), nil
95}
96
97func processFile(filePath string) *ContextFile {
98 content, err := os.ReadFile(filePath)
99 if err != nil {
100 return nil
101 }
102 return &ContextFile{
103 Path: filePath,
104 Content: string(content),
105 }
106}
107
108func processContextPath(p string, store *config.ConfigStore) []ContextFile {
109 var contexts []ContextFile
110 fullPath := p
111 if !filepath.IsAbs(p) {
112 fullPath = filepath.Join(store.WorkingDir(), p)
113 }
114 info, err := os.Stat(fullPath)
115 if err != nil {
116 return contexts
117 }
118 if info.IsDir() {
119 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
120 if err != nil {
121 return err
122 }
123 if !d.IsDir() {
124 if result := processFile(path); result != nil {
125 contexts = append(contexts, *result)
126 }
127 }
128 return nil
129 })
130 } else {
131 result := processFile(fullPath)
132 if result != nil {
133 contexts = append(contexts, *result)
134 }
135 }
136 return contexts
137}
138
139// expandPath expands ~ and environment variables in file paths
140func expandPath(path string, store *config.ConfigStore) string {
141 path = home.Long(path)
142 // Handle environment variable expansion using the same pattern as config
143 if strings.HasPrefix(path, "$") {
144 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
145 path = expanded
146 }
147 }
148
149 return path
150}
151
152func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
153 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
154 platform := cmp.Or(p.platform, runtime.GOOS)
155
156 files := map[string][]ContextFile{}
157
158 cfg := store.Config()
159 for _, pth := range cfg.Options.ContextPaths {
160 expanded := expandPath(pth, store)
161 pathKey := strings.ToLower(expanded)
162 if _, ok := files[pathKey]; ok {
163 continue
164 }
165 content := processContextPath(expanded, store)
166 files[pathKey] = content
167 }
168
169 // Discover and load skills metadata.
170 var availSkillXML string
171
172 // Start with builtin skills.
173 allSkills := skills.DiscoverBuiltin()
174 builtinNames := make(map[string]bool, len(allSkills))
175 for _, s := range allSkills {
176 builtinNames[s.Name] = true
177 }
178
179 // Discover user skills from configured paths.
180 if len(cfg.Options.SkillsPaths) > 0 {
181 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
182 for _, pth := range cfg.Options.SkillsPaths {
183 expandedPaths = append(expandedPaths, expandPath(pth, store))
184 }
185 for _, userSkill := range skills.Discover(expandedPaths) {
186 if builtinNames[userSkill.Name] {
187 slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
188 }
189 allSkills = append(allSkills, userSkill)
190 }
191 }
192
193 // Deduplicate: user skills override builtins with the same name.
194 allSkills = skills.Deduplicate(allSkills)
195
196 // Filter out disabled skills.
197 allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
198
199 if len(allSkills) > 0 {
200 availSkillXML = skills.ToPromptXML(allSkills)
201 }
202
203 isGit := isGitRepo(store.WorkingDir())
204 data := PromptDat{
205 Provider: provider,
206 Model: model,
207 Config: *cfg,
208 WorkingDir: filepath.ToSlash(workingDir),
209 IsGitRepo: isGit,
210 Platform: platform,
211 Date: p.now().Format("1/2/2006"),
212 AvailSkillXML: availSkillXML,
213 }
214 if isGit {
215 var err error
216 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
217 if err != nil {
218 return PromptDat{}, err
219 }
220 }
221
222 for _, contextFiles := range files {
223 data.ContextFiles = append(data.ContextFiles, contextFiles...)
224 }
225 return data, nil
226}
227
228func isGitRepo(dir string) bool {
229 _, err := os.Stat(filepath.Join(dir, ".git"))
230 return err == nil
231}
232
233func getGitStatus(ctx context.Context, dir string) (string, error) {
234 sh := shell.NewShell(&shell.Options{
235 WorkingDir: dir,
236 })
237 branch, err := getGitBranch(ctx, sh)
238 if err != nil {
239 return "", err
240 }
241 status, err := getGitStatusSummary(ctx, sh)
242 if err != nil {
243 return "", err
244 }
245 commits, err := getGitRecentCommits(ctx, sh)
246 if err != nil {
247 return "", err
248 }
249 return branch + status + commits, nil
250}
251
252func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
253 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
254 if err != nil {
255 return "", nil
256 }
257 out = strings.TrimSpace(out)
258 if out == "" {
259 return "", nil
260 }
261 return fmt.Sprintf("Current branch: %s\n", out), nil
262}
263
264func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
265 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
266 if err != nil {
267 return "", nil
268 }
269 out = strings.TrimSpace(out)
270 if out == "" {
271 return "Status: clean\n", nil
272 }
273 return fmt.Sprintf("Status:\n%s\n", out), nil
274}
275
276func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
277 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
278 if err != nil || out == "" {
279 return "", nil
280 }
281 out = strings.TrimSpace(out)
282 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
283}
284
285func (p *Prompt) Name() string {
286 return p.name
287}