1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/home"
 16	"github.com/charmbracelet/crush/internal/shell"
 17)
 18
 19// Prompt represents a template-based prompt generator.
 20type Prompt struct {
 21	name       string
 22	template   string
 23	now        func() time.Time
 24	platform   string
 25	workingDir string
 26}
 27
 28type PromptDat struct {
 29	Provider     string
 30	Model        string
 31	Config       config.Config
 32	WorkingDir   string
 33	IsGitRepo    bool
 34	Platform     string
 35	Date         string
 36	GitStatus    string
 37	ContextFiles []ContextFile
 38}
 39
 40type ContextFile struct {
 41	Path    string
 42	Content string
 43}
 44
 45type Option func(*Prompt)
 46
 47func WithTimeFunc(fn func() time.Time) Option {
 48	return func(p *Prompt) {
 49		p.now = fn
 50	}
 51}
 52
 53func WithPlatform(platform string) Option {
 54	return func(p *Prompt) {
 55		p.platform = platform
 56	}
 57}
 58
 59func WithWorkingDir(workingDir string) Option {
 60	return func(p *Prompt) {
 61		p.workingDir = workingDir
 62	}
 63}
 64
 65func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 66	p := &Prompt{
 67		name:     name,
 68		template: promptTemplate,
 69		now:      time.Now,
 70	}
 71	for _, opt := range opts {
 72		opt(p)
 73	}
 74	return p, nil
 75}
 76
 77func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 78	t, err := template.New(p.name).Parse(p.template)
 79	if err != nil {
 80		return "", fmt.Errorf("parsing template: %w", err)
 81	}
 82	var sb strings.Builder
 83	d, err := p.promptData(ctx, provider, model, cfg)
 84	if err != nil {
 85		return "", err
 86	}
 87	if err := t.Execute(&sb, d); err != nil {
 88		return "", fmt.Errorf("executing template: %w", err)
 89	}
 90
 91	return sb.String(), nil
 92}
 93
 94func processFile(filePath string) *ContextFile {
 95	content, err := os.ReadFile(filePath)
 96	if err != nil {
 97		return nil
 98	}
 99	return &ContextFile{
100		Path:    filePath,
101		Content: string(content),
102	}
103}
104
105func processContextPath(p string, cfg config.Config) []ContextFile {
106	var contexts []ContextFile
107	fullPath := p
108	if !filepath.IsAbs(p) {
109		fullPath = filepath.Join(cfg.WorkingDir(), p)
110	}
111	info, err := os.Stat(fullPath)
112	if err != nil {
113		return contexts
114	}
115	if info.IsDir() {
116		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
117			if err != nil {
118				return err
119			}
120			if !d.IsDir() {
121				if result := processFile(path); result != nil {
122					contexts = append(contexts, *result)
123				}
124			}
125			return nil
126		})
127	} else {
128		result := processFile(fullPath)
129		if result != nil {
130			contexts = append(contexts, *result)
131		}
132	}
133	return contexts
134}
135
136// expandPath expands ~ and environment variables in file paths
137func expandPath(path string, cfg config.Config) string {
138	path = home.Long(path)
139	// Handle environment variable expansion using the same pattern as config
140	if strings.HasPrefix(path, "$") {
141		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
142			path = expanded
143		}
144	}
145
146	return path
147}
148
149func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
150	workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
151	platform := cmp.Or(p.platform, runtime.GOOS)
152
153	files := map[string][]ContextFile{}
154
155	for _, pth := range cfg.Options.ContextPaths {
156		expanded := expandPath(pth, cfg)
157		pathKey := strings.ToLower(expanded)
158		if _, ok := files[pathKey]; ok {
159			continue
160		}
161		content := processContextPath(expanded, cfg)
162		files[pathKey] = content
163	}
164
165	isGit := isGitRepo(cfg.WorkingDir())
166	data := PromptDat{
167		Provider:   provider,
168		Model:      model,
169		Config:     cfg,
170		WorkingDir: filepath.ToSlash(workingDir),
171		IsGitRepo:  isGit,
172		Platform:   platform,
173		Date:       p.now().Format("1/2/2006"),
174	}
175	if isGit {
176		var err error
177		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
178		if err != nil {
179			return PromptDat{}, err
180		}
181	}
182
183	for _, contextFiles := range files {
184		data.ContextFiles = append(data.ContextFiles, contextFiles...)
185	}
186	return data, nil
187}
188
189func isGitRepo(dir string) bool {
190	_, err := os.Stat(filepath.Join(dir, ".git"))
191	return err == nil
192}
193
194func getGitStatus(ctx context.Context, dir string) (string, error) {
195	sh := shell.NewShell(&shell.Options{
196		WorkingDir: dir,
197	})
198	branch, err := getGitBranch(ctx, sh)
199	if err != nil {
200		return "", err
201	}
202	status, err := getGitStatusSummary(ctx, sh)
203	if err != nil {
204		return "", err
205	}
206	commits, err := getGitRecentCommits(ctx, sh)
207	if err != nil {
208		return "", err
209	}
210	return branch + status + commits, nil
211}
212
213func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
214	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
215	if err != nil {
216		return "", nil
217	}
218	out = strings.TrimSpace(out)
219	if out == "" {
220		return "", nil
221	}
222	return fmt.Sprintf("Current branch: %s\n", out), nil
223}
224
225func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
226	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
227	if err != nil {
228		return "", nil
229	}
230	out = strings.TrimSpace(out)
231	if out == "" {
232		return "Status: clean\n", nil
233	}
234	return fmt.Sprintf("Status:\n%s\n", out), nil
235}
236
237func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
238	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
239	if err != nil || out == "" {
240		return "", nil
241	}
242	out = strings.TrimSpace(out)
243	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
244}
245
246func (p *Prompt) Name() string {
247	return p.name
248}