prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/home"
 16	"github.com/charmbracelet/crush/internal/shell"
 17	"github.com/charmbracelet/crush/internal/skills"
 18)
 19
 20// Prompt represents a template-based prompt generator.
 21type Prompt struct {
 22	name       string
 23	template   string
 24	now        func() time.Time
 25	platform   string
 26	workingDir string
 27}
 28
 29type PromptDat struct {
 30	Provider      string
 31	Model         string
 32	Config        config.Config
 33	WorkingDir    string
 34	IsGitRepo     bool
 35	Platform      string
 36	Date          string
 37	GitStatus     string
 38	ContextFiles  []ContextFile
 39	AvailSkillXML string
 40}
 41
 42type ContextFile struct {
 43	Path    string
 44	Content string
 45}
 46
 47type Option func(*Prompt)
 48
 49func WithTimeFunc(fn func() time.Time) Option {
 50	return func(p *Prompt) {
 51		p.now = fn
 52	}
 53}
 54
 55func WithPlatform(platform string) Option {
 56	return func(p *Prompt) {
 57		p.platform = platform
 58	}
 59}
 60
 61func WithWorkingDir(workingDir string) Option {
 62	return func(p *Prompt) {
 63		p.workingDir = workingDir
 64	}
 65}
 66
 67func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 68	p := &Prompt{
 69		name:     name,
 70		template: promptTemplate,
 71		now:      time.Now,
 72	}
 73	for _, opt := range opts {
 74		opt(p)
 75	}
 76	return p, nil
 77}
 78
 79func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
 80	t, err := template.New(p.name).Parse(p.template)
 81	if err != nil {
 82		return "", fmt.Errorf("parsing template: %w", err)
 83	}
 84	var sb strings.Builder
 85	d, err := p.promptData(ctx, provider, model, store)
 86	if err != nil {
 87		return "", err
 88	}
 89	if err := t.Execute(&sb, d); err != nil {
 90		return "", fmt.Errorf("executing template: %w", err)
 91	}
 92
 93	return sb.String(), nil
 94}
 95
 96func processFile(filePath string) *ContextFile {
 97	content, err := os.ReadFile(filePath)
 98	if err != nil {
 99		return nil
100	}
101	return &ContextFile{
102		Path:    filePath,
103		Content: string(content),
104	}
105}
106
107func processContextPath(p string, store *config.ConfigStore) []ContextFile {
108	var contexts []ContextFile
109	fullPath := p
110	if !filepath.IsAbs(p) {
111		fullPath = filepath.Join(store.WorkingDir(), p)
112	}
113	info, err := os.Stat(fullPath)
114	if err != nil {
115		return contexts
116	}
117	if info.IsDir() {
118		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
119			if err != nil {
120				return err
121			}
122			if !d.IsDir() {
123				if result := processFile(path); result != nil {
124					contexts = append(contexts, *result)
125				}
126			}
127			return nil
128		})
129	} else {
130		result := processFile(fullPath)
131		if result != nil {
132			contexts = append(contexts, *result)
133		}
134	}
135	return contexts
136}
137
138// expandPath expands ~ and environment variables in file paths
139func expandPath(path string, store *config.ConfigStore) string {
140	path = home.Long(path)
141	// Handle environment variable expansion using the same pattern as config
142	if strings.HasPrefix(path, "$") {
143		if expanded, err := store.Resolver().ResolveValue(path); err == nil {
144			path = expanded
145		}
146	}
147
148	return path
149}
150
151func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
152	workingDir := cmp.Or(p.workingDir, store.WorkingDir())
153	platform := cmp.Or(p.platform, runtime.GOOS)
154
155	files := map[string][]ContextFile{}
156
157	cfg := store.Config()
158	for _, pth := range cfg.Options.ContextPaths {
159		expanded := expandPath(pth, store)
160		pathKey := strings.ToLower(expanded)
161		if _, ok := files[pathKey]; ok {
162			continue
163		}
164		content := processContextPath(expanded, store)
165		files[pathKey] = content
166	}
167
168	// Discover and load skills metadata.
169	var availSkillXML string
170	if len(cfg.Options.SkillsPaths) > 0 {
171		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
172		for _, pth := range cfg.Options.SkillsPaths {
173			expandedPaths = append(expandedPaths, expandPath(pth, store))
174		}
175		if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
176			availSkillXML = skills.ToPromptXML(discoveredSkills)
177		}
178	}
179
180	isGit := isGitRepo(store.WorkingDir())
181	data := PromptDat{
182		Provider:      provider,
183		Model:         model,
184		Config:        *cfg,
185		WorkingDir:    filepath.ToSlash(workingDir),
186		IsGitRepo:     isGit,
187		Platform:      platform,
188		Date:          p.now().Format("1/2/2006"),
189		AvailSkillXML: availSkillXML,
190	}
191	if isGit {
192		var err error
193		data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
194		if err != nil {
195			return PromptDat{}, err
196		}
197	}
198
199	for _, contextFiles := range files {
200		data.ContextFiles = append(data.ContextFiles, contextFiles...)
201	}
202	return data, nil
203}
204
205func isGitRepo(dir string) bool {
206	_, err := os.Stat(filepath.Join(dir, ".git"))
207	return err == nil
208}
209
210func getGitStatus(ctx context.Context, dir string) (string, error) {
211	sh := shell.NewShell(&shell.Options{
212		WorkingDir: dir,
213	})
214	branch, err := getGitBranch(ctx, sh)
215	if err != nil {
216		return "", err
217	}
218	status, err := getGitStatusSummary(ctx, sh)
219	if err != nil {
220		return "", err
221	}
222	commits, err := getGitRecentCommits(ctx, sh)
223	if err != nil {
224		return "", err
225	}
226	return branch + status + commits, nil
227}
228
229func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
230	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
231	if err != nil {
232		return "", nil
233	}
234	out = strings.TrimSpace(out)
235	if out == "" {
236		return "", nil
237	}
238	return fmt.Sprintf("Current branch: %s\n", out), nil
239}
240
241func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
242	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
243	if err != nil {
244		return "", nil
245	}
246	out = strings.TrimSpace(out)
247	if out == "" {
248		return "Status: clean\n", nil
249	}
250	return fmt.Sprintf("Status:\n%s\n", out), nil
251}
252
253func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
254	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
255	if err != nil || out == "" {
256		return "", nil
257	}
258	out = strings.TrimSpace(out)
259	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
260}
261
262func (p *Prompt) Name() string {
263	return p.name
264}