1package prompt
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/stretchr/testify/assert"
11 "github.com/stretchr/testify/require"
12)
13
14func TestGetContextFromPaths(t *testing.T) {
15 t.Parallel()
16
17 tmpDir := t.TempDir()
18 _, err := config.Init(tmpDir, false)
19 if err != nil {
20 t.Fatalf("Failed to load config: %v", err)
21 }
22 testFiles := []string{
23 "file.txt",
24 "directory/file_a.txt",
25 "directory/file_b.txt",
26 "directory/file_c.txt",
27 }
28
29 createTestFiles(t, tmpDir, testFiles)
30
31 context := getContextFromPaths(
32 []string{
33 "file.txt",
34 "directory/",
35 },
36 )
37 expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
38 assert.Equal(t, expectedContext, context)
39}
40
41func createTestFiles(t *testing.T, tmpDir string, testFiles []string) {
42 t.Helper()
43 for _, path := range testFiles {
44 fullPath := filepath.Join(tmpDir, path)
45 if path[len(path)-1] == '/' {
46 err := os.MkdirAll(fullPath, 0o755)
47 require.NoError(t, err)
48 } else {
49 dir := filepath.Dir(fullPath)
50 err := os.MkdirAll(dir, 0o755)
51 require.NoError(t, err)
52 err = os.WriteFile(fullPath, []byte(path+": test content"), 0o644)
53 require.NoError(t, err)
54 }
55 }
56}