prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"testing"
 13	"text/template"
 14	"time"
 15
 16	"git.secluded.site/crush/internal/config"
 17	"git.secluded.site/crush/internal/home"
 18	"git.secluded.site/crush/internal/shell"
 19	"git.secluded.site/crush/internal/skills"
 20)
 21
 22// Prompt represents a template-based prompt generator.
 23type Prompt struct {
 24	name       string
 25	template   string
 26	now        func() time.Time
 27	platform   string
 28	workingDir string
 29}
 30
 31type PromptDat struct {
 32	Provider           string
 33	Model              string
 34	Config             config.Config
 35	WorkingDir         string
 36	IsGitRepo          bool
 37	Platform           string
 38	Date               string
 39	GitStatus          string
 40	ContextFiles       []ContextFile
 41	GlobalContextFiles []ContextFile
 42	AvailSkillXML      string
 43}
 44
 45type ContextFile struct {
 46	Path    string
 47	Content string
 48}
 49
 50type Option func(*Prompt)
 51
 52func WithTimeFunc(fn func() time.Time) Option {
 53	return func(p *Prompt) {
 54		p.now = fn
 55	}
 56}
 57
 58func WithPlatform(platform string) Option {
 59	return func(p *Prompt) {
 60		p.platform = platform
 61	}
 62}
 63
 64func WithWorkingDir(workingDir string) Option {
 65	return func(p *Prompt) {
 66		p.workingDir = workingDir
 67	}
 68}
 69
 70func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 71	p := &Prompt{
 72		name:     name,
 73		template: promptTemplate,
 74		now:      time.Now,
 75	}
 76	for _, opt := range opts {
 77		opt(p)
 78	}
 79	return p, nil
 80}
 81
 82func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
 83	t, err := template.New(p.name).Parse(p.template)
 84	if err != nil {
 85		return "", fmt.Errorf("parsing template: %w", err)
 86	}
 87	var sb strings.Builder
 88	d, err := p.promptData(ctx, provider, model, store)
 89	if err != nil {
 90		return "", err
 91	}
 92	if err := t.Execute(&sb, d); err != nil {
 93		return "", fmt.Errorf("executing template: %w", err)
 94	}
 95
 96	return sb.String(), nil
 97}
 98
 99func processFile(filePath string) *ContextFile {
100	content, err := os.ReadFile(filePath)
101	if err != nil {
102		return nil
103	}
104	return &ContextFile{
105		Path:    filePath,
106		Content: string(content),
107	}
108}
109
110func processContextPath(p string, store *config.ConfigStore) []ContextFile {
111	var contexts []ContextFile
112	fullPath := p
113	if !filepath.IsAbs(p) {
114		fullPath = filepath.Join(store.WorkingDir(), p)
115	}
116	info, err := os.Stat(fullPath)
117	if err != nil {
118		return contexts
119	}
120	if info.IsDir() {
121		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
122			if err != nil {
123				return err
124			}
125			if !d.IsDir() {
126				if result := processFile(path); result != nil {
127					contexts = append(contexts, *result)
128				}
129			}
130			return nil
131		})
132	} else {
133		result := processFile(fullPath)
134		if result != nil {
135			contexts = append(contexts, *result)
136		}
137	}
138	return contexts
139}
140
141// expandPath expands ~ and environment variables in file paths
142func expandPath(path string, store *config.ConfigStore) string {
143	path = home.Long(path)
144	// Handle environment variable expansion using the same pattern as config
145	if strings.HasPrefix(path, "$") {
146		if expanded, err := store.Resolver().ResolveValue(path); err == nil {
147			path = expanded
148		}
149	}
150
151	return path
152}
153
154func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
155	workingDir := cmp.Or(p.workingDir, store.WorkingDir())
156	platform := cmp.Or(p.platform, runtime.GOOS)
157
158	contextFiles := map[string][]ContextFile{}
159	globalContextFiles := map[string][]ContextFile{}
160
161	cfg := store.Config()
162	for _, pth := range cfg.Options.ContextPaths {
163		expanded := expandPath(pth, store)
164		pathKey := strings.ToLower(expanded)
165		if _, ok := contextFiles[pathKey]; ok {
166			continue
167		}
168		content := processContextPath(expanded, store)
169		contextFiles[pathKey] = content
170	}
171
172	for _, pth := range cfg.Options.GlobalContextPaths {
173		expanded := expandPath(pth, store)
174		pathKey := strings.ToLower(expanded)
175		if _, ok := globalContextFiles[pathKey]; ok {
176			continue
177		}
178		content := processContextPath(expanded, store)
179		globalContextFiles[pathKey] = content
180	}
181
182	// Discover and load skills metadata.
183	var availSkillXML string
184
185	// Start with builtin skills.
186	allSkills := skills.DiscoverBuiltin()
187	builtinNames := make(map[string]bool, len(allSkills))
188	for _, s := range allSkills {
189		builtinNames[s.Name] = true
190	}
191
192	// Discover user skills from configured paths.
193	if len(cfg.Options.SkillsPaths) > 0 {
194		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
195		for _, pth := range cfg.Options.SkillsPaths {
196			expandedPaths = append(expandedPaths, expandPath(pth, store))
197		}
198		for _, userSkill := range skills.Discover(expandedPaths) {
199			if builtinNames[userSkill.Name] {
200				slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
201			}
202			allSkills = append(allSkills, userSkill)
203		}
204	}
205
206	// Deduplicate: user skills override builtins with the same name.
207	allSkills = skills.Deduplicate(allSkills)
208
209	// Filter out disabled skills.
210	allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
211
212	if len(allSkills) > 0 {
213		availSkillXML = skills.ToPromptXML(allSkills)
214	}
215
216	isGit := isGitRepo(store.WorkingDir())
217	data := PromptDat{
218		Provider:      provider,
219		Model:         model,
220		Config:        *cfg,
221		WorkingDir:    filepath.ToSlash(workingDir),
222		IsGitRepo:     isGit,
223		Platform:      platform,
224		Date:          p.now().Format("1/2/2006"),
225		AvailSkillXML: availSkillXML,
226	}
227	if isGit {
228		var err error
229		data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
230		if err != nil {
231			return PromptDat{}, err
232		}
233	}
234
235	for _, files := range contextFiles {
236		data.ContextFiles = append(data.ContextFiles, files...)
237	}
238	if !testing.Testing() {
239		for _, files := range globalContextFiles {
240			data.GlobalContextFiles = append(data.GlobalContextFiles, files...)
241		}
242	}
243	return data, nil
244}
245
246func isGitRepo(dir string) bool {
247	_, err := os.Stat(filepath.Join(dir, ".git"))
248	return err == nil
249}
250
251func getGitStatus(ctx context.Context, dir string) (string, error) {
252	sh := shell.NewShell(&shell.Options{
253		WorkingDir: dir,
254	})
255	branch, err := getGitBranch(ctx, sh)
256	if err != nil {
257		return "", err
258	}
259	status, err := getGitStatusSummary(ctx, sh)
260	if err != nil {
261		return "", err
262	}
263	commits, err := getGitRecentCommits(ctx, sh)
264	if err != nil {
265		return "", err
266	}
267	return branch + status + commits, nil
268}
269
270func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
271	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
272	if err != nil {
273		return "", nil
274	}
275	out = strings.TrimSpace(out)
276	if out == "" {
277		return "", nil
278	}
279	return fmt.Sprintf("Current branch: %s\n", out), nil
280}
281
282func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
283	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
284	if err != nil {
285		return "", nil
286	}
287	out = strings.TrimSpace(out)
288	if out == "" {
289		return "Status: clean\n", nil
290	}
291	return fmt.Sprintf("Status:\n%s\n", out), nil
292}
293
294func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
295	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
296	if err != nil || out == "" {
297		return "", nil
298	}
299	out = strings.TrimSpace(out)
300	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
301}
302
303func (p *Prompt) Name() string {
304	return p.name
305}