prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"text/template"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/filepathext"
 17	"github.com/charmbracelet/crush/internal/home"
 18	"github.com/charmbracelet/crush/internal/shell"
 19	"github.com/charmbracelet/crush/internal/skills"
 20)
 21
 22// Prompt represents a template-based prompt generator.
 23type Prompt struct {
 24	name       string
 25	template   string
 26	now        func() time.Time
 27	platform   string
 28	workingDir string
 29}
 30
 31type PromptDat struct {
 32	Provider      string
 33	Model         string
 34	Config        config.Config
 35	WorkingDir    string
 36	IsGitRepo     bool
 37	Platform      string
 38	Date          string
 39	GitStatus     string
 40	ContextFiles  []ContextFile
 41	AvailSkillXML string
 42}
 43
 44type ContextFile struct {
 45	Path    string
 46	Content string
 47}
 48
 49type Option func(*Prompt)
 50
 51func WithTimeFunc(fn func() time.Time) Option {
 52	return func(p *Prompt) {
 53		p.now = fn
 54	}
 55}
 56
 57func WithPlatform(platform string) Option {
 58	return func(p *Prompt) {
 59		p.platform = platform
 60	}
 61}
 62
 63func WithWorkingDir(workingDir string) Option {
 64	return func(p *Prompt) {
 65		p.workingDir = workingDir
 66	}
 67}
 68
 69func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 70	p := &Prompt{
 71		name:     name,
 72		template: promptTemplate,
 73		now:      time.Now,
 74	}
 75	for _, opt := range opts {
 76		opt(p)
 77	}
 78	return p, nil
 79}
 80
 81func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
 82	t, err := template.New(p.name).Parse(p.template)
 83	if err != nil {
 84		return "", fmt.Errorf("parsing template: %w", err)
 85	}
 86	var sb strings.Builder
 87	d, err := p.promptData(ctx, provider, model, store)
 88	if err != nil {
 89		return "", err
 90	}
 91	if err := t.Execute(&sb, d); err != nil {
 92		return "", fmt.Errorf("executing template: %w", err)
 93	}
 94
 95	return sb.String(), nil
 96}
 97
 98func processFile(filePath string) *ContextFile {
 99	content, err := os.ReadFile(filePath)
100	if err != nil {
101		return nil
102	}
103	return &ContextFile{
104		Path:    filePath,
105		Content: string(content),
106	}
107}
108
109func processContextPath(p string, store *config.ConfigStore) []ContextFile {
110	var contexts []ContextFile
111	fullPath := filepathext.SmartJoin(store.WorkingDir(), p)
112	info, err := os.Stat(fullPath)
113	if err != nil {
114		return contexts
115	}
116	if info.IsDir() {
117		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
118			if err != nil {
119				return err
120			}
121			if !d.IsDir() {
122				if result := processFile(path); result != nil {
123					contexts = append(contexts, *result)
124				}
125			}
126			return nil
127		})
128	} else {
129		result := processFile(fullPath)
130		if result != nil {
131			contexts = append(contexts, *result)
132		}
133	}
134	return contexts
135}
136
137// expandPath expands ~ and environment variables in file paths
138func expandPath(path string, store *config.ConfigStore) string {
139	path = home.Long(path)
140	// Handle environment variable expansion using the same pattern as config
141	if strings.HasPrefix(path, "$") {
142		if expanded, err := store.Resolver().ResolveValue(path); err == nil {
143			path = expanded
144		}
145	}
146
147	return path
148}
149
150func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
151	workingDir := cmp.Or(p.workingDir, store.WorkingDir())
152	platform := cmp.Or(p.platform, runtime.GOOS)
153
154	files := map[string][]ContextFile{}
155
156	cfg := store.Config()
157	for _, pth := range cfg.Options.ContextPaths {
158		expanded := expandPath(pth, store)
159		pathKey := strings.ToLower(expanded)
160		if _, ok := files[pathKey]; ok {
161			continue
162		}
163		content := processContextPath(expanded, store)
164		files[pathKey] = content
165	}
166
167	// Discover and load skills metadata.
168	var availSkillXML string
169
170	// Start with builtin skills.
171	allSkills := skills.DiscoverBuiltin()
172	builtinNames := make(map[string]bool, len(allSkills))
173	for _, s := range allSkills {
174		builtinNames[s.Name] = true
175	}
176
177	// Discover user skills from configured paths.
178	if len(cfg.Options.SkillsPaths) > 0 {
179		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
180		for _, pth := range cfg.Options.SkillsPaths {
181			expandedPaths = append(expandedPaths, expandPath(pth, store))
182		}
183		for _, userSkill := range skills.Discover(expandedPaths) {
184			if builtinNames[userSkill.Name] {
185				slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
186			}
187			allSkills = append(allSkills, userSkill)
188		}
189	}
190
191	// Deduplicate: user skills override builtins with the same name.
192	allSkills = skills.Deduplicate(allSkills)
193
194	// Filter out disabled skills.
195	allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
196
197	if len(allSkills) > 0 {
198		availSkillXML = skills.ToPromptXML(allSkills)
199	}
200
201	isGit := isGitRepo(store.WorkingDir())
202	data := PromptDat{
203		Provider:      provider,
204		Model:         model,
205		Config:        *cfg,
206		WorkingDir:    filepath.ToSlash(workingDir),
207		IsGitRepo:     isGit,
208		Platform:      platform,
209		Date:          p.now().Format("1/2/2006"),
210		AvailSkillXML: availSkillXML,
211	}
212	if isGit {
213		var err error
214		data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
215		if err != nil {
216			return PromptDat{}, err
217		}
218	}
219
220	for _, contextFiles := range files {
221		data.ContextFiles = append(data.ContextFiles, contextFiles...)
222	}
223	return data, nil
224}
225
226func isGitRepo(dir string) bool {
227	_, err := os.Stat(filepath.Join(dir, ".git"))
228	return err == nil
229}
230
231func getGitStatus(ctx context.Context, dir string) (string, error) {
232	sh := shell.NewShell(&shell.Options{
233		WorkingDir: dir,
234	})
235	branch, err := getGitBranch(ctx, sh)
236	if err != nil {
237		return "", err
238	}
239	status, err := getGitStatusSummary(ctx, sh)
240	if err != nil {
241		return "", err
242	}
243	commits, err := getGitRecentCommits(ctx, sh)
244	if err != nil {
245		return "", err
246	}
247	return branch + status + commits, nil
248}
249
250func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
251	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
252	if err != nil {
253		return "", nil
254	}
255	out = strings.TrimSpace(out)
256	if out == "" {
257		return "", nil
258	}
259	return fmt.Sprintf("Current branch: %s\n", out), nil
260}
261
262func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
263	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
264	if err != nil {
265		return "", nil
266	}
267	out = strings.TrimSpace(out)
268	if out == "" {
269		return "Status: clean\n", nil
270	}
271	return fmt.Sprintf("Status:\n%s\n", out), nil
272}
273
274func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
275	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
276	if err != nil || out == "" {
277		return "", nil
278	}
279	out = strings.TrimSpace(out)
280	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
281}
282
283func (p *Prompt) Name() string {
284	return p.name
285}