prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"text/template"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/home"
 17	"github.com/charmbracelet/crush/internal/shell"
 18	"github.com/charmbracelet/crush/internal/skills"
 19)
 20
 21// Prompt represents a template-based prompt generator.
 22type Prompt struct {
 23	name       string
 24	template   string
 25	now        func() time.Time
 26	platform   string
 27	workingDir string
 28}
 29
 30type PromptDat struct {
 31	Provider      string
 32	Model         string
 33	Config        config.Config
 34	WorkingDir    string
 35	IsGitRepo     bool
 36	Platform      string
 37	Date          string
 38	GitStatus     string
 39	ContextFiles  []ContextFile
 40	AvailSkillXML string
 41}
 42
 43type ContextFile struct {
 44	Path    string
 45	Content string
 46}
 47
 48type Option func(*Prompt)
 49
 50func WithTimeFunc(fn func() time.Time) Option {
 51	return func(p *Prompt) {
 52		p.now = fn
 53	}
 54}
 55
 56func WithPlatform(platform string) Option {
 57	return func(p *Prompt) {
 58		p.platform = platform
 59	}
 60}
 61
 62func WithWorkingDir(workingDir string) Option {
 63	return func(p *Prompt) {
 64		p.workingDir = workingDir
 65	}
 66}
 67
 68func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 69	p := &Prompt{
 70		name:     name,
 71		template: promptTemplate,
 72		now:      time.Now,
 73	}
 74	for _, opt := range opts {
 75		opt(p)
 76	}
 77	return p, nil
 78}
 79
 80func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
 81	t, err := template.New(p.name).Parse(p.template)
 82	if err != nil {
 83		return "", fmt.Errorf("parsing template: %w", err)
 84	}
 85	var sb strings.Builder
 86	d, err := p.promptData(ctx, provider, model, store)
 87	if err != nil {
 88		return "", err
 89	}
 90	if err := t.Execute(&sb, d); err != nil {
 91		return "", fmt.Errorf("executing template: %w", err)
 92	}
 93
 94	return sb.String(), nil
 95}
 96
 97func processFile(filePath string) *ContextFile {
 98	content, err := os.ReadFile(filePath)
 99	if err != nil {
100		return nil
101	}
102	return &ContextFile{
103		Path:    filePath,
104		Content: string(content),
105	}
106}
107
108func processContextPath(p string, store *config.ConfigStore) []ContextFile {
109	var contexts []ContextFile
110	fullPath := p
111	if !filepath.IsAbs(p) {
112		fullPath = filepath.Join(store.WorkingDir(), p)
113	}
114	info, err := os.Stat(fullPath)
115	if err != nil {
116		return contexts
117	}
118	if info.IsDir() {
119		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
120			if err != nil {
121				return err
122			}
123			if !d.IsDir() {
124				if result := processFile(path); result != nil {
125					contexts = append(contexts, *result)
126				}
127			}
128			return nil
129		})
130	} else {
131		result := processFile(fullPath)
132		if result != nil {
133			contexts = append(contexts, *result)
134		}
135	}
136	return contexts
137}
138
139// expandPath expands ~ and environment variables in file paths
140func expandPath(path string, store *config.ConfigStore) string {
141	path = home.Long(path)
142	// Handle environment variable expansion using the same pattern as config
143	if strings.HasPrefix(path, "$") {
144		if expanded, err := store.Resolver().ResolveValue(path); err == nil {
145			path = expanded
146		}
147	}
148
149	return path
150}
151
152func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
153	workingDir := cmp.Or(p.workingDir, store.WorkingDir())
154	platform := cmp.Or(p.platform, runtime.GOOS)
155
156	files := map[string][]ContextFile{}
157
158	cfg := store.Config()
159	for _, pth := range cfg.Options.ContextPaths {
160		expanded := expandPath(pth, store)
161		pathKey := strings.ToLower(expanded)
162		if _, ok := files[pathKey]; ok {
163			continue
164		}
165		content := processContextPath(expanded, store)
166		files[pathKey] = content
167	}
168
169	// Discover and load skills metadata.
170	var availSkillXML string
171
172	// Start with builtin skills.
173	allSkills := skills.DiscoverBuiltin()
174	builtinNames := make(map[string]bool, len(allSkills))
175	for _, s := range allSkills {
176		builtinNames[s.Name] = true
177	}
178
179	// Discover user skills from configured paths.
180	if len(cfg.Options.SkillsPaths) > 0 {
181		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
182		for _, pth := range cfg.Options.SkillsPaths {
183			expandedPaths = append(expandedPaths, expandPath(pth, store))
184		}
185		for _, userSkill := range skills.Discover(expandedPaths) {
186			if builtinNames[userSkill.Name] {
187				slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
188			}
189			allSkills = append(allSkills, userSkill)
190		}
191	}
192
193	// Deduplicate: user skills override builtins with the same name.
194	allSkills = skills.Deduplicate(allSkills)
195
196	// Filter out disabled skills.
197	allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
198
199	if len(allSkills) > 0 {
200		availSkillXML = skills.ToPromptXML(allSkills)
201	}
202
203	isGit := isGitRepo(store.WorkingDir())
204	data := PromptDat{
205		Provider:      provider,
206		Model:         model,
207		Config:        *cfg,
208		WorkingDir:    filepath.ToSlash(workingDir),
209		IsGitRepo:     isGit,
210		Platform:      platform,
211		Date:          p.now().Format("1/2/2006"),
212		AvailSkillXML: availSkillXML,
213	}
214	if isGit {
215		var err error
216		data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
217		if err != nil {
218			return PromptDat{}, err
219		}
220	}
221
222	for _, contextFiles := range files {
223		data.ContextFiles = append(data.ContextFiles, contextFiles...)
224	}
225	return data, nil
226}
227
228func isGitRepo(dir string) bool {
229	_, err := os.Stat(filepath.Join(dir, ".git"))
230	return err == nil
231}
232
233func getGitStatus(ctx context.Context, dir string) (string, error) {
234	sh := shell.NewShell(&shell.Options{
235		WorkingDir: dir,
236	})
237	branch, err := getGitBranch(ctx, sh)
238	if err != nil {
239		return "", err
240	}
241	status, err := getGitStatusSummary(ctx, sh)
242	if err != nil {
243		return "", err
244	}
245	commits, err := getGitRecentCommits(ctx, sh)
246	if err != nil {
247		return "", err
248	}
249	return branch + status + commits, nil
250}
251
252func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
253	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
254	if err != nil {
255		return "", nil
256	}
257	out = strings.TrimSpace(out)
258	if out == "" {
259		return "", nil
260	}
261	return fmt.Sprintf("Current branch: %s\n", out), nil
262}
263
264func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
265	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
266	if err != nil {
267		return "", nil
268	}
269	out = strings.TrimSpace(out)
270	if out == "" {
271		return "Status: clean\n", nil
272	}
273	return fmt.Sprintf("Status:\n%s\n", out), nil
274}
275
276func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
277	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
278	if err != nil || out == "" {
279		return "", nil
280	}
281	out = strings.TrimSpace(out)
282	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
283}
284
285func (p *Prompt) Name() string {
286	return p.name
287}