prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"git.secluded.site/crush/internal/config"
 15	"git.secluded.site/crush/internal/home"
 16	"git.secluded.site/crush/internal/shell"
 17	"git.secluded.site/crush/internal/skills"
 18)
 19
 20// Prompt represents a template-based prompt generator.
 21type Prompt struct {
 22	name       string
 23	template   string
 24	now        func() time.Time
 25	platform   string
 26	workingDir string
 27}
 28
 29type PromptDat struct {
 30	Provider      string
 31	Model         string
 32	Config        config.Config
 33	WorkingDir    string
 34	IsGitRepo     bool
 35	Platform      string
 36	Date          string
 37	GitStatus     string
 38	ContextFiles  []ContextFile
 39	MemoryFiles   []ContextFile
 40	AvailSkillXML string
 41}
 42
 43type ContextFile struct {
 44	Path    string
 45	Content string
 46}
 47
 48type Option func(*Prompt)
 49
 50func WithTimeFunc(fn func() time.Time) Option {
 51	return func(p *Prompt) {
 52		p.now = fn
 53	}
 54}
 55
 56func WithPlatform(platform string) Option {
 57	return func(p *Prompt) {
 58		p.platform = platform
 59	}
 60}
 61
 62func WithWorkingDir(workingDir string) Option {
 63	return func(p *Prompt) {
 64		p.workingDir = workingDir
 65	}
 66}
 67
 68func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 69	p := &Prompt{
 70		name:     name,
 71		template: promptTemplate,
 72		now:      time.Now,
 73	}
 74	for _, opt := range opts {
 75		opt(p)
 76	}
 77	return p, nil
 78}
 79
 80func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 81	t, err := template.New(p.name).Parse(p.template)
 82	if err != nil {
 83		return "", fmt.Errorf("parsing template: %w", err)
 84	}
 85	var sb strings.Builder
 86	d, err := p.promptData(ctx, provider, model, cfg)
 87	if err != nil {
 88		return "", err
 89	}
 90	if err := t.Execute(&sb, d); err != nil {
 91		return "", fmt.Errorf("executing template: %w", err)
 92	}
 93
 94	return sb.String(), nil
 95}
 96
 97func processFile(filePath string) *ContextFile {
 98	content, err := os.ReadFile(filePath)
 99	if err != nil {
100		return nil
101	}
102	return &ContextFile{
103		Path:    filePath,
104		Content: string(content),
105	}
106}
107
108func processContextPath(p string, cfg config.Config) []ContextFile {
109	var contexts []ContextFile
110	fullPath := p
111	if !filepath.IsAbs(p) {
112		fullPath = filepath.Join(cfg.WorkingDir(), p)
113	}
114	info, err := os.Stat(fullPath)
115	if err != nil {
116		return contexts
117	}
118	if info.IsDir() {
119		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
120			if err != nil {
121				return err
122			}
123			if !d.IsDir() {
124				if result := processFile(path); result != nil {
125					contexts = append(contexts, *result)
126				}
127			}
128			return nil
129		})
130	} else {
131		result := processFile(fullPath)
132		if result != nil {
133			contexts = append(contexts, *result)
134		}
135	}
136	return contexts
137}
138
139// expandPath expands ~ and environment variables in file paths
140func expandPath(path string, cfg config.Config) string {
141	path = home.Long(path)
142	// Handle environment variable expansion using the same pattern as config
143	if strings.HasPrefix(path, "$") {
144		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
145			path = expanded
146		}
147	}
148
149	return path
150}
151
152func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
153	workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
154	platform := cmp.Or(p.platform, runtime.GOOS)
155
156	contextFiles := map[string][]ContextFile{}
157	memoryFiles := map[string][]ContextFile{}
158
159	for _, pth := range cfg.Options.ContextPaths {
160		expanded := expandPath(pth, cfg)
161		pathKey := strings.ToLower(expanded)
162		if _, ok := contextFiles[pathKey]; ok {
163			continue
164		}
165		content := processContextPath(expanded, cfg)
166		contextFiles[pathKey] = content
167	}
168
169	for _, pth := range cfg.Options.MemoryPaths {
170		expanded := expandPath(pth, cfg)
171		pathKey := strings.ToLower(expanded)
172		if _, ok := memoryFiles[pathKey]; ok {
173			continue
174		}
175		content := processContextPath(expanded, cfg)
176		memoryFiles[pathKey] = content
177	}
178
179	// Discover and load skills metadata.
180	var availSkillXML string
181	if len(cfg.Options.SkillsPaths) > 0 {
182		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
183		for _, pth := range cfg.Options.SkillsPaths {
184			expandedPaths = append(expandedPaths, expandPath(pth, cfg))
185		}
186		if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
187			availSkillXML = skills.ToPromptXML(discoveredSkills)
188		}
189	}
190
191	isGit := isGitRepo(cfg.WorkingDir())
192	data := PromptDat{
193		Provider:      provider,
194		Model:         model,
195		Config:        cfg,
196		WorkingDir:    filepath.ToSlash(workingDir),
197		IsGitRepo:     isGit,
198		Platform:      platform,
199		Date:          p.now().Format("1/2/2006"),
200		AvailSkillXML: availSkillXML,
201	}
202	if isGit {
203		var err error
204		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
205		if err != nil {
206			return PromptDat{}, err
207		}
208	}
209
210	for _, files := range contextFiles {
211		data.ContextFiles = append(data.ContextFiles, files...)
212	}
213	for _, files := range memoryFiles {
214		data.MemoryFiles = append(data.MemoryFiles, files...)
215	}
216	return data, nil
217}
218
219func isGitRepo(dir string) bool {
220	_, err := os.Stat(filepath.Join(dir, ".git"))
221	return err == nil
222}
223
224func getGitStatus(ctx context.Context, dir string) (string, error) {
225	sh := shell.NewShell(&shell.Options{
226		WorkingDir: dir,
227	})
228	branch, err := getGitBranch(ctx, sh)
229	if err != nil {
230		return "", err
231	}
232	status, err := getGitStatusSummary(ctx, sh)
233	if err != nil {
234		return "", err
235	}
236	commits, err := getGitRecentCommits(ctx, sh)
237	if err != nil {
238		return "", err
239	}
240	return branch + status + commits, nil
241}
242
243func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
244	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
245	if err != nil {
246		return "", nil
247	}
248	out = strings.TrimSpace(out)
249	if out == "" {
250		return "", nil
251	}
252	return fmt.Sprintf("Current branch: %s\n", out), nil
253}
254
255func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
256	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
257	if err != nil {
258		return "", nil
259	}
260	out = strings.TrimSpace(out)
261	if out == "" {
262		return "Status: clean\n", nil
263	}
264	return fmt.Sprintf("Status:\n%s\n", out), nil
265}
266
267func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
268	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
269	if err != nil || out == "" {
270		return "", nil
271	}
272	out = strings.TrimSpace(out)
273	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
274}
275
276func (p *Prompt) Name() string {
277	return p.name
278}