prompt.go

  1package prompt
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"strings"
  7	"sync"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/env"
 11)
 12
 13type PromptID string
 14
 15const (
 16	PromptCoder      PromptID = "coder"
 17	PromptTitle      PromptID = "title"
 18	PromptTask       PromptID = "task"
 19	PromptSummarizer PromptID = "summarizer"
 20	PromptDefault    PromptID = "default"
 21)
 22
 23func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
 24	basePrompt := ""
 25	switch promptID {
 26	case PromptCoder:
 27		basePrompt = CoderPrompt(provider, contextPaths...)
 28	case PromptTitle:
 29		basePrompt = TitlePrompt()
 30	case PromptTask:
 31		basePrompt = TaskPrompt()
 32	case PromptSummarizer:
 33		basePrompt = SummarizerPrompt()
 34	default:
 35		basePrompt = "You are a helpful assistant"
 36	}
 37	return basePrompt
 38}
 39
 40func getContextFromPaths(workingDir string, contextPaths []string) string {
 41	return processContextPaths(workingDir, contextPaths)
 42}
 43
 44// expandPath expands ~ and environment variables in file paths
 45func expandPath(path string) string {
 46	// Handle tilde expansion
 47	if strings.HasPrefix(path, "~/") {
 48		homeDir, err := os.UserHomeDir()
 49		if err == nil {
 50			path = filepath.Join(homeDir, path[2:])
 51		}
 52	} else if path == "~" {
 53		homeDir, err := os.UserHomeDir()
 54		if err == nil {
 55			path = homeDir
 56		}
 57	}
 58
 59	// Handle environment variable expansion using the same pattern as config
 60	if strings.HasPrefix(path, "$") {
 61		resolver := config.NewEnvironmentVariableResolver(env.New())
 62		if expanded, err := resolver.ResolveValue(path); err == nil {
 63			path = expanded
 64		}
 65	}
 66
 67	return path
 68}
 69
 70func processContextPaths(workDir string, paths []string) string {
 71	var (
 72		wg       sync.WaitGroup
 73		resultCh = make(chan string)
 74	)
 75
 76	// Track processed files to avoid duplicates
 77	processedFiles := make(map[string]bool)
 78	var processedMutex sync.Mutex
 79
 80	for _, path := range paths {
 81		wg.Add(1)
 82		go func(p string) {
 83			defer wg.Done()
 84
 85			// Expand ~ and environment variables before processing
 86			p = expandPath(p)
 87
 88			if strings.HasSuffix(p, "/") {
 89				// Use absolute path if provided, otherwise join with workDir
 90				dirPath := p
 91				if !filepath.IsAbs(p) {
 92					dirPath = filepath.Join(workDir, p)
 93				}
 94				filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error {
 95					if err != nil {
 96						return err
 97					}
 98					if !d.IsDir() {
 99						// Check if we've already processed this file (case-insensitive)
100						lowerPath := strings.ToLower(path)
101
102						processedMutex.Lock()
103						alreadyProcessed := processedFiles[lowerPath]
104						if !alreadyProcessed {
105							processedFiles[lowerPath] = true
106						}
107						processedMutex.Unlock()
108
109						if !alreadyProcessed {
110							if result := processFile(path); result != "" {
111								resultCh <- result
112							}
113						}
114					}
115					return nil
116				})
117			} else {
118				// Expand ~ and environment variables before processing
119				// Use absolute path if provided, otherwise join with workDir
120				fullPath := p
121				if !filepath.IsAbs(p) {
122					fullPath = filepath.Join(workDir, p)
123				}
124
125				// Check if we've already processed this file (case-insensitive)
126				lowerPath := strings.ToLower(fullPath)
127
128				processedMutex.Lock()
129				alreadyProcessed := processedFiles[lowerPath]
130				if !alreadyProcessed {
131					processedFiles[lowerPath] = true
132				}
133				processedMutex.Unlock()
134
135				if !alreadyProcessed {
136					result := processFile(fullPath)
137					if result != "" {
138						resultCh <- result
139					}
140				}
141			}
142		}(path)
143	}
144
145	go func() {
146		wg.Wait()
147		close(resultCh)
148	}()
149
150	results := make([]string, 0)
151	for result := range resultCh {
152		results = append(results, result)
153	}
154
155	return strings.Join(results, "\n")
156}
157
158func processFile(filePath string) string {
159	content, err := os.ReadFile(filePath)
160	if err != nil {
161		return ""
162	}
163	return "# From:" + filePath + "\n" + string(content)
164}