prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"log/slog"
  8	"os"
  9	"path/filepath"
 10	"runtime"
 11	"strings"
 12	"testing"
 13	"text/template"
 14	"time"
 15
 16	"git.secluded.site/crush/internal/config"
 17	"git.secluded.site/crush/internal/filepathext"
 18	"git.secluded.site/crush/internal/home"
 19	"git.secluded.site/crush/internal/shell"
 20	"git.secluded.site/crush/internal/skills"
 21)
 22
 23// Prompt represents a template-based prompt generator.
 24type Prompt struct {
 25	name       string
 26	template   string
 27	now        func() time.Time
 28	platform   string
 29	workingDir string
 30}
 31
 32type PromptDat struct {
 33	Provider           string
 34	Model              string
 35	Config             config.Config
 36	WorkingDir         string
 37	IsGitRepo          bool
 38	Platform           string
 39	Date               string
 40	GitStatus          string
 41	ContextFiles       []ContextFile
 42	GlobalContextFiles []ContextFile
 43	AvailSkillXML      string
 44}
 45
 46type ContextFile struct {
 47	Path    string
 48	Content string
 49}
 50
 51type Option func(*Prompt)
 52
 53func WithTimeFunc(fn func() time.Time) Option {
 54	return func(p *Prompt) {
 55		p.now = fn
 56	}
 57}
 58
 59func WithPlatform(platform string) Option {
 60	return func(p *Prompt) {
 61		p.platform = platform
 62	}
 63}
 64
 65func WithWorkingDir(workingDir string) Option {
 66	return func(p *Prompt) {
 67		p.workingDir = workingDir
 68	}
 69}
 70
 71func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 72	p := &Prompt{
 73		name:     name,
 74		template: promptTemplate,
 75		now:      time.Now,
 76	}
 77	for _, opt := range opts {
 78		opt(p)
 79	}
 80	return p, nil
 81}
 82
 83func (p *Prompt) Build(ctx context.Context, provider, model string, store *config.ConfigStore) (string, error) {
 84	t, err := template.New(p.name).Parse(p.template)
 85	if err != nil {
 86		return "", fmt.Errorf("parsing template: %w", err)
 87	}
 88	var sb strings.Builder
 89	d, err := p.promptData(ctx, provider, model, store)
 90	if err != nil {
 91		return "", err
 92	}
 93	if err := t.Execute(&sb, d); err != nil {
 94		return "", fmt.Errorf("executing template: %w", err)
 95	}
 96
 97	return sb.String(), nil
 98}
 99
100func processFile(filePath string) *ContextFile {
101	content, err := os.ReadFile(filePath)
102	if err != nil {
103		return nil
104	}
105	return &ContextFile{
106		Path:    filePath,
107		Content: string(content),
108	}
109}
110
111func processContextPath(p string, store *config.ConfigStore) []ContextFile {
112	var contexts []ContextFile
113	fullPath := filepathext.SmartJoin(store.WorkingDir(), p)
114	info, err := os.Stat(fullPath)
115	if err != nil {
116		return contexts
117	}
118	if info.IsDir() {
119		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
120			if err != nil {
121				return err
122			}
123			if !d.IsDir() {
124				if result := processFile(path); result != nil {
125					contexts = append(contexts, *result)
126				}
127			}
128			return nil
129		})
130	} else {
131		result := processFile(fullPath)
132		if result != nil {
133			contexts = append(contexts, *result)
134		}
135	}
136	return contexts
137}
138
139// expandPath expands ~ and environment variables in file paths
140func expandPath(path string, store *config.ConfigStore) string {
141	path = home.Long(path)
142	// Handle environment variable expansion using the same pattern as config
143	if strings.HasPrefix(path, "$") {
144		if expanded, err := store.Resolver().ResolveValue(path); err == nil {
145			path = expanded
146		}
147	}
148
149	return path
150}
151
152func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) {
153	workingDir := cmp.Or(p.workingDir, store.WorkingDir())
154	platform := cmp.Or(p.platform, runtime.GOOS)
155
156	contextFiles := map[string][]ContextFile{}
157	globalContextFiles := map[string][]ContextFile{}
158
159	cfg := store.Config()
160	for _, pth := range cfg.Options.ContextPaths {
161		expanded := expandPath(pth, store)
162		pathKey := strings.ToLower(expanded)
163		if _, ok := contextFiles[pathKey]; ok {
164			continue
165		}
166		content := processContextPath(expanded, store)
167		contextFiles[pathKey] = content
168	}
169
170	for _, pth := range cfg.Options.GlobalContextPaths {
171		expanded := expandPath(pth, store)
172		pathKey := strings.ToLower(expanded)
173		if _, ok := globalContextFiles[pathKey]; ok {
174			continue
175		}
176		content := processContextPath(expanded, store)
177		globalContextFiles[pathKey] = content
178	}
179
180	// Discover and load skills metadata.
181	var availSkillXML string
182
183	// Start with builtin skills.
184	allSkills := skills.DiscoverBuiltin()
185	builtinNames := make(map[string]bool, len(allSkills))
186	for _, s := range allSkills {
187		builtinNames[s.Name] = true
188	}
189
190	// Discover user skills from configured paths.
191	if len(cfg.Options.SkillsPaths) > 0 {
192		expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths))
193		for _, pth := range cfg.Options.SkillsPaths {
194			expandedPaths = append(expandedPaths, expandPath(pth, store))
195		}
196		for _, userSkill := range skills.Discover(expandedPaths) {
197			if builtinNames[userSkill.Name] {
198				slog.Warn("User skill overrides builtin skill", "name", userSkill.Name)
199			}
200			allSkills = append(allSkills, userSkill)
201		}
202	}
203
204	// Deduplicate: user skills override builtins with the same name.
205	allSkills = skills.Deduplicate(allSkills)
206
207	// Filter out disabled skills.
208	allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills)
209
210	if len(allSkills) > 0 {
211		availSkillXML = skills.ToPromptXML(allSkills)
212	}
213
214	isGit := isGitRepo(store.WorkingDir())
215	data := PromptDat{
216		Provider:      provider,
217		Model:         model,
218		Config:        *cfg,
219		WorkingDir:    filepath.ToSlash(workingDir),
220		IsGitRepo:     isGit,
221		Platform:      platform,
222		Date:          p.now().Format("1/2/2006"),
223		AvailSkillXML: availSkillXML,
224	}
225	if isGit {
226		var err error
227		data.GitStatus, err = getGitStatus(ctx, store.WorkingDir())
228		if err != nil {
229			return PromptDat{}, err
230		}
231	}
232
233	for _, files := range contextFiles {
234		data.ContextFiles = append(data.ContextFiles, files...)
235	}
236	if !testing.Testing() {
237		for _, files := range globalContextFiles {
238			data.GlobalContextFiles = append(data.GlobalContextFiles, files...)
239		}
240	}
241	return data, nil
242}
243
244func isGitRepo(dir string) bool {
245	_, err := os.Stat(filepath.Join(dir, ".git"))
246	return err == nil
247}
248
249func getGitStatus(ctx context.Context, dir string) (string, error) {
250	sh := shell.NewShell(&shell.Options{
251		WorkingDir: dir,
252	})
253	branch, err := getGitBranch(ctx, sh)
254	if err != nil {
255		return "", err
256	}
257	status, err := getGitStatusSummary(ctx, sh)
258	if err != nil {
259		return "", err
260	}
261	commits, err := getGitRecentCommits(ctx, sh)
262	if err != nil {
263		return "", err
264	}
265	return branch + status + commits, nil
266}
267
268func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
269	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
270	if err != nil {
271		return "", nil
272	}
273	out = strings.TrimSpace(out)
274	if out == "" {
275		return "", nil
276	}
277	return fmt.Sprintf("Current branch: %s\n", out), nil
278}
279
280func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
281	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
282	if err != nil {
283		return "", nil
284	}
285	out = strings.TrimSpace(out)
286	if out == "" {
287		return "Status: clean\n", nil
288	}
289	return fmt.Sprintf("Status:\n%s\n", out), nil
290}
291
292func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
293	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
294	if err != nil || out == "" {
295		return "", nil
296	}
297	out = strings.TrimSpace(out)
298	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
299}
300
301func (p *Prompt) Name() string {
302	return p.name
303}