1package prompt
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"runtime"
  8	"strings"
  9	"text/template"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/home"
 14)
 15
 16// Prompt represents a template-based prompt generator.
 17type Prompt struct {
 18	name       string
 19	template   string
 20	now        func() time.Time
 21	platform   string
 22	workingDir string
 23}
 24
 25type PromptDat struct {
 26	Provider     string
 27	Model        string
 28	Config       config.Config
 29	WorkingDir   string
 30	IsGitRepo    bool
 31	Platform     string
 32	Date         string
 33	ContextFiles []ContextFile
 34}
 35
 36type ContextFile struct {
 37	Path    string
 38	Content string
 39}
 40
 41type Option func(*Prompt)
 42
 43func WithTimeFunc(fn func() time.Time) Option {
 44	return func(p *Prompt) {
 45		p.now = fn
 46	}
 47}
 48
 49func WithPlatform(platform string) Option {
 50	return func(p *Prompt) {
 51		p.platform = platform
 52	}
 53}
 54
 55func WithWorkingDir(workingDir string) Option {
 56	return func(p *Prompt) {
 57		p.workingDir = workingDir
 58	}
 59}
 60
 61func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 62	p := &Prompt{
 63		name:     name,
 64		template: promptTemplate,
 65		now:      time.Now,
 66	}
 67	for _, opt := range opts {
 68		opt(p)
 69	}
 70	return p, nil
 71}
 72
 73func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
 74	t, err := template.New(p.name).Parse(p.template)
 75	if err != nil {
 76		return "", fmt.Errorf("parsing template: %w", err)
 77	}
 78	var sb strings.Builder
 79	if err := t.Execute(&sb, p.promptData(provider, model, cfg)); err != nil {
 80		return "", fmt.Errorf("executing template: %w", err)
 81	}
 82
 83	return sb.String(), nil
 84}
 85
 86func processFile(filePath string) *ContextFile {
 87	content, err := os.ReadFile(filePath)
 88	if err != nil {
 89		return nil
 90	}
 91	return &ContextFile{
 92		Path:    filePath,
 93		Content: string(content),
 94	}
 95}
 96
 97func processContextPath(p string, cfg config.Config) []ContextFile {
 98	var contexts []ContextFile
 99	fullPath := p
100	if !filepath.IsAbs(p) {
101		fullPath = filepath.Join(cfg.WorkingDir(), p)
102	}
103	info, err := os.Stat(fullPath)
104	if err != nil {
105		return contexts
106	}
107	if info.IsDir() {
108		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
109			if err != nil {
110				return err
111			}
112			if !d.IsDir() {
113				if result := processFile(path); result != nil {
114					contexts = append(contexts, *result)
115				}
116			}
117			return nil
118		})
119	} else {
120		result := processFile(fullPath)
121		if result != nil {
122			contexts = append(contexts, *result)
123		}
124	}
125	return contexts
126}
127
128// expandPath expands ~ and environment variables in file paths
129func expandPath(path string, cfg config.Config) string {
130	path = home.Long(path)
131	// Handle environment variable expansion using the same pattern as config
132	if strings.HasPrefix(path, "$") {
133		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
134			path = expanded
135		}
136	}
137
138	return path
139}
140
141func (p *Prompt) promptData(provider, model string, cfg config.Config) PromptDat {
142	workingDir := cfg.WorkingDir()
143	if p.workingDir != "" {
144		workingDir = p.workingDir
145	}
146	platform := runtime.GOOS
147	if p.platform != "" {
148		platform = p.platform
149	}
150
151	files := map[string][]ContextFile{}
152
153	for _, pth := range cfg.Options.ContextPaths {
154		expanded := expandPath(pth, cfg)
155		pathKey := strings.ToLower(expanded)
156		if _, ok := files[pathKey]; ok {
157			continue
158		}
159		content := processContextPath(expanded, cfg)
160		files[pathKey] = content
161	}
162
163	data := PromptDat{
164		Provider:   provider,
165		Model:      model,
166		Config:     cfg,
167		WorkingDir: workingDir,
168		IsGitRepo:  isGitRepo(cfg.WorkingDir()),
169		Platform:   platform,
170		Date:       p.now().Format("1/2/2006"),
171	}
172
173	for _, contextFiles := range files {
174		data.ContextFiles = append(data.ContextFiles, contextFiles...)
175	}
176	return data
177}
178
179func isGitRepo(dir string) bool {
180	_, err := os.Stat(filepath.Join(dir, ".git"))
181	return err == nil
182}
183
184func (p *Prompt) Name() string {
185	return p.name
186}