1package anthropic
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "math"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/anthropic-sdk-go"
17 "github.com/stretchr/testify/require"
18)
19
20// noopComputerRun is a no-op run function for tests that only need
21// to inspect the tool definition, not execute it.
22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
23 return fantasy.ToolResponse{}, nil
24}
25
26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
27 t.Parallel()
28
29 t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
30 t.Parallel()
31
32 prompt := fantasy.Prompt{
33 {
34 Role: fantasy.MessageRoleUser,
35 Content: []fantasy.MessagePart{
36 fantasy.TextPart{Text: "Hello"},
37 },
38 },
39 {
40 Role: fantasy.MessageRoleAssistant,
41 Content: []fantasy.MessagePart{
42 fantasy.ReasoningPart{
43 Text: "Let me think about this...",
44 ProviderOptions: fantasy.ProviderOptions{
45 Name: &ReasoningOptionMetadata{
46 Signature: "abc123",
47 },
48 },
49 },
50 },
51 },
52 }
53
54 systemBlocks, messages, warnings := toPrompt(prompt, true)
55
56 require.Empty(t, systemBlocks)
57 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
58 require.Len(t, warnings, 1)
59 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
60 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
61 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
62 })
63
64 t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
65 t.Parallel()
66
67 prompt := fantasy.Prompt{
68 {
69 Role: fantasy.MessageRoleUser,
70 Content: []fantasy.MessagePart{
71 fantasy.TextPart{Text: "Hello"},
72 },
73 },
74 {
75 Role: fantasy.MessageRoleAssistant,
76 Content: []fantasy.MessagePart{
77 fantasy.ReasoningPart{
78 Text: "Let me think about this...",
79 ProviderOptions: fantasy.ProviderOptions{
80 Name: &ReasoningOptionMetadata{
81 Signature: "def456",
82 },
83 },
84 },
85 },
86 },
87 }
88
89 systemBlocks, messages, warnings := toPrompt(prompt, false)
90
91 require.Empty(t, systemBlocks)
92 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
93 require.Len(t, warnings, 2)
94 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
95 require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
96 require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
97 require.Contains(t, warnings[1].Message, "dropping empty assistant message")
98 })
99
100 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
101 t.Parallel()
102
103 prompt := fantasy.Prompt{
104 {
105 Role: fantasy.MessageRoleUser,
106 Content: []fantasy.MessagePart{
107 fantasy.TextPart{Text: "Hello"},
108 },
109 },
110 {
111 Role: fantasy.MessageRoleAssistant,
112 Content: []fantasy.MessagePart{},
113 },
114 }
115
116 systemBlocks, messages, warnings := toPrompt(prompt, true)
117
118 require.Empty(t, systemBlocks)
119 require.Len(t, messages, 1, "should only have user message")
120 require.Len(t, warnings, 1)
121 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
122 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
123 })
124
125 t.Run("should keep assistant messages with text content", func(t *testing.T) {
126 t.Parallel()
127
128 prompt := fantasy.Prompt{
129 {
130 Role: fantasy.MessageRoleUser,
131 Content: []fantasy.MessagePart{
132 fantasy.TextPart{Text: "Hello"},
133 },
134 },
135 {
136 Role: fantasy.MessageRoleAssistant,
137 Content: []fantasy.MessagePart{
138 fantasy.TextPart{Text: "Hi there!"},
139 },
140 },
141 }
142
143 systemBlocks, messages, warnings := toPrompt(prompt, true)
144
145 require.Empty(t, systemBlocks)
146 require.Len(t, messages, 2, "should have both user and assistant messages")
147 require.Empty(t, warnings)
148 })
149
150 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
151 t.Parallel()
152
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.TextPart{Text: "What's the weather?"},
158 },
159 },
160 {
161 Role: fantasy.MessageRoleAssistant,
162 Content: []fantasy.MessagePart{
163 fantasy.ToolCallPart{
164 ToolCallID: "call_123",
165 ToolName: "get_weather",
166 Input: `{"location":"NYC"}`,
167 },
168 },
169 },
170 }
171
172 systemBlocks, messages, warnings := toPrompt(prompt, true)
173
174 require.Empty(t, systemBlocks)
175 require.Len(t, messages, 2, "should have both user and assistant messages")
176 require.Empty(t, warnings)
177 })
178
179 t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
180 t.Parallel()
181
182 prompt := fantasy.Prompt{
183 {
184 Role: fantasy.MessageRoleUser,
185 Content: []fantasy.MessagePart{
186 fantasy.TextPart{Text: "Hi"},
187 },
188 },
189 {
190 Role: fantasy.MessageRoleAssistant,
191 Content: []fantasy.MessagePart{
192 fantasy.ToolCallPart{
193 ToolCallID: "call_123",
194 ToolName: "get_weather",
195 Input: "{not-json",
196 },
197 },
198 },
199 }
200
201 systemBlocks, messages, warnings := toPrompt(prompt, true)
202
203 require.Empty(t, systemBlocks)
204 require.Len(t, messages, 1, "should only have user message")
205 require.Len(t, warnings, 1)
206 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
207 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
208 })
209
210 t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
211 t.Parallel()
212
213 prompt := fantasy.Prompt{
214 {
215 Role: fantasy.MessageRoleUser,
216 Content: []fantasy.MessagePart{
217 fantasy.TextPart{Text: "Hello"},
218 },
219 },
220 {
221 Role: fantasy.MessageRoleAssistant,
222 Content: []fantasy.MessagePart{
223 fantasy.ReasoningPart{
224 Text: "Let me think...",
225 ProviderOptions: fantasy.ProviderOptions{
226 Name: &ReasoningOptionMetadata{
227 Signature: "abc123",
228 },
229 },
230 },
231 fantasy.TextPart{Text: "Hi there!"},
232 },
233 },
234 }
235
236 systemBlocks, messages, warnings := toPrompt(prompt, true)
237
238 require.Empty(t, systemBlocks)
239 require.Len(t, messages, 2, "should have both user and assistant messages")
240 require.Empty(t, warnings)
241 })
242
243 t.Run("should keep user messages with image content", func(t *testing.T) {
244 t.Parallel()
245
246 prompt := fantasy.Prompt{
247 {
248 Role: fantasy.MessageRoleUser,
249 Content: []fantasy.MessagePart{
250 fantasy.FilePart{
251 Data: []byte{0x01, 0x02, 0x03},
252 MediaType: "image/png",
253 },
254 },
255 },
256 }
257
258 systemBlocks, messages, warnings := toPrompt(prompt, true)
259
260 require.Empty(t, systemBlocks)
261 require.Len(t, messages, 1)
262 require.Empty(t, warnings)
263 })
264
265 t.Run("should keep user messages with PDF content", func(t *testing.T) {
266 t.Parallel()
267
268 prompt := fantasy.Prompt{
269 {
270 Role: fantasy.MessageRoleUser,
271 Content: []fantasy.MessagePart{
272 fantasy.FilePart{
273 Data: []byte("fake pdf data"),
274 MediaType: "application/pdf",
275 },
276 },
277 },
278 }
279
280 systemBlocks, messages, warnings := toPrompt(prompt, true)
281
282 require.Empty(t, systemBlocks)
283 require.Len(t, messages, 1)
284 require.Empty(t, warnings)
285 })
286
287 t.Run("should keep user messages with text document content", func(t *testing.T) {
288 t.Parallel()
289
290 prompt := fantasy.Prompt{
291 {
292 Role: fantasy.MessageRoleUser,
293 Content: []fantasy.MessagePart{
294 fantasy.FilePart{
295 Data: []byte("# Hello World\nSome markdown content"),
296 MediaType: "text/markdown",
297 },
298 },
299 },
300 }
301
302 systemBlocks, messages, warnings := toPrompt(prompt, true)
303
304 require.Empty(t, systemBlocks)
305 require.Len(t, messages, 1)
306 require.Empty(t, warnings)
307 })
308
309 t.Run("should drop user messages without visible content", func(t *testing.T) {
310 t.Parallel()
311
312 prompt := fantasy.Prompt{
313 {
314 Role: fantasy.MessageRoleUser,
315 Content: []fantasy.MessagePart{
316 fantasy.FilePart{
317 Data: []byte("not supported"),
318 MediaType: "application/zip",
319 },
320 },
321 },
322 }
323
324 systemBlocks, messages, warnings := toPrompt(prompt, true)
325
326 require.Empty(t, systemBlocks)
327 require.Empty(t, messages)
328 require.Len(t, warnings, 1)
329 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
330 require.Contains(t, warnings[0].Message, "dropping empty user message")
331 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
332 })
333
334 t.Run("should keep user messages with tool results", func(t *testing.T) {
335 t.Parallel()
336
337 prompt := fantasy.Prompt{
338 {
339 Role: fantasy.MessageRoleTool,
340 Content: []fantasy.MessagePart{
341 fantasy.ToolResultPart{
342 ToolCallID: "call_123",
343 Output: fantasy.ToolResultOutputContentText{Text: "done"},
344 },
345 },
346 },
347 }
348
349 systemBlocks, messages, warnings := toPrompt(prompt, true)
350
351 require.Empty(t, systemBlocks)
352 require.Len(t, messages, 1)
353 require.Empty(t, warnings)
354 })
355
356 t.Run("should keep user messages with tool error results", func(t *testing.T) {
357 t.Parallel()
358
359 prompt := fantasy.Prompt{
360 {
361 Role: fantasy.MessageRoleTool,
362 Content: []fantasy.MessagePart{
363 fantasy.ToolResultPart{
364 ToolCallID: "call_456",
365 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
366 },
367 },
368 },
369 }
370
371 systemBlocks, messages, warnings := toPrompt(prompt, true)
372
373 require.Empty(t, systemBlocks)
374 require.Len(t, messages, 1)
375 require.Empty(t, warnings)
376 })
377
378 t.Run("should keep user messages with tool media results", func(t *testing.T) {
379 t.Parallel()
380
381 prompt := fantasy.Prompt{
382 {
383 Role: fantasy.MessageRoleTool,
384 Content: []fantasy.MessagePart{
385 fantasy.ToolResultPart{
386 ToolCallID: "call_789",
387 Output: fantasy.ToolResultOutputContentMedia{
388 Data: "AQID",
389 MediaType: "image/png",
390 },
391 },
392 },
393 },
394 }
395
396 systemBlocks, messages, warnings := toPrompt(prompt, true)
397
398 require.Empty(t, systemBlocks)
399 require.Len(t, messages, 1)
400 require.Empty(t, warnings)
401 })
402}
403
404func TestParseContextTooLargeError(t *testing.T) {
405 t.Parallel()
406
407 tests := []struct {
408 name string
409 message string
410 wantErr bool
411 wantUsed int
412 wantMax int
413 }{
414 {
415 name: "matches anthropic format",
416 message: "prompt is too long: 202630 tokens > 200000 maximum",
417 wantErr: true,
418 wantUsed: 202630,
419 wantMax: 200000,
420 },
421 {
422 name: "matches with different numbers",
423 message: "prompt is too long: 150000 tokens > 128000 maximum",
424 wantErr: true,
425 wantUsed: 150000,
426 wantMax: 128000,
427 },
428 {
429 name: "matches with extra whitespace",
430 message: "prompt is too long: 202630 tokens > 200000 maximum",
431 wantErr: true,
432 wantUsed: 202630,
433 wantMax: 200000,
434 },
435 {
436 name: "does not match unrelated error",
437 message: "invalid api key",
438 wantErr: false,
439 },
440 {
441 name: "does not match rate limit error",
442 message: "rate limit exceeded",
443 wantErr: false,
444 },
445 }
446
447 for _, tt := range tests {
448 t.Run(tt.name, func(t *testing.T) {
449 t.Parallel()
450 providerErr := &fantasy.ProviderError{Message: tt.message}
451 parseContextTooLargeError(tt.message, providerErr)
452
453 if tt.wantErr {
454 require.True(t, providerErr.IsContextTooLarge())
455 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
456 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
457 } else {
458 require.False(t, providerErr.IsContextTooLarge())
459 }
460 })
461 }
462}
463
464func TestParseOptions_Effort(t *testing.T) {
465 t.Parallel()
466
467 options, err := ParseOptions(map[string]any{
468 "send_reasoning": true,
469 "thinking": map[string]any{"budget_tokens": int64(2048)},
470 "effort": "medium",
471 "disable_parallel_tool_use": true,
472 })
473 require.NoError(t, err)
474 require.NotNil(t, options.SendReasoning)
475 require.True(t, *options.SendReasoning)
476 require.NotNil(t, options.Thinking)
477 require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
478 require.NotNil(t, options.Effort)
479 require.Equal(t, EffortMedium, *options.Effort)
480 require.NotNil(t, options.DisableParallelToolUse)
481 require.True(t, *options.DisableParallelToolUse)
482}
483
484func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
485 t.Parallel()
486
487 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
488 defer server.Close()
489
490 provider, err := New(
491 WithAPIKey("test-api-key"),
492 WithBaseURL(server.URL),
493 )
494 require.NoError(t, err)
495
496 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
497 require.NoError(t, err)
498
499 effort := EffortMedium
500 _, err = model.Generate(context.Background(), fantasy.Call{
501 Prompt: testPrompt(),
502 ProviderOptions: NewProviderOptions(&ProviderOptions{
503 Effort: &effort,
504 }),
505 })
506 require.NoError(t, err)
507
508 call := awaitAnthropicCall(t, calls)
509 require.Equal(t, "POST", call.method)
510 require.Equal(t, "/v1/messages", call.path)
511 requireAnthropicEffort(t, call.body, EffortMedium)
512}
513
514func TestStream_SendsOutputConfigEffort(t *testing.T) {
515 t.Parallel()
516
517 server, calls := newAnthropicStreamingServer([]string{
518 "event: message_start\n",
519 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
520 "event: message_stop\n",
521 "data: {\"type\":\"message_stop\"}\n\n",
522 })
523 defer server.Close()
524
525 provider, err := New(
526 WithAPIKey("test-api-key"),
527 WithBaseURL(server.URL),
528 )
529 require.NoError(t, err)
530
531 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
532 require.NoError(t, err)
533
534 effort := EffortHigh
535 stream, err := model.Stream(context.Background(), fantasy.Call{
536 Prompt: testPrompt(),
537 ProviderOptions: NewProviderOptions(&ProviderOptions{
538 Effort: &effort,
539 }),
540 })
541 require.NoError(t, err)
542
543 stream(func(fantasy.StreamPart) bool { return true })
544
545 call := awaitAnthropicCall(t, calls)
546 require.Equal(t, "POST", call.method)
547 require.Equal(t, "/v1/messages", call.path)
548 requireAnthropicEffort(t, call.body, EffortHigh)
549}
550
551type anthropicCall struct {
552 method string
553 path string
554 body map[string]any
555}
556
557func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
558 calls := make(chan anthropicCall, 4)
559
560 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
561 var body map[string]any
562 if r.Body != nil {
563 _ = json.NewDecoder(r.Body).Decode(&body)
564 }
565
566 calls <- anthropicCall{
567 method: r.Method,
568 path: r.URL.Path,
569 body: body,
570 }
571
572 w.Header().Set("Content-Type", "application/json")
573 _ = json.NewEncoder(w).Encode(response)
574 }))
575
576 return server, calls
577}
578
579func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
580 calls := make(chan anthropicCall, 4)
581
582 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
583 var body map[string]any
584 if r.Body != nil {
585 _ = json.NewDecoder(r.Body).Decode(&body)
586 }
587
588 calls <- anthropicCall{
589 method: r.Method,
590 path: r.URL.Path,
591 body: body,
592 }
593
594 w.Header().Set("Content-Type", "text/event-stream")
595 w.Header().Set("Cache-Control", "no-cache")
596 w.Header().Set("Connection", "keep-alive")
597 w.WriteHeader(http.StatusOK)
598
599 for _, chunk := range chunks {
600 _, _ = fmt.Fprint(w, chunk)
601 if flusher, ok := w.(http.Flusher); ok {
602 flusher.Flush()
603 }
604 }
605 }))
606
607 return server, calls
608}
609
610func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
611 t.Helper()
612
613 select {
614 case call := <-calls:
615 return call
616 case <-time.After(2 * time.Second):
617 t.Fatal("timed out waiting for Anthropic request")
618 return anthropicCall{}
619 }
620}
621
622func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
623 t.Helper()
624
625 select {
626 case call := <-calls:
627 t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
628 case <-time.After(200 * time.Millisecond):
629 }
630}
631
632func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
633 t.Helper()
634
635 outputConfig, ok := body["output_config"].(map[string]any)
636 thinking, ok := body["thinking"].(map[string]any)
637 require.True(t, ok)
638 require.Equal(t, string(expected), outputConfig["effort"])
639 require.Equal(t, "adaptive", thinking["type"])
640}
641
642func testPrompt() fantasy.Prompt {
643 return fantasy.Prompt{
644 {
645 Role: fantasy.MessageRoleUser,
646 Content: []fantasy.MessagePart{
647 fantasy.TextPart{Text: "Hello"},
648 },
649 },
650 }
651}
652
653func mockAnthropicGenerateResponse() map[string]any {
654 return map[string]any{
655 "id": "msg_01Test",
656 "type": "message",
657 "role": "assistant",
658 "model": "claude-sonnet-4-20250514",
659 "content": []any{
660 map[string]any{
661 "type": "text",
662 "text": "Hi there",
663 },
664 },
665 "stop_reason": "end_turn",
666 "stop_sequence": "",
667 "usage": map[string]any{
668 "cache_creation": map[string]any{
669 "ephemeral_1h_input_tokens": 0,
670 "ephemeral_5m_input_tokens": 0,
671 },
672 "cache_creation_input_tokens": 0,
673 "cache_read_input_tokens": 0,
674 "input_tokens": 5,
675 "output_tokens": 2,
676 "server_tool_use": map[string]any{
677 "web_search_requests": 0,
678 },
679 "service_tier": "standard",
680 },
681 }
682}
683
684func mockAnthropicWebSearchResponse() map[string]any {
685 return map[string]any{
686 "id": "msg_01WebSearch",
687 "type": "message",
688 "role": "assistant",
689 "model": "claude-sonnet-4-20250514",
690 "content": []any{
691 map[string]any{
692 "type": "server_tool_use",
693 "id": "srvtoolu_01",
694 "name": "web_search",
695 "input": map[string]any{"query": "latest AI news"},
696 "caller": map[string]any{"type": "direct"},
697 },
698 map[string]any{
699 "type": "web_search_tool_result",
700 "tool_use_id": "srvtoolu_01",
701 "caller": map[string]any{"type": "direct"},
702 "content": []any{
703 map[string]any{
704 "type": "web_search_result",
705 "url": "https://example.com/ai-news",
706 "title": "Latest AI News",
707 "encrypted_content": "encrypted_abc123",
708 "page_age": "2 hours ago",
709 },
710 map[string]any{
711 "type": "web_search_result",
712 "url": "https://example.com/ml-update",
713 "title": "ML Update",
714 "encrypted_content": "encrypted_def456",
715 "page_age": "",
716 },
717 },
718 },
719 map[string]any{
720 "type": "text",
721 "text": "Based on recent search results, here is the latest AI news.",
722 },
723 },
724 "stop_reason": "end_turn",
725 "stop_sequence": nil,
726 "usage": map[string]any{
727 "input_tokens": 100,
728 "output_tokens": 50,
729 "cache_creation_input_tokens": 0,
730 "cache_read_input_tokens": 0,
731 "server_tool_use": map[string]any{
732 "web_search_requests": 1,
733 },
734 },
735 }
736}
737
738func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
739 t.Parallel()
740
741 prompt := fantasy.Prompt{
742 // User message.
743 {
744 Role: fantasy.MessageRoleUser,
745 Content: []fantasy.MessagePart{
746 fantasy.TextPart{Text: "Search for the latest AI news"},
747 },
748 },
749 // Assistant message with a provider-executed tool call, its
750 // result, and trailing text. toResponseMessages routes
751 // provider-executed results into the assistant message, so
752 // the prompt already reflects that structure.
753 {
754 Role: fantasy.MessageRoleAssistant,
755 Content: []fantasy.MessagePart{
756 fantasy.ToolCallPart{
757 ToolCallID: "srvtoolu_01",
758 ToolName: "web_search",
759 Input: `{"query":"latest AI news"}`,
760 ProviderExecuted: true,
761 },
762 fantasy.ToolResultPart{
763 ToolCallID: "srvtoolu_01",
764 ProviderExecuted: true,
765 ProviderOptions: fantasy.ProviderOptions{
766 Name: &WebSearchResultMetadata{
767 Results: []WebSearchResultItem{
768 {
769 URL: "https://example.com/ai-news",
770 Title: "Latest AI News",
771 EncryptedContent: "encrypted_abc123",
772 PageAge: "2 hours ago",
773 },
774 {
775 URL: "https://example.com/ml-update",
776 Title: "ML Update",
777 EncryptedContent: "encrypted_def456",
778 },
779 },
780 },
781 },
782 },
783 fantasy.TextPart{Text: "Here is what I found."},
784 },
785 },
786 }
787
788 _, messages, warnings := toPrompt(prompt, true)
789
790 // No warnings expected; the provider-executed result is in the
791 // assistant message so there is no empty tool message to drop.
792 require.Empty(t, warnings)
793
794 // We should have a user message and an assistant message.
795 require.Len(t, messages, 2, "expected user + assistant messages")
796
797 assistantMsg := messages[1]
798 require.Len(t, assistantMsg.Content, 3,
799 "expected server_tool_use + web_search_tool_result + text")
800
801 // First content block: reconstructed server_tool_use.
802 serverToolUse := assistantMsg.Content[0]
803 require.NotNil(t, serverToolUse.OfServerToolUse,
804 "first block should be a server_tool_use")
805 require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
806 require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
807 serverToolUse.OfServerToolUse.Name)
808
809 // Second content block: reconstructed web_search_tool_result with
810 // encrypted_content preserved for multi-turn round-tripping.
811 webResult := assistantMsg.Content[1]
812 require.NotNil(t, webResult.OfWebSearchToolResult,
813 "second block should be a web_search_tool_result")
814 require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
815
816 results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
817 require.Len(t, results, 2)
818 require.Equal(t, "https://example.com/ai-news", results[0].URL)
819 require.Equal(t, "Latest AI News", results[0].Title)
820 require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
821 require.Equal(t, "https://example.com/ml-update", results[1].URL)
822 require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
823 // PageAge should be set for the first result and absent for the second.
824 require.True(t, results[0].PageAge.Valid())
825 require.Equal(t, "2 hours ago", results[0].PageAge.Value)
826 require.False(t, results[1].PageAge.Valid())
827
828 // Third content block: plain text.
829 require.NotNil(t, assistantMsg.Content[2].OfText)
830 require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
831}
832
833func TestGenerate_WebSearchResponse(t *testing.T) {
834 t.Parallel()
835
836 server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
837 defer server.Close()
838
839 provider, err := New(
840 WithAPIKey("test-api-key"),
841 WithBaseURL(server.URL),
842 )
843 require.NoError(t, err)
844
845 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
846 require.NoError(t, err)
847
848 resp, err := model.Generate(context.Background(), fantasy.Call{
849 Prompt: testPrompt(),
850 Tools: []fantasy.Tool{
851 WebSearchTool(nil),
852 },
853 })
854 require.NoError(t, err)
855
856 call := awaitAnthropicCall(t, calls)
857 require.Equal(t, "POST", call.method)
858 require.Equal(t, "/v1/messages", call.path)
859
860 // Walk the response content and categorise each item.
861 var (
862 toolCalls []fantasy.ToolCallContent
863 sources []fantasy.SourceContent
864 toolResults []fantasy.ToolResultContent
865 texts []fantasy.TextContent
866 )
867 for _, c := range resp.Content {
868 switch v := c.(type) {
869 case fantasy.ToolCallContent:
870 toolCalls = append(toolCalls, v)
871 case fantasy.SourceContent:
872 sources = append(sources, v)
873 case fantasy.ToolResultContent:
874 toolResults = append(toolResults, v)
875 case fantasy.TextContent:
876 texts = append(texts, v)
877 }
878 }
879
880 // ToolCallContent for the provider-executed web_search.
881 require.Len(t, toolCalls, 1)
882 require.True(t, toolCalls[0].ProviderExecuted)
883 require.Equal(t, "web_search", toolCalls[0].ToolName)
884 require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
885
886 // SourceContent entries for each search result.
887 require.Len(t, sources, 2)
888 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
889 require.Equal(t, "Latest AI News", sources[0].Title)
890 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
891 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
892 require.Equal(t, "ML Update", sources[1].Title)
893
894 // ToolResultContent with provider metadata preserving encrypted_content.
895 require.Len(t, toolResults, 1)
896 require.True(t, toolResults[0].ProviderExecuted)
897 require.Equal(t, "web_search", toolResults[0].ToolName)
898 require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
899
900 searchMeta, ok := toolResults[0].ProviderMetadata[Name]
901 require.True(t, ok, "providerMetadata should contain anthropic key")
902 webMeta, ok := searchMeta.(*WebSearchResultMetadata)
903 require.True(t, ok, "metadata should be *WebSearchResultMetadata")
904 require.Len(t, webMeta.Results, 2)
905 require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
906 require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
907 require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
908
909 // TextContent with the final answer.
910 require.Len(t, texts, 1)
911 require.Equal(t,
912 "Based on recent search results, here is the latest AI news.",
913 texts[0].Text,
914 )
915}
916
917func TestGenerate_WebSearchToolInRequest(t *testing.T) {
918 t.Parallel()
919
920 t.Run("basic web_search tool", func(t *testing.T) {
921 t.Parallel()
922
923 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
924 defer server.Close()
925
926 provider, err := New(
927 WithAPIKey("test-api-key"),
928 WithBaseURL(server.URL),
929 )
930 require.NoError(t, err)
931
932 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
933 require.NoError(t, err)
934
935 _, err = model.Generate(context.Background(), fantasy.Call{
936 Prompt: testPrompt(),
937 Tools: []fantasy.Tool{
938 WebSearchTool(nil),
939 },
940 })
941 require.NoError(t, err)
942
943 call := awaitAnthropicCall(t, calls)
944 tools, ok := call.body["tools"].([]any)
945 require.True(t, ok, "request body should have tools array")
946 require.Len(t, tools, 1)
947
948 tool, ok := tools[0].(map[string]any)
949 require.True(t, ok)
950 require.Equal(t, "web_search_20250305", tool["type"])
951 })
952
953 t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
954 t.Parallel()
955
956 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
957 defer server.Close()
958
959 provider, err := New(
960 WithAPIKey("test-api-key"),
961 WithBaseURL(server.URL),
962 )
963 require.NoError(t, err)
964
965 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
966 require.NoError(t, err)
967
968 _, err = model.Generate(context.Background(), fantasy.Call{
969 Prompt: testPrompt(),
970 Tools: []fantasy.Tool{
971 WebSearchTool(&WebSearchToolOptions{
972 AllowedDomains: []string{"example.com", "test.com"},
973 }),
974 },
975 })
976 require.NoError(t, err)
977
978 call := awaitAnthropicCall(t, calls)
979 tools, ok := call.body["tools"].([]any)
980 require.True(t, ok)
981 require.Len(t, tools, 1)
982
983 tool, ok := tools[0].(map[string]any)
984 require.True(t, ok)
985 require.Equal(t, "web_search_20250305", tool["type"])
986
987 domains, ok := tool["allowed_domains"].([]any)
988 require.True(t, ok, "tool should have allowed_domains")
989 require.Len(t, domains, 2)
990 require.Equal(t, "example.com", domains[0])
991 require.Equal(t, "test.com", domains[1])
992 })
993
994 t.Run("with max uses and user location", func(t *testing.T) {
995 t.Parallel()
996
997 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
998 defer server.Close()
999
1000 provider, err := New(
1001 WithAPIKey("test-api-key"),
1002 WithBaseURL(server.URL),
1003 )
1004 require.NoError(t, err)
1005
1006 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1007 require.NoError(t, err)
1008
1009 _, err = model.Generate(context.Background(), fantasy.Call{
1010 Prompt: testPrompt(),
1011 Tools: []fantasy.Tool{
1012 WebSearchTool(&WebSearchToolOptions{
1013 MaxUses: 5,
1014 UserLocation: &UserLocation{
1015 City: "San Francisco",
1016 Country: "US",
1017 },
1018 }),
1019 },
1020 })
1021 require.NoError(t, err)
1022
1023 call := awaitAnthropicCall(t, calls)
1024 tools, ok := call.body["tools"].([]any)
1025 require.True(t, ok)
1026 require.Len(t, tools, 1)
1027
1028 tool, ok := tools[0].(map[string]any)
1029 require.True(t, ok)
1030 require.Equal(t, "web_search_20250305", tool["type"])
1031
1032 // max_uses is serialized as a JSON number; json.Unmarshal
1033 // into map[string]any decodes numbers as float64.
1034 maxUses, ok := tool["max_uses"].(float64)
1035 require.True(t, ok, "tool should have max_uses")
1036 require.Equal(t, float64(5), maxUses)
1037
1038 userLoc, ok := tool["user_location"].(map[string]any)
1039 require.True(t, ok, "tool should have user_location")
1040 require.Equal(t, "San Francisco", userLoc["city"])
1041 require.Equal(t, "US", userLoc["country"])
1042 require.Equal(t, "approximate", userLoc["type"])
1043 })
1044
1045 t.Run("with max uses", func(t *testing.T) {
1046 t.Parallel()
1047
1048 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1049 defer server.Close()
1050
1051 provider, err := New(
1052 WithAPIKey("test-api-key"),
1053 WithBaseURL(server.URL),
1054 )
1055 require.NoError(t, err)
1056
1057 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1058 require.NoError(t, err)
1059
1060 _, err = model.Generate(context.Background(), fantasy.Call{
1061 Prompt: testPrompt(),
1062 Tools: []fantasy.Tool{
1063 WebSearchTool(&WebSearchToolOptions{
1064 MaxUses: 3,
1065 }),
1066 },
1067 })
1068 require.NoError(t, err)
1069
1070 call := awaitAnthropicCall(t, calls)
1071 tools, ok := call.body["tools"].([]any)
1072 require.True(t, ok)
1073 require.Len(t, tools, 1)
1074
1075 tool, ok := tools[0].(map[string]any)
1076 require.True(t, ok)
1077 require.Equal(t, "web_search_20250305", tool["type"])
1078
1079 maxUses, ok := tool["max_uses"].(float64)
1080 require.True(t, ok, "tool should have max_uses")
1081 require.Equal(t, float64(3), maxUses)
1082 })
1083
1084 t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1085 t.Parallel()
1086
1087 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1088 defer server.Close()
1089
1090 provider, err := New(
1091 WithAPIKey("test-api-key"),
1092 WithBaseURL(server.URL),
1093 )
1094 require.NoError(t, err)
1095
1096 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1097 require.NoError(t, err)
1098
1099 baseTool := WebSearchTool(&WebSearchToolOptions{
1100 MaxUses: 7,
1101 BlockedDomains: []string{"example.com", "test.com"},
1102 UserLocation: &UserLocation{
1103 City: "San Francisco",
1104 Region: "CA",
1105 Country: "US",
1106 Timezone: "America/Los_Angeles",
1107 },
1108 })
1109
1110 data, err := json.Marshal(baseTool)
1111 require.NoError(t, err)
1112
1113 var roundTripped fantasy.ProviderDefinedTool
1114 err = json.Unmarshal(data, &roundTripped)
1115 require.NoError(t, err)
1116
1117 _, err = model.Generate(context.Background(), fantasy.Call{
1118 Prompt: testPrompt(),
1119 Tools: []fantasy.Tool{roundTripped},
1120 })
1121 require.NoError(t, err)
1122
1123 call := awaitAnthropicCall(t, calls)
1124 tools, ok := call.body["tools"].([]any)
1125 require.True(t, ok)
1126 require.Len(t, tools, 1)
1127
1128 tool, ok := tools[0].(map[string]any)
1129 require.True(t, ok)
1130 require.Equal(t, "web_search_20250305", tool["type"])
1131
1132 domains, ok := tool["blocked_domains"].([]any)
1133 require.True(t, ok, "tool should have blocked_domains")
1134 require.Len(t, domains, 2)
1135 require.Equal(t, "example.com", domains[0])
1136 require.Equal(t, "test.com", domains[1])
1137
1138 maxUses, ok := tool["max_uses"].(float64)
1139 require.True(t, ok, "tool should have max_uses")
1140 require.Equal(t, float64(7), maxUses)
1141
1142 userLoc, ok := tool["user_location"].(map[string]any)
1143 require.True(t, ok, "tool should have user_location")
1144 require.Equal(t, "San Francisco", userLoc["city"])
1145 require.Equal(t, "CA", userLoc["region"])
1146 require.Equal(t, "US", userLoc["country"])
1147 require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1148 require.Equal(t, "approximate", userLoc["type"])
1149 })
1150}
1151
1152func TestAnyToStringSlice(t *testing.T) {
1153 t.Parallel()
1154
1155 t.Run("from string slice", func(t *testing.T) {
1156 t.Parallel()
1157
1158 got := anyToStringSlice([]string{"example.com", ""})
1159 require.Equal(t, []string{"example.com", ""}, got)
1160 })
1161
1162 t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1163 t.Parallel()
1164
1165 got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1166 require.Equal(t, []string{"example.com", "test.com"}, got)
1167 })
1168
1169 t.Run("unsupported type", func(t *testing.T) {
1170 t.Parallel()
1171
1172 got := anyToStringSlice("example.com")
1173 require.Nil(t, got)
1174 })
1175}
1176
1177func TestAnyToInt64(t *testing.T) {
1178 t.Parallel()
1179
1180 tests := []struct {
1181 name string
1182 input any
1183 want int64
1184 wantOK bool
1185 }{
1186 {name: "int64", input: int64(7), want: 7, wantOK: true},
1187 {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1188 {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1189 {name: "float64 non-integer", input: float64(7.5), wantOK: false},
1190 {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1191 {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1192 {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1193 {name: "json number float", input: json.Number("4.2"), wantOK: false},
1194 {name: "nan", input: math.NaN(), wantOK: false},
1195 {name: "inf", input: math.Inf(1), wantOK: false},
1196 {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1197 }
1198
1199 for _, tt := range tests {
1200 t.Run(tt.name, func(t *testing.T) {
1201 got, ok := anyToInt64(tt.input)
1202 require.Equal(t, tt.wantOK, ok)
1203 if tt.wantOK {
1204 require.Equal(t, tt.want, got)
1205 }
1206 })
1207 }
1208}
1209
1210func TestAnyToUserLocation(t *testing.T) {
1211 t.Parallel()
1212
1213 t.Run("pointer passthrough", func(t *testing.T) {
1214 t.Parallel()
1215
1216 input := &UserLocation{City: "San Francisco", Country: "US"}
1217 got := anyToUserLocation(input)
1218 require.Same(t, input, got)
1219 })
1220
1221 t.Run("struct value", func(t *testing.T) {
1222 t.Parallel()
1223
1224 got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1225 require.NotNil(t, got)
1226 require.Equal(t, "San Francisco", got.City)
1227 require.Equal(t, "US", got.Country)
1228 })
1229
1230 t.Run("map value", func(t *testing.T) {
1231 t.Parallel()
1232
1233 got := anyToUserLocation(map[string]any{
1234 "city": "San Francisco",
1235 "region": "CA",
1236 "country": "US",
1237 "timezone": "America/Los_Angeles",
1238 "type": "approximate",
1239 })
1240 require.NotNil(t, got)
1241 require.Equal(t, "San Francisco", got.City)
1242 require.Equal(t, "CA", got.Region)
1243 require.Equal(t, "US", got.Country)
1244 require.Equal(t, "America/Los_Angeles", got.Timezone)
1245 })
1246
1247 t.Run("empty map", func(t *testing.T) {
1248 t.Parallel()
1249
1250 got := anyToUserLocation(map[string]any{"type": "approximate"})
1251 require.Nil(t, got)
1252 })
1253
1254 t.Run("unsupported type", func(t *testing.T) {
1255 t.Parallel()
1256
1257 got := anyToUserLocation("San Francisco")
1258 require.Nil(t, got)
1259 })
1260}
1261
1262func TestStream_WebSearchResponse(t *testing.T) {
1263 t.Parallel()
1264
1265 // Build SSE chunks that simulate a web search streaming response.
1266 // The Anthropic SDK accumulates content blocks via
1267 // acc.Accumulate(event). We read the Content and ToolUseID
1268 // directly from struct fields instead of using AsAny(), which
1269 // avoids the SDK's re-marshal limitation that previously dropped
1270 // source data.
1271 webSearchResultContent, _ := json.Marshal([]any{
1272 map[string]any{
1273 "type": "web_search_result",
1274 "url": "https://example.com/ai-news",
1275 "title": "Latest AI News",
1276 "encrypted_content": "encrypted_abc123",
1277 "page_age": "2 hours ago",
1278 },
1279 })
1280
1281 chunks := []string{
1282 // message_start
1283 "event: message_start\n",
1284 `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1285 // Block 0: server_tool_use
1286 "event: content_block_start\n",
1287 `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1288 "event: content_block_stop\n",
1289 `data: {"type":"content_block_stop","index":0}` + "\n\n",
1290 // Block 1: web_search_tool_result
1291 "event: content_block_start\n",
1292 `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1293 "event: content_block_stop\n",
1294 `data: {"type":"content_block_stop","index":1}` + "\n\n",
1295 // Block 2: text
1296 "event: content_block_start\n",
1297 `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1298 "event: content_block_delta\n",
1299 `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1300 "event: content_block_stop\n",
1301 `data: {"type":"content_block_stop","index":2}` + "\n\n",
1302 // message_stop
1303 "event: message_stop\n",
1304 `data: {"type":"message_stop"}` + "\n\n",
1305 }
1306
1307 server, calls := newAnthropicStreamingServer(chunks)
1308 defer server.Close()
1309
1310 provider, err := New(
1311 WithAPIKey("test-api-key"),
1312 WithBaseURL(server.URL),
1313 )
1314 require.NoError(t, err)
1315
1316 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1317 require.NoError(t, err)
1318
1319 stream, err := model.Stream(context.Background(), fantasy.Call{
1320 Prompt: testPrompt(),
1321 Tools: []fantasy.Tool{
1322 WebSearchTool(nil),
1323 },
1324 })
1325 require.NoError(t, err)
1326
1327 var parts []fantasy.StreamPart
1328 stream(func(part fantasy.StreamPart) bool {
1329 parts = append(parts, part)
1330 return true
1331 })
1332
1333 _ = awaitAnthropicCall(t, calls)
1334
1335 // Collect parts by type for assertions.
1336 var (
1337 toolInputStarts []fantasy.StreamPart
1338 toolCalls []fantasy.StreamPart
1339 toolResults []fantasy.StreamPart
1340 sourceParts []fantasy.StreamPart
1341 textDeltas []fantasy.StreamPart
1342 )
1343 for _, p := range parts {
1344 switch p.Type {
1345 case fantasy.StreamPartTypeToolInputStart:
1346 toolInputStarts = append(toolInputStarts, p)
1347 case fantasy.StreamPartTypeToolCall:
1348 toolCalls = append(toolCalls, p)
1349 case fantasy.StreamPartTypeToolResult:
1350 toolResults = append(toolResults, p)
1351 case fantasy.StreamPartTypeSource:
1352 sourceParts = append(sourceParts, p)
1353 case fantasy.StreamPartTypeTextDelta:
1354 textDeltas = append(textDeltas, p)
1355 }
1356 }
1357
1358 // server_tool_use emits a ToolInputStart with ProviderExecuted.
1359 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1360 require.True(t, toolInputStarts[0].ProviderExecuted)
1361 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1362
1363 // server_tool_use emits a ToolCall with ProviderExecuted.
1364 require.NotEmpty(t, toolCalls, "should have a tool call")
1365 require.True(t, toolCalls[0].ProviderExecuted)
1366
1367 // web_search_tool_result always emits a ToolResult even when
1368 // the SDK drops source data. The ToolUseID comes from the raw
1369 // union field as a fallback.
1370 require.NotEmpty(t, toolResults, "should have a tool result")
1371 require.True(t, toolResults[0].ProviderExecuted)
1372 require.Equal(t, "web_search", toolResults[0].ToolCallName)
1373 require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1374 "tool result ID should match the tool_use_id")
1375
1376 // Source parts are now correctly emitted by reading struct fields
1377 // directly instead of using AsAny().
1378 require.Len(t, sourceParts, 1)
1379 require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1380 require.Equal(t, "Latest AI News", sourceParts[0].Title)
1381 require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1382
1383 // Text block emits a text delta.
1384 require.NotEmpty(t, textDeltas, "should have text deltas")
1385 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1386}
1387
1388func TestGenerate_ToolChoiceNone(t *testing.T) {
1389 t.Parallel()
1390
1391 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1392 defer server.Close()
1393
1394 provider, err := New(
1395 WithAPIKey("test-api-key"),
1396 WithBaseURL(server.URL),
1397 )
1398 require.NoError(t, err)
1399
1400 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1401 require.NoError(t, err)
1402
1403 toolChoiceNone := fantasy.ToolChoiceNone
1404 _, err = model.Generate(context.Background(), fantasy.Call{
1405 Prompt: testPrompt(),
1406 Tools: []fantasy.Tool{
1407 WebSearchTool(nil),
1408 },
1409 ToolChoice: &toolChoiceNone,
1410 })
1411 require.NoError(t, err)
1412
1413 call := awaitAnthropicCall(t, calls)
1414 toolChoice, ok := call.body["tool_choice"].(map[string]any)
1415 require.True(t, ok, "request body should have tool_choice")
1416 require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1417}
1418
1419// --- Computer Use Tests ---
1420
1421// jsonRoundTripTool simulates a JSON round-trip on a
1422// ProviderDefinedTool so that its Args map contains float64
1423// values (as json.Unmarshal produces) rather than the int64
1424// values that NewComputerUseTool stores directly. The
1425// production toBetaTools code asserts float64.
1426func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1427 t.Helper()
1428 pdt := tool.Definition()
1429 data, err := json.Marshal(pdt.Args)
1430 require.NoError(t, err)
1431 var args map[string]any
1432 require.NoError(t, json.Unmarshal(data, &args))
1433 pdt.Args = args
1434 return pdt
1435}
1436
1437func TestNewComputerUseTool(t *testing.T) {
1438 t.Parallel()
1439
1440 t.Run("creates tool with correct ID and name", func(t *testing.T) {
1441 t.Parallel()
1442 tool := NewComputerUseTool(ComputerUseToolOptions{
1443 DisplayWidthPx: 1920,
1444 DisplayHeightPx: 1080,
1445 ToolVersion: ComputerUse20250124,
1446 }, noopComputerRun).Definition()
1447 require.Equal(t, "anthropic.computer", tool.ID)
1448 require.Equal(t, "computer", tool.Name)
1449 require.Equal(t, int64(1920), tool.Args["display_width_px"])
1450 require.Equal(t, int64(1080), tool.Args["display_height_px"])
1451 require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1452 })
1453
1454 t.Run("includes optional fields when set", func(t *testing.T) {
1455 t.Parallel()
1456 displayNum := int64(1)
1457 enableZoom := true
1458 tool := NewComputerUseTool(ComputerUseToolOptions{
1459 DisplayWidthPx: 1024,
1460 DisplayHeightPx: 768,
1461 DisplayNumber: &displayNum,
1462 EnableZoom: &enableZoom,
1463 ToolVersion: ComputerUse20251124,
1464 CacheControl: &CacheControl{Type: "ephemeral"},
1465 }, noopComputerRun).Definition()
1466 require.Equal(t, int64(1), tool.Args["display_number"])
1467 require.Equal(t, true, tool.Args["enable_zoom"])
1468 require.NotNil(t, tool.Args["cache_control"])
1469 })
1470
1471 t.Run("omits optional fields when nil", func(t *testing.T) {
1472 t.Parallel()
1473 tool := NewComputerUseTool(ComputerUseToolOptions{
1474 DisplayWidthPx: 1920,
1475 DisplayHeightPx: 1080,
1476 ToolVersion: ComputerUse20250124,
1477 }, noopComputerRun).Definition()
1478 _, hasDisplayNum := tool.Args["display_number"]
1479 _, hasEnableZoom := tool.Args["enable_zoom"]
1480 _, hasCacheControl := tool.Args["cache_control"]
1481 require.False(t, hasDisplayNum)
1482 require.False(t, hasEnableZoom)
1483 require.False(t, hasCacheControl)
1484 })
1485}
1486
1487func TestIsComputerUseTool(t *testing.T) {
1488 t.Parallel()
1489
1490 t.Run("returns true for computer use tool", func(t *testing.T) {
1491 t.Parallel()
1492 tool := NewComputerUseTool(ComputerUseToolOptions{
1493 DisplayWidthPx: 1920,
1494 DisplayHeightPx: 1080,
1495 ToolVersion: ComputerUse20250124,
1496 }, noopComputerRun)
1497 require.True(t, IsComputerUseTool(tool.Definition()))
1498 })
1499
1500 t.Run("returns false for function tool", func(t *testing.T) {
1501 t.Parallel()
1502 tool := fantasy.FunctionTool{
1503 Name: "test",
1504 Description: "test tool",
1505 }
1506 require.False(t, IsComputerUseTool(tool))
1507 })
1508
1509 t.Run("returns false for other provider defined tool", func(t *testing.T) {
1510 t.Parallel()
1511 tool := fantasy.ProviderDefinedTool{
1512 ID: "other.tool",
1513 Name: "other",
1514 }
1515 require.False(t, IsComputerUseTool(tool))
1516 })
1517}
1518
1519func TestNeedsBetaAPI(t *testing.T) {
1520 t.Parallel()
1521
1522 lm := languageModel{options: options{}}
1523
1524 t.Run("returns false for empty tools", func(t *testing.T) {
1525 t.Parallel()
1526 _, _, _, betaFlags := lm.toTools(nil, nil, false)
1527 require.Empty(t, betaFlags)
1528 _, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1529 require.Empty(t, betaFlags)
1530 })
1531
1532 t.Run("returns false for only function tools", func(t *testing.T) {
1533 t.Parallel()
1534 tools := []fantasy.Tool{
1535 fantasy.FunctionTool{Name: "test"},
1536 }
1537 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1538 require.Empty(t, betaFlags)
1539 })
1540
1541 t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1542 t.Parallel()
1543 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1544 DisplayWidthPx: 1920,
1545 DisplayHeightPx: 1080,
1546 ToolVersion: ComputerUse20250124,
1547 }, noopComputerRun))
1548 tools := []fantasy.Tool{
1549 fantasy.FunctionTool{Name: "test"},
1550 cuTool,
1551 }
1552 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1553 require.NotEmpty(t, betaFlags)
1554 })
1555}
1556
1557func TestComputerUseToolJSON(t *testing.T) {
1558 t.Parallel()
1559
1560 t.Run("builds JSON for version 20250124", func(t *testing.T) {
1561 t.Parallel()
1562 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1563 DisplayWidthPx: 1920,
1564 DisplayHeightPx: 1080,
1565 ToolVersion: ComputerUse20250124,
1566 }, noopComputerRun))
1567 data, err := computerUseToolJSON(cuTool)
1568 require.NoError(t, err)
1569 var m map[string]any
1570 require.NoError(t, json.Unmarshal(data, &m))
1571 require.Equal(t, "computer_20250124", m["type"])
1572 require.Equal(t, "computer", m["name"])
1573 require.InDelta(t, 1920, m["display_width_px"], 0)
1574 require.InDelta(t, 1080, m["display_height_px"], 0)
1575 })
1576
1577 t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1578 t.Parallel()
1579 enableZoom := true
1580 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1581 DisplayWidthPx: 1024,
1582 DisplayHeightPx: 768,
1583 EnableZoom: &enableZoom,
1584 ToolVersion: ComputerUse20251124,
1585 }, noopComputerRun))
1586 data, err := computerUseToolJSON(cuTool)
1587 require.NoError(t, err)
1588 var m map[string]any
1589 require.NoError(t, json.Unmarshal(data, &m))
1590 require.Equal(t, "computer_20251124", m["type"])
1591 require.Equal(t, true, m["enable_zoom"])
1592 })
1593
1594 t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1595 t.Parallel()
1596 // Direct construction stores int64 values.
1597 cuTool := NewComputerUseTool(ComputerUseToolOptions{
1598 DisplayWidthPx: 1920,
1599 DisplayHeightPx: 1080,
1600 ToolVersion: ComputerUse20250124,
1601 }, noopComputerRun)
1602 data, err := computerUseToolJSON(cuTool.Definition())
1603 require.NoError(t, err)
1604 var m map[string]any
1605 require.NoError(t, json.Unmarshal(data, &m))
1606 require.InDelta(t, 1920, m["display_width_px"], 0)
1607 })
1608
1609 t.Run("returns error when version is missing", func(t *testing.T) {
1610 t.Parallel()
1611 pdt := fantasy.ProviderDefinedTool{
1612 ID: "anthropic.computer",
1613 Name: "computer",
1614 Args: map[string]any{
1615 "display_width_px": float64(1920),
1616 "display_height_px": float64(1080),
1617 },
1618 }
1619 _, err := computerUseToolJSON(pdt)
1620 require.Error(t, err)
1621 require.Contains(t, err.Error(), "tool_version arg is missing")
1622 })
1623
1624 t.Run("returns error for unsupported version", func(t *testing.T) {
1625 t.Parallel()
1626 pdt := fantasy.ProviderDefinedTool{
1627 ID: "anthropic.computer",
1628 Name: "computer",
1629 Args: map[string]any{
1630 "display_width_px": float64(1920),
1631 "display_height_px": float64(1080),
1632 "tool_version": "computer_99991231",
1633 },
1634 }
1635 _, err := computerUseToolJSON(pdt)
1636 require.Error(t, err)
1637 require.Contains(t, err.Error(), "unsupported")
1638 })
1639}
1640
1641func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1642 t.Parallel()
1643
1644 t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1645 t.Parallel()
1646 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1647 require.Error(t, err)
1648 require.Contains(t, err.Error(), "coordinate")
1649 })
1650
1651 t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1652 t.Parallel()
1653 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1654 require.Error(t, err)
1655 require.Contains(t, err.Error(), "coordinate")
1656 })
1657
1658 t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1659 t.Parallel()
1660 _, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1661 require.Error(t, err)
1662 require.Contains(t, err.Error(), "start_coordinate")
1663 })
1664
1665 t.Run("rejects region with 3 elements", func(t *testing.T) {
1666 t.Parallel()
1667 _, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1668 require.Error(t, err)
1669 require.Contains(t, err.Error(), "region")
1670 })
1671
1672 t.Run("accepts valid coordinate", func(t *testing.T) {
1673 t.Parallel()
1674 result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1675 require.NoError(t, err)
1676 require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1677 })
1678
1679 t.Run("accepts absent optional arrays", func(t *testing.T) {
1680 t.Parallel()
1681 result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1682 require.NoError(t, err)
1683 require.Equal(t, ActionScreenshot, result.Action)
1684 })
1685}
1686
1687func TestToTools_RawJSON(t *testing.T) {
1688 t.Parallel()
1689
1690 lm := languageModel{options: options{}}
1691
1692 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1693 DisplayWidthPx: 1920,
1694 DisplayHeightPx: 1080,
1695 ToolVersion: ComputerUse20250124,
1696 }, noopComputerRun))
1697
1698 tools := []fantasy.Tool{
1699 fantasy.FunctionTool{
1700 Name: "weather",
1701 Description: "Get weather",
1702 InputSchema: map[string]any{
1703 "properties": map[string]any{
1704 "location": map[string]any{"type": "string"},
1705 },
1706 "required": []string{"location"},
1707 },
1708 },
1709 WebSearchTool(nil),
1710 cuTool,
1711 }
1712
1713 rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1714
1715 require.Len(t, rawTools, 3)
1716 require.Nil(t, toolChoice)
1717 require.Empty(t, warnings)
1718 require.NotEmpty(t, betaFlags)
1719
1720 // Verify each raw tool is valid JSON.
1721 for i, raw := range rawTools {
1722 var m map[string]any
1723 require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1724 }
1725
1726 // Check function tool.
1727 var funcTool map[string]any
1728 require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1729 require.Equal(t, "weather", funcTool["name"])
1730
1731 // Check web search tool.
1732 var webTool map[string]any
1733 require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1734 require.Equal(t, "web_search_20250305", webTool["type"])
1735
1736 // Check computer use tool.
1737 var cuToolJSON map[string]any
1738 require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1739 require.Equal(t, "computer_20250124", cuToolJSON["type"])
1740 require.Equal(t, "computer", cuToolJSON["name"])
1741}
1742
1743func TestGenerate_BetaAPI(t *testing.T) {
1744 t.Parallel()
1745
1746 t.Run("sends beta header for computer use", func(t *testing.T) {
1747 t.Parallel()
1748
1749 var capturedHeaders http.Header
1750 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1751 capturedHeaders = r.Header.Clone()
1752 w.Header().Set("Content-Type", "application/json")
1753 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1754 }))
1755 defer server.Close()
1756
1757 provider, err := New(
1758 WithAPIKey("test-api-key"),
1759 WithBaseURL(server.URL),
1760 )
1761 require.NoError(t, err)
1762
1763 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1764 require.NoError(t, err)
1765
1766 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1767 DisplayWidthPx: 1920,
1768 DisplayHeightPx: 1080,
1769 ToolVersion: ComputerUse20250124,
1770 }, noopComputerRun))
1771
1772 _, err = model.Generate(context.Background(), fantasy.Call{
1773 Prompt: testPrompt(),
1774 Tools: []fantasy.Tool{cuTool},
1775 })
1776 require.NoError(t, err)
1777 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1778 })
1779
1780 t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1781 t.Parallel()
1782
1783 var capturedHeaders http.Header
1784 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1785 capturedHeaders = r.Header.Clone()
1786 w.Header().Set("Content-Type", "application/json")
1787 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1788 }))
1789 defer server.Close()
1790
1791 provider, err := New(
1792 WithAPIKey("test-api-key"),
1793 WithBaseURL(server.URL),
1794 )
1795 require.NoError(t, err)
1796
1797 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1798 require.NoError(t, err)
1799
1800 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1801 DisplayWidthPx: 1920,
1802 DisplayHeightPx: 1080,
1803 ToolVersion: ComputerUse20251124,
1804 }, noopComputerRun))
1805
1806 _, err = model.Generate(context.Background(), fantasy.Call{
1807 Prompt: testPrompt(),
1808 Tools: []fantasy.Tool{cuTool},
1809 })
1810 require.NoError(t, err)
1811 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1812 })
1813
1814 t.Run("returns tool use from beta response", func(t *testing.T) {
1815 t.Parallel()
1816
1817 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1818 w.Header().Set("Content-Type", "application/json")
1819 _ = json.NewEncoder(w).Encode(map[string]any{
1820 "id": "msg_01Test",
1821 "type": "message",
1822 "role": "assistant",
1823 "model": "claude-sonnet-4-20250514",
1824 "content": []any{
1825 map[string]any{
1826 "type": "tool_use",
1827 "id": "toolu_01",
1828 "name": "computer",
1829 "input": map[string]any{"action": "screenshot"},
1830 },
1831 },
1832 "stop_reason": "tool_use",
1833 "usage": map[string]any{
1834 "input_tokens": 10,
1835 "output_tokens": 5,
1836 "cache_creation": map[string]any{
1837 "ephemeral_1h_input_tokens": 0,
1838 "ephemeral_5m_input_tokens": 0,
1839 },
1840 "cache_creation_input_tokens": 0,
1841 "cache_read_input_tokens": 0,
1842 "server_tool_use": map[string]any{
1843 "web_search_requests": 0,
1844 },
1845 "service_tier": "standard",
1846 },
1847 })
1848 }))
1849 defer server.Close()
1850
1851 provider, err := New(
1852 WithAPIKey("test-api-key"),
1853 WithBaseURL(server.URL),
1854 )
1855 require.NoError(t, err)
1856
1857 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1858 require.NoError(t, err)
1859
1860 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1861 DisplayWidthPx: 1920,
1862 DisplayHeightPx: 1080,
1863 ToolVersion: ComputerUse20250124,
1864 }, noopComputerRun))
1865
1866 resp, err := model.Generate(context.Background(), fantasy.Call{
1867 Prompt: testPrompt(),
1868 Tools: []fantasy.Tool{cuTool},
1869 })
1870 require.NoError(t, err)
1871
1872 toolCalls := resp.Content.ToolCalls()
1873 require.Len(t, toolCalls, 1)
1874 require.Equal(t, "computer", toolCalls[0].ToolName)
1875 require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1876 require.Contains(t, toolCalls[0].Input, "screenshot")
1877 require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1878
1879 // Verify typed parsing works on the tool call input.
1880 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1881 require.NoError(t, err)
1882 require.Equal(t, ActionScreenshot, parsed.Action)
1883 })
1884}
1885
1886func TestStream_BetaAPI(t *testing.T) {
1887 t.Parallel()
1888
1889 t.Run("streams via beta API for computer use", func(t *testing.T) {
1890 t.Parallel()
1891
1892 var capturedHeaders http.Header
1893 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1894 capturedHeaders = r.Header.Clone()
1895 w.Header().Set("Content-Type", "text/event-stream")
1896 w.Header().Set("Cache-Control", "no-cache")
1897 w.WriteHeader(http.StatusOK)
1898 chunks := []string{
1899 "event: message_start\n",
1900 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1901 "event: message_stop\n",
1902 "data: {\"type\":\"message_stop\"}\n\n",
1903 }
1904 for _, chunk := range chunks {
1905 _, _ = fmt.Fprint(w, chunk)
1906 if flusher, ok := w.(http.Flusher); ok {
1907 flusher.Flush()
1908 }
1909 }
1910 }))
1911 defer server.Close()
1912
1913 provider, err := New(
1914 WithAPIKey("test-api-key"),
1915 WithBaseURL(server.URL),
1916 )
1917 require.NoError(t, err)
1918
1919 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1920 require.NoError(t, err)
1921
1922 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1923 DisplayWidthPx: 1920,
1924 DisplayHeightPx: 1080,
1925 ToolVersion: ComputerUse20250124,
1926 }, noopComputerRun))
1927
1928 stream, err := model.Stream(context.Background(), fantasy.Call{
1929 Prompt: testPrompt(),
1930 Tools: []fantasy.Tool{cuTool},
1931 })
1932 require.NoError(t, err)
1933
1934 stream(func(fantasy.StreamPart) bool { return true })
1935
1936 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1937 })
1938
1939 t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1940 t.Parallel()
1941
1942 var capturedHeaders http.Header
1943 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1944 capturedHeaders = r.Header.Clone()
1945 w.Header().Set("Content-Type", "text/event-stream")
1946 w.Header().Set("Cache-Control", "no-cache")
1947 w.WriteHeader(http.StatusOK)
1948 chunks := []string{
1949 "event: message_start\n",
1950 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1951 "event: message_stop\n",
1952 "data: {\"type\":\"message_stop\"}\n\n",
1953 }
1954 for _, chunk := range chunks {
1955 _, _ = fmt.Fprint(w, chunk)
1956 if flusher, ok := w.(http.Flusher); ok {
1957 flusher.Flush()
1958 }
1959 }
1960 }))
1961 defer server.Close()
1962
1963 provider, err := New(
1964 WithAPIKey("test-api-key"),
1965 WithBaseURL(server.URL),
1966 )
1967 require.NoError(t, err)
1968
1969 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1970 require.NoError(t, err)
1971
1972 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1973 DisplayWidthPx: 1920,
1974 DisplayHeightPx: 1080,
1975 ToolVersion: ComputerUse20251124,
1976 }, noopComputerRun))
1977
1978 stream, err := model.Stream(context.Background(), fantasy.Call{
1979 Prompt: testPrompt(),
1980 Tools: []fantasy.Tool{cuTool},
1981 })
1982 require.NoError(t, err)
1983
1984 stream(func(fantasy.StreamPart) bool { return true })
1985
1986 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1987 })
1988}
1989
1990// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1991// via model.Generate, passing the ExecutableProviderTool directly into
1992// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1993// walks through a scripted sequence of actions — screenshot, click,
1994// type, key, scroll — then finishes with a text reply. Each turn the
1995// test parses the tool call, builds a screenshot result, and appends
1996// both to the prompt for the next request.
1997func TestGenerate_ComputerUseTool(t *testing.T) {
1998 t.Parallel()
1999
2000 type actionStep struct {
2001 input map[string]any
2002 want ComputerUseInput
2003 }
2004 steps := []actionStep{
2005 {
2006 input: map[string]any{"action": "screenshot"},
2007 want: ComputerUseInput{Action: ActionScreenshot},
2008 },
2009 {
2010 input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
2011 want: ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
2012 },
2013 {
2014 input: map[string]any{"action": "type", "text": "hello world"},
2015 want: ComputerUseInput{Action: ActionType, Text: "hello world"},
2016 },
2017 {
2018 input: map[string]any{"action": "key", "text": "Return"},
2019 want: ComputerUseInput{Action: ActionKey, Text: "Return"},
2020 },
2021 {
2022 input: map[string]any{
2023 "action": "scroll",
2024 "coordinate": []any{500, 300},
2025 "scroll_direction": "down",
2026 "scroll_amount": 3,
2027 },
2028 want: ComputerUseInput{
2029 Action: ActionScroll,
2030 Coordinate: [2]int64{500, 300},
2031 ScrollDirection: "down",
2032 ScrollAmount: 3,
2033 },
2034 },
2035 {
2036 input: map[string]any{"action": "screenshot"},
2037 want: ComputerUseInput{Action: ActionScreenshot},
2038 },
2039 }
2040
2041 var (
2042 requestIdx int
2043 betaHeaders []string
2044 )
2045 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2046 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2047 idx := requestIdx
2048 requestIdx++
2049
2050 w.Header().Set("Content-Type", "application/json")
2051 if idx < len(steps) {
2052 _ = json.NewEncoder(w).Encode(map[string]any{
2053 "id": fmt.Sprintf("msg_%02d", idx),
2054 "type": "message",
2055 "role": "assistant",
2056 "model": "claude-sonnet-4-20250514",
2057 "content": []any{map[string]any{
2058 "type": "tool_use",
2059 "id": fmt.Sprintf("toolu_%02d", idx),
2060 "name": "computer",
2061 "input": steps[idx].input,
2062 }},
2063 "stop_reason": "tool_use",
2064 "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
2065 })
2066 return
2067 }
2068 _ = json.NewEncoder(w).Encode(map[string]any{
2069 "id": "msg_final",
2070 "type": "message",
2071 "role": "assistant",
2072 "model": "claude-sonnet-4-20250514",
2073 "content": []any{map[string]any{
2074 "type": "text",
2075 "text": "Done! I have completed all the requested actions.",
2076 }},
2077 "stop_reason": "end_turn",
2078 "usage": map[string]any{"input_tokens": 10, "output_tokens": 15},
2079 })
2080 }))
2081 defer server.Close()
2082
2083 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2084 require.NoError(t, err)
2085
2086 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2087 require.NoError(t, err)
2088
2089 // Pass the ExecutableProviderTool directly — the whole point is
2090 // to verify that the Tool interface works without unwrapping.
2091 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2092 DisplayWidthPx: 1920,
2093 DisplayHeightPx: 1080,
2094 ToolVersion: ComputerUse20250124,
2095 }, noopComputerRun)
2096
2097 var got []ComputerUseInput
2098 prompt := testPrompt()
2099 fakePNG := []byte("fake-screenshot-png")
2100
2101 for turn := 0; turn <= len(steps); turn++ {
2102 resp, err := model.Generate(context.Background(), fantasy.Call{
2103 Prompt: prompt,
2104 Tools: []fantasy.Tool{cuTool},
2105 })
2106 require.NoError(t, err, "turn %d", turn)
2107
2108 if resp.FinishReason != fantasy.FinishReasonToolCalls {
2109 require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2110 require.Contains(t, resp.Content.Text(), "Done")
2111 break
2112 }
2113
2114 toolCalls := resp.Content.ToolCalls()
2115 require.Len(t, toolCalls, 1, "turn %d", turn)
2116 require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2117
2118 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2119 require.NoError(t, err, "turn %d", turn)
2120 got = append(got, parsed)
2121
2122 // Build the next prompt: append the assistant tool-call turn
2123 // and the user screenshot-result turn.
2124 prompt = append(prompt,
2125 fantasy.Message{
2126 Role: fantasy.MessageRoleAssistant,
2127 Content: []fantasy.MessagePart{
2128 fantasy.ToolCallPart{
2129 ToolCallID: toolCalls[0].ToolCallID,
2130 ToolName: toolCalls[0].ToolName,
2131 Input: toolCalls[0].Input,
2132 },
2133 },
2134 },
2135 fantasy.Message{
2136 // Use MessageRoleTool for tool results — this matches
2137 // what the agent loop produces.
2138 Role: fantasy.MessageRoleTool,
2139 Content: []fantasy.MessagePart{
2140 NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2141 },
2142 },
2143 )
2144 }
2145
2146 // Every scripted action was received and parsed correctly.
2147 require.Len(t, got, len(steps))
2148 for i, step := range steps {
2149 require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2150 require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2151 require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2152 require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2153 require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2154 }
2155
2156 // Beta header was sent on every request.
2157 require.Len(t, betaHeaders, len(steps)+1)
2158 for i, h := range betaHeaders {
2159 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2160 }
2161}
2162
2163// TestStream_ComputerUseTool runs a multi-turn computer use session
2164// via model.Stream, verifying that the ExecutableProviderTool works
2165// through the streaming path end-to-end.
2166func TestStream_ComputerUseTool(t *testing.T) {
2167 t.Parallel()
2168
2169 type streamStep struct {
2170 input map[string]any
2171 wantAction ComputerAction
2172 }
2173 steps := []streamStep{
2174 {input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2175 {input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2176 {input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2177 }
2178
2179 var (
2180 requestIdx int
2181 betaHeaders []string
2182 )
2183
2184 // streamToolUseChunks returns SSE chunks for a single
2185 // computer-use tool_use content block.
2186 streamToolUseChunks := func(id string, input map[string]any) []string {
2187 inputJSON, _ := json.Marshal(input)
2188 escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2189 return []string{
2190 "event: message_start\n",
2191 `data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2192 "event: content_block_start\n",
2193 `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2194 "event: content_block_delta\n",
2195 `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2196 "event: content_block_stop\n",
2197 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2198 "event: message_delta\n",
2199 `data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2200 "event: message_stop\n",
2201 `data: {"type":"message_stop"}` + "\n\n",
2202 }
2203 }
2204
2205 streamTextChunks := func() []string {
2206 return []string{
2207 "event: message_start\n",
2208 `data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2209 "event: content_block_start\n",
2210 `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2211 "event: content_block_delta\n",
2212 `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2213 "event: content_block_stop\n",
2214 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2215 "event: message_delta\n",
2216 `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2217 "event: message_stop\n",
2218 `data: {"type":"message_stop"}` + "\n\n",
2219 }
2220 }
2221
2222 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2223 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2224 idx := requestIdx
2225 requestIdx++
2226
2227 w.Header().Set("Content-Type", "text/event-stream")
2228 w.Header().Set("Cache-Control", "no-cache")
2229 w.WriteHeader(http.StatusOK)
2230
2231 var chunks []string
2232 if idx < len(steps) {
2233 chunks = streamToolUseChunks(
2234 fmt.Sprintf("toolu_%02d", idx),
2235 steps[idx].input,
2236 )
2237 } else {
2238 chunks = streamTextChunks()
2239 }
2240 for _, chunk := range chunks {
2241 _, _ = fmt.Fprint(w, chunk)
2242 if f, ok := w.(http.Flusher); ok {
2243 f.Flush()
2244 }
2245 }
2246 }))
2247 defer server.Close()
2248
2249 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2250 require.NoError(t, err)
2251
2252 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2253 require.NoError(t, err)
2254
2255 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2256 DisplayWidthPx: 1920,
2257 DisplayHeightPx: 1080,
2258 ToolVersion: ComputerUse20250124,
2259 }, noopComputerRun)
2260
2261 var gotActions []ComputerAction
2262 prompt := testPrompt()
2263 fakePNG := []byte("fake-screenshot-png")
2264
2265 for turn := 0; turn <= len(steps); turn++ {
2266 stream, err := model.Stream(context.Background(), fantasy.Call{
2267 Prompt: prompt,
2268 Tools: []fantasy.Tool{cuTool},
2269 })
2270 require.NoError(t, err, "turn %d", turn)
2271
2272 var (
2273 toolCallName string
2274 toolCallID string
2275 toolCallInput string
2276 finishReason fantasy.FinishReason
2277 gotText string
2278 )
2279 stream(func(part fantasy.StreamPart) bool {
2280 switch part.Type {
2281 case fantasy.StreamPartTypeToolCall:
2282 toolCallName = part.ToolCallName
2283 toolCallID = part.ID
2284 toolCallInput = part.ToolCallInput
2285 case fantasy.StreamPartTypeFinish:
2286 finishReason = part.FinishReason
2287 case fantasy.StreamPartTypeTextDelta:
2288 gotText += part.Delta
2289 }
2290 return true
2291 })
2292
2293 if finishReason != fantasy.FinishReasonToolCalls {
2294 require.Contains(t, gotText, "All done")
2295 break
2296 }
2297
2298 require.Equal(t, "computer", toolCallName, "turn %d", turn)
2299
2300 parsed, err := ParseComputerUseInput(toolCallInput)
2301 require.NoError(t, err, "turn %d", turn)
2302 gotActions = append(gotActions, parsed.Action)
2303
2304 prompt = append(prompt,
2305 fantasy.Message{
2306 Role: fantasy.MessageRoleAssistant,
2307 Content: []fantasy.MessagePart{
2308 fantasy.ToolCallPart{
2309 ToolCallID: toolCallID,
2310 ToolName: toolCallName,
2311 Input: toolCallInput,
2312 },
2313 },
2314 },
2315 fantasy.Message{
2316 // Use MessageRoleTool for tool results — this matches
2317 // what the agent loop produces.
2318 Role: fantasy.MessageRoleTool,
2319 Content: []fantasy.MessagePart{
2320 NewComputerUseScreenshotResult(toolCallID, fakePNG),
2321 },
2322 },
2323 )
2324 }
2325
2326 require.Len(t, gotActions, len(steps))
2327 for i, step := range steps {
2328 require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2329 }
2330
2331 require.Len(t, betaHeaders, len(steps)+1)
2332 for i, h := range betaHeaders {
2333 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2334 }
2335}