1package prompt
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "text/template"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14)
15
16// Prompt represents a template-based prompt generator.
17type Prompt struct {
18 name string
19 template string
20 now func() time.Time
21}
22
23type PromptDat struct {
24 Provider string
25 Model string
26 Config config.Config
27 WorkingDir string
28 IsGitRepo bool
29 Platform string
30 Date string
31}
32
33type ContextFile struct {
34 Path string
35 Content string
36}
37
38type Option func(*Prompt)
39
40func WithTimeFunc(fn func() time.Time) Option {
41 return func(p *Prompt) {
42 p.now = fn
43 }
44}
45
46func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
47 p := &Prompt{
48 name: name,
49 template: promptTemplate,
50 now: time.Now,
51 }
52 for _, opt := range opts {
53 opt(p)
54 }
55 return p, nil
56}
57
58func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
59 t, err := template.New(p.name).Funcs(p.funcMap(cfg)).Parse(p.template)
60 if err != nil {
61 return "", fmt.Errorf("parsing template: %w", err)
62 }
63 var sb strings.Builder
64 if err := t.Execute(&sb, p.promptData(provider, model, cfg)); err != nil {
65 return "", fmt.Errorf("executing template: %w", err)
66 }
67
68 return sb.String(), nil
69}
70
71func (p *Prompt) funcMap(cfg config.Config) template.FuncMap {
72 return template.FuncMap{
73 "contextFiles": func(path string) []ContextFile {
74 path = expandPath(path, cfg)
75 return processContextPath(path, cfg)
76 },
77 }
78}
79
80func processFile(filePath string) *ContextFile {
81 content, err := os.ReadFile(filePath)
82 if err != nil {
83 return nil
84 }
85 return &ContextFile{
86 Path: filePath,
87 Content: string(content),
88 }
89}
90
91func processContextPath(p string, cfg config.Config) []ContextFile {
92 var contexts []ContextFile
93 fullPath := p
94 if !filepath.IsAbs(p) {
95 fullPath = filepath.Join(cfg.WorkingDir(), p)
96 }
97 info, err := os.Stat(fullPath)
98 if err != nil {
99 return contexts
100 }
101 if info.IsDir() {
102 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
103 if err != nil {
104 return err
105 }
106 if !d.IsDir() {
107 if result := processFile(path); result != nil {
108 contexts = append(contexts, *result)
109 }
110 }
111 return nil
112 })
113 } else {
114 result := processFile(fullPath)
115 if result != nil {
116 contexts = append(contexts, *result)
117 }
118 }
119 return contexts
120}
121
122// expandPath expands ~ and environment variables in file paths
123func expandPath(path string, cfg config.Config) string {
124 path = home.Long(path)
125 // Handle environment variable expansion using the same pattern as config
126 if strings.HasPrefix(path, "$") {
127 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
128 path = expanded
129 }
130 }
131
132 return path
133}
134
135func (p *Prompt) promptData(provider, model string, cfg config.Config) PromptDat {
136 return PromptDat{
137 Provider: provider,
138 Model: model,
139 Config: cfg,
140 WorkingDir: cfg.WorkingDir(),
141 IsGitRepo: isGitRepo(cfg.WorkingDir()),
142 Platform: runtime.GOOS,
143 Date: p.now().Format("1/2/2006"),
144 }
145}
146
147func isGitRepo(dir string) bool {
148 _, err := os.Stat(filepath.Join(dir, ".git"))
149 return err == nil
150}
151
152func (p *Prompt) Name() string {
153 return p.name
154}