fetch_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"net/http"
  7	"net/http/httptest"
  8	"strings"
  9	"testing"
 10
 11	"charm.land/fantasy"
 12	"github.com/charmbracelet/crush/internal/permission"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17type mockFetchPermissionService struct {
 18	*pubsub.Broker[permission.PermissionRequest]
 19}
 20
 21func (m *mockFetchPermissionService) Request(ctx context.Context, req permission.CreatePermissionRequest) (bool, error) {
 22	return true, nil
 23}
 24
 25func (m *mockFetchPermissionService) Grant(req permission.PermissionRequest)           {}
 26func (m *mockFetchPermissionService) Deny(req permission.PermissionRequest)            {}
 27func (m *mockFetchPermissionService) GrantPersistent(req permission.PermissionRequest) {}
 28func (m *mockFetchPermissionService) AutoApproveSession(sessionID string)              {}
 29func (m *mockFetchPermissionService) SetSkipRequests(skip bool)                        {}
 30func (m *mockFetchPermissionService) SkipRequests() bool                               { return false }
 31func (m *mockFetchPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
 32	return make(<-chan pubsub.Event[permission.PermissionNotification])
 33}
 34
 35func newFetchToolForTest() fantasy.AgentTool {
 36	permissions := &mockFetchPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
 37	return NewFetchTool(permissions, "/tmp", http.DefaultClient)
 38}
 39
 40func runFetchTool(t *testing.T, tool fantasy.AgentTool, params FetchParams) fantasy.ToolResponse {
 41	t.Helper()
 42	input, err := json.Marshal(params)
 43	require.NoError(t, err)
 44	ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
 45	resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "t", Name: FetchToolName, Input: string(input)})
 46	require.NoError(t, err)
 47	return resp
 48}
 49
 50func TestApplyJQ(t *testing.T) {
 51	t.Parallel()
 52
 53	tests := []struct {
 54		name string
 55		body string
 56		expr string
 57		want string
 58	}{
 59		{
 60			name: "length of array",
 61			body: `[1,2,3,4,5]`,
 62			expr: `length`,
 63			want: `5`,
 64		},
 65		{
 66			name: "extract field",
 67			body: `{"name":"crush","version":"1.0"}`,
 68			expr: `.name`,
 69			want: `"crush"`,
 70		},
 71		{
 72			name: "count objects in array",
 73			body: `[{"id":"a"},{"id":"b"},{"id":"c"}]`,
 74			expr: `length`,
 75			want: `3`,
 76		},
 77		{
 78			name: "sum nested array lengths",
 79			body: `[{"models":[1,2]},{"models":[3,4,5]},{"models":[6]}]`,
 80			expr: `[.[].models | length] | add`,
 81			want: `6`,
 82		},
 83		{
 84			name: "extract names",
 85			body: `[{"name":"a"},{"name":"b"}]`,
 86			expr: `[.[].name]`,
 87			want: "[\n  \"a\",\n  \"b\"\n]",
 88		},
 89	}
 90
 91	for _, tt := range tests {
 92		t.Run(tt.name, func(t *testing.T) {
 93			t.Parallel()
 94			got, err := applyJQ(tt.body, tt.expr)
 95			require.NoError(t, err)
 96			require.Equal(t, tt.want, got)
 97		})
 98	}
 99}
100
101func TestApplyJQErrors(t *testing.T) {
102	t.Parallel()
103
104	_, err := applyJQ(`not json`, `.`)
105	require.Error(t, err)
106
107	_, err = applyJQ(`[1,2,3]`, `|||`)
108	require.Error(t, err)
109
110	_, err = applyJQ(``, `.`)
111	require.Error(t, err)
112}
113
114// TestApplyJQShapeHint verifies that when a jq filter fails because it
115// assumed the wrong top-level shape, the error message includes a
116// describeShape() hint so the caller can self-correct.
117func TestApplyJQShapeHint(t *testing.T) {
118	t.Parallel()
119
120	// Filter assumes an object but body is an array. This is the exact
121	// failure mode observed with kimi-k2.5 in the eval harness:
122	// `.providers[]` against a top-level array.
123	_, err := applyJQ(`[{"id":"a"},{"id":"b"}]`, `.providers[]`)
124	require.Error(t, err)
125	require.Contains(t, err.Error(), "input shape:")
126	require.Contains(t, err.Error(), "array of 2 items")
127	require.Contains(t, err.Error(), "object with keys: id")
128
129	// Filter assumes an array index but body is an object.
130	_, err = applyJQ(`{"data":{"x":1},"meta":{}}`, `.[0]`)
131	require.Error(t, err)
132	require.Contains(t, err.Error(), "input shape:")
133	require.Contains(t, err.Error(), "object with keys: data, meta")
134}
135
136func TestDescribeShape(t *testing.T) {
137	t.Parallel()
138
139	tests := []struct {
140		name string
141		json string
142		want string
143	}{
144		{"null", `null`, "null"},
145		{"bool", `true`, "boolean"},
146		{"number", `42`, "number"},
147		{"string", `"hi"`, "string"},
148		{"empty array", `[]`, "empty array"},
149		{"empty object", `{}`, "empty object"},
150		{"array of objects", `[{"a":1,"b":2},{"a":3}]`, "array of 2 items; first item is object with keys: a, b"},
151		{"object with keys", `{"zebra":1,"apple":2,"mango":3}`, "object with keys: apple, mango, zebra"},
152		{
153			"object truncates keys",
154			`{"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10}`,
155			"object with keys: a, b, c, d, e, f, g, h, ...",
156		},
157	}
158
159	for _, tt := range tests {
160		t.Run(tt.name, func(t *testing.T) {
161			t.Parallel()
162			var v any
163			dec := json.NewDecoder(strings.NewReader(tt.json))
164			dec.UseNumber()
165			require.NoError(t, dec.Decode(&v))
166			require.Equal(t, tt.want, describeShape(v))
167		})
168	}
169}
170
171func TestLooksLikeJSON(t *testing.T) {
172	t.Parallel()
173
174	tests := []struct {
175		name        string
176		contentType string
177		body        string
178		want        bool
179	}{
180		{"json content type", "application/json", "garbage", true},
181		{"json content type uppercase", "Application/JSON; charset=utf-8", "garbage", true},
182		{"array body", "text/plain", "[1,2,3]", true},
183		{"object body", "text/plain", `{"a":1}`, true},
184		{"leading whitespace", "text/plain", "\n  \t[1]", true},
185		{"html body", "text/html", "<html></html>", false},
186		{"plain text", "text/plain", "hello world", false},
187		{"empty", "", "", false},
188	}
189
190	for _, tt := range tests {
191		t.Run(tt.name, func(t *testing.T) {
192			t.Parallel()
193			require.Equal(t, tt.want, looksLikeJSON(tt.contentType, []byte(tt.body)))
194		})
195	}
196}
197
198// TestFetchToolJQHintPlacement verifies that when fetch returns a large
199// JSON body without a jq filter:
200//
201//  1. The hint banner is appended (not prepended).
202//  2. The content up to the banner is the unmodified original JSON and
203//     parses cleanly — this is the "agent pipes response to jq" path we
204//     want to protect from regressions.
205//  3. Small JSON bodies do NOT get the banner (no unnecessary noise).
206//  4. Non-JSON bodies (even if large) do NOT get the banner.
207//  5. When jq is set, there is no banner and format validation is skipped.
208func TestFetchToolJQHintPlacement(t *testing.T) {
209	t.Parallel()
210
211	// Build a JSON body larger than jqHintThreshold.
212	items := make([]map[string]any, 3000)
213	for i := range items {
214		items[i] = map[string]any{"id": i, "name": "item"}
215	}
216	largeJSON, err := json.Marshal(items)
217	require.NoError(t, err)
218	require.Greater(t, len(largeJSON), jqHintThreshold, "fixture must exceed threshold")
219
220	smallJSON := []byte(`[{"id":1,"name":"item"},{"id":2,"name":"item"}]`)
221	require.Less(t, len(smallJSON), jqHintThreshold)
222
223	largeText := strings.Repeat("lorem ipsum dolor sit amet ", 3000)
224	require.Greater(t, len(largeText), jqHintThreshold)
225
226	tests := []struct {
227		name        string
228		contentType string
229		body        []byte
230		params      FetchParams
231		wantBanner  bool
232		wantErr     bool
233	}{
234		{
235			name:        "large JSON without jq gets trailing banner",
236			contentType: "application/json",
237			body:        largeJSON,
238			params:      FetchParams{Format: "text"},
239			wantBanner:  true,
240		},
241		{
242			name:        "large JSON with jq has no banner",
243			contentType: "application/json",
244			body:        largeJSON,
245			params:      FetchParams{JQ: "length"},
246			wantBanner:  false,
247		},
248		{
249			name:        "large JSON with jq and no format still works",
250			contentType: "application/json",
251			body:        largeJSON,
252			params:      FetchParams{JQ: "length"}, // note: Format unset
253			wantBanner:  false,
254		},
255		{
256			name:        "small JSON gets no banner",
257			contentType: "application/json",
258			body:        smallJSON,
259			params:      FetchParams{Format: "text"},
260			wantBanner:  false,
261		},
262		{
263			name:        "large non-JSON text gets no banner",
264			contentType: "text/plain",
265			body:        []byte(largeText),
266			params:      FetchParams{Format: "text"},
267			wantBanner:  false,
268		},
269	}
270
271	for _, tt := range tests {
272		t.Run(tt.name, func(t *testing.T) {
273			t.Parallel()
274
275			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
276				w.Header().Set("Content-Type", tt.contentType)
277				w.Write(tt.body)
278			}))
279			t.Cleanup(srv.Close)
280
281			tool := newFetchToolForTest()
282			params := tt.params
283			params.URL = srv.URL
284
285			resp := runFetchTool(t, tool, params)
286			require.False(t, resp.IsError, "unexpected error: %s", resp.Content)
287
288			if tt.wantBanner {
289				require.Contains(t, resp.Content, "[crush-hint:")
290				// Banner must be trailing, not leading.
291				require.False(t, strings.HasPrefix(resp.Content, "[crush-hint:"),
292					"banner must be appended, not prepended")
293				// Critical regression guard: the content BEFORE the banner
294				// must be the original JSON body, still parseable.
295				bannerIdx := strings.LastIndex(resp.Content, "\n\n[crush-hint:")
296				require.Greater(t, bannerIdx, 0, "banner not found in expected form")
297				jsonPortion := resp.Content[:bannerIdx]
298				var parsed any
299				require.NoError(t, json.Unmarshal([]byte(jsonPortion), &parsed),
300					"content before banner must be valid JSON")
301			} else {
302				require.NotContains(t, resp.Content, "[crush-hint:")
303			}
304		})
305	}
306}
307
308// TestFetchToolFormatOptionalWithJQ verifies that format is optional
309// (defaults to text) when jq is set, so callers don't get bounced on
310// format="json" or missing format when they're using jq.
311func TestFetchToolFormatOptionalWithJQ(t *testing.T) {
312	t.Parallel()
313
314	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
315		w.Header().Set("Content-Type", "application/json")
316		w.Write([]byte(`[{"id":1},{"id":2},{"id":3}]`))
317	}))
318	t.Cleanup(srv.Close)
319
320	tool := newFetchToolForTest()
321
322	// No format at all.
323	resp := runFetchTool(t, tool, FetchParams{URL: srv.URL, JQ: "length"})
324	require.False(t, resp.IsError, "expected success, got: %s", resp.Content)
325	require.Equal(t, "3", resp.Content)
326
327	// format="json" — historically rejected, should now pass because jq is set.
328	resp = runFetchTool(t, tool, FetchParams{URL: srv.URL, Format: "json", JQ: "length"})
329	require.False(t, resp.IsError, "expected success with format=json + jq, got: %s", resp.Content)
330	require.Equal(t, "3", resp.Content)
331
332	// Sanity: invalid format WITHOUT jq still rejected.
333	resp = runFetchTool(t, tool, FetchParams{URL: srv.URL, Format: "json"})
334	require.True(t, resp.IsError, "invalid format without jq should still error")
335	require.Contains(t, resp.Content, "jq")
336}
337
338// TestFetchToolShapeHintSurfaces verifies that a wrong-shape jq filter
339// against a fetched body returns an error whose message includes the
340// (input shape: ...) hint, so the LLM has enough info to self-correct.
341func TestFetchToolShapeHintSurfaces(t *testing.T) {
342	t.Parallel()
343
344	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
345		w.Header().Set("Content-Type", "application/json")
346		w.Write([]byte(`[{"id":"a"},{"id":"b"}]`))
347	}))
348	t.Cleanup(srv.Close)
349
350	tool := newFetchToolForTest()
351	resp := runFetchTool(t, tool, FetchParams{
352		URL: srv.URL,
353		JQ:  ".providers[].id",
354	})
355	require.True(t, resp.IsError)
356	require.Contains(t, resp.Content, "jq:")
357	require.Contains(t, resp.Content, "input shape:")
358	require.Contains(t, resp.Content, "array of 2 items")
359}