1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "runtime"
11 "strings"
12 "text/template"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/filepathext"
17 "github.com/charmbracelet/crush/internal/home"
18 "github.com/charmbracelet/crush/internal/shell"
19 "github.com/charmbracelet/crush/internal/skills"
20)
21
22// Prompt represents a template-based prompt generator.
23type Prompt struct {
24 name string
25 template string
26 now func() time.Time
27 platform string
28 workingDir string
29}
30
31type PromptDat struct {
32 Provider string
33 Model string
34 Config config.Config
35 WorkingDir string
36 IsGitRepo bool
37 Platform string
38 Date string
39 GitStatus string
40 ContextFiles []ContextFile
41 GlobalContextFiles []ContextFile
42 AvailSkillXML string
43}
44
45type ContextFile struct {
46 Path string
47 Content string
48}
49
50type Option func(*Prompt)
51
52func WithTimeFunc(fn func() time.Time) Option {
53 return func(p *Prompt) {
54 p.now = fn
55 }
56}
57
58func WithPlatform(platform string) Option {
59 return func(p *Prompt) {
60 p.platform = platform
61 }
62}
63
64func WithWorkingDir(workingDir string) Option {
65 return func(p *Prompt) {
66 p.workingDir = workingDir
67 }
68}
69
70func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
71 p := &Prompt{
72 name: name,
73 template: promptTemplate,
74 now: time.Now,
75 }
76 for _, opt := range opts {
77 opt(p)
78 }
79 return p, nil
80}
81
82func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
83 t, err := template.New(p.name).Parse(p.template)
84 if err != nil {
85 return "", fmt.Errorf("parsing template: %w", err)
86 }
87 var sb strings.Builder
88 d, err := p.promptData(ctx, provider, model, store)
89 if err != nil {
90 return "", err
91 }
92 if err := t.Execute(&sb, d); err != nil {
93 return "", fmt.Errorf("executing template: %w", err)
94 }
95
96 return sb.String(), nil
97}
98
99func processFile(filePath string) *ContextFile {
100 content, err := os.ReadFile(filePath)
101 if err != nil {
102 return nil
103 }
104 return &ContextFile{
105 Path: filePath,
106 Content: string(content),
107 }
108}
109
110func processContextPath(p string, store *config.ConfigStore) []ContextFile {
111 var contexts []ContextFile
112 fullPath := filepathext.SmartJoin(store.WorkingDir(), p)
113 info, err := os.Stat(fullPath)
114 if err != nil {
115 return contexts
116 }
117 if info.IsDir() {
118 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
119 if err != nil {
120 return err
121 }
122 if !d.IsDir() {
123 if result := processFile(path); result != nil {
124 contexts = append(contexts, *result)
125 }
126 }
127 return nil
128 })
129 } else {
130 result := processFile(fullPath)
131 if result != nil {
132 contexts = append(contexts, *result)
133 }
134 }
135 return contexts
136}
137
138// expandPath expands ~ and environment variables in file paths
139func expandPath(path string, store *config.ConfigStore) string {
140 path = home.Long(path)
141 // Handle environment variable expansion using the same pattern as config
142 if strings.HasPrefix(path, "$") {
143 if expanded, err := store.Resolver().ResolveValue(path); err == nil {
144 path = expanded
145 }
146 }
147
148 return path
149}
150
151// loadContextFiles loads and deduplicates context files from a list of paths.
152func loadContextFiles(paths []string, store *config.ConfigStore) map[string][]ContextFile {
153 files := map[string][]ContextFile{}
154 for _, pth := range paths {
155 expanded := expandPath(pth, store)
156 pathKey := strings.ToLower(expanded)
157 if _, ok := files[pathKey]; ok {
158 continue
159 }
160 files[pathKey] = processContextPath(expanded, store)
161 }
162 return files
163}
164
165func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
166 workingDir := cmp.Or(p.workingDir, store.WorkingDir())
167 platform := cmp.Or(p.platform, runtime.GOOS)
168
169 cfg := store.Config()
170 contextFiles := loadContextFiles(cfg.Options.ContextPaths, store)
171 globalContextFiles := loadContextFiles(cfg.Options.GlobalContextPaths, store)
172
173 // Discover and load skills metadata.
174 var availSkillXML string
175
176 // Start with builtin skills.
177 allSkills := skills.DiscoverBuiltin()
178 builtinNames := make(map[string]bool, len(allSkills))
179 for _, s := range allSkills {
180 builtinNames[s.Name] = true
181 }
182
183 // Discover user skills from configured paths.
184 if len(cfg.Options.SkillsPaths) > 0 {
185 expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
186 for _, pth := range cfg.Options.SkillsPaths {
187 expandedPaths = append(expandedPaths, expandPath(pth, store))
188 }
189 for _, userSkill := range skills.Discover(expandedPaths) {
190 if builtinNames[userSkill.Name] {
191 slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
192 }
193 allSkills = append(allSkills, userSkill)
194 }
195 }
196
197 // Deduplicate: user skills override builtins with the same name.
198 allSkills = skills.Deduplicate(allSkills)
199
200 // Filter out disabled skills.
201 allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
202
203 if len(allSkills) > 0 {
204 availSkillXML = skills.ToPromptXML(allSkills)
205 }
206
207 isGit := isGitRepo(store.WorkingDir())
208 data := PromptDat{
209 Provider: provider,
210 Model: model,
211 Config: *cfg,
212 WorkingDir: filepath.ToSlash(workingDir),
213 IsGitRepo: isGit,
214 Platform: platform,
215 Date: p.now().Format("1/2/2006"),
216 AvailSkillXML: availSkillXML,
217 }
218 if isGit {
219 var err error
220 data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
221 if err != nil {
222 return PromptDat{}, err
223 }
224 }
225
226 for _, files := range contextFiles {
227 data.ContextFiles = append(data.ContextFiles, files...)
228 }
229 for _, files := range globalContextFiles {
230 data.GlobalContextFiles = append(data.GlobalContextFiles, files...)
231 }
232 return data, nil
233}
234
235func isGitRepo(dir string) bool {
236 _, err := os.Stat(filepath.Join(dir, ".git"))
237 return err == nil
238}
239
240func getGitStatus(ctx context.Context, dir string) (string, error) {
241 sh := shell.NewShell(&shell.Options{
242 WorkingDir: dir,
243 })
244 branch, err := getGitBranch(ctx, sh)
245 if err != nil {
246 return "", err
247 }
248 status, err := getGitStatusSummary(ctx, sh)
249 if err != nil {
250 return "", err
251 }
252 commits, err := getGitRecentCommits(ctx, sh)
253 if err != nil {
254 return "", err
255 }
256 return branch + status + commits, nil
257}
258
259func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
260 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
261 if err != nil {
262 return "", nil
263 }
264 out = strings.TrimSpace(out)
265 if out == "" {
266 return "", nil
267 }
268 return fmt.Sprintf("Current branch: %s\n", out), nil
269}
270
271func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
272 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
273 if err != nil {
274 return "", nil
275 }
276 out = strings.TrimSpace(out)
277 if out == "" {
278 return "Status: clean\n", nil
279 }
280 return fmt.Sprintf("Status:\n%s\n", out), nil
281}
282
283func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
284 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
285 if err != nil || out == "" {
286 return "", nil
287 }
288 out = strings.TrimSpace(out)
289 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
290}
291
292func (p *Prompt) Name() string {
293 return p.name
294}