Merge pull request #140 from charmbracelet/feature/custom-context-paths

Kujtim Hoxha created

allow for custom contextFiles outside of workingDir path

Change summary

internal/llm/agent/agent.go        |   4 
internal/llm/prompt/prompt.go      |  53 +++++++++++++-
internal/llm/prompt/prompt_test.go | 113 ++++++++++++++++++++++++++++++++
3 files changed, 163 insertions(+), 7 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -149,7 +149,7 @@ func NewAgent(
 	}
 	opts := []provider.ProviderClientOption{
 		provider.WithModel(agentCfg.Model),
-		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
+		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 	}
 	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 	if err != nil {
@@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error {
 
 		opts := []provider.ProviderClientOption{
 			provider.WithModel(a.agentCfg.Model),
-			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
+			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 		}
 
 		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)

internal/llm/prompt/prompt.go 🔗

@@ -5,6 +5,9 @@ import (
 	"path/filepath"
 	"strings"
 	"sync"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/env"
 )
 
 type PromptID string
@@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin
 	basePrompt := ""
 	switch promptID {
 	case PromptCoder:
-		basePrompt = CoderPrompt(provider)
+		basePrompt = CoderPrompt(provider, contextPaths...)
 	case PromptTitle:
 		basePrompt = TitlePrompt()
 	case PromptTask:
@@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string {
 	return processContextPaths(workingDir, contextPaths)
 }
 
+// expandPath expands ~ and environment variables in file paths
+func expandPath(path string) string {
+	// Handle tilde expansion
+	if strings.HasPrefix(path, "~/") {
+		homeDir, err := os.UserHomeDir()
+		if err == nil {
+			path = filepath.Join(homeDir, path[2:])
+		}
+	} else if path == "~" {
+		homeDir, err := os.UserHomeDir()
+		if err == nil {
+			path = homeDir
+		}
+	}
+
+	// Handle environment variable expansion using the same pattern as config
+	if strings.HasPrefix(path, "$") {
+		resolver := config.NewEnvironmentVariableResolver(env.New())
+		if expanded, err := resolver.ResolveValue(path); err == nil {
+			path = expanded
+		}
+	}
+
+	return path
+}
+
 func processContextPaths(workDir string, paths []string) string {
 	var (
 		wg       sync.WaitGroup
@@ -53,8 +82,23 @@ func processContextPaths(workDir string, paths []string) string {
 		go func(p string) {
 			defer wg.Done()
 
-			if strings.HasSuffix(p, "/") {
-				filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
+			// Expand ~ and environment variables before processing
+			p = expandPath(p)
+
+			// Use absolute path if provided, otherwise join with workDir
+			fullPath := p
+			if !filepath.IsAbs(p) {
+				fullPath = filepath.Join(workDir, p)
+			}
+
+			// Check if the path is a directory using os.Stat
+			info, err := os.Stat(fullPath)
+			if err != nil {
+				return // Skip if path doesn't exist or can't be accessed
+			}
+
+			if info.IsDir() {
+				filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error {
 					if err != nil {
 						return err
 					}
@@ -78,8 +122,7 @@ func processContextPaths(workDir string, paths []string) string {
 					return nil
 				})
 			} else {
-				fullPath := filepath.Join(workDir, p)
-
+				// It's a file, process it directly
 				// Check if we've already processed this file (case-insensitive)
 				lowerPath := strings.ToLower(fullPath)
 

internal/llm/prompt/prompt_test.go 🔗

@@ -0,0 +1,113 @@
+package prompt
+
+import (
+	"os"
+	"path/filepath"
+	"strings"
+	"testing"
+)
+
+func TestExpandPath(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    string
+		expected func() string
+	}{
+		{
+			name:  "regular path unchanged",
+			input: "/absolute/path",
+			expected: func() string {
+				return "/absolute/path"
+			},
+		},
+		{
+			name:  "tilde expansion",
+			input: "~/documents",
+			expected: func() string {
+				home, _ := os.UserHomeDir()
+				return filepath.Join(home, "documents")
+			},
+		},
+		{
+			name:  "tilde only",
+			input: "~",
+			expected: func() string {
+				home, _ := os.UserHomeDir()
+				return home
+			},
+		},
+		{
+			name:  "environment variable expansion",
+			input: "$HOME",
+			expected: func() string {
+				return os.Getenv("HOME")
+			},
+		},
+		{
+			name:  "relative path unchanged",
+			input: "relative/path",
+			expected: func() string {
+				return "relative/path"
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			result := expandPath(tt.input)
+			expected := tt.expected()
+
+			// Skip test if environment variable is not set
+			if strings.HasPrefix(tt.input, "$") && expected == "" {
+				t.Skip("Environment variable not set")
+			}
+
+			if result != expected {
+				t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected)
+			}
+		})
+	}
+}
+
+func TestProcessContextPaths(t *testing.T) {
+	// Create a temporary directory and file for testing
+	tmpDir := t.TempDir()
+	testFile := filepath.Join(tmpDir, "test.txt")
+	testContent := "test content"
+
+	err := os.WriteFile(testFile, []byte(testContent), 0o644)
+	if err != nil {
+		t.Fatalf("Failed to create test file: %v", err)
+	}
+
+	// Test with absolute path to file
+	result := processContextPaths("", []string{testFile})
+	expected := "# From:" + testFile + "\n" + testContent
+
+	if result != expected {
+		t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected)
+	}
+
+	// Test with directory path (should process all files in directory)
+	result = processContextPaths("", []string{tmpDir})
+	if !strings.Contains(result, testContent) {
+		t.Errorf("processContextPaths with directory path failed to include file content")
+	}
+
+	// Test with tilde expansion (if we can create a file in home directory)
+	tmpDir = t.TempDir()
+	t.Setenv("HOME", tmpDir)
+	homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
+	err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
+	if err == nil {
+		defer os.Remove(homeTestFile) // Clean up
+
+		tildeFile := "~/crush_test_file.txt"
+		result = processContextPaths("", []string{tildeFile})
+		expected = "# From:" + homeTestFile + "\n" + testContent
+
+		if result != expected {
+			t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected)
+		}
+	}
+}