prompt.go

  1package prompt
  2
  3import (
  4	"cmp"
  5	"context"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"runtime"
 10	"strings"
 11	"text/template"
 12	"time"
 13
 14	"git.secluded.site/crush/internal/config"
 15	"git.secluded.site/crush/internal/home"
 16	"git.secluded.site/crush/internal/shell"
 17)
 18
 19// Prompt represents a template-based prompt generator.
 20type Prompt struct {
 21	name       string
 22	template   string
 23	now        func() time.Time
 24	platform   string
 25	workingDir string
 26}
 27
 28type PromptDat struct {
 29	Provider     string
 30	Model        string
 31	Config       config.Config
 32	WorkingDir   string
 33	IsGitRepo    bool
 34	Platform     string
 35	Date         string
 36	GitStatus    string
 37	ContextFiles []ContextFile
 38	MemoryFiles  []ContextFile
 39}
 40
 41type ContextFile struct {
 42	Path    string
 43	Content string
 44}
 45
 46type Option func(*Prompt)
 47
 48func WithTimeFunc(fn func() time.Time) Option {
 49	return func(p *Prompt) {
 50		p.now = fn
 51	}
 52}
 53
 54func WithPlatform(platform string) Option {
 55	return func(p *Prompt) {
 56		p.platform = platform
 57	}
 58}
 59
 60func WithWorkingDir(workingDir string) Option {
 61	return func(p *Prompt) {
 62		p.workingDir = workingDir
 63	}
 64}
 65
 66func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 67	p := &Prompt{
 68		name:     name,
 69		template: promptTemplate,
 70		now:      time.Now,
 71	}
 72	for _, opt := range opts {
 73		opt(p)
 74	}
 75	return p, nil
 76}
 77
 78func (p *Prompt) Build(ctx context.Context, provider, model string, cfg config.Config) (string, error) {
 79	t, err := template.New(p.name).Parse(p.template)
 80	if err != nil {
 81		return "", fmt.Errorf("parsing template: %w", err)
 82	}
 83	var sb strings.Builder
 84	d, err := p.promptData(ctx, provider, model, cfg)
 85	if err != nil {
 86		return "", err
 87	}
 88	if err := t.Execute(&sb, d); err != nil {
 89		return "", fmt.Errorf("executing template: %w", err)
 90	}
 91
 92	return sb.String(), nil
 93}
 94
 95func processFile(filePath string) *ContextFile {
 96	content, err := os.ReadFile(filePath)
 97	if err != nil {
 98		return nil
 99	}
100	return &ContextFile{
101		Path:    filePath,
102		Content: string(content),
103	}
104}
105
106func processContextPath(p string, cfg config.Config) []ContextFile {
107	var contexts []ContextFile
108	fullPath := p
109	if !filepath.IsAbs(p) {
110		fullPath = filepath.Join(cfg.WorkingDir(), p)
111	}
112	info, err := os.Stat(fullPath)
113	if err != nil {
114		return contexts
115	}
116	if info.IsDir() {
117		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
118			if err != nil {
119				return err
120			}
121			if !d.IsDir() {
122				if result := processFile(path); result != nil {
123					contexts = append(contexts, *result)
124				}
125			}
126			return nil
127		})
128	} else {
129		result := processFile(fullPath)
130		if result != nil {
131			contexts = append(contexts, *result)
132		}
133	}
134	return contexts
135}
136
137// expandPath expands ~ and environment variables in file paths
138func expandPath(path string, cfg config.Config) string {
139	path = home.Long(path)
140	// Handle environment variable expansion using the same pattern as config
141	if strings.HasPrefix(path, "$") {
142		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
143			path = expanded
144		}
145	}
146
147	return path
148}
149
150func (p *Prompt) promptData(ctx context.Context, provider, model string, cfg config.Config) (PromptDat, error) {
151	workingDir := cmp.Or(p.workingDir, cfg.WorkingDir())
152	platform := cmp.Or(p.platform, runtime.GOOS)
153
154	contextFiles := map[string][]ContextFile{}
155	memoryFiles := map[string][]ContextFile{}
156
157	for _, pth := range cfg.Options.ContextPaths {
158		expanded := expandPath(pth, cfg)
159		pathKey := strings.ToLower(expanded)
160		if _, ok := contextFiles[pathKey]; ok {
161			continue
162		}
163		content := processContextPath(expanded, cfg)
164		contextFiles[pathKey] = content
165	}
166
167	for _, pth := range cfg.Options.MemoryPaths {
168		expanded := expandPath(pth, cfg)
169		pathKey := strings.ToLower(expanded)
170		if _, ok := memoryFiles[pathKey]; ok {
171			continue
172		}
173		content := processContextPath(expanded, cfg)
174		memoryFiles[pathKey] = content
175	}
176
177	isGit := isGitRepo(cfg.WorkingDir())
178	data := PromptDat{
179		Provider:   provider,
180		Model:      model,
181		Config:     cfg,
182		WorkingDir: filepath.ToSlash(workingDir),
183		IsGitRepo:  isGit,
184		Platform:   platform,
185		Date:       p.now().Format("1/2/2006"),
186	}
187	if isGit {
188		var err error
189		data.GitStatus, err = getGitStatus(ctx, cfg.WorkingDir())
190		if err != nil {
191			return PromptDat{}, err
192		}
193	}
194
195	for _, files := range contextFiles {
196		data.ContextFiles = append(data.ContextFiles, files...)
197	}
198	for _, files := range memoryFiles {
199		data.MemoryFiles = append(data.MemoryFiles, files...)
200	}
201	return data, nil
202}
203
204func isGitRepo(dir string) bool {
205	_, err := os.Stat(filepath.Join(dir, ".git"))
206	return err == nil
207}
208
209func getGitStatus(ctx context.Context, dir string) (string, error) {
210	sh := shell.NewShell(&shell.Options{
211		WorkingDir: dir,
212	})
213	branch, err := getGitBranch(ctx, sh)
214	if err != nil {
215		return "", err
216	}
217	status, err := getGitStatusSummary(ctx, sh)
218	if err != nil {
219		return "", err
220	}
221	commits, err := getGitRecentCommits(ctx, sh)
222	if err != nil {
223		return "", err
224	}
225	return branch + status + commits, nil
226}
227
228func getGitBranch(ctx context.Context, sh *shell.Shell) (string, error) {
229	out, _, err := sh.Exec(ctx, "git branch --show-current 2>/dev/null")
230	if err != nil {
231		return "", nil
232	}
233	out = strings.TrimSpace(out)
234	if out == "" {
235		return "", nil
236	}
237	return fmt.Sprintf("Current branch: %s\n", out), nil
238}
239
240func getGitStatusSummary(ctx context.Context, sh *shell.Shell) (string, error) {
241	out, _, err := sh.Exec(ctx, "git status --short 2>/dev/null | head -20")
242	if err != nil {
243		return "", nil
244	}
245	out = strings.TrimSpace(out)
246	if out == "" {
247		return "Status: clean\n", nil
248	}
249	return fmt.Sprintf("Status:\n%s\n", out), nil
250}
251
252func getGitRecentCommits(ctx context.Context, sh *shell.Shell) (string, error) {
253	out, _, err := sh.Exec(ctx, "git log --oneline -n 3 2>/dev/null")
254	if err != nil || out == "" {
255		return "", nil
256	}
257	out = strings.TrimSpace(out)
258	return fmt.Sprintf("Recent commits:\n%s\n", out), nil
259}
260
261func (p *Prompt) Name() string {
262	return p.name
263}