prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"git.secluded.site/crush/internal/config"
 15	"git.secluded.site/crush/internal/home"
 16	"git.secluded.site/crush/internal/shell"
 17	"git.secluded.site/crush/internal/skills"
 18)
 19
 20// Prompt represents a template-based prompt generator.
 21type Prompt struct {
 22	name        string
 23	template    string
 24	now         func() time.Time
 25	platform    string
 26	workingDir  string
 27	modelFamily ModelFamily
 28}
 29
 30type PromptDat struct {
 31	Provider      string
 32	Model         string
 33	ModelFamily   ModelFamily
 34	Config        config.Config
 35	WorkingDir    string
 36	IsGitRepo     bool
 37	Platform      string
 38	Date          string
 39	GitStatus     string
 40	ContextFiles  []ContextFile
 41	MemoryFiles   []ContextFile
 42	AvailSkillXML string
 43}
 44
 45type ContextFile struct {
 46	Path    string
 47	Content string
 48}
 49
 50type Option func(*Prompt)
 51
 52func WithTimeFunc(fn func() time.Time) Option {
 53	return func(p *Prompt) {
 54		p.now = fn
 55	}
 56}
 57
 58func WithPlatform(platform string) Option {
 59	return func(p *Prompt) {
 60		p.platform = platform
 61	}
 62}
 63
 64func WithWorkingDir(workingDir string) Option {
 65	return func(p *Prompt) {
 66		p.workingDir = workingDir
 67	}
 68}
 69
 70func WithModelFamily(modelFamily ModelFamily) Option {
 71	return func(p *Prompt) {
 72		p.modelFamily = modelFamily
 73	}
 74}
 75
 76func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 77	p := &Prompt{
 78		name:     name,
 79		template: promptTemplate,
 80		now:      time.Now,
 81	}
 82	for _, opt := range opts {
 83		opt(p)
 84	}
 85	return p, nil
 86}
 87
 88func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 89	t, err := template.New(p.name).Parse(p.template)
 90	if err != nil {
 91		return "", fmt.Errorf("parsing template: %w", err)
 92	}
 93	var sb strings.Builder
 94	d, err := p.promptData(ctx, provider, model, cfg)
 95	if err != nil {
 96		return "", err
 97	}
 98	if err := t.Execute(&sb, d); err != nil {
 99		return "", fmt.Errorf("executing template: %w", err)
100	}
101
102	return sb.String(), nil
103}
104
105func processFile(filePath string) *ContextFile {
106	content, err := os.ReadFile(filePath)
107	if err != nil {
108		return nil
109	}
110	return &ContextFile{
111		Path:    filePath,
112		Content: string(content),
113	}
114}
115
116func processContextPath(p string, cfg config.Config) []ContextFile {
117	var contexts []ContextFile
118	fullPath := p
119	if !filepath.IsAbs(p) {
120		fullPath = filepath.Join(cfg.WorkingDir(), p)
121	}
122	info, err := os.Stat(fullPath)
123	if err != nil {
124		return contexts
125	}
126	if info.IsDir() {
127		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
128			if err != nil {
129				return err
130			}
131			if !d.IsDir() {
132				if result := processFile(path); result != nil {
133					contexts = append(contexts, *result)
134				}
135			}
136			return nil
137		})
138	} else {
139		result := processFile(fullPath)
140		if result != nil {
141			contexts = append(contexts, *result)
142		}
143	}
144	return contexts
145}
146
147// expandPath expands ~ and environment variables in file paths
148func expandPath(path string, cfg config.Config) string {
149	path = home.Long(path)
150	// Handle environment variable expansion using the same pattern as config
151	if strings.HasPrefix(path, "$") {
152		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
153			path = expanded
154		}
155	}
156
157	return path
158}
159
160func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
161	workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
162	platform := cmp.Or(p.platform, runtime.GOOS)
163
164	contextFiles := map[string][]ContextFile{}
165	memoryFiles := map[string][]ContextFile{}
166
167	for _, pth := range cfg.Options.ContextPaths {
168		expanded := expandPath(pth, cfg)
169		pathKey := strings.ToLower(expanded)
170		if _, ok := contextFiles[pathKey]; ok {
171			continue
172		}
173		content := processContextPath(expanded, cfg)
174		contextFiles[pathKey] = content
175	}
176
177	for _, pth := range cfg.Options.MemoryPaths {
178		expanded := expandPath(pth, cfg)
179		pathKey := strings.ToLower(expanded)
180		if _, ok := memoryFiles[pathKey]; ok {
181			continue
182		}
183		content := processContextPath(expanded, cfg)
184		memoryFiles[pathKey] = content
185	}
186
187	// Discover and load skills metadata.
188	var availSkillXML string
189	if len(cfg.Options.SkillsPaths) > 0 {
190		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
191		for _, pth := range cfg.Options.SkillsPaths {
192			expandedPaths = append(expandedPaths, expandPath(pth, cfg))
193		}
194		if discoveredSkills := skills.Discover(expandedPaths); len(discoveredSkills) > 0 {
195			availSkillXML = skills.ToPromptXML(discoveredSkills)
196		}
197	}
198
199	isGit := isGitRepo(cfg.WorkingDir())
200	data := PromptDat{
201		Provider:      provider,
202		Model:         model,
203		ModelFamily:   p.modelFamily,
204		Config:        cfg,
205		WorkingDir:    filepath.ToSlash(workingDir),
206		IsGitRepo:     isGit,
207		Platform:      platform,
208		Date:          p.now().Format("1/2/2006"),
209		AvailSkillXML: availSkillXML,
210	}
211	if isGit {
212		var err error
213		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
214		if err != nil {
215			return PromptDat{}, err
216		}
217	}
218
219	for _, files := range contextFiles {
220		data.ContextFiles = append(data.ContextFiles, files...)
221	}
222	for _, files := range memoryFiles {
223		data.MemoryFiles = append(data.MemoryFiles, files...)
224	}
225	return data, nil
226}
227
228func isGitRepo(dir string) bool {
229	_, err := os.Stat(filepath.Join(dir, ".git"))
230	return err == nil
231}
232
233func getGitStatus(ctx context.Context, dir string) (string, error) {
234	sh := shell.NewShell(&shell.Options{
235		WorkingDir: dir,
236	})
237	branch, err := getGitBranch(ctx, sh)
238	if err != nil {
239		return "", err
240	}
241	status, err := getGitStatusSummary(ctx, sh)
242	if err != nil {
243		return "", err
244	}
245	commits, err := getGitRecentCommits(ctx, sh)
246	if err != nil {
247		return "", err
248	}
249	return branch + status + commits, nil
250}
251
252func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
253	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
254	if err != nil {
255		return "", nil
256	}
257	out = strings.TrimSpace(out)
258	if out == "" {
259		return "", nil
260	}
261	return fmt.Sprintf("Current branch: %s\n", out), nil
262}
263
264func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
265	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
266	if err != nil {
267		return "", nil
268	}
269	out = strings.TrimSpace(out)
270	if out == "" {
271		return "Status: clean\n", nil
272	}
273	return fmt.Sprintf("Status:\n%s\n", out), nil
274}
275
276func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
277	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
278	if err != nil || out == "" {
279		return "", nil
280	}
281	out = strings.TrimSpace(out)
282	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
283}
284
285func (p *Prompt) Name() string {
286	return p.name
287}