1package prompt
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "text/template"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14)
15
16// Prompt represents a template-based prompt generator.
17type Prompt struct {
18 name string
19 template string
20 now func() time.Time
21 platform string
22 workingDir string
23}
24
25type PromptDat struct {
26 Provider string
27 Model string
28 Config config.Config
29 WorkingDir string
30 IsGitRepo bool
31 Platform string
32 Date string
33}
34
35type ContextFile struct {
36 Path string
37 Content string
38}
39
40type Option func(*Prompt)
41
42func WithTimeFunc(fn func() time.Time) Option {
43 return func(p *Prompt) {
44 p.now = fn
45 }
46}
47
48func WithPlatform(platform string) Option {
49 return func(p *Prompt) {
50 p.platform = platform
51 }
52}
53
54func WithWorkingDir(workingDir string) Option {
55 return func(p *Prompt) {
56 p.workingDir = workingDir
57 }
58}
59
60func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
61 p := &Prompt{
62 name: name,
63 template: promptTemplate,
64 now: time.Now,
65 }
66 for _, opt := range opts {
67 opt(p)
68 }
69 return p, nil
70}
71
72func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
73 t, err := template.New(p.name).Funcs(p.funcMap(cfg)).Parse(p.template)
74 if err != nil {
75 return "", fmt.Errorf("parsing template: %w", err)
76 }
77 var sb strings.Builder
78 if err := t.Execute(&sb, p.promptData(provider, model, cfg)); err != nil {
79 return "", fmt.Errorf("executing template: %w", err)
80 }
81
82 return sb.String(), nil
83}
84
85func (p *Prompt) funcMap(cfg config.Config) template.FuncMap {
86 return template.FuncMap{
87 "contextFiles": func(path string) []ContextFile {
88 path = expandPath(path, cfg)
89 return processContextPath(path, cfg)
90 },
91 }
92}
93
94func processFile(filePath string) *ContextFile {
95 content, err := os.ReadFile(filePath)
96 if err != nil {
97 return nil
98 }
99 return &ContextFile{
100 Path: filePath,
101 Content: string(content),
102 }
103}
104
105func processContextPath(p string, cfg config.Config) []ContextFile {
106 var contexts []ContextFile
107 fullPath := p
108 if !filepath.IsAbs(p) {
109 fullPath = filepath.Join(cfg.WorkingDir(), p)
110 }
111 info, err := os.Stat(fullPath)
112 if err != nil {
113 return contexts
114 }
115 if info.IsDir() {
116 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
117 if err != nil {
118 return err
119 }
120 if !d.IsDir() {
121 if result := processFile(path); result != nil {
122 contexts = append(contexts, *result)
123 }
124 }
125 return nil
126 })
127 } else {
128 result := processFile(fullPath)
129 if result != nil {
130 contexts = append(contexts, *result)
131 }
132 }
133 return contexts
134}
135
136// expandPath expands ~ and environment variables in file paths
137func expandPath(path string, cfg config.Config) string {
138 path = home.Long(path)
139 // Handle environment variable expansion using the same pattern as config
140 if strings.HasPrefix(path, "$") {
141 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
142 path = expanded
143 }
144 }
145
146 return path
147}
148
149func (p *Prompt) promptData(provider, model string, cfg config.Config) PromptDat {
150 workingDir := cfg.WorkingDir()
151 if p.workingDir != "" {
152 workingDir = p.workingDir
153 }
154 platform := runtime.GOOS
155 if p.platform != "" {
156 platform = p.platform
157 }
158 return PromptDat{
159 Provider: provider,
160 Model: model,
161 Config: cfg,
162 WorkingDir: workingDir,
163 IsGitRepo: isGitRepo(cfg.WorkingDir()),
164 Platform: platform,
165 Date: p.now().Format("1/2/2006"),
166 }
167}
168
169func isGitRepo(dir string) bool {
170 _, err := os.Stat(filepath.Join(dir, ".git"))
171 return err == nil
172}
173
174func (p *Prompt) Name() string {
175 return p.name
176}