1package anthropic
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "math"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/anthropic-sdk-go"
17 "github.com/stretchr/testify/require"
18)
19
20// noopComputerRun is a no-op run function for tests that only need
21// to inspect the tool definition, not execute it.
22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
23 return fantasy.ToolResponse{}, nil
24}
25
26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
27 t.Parallel()
28
29 t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
30 t.Parallel()
31
32 prompt := fantasy.Prompt{
33 {
34 Role: fantasy.MessageRoleUser,
35 Content: []fantasy.MessagePart{
36 fantasy.TextPart{Text: "Hello"},
37 },
38 },
39 {
40 Role: fantasy.MessageRoleAssistant,
41 Content: []fantasy.MessagePart{
42 fantasy.ReasoningPart{
43 Text: "Let me think about this...",
44 ProviderOptions: fantasy.ProviderOptions{
45 Name: &ReasoningOptionMetadata{
46 Signature: "abc123",
47 },
48 },
49 },
50 },
51 },
52 }
53
54 systemBlocks, messages, warnings := toPrompt(prompt, true)
55
56 require.Empty(t, systemBlocks)
57 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
58 require.Len(t, warnings, 1)
59 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
60 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
61 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
62 })
63
64 t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
65 t.Parallel()
66
67 prompt := fantasy.Prompt{
68 {
69 Role: fantasy.MessageRoleUser,
70 Content: []fantasy.MessagePart{
71 fantasy.TextPart{Text: "Hello"},
72 },
73 },
74 {
75 Role: fantasy.MessageRoleAssistant,
76 Content: []fantasy.MessagePart{
77 fantasy.ReasoningPart{
78 Text: "Let me think about this...",
79 ProviderOptions: fantasy.ProviderOptions{
80 Name: &ReasoningOptionMetadata{
81 Signature: "def456",
82 },
83 },
84 },
85 },
86 },
87 }
88
89 systemBlocks, messages, warnings := toPrompt(prompt, false)
90
91 require.Empty(t, systemBlocks)
92 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
93 require.Len(t, warnings, 2)
94 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
95 require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
96 require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
97 require.Contains(t, warnings[1].Message, "dropping empty assistant message")
98 })
99
100 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
101 t.Parallel()
102
103 prompt := fantasy.Prompt{
104 {
105 Role: fantasy.MessageRoleUser,
106 Content: []fantasy.MessagePart{
107 fantasy.TextPart{Text: "Hello"},
108 },
109 },
110 {
111 Role: fantasy.MessageRoleAssistant,
112 Content: []fantasy.MessagePart{},
113 },
114 }
115
116 systemBlocks, messages, warnings := toPrompt(prompt, true)
117
118 require.Empty(t, systemBlocks)
119 require.Len(t, messages, 1, "should only have user message")
120 require.Len(t, warnings, 1)
121 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
122 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
123 })
124
125 t.Run("should keep assistant messages with text content", func(t *testing.T) {
126 t.Parallel()
127
128 prompt := fantasy.Prompt{
129 {
130 Role: fantasy.MessageRoleUser,
131 Content: []fantasy.MessagePart{
132 fantasy.TextPart{Text: "Hello"},
133 },
134 },
135 {
136 Role: fantasy.MessageRoleAssistant,
137 Content: []fantasy.MessagePart{
138 fantasy.TextPart{Text: "Hi there!"},
139 },
140 },
141 }
142
143 systemBlocks, messages, warnings := toPrompt(prompt, true)
144
145 require.Empty(t, systemBlocks)
146 require.Len(t, messages, 2, "should have both user and assistant messages")
147 require.Empty(t, warnings)
148 })
149
150 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
151 t.Parallel()
152
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.TextPart{Text: "What's the weather?"},
158 },
159 },
160 {
161 Role: fantasy.MessageRoleAssistant,
162 Content: []fantasy.MessagePart{
163 fantasy.ToolCallPart{
164 ToolCallID: "call_123",
165 ToolName: "get_weather",
166 Input: `{"location":"NYC"}`,
167 },
168 },
169 },
170 }
171
172 systemBlocks, messages, warnings := toPrompt(prompt, true)
173
174 require.Empty(t, systemBlocks)
175 require.Len(t, messages, 2, "should have both user and assistant messages")
176 require.Empty(t, warnings)
177 })
178
179 t.Run("should preserve tool_use block when input JSON is malformed", func(t *testing.T) {
180 t.Parallel()
181
182 // Anthropic's API rejects any request whose tool_result lacks a
183 // matching tool_use in the previous message. Dropping the tool_use
184 // because its input failed to parse leaves the next turn's
185 // tool_result orphaned and produces a 400. Emit the block with
186 // empty arguments instead, and surface the parse error as a warning.
187 prompt := fantasy.Prompt{
188 {
189 Role: fantasy.MessageRoleUser,
190 Content: []fantasy.MessagePart{
191 fantasy.TextPart{Text: "Hi"},
192 },
193 },
194 {
195 Role: fantasy.MessageRoleAssistant,
196 Content: []fantasy.MessagePart{
197 fantasy.ToolCallPart{
198 ToolCallID: "call_123",
199 ToolName: "get_weather",
200 Input: "{not-json",
201 },
202 },
203 },
204 }
205
206 systemBlocks, messages, warnings := toPrompt(prompt, true)
207
208 require.Empty(t, systemBlocks)
209 require.Len(t, messages, 2, "user + assistant — assistant must be preserved so tool_result can pair")
210 assistant := messages[1]
211 require.Equal(t, anthropic.MessageParamRoleAssistant, assistant.Role)
212 require.Len(t, assistant.Content, 1)
213 toolUse := assistant.Content[0].OfToolUse
214 require.NotNil(t, toolUse, "tool_use block should be emitted even when input is malformed")
215 require.Equal(t, "call_123", toolUse.ID)
216 require.Equal(t, "get_weather", toolUse.Name)
217 require.Len(t, warnings, 1)
218 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
219 require.Contains(t, warnings[0].Message, "malformed input JSON")
220 require.Contains(t, warnings[0].Message, "call_123")
221 })
222
223 t.Run("should preserve tool_use block when input is empty", func(t *testing.T) {
224 t.Parallel()
225
226 // Some upstream providers (notably DeepSeek's anthropic-compat
227 // endpoint) emit tool_use blocks with empty input for parameterless
228 // tool calls. Treat empty input as {} rather than dropping the
229 // block, since the next turn's tool_result still needs its pair.
230 prompt := fantasy.Prompt{
231 {
232 Role: fantasy.MessageRoleUser,
233 Content: []fantasy.MessagePart{
234 fantasy.TextPart{Text: "Hi"},
235 },
236 },
237 {
238 Role: fantasy.MessageRoleAssistant,
239 Content: []fantasy.MessagePart{
240 fantasy.ToolCallPart{
241 ToolCallID: "call_empty",
242 ToolName: "ping",
243 Input: "",
244 },
245 },
246 },
247 }
248
249 systemBlocks, messages, warnings := toPrompt(prompt, true)
250
251 require.Empty(t, systemBlocks)
252 require.Empty(t, warnings, "empty input is a valid round-trip; no warning")
253 require.Len(t, messages, 2)
254 require.Equal(t, anthropic.MessageParamRoleAssistant, messages[1].Role)
255 toolUse := messages[1].Content[0].OfToolUse
256 require.NotNil(t, toolUse)
257 require.Equal(t, "call_empty", toolUse.ID)
258 require.Equal(t, "ping", toolUse.Name)
259 })
260
261 t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
262 t.Parallel()
263
264 prompt := fantasy.Prompt{
265 {
266 Role: fantasy.MessageRoleUser,
267 Content: []fantasy.MessagePart{
268 fantasy.TextPart{Text: "Hello"},
269 },
270 },
271 {
272 Role: fantasy.MessageRoleAssistant,
273 Content: []fantasy.MessagePart{
274 fantasy.ReasoningPart{
275 Text: "Let me think...",
276 ProviderOptions: fantasy.ProviderOptions{
277 Name: &ReasoningOptionMetadata{
278 Signature: "abc123",
279 },
280 },
281 },
282 fantasy.TextPart{Text: "Hi there!"},
283 },
284 },
285 }
286
287 systemBlocks, messages, warnings := toPrompt(prompt, true)
288
289 require.Empty(t, systemBlocks)
290 require.Len(t, messages, 2, "should have both user and assistant messages")
291 require.Empty(t, warnings)
292 })
293
294 t.Run("should keep user messages with image content", func(t *testing.T) {
295 t.Parallel()
296
297 prompt := fantasy.Prompt{
298 {
299 Role: fantasy.MessageRoleUser,
300 Content: []fantasy.MessagePart{
301 fantasy.FilePart{
302 Data: []byte{0x01, 0x02, 0x03},
303 MediaType: "image/png",
304 },
305 },
306 },
307 }
308
309 systemBlocks, messages, warnings := toPrompt(prompt, true)
310
311 require.Empty(t, systemBlocks)
312 require.Len(t, messages, 1)
313 require.Empty(t, warnings)
314 })
315
316 t.Run("should keep user messages with PDF content", func(t *testing.T) {
317 t.Parallel()
318
319 prompt := fantasy.Prompt{
320 {
321 Role: fantasy.MessageRoleUser,
322 Content: []fantasy.MessagePart{
323 fantasy.FilePart{
324 Data: []byte("fake pdf data"),
325 MediaType: "application/pdf",
326 },
327 },
328 },
329 }
330
331 systemBlocks, messages, warnings := toPrompt(prompt, true)
332
333 require.Empty(t, systemBlocks)
334 require.Len(t, messages, 1)
335 require.Empty(t, warnings)
336 })
337
338 t.Run("should keep user messages with text document content", func(t *testing.T) {
339 t.Parallel()
340
341 prompt := fantasy.Prompt{
342 {
343 Role: fantasy.MessageRoleUser,
344 Content: []fantasy.MessagePart{
345 fantasy.FilePart{
346 Data: []byte("# Hello World\nSome markdown content"),
347 MediaType: "text/markdown",
348 },
349 },
350 },
351 }
352
353 systemBlocks, messages, warnings := toPrompt(prompt, true)
354
355 require.Empty(t, systemBlocks)
356 require.Len(t, messages, 1)
357 require.Empty(t, warnings)
358 })
359
360 t.Run("should drop user messages without visible content", func(t *testing.T) {
361 t.Parallel()
362
363 prompt := fantasy.Prompt{
364 {
365 Role: fantasy.MessageRoleUser,
366 Content: []fantasy.MessagePart{
367 fantasy.FilePart{
368 Data: []byte("not supported"),
369 MediaType: "application/zip",
370 },
371 },
372 },
373 }
374
375 systemBlocks, messages, warnings := toPrompt(prompt, true)
376
377 require.Empty(t, systemBlocks)
378 require.Empty(t, messages)
379 require.Len(t, warnings, 1)
380 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
381 require.Contains(t, warnings[0].Message, "dropping empty user message")
382 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
383 })
384
385 t.Run("should keep user messages with tool results", func(t *testing.T) {
386 t.Parallel()
387
388 prompt := fantasy.Prompt{
389 {
390 Role: fantasy.MessageRoleTool,
391 Content: []fantasy.MessagePart{
392 fantasy.ToolResultPart{
393 ToolCallID: "call_123",
394 Output: fantasy.ToolResultOutputContentText{Text: "done"},
395 },
396 },
397 },
398 }
399
400 systemBlocks, messages, warnings := toPrompt(prompt, true)
401
402 require.Empty(t, systemBlocks)
403 require.Len(t, messages, 1)
404 require.Empty(t, warnings)
405 })
406
407 t.Run("should keep user messages with tool error results", func(t *testing.T) {
408 t.Parallel()
409
410 prompt := fantasy.Prompt{
411 {
412 Role: fantasy.MessageRoleTool,
413 Content: []fantasy.MessagePart{
414 fantasy.ToolResultPart{
415 ToolCallID: "call_456",
416 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
417 },
418 },
419 },
420 }
421
422 systemBlocks, messages, warnings := toPrompt(prompt, true)
423
424 require.Empty(t, systemBlocks)
425 require.Len(t, messages, 1)
426 require.Empty(t, warnings)
427 })
428
429 t.Run("should keep user messages with tool media results", func(t *testing.T) {
430 t.Parallel()
431
432 prompt := fantasy.Prompt{
433 {
434 Role: fantasy.MessageRoleTool,
435 Content: []fantasy.MessagePart{
436 fantasy.ToolResultPart{
437 ToolCallID: "call_789",
438 Output: fantasy.ToolResultOutputContentMedia{
439 Data: "AQID",
440 MediaType: "image/png",
441 },
442 },
443 },
444 },
445 }
446
447 systemBlocks, messages, warnings := toPrompt(prompt, true)
448
449 require.Empty(t, systemBlocks)
450 require.Len(t, messages, 1)
451 require.Empty(t, warnings)
452 })
453}
454
455func TestParseContextTooLargeError(t *testing.T) {
456 t.Parallel()
457
458 tests := []struct {
459 name string
460 message string
461 wantErr bool
462 wantUsed int
463 wantMax int
464 }{
465 {
466 name: "matches anthropic format",
467 message: "prompt is too long: 202630 tokens > 200000 maximum",
468 wantErr: true,
469 wantUsed: 202630,
470 wantMax: 200000,
471 },
472 {
473 name: "matches with different numbers",
474 message: "prompt is too long: 150000 tokens > 128000 maximum",
475 wantErr: true,
476 wantUsed: 150000,
477 wantMax: 128000,
478 },
479 {
480 name: "matches with extra whitespace",
481 message: "prompt is too long: 202630 tokens > 200000 maximum",
482 wantErr: true,
483 wantUsed: 202630,
484 wantMax: 200000,
485 },
486 {
487 name: "does not match unrelated error",
488 message: "invalid api key",
489 wantErr: false,
490 },
491 {
492 name: "does not match rate limit error",
493 message: "rate limit exceeded",
494 wantErr: false,
495 },
496 }
497
498 for _, tt := range tests {
499 t.Run(tt.name, func(t *testing.T) {
500 t.Parallel()
501 providerErr := &fantasy.ProviderError{Message: tt.message}
502 parseContextTooLargeError(tt.message, providerErr)
503
504 if tt.wantErr {
505 require.True(t, providerErr.IsContextTooLarge())
506 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
507 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
508 } else {
509 require.False(t, providerErr.IsContextTooLarge())
510 }
511 })
512 }
513}
514
515func TestParseOptions_Effort(t *testing.T) {
516 t.Parallel()
517
518 options, err := ParseOptions(map[string]any{
519 "send_reasoning": true,
520 "thinking": map[string]any{"budget_tokens": int64(2048)},
521 "effort": "medium",
522 "disable_parallel_tool_use": true,
523 })
524 require.NoError(t, err)
525 require.NotNil(t, options.SendReasoning)
526 require.True(t, *options.SendReasoning)
527 require.NotNil(t, options.Thinking)
528 require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
529 require.NotNil(t, options.Effort)
530 require.Equal(t, EffortMedium, *options.Effort)
531 require.NotNil(t, options.DisableParallelToolUse)
532 require.True(t, *options.DisableParallelToolUse)
533}
534
535func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
536 t.Parallel()
537
538 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
539 defer server.Close()
540
541 provider, err := New(
542 WithAPIKey("test-api-key"),
543 WithBaseURL(server.URL),
544 )
545 require.NoError(t, err)
546
547 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
548 require.NoError(t, err)
549
550 effort := EffortMedium
551 _, err = model.Generate(context.Background(), fantasy.Call{
552 Prompt: testPrompt(),
553 ProviderOptions: NewProviderOptions(&ProviderOptions{
554 Effort: &effort,
555 }),
556 })
557 require.NoError(t, err)
558
559 call := awaitAnthropicCall(t, calls)
560 require.Equal(t, "POST", call.method)
561 require.Equal(t, "/v1/messages", call.path)
562 requireAnthropicEffort(t, call.body, EffortMedium)
563}
564
565func TestStream_SendsOutputConfigEffort(t *testing.T) {
566 t.Parallel()
567
568 server, calls := newAnthropicStreamingServer([]string{
569 "event: message_start\n",
570 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
571 "event: message_stop\n",
572 "data: {\"type\":\"message_stop\"}\n\n",
573 })
574 defer server.Close()
575
576 provider, err := New(
577 WithAPIKey("test-api-key"),
578 WithBaseURL(server.URL),
579 )
580 require.NoError(t, err)
581
582 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
583 require.NoError(t, err)
584
585 effort := EffortHigh
586 stream, err := model.Stream(context.Background(), fantasy.Call{
587 Prompt: testPrompt(),
588 ProviderOptions: NewProviderOptions(&ProviderOptions{
589 Effort: &effort,
590 }),
591 })
592 require.NoError(t, err)
593
594 stream(func(fantasy.StreamPart) bool { return true })
595
596 call := awaitAnthropicCall(t, calls)
597 require.Equal(t, "POST", call.method)
598 require.Equal(t, "/v1/messages", call.path)
599 requireAnthropicEffort(t, call.body, EffortHigh)
600}
601
602type anthropicCall struct {
603 method string
604 path string
605 body map[string]any
606}
607
608func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
609 calls := make(chan anthropicCall, 4)
610
611 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
612 var body map[string]any
613 if r.Body != nil {
614 _ = json.NewDecoder(r.Body).Decode(&body)
615 }
616
617 calls <- anthropicCall{
618 method: r.Method,
619 path: r.URL.Path,
620 body: body,
621 }
622
623 w.Header().Set("Content-Type", "application/json")
624 _ = json.NewEncoder(w).Encode(response)
625 }))
626
627 return server, calls
628}
629
630func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
631 calls := make(chan anthropicCall, 4)
632
633 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
634 var body map[string]any
635 if r.Body != nil {
636 _ = json.NewDecoder(r.Body).Decode(&body)
637 }
638
639 calls <- anthropicCall{
640 method: r.Method,
641 path: r.URL.Path,
642 body: body,
643 }
644
645 w.Header().Set("Content-Type", "text/event-stream")
646 w.Header().Set("Cache-Control", "no-cache")
647 w.Header().Set("Connection", "keep-alive")
648 w.WriteHeader(http.StatusOK)
649
650 for _, chunk := range chunks {
651 _, _ = fmt.Fprint(w, chunk)
652 if flusher, ok := w.(http.Flusher); ok {
653 flusher.Flush()
654 }
655 }
656 }))
657
658 return server, calls
659}
660
661func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
662 t.Helper()
663
664 select {
665 case call := <-calls:
666 return call
667 case <-time.After(2 * time.Second):
668 t.Fatal("timed out waiting for Anthropic request")
669 return anthropicCall{}
670 }
671}
672
673func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
674 t.Helper()
675
676 select {
677 case call := <-calls:
678 t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
679 case <-time.After(200 * time.Millisecond):
680 }
681}
682
683func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
684 t.Helper()
685
686 outputConfig, ok := body["output_config"].(map[string]any)
687 thinking, ok := body["thinking"].(map[string]any)
688 require.True(t, ok)
689 require.Equal(t, string(expected), outputConfig["effort"])
690 require.Equal(t, "adaptive", thinking["type"])
691}
692
693func testPrompt() fantasy.Prompt {
694 return fantasy.Prompt{
695 {
696 Role: fantasy.MessageRoleUser,
697 Content: []fantasy.MessagePart{
698 fantasy.TextPart{Text: "Hello"},
699 },
700 },
701 }
702}
703
704func mockAnthropicGenerateResponse() map[string]any {
705 return map[string]any{
706 "id": "msg_01Test",
707 "type": "message",
708 "role": "assistant",
709 "model": "claude-sonnet-4-20250514",
710 "content": []any{
711 map[string]any{
712 "type": "text",
713 "text": "Hi there",
714 },
715 },
716 "stop_reason": "end_turn",
717 "stop_sequence": "",
718 "usage": map[string]any{
719 "cache_creation": map[string]any{
720 "ephemeral_1h_input_tokens": 0,
721 "ephemeral_5m_input_tokens": 0,
722 },
723 "cache_creation_input_tokens": 0,
724 "cache_read_input_tokens": 0,
725 "input_tokens": 5,
726 "output_tokens": 2,
727 "server_tool_use": map[string]any{
728 "web_search_requests": 0,
729 },
730 "service_tier": "standard",
731 },
732 }
733}
734
735func mockAnthropicWebSearchResponse() map[string]any {
736 return map[string]any{
737 "id": "msg_01WebSearch",
738 "type": "message",
739 "role": "assistant",
740 "model": "claude-sonnet-4-20250514",
741 "content": []any{
742 map[string]any{
743 "type": "server_tool_use",
744 "id": "srvtoolu_01",
745 "name": "web_search",
746 "input": map[string]any{"query": "latest AI news"},
747 "caller": map[string]any{"type": "direct"},
748 },
749 map[string]any{
750 "type": "web_search_tool_result",
751 "tool_use_id": "srvtoolu_01",
752 "caller": map[string]any{"type": "direct"},
753 "content": []any{
754 map[string]any{
755 "type": "web_search_result",
756 "url": "https://example.com/ai-news",
757 "title": "Latest AI News",
758 "encrypted_content": "encrypted_abc123",
759 "page_age": "2 hours ago",
760 },
761 map[string]any{
762 "type": "web_search_result",
763 "url": "https://example.com/ml-update",
764 "title": "ML Update",
765 "encrypted_content": "encrypted_def456",
766 "page_age": "",
767 },
768 },
769 },
770 map[string]any{
771 "type": "text",
772 "text": "Based on recent search results, here is the latest AI news.",
773 },
774 },
775 "stop_reason": "end_turn",
776 "stop_sequence": nil,
777 "usage": map[string]any{
778 "input_tokens": 100,
779 "output_tokens": 50,
780 "cache_creation_input_tokens": 0,
781 "cache_read_input_tokens": 0,
782 "server_tool_use": map[string]any{
783 "web_search_requests": 1,
784 },
785 },
786 }
787}
788
789func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
790 t.Parallel()
791
792 prompt := fantasy.Prompt{
793 // User message.
794 {
795 Role: fantasy.MessageRoleUser,
796 Content: []fantasy.MessagePart{
797 fantasy.TextPart{Text: "Search for the latest AI news"},
798 },
799 },
800 // Assistant message with a provider-executed tool call, its
801 // result, and trailing text. toResponseMessages routes
802 // provider-executed results into the assistant message, so
803 // the prompt already reflects that structure.
804 {
805 Role: fantasy.MessageRoleAssistant,
806 Content: []fantasy.MessagePart{
807 fantasy.ToolCallPart{
808 ToolCallID: "srvtoolu_01",
809 ToolName: "web_search",
810 Input: `{"query":"latest AI news"}`,
811 ProviderExecuted: true,
812 },
813 fantasy.ToolResultPart{
814 ToolCallID: "srvtoolu_01",
815 ProviderExecuted: true,
816 ProviderOptions: fantasy.ProviderOptions{
817 Name: &WebSearchResultMetadata{
818 Results: []WebSearchResultItem{
819 {
820 URL: "https://example.com/ai-news",
821 Title: "Latest AI News",
822 EncryptedContent: "encrypted_abc123",
823 PageAge: "2 hours ago",
824 },
825 {
826 URL: "https://example.com/ml-update",
827 Title: "ML Update",
828 EncryptedContent: "encrypted_def456",
829 },
830 },
831 },
832 },
833 },
834 fantasy.TextPart{Text: "Here is what I found."},
835 },
836 },
837 }
838
839 _, messages, warnings := toPrompt(prompt, true)
840
841 // No warnings expected; the provider-executed result is in the
842 // assistant message so there is no empty tool message to drop.
843 require.Empty(t, warnings)
844
845 // We should have a user message and an assistant message.
846 require.Len(t, messages, 2, "expected user + assistant messages")
847
848 assistantMsg := messages[1]
849 require.Len(t, assistantMsg.Content, 3,
850 "expected server_tool_use + web_search_tool_result + text")
851
852 // First content block: reconstructed server_tool_use.
853 serverToolUse := assistantMsg.Content[0]
854 require.NotNil(t, serverToolUse.OfServerToolUse,
855 "first block should be a server_tool_use")
856 require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
857 require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
858 serverToolUse.OfServerToolUse.Name)
859
860 // Second content block: reconstructed web_search_tool_result with
861 // encrypted_content preserved for multi-turn round-tripping.
862 webResult := assistantMsg.Content[1]
863 require.NotNil(t, webResult.OfWebSearchToolResult,
864 "second block should be a web_search_tool_result")
865 require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
866
867 results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
868 require.Len(t, results, 2)
869 require.Equal(t, "https://example.com/ai-news", results[0].URL)
870 require.Equal(t, "Latest AI News", results[0].Title)
871 require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
872 require.Equal(t, "https://example.com/ml-update", results[1].URL)
873 require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
874 // PageAge should be set for the first result and absent for the second.
875 require.True(t, results[0].PageAge.Valid())
876 require.Equal(t, "2 hours ago", results[0].PageAge.Value)
877 require.False(t, results[1].PageAge.Valid())
878
879 // Third content block: plain text.
880 require.NotNil(t, assistantMsg.Content[2].OfText)
881 require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
882}
883
884func TestGenerate_WebSearchResponse(t *testing.T) {
885 t.Parallel()
886
887 server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
888 defer server.Close()
889
890 provider, err := New(
891 WithAPIKey("test-api-key"),
892 WithBaseURL(server.URL),
893 )
894 require.NoError(t, err)
895
896 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
897 require.NoError(t, err)
898
899 resp, err := model.Generate(context.Background(), fantasy.Call{
900 Prompt: testPrompt(),
901 Tools: []fantasy.Tool{
902 WebSearchTool(nil),
903 },
904 })
905 require.NoError(t, err)
906
907 call := awaitAnthropicCall(t, calls)
908 require.Equal(t, "POST", call.method)
909 require.Equal(t, "/v1/messages", call.path)
910
911 // Walk the response content and categorise each item.
912 var (
913 toolCalls []fantasy.ToolCallContent
914 sources []fantasy.SourceContent
915 toolResults []fantasy.ToolResultContent
916 texts []fantasy.TextContent
917 )
918 for _, c := range resp.Content {
919 switch v := c.(type) {
920 case fantasy.ToolCallContent:
921 toolCalls = append(toolCalls, v)
922 case fantasy.SourceContent:
923 sources = append(sources, v)
924 case fantasy.ToolResultContent:
925 toolResults = append(toolResults, v)
926 case fantasy.TextContent:
927 texts = append(texts, v)
928 }
929 }
930
931 // ToolCallContent for the provider-executed web_search.
932 require.Len(t, toolCalls, 1)
933 require.True(t, toolCalls[0].ProviderExecuted)
934 require.Equal(t, "web_search", toolCalls[0].ToolName)
935 require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
936
937 // SourceContent entries for each search result.
938 require.Len(t, sources, 2)
939 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
940 require.Equal(t, "Latest AI News", sources[0].Title)
941 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
942 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
943 require.Equal(t, "ML Update", sources[1].Title)
944
945 // ToolResultContent with provider metadata preserving encrypted_content.
946 require.Len(t, toolResults, 1)
947 require.True(t, toolResults[0].ProviderExecuted)
948 require.Equal(t, "web_search", toolResults[0].ToolName)
949 require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
950
951 searchMeta, ok := toolResults[0].ProviderMetadata[Name]
952 require.True(t, ok, "providerMetadata should contain anthropic key")
953 webMeta, ok := searchMeta.(*WebSearchResultMetadata)
954 require.True(t, ok, "metadata should be *WebSearchResultMetadata")
955 require.Len(t, webMeta.Results, 2)
956 require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
957 require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
958 require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
959
960 // TextContent with the final answer.
961 require.Len(t, texts, 1)
962 require.Equal(t,
963 "Based on recent search results, here is the latest AI news.",
964 texts[0].Text,
965 )
966}
967
968func TestGenerate_WebSearchToolInRequest(t *testing.T) {
969 t.Parallel()
970
971 t.Run("basic web_search tool", func(t *testing.T) {
972 t.Parallel()
973
974 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
975 defer server.Close()
976
977 provider, err := New(
978 WithAPIKey("test-api-key"),
979 WithBaseURL(server.URL),
980 )
981 require.NoError(t, err)
982
983 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
984 require.NoError(t, err)
985
986 _, err = model.Generate(context.Background(), fantasy.Call{
987 Prompt: testPrompt(),
988 Tools: []fantasy.Tool{
989 WebSearchTool(nil),
990 },
991 })
992 require.NoError(t, err)
993
994 call := awaitAnthropicCall(t, calls)
995 tools, ok := call.body["tools"].([]any)
996 require.True(t, ok, "request body should have tools array")
997 require.Len(t, tools, 1)
998
999 tool, ok := tools[0].(map[string]any)
1000 require.True(t, ok)
1001 require.Equal(t, "web_search_20250305", tool["type"])
1002 })
1003
1004 t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
1005 t.Parallel()
1006
1007 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1008 defer server.Close()
1009
1010 provider, err := New(
1011 WithAPIKey("test-api-key"),
1012 WithBaseURL(server.URL),
1013 )
1014 require.NoError(t, err)
1015
1016 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1017 require.NoError(t, err)
1018
1019 _, err = model.Generate(context.Background(), fantasy.Call{
1020 Prompt: testPrompt(),
1021 Tools: []fantasy.Tool{
1022 WebSearchTool(&WebSearchToolOptions{
1023 AllowedDomains: []string{"example.com", "test.com"},
1024 }),
1025 },
1026 })
1027 require.NoError(t, err)
1028
1029 call := awaitAnthropicCall(t, calls)
1030 tools, ok := call.body["tools"].([]any)
1031 require.True(t, ok)
1032 require.Len(t, tools, 1)
1033
1034 tool, ok := tools[0].(map[string]any)
1035 require.True(t, ok)
1036 require.Equal(t, "web_search_20250305", tool["type"])
1037
1038 domains, ok := tool["allowed_domains"].([]any)
1039 require.True(t, ok, "tool should have allowed_domains")
1040 require.Len(t, domains, 2)
1041 require.Equal(t, "example.com", domains[0])
1042 require.Equal(t, "test.com", domains[1])
1043 })
1044
1045 t.Run("with max uses and user location", func(t *testing.T) {
1046 t.Parallel()
1047
1048 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1049 defer server.Close()
1050
1051 provider, err := New(
1052 WithAPIKey("test-api-key"),
1053 WithBaseURL(server.URL),
1054 )
1055 require.NoError(t, err)
1056
1057 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1058 require.NoError(t, err)
1059
1060 _, err = model.Generate(context.Background(), fantasy.Call{
1061 Prompt: testPrompt(),
1062 Tools: []fantasy.Tool{
1063 WebSearchTool(&WebSearchToolOptions{
1064 MaxUses: 5,
1065 UserLocation: &UserLocation{
1066 City: "San Francisco",
1067 Country: "US",
1068 },
1069 }),
1070 },
1071 })
1072 require.NoError(t, err)
1073
1074 call := awaitAnthropicCall(t, calls)
1075 tools, ok := call.body["tools"].([]any)
1076 require.True(t, ok)
1077 require.Len(t, tools, 1)
1078
1079 tool, ok := tools[0].(map[string]any)
1080 require.True(t, ok)
1081 require.Equal(t, "web_search_20250305", tool["type"])
1082
1083 // max_uses is serialized as a JSON number; json.Unmarshal
1084 // into map[string]any decodes numbers as float64.
1085 maxUses, ok := tool["max_uses"].(float64)
1086 require.True(t, ok, "tool should have max_uses")
1087 require.Equal(t, float64(5), maxUses)
1088
1089 userLoc, ok := tool["user_location"].(map[string]any)
1090 require.True(t, ok, "tool should have user_location")
1091 require.Equal(t, "San Francisco", userLoc["city"])
1092 require.Equal(t, "US", userLoc["country"])
1093 require.Equal(t, "approximate", userLoc["type"])
1094 })
1095
1096 t.Run("with max uses", func(t *testing.T) {
1097 t.Parallel()
1098
1099 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1100 defer server.Close()
1101
1102 provider, err := New(
1103 WithAPIKey("test-api-key"),
1104 WithBaseURL(server.URL),
1105 )
1106 require.NoError(t, err)
1107
1108 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1109 require.NoError(t, err)
1110
1111 _, err = model.Generate(context.Background(), fantasy.Call{
1112 Prompt: testPrompt(),
1113 Tools: []fantasy.Tool{
1114 WebSearchTool(&WebSearchToolOptions{
1115 MaxUses: 3,
1116 }),
1117 },
1118 })
1119 require.NoError(t, err)
1120
1121 call := awaitAnthropicCall(t, calls)
1122 tools, ok := call.body["tools"].([]any)
1123 require.True(t, ok)
1124 require.Len(t, tools, 1)
1125
1126 tool, ok := tools[0].(map[string]any)
1127 require.True(t, ok)
1128 require.Equal(t, "web_search_20250305", tool["type"])
1129
1130 maxUses, ok := tool["max_uses"].(float64)
1131 require.True(t, ok, "tool should have max_uses")
1132 require.Equal(t, float64(3), maxUses)
1133 })
1134
1135 t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1136 t.Parallel()
1137
1138 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1139 defer server.Close()
1140
1141 provider, err := New(
1142 WithAPIKey("test-api-key"),
1143 WithBaseURL(server.URL),
1144 )
1145 require.NoError(t, err)
1146
1147 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1148 require.NoError(t, err)
1149
1150 baseTool := WebSearchTool(&WebSearchToolOptions{
1151 MaxUses: 7,
1152 BlockedDomains: []string{"example.com", "test.com"},
1153 UserLocation: &UserLocation{
1154 City: "San Francisco",
1155 Region: "CA",
1156 Country: "US",
1157 Timezone: "America/Los_Angeles",
1158 },
1159 })
1160
1161 data, err := json.Marshal(baseTool)
1162 require.NoError(t, err)
1163
1164 var roundTripped fantasy.ProviderDefinedTool
1165 err = json.Unmarshal(data, &roundTripped)
1166 require.NoError(t, err)
1167
1168 _, err = model.Generate(context.Background(), fantasy.Call{
1169 Prompt: testPrompt(),
1170 Tools: []fantasy.Tool{roundTripped},
1171 })
1172 require.NoError(t, err)
1173
1174 call := awaitAnthropicCall(t, calls)
1175 tools, ok := call.body["tools"].([]any)
1176 require.True(t, ok)
1177 require.Len(t, tools, 1)
1178
1179 tool, ok := tools[0].(map[string]any)
1180 require.True(t, ok)
1181 require.Equal(t, "web_search_20250305", tool["type"])
1182
1183 domains, ok := tool["blocked_domains"].([]any)
1184 require.True(t, ok, "tool should have blocked_domains")
1185 require.Len(t, domains, 2)
1186 require.Equal(t, "example.com", domains[0])
1187 require.Equal(t, "test.com", domains[1])
1188
1189 maxUses, ok := tool["max_uses"].(float64)
1190 require.True(t, ok, "tool should have max_uses")
1191 require.Equal(t, float64(7), maxUses)
1192
1193 userLoc, ok := tool["user_location"].(map[string]any)
1194 require.True(t, ok, "tool should have user_location")
1195 require.Equal(t, "San Francisco", userLoc["city"])
1196 require.Equal(t, "CA", userLoc["region"])
1197 require.Equal(t, "US", userLoc["country"])
1198 require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1199 require.Equal(t, "approximate", userLoc["type"])
1200 })
1201}
1202
1203func TestAnyToStringSlice(t *testing.T) {
1204 t.Parallel()
1205
1206 t.Run("from string slice", func(t *testing.T) {
1207 t.Parallel()
1208
1209 got := anyToStringSlice([]string{"example.com", ""})
1210 require.Equal(t, []string{"example.com", ""}, got)
1211 })
1212
1213 t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1214 t.Parallel()
1215
1216 got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1217 require.Equal(t, []string{"example.com", "test.com"}, got)
1218 })
1219
1220 t.Run("unsupported type", func(t *testing.T) {
1221 t.Parallel()
1222
1223 got := anyToStringSlice("example.com")
1224 require.Nil(t, got)
1225 })
1226}
1227
1228func TestAnyToInt64(t *testing.T) {
1229 t.Parallel()
1230
1231 tests := []struct {
1232 name string
1233 input any
1234 want int64
1235 wantOK bool
1236 }{
1237 {name: "int64", input: int64(7), want: 7, wantOK: true},
1238 {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1239 {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1240 {name: "float64 non-integer", input: float64(7.5), wantOK: false},
1241 {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1242 {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1243 {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1244 {name: "json number float", input: json.Number("4.2"), wantOK: false},
1245 {name: "nan", input: math.NaN(), wantOK: false},
1246 {name: "inf", input: math.Inf(1), wantOK: false},
1247 {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1248 }
1249
1250 for _, tt := range tests {
1251 t.Run(tt.name, func(t *testing.T) {
1252 got, ok := anyToInt64(tt.input)
1253 require.Equal(t, tt.wantOK, ok)
1254 if tt.wantOK {
1255 require.Equal(t, tt.want, got)
1256 }
1257 })
1258 }
1259}
1260
1261func TestAnyToUserLocation(t *testing.T) {
1262 t.Parallel()
1263
1264 t.Run("pointer passthrough", func(t *testing.T) {
1265 t.Parallel()
1266
1267 input := &UserLocation{City: "San Francisco", Country: "US"}
1268 got := anyToUserLocation(input)
1269 require.Same(t, input, got)
1270 })
1271
1272 t.Run("struct value", func(t *testing.T) {
1273 t.Parallel()
1274
1275 got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1276 require.NotNil(t, got)
1277 require.Equal(t, "San Francisco", got.City)
1278 require.Equal(t, "US", got.Country)
1279 })
1280
1281 t.Run("map value", func(t *testing.T) {
1282 t.Parallel()
1283
1284 got := anyToUserLocation(map[string]any{
1285 "city": "San Francisco",
1286 "region": "CA",
1287 "country": "US",
1288 "timezone": "America/Los_Angeles",
1289 "type": "approximate",
1290 })
1291 require.NotNil(t, got)
1292 require.Equal(t, "San Francisco", got.City)
1293 require.Equal(t, "CA", got.Region)
1294 require.Equal(t, "US", got.Country)
1295 require.Equal(t, "America/Los_Angeles", got.Timezone)
1296 })
1297
1298 t.Run("empty map", func(t *testing.T) {
1299 t.Parallel()
1300
1301 got := anyToUserLocation(map[string]any{"type": "approximate"})
1302 require.Nil(t, got)
1303 })
1304
1305 t.Run("unsupported type", func(t *testing.T) {
1306 t.Parallel()
1307
1308 got := anyToUserLocation("San Francisco")
1309 require.Nil(t, got)
1310 })
1311}
1312
1313func TestStream_WebSearchResponse(t *testing.T) {
1314 t.Parallel()
1315
1316 // Build SSE chunks that simulate a web search streaming response.
1317 // The Anthropic SDK accumulates content blocks via
1318 // acc.Accumulate(event). We read the Content and ToolUseID
1319 // directly from struct fields instead of using AsAny(), which
1320 // avoids the SDK's re-marshal limitation that previously dropped
1321 // source data.
1322 webSearchResultContent, _ := json.Marshal([]any{
1323 map[string]any{
1324 "type": "web_search_result",
1325 "url": "https://example.com/ai-news",
1326 "title": "Latest AI News",
1327 "encrypted_content": "encrypted_abc123",
1328 "page_age": "2 hours ago",
1329 },
1330 })
1331
1332 chunks := []string{
1333 // message_start
1334 "event: message_start\n",
1335 `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1336 // Block 0: server_tool_use
1337 "event: content_block_start\n",
1338 `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1339 "event: content_block_stop\n",
1340 `data: {"type":"content_block_stop","index":0}` + "\n\n",
1341 // Block 1: web_search_tool_result
1342 "event: content_block_start\n",
1343 `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1344 "event: content_block_stop\n",
1345 `data: {"type":"content_block_stop","index":1}` + "\n\n",
1346 // Block 2: text
1347 "event: content_block_start\n",
1348 `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1349 "event: content_block_delta\n",
1350 `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1351 "event: content_block_stop\n",
1352 `data: {"type":"content_block_stop","index":2}` + "\n\n",
1353 // message_stop
1354 "event: message_stop\n",
1355 `data: {"type":"message_stop"}` + "\n\n",
1356 }
1357
1358 server, calls := newAnthropicStreamingServer(chunks)
1359 defer server.Close()
1360
1361 provider, err := New(
1362 WithAPIKey("test-api-key"),
1363 WithBaseURL(server.URL),
1364 )
1365 require.NoError(t, err)
1366
1367 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1368 require.NoError(t, err)
1369
1370 stream, err := model.Stream(context.Background(), fantasy.Call{
1371 Prompt: testPrompt(),
1372 Tools: []fantasy.Tool{
1373 WebSearchTool(nil),
1374 },
1375 })
1376 require.NoError(t, err)
1377
1378 var parts []fantasy.StreamPart
1379 stream(func(part fantasy.StreamPart) bool {
1380 parts = append(parts, part)
1381 return true
1382 })
1383
1384 _ = awaitAnthropicCall(t, calls)
1385
1386 // Collect parts by type for assertions.
1387 var (
1388 toolInputStarts []fantasy.StreamPart
1389 toolCalls []fantasy.StreamPart
1390 toolResults []fantasy.StreamPart
1391 sourceParts []fantasy.StreamPart
1392 textDeltas []fantasy.StreamPart
1393 )
1394 for _, p := range parts {
1395 switch p.Type {
1396 case fantasy.StreamPartTypeToolInputStart:
1397 toolInputStarts = append(toolInputStarts, p)
1398 case fantasy.StreamPartTypeToolCall:
1399 toolCalls = append(toolCalls, p)
1400 case fantasy.StreamPartTypeToolResult:
1401 toolResults = append(toolResults, p)
1402 case fantasy.StreamPartTypeSource:
1403 sourceParts = append(sourceParts, p)
1404 case fantasy.StreamPartTypeTextDelta:
1405 textDeltas = append(textDeltas, p)
1406 }
1407 }
1408
1409 // server_tool_use emits a ToolInputStart with ProviderExecuted.
1410 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1411 require.True(t, toolInputStarts[0].ProviderExecuted)
1412 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1413
1414 // server_tool_use emits a ToolCall with ProviderExecuted.
1415 require.NotEmpty(t, toolCalls, "should have a tool call")
1416 require.True(t, toolCalls[0].ProviderExecuted)
1417
1418 // web_search_tool_result always emits a ToolResult even when
1419 // the SDK drops source data. The ToolUseID comes from the raw
1420 // union field as a fallback.
1421 require.NotEmpty(t, toolResults, "should have a tool result")
1422 require.True(t, toolResults[0].ProviderExecuted)
1423 require.Equal(t, "web_search", toolResults[0].ToolCallName)
1424 require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1425 "tool result ID should match the tool_use_id")
1426
1427 // Source parts are now correctly emitted by reading struct fields
1428 // directly instead of using AsAny().
1429 require.Len(t, sourceParts, 1)
1430 require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1431 require.Equal(t, "Latest AI News", sourceParts[0].Title)
1432 require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1433
1434 // Text block emits a text delta.
1435 require.NotEmpty(t, textDeltas, "should have text deltas")
1436 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1437}
1438
1439func TestGenerate_ToolChoiceNone(t *testing.T) {
1440 t.Parallel()
1441
1442 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1443 defer server.Close()
1444
1445 provider, err := New(
1446 WithAPIKey("test-api-key"),
1447 WithBaseURL(server.URL),
1448 )
1449 require.NoError(t, err)
1450
1451 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1452 require.NoError(t, err)
1453
1454 toolChoiceNone := fantasy.ToolChoiceNone
1455 _, err = model.Generate(context.Background(), fantasy.Call{
1456 Prompt: testPrompt(),
1457 Tools: []fantasy.Tool{
1458 WebSearchTool(nil),
1459 },
1460 ToolChoice: &toolChoiceNone,
1461 })
1462 require.NoError(t, err)
1463
1464 call := awaitAnthropicCall(t, calls)
1465 toolChoice, ok := call.body["tool_choice"].(map[string]any)
1466 require.True(t, ok, "request body should have tool_choice")
1467 require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1468}
1469
1470// --- Computer Use Tests ---
1471
1472// jsonRoundTripTool simulates a JSON round-trip on a
1473// ProviderDefinedTool so that its Args map contains float64
1474// values (as json.Unmarshal produces) rather than the int64
1475// values that NewComputerUseTool stores directly. The
1476// production toBetaTools code asserts float64.
1477func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1478 t.Helper()
1479 pdt := tool.Definition()
1480 data, err := json.Marshal(pdt.Args)
1481 require.NoError(t, err)
1482 var args map[string]any
1483 require.NoError(t, json.Unmarshal(data, &args))
1484 pdt.Args = args
1485 return pdt
1486}
1487
1488func TestNewComputerUseTool(t *testing.T) {
1489 t.Parallel()
1490
1491 t.Run("creates tool with correct ID and name", func(t *testing.T) {
1492 t.Parallel()
1493 tool := NewComputerUseTool(ComputerUseToolOptions{
1494 DisplayWidthPx: 1920,
1495 DisplayHeightPx: 1080,
1496 ToolVersion: ComputerUse20250124,
1497 }, noopComputerRun).Definition()
1498 require.Equal(t, "anthropic.computer", tool.ID)
1499 require.Equal(t, "computer", tool.Name)
1500 require.Equal(t, int64(1920), tool.Args["display_width_px"])
1501 require.Equal(t, int64(1080), tool.Args["display_height_px"])
1502 require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1503 })
1504
1505 t.Run("includes optional fields when set", func(t *testing.T) {
1506 t.Parallel()
1507 displayNum := int64(1)
1508 enableZoom := true
1509 tool := NewComputerUseTool(ComputerUseToolOptions{
1510 DisplayWidthPx: 1024,
1511 DisplayHeightPx: 768,
1512 DisplayNumber: &displayNum,
1513 EnableZoom: &enableZoom,
1514 ToolVersion: ComputerUse20251124,
1515 CacheControl: &CacheControl{Type: "ephemeral"},
1516 }, noopComputerRun).Definition()
1517 require.Equal(t, int64(1), tool.Args["display_number"])
1518 require.Equal(t, true, tool.Args["enable_zoom"])
1519 require.NotNil(t, tool.Args["cache_control"])
1520 })
1521
1522 t.Run("omits optional fields when nil", func(t *testing.T) {
1523 t.Parallel()
1524 tool := NewComputerUseTool(ComputerUseToolOptions{
1525 DisplayWidthPx: 1920,
1526 DisplayHeightPx: 1080,
1527 ToolVersion: ComputerUse20250124,
1528 }, noopComputerRun).Definition()
1529 _, hasDisplayNum := tool.Args["display_number"]
1530 _, hasEnableZoom := tool.Args["enable_zoom"]
1531 _, hasCacheControl := tool.Args["cache_control"]
1532 require.False(t, hasDisplayNum)
1533 require.False(t, hasEnableZoom)
1534 require.False(t, hasCacheControl)
1535 })
1536}
1537
1538func TestIsComputerUseTool(t *testing.T) {
1539 t.Parallel()
1540
1541 t.Run("returns true for computer use tool", func(t *testing.T) {
1542 t.Parallel()
1543 tool := NewComputerUseTool(ComputerUseToolOptions{
1544 DisplayWidthPx: 1920,
1545 DisplayHeightPx: 1080,
1546 ToolVersion: ComputerUse20250124,
1547 }, noopComputerRun)
1548 require.True(t, IsComputerUseTool(tool.Definition()))
1549 })
1550
1551 t.Run("returns false for function tool", func(t *testing.T) {
1552 t.Parallel()
1553 tool := fantasy.FunctionTool{
1554 Name: "test",
1555 Description: "test tool",
1556 }
1557 require.False(t, IsComputerUseTool(tool))
1558 })
1559
1560 t.Run("returns false for other provider defined tool", func(t *testing.T) {
1561 t.Parallel()
1562 tool := fantasy.ProviderDefinedTool{
1563 ID: "other.tool",
1564 Name: "other",
1565 }
1566 require.False(t, IsComputerUseTool(tool))
1567 })
1568}
1569
1570func TestNeedsBetaAPI(t *testing.T) {
1571 t.Parallel()
1572
1573 lm := languageModel{options: options{}}
1574
1575 t.Run("returns false for empty tools", func(t *testing.T) {
1576 t.Parallel()
1577 _, _, _, betaFlags := lm.toTools(nil, nil, false)
1578 require.Empty(t, betaFlags)
1579 _, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1580 require.Empty(t, betaFlags)
1581 })
1582
1583 t.Run("returns false for only function tools", func(t *testing.T) {
1584 t.Parallel()
1585 tools := []fantasy.Tool{
1586 fantasy.FunctionTool{Name: "test"},
1587 }
1588 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1589 require.Empty(t, betaFlags)
1590 })
1591
1592 t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1593 t.Parallel()
1594 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1595 DisplayWidthPx: 1920,
1596 DisplayHeightPx: 1080,
1597 ToolVersion: ComputerUse20250124,
1598 }, noopComputerRun))
1599 tools := []fantasy.Tool{
1600 fantasy.FunctionTool{Name: "test"},
1601 cuTool,
1602 }
1603 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1604 require.NotEmpty(t, betaFlags)
1605 })
1606}
1607
1608func TestComputerUseToolJSON(t *testing.T) {
1609 t.Parallel()
1610
1611 t.Run("builds JSON for version 20250124", func(t *testing.T) {
1612 t.Parallel()
1613 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1614 DisplayWidthPx: 1920,
1615 DisplayHeightPx: 1080,
1616 ToolVersion: ComputerUse20250124,
1617 }, noopComputerRun))
1618 data, err := computerUseToolJSON(cuTool)
1619 require.NoError(t, err)
1620 var m map[string]any
1621 require.NoError(t, json.Unmarshal(data, &m))
1622 require.Equal(t, "computer_20250124", m["type"])
1623 require.Equal(t, "computer", m["name"])
1624 require.InDelta(t, 1920, m["display_width_px"], 0)
1625 require.InDelta(t, 1080, m["display_height_px"], 0)
1626 })
1627
1628 t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1629 t.Parallel()
1630 enableZoom := true
1631 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1632 DisplayWidthPx: 1024,
1633 DisplayHeightPx: 768,
1634 EnableZoom: &enableZoom,
1635 ToolVersion: ComputerUse20251124,
1636 }, noopComputerRun))
1637 data, err := computerUseToolJSON(cuTool)
1638 require.NoError(t, err)
1639 var m map[string]any
1640 require.NoError(t, json.Unmarshal(data, &m))
1641 require.Equal(t, "computer_20251124", m["type"])
1642 require.Equal(t, true, m["enable_zoom"])
1643 })
1644
1645 t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1646 t.Parallel()
1647 // Direct construction stores int64 values.
1648 cuTool := NewComputerUseTool(ComputerUseToolOptions{
1649 DisplayWidthPx: 1920,
1650 DisplayHeightPx: 1080,
1651 ToolVersion: ComputerUse20250124,
1652 }, noopComputerRun)
1653 data, err := computerUseToolJSON(cuTool.Definition())
1654 require.NoError(t, err)
1655 var m map[string]any
1656 require.NoError(t, json.Unmarshal(data, &m))
1657 require.InDelta(t, 1920, m["display_width_px"], 0)
1658 })
1659
1660 t.Run("returns error when version is missing", func(t *testing.T) {
1661 t.Parallel()
1662 pdt := fantasy.ProviderDefinedTool{
1663 ID: "anthropic.computer",
1664 Name: "computer",
1665 Args: map[string]any{
1666 "display_width_px": float64(1920),
1667 "display_height_px": float64(1080),
1668 },
1669 }
1670 _, err := computerUseToolJSON(pdt)
1671 require.Error(t, err)
1672 require.Contains(t, err.Error(), "tool_version arg is missing")
1673 })
1674
1675 t.Run("returns error for unsupported version", func(t *testing.T) {
1676 t.Parallel()
1677 pdt := fantasy.ProviderDefinedTool{
1678 ID: "anthropic.computer",
1679 Name: "computer",
1680 Args: map[string]any{
1681 "display_width_px": float64(1920),
1682 "display_height_px": float64(1080),
1683 "tool_version": "computer_99991231",
1684 },
1685 }
1686 _, err := computerUseToolJSON(pdt)
1687 require.Error(t, err)
1688 require.Contains(t, err.Error(), "unsupported")
1689 })
1690}
1691
1692func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1693 t.Parallel()
1694
1695 t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1696 t.Parallel()
1697 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1698 require.Error(t, err)
1699 require.Contains(t, err.Error(), "coordinate")
1700 })
1701
1702 t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1703 t.Parallel()
1704 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1705 require.Error(t, err)
1706 require.Contains(t, err.Error(), "coordinate")
1707 })
1708
1709 t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1710 t.Parallel()
1711 _, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1712 require.Error(t, err)
1713 require.Contains(t, err.Error(), "start_coordinate")
1714 })
1715
1716 t.Run("rejects region with 3 elements", func(t *testing.T) {
1717 t.Parallel()
1718 _, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1719 require.Error(t, err)
1720 require.Contains(t, err.Error(), "region")
1721 })
1722
1723 t.Run("accepts valid coordinate", func(t *testing.T) {
1724 t.Parallel()
1725 result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1726 require.NoError(t, err)
1727 require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1728 })
1729
1730 t.Run("accepts absent optional arrays", func(t *testing.T) {
1731 t.Parallel()
1732 result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1733 require.NoError(t, err)
1734 require.Equal(t, ActionScreenshot, result.Action)
1735 })
1736}
1737
1738func TestToTools_RawJSON(t *testing.T) {
1739 t.Parallel()
1740
1741 lm := languageModel{options: options{}}
1742
1743 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1744 DisplayWidthPx: 1920,
1745 DisplayHeightPx: 1080,
1746 ToolVersion: ComputerUse20250124,
1747 }, noopComputerRun))
1748
1749 tools := []fantasy.Tool{
1750 fantasy.FunctionTool{
1751 Name: "weather",
1752 Description: "Get weather",
1753 InputSchema: map[string]any{
1754 "properties": map[string]any{
1755 "location": map[string]any{"type": "string"},
1756 },
1757 "required": []string{"location"},
1758 },
1759 },
1760 WebSearchTool(nil),
1761 cuTool,
1762 }
1763
1764 rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1765
1766 require.Len(t, rawTools, 3)
1767 require.Nil(t, toolChoice)
1768 require.Empty(t, warnings)
1769 require.NotEmpty(t, betaFlags)
1770
1771 // Verify each raw tool is valid JSON.
1772 for i, raw := range rawTools {
1773 var m map[string]any
1774 require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1775 }
1776
1777 // Check function tool.
1778 var funcTool map[string]any
1779 require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1780 require.Equal(t, "weather", funcTool["name"])
1781
1782 // Check web search tool.
1783 var webTool map[string]any
1784 require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1785 require.Equal(t, "web_search_20250305", webTool["type"])
1786
1787 // Check computer use tool.
1788 var cuToolJSON map[string]any
1789 require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1790 require.Equal(t, "computer_20250124", cuToolJSON["type"])
1791 require.Equal(t, "computer", cuToolJSON["name"])
1792}
1793
1794func TestGenerate_BetaAPI(t *testing.T) {
1795 t.Parallel()
1796
1797 t.Run("sends beta header for computer use", func(t *testing.T) {
1798 t.Parallel()
1799
1800 var capturedHeaders http.Header
1801 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1802 capturedHeaders = r.Header.Clone()
1803 w.Header().Set("Content-Type", "application/json")
1804 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1805 }))
1806 defer server.Close()
1807
1808 provider, err := New(
1809 WithAPIKey("test-api-key"),
1810 WithBaseURL(server.URL),
1811 )
1812 require.NoError(t, err)
1813
1814 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1815 require.NoError(t, err)
1816
1817 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1818 DisplayWidthPx: 1920,
1819 DisplayHeightPx: 1080,
1820 ToolVersion: ComputerUse20250124,
1821 }, noopComputerRun))
1822
1823 _, err = model.Generate(context.Background(), fantasy.Call{
1824 Prompt: testPrompt(),
1825 Tools: []fantasy.Tool{cuTool},
1826 })
1827 require.NoError(t, err)
1828 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1829 })
1830
1831 t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1832 t.Parallel()
1833
1834 var capturedHeaders http.Header
1835 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1836 capturedHeaders = r.Header.Clone()
1837 w.Header().Set("Content-Type", "application/json")
1838 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1839 }))
1840 defer server.Close()
1841
1842 provider, err := New(
1843 WithAPIKey("test-api-key"),
1844 WithBaseURL(server.URL),
1845 )
1846 require.NoError(t, err)
1847
1848 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1849 require.NoError(t, err)
1850
1851 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1852 DisplayWidthPx: 1920,
1853 DisplayHeightPx: 1080,
1854 ToolVersion: ComputerUse20251124,
1855 }, noopComputerRun))
1856
1857 _, err = model.Generate(context.Background(), fantasy.Call{
1858 Prompt: testPrompt(),
1859 Tools: []fantasy.Tool{cuTool},
1860 })
1861 require.NoError(t, err)
1862 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1863 })
1864
1865 t.Run("returns tool use from beta response", func(t *testing.T) {
1866 t.Parallel()
1867
1868 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1869 w.Header().Set("Content-Type", "application/json")
1870 _ = json.NewEncoder(w).Encode(map[string]any{
1871 "id": "msg_01Test",
1872 "type": "message",
1873 "role": "assistant",
1874 "model": "claude-sonnet-4-20250514",
1875 "content": []any{
1876 map[string]any{
1877 "type": "tool_use",
1878 "id": "toolu_01",
1879 "name": "computer",
1880 "input": map[string]any{"action": "screenshot"},
1881 },
1882 },
1883 "stop_reason": "tool_use",
1884 "usage": map[string]any{
1885 "input_tokens": 10,
1886 "output_tokens": 5,
1887 "cache_creation": map[string]any{
1888 "ephemeral_1h_input_tokens": 0,
1889 "ephemeral_5m_input_tokens": 0,
1890 },
1891 "cache_creation_input_tokens": 0,
1892 "cache_read_input_tokens": 0,
1893 "server_tool_use": map[string]any{
1894 "web_search_requests": 0,
1895 },
1896 "service_tier": "standard",
1897 },
1898 })
1899 }))
1900 defer server.Close()
1901
1902 provider, err := New(
1903 WithAPIKey("test-api-key"),
1904 WithBaseURL(server.URL),
1905 )
1906 require.NoError(t, err)
1907
1908 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1909 require.NoError(t, err)
1910
1911 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1912 DisplayWidthPx: 1920,
1913 DisplayHeightPx: 1080,
1914 ToolVersion: ComputerUse20250124,
1915 }, noopComputerRun))
1916
1917 resp, err := model.Generate(context.Background(), fantasy.Call{
1918 Prompt: testPrompt(),
1919 Tools: []fantasy.Tool{cuTool},
1920 })
1921 require.NoError(t, err)
1922
1923 toolCalls := resp.Content.ToolCalls()
1924 require.Len(t, toolCalls, 1)
1925 require.Equal(t, "computer", toolCalls[0].ToolName)
1926 require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1927 require.Contains(t, toolCalls[0].Input, "screenshot")
1928 require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1929
1930 // Verify typed parsing works on the tool call input.
1931 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1932 require.NoError(t, err)
1933 require.Equal(t, ActionScreenshot, parsed.Action)
1934 })
1935}
1936
1937func TestStream_BetaAPI(t *testing.T) {
1938 t.Parallel()
1939
1940 t.Run("streams via beta API for computer use", func(t *testing.T) {
1941 t.Parallel()
1942
1943 var capturedHeaders http.Header
1944 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1945 capturedHeaders = r.Header.Clone()
1946 w.Header().Set("Content-Type", "text/event-stream")
1947 w.Header().Set("Cache-Control", "no-cache")
1948 w.WriteHeader(http.StatusOK)
1949 chunks := []string{
1950 "event: message_start\n",
1951 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1952 "event: message_stop\n",
1953 "data: {\"type\":\"message_stop\"}\n\n",
1954 }
1955 for _, chunk := range chunks {
1956 _, _ = fmt.Fprint(w, chunk)
1957 if flusher, ok := w.(http.Flusher); ok {
1958 flusher.Flush()
1959 }
1960 }
1961 }))
1962 defer server.Close()
1963
1964 provider, err := New(
1965 WithAPIKey("test-api-key"),
1966 WithBaseURL(server.URL),
1967 )
1968 require.NoError(t, err)
1969
1970 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1971 require.NoError(t, err)
1972
1973 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1974 DisplayWidthPx: 1920,
1975 DisplayHeightPx: 1080,
1976 ToolVersion: ComputerUse20250124,
1977 }, noopComputerRun))
1978
1979 stream, err := model.Stream(context.Background(), fantasy.Call{
1980 Prompt: testPrompt(),
1981 Tools: []fantasy.Tool{cuTool},
1982 })
1983 require.NoError(t, err)
1984
1985 stream(func(fantasy.StreamPart) bool { return true })
1986
1987 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1988 })
1989
1990 t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1991 t.Parallel()
1992
1993 var capturedHeaders http.Header
1994 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1995 capturedHeaders = r.Header.Clone()
1996 w.Header().Set("Content-Type", "text/event-stream")
1997 w.Header().Set("Cache-Control", "no-cache")
1998 w.WriteHeader(http.StatusOK)
1999 chunks := []string{
2000 "event: message_start\n",
2001 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
2002 "event: message_stop\n",
2003 "data: {\"type\":\"message_stop\"}\n\n",
2004 }
2005 for _, chunk := range chunks {
2006 _, _ = fmt.Fprint(w, chunk)
2007 if flusher, ok := w.(http.Flusher); ok {
2008 flusher.Flush()
2009 }
2010 }
2011 }))
2012 defer server.Close()
2013
2014 provider, err := New(
2015 WithAPIKey("test-api-key"),
2016 WithBaseURL(server.URL),
2017 )
2018 require.NoError(t, err)
2019
2020 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2021 require.NoError(t, err)
2022
2023 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
2024 DisplayWidthPx: 1920,
2025 DisplayHeightPx: 1080,
2026 ToolVersion: ComputerUse20251124,
2027 }, noopComputerRun))
2028
2029 stream, err := model.Stream(context.Background(), fantasy.Call{
2030 Prompt: testPrompt(),
2031 Tools: []fantasy.Tool{cuTool},
2032 })
2033 require.NoError(t, err)
2034
2035 stream(func(fantasy.StreamPart) bool { return true })
2036
2037 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
2038 })
2039}
2040
2041// TestGenerate_ComputerUseTool runs a multi-turn computer use session
2042// via model.Generate, passing the ExecutableProviderTool directly into
2043// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
2044// walks through a scripted sequence of actions — screenshot, click,
2045// type, key, scroll — then finishes with a text reply. Each turn the
2046// test parses the tool call, builds a screenshot result, and appends
2047// both to the prompt for the next request.
2048func TestGenerate_ComputerUseTool(t *testing.T) {
2049 t.Parallel()
2050
2051 type actionStep struct {
2052 input map[string]any
2053 want ComputerUseInput
2054 }
2055 steps := []actionStep{
2056 {
2057 input: map[string]any{"action": "screenshot"},
2058 want: ComputerUseInput{Action: ActionScreenshot},
2059 },
2060 {
2061 input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
2062 want: ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
2063 },
2064 {
2065 input: map[string]any{"action": "type", "text": "hello world"},
2066 want: ComputerUseInput{Action: ActionType, Text: "hello world"},
2067 },
2068 {
2069 input: map[string]any{"action": "key", "text": "Return"},
2070 want: ComputerUseInput{Action: ActionKey, Text: "Return"},
2071 },
2072 {
2073 input: map[string]any{
2074 "action": "scroll",
2075 "coordinate": []any{500, 300},
2076 "scroll_direction": "down",
2077 "scroll_amount": 3,
2078 },
2079 want: ComputerUseInput{
2080 Action: ActionScroll,
2081 Coordinate: [2]int64{500, 300},
2082 ScrollDirection: "down",
2083 ScrollAmount: 3,
2084 },
2085 },
2086 {
2087 input: map[string]any{"action": "screenshot"},
2088 want: ComputerUseInput{Action: ActionScreenshot},
2089 },
2090 }
2091
2092 var (
2093 requestIdx int
2094 betaHeaders []string
2095 )
2096 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2097 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2098 idx := requestIdx
2099 requestIdx++
2100
2101 w.Header().Set("Content-Type", "application/json")
2102 if idx < len(steps) {
2103 _ = json.NewEncoder(w).Encode(map[string]any{
2104 "id": fmt.Sprintf("msg_%02d", idx),
2105 "type": "message",
2106 "role": "assistant",
2107 "model": "claude-sonnet-4-20250514",
2108 "content": []any{map[string]any{
2109 "type": "tool_use",
2110 "id": fmt.Sprintf("toolu_%02d", idx),
2111 "name": "computer",
2112 "input": steps[idx].input,
2113 }},
2114 "stop_reason": "tool_use",
2115 "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
2116 })
2117 return
2118 }
2119 _ = json.NewEncoder(w).Encode(map[string]any{
2120 "id": "msg_final",
2121 "type": "message",
2122 "role": "assistant",
2123 "model": "claude-sonnet-4-20250514",
2124 "content": []any{map[string]any{
2125 "type": "text",
2126 "text": "Done! I have completed all the requested actions.",
2127 }},
2128 "stop_reason": "end_turn",
2129 "usage": map[string]any{"input_tokens": 10, "output_tokens": 15},
2130 })
2131 }))
2132 defer server.Close()
2133
2134 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2135 require.NoError(t, err)
2136
2137 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2138 require.NoError(t, err)
2139
2140 // Pass the ExecutableProviderTool directly — the whole point is
2141 // to verify that the Tool interface works without unwrapping.
2142 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2143 DisplayWidthPx: 1920,
2144 DisplayHeightPx: 1080,
2145 ToolVersion: ComputerUse20250124,
2146 }, noopComputerRun)
2147
2148 var got []ComputerUseInput
2149 prompt := testPrompt()
2150 fakePNG := []byte("fake-screenshot-png")
2151
2152 for turn := 0; turn <= len(steps); turn++ {
2153 resp, err := model.Generate(context.Background(), fantasy.Call{
2154 Prompt: prompt,
2155 Tools: []fantasy.Tool{cuTool},
2156 })
2157 require.NoError(t, err, "turn %d", turn)
2158
2159 if resp.FinishReason != fantasy.FinishReasonToolCalls {
2160 require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2161 require.Contains(t, resp.Content.Text(), "Done")
2162 break
2163 }
2164
2165 toolCalls := resp.Content.ToolCalls()
2166 require.Len(t, toolCalls, 1, "turn %d", turn)
2167 require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2168
2169 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2170 require.NoError(t, err, "turn %d", turn)
2171 got = append(got, parsed)
2172
2173 // Build the next prompt: append the assistant tool-call turn
2174 // and the user screenshot-result turn.
2175 prompt = append(prompt,
2176 fantasy.Message{
2177 Role: fantasy.MessageRoleAssistant,
2178 Content: []fantasy.MessagePart{
2179 fantasy.ToolCallPart{
2180 ToolCallID: toolCalls[0].ToolCallID,
2181 ToolName: toolCalls[0].ToolName,
2182 Input: toolCalls[0].Input,
2183 },
2184 },
2185 },
2186 fantasy.Message{
2187 // Use MessageRoleTool for tool results — this matches
2188 // what the agent loop produces.
2189 Role: fantasy.MessageRoleTool,
2190 Content: []fantasy.MessagePart{
2191 NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2192 },
2193 },
2194 )
2195 }
2196
2197 // Every scripted action was received and parsed correctly.
2198 require.Len(t, got, len(steps))
2199 for i, step := range steps {
2200 require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2201 require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2202 require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2203 require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2204 require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2205 }
2206
2207 // Beta header was sent on every request.
2208 require.Len(t, betaHeaders, len(steps)+1)
2209 for i, h := range betaHeaders {
2210 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2211 }
2212}
2213
2214// TestStream_ComputerUseTool runs a multi-turn computer use session
2215// via model.Stream, verifying that the ExecutableProviderTool works
2216// through the streaming path end-to-end.
2217func TestStream_ComputerUseTool(t *testing.T) {
2218 t.Parallel()
2219
2220 type streamStep struct {
2221 input map[string]any
2222 wantAction ComputerAction
2223 }
2224 steps := []streamStep{
2225 {input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2226 {input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2227 {input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2228 }
2229
2230 var (
2231 requestIdx int
2232 betaHeaders []string
2233 )
2234
2235 // streamToolUseChunks returns SSE chunks for a single
2236 // computer-use tool_use content block.
2237 streamToolUseChunks := func(id string, input map[string]any) []string {
2238 inputJSON, _ := json.Marshal(input)
2239 escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2240 return []string{
2241 "event: message_start\n",
2242 `data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2243 "event: content_block_start\n",
2244 `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2245 "event: content_block_delta\n",
2246 `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2247 "event: content_block_stop\n",
2248 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2249 "event: message_delta\n",
2250 `data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2251 "event: message_stop\n",
2252 `data: {"type":"message_stop"}` + "\n\n",
2253 }
2254 }
2255
2256 streamTextChunks := func() []string {
2257 return []string{
2258 "event: message_start\n",
2259 `data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2260 "event: content_block_start\n",
2261 `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2262 "event: content_block_delta\n",
2263 `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2264 "event: content_block_stop\n",
2265 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2266 "event: message_delta\n",
2267 `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2268 "event: message_stop\n",
2269 `data: {"type":"message_stop"}` + "\n\n",
2270 }
2271 }
2272
2273 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2274 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2275 idx := requestIdx
2276 requestIdx++
2277
2278 w.Header().Set("Content-Type", "text/event-stream")
2279 w.Header().Set("Cache-Control", "no-cache")
2280 w.WriteHeader(http.StatusOK)
2281
2282 var chunks []string
2283 if idx < len(steps) {
2284 chunks = streamToolUseChunks(
2285 fmt.Sprintf("toolu_%02d", idx),
2286 steps[idx].input,
2287 )
2288 } else {
2289 chunks = streamTextChunks()
2290 }
2291 for _, chunk := range chunks {
2292 _, _ = fmt.Fprint(w, chunk)
2293 if f, ok := w.(http.Flusher); ok {
2294 f.Flush()
2295 }
2296 }
2297 }))
2298 defer server.Close()
2299
2300 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2301 require.NoError(t, err)
2302
2303 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2304 require.NoError(t, err)
2305
2306 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2307 DisplayWidthPx: 1920,
2308 DisplayHeightPx: 1080,
2309 ToolVersion: ComputerUse20250124,
2310 }, noopComputerRun)
2311
2312 var gotActions []ComputerAction
2313 prompt := testPrompt()
2314 fakePNG := []byte("fake-screenshot-png")
2315
2316 for turn := 0; turn <= len(steps); turn++ {
2317 stream, err := model.Stream(context.Background(), fantasy.Call{
2318 Prompt: prompt,
2319 Tools: []fantasy.Tool{cuTool},
2320 })
2321 require.NoError(t, err, "turn %d", turn)
2322
2323 var (
2324 toolCallName string
2325 toolCallID string
2326 toolCallInput string
2327 finishReason fantasy.FinishReason
2328 gotText string
2329 )
2330 stream(func(part fantasy.StreamPart) bool {
2331 switch part.Type {
2332 case fantasy.StreamPartTypeToolCall:
2333 toolCallName = part.ToolCallName
2334 toolCallID = part.ID
2335 toolCallInput = part.ToolCallInput
2336 case fantasy.StreamPartTypeFinish:
2337 finishReason = part.FinishReason
2338 case fantasy.StreamPartTypeTextDelta:
2339 gotText += part.Delta
2340 }
2341 return true
2342 })
2343
2344 if finishReason != fantasy.FinishReasonToolCalls {
2345 require.Contains(t, gotText, "All done")
2346 break
2347 }
2348
2349 require.Equal(t, "computer", toolCallName, "turn %d", turn)
2350
2351 parsed, err := ParseComputerUseInput(toolCallInput)
2352 require.NoError(t, err, "turn %d", turn)
2353 gotActions = append(gotActions, parsed.Action)
2354
2355 prompt = append(prompt,
2356 fantasy.Message{
2357 Role: fantasy.MessageRoleAssistant,
2358 Content: []fantasy.MessagePart{
2359 fantasy.ToolCallPart{
2360 ToolCallID: toolCallID,
2361 ToolName: toolCallName,
2362 Input: toolCallInput,
2363 },
2364 },
2365 },
2366 fantasy.Message{
2367 // Use MessageRoleTool for tool results — this matches
2368 // what the agent loop produces.
2369 Role: fantasy.MessageRoleTool,
2370 Content: []fantasy.MessagePart{
2371 NewComputerUseScreenshotResult(toolCallID, fakePNG),
2372 },
2373 },
2374 )
2375 }
2376
2377 require.Len(t, gotActions, len(steps))
2378 for i, step := range steps {
2379 require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2380 }
2381
2382 require.Len(t, betaHeaders, len(steps)+1)
2383 for i, h := range betaHeaders {
2384 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2385 }
2386}