1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "os"
8 "path/filepath"
9 "strings"
10 "testing"
11 "time"
12
13 "charm.land/fantasy"
14 "github.com/charmbracelet/crush/internal/filetracker"
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/crush/internal/pubsub"
17 "github.com/stretchr/testify/require"
18)
19
20func TestReadTextFileBoundaryCases(t *testing.T) {
21 t.Parallel()
22
23 tmpDir := t.TempDir()
24 filePath := filepath.Join(tmpDir, "sample.txt")
25
26 var allLines []string
27 for i := range 5 {
28 allLines = append(allLines, fmt.Sprintf("line %d", i+1))
29 }
30 require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(allLines, "\n")), 0o644))
31
32 tests := []struct {
33 name string
34 offset int
35 limit int
36 wantContent string
37 wantHasMore bool
38 }{
39 {
40 name: "exactly limit lines remaining",
41 offset: 0,
42 limit: 5,
43 wantContent: "line 1\nline 2\nline 3\nline 4\nline 5",
44 wantHasMore: false,
45 },
46 {
47 name: "limit plus one line remaining",
48 offset: 0,
49 limit: 4,
50 wantContent: "line 1\nline 2\nline 3\nline 4",
51 wantHasMore: true,
52 },
53 {
54 name: "offset at last line",
55 offset: 4,
56 limit: 3,
57 wantContent: "line 5",
58 wantHasMore: false,
59 },
60 {
61 name: "offset beyond eof",
62 offset: 10,
63 limit: 3,
64 wantContent: "",
65 wantHasMore: false,
66 },
67 }
68
69 for _, tt := range tests {
70 t.Run(tt.name, func(t *testing.T) {
71 t.Parallel()
72
73 gotContent, gotHasMore, err := readTextFile(filePath, tt.offset, tt.limit, 0)
74 require.NoError(t, err)
75 require.Equal(t, tt.wantContent, gotContent)
76 require.Equal(t, tt.wantHasMore, gotHasMore)
77 })
78 }
79}
80
81func TestReadTextFileTruncatesLongLines(t *testing.T) {
82 t.Parallel()
83
84 tmpDir := t.TempDir()
85 filePath := filepath.Join(tmpDir, "longline.txt")
86
87 longLine := strings.Repeat("a", MaxLineLength+10)
88 require.NoError(t, os.WriteFile(filePath, []byte(longLine), 0o644))
89
90 content, hasMore, err := readTextFile(filePath, 0, 1, 0)
91 require.NoError(t, err)
92 require.False(t, hasMore)
93 require.Equal(t, strings.Repeat("a", MaxLineLength)+"...", content)
94}
95
96func TestViewToolAllowsSmallSectionsOfLargeFiles(t *testing.T) {
97 t.Parallel()
98
99 workingDir := t.TempDir()
100 filePath := filepath.Join(workingDir, "large.txt")
101 lines := []string{strings.Repeat("a", MaxViewSize+1), "target line", "after target"}
102 require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
103
104 tool := newViewToolForTest(workingDir)
105 ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
106 resp := runViewTool(t, tool, ctx, ViewParams{
107 FilePath: filePath,
108 Offset: 1,
109 Limit: 1,
110 })
111
112 require.False(t, resp.IsError)
113 require.Contains(t, resp.Content, " 2|target line")
114 require.NotContains(t, resp.Content, "File is too large")
115
116 var meta ViewResponseMetadata
117 require.NoError(t, json.Unmarshal([]byte(resp.Metadata), &meta))
118 require.Equal(t, "target line", meta.Content)
119}
120
121func TestViewToolBlocksOversizedReturnedSections(t *testing.T) {
122 t.Parallel()
123
124 workingDir := t.TempDir()
125 filePath := filepath.Join(workingDir, "large-section.txt")
126 lines := make([]string, DefaultReadLimit)
127 for i := range lines {
128 lines[i] = strings.Repeat("a", MaxLineLength)
129 }
130 require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
131
132 tool := newViewToolForTest(workingDir)
133 ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
134 resp := runViewTool(t, tool, ctx, ViewParams{
135 FilePath: filePath,
136 })
137
138 require.True(t, resp.IsError)
139 require.Contains(t, resp.Content, "Content section is too large")
140}
141
142func TestViewToolBlocksOversizedImages(t *testing.T) {
143 t.Parallel()
144
145 workingDir := t.TempDir()
146 filePath := filepath.Join(workingDir, "large.png")
147 require.NoError(t, os.WriteFile(filePath, []byte(strings.Repeat("a", MaxViewSize+1)), 0o644))
148
149 tool := newViewToolForTest(workingDir)
150 ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
151 ctx = context.WithValue(ctx, SupportsImagesContextKey, true)
152 resp := runViewTool(t, tool, ctx, ViewParams{
153 FilePath: filePath,
154 })
155
156 require.True(t, resp.IsError)
157 require.Contains(t, resp.Content, "Image file is too large")
158}
159
160func TestReadTextFileEnforcesMaxContentSize(t *testing.T) {
161 t.Parallel()
162
163 workingDir := t.TempDir()
164 filePath := filepath.Join(workingDir, "oversized.txt")
165 lines := []string{
166 strings.Repeat("a", MaxLineLength),
167 strings.Repeat("b", MaxLineLength),
168 "target line",
169 }
170 require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
171
172 content, hasMore, err := readTextFile(filePath, 0, len(lines), MaxLineLength)
173 require.ErrorAs(t, err, &contentTooLargeError{})
174 require.Empty(t, content)
175 require.False(t, hasMore)
176
177 content, hasMore, err = readTextFile(filePath, 2, 1, MaxLineLength)
178 require.NoError(t, err)
179 require.Equal(t, "target line", content)
180 require.False(t, hasMore)
181}
182
183func TestReadTextFileAllowsExactMaxContentSize(t *testing.T) {
184 t.Parallel()
185
186 workingDir := t.TempDir()
187 filePath := filepath.Join(workingDir, "exact-size.txt")
188 require.NoError(t, os.WriteFile(filePath, []byte("abcd\nefgh"), 0o644))
189
190 content, hasMore, err := readTextFile(filePath, 0, 2, len("abcd\nefgh"))
191 require.NoError(t, err)
192 require.Equal(t, "abcd\nefgh", content)
193 require.False(t, hasMore)
194}
195
196type mockViewPermissionService struct {
197 *pubsub.Broker[permission.PermissionRequest]
198}
199
200func (m *mockViewPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) {
201 return true, nil
202}
203
204func (m *mockViewPermissionService) Grant(req permission.PermissionRequest) {}
205
206func (m *mockViewPermissionService) Deny(req permission.PermissionRequest) {}
207
208func (m *mockViewPermissionService) GrantPersistent(req permission.PermissionRequest) {}
209
210func (m *mockViewPermissionService) AutoApproveSession(sessionID string) {}
211
212func (m *mockViewPermissionService) SetSkipRequests(skip bool) {}
213
214func (m *mockViewPermissionService) SkipRequests() bool {
215 return false
216}
217
218func (m *mockViewPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
219 return make(<-chan pubsub.Event[permission.PermissionNotification])
220}
221
222type mockFileTracker struct{}
223
224func (m mockFileTracker) RecordRead(ctx context.Context, sessionID, path string) {}
225
226func (m mockFileTracker) LastReadTime(ctx context.Context, sessionID, path string) time.Time {
227 return time.Time{}
228}
229
230func (m mockFileTracker) ListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
231 return nil, nil
232}
233
234func newViewToolForTest(workingDir string) fantasy.AgentTool {
235 permissions := &mockViewPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
236 return NewViewTool(nil, permissions, mockFileTracker{}, nil, workingDir)
237}
238
239func runViewTool(t *testing.T, tool fantasy.AgentTool, ctx context.Context, params ViewParams) fantasy.ToolResponse {
240 t.Helper()
241
242 input, err := json.Marshal(params)
243 require.NoError(t, err)
244
245 call := fantasy.ToolCall{
246 ID: "test-call",
247 Name: ViewToolName,
248 Input: string(input),
249 }
250
251 resp, err := tool.Run(ctx, call)
252 require.NoError(t, err)
253 return resp
254}
255
256var _ filetracker.Service = mockFileTracker{}
257
258func TestReadBuiltinFile(t *testing.T) {
259 t.Parallel()
260
261 t.Run("reads crush-config skill", func(t *testing.T) {
262 t.Parallel()
263
264 resp, err := readBuiltinFile(ViewParams{
265 FilePath: "crush://skills/crush-config/SKILL.md",
266 }, nil)
267 require.NoError(t, err)
268 require.NotEmpty(t, resp.Content)
269 require.Contains(t, resp.Content, "Crush Configuration")
270 })
271
272 t.Run("not found", func(t *testing.T) {
273 t.Parallel()
274
275 resp, err := readBuiltinFile(ViewParams{
276 FilePath: "crush://skills/nonexistent/SKILL.md",
277 }, nil)
278 require.NoError(t, err)
279 require.True(t, resp.IsError)
280 })
281
282 t.Run("metadata has skill info", func(t *testing.T) {
283 t.Parallel()
284
285 resp, err := readBuiltinFile(ViewParams{
286 FilePath: "crush://skills/crush-config/SKILL.md",
287 }, nil)
288 require.NoError(t, err)
289
290 var meta ViewResponseMetadata
291 require.NoError(t, json.Unmarshal([]byte(resp.Metadata), &meta))
292 require.Equal(t, ViewResourceSkill, meta.ResourceType)
293 require.Equal(t, "crush-config", meta.ResourceName)
294 require.NotEmpty(t, meta.ResourceDescription)
295 })
296
297 t.Run("respects offset", func(t *testing.T) {
298 t.Parallel()
299
300 resp, err := readBuiltinFile(ViewParams{
301 FilePath: "crush://skills/crush-config/SKILL.md",
302 Offset: 5,
303 }, nil)
304 require.NoError(t, err)
305 require.NotContains(t, resp.Content, " 1|")
306 })
307}
308
309func TestSniffImageMimeType(t *testing.T) {
310 t.Parallel()
311
312 jpegMagic := []byte{0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F'}
313 pngMagic := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}
314 gifMagic := []byte("GIF89a")
315 // Minimal RIFF/WEBP header.
316 webpMagic := append([]byte("RIFF\x00\x00\x00\x00WEBPVP8 "), make([]byte, 16)...)
317 random := []byte("not an image at all, just text")
318
319 cases := []struct {
320 name string
321 data []byte
322 fallback string
323 want string
324 }{
325 {"jpeg bytes in .png file uses sniffed", jpegMagic, "image/png", "image/jpeg"},
326 {"png bytes in .jpg file uses sniffed", pngMagic, "image/jpeg", "image/png"},
327 {"gif bytes uses sniffed", gifMagic, "image/png", "image/gif"},
328 {"webp bytes uses sniffed", webpMagic, "image/png", "image/webp"},
329 {"matching extension and content keeps sniffed", pngMagic, "image/png", "image/png"},
330 {"unsniffable content falls back", random, "image/png", "image/png"},
331 {"empty content falls back", nil, "image/jpeg", "image/jpeg"},
332 }
333 for _, tc := range cases {
334 t.Run(tc.name, func(t *testing.T) {
335 t.Parallel()
336 require.Equal(t, tc.want, sniffImageMimeType(tc.data, tc.fallback))
337 })
338 }
339}