1package prompt
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "os"
8 "path/filepath"
9 "runtime"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/home"
16 "github.com/charmbracelet/crush/internal/shell"
17 "github.com/charmbracelet/crush/internal/skills"
18)
19
20// Prompt represents a template-based prompt generator.
21type Prompt struct {
22 name string
23 template string
24 now func() time.Time
25 platform string
26 workingDir string
27}
28
29type PromptDat struct {
30 Provider string
31 Model string
32 Config promptConfig
33 WorkingDir string
34 IsGitRepo bool
35 Platform string
36 Date string
37 GitStatus string
38 ContextFiles []ContextFile
39 AvailSkillXML string
40}
41
42type promptConfig struct {
43 LSP config.LSPs
44}
45
46type ContextFile struct {
47 Path string
48 Content string
49}
50
51type Option func(*Prompt)
52
53func WithTimeFunc(fn func() time.Time) Option {
54 return func(p *Prompt) {
55 p.now = fn
56 }
57}
58
59func WithPlatform(platform string) Option {
60 return func(p *Prompt) {
61 p.platform = platform
62 }
63}
64
65func WithWorkingDir(workingDir string) Option {
66 return func(p *Prompt) {
67 p.workingDir = workingDir
68 }
69}
70
71func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
72 p := &Prompt{
73 name: name,
74 template: promptTemplate,
75 now: time.Now,
76 }
77 for _, opt := range opts {
78 opt(p)
79 }
80 return p, nil
81}
82
83func (p *Prompt) Build(ctx context.Context, provider, model string, svc *config.Service) (string, error) {
84 t, err := template.New(p.name).Parse(p.template)
85 if err != nil {
86 return "", fmt.Errorf("parsing template: %w", err)
87 }
88 var sb strings.Builder
89 d, err := p.promptData(ctx, provider, model, svc)
90 if err != nil {
91 return "", err
92 }
93 if err := t.Execute(&sb, d); err != nil {
94 return "", fmt.Errorf("executing template: %w", err)
95 }
96
97 return sb.String(), nil
98}
99
100func processFile(filePath string) *ContextFile {
101 content, err := os.ReadFile(filePath)
102 if err != nil {
103 return nil
104 }
105 return &ContextFile{
106 Path: filePath,
107 Content: string(content),
108 }
109}
110
111func processContextPath(p string, workingDir string) []ContextFile {
112 var contexts []ContextFile
113 fullPath := p
114 if !filepath.IsAbs(p) {
115 fullPath = filepath.Join(workingDir, p)
116 }
117 info, err := os.Stat(fullPath)
118 if err != nil {
119 return contexts
120 }
121 if info.IsDir() {
122 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
123 if err != nil {
124 return err
125 }
126 if !d.IsDir() {
127 if result := processFile(path); result != nil {
128 contexts = append(contexts, *result)
129 }
130 }
131 return nil
132 })
133 } else {
134 result := processFile(fullPath)
135 if result != nil {
136 contexts = append(contexts, *result)
137 }
138 }
139 return contexts
140}
141
142// expandPath expands ~ and environment variables in file paths
143func expandPath(path string, resolver config.VariableResolver) string {
144 path = home.Long(path)
145 if strings.HasPrefix(path, "$") {
146 if resolver != nil {
147 if expanded, err := resolver.ResolveValue(path); err == nil {
148 path = expanded
149 }
150 }
151 }
152
153 return path
154}
155
156func (p *Prompt) promptData(ctx context.Context, provider, model string, svc *config.Service) (PromptDat, error) {
157 workingDir := cmp.Or(p.workingDir, svc.WorkingDir())
158 platform := cmp.Or(p.platform, runtime.GOOS)
159
160 files := map[string][]ContextFile{}
161
162 for _, pth := range svc.ContextPaths() {
163 expanded := expandPath(pth, svc.Resolver())
164 pathKey := strings.ToLower(expanded)
165 if _, ok := files[pathKey]; ok {
166 continue
167 }
168 content := processContextPath(expanded, svc.WorkingDir())
169 files[pathKey] = content
170 }
171
172 var availSkillXML string
173 if len(svc.SkillsPaths()) > 0 {
174 expandedPaths := make([]string, 0, len(svc.SkillsPaths()))
175 for _, pth := range svc.SkillsPaths() {
176 expandedPaths = append(expandedPaths, expandPath(pth, svc.Resolver()))
177 }
178 if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
179 availSkillXML = skills.ToPromptXML(discoveredSkills)
180 }
181 }
182
183 isGit := isGitRepo(svc.WorkingDir())
184 data := PromptDat{
185 Provider: provider,
186 Model: model,
187 Config: promptConfig{LSP: svc.LSP()},
188 WorkingDir: filepath.ToSlash(workingDir),
189 IsGitRepo: isGit,
190 Platform: platform,
191 Date: p.now().Format("1/2/2006"),
192 AvailSkillXML: availSkillXML,
193 }
194 if isGit {
195 var err error
196 data.GitStatus, err = getGitStatus(ctx, svc.WorkingDir())
197 if err != nil {
198 return PromptDat{}, err
199 }
200 }
201
202 for _, contextFiles := range files {
203 data.ContextFiles = append(data.ContextFiles, contextFiles...)
204 }
205 return data, nil
206}
207
208func isGitRepo(dir string) bool {
209 _, err := os.Stat(filepath.Join(dir, ".git"))
210 return err == nil
211}
212
213func getGitStatus(ctx context.Context, dir string) (string, error) {
214 sh := shell.NewShell(&shell.Options{
215 WorkingDir: dir,
216 })
217 branch, err := getGitBranch(ctx, sh)
218 if err != nil {
219 return "", err
220 }
221 status, err := getGitStatusSummary(ctx, sh)
222 if err != nil {
223 return "", err
224 }
225 commits, err := getGitRecentCommits(ctx, sh)
226 if err != nil {
227 return "", err
228 }
229 return branch + status + commits, nil
230}
231
232func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
233 out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
234 if err != nil {
235 return "", nil
236 }
237 out = strings.TrimSpace(out)
238 if out == "" {
239 return "", nil
240 }
241 return fmt.Sprintf("Current branch: %s\n", out), nil
242}
243
244func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
245 out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
246 if err != nil {
247 return "", nil
248 }
249 out = strings.TrimSpace(out)
250 if out == "" {
251 return "Status: clean\n", nil
252 }
253 return fmt.Sprintf("Status:\n%s\n", out), nil
254}
255
256func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
257 out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
258 if err != nil || out == "" {
259 return "", nil
260 }
261 out = strings.TrimSpace(out)
262 return fmt.Sprintf("Recent commits:\n%s\n", out), nil
263}
264
265func (p *Prompt) Name() string {
266 return p.name
267}