diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 313b83c0448d8a668e2390368c6797c82dd22452..fbb5b4fd8c6390ff0dfad0e072af35342355ba41 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -149,7 +149,7 @@ func NewAgent( } opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), } agentProvider, err := provider.NewProvider(*providerCfg, opts...) if err != nil { @@ -827,7 +827,7 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), } newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 835279b4f4c0e08e46aaad271b7cb7f2a59b467f..4a2661bb9f663d9f93cf0371ac5d71dd513392c7 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -5,6 +5,9 @@ import ( "path/filepath" "strings" "sync" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/env" ) type PromptID string @@ -21,7 +24,7 @@ func GetPrompt(promptID PromptID, provider string, contextPaths ...string) strin basePrompt := "" switch promptID { case PromptCoder: - basePrompt = CoderPrompt(provider) + basePrompt = CoderPrompt(provider, contextPaths...) case PromptTitle: basePrompt = TitlePrompt() case PromptTask: @@ -38,6 +41,32 @@ func getContextFromPaths(workingDir string, contextPaths []string) string { return processContextPaths(workingDir, contextPaths) } +// expandPath expands ~ and environment variables in file paths +func expandPath(path string) string { + // Handle tilde expansion + if strings.HasPrefix(path, "~/") { + homeDir, err := os.UserHomeDir() + if err == nil { + path = filepath.Join(homeDir, path[2:]) + } + } else if path == "~" { + homeDir, err := os.UserHomeDir() + if err == nil { + path = homeDir + } + } + + // Handle environment variable expansion using the same pattern as config + if strings.HasPrefix(path, "$") { + resolver := config.NewEnvironmentVariableResolver(env.New()) + if expanded, err := resolver.ResolveValue(path); err == nil { + path = expanded + } + } + + return path +} + func processContextPaths(workDir string, paths []string) string { var ( wg sync.WaitGroup @@ -53,8 +82,23 @@ func processContextPaths(workDir string, paths []string) string { go func(p string) { defer wg.Done() - if strings.HasSuffix(p, "/") { - filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error { + // Expand ~ and environment variables before processing + p = expandPath(p) + + // Use absolute path if provided, otherwise join with workDir + fullPath := p + if !filepath.IsAbs(p) { + fullPath = filepath.Join(workDir, p) + } + + // Check if the path is a directory using os.Stat + info, err := os.Stat(fullPath) + if err != nil { + return // Skip if path doesn't exist or can't be accessed + } + + if info.IsDir() { + filepath.WalkDir(fullPath, func(path string, d os.DirEntry, err error) error { if err != nil { return err } @@ -78,8 +122,7 @@ func processContextPaths(workDir string, paths []string) string { return nil }) } else { - fullPath := filepath.Join(workDir, p) - + // It's a file, process it directly // Check if we've already processed this file (case-insensitive) lowerPath := strings.ToLower(fullPath) diff --git a/internal/llm/prompt/prompt_test.go b/internal/llm/prompt/prompt_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ce7fa0fb35cfdf021b886a96a828202001588a7f --- /dev/null +++ b/internal/llm/prompt/prompt_test.go @@ -0,0 +1,113 @@ +package prompt + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExpandPath(t *testing.T) { + tests := []struct { + name string + input string + expected func() string + }{ + { + name: "regular path unchanged", + input: "/absolute/path", + expected: func() string { + return "/absolute/path" + }, + }, + { + name: "tilde expansion", + input: "~/documents", + expected: func() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, "documents") + }, + }, + { + name: "tilde only", + input: "~", + expected: func() string { + home, _ := os.UserHomeDir() + return home + }, + }, + { + name: "environment variable expansion", + input: "$HOME", + expected: func() string { + return os.Getenv("HOME") + }, + }, + { + name: "relative path unchanged", + input: "relative/path", + expected: func() string { + return "relative/path" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPath(tt.input) + expected := tt.expected() + + // Skip test if environment variable is not set + if strings.HasPrefix(tt.input, "$") && expected == "" { + t.Skip("Environment variable not set") + } + + if result != expected { + t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, expected) + } + }) + } +} + +func TestProcessContextPaths(t *testing.T) { + // Create a temporary directory and file for testing + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + testContent := "test content" + + err := os.WriteFile(testFile, []byte(testContent), 0o644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Test with absolute path to file + result := processContextPaths("", []string{testFile}) + expected := "# From:" + testFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with absolute path failed.\nGot: %q\nWant: %q", result, expected) + } + + // Test with directory path (should process all files in directory) + result = processContextPaths("", []string{tmpDir}) + if !strings.Contains(result, testContent) { + t.Errorf("processContextPaths with directory path failed to include file content") + } + + // Test with tilde expansion (if we can create a file in home directory) + tmpDir = t.TempDir() + t.Setenv("HOME", tmpDir) + homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") + err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) + if err == nil { + defer os.Remove(homeTestFile) // Clean up + + tildeFile := "~/crush_test_file.txt" + result = processContextPaths("", []string{tildeFile}) + expected = "# From:" + homeTestFile + "\n" + testContent + + if result != expected { + t.Errorf("processContextPaths with tilde expansion failed.\nGot: %q\nWant: %q", result, expected) + } + } +}