prompt.go

  1package prompt
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"runtime"
  8	"strings"
  9	"text/template"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/home"
 14)
 15
 16// Prompt represents a template-based prompt generator.
 17type Prompt struct {
 18	name       string
 19	template   string
 20	now        func() time.Time
 21	platform   string
 22	workingDir string
 23}
 24
 25type PromptDat struct {
 26	Provider   string
 27	Model      string
 28	Config     config.Config
 29	WorkingDir string
 30	IsGitRepo  bool
 31	Platform   string
 32	Date       string
 33}
 34
 35type ContextFile struct {
 36	Path    string
 37	Content string
 38}
 39
 40type Option func(*Prompt)
 41
 42func WithTimeFunc(fn func() time.Time) Option {
 43	return func(p *Prompt) {
 44		p.now = fn
 45	}
 46}
 47
 48func WithPlatform(platform string) Option {
 49	return func(p *Prompt) {
 50		p.platform = platform
 51	}
 52}
 53
 54func WithWorkingDir(workingDir string) Option {
 55	return func(p *Prompt) {
 56		p.workingDir = workingDir
 57	}
 58}
 59
 60func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) {
 61	p := &Prompt{
 62		name:     name,
 63		template: promptTemplate,
 64		now:      time.Now,
 65	}
 66	for _, opt := range opts {
 67		opt(p)
 68	}
 69	return p, nil
 70}
 71
 72func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
 73	t, err := template.New(p.name).Funcs(p.funcMap(cfg)).Parse(p.template)
 74	if err != nil {
 75		return "", fmt.Errorf("parsing template: %w", err)
 76	}
 77	var sb strings.Builder
 78	if err := t.Execute(&sb, p.promptData(provider, model, cfg)); err != nil {
 79		return "", fmt.Errorf("executing template: %w", err)
 80	}
 81
 82	return sb.String(), nil
 83}
 84
 85func (p *Prompt) funcMap(cfg config.Config) template.FuncMap {
 86	return template.FuncMap{
 87		"contextFiles": func(path string) []ContextFile {
 88			path = expandPath(path, cfg)
 89			return processContextPath(path, cfg)
 90		},
 91	}
 92}
 93
 94func processFile(filePath string) *ContextFile {
 95	content, err := os.ReadFile(filePath)
 96	if err != nil {
 97		return nil
 98	}
 99	return &ContextFile{
100		Path:    filePath,
101		Content: string(content),
102	}
103}
104
105func processContextPath(p string, cfg config.Config) []ContextFile {
106	var contexts []ContextFile
107	fullPath := p
108	if !filepath.IsAbs(p) {
109		fullPath = filepath.Join(cfg.WorkingDir(), p)
110	}
111	info, err := os.Stat(fullPath)
112	if err != nil {
113		return contexts
114	}
115	if info.IsDir() {
116		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
117			if err != nil {
118				return err
119			}
120			if !d.IsDir() {
121				if result := processFile(path); result != nil {
122					contexts = append(contexts, *result)
123				}
124			}
125			return nil
126		})
127	} else {
128		result := processFile(fullPath)
129		if result != nil {
130			contexts = append(contexts, *result)
131		}
132	}
133	return contexts
134}
135
136// expandPath expands ~ and environment variables in file paths
137func expandPath(path string, cfg config.Config) string {
138	path = home.Long(path)
139	// Handle environment variable expansion using the same pattern as config
140	if strings.HasPrefix(path, "$") {
141		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
142			path = expanded
143		}
144	}
145
146	return path
147}
148
149func (p *Prompt) promptData(provider, model string, cfg config.Config) PromptDat {
150	workingDir := cfg.WorkingDir()
151	if p.workingDir != "" {
152		workingDir = p.workingDir
153	}
154	platform := runtime.GOOS
155	if p.platform != "" {
156		platform = p.platform
157	}
158	return PromptDat{
159		Provider:   provider,
160		Model:      model,
161		Config:     cfg,
162		WorkingDir: workingDir,
163		IsGitRepo:  isGitRepo(cfg.WorkingDir()),
164		Platform:   platform,
165		Date:       p.now().Format("1/2/2006"),
166	}
167}
168
169func isGitRepo(dir string) bool {
170	_, err := os.Stat(filepath.Join(dir, ".git"))
171	return err == nil
172}
173
174func (p *Prompt) Name() string {
175	return p.name
176}