prompt.go

  1package prompt
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"path/filepath"
  8	"runtime"
  9	"strings"
 10	"text/template"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/home"
 15	"github.com/charmbracelet/crush/internal/shell"
 16)
 17
 18// Prompt represents a template-based prompt generator.
 19type Prompt struct {
 20	name       string
 21	template   string
 22	now        func() time.Time
 23	platform   string
 24	workingDir string
 25}
 26
 27type PromptDat struct {
 28	Provider     string
 29	Model        string
 30	Config       config.Config
 31	WorkingDir   string
 32	IsGitRepo    bool
 33	Platform     string
 34	Date         string
 35	GitStatus    string
 36	ContextFiles []ContextFile
 37}
 38
 39type ContextFile struct {
 40	Path    string
 41	Content string
 42}
 43
 44type Option func(*Prompt)
 45
 46func WithTimeFunc(fn func() time.Time) Option {
 47	return func(p *Prompt) {
 48		p.now = fn
 49	}
 50}
 51
 52func WithPlatform(platform string) Option {
 53	return func(p *Prompt) {
 54		p.platform = platform
 55	}
 56}
 57
 58func WithWorkingDir(workingDir string) Option {
 59	return func(p *Prompt) {
 60		p.workingDir = workingDir
 61	}
 62}
 63
 64func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 65	p := &Prompt{
 66		name:     name,
 67		template: promptTemplate,
 68		now:      time.Now,
 69	}
 70	for _, opt := range opts {
 71		opt(p)
 72	}
 73	return p, nil
 74}
 75
 76func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 77	t, err := template.New(p.name).Parse(p.template)
 78	if err != nil {
 79		return "", fmt.Errorf("parsing template: %w", err)
 80	}
 81	var sb strings.Builder
 82	d, err := p.promptData(ctx, provider, model, cfg)
 83	if err != nil {
 84		return "", err
 85	}
 86	if err := t.Execute(&sb, d); err != nil {
 87		return "", fmt.Errorf("executing template: %w", err)
 88	}
 89
 90	return sb.String(), nil
 91}
 92
 93func processFile(filePath string) *ContextFile {
 94	content, err := os.ReadFile(filePath)
 95	if err != nil {
 96		return nil
 97	}
 98	return &ContextFile{
 99		Path:    filePath,
100		Content: string(content),
101	}
102}
103
104func processContextPath(p string, cfg config.Config) []ContextFile {
105	var contexts []ContextFile
106	fullPath := p
107	if !filepath.IsAbs(p) {
108		fullPath = filepath.Join(cfg.WorkingDir(), p)
109	}
110	info, err := os.Stat(fullPath)
111	if err != nil {
112		return contexts
113	}
114	if info.IsDir() {
115		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
116			if err != nil {
117				return err
118			}
119			if !d.IsDir() {
120				if result := processFile(path); result != nil {
121					contexts = append(contexts, *result)
122				}
123			}
124			return nil
125		})
126	} else {
127		result := processFile(fullPath)
128		if result != nil {
129			contexts = append(contexts, *result)
130		}
131	}
132	return contexts
133}
134
135// expandPath expands ~ and environment variables in file paths
136func expandPath(path string, cfg config.Config) string {
137	path = home.Long(path)
138	// Handle environment variable expansion using the same pattern as config
139	if strings.HasPrefix(path, "$") {
140		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
141			path = expanded
142		}
143	}
144
145	return path
146}
147
148func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
149	workingDir := cfg.WorkingDir()
150	if p.workingDir != "" {
151		workingDir = p.workingDir
152	}
153	platform := runtime.GOOS
154	if p.platform != "" {
155		platform = p.platform
156	}
157
158	files := map[string][]ContextFile{}
159
160	for _, pth := range cfg.Options.ContextPaths {
161		expanded := expandPath(pth, cfg)
162		pathKey := strings.ToLower(expanded)
163		if _, ok := files[pathKey]; ok {
164			continue
165		}
166		content := processContextPath(expanded, cfg)
167		files[pathKey] = content
168	}
169
170	isGit := isGitRepo(cfg.WorkingDir())
171	data := PromptDat{
172		Provider:   provider,
173		Model:      model,
174		Config:     cfg,
175		WorkingDir: workingDir,
176		IsGitRepo:  isGit,
177		Platform:   platform,
178		Date:       p.now().Format("1/2/2006"),
179	}
180	if isGit {
181		var err error
182		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
183		if err != nil {
184			return PromptDat{}, err
185		}
186	}
187
188	for _, contextFiles := range files {
189		data.ContextFiles = append(data.ContextFiles, contextFiles...)
190	}
191	return data, nil
192}
193
194func isGitRepo(dir string) bool {
195	_, err := os.Stat(filepath.Join(dir, ".git"))
196	return err == nil
197}
198
199func getGitStatus(ctx context.Context, dir string) (string, error) {
200	sh := shell.NewShell(&shell.Options{
201		WorkingDir: dir,
202	})
203	branch, err := getGitBranch(ctx, sh)
204	if err != nil {
205		return "", err
206	}
207	status, err := getGitStatusSummary(ctx, sh)
208	if err != nil {
209		return "", err
210	}
211	commits, err := getGitRecentCommits(ctx, sh)
212	if err != nil {
213		return "", err
214	}
215	return branch + status + commits, nil
216}
217
218func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
219	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
220	if err != nil {
221		return "", nil
222	}
223	out = strings.TrimSpace(out)
224	if out == "" {
225		return "", nil
226	}
227	return fmt.Sprintf("Current branch: %s\n", out), nil
228}
229
230func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
231	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
232	if err != nil {
233		return "", nil
234	}
235	out = strings.TrimSpace(out)
236	if out == "" {
237		return "Status: clean\n", nil
238	}
239	return fmt.Sprintf("Status:\n%s\n", out), nil
240}
241
242func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
243	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
244	if err != nil || out == "" {
245		return "", nil
246	}
247	out = strings.TrimSpace(out)
248	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
249}
250
251func (p *Prompt) Name() string {
252	return p.name
253}