From f8d8ce33e15d33eed1d638ec341a8b76afbe38d8 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:15:38 -0700 Subject: [PATCH 1/4] allow for custom contextFiles outside of workingDir path --- internal/llm/agent/agent.go | 4 +- internal/llm/prompt/prompt.go | 48 ++++++++++++- internal/llm/prompt/prompt_test.go | 108 +++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 internal/llm/prompt/prompt_test.go diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 313b83c0448d8a668e2390368c6797c82dd22452..fbb5b4fd8c6390ff0dfad0e072af35342355ba41 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -149,7 +149,7 @@ func NewAgent( } opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), } agentProvider, err := provider.NewProvider(*providerCfg, opts...) if err != nil { @@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), } newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 835279b4f4c0e08e46aaad271b7cb7f2a59b467f..7f1f58d6f7dcb163a7a9c64bf0fac8f3e63455b3 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -5,6 +5,9 @@ import ( "path/filepath" "strings" "sync" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/env" ) type PromptID string @@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin basePrompt := "" switch promptID { case PromptCoder: - basePrompt = CoderPrompt(provider) + basePrompt = CoderPrompt(provider, contextPaths...) case PromptTitle: basePrompt = TitlePrompt() case PromptTask: @@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string { return processContextPaths(workingDir, contextPaths) } +// expandPath expands ~ and environment variables in file paths +func expandPath(path string) string { + // Handle tilde expansion + if strings.HasPrefix(path, "~/") { + homeDir, err := os.UserHomeDir() + if err == nil { + path = filepath.Join(homeDir, path[2:]) + } + } else if path == "~" { + homeDir, err := os.UserHomeDir() + if err == nil { + path = homeDir + } + } + + // Handle environment variable expansion using the same pattern as config + if strings.HasPrefix(path, "$") { + resolver := config.NewEnvironmentVariableResolver(env.New()) + if expanded, err := resolver.ResolveValue(path); err == nil { + path = expanded + } + } + + return path +} + func processContextPaths(workDir string, paths []string) string { var ( wg sync.WaitGroup @@ -53,8 +82,16 @@ func processContextPaths(workDir string, paths []string) string { go func(p string) { defer wg.Done() + // Expand ~ and environment variables before processing + p = expandPath(p) + if strings.HasSuffix(p, "/") { - filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error { + // Use absolute path if provided, otherwise join with workDir + dirPath := p + if !filepath.IsAbs(p) { + dirPath = filepath.Join(workDir, p) + } + filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -78,7 +115,12 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - fullPath := filepath.Join(workDir, p) + // Expand ~ and environment variables before processing + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77fe86a827749e0f7f0ef285e100c043b908bdea --- /dev/null +++ b/internal/llm/prompt/prompt_test.go @@ -0,0 +1,108 @@ +package prompt + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExpandPath(t *testing.T) { + tests := []struct { + name string + input string + expected func() string + }{ + { + name: "regular path unchanged", + input: "/absolute/path", + expected: func() string { + return "/absolute/path" + }, + }, + { + name: "tilde expansion", + input: "~/documents", + expected: func() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, "documents") + }, + }, + { + name: "tilde only", + input: "~", + expected: func() string { + home, _ := os.UserHomeDir() + return home + }, + }, + { + name: "environment variable expansion", + input: "$HOME", + expected: func() string { + return os.Getenv("HOME") + }, + }, + { + name: "relative path unchanged", + input: "relative/path", + expected: func() string { + return "relative/path" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPath(tt.input) + expected := tt.expected() + + // Skip test if environment variable is not set + if strings.HasPrefix(tt.input, "$") && expected == "" { + t.Skip("Environment variable not set") + } + + if result != expected { + t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected) + } + }) + } +} + +func TestProcessContextPaths(t *testing.T) { + // Create a temporary directory and file for testing + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + testContent := "test content" + + err := os.WriteFile(testFile, []byte(testContent), 0o644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Test with absolute path + result := processContextPaths("", []string{testFile}) + expected := "# From:" + testFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) + } + + // Test with tilde expansion (if we can create a file in home directory) + home, err := os.UserHomeDir() + if err == nil { + homeTestFile := filepath.Join(home, "crush_test_file.txt") + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) + if err == nil { + defer os.Remove(homeTestFile) // Clean up + + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) + } + } + } +} From b7935b4ef8a834da5ec4514e96aeab65a66f537f Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:21:37 -0700 Subject: [PATCH 2/4] Update internal/llm/prompt/prompt_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/llm/prompt/prompt_test.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 77fe86a827749e0f7f0ef285e100c043b908bdea..2087ca149a372209e8cd8c8cdb56aaf8cbc4d68e 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -89,20 +89,19 @@ func TestProcessContextPaths(t *testing.T) { } // Test with tilde expansion (if we can create a file in home directory) - home, err := os.UserHomeDir() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") + err := os.WriteFile(homeTestFile, []byte(testContent), 0o644) if err == nil { - homeTestFile := filepath.Join(home, "crush_test_file.txt") - err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) - if err == nil { - defer os.Remove(homeTestFile) // Clean up + defer os.Remove(homeTestFile) // Clean up - tildeFile := "~/crush_test_file.txt" - result = processContextPaths("", []string{tildeFile}) - expected = "# From:" + homeTestFile + "\n" + testContent + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent - if result != expected { - t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) - } + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) } } } From f6d6ffdc01bd72c1682a5f79cafbd17be89f040e Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:24:18 -0700 Subject: [PATCH 3/4] fixup suggestions from copilot --- internal/llm/prompt/prompt.go | 29 +++++++++++++++-------------- internal/llm/prompt/prompt_test.go | 8 +++++++- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 7f1f58d6f7dcb163a7a9c64bf0fac8f3e63455b3..4a2661bb9f663d9f93cf0371ac5d71dd513392c7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -85,13 +85,20 @@ func processContextPaths(workDir string, paths []string) string { // Expand ~ and environment variables before processing p = expandPath(p) - if strings.HasSuffix(p, "/") { - // Use absolute path if provided, otherwise join with workDir - dirPath := p - if !filepath.IsAbs(p) { - dirPath = filepath.Join(workDir, p) - } - filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } + + // Check if the path is a directory using os.Stat + info, err := os.Stat(fullPath) + if err != nil { + return // Skip if path doesn't exist or can't be accessed + } + + if info.IsDir() { + filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -115,13 +122,7 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - // Expand ~ and environment variables before processing - // Use absolute path if provided, otherwise join with workDir - fullPath := p - if !filepath.IsAbs(p) { - fullPath = filepath.Join(workDir, p) - } - + // It's a file, process it directly // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 2087ca149a372209e8cd8c8cdb56aaf8cbc4d68e..3f87c435a18daf251bbe6745ff73085b195cf718 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -80,7 +80,7 @@ func TestProcessContextPaths(t *testing.T) { t.Fatalf("Failed to create test file: %v", err) } - // Test with absolute path + // Test with absolute path to file result := processContextPaths("", []string{testFile}) expected := "# From:" + testFile + "\n" + testContent @@ -88,6 +88,12 @@ func TestProcessContextPaths(t *testing.T) { t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) } + // Test with directory path (should process all files in directory) + result = processContextPaths("", []string{tmpDir}) + if !strings.Contains(result, testContent) { + t.Errorf("processContextPaths with directory path failed to include file content") + } + // Test with tilde expansion (if we can create a file in home directory) tmpDir := t.TempDir() t.Setenv("HOME", tmpDir) From b6a97d222f29c2a383ee02e63636cbe7ec3f13a9 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Thu, 10 Jul 2025 18:28:40 -0700 Subject: [PATCH 4/4] fixup test --- internal/llm/prompt/prompt_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go index 3f87c435a18daf251bbe6745ff73085b195cf718..ce7fa0fb35cfdf021b886a96a828202001588a7f 100644 --- a/internal/llm/prompt/prompt_test.go +++ b/internal/llm/prompt/prompt_test.go @@ -95,10 +95,10 @@ func TestProcessContextPaths(t *testing.T) { } // Test with tilde expansion (if we can create a file in home directory) - tmpDir := t.TempDir() + tmpDir = t.TempDir() t.Setenv("HOME", tmpDir) homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") - err := os.WriteFile(homeTestFile, []byte(testContent), 0o644) + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) if err == nil { defer os.Remove(homeTestFile) // Clean up