prompt.go

  1package prompt
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"strings"
  7	"sync"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/csync"
 11	"github.com/charmbracelet/crush/internal/home"
 12)
 13
 14type PromptID string
 15
 16const (
 17	PromptCoder      PromptID = "coder"
 18	PromptTitle      PromptID = "title"
 19	PromptTask       PromptID = "task"
 20	PromptSummarizer PromptID = "summarizer"
 21	PromptDefault    PromptID = "default"
 22)
 23
 24func GetPrompt(cfg *config.Config, promptID PromptID, provider string, contextPaths ...string) string {
 25	basePrompt := ""
 26	switch promptID {
 27	case PromptCoder:
 28		basePrompt = CoderPrompt(cfg, provider, contextPaths...)
 29	case PromptTitle:
 30		basePrompt = TitlePrompt()
 31	case PromptTask:
 32		basePrompt = TaskPrompt(cfg)
 33	case PromptSummarizer:
 34		basePrompt = SummarizerPrompt()
 35	default:
 36		basePrompt = "You are a helpful assistant"
 37	}
 38	return basePrompt
 39}
 40
 41func getContextFromPaths(workingDir string, contextPaths []string) string {
 42	return processContextPaths(workingDir, contextPaths)
 43}
 44
 45// expandPath expands ~ and environment variables in file paths
 46func expandPath(path string) string {
 47	path = home.Long(path)
 48
 49	// Handle environment variable expansion using the same pattern as config
 50	if strings.HasPrefix(path, "$") {
 51		resolver := config.NewEnvironmentVariableResolver(os.Environ())
 52		if expanded, err := resolver.ResolveValue(path); err == nil {
 53			path = expanded
 54		}
 55	}
 56
 57	return path
 58}
 59
 60func processContextPaths(workDir string, paths []string) string {
 61	var (
 62		wg       sync.WaitGroup
 63		resultCh = make(chan string)
 64	)
 65
 66	// Track processed files to avoid duplicates
 67	processedFiles := csync.NewMap[string, bool]()
 68
 69	for _, path := range paths {
 70		wg.Add(1)
 71		go func(p string) {
 72			defer wg.Done()
 73
 74			// Expand ~ and environment variables before processing
 75			p = expandPath(p)
 76
 77			// Use absolute path if provided, otherwise join with workDir
 78			fullPath := p
 79			if !filepath.IsAbs(p) {
 80				fullPath = filepath.Join(workDir, p)
 81			}
 82
 83			// Check if the path is a directory using os.Stat
 84			info, err := os.Stat(fullPath)
 85			if err != nil {
 86				return // Skip if path doesn't exist or can't be accessed
 87			}
 88
 89			if info.IsDir() {
 90				filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
 91					if err != nil {
 92						return err
 93					}
 94					if !d.IsDir() {
 95						// Check if we've already processed this file (case-insensitive)
 96						lowerPath := strings.ToLower(path)
 97
 98						if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
 99							processedFiles.Set(lowerPath, true)
100							if result := processFile(path); result != "" {
101								resultCh <- result
102							}
103						}
104					}
105					return nil
106				})
107			} else {
108				// It's a file, process it directly
109				// Check if we've already processed this file (case-insensitive)
110				lowerPath := strings.ToLower(fullPath)
111
112				if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
113					processedFiles.Set(lowerPath, true)
114					result := processFile(fullPath)
115					if result != "" {
116						resultCh <- result
117					}
118				}
119			}
120		}(path)
121	}
122
123	go func() {
124		wg.Wait()
125		close(resultCh)
126	}()
127
128	results := make([]string, 0)
129	for result := range resultCh {
130		results = append(results, result)
131	}
132
133	return strings.Join(results, "\n")
134}
135
136func processFile(filePath string) string {
137	content, err := os.ReadFile(filePath)
138	if err != nil {
139		return ""
140	}
141	return "# From:" + filePath + "\n" + string(content)
142}