prompt.go

  1package prompt
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"strings"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/llm/models"
 12	"github.com/charmbracelet/crush/internal/logging"
 13)
 14
 15func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
 16	basePrompt := ""
 17	switch agentName {
 18	case config.AgentCoder:
 19		basePrompt = CoderPrompt(provider)
 20	case config.AgentTitle:
 21		basePrompt = TitlePrompt(provider)
 22	case config.AgentTask:
 23		basePrompt = TaskPrompt(provider)
 24	case config.AgentSummarizer:
 25		basePrompt = SummarizerPrompt(provider)
 26	default:
 27		basePrompt = "You are a helpful assistant"
 28	}
 29
 30	if agentName == config.AgentCoder || agentName == config.AgentTask {
 31		// Add context from project-specific instruction files if they exist
 32		contextContent := getContextFromPaths()
 33		logging.Debug("Context content", "Context", contextContent)
 34		if contextContent != "" {
 35			return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
 36		}
 37	}
 38	return basePrompt
 39}
 40
 41var (
 42	onceContext    sync.Once
 43	contextContent string
 44)
 45
 46func getContextFromPaths() string {
 47	onceContext.Do(func() {
 48		var (
 49			cfg          = config.Get()
 50			workDir      = cfg.WorkingDir
 51			contextPaths = cfg.ContextPaths
 52		)
 53
 54		contextContent = processContextPaths(workDir, contextPaths)
 55	})
 56
 57	return contextContent
 58}
 59
 60func processContextPaths(workDir string, paths []string) string {
 61	var (
 62		wg       sync.WaitGroup
 63		resultCh = make(chan string)
 64	)
 65
 66	// Track processed files to avoid duplicates
 67	processedFiles := make(map[string]bool)
 68	var processedMutex sync.Mutex
 69
 70	for _, path := range paths {
 71		wg.Add(1)
 72		go func(p string) {
 73			defer wg.Done()
 74
 75			if strings.HasSuffix(p, "/") {
 76				filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
 77					if err != nil {
 78						return err
 79					}
 80					if !d.IsDir() {
 81						// Check if we've already processed this file (case-insensitive)
 82						lowerPath := strings.ToLower(path)
 83
 84						processedMutex.Lock()
 85						alreadyProcessed := processedFiles[lowerPath]
 86						if !alreadyProcessed {
 87							processedFiles[lowerPath] = true
 88						}
 89						processedMutex.Unlock()
 90
 91						if !alreadyProcessed {
 92							if result := processFile(path); result != "" {
 93								resultCh <- result
 94							}
 95						}
 96					}
 97					return nil
 98				})
 99			} else {
100				fullPath := filepath.Join(workDir, p)
101
102				// Check if we've already processed this file (case-insensitive)
103				lowerPath := strings.ToLower(fullPath)
104
105				processedMutex.Lock()
106				alreadyProcessed := processedFiles[lowerPath]
107				if !alreadyProcessed {
108					processedFiles[lowerPath] = true
109				}
110				processedMutex.Unlock()
111
112				if !alreadyProcessed {
113					result := processFile(fullPath)
114					if result != "" {
115						resultCh <- result
116					}
117				}
118			}
119		}(path)
120	}
121
122	go func() {
123		wg.Wait()
124		close(resultCh)
125	}()
126
127	results := make([]string, 0)
128	for result := range resultCh {
129		results = append(results, result)
130	}
131
132	return strings.Join(results, "\n")
133}
134
135func processFile(filePath string) string {
136	content, err := os.ReadFile(filePath)
137	if err != nil {
138		return ""
139	}
140	return "# From:" + filePath + "\n" + string(content)
141}