prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/home"
 16	"github.com/charmbracelet/crush/internal/shell"
 17	"github.com/charmbracelet/crush/internal/skills"
 18)
 19
 20// Prompt represents a template-based prompt generator.
 21type Prompt struct {
 22	name       string
 23	template   string
 24	now        func() time.Time
 25	platform   string
 26	workingDir string
 27}
 28
 29type PromptDat struct {
 30	Provider      string
 31	Model         string
 32	Config        promptConfig
 33	WorkingDir    string
 34	IsGitRepo     bool
 35	Platform      string
 36	Date          string
 37	GitStatus     string
 38	ContextFiles  []ContextFile
 39	AvailSkillXML string
 40}
 41
 42type promptConfig struct {
 43	LSP config.LSPs
 44}
 45
 46type ContextFile struct {
 47	Path    string
 48	Content string
 49}
 50
 51type Option func(*Prompt)
 52
 53func WithTimeFunc(fn func() time.Time) Option {
 54	return func(p *Prompt) {
 55		p.now = fn
 56	}
 57}
 58
 59func WithPlatform(platform string) Option {
 60	return func(p *Prompt) {
 61		p.platform = platform
 62	}
 63}
 64
 65func WithWorkingDir(workingDir string) Option {
 66	return func(p *Prompt) {
 67		p.workingDir = workingDir
 68	}
 69}
 70
 71func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 72	p := &Prompt{
 73		name:     name,
 74		template: promptTemplate,
 75		now:      time.Now,
 76	}
 77	for _, opt := range opts {
 78		opt(p)
 79	}
 80	return p, nil
 81}
 82
 83func (p *Prompt) Build(ctx context.Context, provider, model string, svc *config.Service) (string, error) {
 84	t, err := template.New(p.name).Parse(p.template)
 85	if err != nil {
 86		return "", fmt.Errorf("parsing template: %w", err)
 87	}
 88	var sb strings.Builder
 89	d, err := p.promptData(ctx, provider, model, svc)
 90	if err != nil {
 91		return "", err
 92	}
 93	if err := t.Execute(&sb, d); err != nil {
 94		return "", fmt.Errorf("executing template: %w", err)
 95	}
 96
 97	return sb.String(), nil
 98}
 99
100func processFile(filePath string) *ContextFile {
101	content, err := os.ReadFile(filePath)
102	if err != nil {
103		return nil
104	}
105	return &ContextFile{
106		Path:    filePath,
107		Content: string(content),
108	}
109}
110
111func processContextPath(p string, workingDir string) []ContextFile {
112	var contexts []ContextFile
113	fullPath := p
114	if !filepath.IsAbs(p) {
115		fullPath = filepath.Join(workingDir, p)
116	}
117	info, err := os.Stat(fullPath)
118	if err != nil {
119		return contexts
120	}
121	if info.IsDir() {
122		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
123			if err != nil {
124				return err
125			}
126			if !d.IsDir() {
127				if result := processFile(path); result != nil {
128					contexts = append(contexts, *result)
129				}
130			}
131			return nil
132		})
133	} else {
134		result := processFile(fullPath)
135		if result != nil {
136			contexts = append(contexts, *result)
137		}
138	}
139	return contexts
140}
141
142// expandPath expands ~ and environment variables in file paths
143func expandPath(path string, resolver config.VariableResolver) string {
144	path = home.Long(path)
145	if strings.HasPrefix(path, "$") {
146		if resolver != nil {
147			if expanded, err := resolver.ResolveValue(path); err == nil {
148				path = expanded
149			}
150		}
151	}
152
153	return path
154}
155
156func (p *Prompt) promptData(ctx context.Context, provider, model string, svc *config.Service) (PromptDat, error) {
157	workingDir := cmp.Or(p.workingDir, svc.WorkingDir())
158	platform := cmp.Or(p.platform, runtime.GOOS)
159
160	files := map[string][]ContextFile{}
161
162	for _, pth := range svc.ContextPaths() {
163		expanded := expandPath(pth, svc.Resolver())
164		pathKey := strings.ToLower(expanded)
165		if _, ok := files[pathKey]; ok {
166			continue
167		}
168		content := processContextPath(expanded, svc.WorkingDir())
169		files[pathKey] = content
170	}
171
172	var availSkillXML string
173	if len(svc.SkillsPaths()) > 0 {
174		expandedPaths := make([]string, 0, len(svc.SkillsPaths()))
175		for _, pth := range svc.SkillsPaths() {
176			expandedPaths = append(expandedPaths, expandPath(pth, svc.Resolver()))
177		}
178		if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
179			availSkillXML = skills.ToPromptXML(discoveredSkills)
180		}
181	}
182
183	isGit := isGitRepo(svc.WorkingDir())
184	data := PromptDat{
185		Provider:      provider,
186		Model:         model,
187		Config:        promptConfig{LSP: svc.LSP()},
188		WorkingDir:    filepath.ToSlash(workingDir),
189		IsGitRepo:     isGit,
190		Platform:      platform,
191		Date:          p.now().Format("1/2/2006"),
192		AvailSkillXML: availSkillXML,
193	}
194	if isGit {
195		var err error
196		data.GitStatus, err = getGitStatus(ctx, svc.WorkingDir())
197		if err != nil {
198			return PromptDat{}, err
199		}
200	}
201
202	for _, contextFiles := range files {
203		data.ContextFiles = append(data.ContextFiles, contextFiles...)
204	}
205	return data, nil
206}
207
208func isGitRepo(dir string) bool {
209	_, err := os.Stat(filepath.Join(dir, ".git"))
210	return err == nil
211}
212
213func getGitStatus(ctx context.Context, dir string) (string, error) {
214	sh := shell.NewShell(&shell.Options{
215		WorkingDir: dir,
216	})
217	branch, err := getGitBranch(ctx, sh)
218	if err != nil {
219		return "", err
220	}
221	status, err := getGitStatusSummary(ctx, sh)
222	if err != nil {
223		return "", err
224	}
225	commits, err := getGitRecentCommits(ctx, sh)
226	if err != nil {
227		return "", err
228	}
229	return branch + status + commits, nil
230}
231
232func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
233	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
234	if err != nil {
235		return "", nil
236	}
237	out = strings.TrimSpace(out)
238	if out == "" {
239		return "", nil
240	}
241	return fmt.Sprintf("Current branch: %s\n", out), nil
242}
243
244func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
245	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
246	if err != nil {
247		return "", nil
248	}
249	out = strings.TrimSpace(out)
250	if out == "" {
251		return "Status: clean\n", nil
252	}
253	return fmt.Sprintf("Status:\n%s\n", out), nil
254}
255
256func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
257	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
258	if err != nil || out == "" {
259		return "", nil
260	}
261	out = strings.TrimSpace(out)
262	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
263}
264
265func (p *Prompt) Name() string {
266	return p.name
267}