1package prompt
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "text/template"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14)
15
16// Prompt represents a template-based prompt generator.
17type Prompt struct {
18 name string
19 template string
20 now func() time.Time
21 platform string
22 workingDir string
23}
24
25type PromptDat struct {
26 Provider string
27 Model string
28 Config config.Config
29 WorkingDir string
30 IsGitRepo bool
31 Platform string
32 Date string
33 ContextFiles []ContextFile
34}
35
36type ContextFile struct {
37 Path string
38 Content string
39}
40
41type Option func(*Prompt)
42
43func WithTimeFunc(fn func() time.Time) Option {
44 return func(p *Prompt) {
45 p.now = fn
46 }
47}
48
49func WithPlatform(platform string) Option {
50 return func(p *Prompt) {
51 p.platform = platform
52 }
53}
54
55func WithWorkingDir(workingDir string) Option {
56 return func(p *Prompt) {
57 p.workingDir = workingDir
58 }
59}
60
61func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
62 p := &Prompt{
63 name: name,
64 template: promptTemplate,
65 now: time.Now,
66 }
67 for _, opt := range opts {
68 opt(p)
69 }
70 return p, nil
71}
72
73func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
74 t, err := template.New(p.name).Parse(p.template)
75 if err != nil {
76 return "", fmt.Errorf("parsing template: %w", err)
77 }
78 var sb strings.Builder
79 if err := t.Execute(&sb, p.promptData(provider, model, cfg)); err != nil {
80 return "", fmt.Errorf("executing template: %w", err)
81 }
82
83 return sb.String(), nil
84}
85
86func processFile(filePath string) *ContextFile {
87 content, err := os.ReadFile(filePath)
88 if err != nil {
89 return nil
90 }
91 return &ContextFile{
92 Path: filePath,
93 Content: string(content),
94 }
95}
96
97func processContextPath(p string, cfg config.Config) []ContextFile {
98 var contexts []ContextFile
99 fullPath := p
100 if !filepath.IsAbs(p) {
101 fullPath = filepath.Join(cfg.WorkingDir(), p)
102 }
103 info, err := os.Stat(fullPath)
104 if err != nil {
105 return contexts
106 }
107 if info.IsDir() {
108 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
109 if err != nil {
110 return err
111 }
112 if !d.IsDir() {
113 if result := processFile(path); result != nil {
114 contexts = append(contexts, *result)
115 }
116 }
117 return nil
118 })
119 } else {
120 result := processFile(fullPath)
121 if result != nil {
122 contexts = append(contexts, *result)
123 }
124 }
125 return contexts
126}
127
128// expandPath expands ~ and environment variables in file paths
129func expandPath(path string, cfg config.Config) string {
130 path = home.Long(path)
131 // Handle environment variable expansion using the same pattern as config
132 if strings.HasPrefix(path, "$") {
133 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
134 path = expanded
135 }
136 }
137
138 return path
139}
140
141func (p *Prompt) promptData(provider, model string, cfg config.Config) PromptDat {
142 workingDir := cfg.WorkingDir()
143 if p.workingDir != "" {
144 workingDir = p.workingDir
145 }
146 platform := runtime.GOOS
147 if p.platform != "" {
148 platform = p.platform
149 }
150
151 files := map[string][]ContextFile{}
152
153 for _, pth := range cfg.Options.ContextPaths {
154 expanded := expandPath(pth, cfg)
155 pathKey := strings.ToLower(expanded)
156 if _, ok := files[pathKey]; ok {
157 continue
158 }
159 content := processContextPath(expanded, cfg)
160 files[pathKey] = content
161 }
162
163 data := PromptDat{
164 Provider: provider,
165 Model: model,
166 Config: cfg,
167 WorkingDir: workingDir,
168 IsGitRepo: isGitRepo(cfg.WorkingDir()),
169 Platform: platform,
170 Date: p.now().Format("1/2/2006"),
171 }
172
173 for _, contextFiles := range files {
174 data.ContextFiles = append(data.ContextFiles, contextFiles...)
175 }
176 return data
177}
178
179func isGitRepo(dir string) bool {
180 _, err := os.Stat(filepath.Join(dir, ".git"))
181 return err == nil
182}
183
184func (p *Prompt) Name() string {
185 return p.name
186}