prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/home"
 16	"github.com/charmbracelet/crush/internal/shell"
 17	"github.com/charmbracelet/crush/internal/skills"
 18)
 19
 20// Prompt represents a template-based prompt generator.
 21type Prompt struct {
 22	name        string
 23	template    string
 24	now         func() time.Time
 25	platform    string
 26	workingDir  string
 27	modelFamily ModelFamily
 28}
 29
 30type PromptDat struct {
 31	Provider      string
 32	Model         string
 33	ModelFamily   ModelFamily
 34	Config        config.Config
 35	WorkingDir    string
 36	IsGitRepo     bool
 37	Platform      string
 38	Date          string
 39	GitStatus     string
 40	ContextFiles  []ContextFile
 41	AvailSkillXML string
 42}
 43
 44type ContextFile struct {
 45	Path    string
 46	Content string
 47}
 48
 49type Option func(*Prompt)
 50
 51func WithTimeFunc(fn func() time.Time) Option {
 52	return func(p *Prompt) {
 53		p.now = fn
 54	}
 55}
 56
 57func WithPlatform(platform string) Option {
 58	return func(p *Prompt) {
 59		p.platform = platform
 60	}
 61}
 62
 63func WithWorkingDir(workingDir string) Option {
 64	return func(p *Prompt) {
 65		p.workingDir = workingDir
 66	}
 67}
 68
 69func WithModelFamily(modelFamily ModelFamily) Option {
 70	return func(p *Prompt) {
 71		p.modelFamily = modelFamily
 72	}
 73}
 74
 75func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 76	p := &Prompt{
 77		name:     name,
 78		template: promptTemplate,
 79		now:      time.Now,
 80	}
 81	for _, opt := range opts {
 82		opt(p)
 83	}
 84	return p, nil
 85}
 86
 87func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 88	t, err := template.New(p.name).Parse(p.template)
 89	if err != nil {
 90		return "", fmt.Errorf("parsing template: %w", err)
 91	}
 92	var sb strings.Builder
 93	d, err := p.promptData(ctx, provider, model, cfg)
 94	if err != nil {
 95		return "", err
 96	}
 97	if err := t.Execute(&sb, d); err != nil {
 98		return "", fmt.Errorf("executing template: %w", err)
 99	}
100
101	return sb.String(), nil
102}
103
104func processFile(filePath string) *ContextFile {
105	content, err := os.ReadFile(filePath)
106	if err != nil {
107		return nil
108	}
109	return &ContextFile{
110		Path:    filePath,
111		Content: string(content),
112	}
113}
114
115func processContextPath(p string, cfg config.Config) []ContextFile {
116	var contexts []ContextFile
117	fullPath := p
118	if !filepath.IsAbs(p) {
119		fullPath = filepath.Join(cfg.WorkingDir(), p)
120	}
121	info, err := os.Stat(fullPath)
122	if err != nil {
123		return contexts
124	}
125	if info.IsDir() {
126		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
127			if err != nil {
128				return err
129			}
130			if !d.IsDir() {
131				if result := processFile(path); result != nil {
132					contexts = append(contexts, *result)
133				}
134			}
135			return nil
136		})
137	} else {
138		result := processFile(fullPath)
139		if result != nil {
140			contexts = append(contexts, *result)
141		}
142	}
143	return contexts
144}
145
146// expandPath expands ~ and environment variables in file paths
147func expandPath(path string, cfg config.Config) string {
148	path = home.Long(path)
149	// Handle environment variable expansion using the same pattern as config
150	if strings.HasPrefix(path, "$") {
151		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
152			path = expanded
153		}
154	}
155
156	return path
157}
158
159func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
160	workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
161	platform := cmp.Or(p.platform, runtime.GOOS)
162
163	files := map[string][]ContextFile{}
164
165	for _, pth := range cfg.Options.ContextPaths {
166		expanded := expandPath(pth, cfg)
167		pathKey := strings.ToLower(expanded)
168		if _, ok := files[pathKey]; ok {
169			continue
170		}
171		content := processContextPath(expanded, cfg)
172		files[pathKey] = content
173	}
174
175	// Discover and load skills metadata.
176	var availSkillXML string
177	if len(cfg.Options.SkillsPaths) > 0 {
178		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
179		for _, pth := range cfg.Options.SkillsPaths {
180			expandedPaths = append(expandedPaths, expandPath(pth, cfg))
181		}
182		if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
183			availSkillXML = skills.ToPromptXML(discoveredSkills)
184		}
185	}
186
187	isGit := isGitRepo(cfg.WorkingDir())
188	data := PromptDat{
189		Provider:      provider,
190		Model:         model,
191		ModelFamily:   p.modelFamily,
192		Config:        cfg,
193		WorkingDir:    filepath.ToSlash(workingDir),
194		IsGitRepo:     isGit,
195		Platform:      platform,
196		Date:          p.now().Format("1/2/2006"),
197		AvailSkillXML: availSkillXML,
198	}
199	if isGit {
200		var err error
201		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
202		if err != nil {
203			return PromptDat{}, err
204		}
205	}
206
207	for _, contextFiles := range files {
208		data.ContextFiles = append(data.ContextFiles, contextFiles...)
209	}
210	return data, nil
211}
212
213func isGitRepo(dir string) bool {
214	_, err := os.Stat(filepath.Join(dir, ".git"))
215	return err == nil
216}
217
218func getGitStatus(ctx context.Context, dir string) (string, error) {
219	sh := shell.NewShell(&shell.Options{
220		WorkingDir: dir,
221	})
222	branch, err := getGitBranch(ctx, sh)
223	if err != nil {
224		return "", err
225	}
226	status, err := getGitStatusSummary(ctx, sh)
227	if err != nil {
228		return "", err
229	}
230	commits, err := getGitRecentCommits(ctx, sh)
231	if err != nil {
232		return "", err
233	}
234	return branch + status + commits, nil
235}
236
237func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
238	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
239	if err != nil {
240		return "", nil
241	}
242	out = strings.TrimSpace(out)
243	if out == "" {
244		return "", nil
245	}
246	return fmt.Sprintf("Current branch: %s\n", out), nil
247}
248
249func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
250	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
251	if err != nil {
252		return "", nil
253	}
254	out = strings.TrimSpace(out)
255	if out == "" {
256		return "Status: clean\n", nil
257	}
258	return fmt.Sprintf("Status:\n%s\n", out), nil
259}
260
261func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
262	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
263	if err != nil || out == "" {
264		return "", nil
265	}
266	out = strings.TrimSpace(out)
267	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
268}
269
270func (p *Prompt) Name() string {
271	return p.name
272}