1package prompt
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "text/template"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/home"
14)
15
16// Prompt represents a template-based prompt generator.
17type Prompt struct {
18 name string
19 template string
20}
21
22type PromptDat struct {
23 Provider string
24 Model string
25 Config config.Config
26 WorkingDir string
27 IsGitRepo bool
28 Platform string
29 Date string
30}
31
32type ContextFile struct {
33 Path string
34 Content string
35}
36
37func NewPrompt(name, promptTemplate string) (*Prompt, error) {
38 return &Prompt{
39 name: name,
40 template: promptTemplate,
41 }, nil
42}
43
44func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
45 t, err := template.New(p.name).Funcs(p.funcMap(cfg)).Parse(p.template)
46 if err != nil {
47 return "", fmt.Errorf("parsing template: %w", err)
48 }
49 var sb strings.Builder
50 if err := t.Execute(&sb, promptData(provider, model, cfg)); err != nil {
51 return "", fmt.Errorf("executing template: %w", err)
52 }
53
54 return sb.String(), nil
55}
56
57func (p *Prompt) funcMap(cfg config.Config) template.FuncMap {
58 return template.FuncMap{
59 "contextFiles": func(path string) []ContextFile {
60 path = expandPath(path, cfg)
61 return processContextPath(path, cfg)
62 },
63 }
64}
65
66func processFile(filePath string) *ContextFile {
67 content, err := os.ReadFile(filePath)
68 if err != nil {
69 return nil
70 }
71 return &ContextFile{
72 Path: filePath,
73 Content: string(content),
74 }
75}
76
77func processContextPath(p string, cfg config.Config) []ContextFile {
78 var contexts []ContextFile
79 fullPath := p
80 if !filepath.IsAbs(p) {
81 fullPath = filepath.Join(cfg.WorkingDir(), p)
82 }
83 info, err := os.Stat(fullPath)
84 if err != nil {
85 return contexts
86 }
87 if info.IsDir() {
88 filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
89 if err != nil {
90 return err
91 }
92 if !d.IsDir() {
93 if result := processFile(path); result != nil {
94 contexts = append(contexts, *result)
95 }
96 }
97 return nil
98 })
99 } else {
100 result := processFile(fullPath)
101 if result != nil {
102 contexts = append(contexts, *result)
103 }
104 }
105 return contexts
106}
107
108// expandPath expands ~ and environment variables in file paths
109func expandPath(path string, cfg config.Config) string {
110 path = home.Long(path)
111 // Handle environment variable expansion using the same pattern as config
112 if strings.HasPrefix(path, "$") {
113 if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
114 path = expanded
115 }
116 }
117
118 return path
119}
120
121func promptData(provider, model string, cfg config.Config) PromptDat {
122 return PromptDat{
123 Provider: provider,
124 Model: model,
125 Config: cfg,
126 WorkingDir: cfg.WorkingDir(),
127 IsGitRepo: isGitRepo(cfg.WorkingDir()),
128 Platform: runtime.GOOS,
129 Date: time.Now().Format("1/2/2006"),
130 }
131}
132
133func isGitRepo(dir string) bool {
134 _, err := os.Stat(filepath.Join(dir, ".git"))
135 return err == nil
136}
137
138func (p *Prompt) Name() string {
139 return p.name
140}