prompt.go

  1package prompt
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"runtime"
  8	"strings"
  9	"text/template"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/home"
 14)
 15
 16// Prompt represents a template-based prompt generator.
 17type Prompt struct {
 18	name     string
 19	template string
 20}
 21
 22type PromptDat struct {
 23	Provider   string
 24	Model      string
 25	Config     config.Config
 26	WorkingDir string
 27	IsGitRepo  bool
 28	Platform   string
 29	Date       string
 30}
 31
 32type ContextFile struct {
 33	Path    string
 34	Content string
 35}
 36
 37func NewPrompt(name, promptTemplate string) (*Prompt, error) {
 38	return &Prompt{
 39		name:     name,
 40		template: promptTemplate,
 41	}, nil
 42}
 43
 44func (p *Prompt) Build(provider, model string, cfg config.Config) (string, error) {
 45	t, err := template.New(p.name).Funcs(p.funcMap(cfg)).Parse(p.template)
 46	if err != nil {
 47		return "", fmt.Errorf("parsing template: %w", err)
 48	}
 49	var sb strings.Builder
 50	if err := t.Execute(&sb, promptData(provider, model, cfg)); err != nil {
 51		return "", fmt.Errorf("executing template: %w", err)
 52	}
 53
 54	return sb.String(), nil
 55}
 56
 57func (p *Prompt) funcMap(cfg config.Config) template.FuncMap {
 58	return template.FuncMap{
 59		"contextFiles": func(path string) []ContextFile {
 60			path = expandPath(path, cfg)
 61			return processContextPath(path, cfg)
 62		},
 63	}
 64}
 65
 66func processFile(filePath string) *ContextFile {
 67	content, err := os.ReadFile(filePath)
 68	if err != nil {
 69		return nil
 70	}
 71	return &ContextFile{
 72		Path:    filePath,
 73		Content: string(content),
 74	}
 75}
 76
 77func processContextPath(p string, cfg config.Config) []ContextFile {
 78	var contexts []ContextFile
 79	fullPath := p
 80	if !filepath.IsAbs(p) {
 81		fullPath = filepath.Join(cfg.WorkingDir(), p)
 82	}
 83	info, err := os.Stat(fullPath)
 84	if err != nil {
 85		return contexts
 86	}
 87	if info.IsDir() {
 88		filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
 89			if err != nil {
 90				return err
 91			}
 92			if !d.IsDir() {
 93				if result := processFile(path); result != nil {
 94					contexts = append(contexts, *result)
 95				}
 96			}
 97			return nil
 98		})
 99	} else {
100		result := processFile(fullPath)
101		if result != nil {
102			contexts = append(contexts, *result)
103		}
104	}
105	return contexts
106}
107
108// expandPath expands ~ and environment variables in file paths
109func expandPath(path string, cfg config.Config) string {
110	path = home.Long(path)
111	// Handle environment variable expansion using the same pattern as config
112	if strings.HasPrefix(path, "$") {
113		if expanded, err := cfg.Resolver().ResolveValue(path); err == nil {
114			path = expanded
115		}
116	}
117
118	return path
119}
120
121func promptData(provider, model string, cfg config.Config) PromptDat {
122	return PromptDat{
123		Provider:   provider,
124		Model:      model,
125		Config:     cfg,
126		WorkingDir: cfg.WorkingDir(),
127		IsGitRepo:  isGitRepo(cfg.WorkingDir()),
128		Platform:   runtime.GOOS,
129		Date:       time.Now().Format("1/2/2006"),
130	}
131}
132
133func isGitRepo(dir string) bool {
134	_, err := os.Stat(filepath.Join(dir, ".git"))
135	return err == nil
136}
137
138func (p *Prompt) Name() string {
139	return p.name
140}