view_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"os"
  8	"path/filepath"
  9	"strings"
 10	"testing"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"github.com/charmbracelet/crush/internal/filetracker"
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/pubsub"
 17	"github.com/stretchr/testify/require"
 18)
 19
 20func TestReadTextFileBoundaryCases(t *testing.T) {
 21	t.Parallel()
 22
 23	tmpDir := t.TempDir()
 24	filePath := filepath.Join(tmpDir, "sample.txt")
 25
 26	var allLines []string
 27	for i := range 5 {
 28		allLines = append(allLines, fmt.Sprintf("line %d", i+1))
 29	}
 30	require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(allLines, "\n")), 0o644))
 31
 32	tests := []struct {
 33		name        string
 34		offset      int
 35		limit       int
 36		wantContent string
 37		wantHasMore bool
 38	}{
 39		{
 40			name:        "exactly limit lines remaining",
 41			offset:      0,
 42			limit:       5,
 43			wantContent: "line 1\nline 2\nline 3\nline 4\nline 5",
 44			wantHasMore: false,
 45		},
 46		{
 47			name:        "limit plus one line remaining",
 48			offset:      0,
 49			limit:       4,
 50			wantContent: "line 1\nline 2\nline 3\nline 4",
 51			wantHasMore: true,
 52		},
 53		{
 54			name:        "offset at last line",
 55			offset:      4,
 56			limit:       3,
 57			wantContent: "line 5",
 58			wantHasMore: false,
 59		},
 60		{
 61			name:        "offset beyond eof",
 62			offset:      10,
 63			limit:       3,
 64			wantContent: "",
 65			wantHasMore: false,
 66		},
 67	}
 68
 69	for _, tt := range tests {
 70		t.Run(tt.name, func(t *testing.T) {
 71			t.Parallel()
 72
 73			gotContent, gotHasMore, err := readTextFile(filePath, tt.offset, tt.limit, 0)
 74			require.NoError(t, err)
 75			require.Equal(t, tt.wantContent, gotContent)
 76			require.Equal(t, tt.wantHasMore, gotHasMore)
 77		})
 78	}
 79}
 80
 81func TestReadTextFileTruncatesLongLines(t *testing.T) {
 82	t.Parallel()
 83
 84	tmpDir := t.TempDir()
 85	filePath := filepath.Join(tmpDir, "longline.txt")
 86
 87	longLine := strings.Repeat("a", MaxLineLength+10)
 88	require.NoError(t, os.WriteFile(filePath, []byte(longLine), 0o644))
 89
 90	content, hasMore, err := readTextFile(filePath, 0, 1, 0)
 91	require.NoError(t, err)
 92	require.False(t, hasMore)
 93	require.Equal(t, strings.Repeat("a", MaxLineLength)+"...", content)
 94}
 95
 96func TestViewToolAllowsSmallSectionsOfLargeFiles(t *testing.T) {
 97	t.Parallel()
 98
 99	workingDir := t.TempDir()
100	filePath := filepath.Join(workingDir, "large.txt")
101	lines := []string{strings.Repeat("a", MaxViewSize+1), "target line", "after target"}
102	require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
103
104	tool := newViewToolForTest(workingDir)
105	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
106	resp := runViewTool(t, tool, ctx, ViewParams{
107		FilePath: filePath,
108		Offset:   1,
109		Limit:    1,
110	})
111
112	require.False(t, resp.IsError)
113	require.Contains(t, resp.Content, "     2|target line")
114	require.NotContains(t, resp.Content, "File is too large")
115
116	var meta ViewResponseMetadata
117	require.NoError(t, json.Unmarshal([]byte(resp.Metadata), &meta))
118	require.Equal(t, "target line", meta.Content)
119}
120
121func TestViewToolBlocksOversizedReturnedSections(t *testing.T) {
122	t.Parallel()
123
124	workingDir := t.TempDir()
125	filePath := filepath.Join(workingDir, "large-section.txt")
126	lines := make([]string, DefaultReadLimit)
127	for i := range lines {
128		lines[i] = strings.Repeat("a", MaxLineLength)
129	}
130	require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
131
132	tool := newViewToolForTest(workingDir)
133	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
134	resp := runViewTool(t, tool, ctx, ViewParams{
135		FilePath: filePath,
136	})
137
138	require.True(t, resp.IsError)
139	require.Contains(t, resp.Content, "Content section is too large")
140}
141
142func TestViewToolBlocksOversizedImages(t *testing.T) {
143	t.Parallel()
144
145	workingDir := t.TempDir()
146	filePath := filepath.Join(workingDir, "large.png")
147	require.NoError(t, os.WriteFile(filePath, []byte(strings.Repeat("a", MaxViewSize+1)), 0o644))
148
149	tool := newViewToolForTest(workingDir)
150	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
151	ctx = context.WithValue(ctx, SupportsImagesContextKey, true)
152	resp := runViewTool(t, tool, ctx, ViewParams{
153		FilePath: filePath,
154	})
155
156	require.True(t, resp.IsError)
157	require.Contains(t, resp.Content, "Image file is too large")
158}
159
160func TestReadTextFileEnforcesMaxContentSize(t *testing.T) {
161	t.Parallel()
162
163	workingDir := t.TempDir()
164	filePath := filepath.Join(workingDir, "oversized.txt")
165	lines := []string{
166		strings.Repeat("a", MaxLineLength),
167		strings.Repeat("b", MaxLineLength),
168		"target line",
169	}
170	require.NoError(t, os.WriteFile(filePath, []byte(strings.Join(lines, "\n")), 0o644))
171
172	content, hasMore, err := readTextFile(filePath, 0, len(lines), MaxLineLength)
173	require.ErrorAs(t, err, &contentTooLargeError{})
174	require.Empty(t, content)
175	require.False(t, hasMore)
176
177	content, hasMore, err = readTextFile(filePath, 2, 1, MaxLineLength)
178	require.NoError(t, err)
179	require.Equal(t, "target line", content)
180	require.False(t, hasMore)
181}
182
183func TestReadTextFileAllowsExactMaxContentSize(t *testing.T) {
184	t.Parallel()
185
186	workingDir := t.TempDir()
187	filePath := filepath.Join(workingDir, "exact-size.txt")
188	require.NoError(t, os.WriteFile(filePath, []byte("abcd\nefgh"), 0o644))
189
190	content, hasMore, err := readTextFile(filePath, 0, 2, len("abcd\nefgh"))
191	require.NoError(t, err)
192	require.Equal(t, "abcd\nefgh", content)
193	require.False(t, hasMore)
194}
195
196type mockViewPermissionService struct {
197	*pubsub.Broker[permission.PermissionRequest]
198}
199
200func (m *mockViewPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) {
201	return true, nil
202}
203
204func (m *mockViewPermissionService) Grant(req permission.PermissionRequest) {}
205
206func (m *mockViewPermissionService) Deny(req permission.PermissionRequest) {}
207
208func (m *mockViewPermissionService) GrantPersistent(req permission.PermissionRequest) {}
209
210func (m *mockViewPermissionService) AutoApproveSession(sessionID string) {}
211
212func (m *mockViewPermissionService) SetSkipRequests(skip bool) {}
213
214func (m *mockViewPermissionService) SkipRequests() bool {
215	return false
216}
217
218func (m *mockViewPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
219	return make(<-chan pubsub.Event[permission.PermissionNotification])
220}
221
222type mockFileTracker struct{}
223
224func (m mockFileTracker) RecordRead(ctx context.Context, sessionID, path string) {}
225
226func (m mockFileTracker) LastReadTime(ctx context.Context, sessionID, path string) time.Time {
227	return time.Time{}
228}
229
230func (m mockFileTracker) ListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
231	return nil, nil
232}
233
234func newViewToolForTest(workingDir string) fantasy.AgentTool {
235	permissions := &mockViewPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
236	return NewViewTool(nil, permissions, mockFileTracker{}, nil, workingDir)
237}
238
239func runViewTool(t *testing.T, tool fantasy.AgentTool, ctx context.Context, params ViewParams) fantasy.ToolResponse {
240	t.Helper()
241
242	input, err := json.Marshal(params)
243	require.NoError(t, err)
244
245	call := fantasy.ToolCall{
246		ID:    "test-call",
247		Name:  ViewToolName,
248		Input: string(input),
249	}
250
251	resp, err := tool.Run(ctx, call)
252	require.NoError(t, err)
253	return resp
254}
255
256var _ filetracker.Service = mockFileTracker{}
257
258func TestReadBuiltinFile(t *testing.T) {
259	t.Parallel()
260
261	t.Run("reads crush-config skill", func(t *testing.T) {
262		t.Parallel()
263
264		resp, err := readBuiltinFile(ViewParams{
265			FilePath: "crush://skills/crush-config/SKILL.md",
266		}, nil)
267		require.NoError(t, err)
268		require.NotEmpty(t, resp.Content)
269		require.Contains(t, resp.Content, "Crush Configuration")
270	})
271
272	t.Run("not found", func(t *testing.T) {
273		t.Parallel()
274
275		resp, err := readBuiltinFile(ViewParams{
276			FilePath: "crush://skills/nonexistent/SKILL.md",
277		}, nil)
278		require.NoError(t, err)
279		require.True(t, resp.IsError)
280	})
281
282	t.Run("metadata has skill info", func(t *testing.T) {
283		t.Parallel()
284
285		resp, err := readBuiltinFile(ViewParams{
286			FilePath: "crush://skills/crush-config/SKILL.md",
287		}, nil)
288		require.NoError(t, err)
289
290		var meta ViewResponseMetadata
291		require.NoError(t, json.Unmarshal([]byte(resp.Metadata), &meta))
292		require.Equal(t, ViewResourceSkill, meta.ResourceType)
293		require.Equal(t, "crush-config", meta.ResourceName)
294		require.NotEmpty(t, meta.ResourceDescription)
295	})
296
297	t.Run("respects offset", func(t *testing.T) {
298		t.Parallel()
299
300		resp, err := readBuiltinFile(ViewParams{
301			FilePath: "crush://skills/crush-config/SKILL.md",
302			Offset:   5,
303		}, nil)
304		require.NoError(t, err)
305		require.NotContains(t, resp.Content, "     1|")
306	})
307}
308
309func TestSniffImageMimeType(t *testing.T) {
310	t.Parallel()
311
312	jpegMagic := []byte{0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F'}
313	pngMagic := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}
314	gifMagic := []byte("GIF89a")
315	// Minimal RIFF/WEBP header.
316	webpMagic := append([]byte("RIFF\x00\x00\x00\x00WEBPVP8 "), make([]byte, 16)...)
317	random := []byte("not an image at all, just text")
318
319	cases := []struct {
320		name     string
321		data     []byte
322		fallback string
323		want     string
324	}{
325		{"jpeg bytes in .png file uses sniffed", jpegMagic, "image/png", "image/jpeg"},
326		{"png bytes in .jpg file uses sniffed", pngMagic, "image/jpeg", "image/png"},
327		{"gif bytes uses sniffed", gifMagic, "image/png", "image/gif"},
328		{"webp bytes uses sniffed", webpMagic, "image/png", "image/webp"},
329		{"matching extension and content keeps sniffed", pngMagic, "image/png", "image/png"},
330		{"unsniffable content falls back", random, "image/png", "image/png"},
331		{"empty content falls back", nil, "image/jpeg", "image/jpeg"},
332	}
333	for _, tc := range cases {
334		t.Run(tc.name, func(t *testing.T) {
335			t.Parallel()
336			require.Equal(t, tc.want, sniffImageMimeType(tc.data, tc.fallback))
337		})
338	}
339}