prompt.go

  1package prompt
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"strings"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/llm/models"
 12	"github.com/charmbracelet/crush/internal/logging"
 13)
 14
 15func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
 16	basePrompt := ""
 17	switch agentName {
 18	case config.AgentCoder:
 19		basePrompt = CoderPrompt(provider)
 20	case config.AgentTitle:
 21		basePrompt = TitlePrompt(provider)
 22	case config.AgentTask:
 23		basePrompt = TaskPrompt(provider)
 24	case config.AgentSummarizer:
 25		basePrompt = SummarizerPrompt(provider)
 26	default:
 27		basePrompt = "You are a helpful assistant"
 28	}
 29
 30	if agentName == config.AgentCoder || agentName == config.AgentTask {
 31		// Add context from project-specific instruction files if they exist
 32		contextContent := getContextFromPaths()
 33		logging.Debug("Context content", "Context", contextContent)
 34		if contextContent != "" {
 35			return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
 36		}
 37	}
 38	return basePrompt
 39}
 40
 41var (
 42	onceContext    sync.Once
 43	contextContent string
 44)
 45
 46func getContextFromPaths() string {
 47	onceContext.Do(func() {
 48		var (
 49			cfg          = config.Get()
 50			workDir      = cfg.WorkingDir
 51			contextPaths = cfg.ContextPaths
 52		)
 53
 54		contextContent = processContextPaths(workDir, contextPaths)
 55	})
 56
 57	return contextContent
 58}
 59
 60func processContextPaths(workDir string, paths []string) string {
 61	var (
 62		wg       sync.WaitGroup
 63		resultCh = make(chan string)
 64	)
 65
 66	// Track processed files to avoid duplicates
 67	processedFiles := make(map[string]bool)
 68	var processedMutex sync.Mutex
 69
 70	for _, path := range paths {
 71		wg.Add(1)
 72		go func(p string) {
 73			defer wg.Done()
 74
 75			if strings.HasSuffix(p, "/") {
 76				filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
 77					if err != nil {
 78						return err
 79					}
 80					if !d.IsDir() {
 81						// Check if we've already processed this file (case-insensitive)
 82						processedMutex.Lock()
 83						lowerPath := strings.ToLower(path)
 84						if !processedFiles[lowerPath] {
 85							processedFiles[lowerPath] = true
 86							processedMutex.Unlock()
 87
 88							if result := processFile(path); result != "" {
 89								resultCh <- result
 90							}
 91						} else {
 92							processedMutex.Unlock()
 93						}
 94					}
 95					return nil
 96				})
 97			} else {
 98				fullPath := filepath.Join(workDir, p)
 99
100				// Check if we've already processed this file (case-insensitive)
101				processedMutex.Lock()
102				lowerPath := strings.ToLower(fullPath)
103				if !processedFiles[lowerPath] {
104					processedFiles[lowerPath] = true
105					processedMutex.Unlock()
106
107					result := processFile(fullPath)
108					if result != "" {
109						resultCh <- result
110					}
111				} else {
112					processedMutex.Unlock()
113				}
114			}
115		}(path)
116	}
117
118	go func() {
119		wg.Wait()
120		close(resultCh)
121	}()
122
123	results := make([]string, 0)
124	for result := range resultCh {
125		results = append(results, result)
126	}
127
128	return strings.Join(results, "\n")
129}
130
131func processFile(filePath string) string {
132	content, err := os.ReadFile(filePath)
133	if err != nil {
134		return ""
135	}
136	return "# From:" + filePath + "\n" + string(content)
137}