1package anthropic
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "math"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/anthropic-sdk-go"
17 "github.com/stretchr/testify/require"
18)
19
20// noopComputerRun is a no-op run function for tests that only need
21// to inspect the tool definition, not execute it.
22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
23 return fantasy.ToolResponse{}, nil
24}
25
26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
27 t.Parallel()
28
29 t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
30 t.Parallel()
31
32 prompt := fantasy.Prompt{
33 {
34 Role: fantasy.MessageRoleUser,
35 Content: []fantasy.MessagePart{
36 fantasy.TextPart{Text: "Hello"},
37 },
38 },
39 {
40 Role: fantasy.MessageRoleAssistant,
41 Content: []fantasy.MessagePart{
42 fantasy.ReasoningPart{
43 Text: "Let me think about this...",
44 ProviderOptions: fantasy.ProviderOptions{
45 Name: &ReasoningOptionMetadata{
46 Signature: "abc123",
47 },
48 },
49 },
50 },
51 },
52 }
53
54 systemBlocks, messages, warnings := toPrompt(prompt, true)
55
56 require.Empty(t, systemBlocks)
57 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
58 require.Len(t, warnings, 1)
59 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
60 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
61 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
62 })
63
64 t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
65 t.Parallel()
66
67 prompt := fantasy.Prompt{
68 {
69 Role: fantasy.MessageRoleUser,
70 Content: []fantasy.MessagePart{
71 fantasy.TextPart{Text: "Hello"},
72 },
73 },
74 {
75 Role: fantasy.MessageRoleAssistant,
76 Content: []fantasy.MessagePart{
77 fantasy.ReasoningPart{
78 Text: "Let me think about this...",
79 ProviderOptions: fantasy.ProviderOptions{
80 Name: &ReasoningOptionMetadata{
81 Signature: "def456",
82 },
83 },
84 },
85 },
86 },
87 }
88
89 systemBlocks, messages, warnings := toPrompt(prompt, false)
90
91 require.Empty(t, systemBlocks)
92 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
93 require.Len(t, warnings, 2)
94 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
95 require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
96 require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
97 require.Contains(t, warnings[1].Message, "dropping empty assistant message")
98 })
99
100 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
101 t.Parallel()
102
103 prompt := fantasy.Prompt{
104 {
105 Role: fantasy.MessageRoleUser,
106 Content: []fantasy.MessagePart{
107 fantasy.TextPart{Text: "Hello"},
108 },
109 },
110 {
111 Role: fantasy.MessageRoleAssistant,
112 Content: []fantasy.MessagePart{},
113 },
114 }
115
116 systemBlocks, messages, warnings := toPrompt(prompt, true)
117
118 require.Empty(t, systemBlocks)
119 require.Len(t, messages, 1, "should only have user message")
120 require.Len(t, warnings, 1)
121 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
122 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
123 })
124
125 t.Run("should keep assistant messages with text content", func(t *testing.T) {
126 t.Parallel()
127
128 prompt := fantasy.Prompt{
129 {
130 Role: fantasy.MessageRoleUser,
131 Content: []fantasy.MessagePart{
132 fantasy.TextPart{Text: "Hello"},
133 },
134 },
135 {
136 Role: fantasy.MessageRoleAssistant,
137 Content: []fantasy.MessagePart{
138 fantasy.TextPart{Text: "Hi there!"},
139 },
140 },
141 }
142
143 systemBlocks, messages, warnings := toPrompt(prompt, true)
144
145 require.Empty(t, systemBlocks)
146 require.Len(t, messages, 2, "should have both user and assistant messages")
147 require.Empty(t, warnings)
148 })
149
150 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
151 t.Parallel()
152
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.TextPart{Text: "What's the weather?"},
158 },
159 },
160 {
161 Role: fantasy.MessageRoleAssistant,
162 Content: []fantasy.MessagePart{
163 fantasy.ToolCallPart{
164 ToolCallID: "call_123",
165 ToolName: "get_weather",
166 Input: `{"location":"NYC"}`,
167 },
168 },
169 },
170 }
171
172 systemBlocks, messages, warnings := toPrompt(prompt, true)
173
174 require.Empty(t, systemBlocks)
175 require.Len(t, messages, 2, "should have both user and assistant messages")
176 require.Empty(t, warnings)
177 })
178
179 t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
180 t.Parallel()
181
182 prompt := fantasy.Prompt{
183 {
184 Role: fantasy.MessageRoleUser,
185 Content: []fantasy.MessagePart{
186 fantasy.TextPart{Text: "Hi"},
187 },
188 },
189 {
190 Role: fantasy.MessageRoleAssistant,
191 Content: []fantasy.MessagePart{
192 fantasy.ToolCallPart{
193 ToolCallID: "call_123",
194 ToolName: "get_weather",
195 Input: "{not-json",
196 },
197 },
198 },
199 }
200
201 systemBlocks, messages, warnings := toPrompt(prompt, true)
202
203 require.Empty(t, systemBlocks)
204 require.Len(t, messages, 1, "should only have user message")
205 require.Len(t, warnings, 1)
206 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
207 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
208 })
209
210 t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
211 t.Parallel()
212
213 prompt := fantasy.Prompt{
214 {
215 Role: fantasy.MessageRoleUser,
216 Content: []fantasy.MessagePart{
217 fantasy.TextPart{Text: "Hello"},
218 },
219 },
220 {
221 Role: fantasy.MessageRoleAssistant,
222 Content: []fantasy.MessagePart{
223 fantasy.ReasoningPart{
224 Text: "Let me think...",
225 ProviderOptions: fantasy.ProviderOptions{
226 Name: &ReasoningOptionMetadata{
227 Signature: "abc123",
228 },
229 },
230 },
231 fantasy.TextPart{Text: "Hi there!"},
232 },
233 },
234 }
235
236 systemBlocks, messages, warnings := toPrompt(prompt, true)
237
238 require.Empty(t, systemBlocks)
239 require.Len(t, messages, 2, "should have both user and assistant messages")
240 require.Empty(t, warnings)
241 })
242
243 t.Run("should keep user messages with image content", func(t *testing.T) {
244 t.Parallel()
245
246 prompt := fantasy.Prompt{
247 {
248 Role: fantasy.MessageRoleUser,
249 Content: []fantasy.MessagePart{
250 fantasy.FilePart{
251 Data: []byte{0x01, 0x02, 0x03},
252 MediaType: "image/png",
253 },
254 },
255 },
256 }
257
258 systemBlocks, messages, warnings := toPrompt(prompt, true)
259
260 require.Empty(t, systemBlocks)
261 require.Len(t, messages, 1)
262 require.Empty(t, warnings)
263 })
264
265 t.Run("should drop user messages without visible content", func(t *testing.T) {
266 t.Parallel()
267
268 prompt := fantasy.Prompt{
269 {
270 Role: fantasy.MessageRoleUser,
271 Content: []fantasy.MessagePart{
272 fantasy.FilePart{
273 Data: []byte("not supported"),
274 MediaType: "application/pdf",
275 },
276 },
277 },
278 }
279
280 systemBlocks, messages, warnings := toPrompt(prompt, true)
281
282 require.Empty(t, systemBlocks)
283 require.Empty(t, messages)
284 require.Len(t, warnings, 1)
285 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
286 require.Contains(t, warnings[0].Message, "dropping empty user message")
287 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
288 })
289
290 t.Run("should keep user messages with tool results", func(t *testing.T) {
291 t.Parallel()
292
293 prompt := fantasy.Prompt{
294 {
295 Role: fantasy.MessageRoleTool,
296 Content: []fantasy.MessagePart{
297 fantasy.ToolResultPart{
298 ToolCallID: "call_123",
299 Output: fantasy.ToolResultOutputContentText{Text: "done"},
300 },
301 },
302 },
303 }
304
305 systemBlocks, messages, warnings := toPrompt(prompt, true)
306
307 require.Empty(t, systemBlocks)
308 require.Len(t, messages, 1)
309 require.Empty(t, warnings)
310 })
311
312 t.Run("should keep user messages with tool error results", func(t *testing.T) {
313 t.Parallel()
314
315 prompt := fantasy.Prompt{
316 {
317 Role: fantasy.MessageRoleTool,
318 Content: []fantasy.MessagePart{
319 fantasy.ToolResultPart{
320 ToolCallID: "call_456",
321 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
322 },
323 },
324 },
325 }
326
327 systemBlocks, messages, warnings := toPrompt(prompt, true)
328
329 require.Empty(t, systemBlocks)
330 require.Len(t, messages, 1)
331 require.Empty(t, warnings)
332 })
333
334 t.Run("should keep user messages with tool media results", func(t *testing.T) {
335 t.Parallel()
336
337 prompt := fantasy.Prompt{
338 {
339 Role: fantasy.MessageRoleTool,
340 Content: []fantasy.MessagePart{
341 fantasy.ToolResultPart{
342 ToolCallID: "call_789",
343 Output: fantasy.ToolResultOutputContentMedia{
344 Data: "AQID",
345 MediaType: "image/png",
346 },
347 },
348 },
349 },
350 }
351
352 systemBlocks, messages, warnings := toPrompt(prompt, true)
353
354 require.Empty(t, systemBlocks)
355 require.Len(t, messages, 1)
356 require.Empty(t, warnings)
357 })
358}
359
360func TestParseContextTooLargeError(t *testing.T) {
361 t.Parallel()
362
363 tests := []struct {
364 name string
365 message string
366 wantErr bool
367 wantUsed int
368 wantMax int
369 }{
370 {
371 name: "matches anthropic format",
372 message: "prompt is too long: 202630 tokens > 200000 maximum",
373 wantErr: true,
374 wantUsed: 202630,
375 wantMax: 200000,
376 },
377 {
378 name: "matches with different numbers",
379 message: "prompt is too long: 150000 tokens > 128000 maximum",
380 wantErr: true,
381 wantUsed: 150000,
382 wantMax: 128000,
383 },
384 {
385 name: "matches with extra whitespace",
386 message: "prompt is too long: 202630 tokens > 200000 maximum",
387 wantErr: true,
388 wantUsed: 202630,
389 wantMax: 200000,
390 },
391 {
392 name: "does not match unrelated error",
393 message: "invalid api key",
394 wantErr: false,
395 },
396 {
397 name: "does not match rate limit error",
398 message: "rate limit exceeded",
399 wantErr: false,
400 },
401 }
402
403 for _, tt := range tests {
404 t.Run(tt.name, func(t *testing.T) {
405 t.Parallel()
406 providerErr := &fantasy.ProviderError{Message: tt.message}
407 parseContextTooLargeError(tt.message, providerErr)
408
409 if tt.wantErr {
410 require.True(t, providerErr.IsContextTooLarge())
411 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
412 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
413 } else {
414 require.False(t, providerErr.IsContextTooLarge())
415 }
416 })
417 }
418}
419
420func TestParseOptions_Effort(t *testing.T) {
421 t.Parallel()
422
423 options, err := ParseOptions(map[string]any{
424 "send_reasoning": true,
425 "thinking": map[string]any{"budget_tokens": int64(2048)},
426 "effort": "medium",
427 "disable_parallel_tool_use": true,
428 })
429 require.NoError(t, err)
430 require.NotNil(t, options.SendReasoning)
431 require.True(t, *options.SendReasoning)
432 require.NotNil(t, options.Thinking)
433 require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
434 require.NotNil(t, options.Effort)
435 require.Equal(t, EffortMedium, *options.Effort)
436 require.NotNil(t, options.DisableParallelToolUse)
437 require.True(t, *options.DisableParallelToolUse)
438}
439
440func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
441 t.Parallel()
442
443 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
444 defer server.Close()
445
446 provider, err := New(
447 WithAPIKey("test-api-key"),
448 WithBaseURL(server.URL),
449 )
450 require.NoError(t, err)
451
452 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
453 require.NoError(t, err)
454
455 effort := EffortMedium
456 _, err = model.Generate(context.Background(), fantasy.Call{
457 Prompt: testPrompt(),
458 ProviderOptions: NewProviderOptions(&ProviderOptions{
459 Effort: &effort,
460 }),
461 })
462 require.NoError(t, err)
463
464 call := awaitAnthropicCall(t, calls)
465 require.Equal(t, "POST", call.method)
466 require.Equal(t, "/v1/messages", call.path)
467 requireAnthropicEffort(t, call.body, EffortMedium)
468}
469
470func TestStream_SendsOutputConfigEffort(t *testing.T) {
471 t.Parallel()
472
473 server, calls := newAnthropicStreamingServer([]string{
474 "event: message_start\n",
475 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
476 "event: message_stop\n",
477 "data: {\"type\":\"message_stop\"}\n\n",
478 })
479 defer server.Close()
480
481 provider, err := New(
482 WithAPIKey("test-api-key"),
483 WithBaseURL(server.URL),
484 )
485 require.NoError(t, err)
486
487 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
488 require.NoError(t, err)
489
490 effort := EffortHigh
491 stream, err := model.Stream(context.Background(), fantasy.Call{
492 Prompt: testPrompt(),
493 ProviderOptions: NewProviderOptions(&ProviderOptions{
494 Effort: &effort,
495 }),
496 })
497 require.NoError(t, err)
498
499 stream(func(fantasy.StreamPart) bool { return true })
500
501 call := awaitAnthropicCall(t, calls)
502 require.Equal(t, "POST", call.method)
503 require.Equal(t, "/v1/messages", call.path)
504 requireAnthropicEffort(t, call.body, EffortHigh)
505}
506
507type anthropicCall struct {
508 method string
509 path string
510 body map[string]any
511}
512
513func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
514 calls := make(chan anthropicCall, 4)
515
516 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
517 var body map[string]any
518 if r.Body != nil {
519 _ = json.NewDecoder(r.Body).Decode(&body)
520 }
521
522 calls <- anthropicCall{
523 method: r.Method,
524 path: r.URL.Path,
525 body: body,
526 }
527
528 w.Header().Set("Content-Type", "application/json")
529 _ = json.NewEncoder(w).Encode(response)
530 }))
531
532 return server, calls
533}
534
535func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
536 calls := make(chan anthropicCall, 4)
537
538 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
539 var body map[string]any
540 if r.Body != nil {
541 _ = json.NewDecoder(r.Body).Decode(&body)
542 }
543
544 calls <- anthropicCall{
545 method: r.Method,
546 path: r.URL.Path,
547 body: body,
548 }
549
550 w.Header().Set("Content-Type", "text/event-stream")
551 w.Header().Set("Cache-Control", "no-cache")
552 w.Header().Set("Connection", "keep-alive")
553 w.WriteHeader(http.StatusOK)
554
555 for _, chunk := range chunks {
556 _, _ = fmt.Fprint(w, chunk)
557 if flusher, ok := w.(http.Flusher); ok {
558 flusher.Flush()
559 }
560 }
561 }))
562
563 return server, calls
564}
565
566func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
567 t.Helper()
568
569 select {
570 case call := <-calls:
571 return call
572 case <-time.After(2 * time.Second):
573 t.Fatal("timed out waiting for Anthropic request")
574 return anthropicCall{}
575 }
576}
577
578func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
579 t.Helper()
580
581 select {
582 case call := <-calls:
583 t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
584 case <-time.After(200 * time.Millisecond):
585 }
586}
587
588func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
589 t.Helper()
590
591 outputConfig, ok := body["output_config"].(map[string]any)
592 thinking, ok := body["thinking"].(map[string]any)
593 require.True(t, ok)
594 require.Equal(t, string(expected), outputConfig["effort"])
595 require.Equal(t, "adaptive", thinking["type"])
596}
597
598func testPrompt() fantasy.Prompt {
599 return fantasy.Prompt{
600 {
601 Role: fantasy.MessageRoleUser,
602 Content: []fantasy.MessagePart{
603 fantasy.TextPart{Text: "Hello"},
604 },
605 },
606 }
607}
608
609func mockAnthropicGenerateResponse() map[string]any {
610 return map[string]any{
611 "id": "msg_01Test",
612 "type": "message",
613 "role": "assistant",
614 "model": "claude-sonnet-4-20250514",
615 "content": []any{
616 map[string]any{
617 "type": "text",
618 "text": "Hi there",
619 },
620 },
621 "stop_reason": "end_turn",
622 "stop_sequence": "",
623 "usage": map[string]any{
624 "cache_creation": map[string]any{
625 "ephemeral_1h_input_tokens": 0,
626 "ephemeral_5m_input_tokens": 0,
627 },
628 "cache_creation_input_tokens": 0,
629 "cache_read_input_tokens": 0,
630 "input_tokens": 5,
631 "output_tokens": 2,
632 "server_tool_use": map[string]any{
633 "web_search_requests": 0,
634 },
635 "service_tier": "standard",
636 },
637 }
638}
639
640func mockAnthropicWebSearchResponse() map[string]any {
641 return map[string]any{
642 "id": "msg_01WebSearch",
643 "type": "message",
644 "role": "assistant",
645 "model": "claude-sonnet-4-20250514",
646 "content": []any{
647 map[string]any{
648 "type": "server_tool_use",
649 "id": "srvtoolu_01",
650 "name": "web_search",
651 "input": map[string]any{"query": "latest AI news"},
652 "caller": map[string]any{"type": "direct"},
653 },
654 map[string]any{
655 "type": "web_search_tool_result",
656 "tool_use_id": "srvtoolu_01",
657 "caller": map[string]any{"type": "direct"},
658 "content": []any{
659 map[string]any{
660 "type": "web_search_result",
661 "url": "https://example.com/ai-news",
662 "title": "Latest AI News",
663 "encrypted_content": "encrypted_abc123",
664 "page_age": "2 hours ago",
665 },
666 map[string]any{
667 "type": "web_search_result",
668 "url": "https://example.com/ml-update",
669 "title": "ML Update",
670 "encrypted_content": "encrypted_def456",
671 "page_age": "",
672 },
673 },
674 },
675 map[string]any{
676 "type": "text",
677 "text": "Based on recent search results, here is the latest AI news.",
678 },
679 },
680 "stop_reason": "end_turn",
681 "stop_sequence": nil,
682 "usage": map[string]any{
683 "input_tokens": 100,
684 "output_tokens": 50,
685 "cache_creation_input_tokens": 0,
686 "cache_read_input_tokens": 0,
687 "server_tool_use": map[string]any{
688 "web_search_requests": 1,
689 },
690 },
691 }
692}
693
694func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
695 t.Parallel()
696
697 prompt := fantasy.Prompt{
698 // User message.
699 {
700 Role: fantasy.MessageRoleUser,
701 Content: []fantasy.MessagePart{
702 fantasy.TextPart{Text: "Search for the latest AI news"},
703 },
704 },
705 // Assistant message with a provider-executed tool call, its
706 // result, and trailing text. toResponseMessages routes
707 // provider-executed results into the assistant message, so
708 // the prompt already reflects that structure.
709 {
710 Role: fantasy.MessageRoleAssistant,
711 Content: []fantasy.MessagePart{
712 fantasy.ToolCallPart{
713 ToolCallID: "srvtoolu_01",
714 ToolName: "web_search",
715 Input: `{"query":"latest AI news"}`,
716 ProviderExecuted: true,
717 },
718 fantasy.ToolResultPart{
719 ToolCallID: "srvtoolu_01",
720 ProviderExecuted: true,
721 ProviderOptions: fantasy.ProviderOptions{
722 Name: &WebSearchResultMetadata{
723 Results: []WebSearchResultItem{
724 {
725 URL: "https://example.com/ai-news",
726 Title: "Latest AI News",
727 EncryptedContent: "encrypted_abc123",
728 PageAge: "2 hours ago",
729 },
730 {
731 URL: "https://example.com/ml-update",
732 Title: "ML Update",
733 EncryptedContent: "encrypted_def456",
734 },
735 },
736 },
737 },
738 },
739 fantasy.TextPart{Text: "Here is what I found."},
740 },
741 },
742 }
743
744 _, messages, warnings := toPrompt(prompt, true)
745
746 // No warnings expected; the provider-executed result is in the
747 // assistant message so there is no empty tool message to drop.
748 require.Empty(t, warnings)
749
750 // We should have a user message and an assistant message.
751 require.Len(t, messages, 2, "expected user + assistant messages")
752
753 assistantMsg := messages[1]
754 require.Len(t, assistantMsg.Content, 3,
755 "expected server_tool_use + web_search_tool_result + text")
756
757 // First content block: reconstructed server_tool_use.
758 serverToolUse := assistantMsg.Content[0]
759 require.NotNil(t, serverToolUse.OfServerToolUse,
760 "first block should be a server_tool_use")
761 require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
762 require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
763 serverToolUse.OfServerToolUse.Name)
764
765 // Second content block: reconstructed web_search_tool_result with
766 // encrypted_content preserved for multi-turn round-tripping.
767 webResult := assistantMsg.Content[1]
768 require.NotNil(t, webResult.OfWebSearchToolResult,
769 "second block should be a web_search_tool_result")
770 require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
771
772 results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
773 require.Len(t, results, 2)
774 require.Equal(t, "https://example.com/ai-news", results[0].URL)
775 require.Equal(t, "Latest AI News", results[0].Title)
776 require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
777 require.Equal(t, "https://example.com/ml-update", results[1].URL)
778 require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
779 // PageAge should be set for the first result and absent for the second.
780 require.True(t, results[0].PageAge.Valid())
781 require.Equal(t, "2 hours ago", results[0].PageAge.Value)
782 require.False(t, results[1].PageAge.Valid())
783
784 // Third content block: plain text.
785 require.NotNil(t, assistantMsg.Content[2].OfText)
786 require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
787}
788
789func TestGenerate_WebSearchResponse(t *testing.T) {
790 t.Parallel()
791
792 server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
793 defer server.Close()
794
795 provider, err := New(
796 WithAPIKey("test-api-key"),
797 WithBaseURL(server.URL),
798 )
799 require.NoError(t, err)
800
801 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
802 require.NoError(t, err)
803
804 resp, err := model.Generate(context.Background(), fantasy.Call{
805 Prompt: testPrompt(),
806 Tools: []fantasy.Tool{
807 WebSearchTool(nil),
808 },
809 })
810 require.NoError(t, err)
811
812 call := awaitAnthropicCall(t, calls)
813 require.Equal(t, "POST", call.method)
814 require.Equal(t, "/v1/messages", call.path)
815
816 // Walk the response content and categorise each item.
817 var (
818 toolCalls []fantasy.ToolCallContent
819 sources []fantasy.SourceContent
820 toolResults []fantasy.ToolResultContent
821 texts []fantasy.TextContent
822 )
823 for _, c := range resp.Content {
824 switch v := c.(type) {
825 case fantasy.ToolCallContent:
826 toolCalls = append(toolCalls, v)
827 case fantasy.SourceContent:
828 sources = append(sources, v)
829 case fantasy.ToolResultContent:
830 toolResults = append(toolResults, v)
831 case fantasy.TextContent:
832 texts = append(texts, v)
833 }
834 }
835
836 // ToolCallContent for the provider-executed web_search.
837 require.Len(t, toolCalls, 1)
838 require.True(t, toolCalls[0].ProviderExecuted)
839 require.Equal(t, "web_search", toolCalls[0].ToolName)
840 require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
841
842 // SourceContent entries for each search result.
843 require.Len(t, sources, 2)
844 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
845 require.Equal(t, "Latest AI News", sources[0].Title)
846 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
847 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
848 require.Equal(t, "ML Update", sources[1].Title)
849
850 // ToolResultContent with provider metadata preserving encrypted_content.
851 require.Len(t, toolResults, 1)
852 require.True(t, toolResults[0].ProviderExecuted)
853 require.Equal(t, "web_search", toolResults[0].ToolName)
854 require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
855
856 searchMeta, ok := toolResults[0].ProviderMetadata[Name]
857 require.True(t, ok, "providerMetadata should contain anthropic key")
858 webMeta, ok := searchMeta.(*WebSearchResultMetadata)
859 require.True(t, ok, "metadata should be *WebSearchResultMetadata")
860 require.Len(t, webMeta.Results, 2)
861 require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
862 require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
863 require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
864
865 // TextContent with the final answer.
866 require.Len(t, texts, 1)
867 require.Equal(t,
868 "Based on recent search results, here is the latest AI news.",
869 texts[0].Text,
870 )
871}
872
873func TestGenerate_WebSearchToolInRequest(t *testing.T) {
874 t.Parallel()
875
876 t.Run("basic web_search tool", func(t *testing.T) {
877 t.Parallel()
878
879 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
880 defer server.Close()
881
882 provider, err := New(
883 WithAPIKey("test-api-key"),
884 WithBaseURL(server.URL),
885 )
886 require.NoError(t, err)
887
888 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
889 require.NoError(t, err)
890
891 _, err = model.Generate(context.Background(), fantasy.Call{
892 Prompt: testPrompt(),
893 Tools: []fantasy.Tool{
894 WebSearchTool(nil),
895 },
896 })
897 require.NoError(t, err)
898
899 call := awaitAnthropicCall(t, calls)
900 tools, ok := call.body["tools"].([]any)
901 require.True(t, ok, "request body should have tools array")
902 require.Len(t, tools, 1)
903
904 tool, ok := tools[0].(map[string]any)
905 require.True(t, ok)
906 require.Equal(t, "web_search_20250305", tool["type"])
907 })
908
909 t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
910 t.Parallel()
911
912 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
913 defer server.Close()
914
915 provider, err := New(
916 WithAPIKey("test-api-key"),
917 WithBaseURL(server.URL),
918 )
919 require.NoError(t, err)
920
921 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
922 require.NoError(t, err)
923
924 _, err = model.Generate(context.Background(), fantasy.Call{
925 Prompt: testPrompt(),
926 Tools: []fantasy.Tool{
927 WebSearchTool(&WebSearchToolOptions{
928 AllowedDomains: []string{"example.com", "test.com"},
929 }),
930 },
931 })
932 require.NoError(t, err)
933
934 call := awaitAnthropicCall(t, calls)
935 tools, ok := call.body["tools"].([]any)
936 require.True(t, ok)
937 require.Len(t, tools, 1)
938
939 tool, ok := tools[0].(map[string]any)
940 require.True(t, ok)
941 require.Equal(t, "web_search_20250305", tool["type"])
942
943 domains, ok := tool["allowed_domains"].([]any)
944 require.True(t, ok, "tool should have allowed_domains")
945 require.Len(t, domains, 2)
946 require.Equal(t, "example.com", domains[0])
947 require.Equal(t, "test.com", domains[1])
948 })
949
950 t.Run("with max uses and user location", func(t *testing.T) {
951 t.Parallel()
952
953 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
954 defer server.Close()
955
956 provider, err := New(
957 WithAPIKey("test-api-key"),
958 WithBaseURL(server.URL),
959 )
960 require.NoError(t, err)
961
962 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
963 require.NoError(t, err)
964
965 _, err = model.Generate(context.Background(), fantasy.Call{
966 Prompt: testPrompt(),
967 Tools: []fantasy.Tool{
968 WebSearchTool(&WebSearchToolOptions{
969 MaxUses: 5,
970 UserLocation: &UserLocation{
971 City: "San Francisco",
972 Country: "US",
973 },
974 }),
975 },
976 })
977 require.NoError(t, err)
978
979 call := awaitAnthropicCall(t, calls)
980 tools, ok := call.body["tools"].([]any)
981 require.True(t, ok)
982 require.Len(t, tools, 1)
983
984 tool, ok := tools[0].(map[string]any)
985 require.True(t, ok)
986 require.Equal(t, "web_search_20250305", tool["type"])
987
988 // max_uses is serialized as a JSON number; json.Unmarshal
989 // into map[string]any decodes numbers as float64.
990 maxUses, ok := tool["max_uses"].(float64)
991 require.True(t, ok, "tool should have max_uses")
992 require.Equal(t, float64(5), maxUses)
993
994 userLoc, ok := tool["user_location"].(map[string]any)
995 require.True(t, ok, "tool should have user_location")
996 require.Equal(t, "San Francisco", userLoc["city"])
997 require.Equal(t, "US", userLoc["country"])
998 require.Equal(t, "approximate", userLoc["type"])
999 })
1000
1001 t.Run("with max uses", func(t *testing.T) {
1002 t.Parallel()
1003
1004 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1005 defer server.Close()
1006
1007 provider, err := New(
1008 WithAPIKey("test-api-key"),
1009 WithBaseURL(server.URL),
1010 )
1011 require.NoError(t, err)
1012
1013 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1014 require.NoError(t, err)
1015
1016 _, err = model.Generate(context.Background(), fantasy.Call{
1017 Prompt: testPrompt(),
1018 Tools: []fantasy.Tool{
1019 WebSearchTool(&WebSearchToolOptions{
1020 MaxUses: 3,
1021 }),
1022 },
1023 })
1024 require.NoError(t, err)
1025
1026 call := awaitAnthropicCall(t, calls)
1027 tools, ok := call.body["tools"].([]any)
1028 require.True(t, ok)
1029 require.Len(t, tools, 1)
1030
1031 tool, ok := tools[0].(map[string]any)
1032 require.True(t, ok)
1033 require.Equal(t, "web_search_20250305", tool["type"])
1034
1035 maxUses, ok := tool["max_uses"].(float64)
1036 require.True(t, ok, "tool should have max_uses")
1037 require.Equal(t, float64(3), maxUses)
1038 })
1039
1040 t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1041 t.Parallel()
1042
1043 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1044 defer server.Close()
1045
1046 provider, err := New(
1047 WithAPIKey("test-api-key"),
1048 WithBaseURL(server.URL),
1049 )
1050 require.NoError(t, err)
1051
1052 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1053 require.NoError(t, err)
1054
1055 baseTool := WebSearchTool(&WebSearchToolOptions{
1056 MaxUses: 7,
1057 BlockedDomains: []string{"example.com", "test.com"},
1058 UserLocation: &UserLocation{
1059 City: "San Francisco",
1060 Region: "CA",
1061 Country: "US",
1062 Timezone: "America/Los_Angeles",
1063 },
1064 })
1065
1066 data, err := json.Marshal(baseTool)
1067 require.NoError(t, err)
1068
1069 var roundTripped fantasy.ProviderDefinedTool
1070 err = json.Unmarshal(data, &roundTripped)
1071 require.NoError(t, err)
1072
1073 _, err = model.Generate(context.Background(), fantasy.Call{
1074 Prompt: testPrompt(),
1075 Tools: []fantasy.Tool{roundTripped},
1076 })
1077 require.NoError(t, err)
1078
1079 call := awaitAnthropicCall(t, calls)
1080 tools, ok := call.body["tools"].([]any)
1081 require.True(t, ok)
1082 require.Len(t, tools, 1)
1083
1084 tool, ok := tools[0].(map[string]any)
1085 require.True(t, ok)
1086 require.Equal(t, "web_search_20250305", tool["type"])
1087
1088 domains, ok := tool["blocked_domains"].([]any)
1089 require.True(t, ok, "tool should have blocked_domains")
1090 require.Len(t, domains, 2)
1091 require.Equal(t, "example.com", domains[0])
1092 require.Equal(t, "test.com", domains[1])
1093
1094 maxUses, ok := tool["max_uses"].(float64)
1095 require.True(t, ok, "tool should have max_uses")
1096 require.Equal(t, float64(7), maxUses)
1097
1098 userLoc, ok := tool["user_location"].(map[string]any)
1099 require.True(t, ok, "tool should have user_location")
1100 require.Equal(t, "San Francisco", userLoc["city"])
1101 require.Equal(t, "CA", userLoc["region"])
1102 require.Equal(t, "US", userLoc["country"])
1103 require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1104 require.Equal(t, "approximate", userLoc["type"])
1105 })
1106}
1107
1108func TestAnyToStringSlice(t *testing.T) {
1109 t.Parallel()
1110
1111 t.Run("from string slice", func(t *testing.T) {
1112 t.Parallel()
1113
1114 got := anyToStringSlice([]string{"example.com", ""})
1115 require.Equal(t, []string{"example.com", ""}, got)
1116 })
1117
1118 t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1119 t.Parallel()
1120
1121 got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1122 require.Equal(t, []string{"example.com", "test.com"}, got)
1123 })
1124
1125 t.Run("unsupported type", func(t *testing.T) {
1126 t.Parallel()
1127
1128 got := anyToStringSlice("example.com")
1129 require.Nil(t, got)
1130 })
1131}
1132
1133func TestAnyToInt64(t *testing.T) {
1134 t.Parallel()
1135
1136 tests := []struct {
1137 name string
1138 input any
1139 want int64
1140 wantOK bool
1141 }{
1142 {name: "int64", input: int64(7), want: 7, wantOK: true},
1143 {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1144 {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1145 {name: "float64 non-integer", input: float64(7.5), wantOK: false},
1146 {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1147 {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1148 {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1149 {name: "json number float", input: json.Number("4.2"), wantOK: false},
1150 {name: "nan", input: math.NaN(), wantOK: false},
1151 {name: "inf", input: math.Inf(1), wantOK: false},
1152 {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1153 }
1154
1155 for _, tt := range tests {
1156 t.Run(tt.name, func(t *testing.T) {
1157 got, ok := anyToInt64(tt.input)
1158 require.Equal(t, tt.wantOK, ok)
1159 if tt.wantOK {
1160 require.Equal(t, tt.want, got)
1161 }
1162 })
1163 }
1164}
1165
1166func TestAnyToUserLocation(t *testing.T) {
1167 t.Parallel()
1168
1169 t.Run("pointer passthrough", func(t *testing.T) {
1170 t.Parallel()
1171
1172 input := &UserLocation{City: "San Francisco", Country: "US"}
1173 got := anyToUserLocation(input)
1174 require.Same(t, input, got)
1175 })
1176
1177 t.Run("struct value", func(t *testing.T) {
1178 t.Parallel()
1179
1180 got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1181 require.NotNil(t, got)
1182 require.Equal(t, "San Francisco", got.City)
1183 require.Equal(t, "US", got.Country)
1184 })
1185
1186 t.Run("map value", func(t *testing.T) {
1187 t.Parallel()
1188
1189 got := anyToUserLocation(map[string]any{
1190 "city": "San Francisco",
1191 "region": "CA",
1192 "country": "US",
1193 "timezone": "America/Los_Angeles",
1194 "type": "approximate",
1195 })
1196 require.NotNil(t, got)
1197 require.Equal(t, "San Francisco", got.City)
1198 require.Equal(t, "CA", got.Region)
1199 require.Equal(t, "US", got.Country)
1200 require.Equal(t, "America/Los_Angeles", got.Timezone)
1201 })
1202
1203 t.Run("empty map", func(t *testing.T) {
1204 t.Parallel()
1205
1206 got := anyToUserLocation(map[string]any{"type": "approximate"})
1207 require.Nil(t, got)
1208 })
1209
1210 t.Run("unsupported type", func(t *testing.T) {
1211 t.Parallel()
1212
1213 got := anyToUserLocation("San Francisco")
1214 require.Nil(t, got)
1215 })
1216}
1217
1218func TestStream_WebSearchResponse(t *testing.T) {
1219 t.Parallel()
1220
1221 // Build SSE chunks that simulate a web search streaming response.
1222 // The Anthropic SDK accumulates content blocks via
1223 // acc.Accumulate(event). We read the Content and ToolUseID
1224 // directly from struct fields instead of using AsAny(), which
1225 // avoids the SDK's re-marshal limitation that previously dropped
1226 // source data.
1227 webSearchResultContent, _ := json.Marshal([]any{
1228 map[string]any{
1229 "type": "web_search_result",
1230 "url": "https://example.com/ai-news",
1231 "title": "Latest AI News",
1232 "encrypted_content": "encrypted_abc123",
1233 "page_age": "2 hours ago",
1234 },
1235 })
1236
1237 chunks := []string{
1238 // message_start
1239 "event: message_start\n",
1240 `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1241 // Block 0: server_tool_use
1242 "event: content_block_start\n",
1243 `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1244 "event: content_block_stop\n",
1245 `data: {"type":"content_block_stop","index":0}` + "\n\n",
1246 // Block 1: web_search_tool_result
1247 "event: content_block_start\n",
1248 `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1249 "event: content_block_stop\n",
1250 `data: {"type":"content_block_stop","index":1}` + "\n\n",
1251 // Block 2: text
1252 "event: content_block_start\n",
1253 `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1254 "event: content_block_delta\n",
1255 `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1256 "event: content_block_stop\n",
1257 `data: {"type":"content_block_stop","index":2}` + "\n\n",
1258 // message_stop
1259 "event: message_stop\n",
1260 `data: {"type":"message_stop"}` + "\n\n",
1261 }
1262
1263 server, calls := newAnthropicStreamingServer(chunks)
1264 defer server.Close()
1265
1266 provider, err := New(
1267 WithAPIKey("test-api-key"),
1268 WithBaseURL(server.URL),
1269 )
1270 require.NoError(t, err)
1271
1272 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1273 require.NoError(t, err)
1274
1275 stream, err := model.Stream(context.Background(), fantasy.Call{
1276 Prompt: testPrompt(),
1277 Tools: []fantasy.Tool{
1278 WebSearchTool(nil),
1279 },
1280 })
1281 require.NoError(t, err)
1282
1283 var parts []fantasy.StreamPart
1284 stream(func(part fantasy.StreamPart) bool {
1285 parts = append(parts, part)
1286 return true
1287 })
1288
1289 _ = awaitAnthropicCall(t, calls)
1290
1291 // Collect parts by type for assertions.
1292 var (
1293 toolInputStarts []fantasy.StreamPart
1294 toolCalls []fantasy.StreamPart
1295 toolResults []fantasy.StreamPart
1296 sourceParts []fantasy.StreamPart
1297 textDeltas []fantasy.StreamPart
1298 )
1299 for _, p := range parts {
1300 switch p.Type {
1301 case fantasy.StreamPartTypeToolInputStart:
1302 toolInputStarts = append(toolInputStarts, p)
1303 case fantasy.StreamPartTypeToolCall:
1304 toolCalls = append(toolCalls, p)
1305 case fantasy.StreamPartTypeToolResult:
1306 toolResults = append(toolResults, p)
1307 case fantasy.StreamPartTypeSource:
1308 sourceParts = append(sourceParts, p)
1309 case fantasy.StreamPartTypeTextDelta:
1310 textDeltas = append(textDeltas, p)
1311 }
1312 }
1313
1314 // server_tool_use emits a ToolInputStart with ProviderExecuted.
1315 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1316 require.True(t, toolInputStarts[0].ProviderExecuted)
1317 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1318
1319 // server_tool_use emits a ToolCall with ProviderExecuted.
1320 require.NotEmpty(t, toolCalls, "should have a tool call")
1321 require.True(t, toolCalls[0].ProviderExecuted)
1322
1323 // web_search_tool_result always emits a ToolResult even when
1324 // the SDK drops source data. The ToolUseID comes from the raw
1325 // union field as a fallback.
1326 require.NotEmpty(t, toolResults, "should have a tool result")
1327 require.True(t, toolResults[0].ProviderExecuted)
1328 require.Equal(t, "web_search", toolResults[0].ToolCallName)
1329 require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1330 "tool result ID should match the tool_use_id")
1331
1332 // Source parts are now correctly emitted by reading struct fields
1333 // directly instead of using AsAny().
1334 require.Len(t, sourceParts, 1)
1335 require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1336 require.Equal(t, "Latest AI News", sourceParts[0].Title)
1337 require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1338
1339 // Text block emits a text delta.
1340 require.NotEmpty(t, textDeltas, "should have text deltas")
1341 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1342}
1343
1344func TestGenerate_ToolChoiceNone(t *testing.T) {
1345 t.Parallel()
1346
1347 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1348 defer server.Close()
1349
1350 provider, err := New(
1351 WithAPIKey("test-api-key"),
1352 WithBaseURL(server.URL),
1353 )
1354 require.NoError(t, err)
1355
1356 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1357 require.NoError(t, err)
1358
1359 toolChoiceNone := fantasy.ToolChoiceNone
1360 _, err = model.Generate(context.Background(), fantasy.Call{
1361 Prompt: testPrompt(),
1362 Tools: []fantasy.Tool{
1363 WebSearchTool(nil),
1364 },
1365 ToolChoice: &toolChoiceNone,
1366 })
1367 require.NoError(t, err)
1368
1369 call := awaitAnthropicCall(t, calls)
1370 toolChoice, ok := call.body["tool_choice"].(map[string]any)
1371 require.True(t, ok, "request body should have tool_choice")
1372 require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1373}
1374
1375// --- Computer Use Tests ---
1376
1377// jsonRoundTripTool simulates a JSON round-trip on a
1378// ProviderDefinedTool so that its Args map contains float64
1379// values (as json.Unmarshal produces) rather than the int64
1380// values that NewComputerUseTool stores directly. The
1381// production toBetaTools code asserts float64.
1382func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1383 t.Helper()
1384 pdt := tool.Definition()
1385 data, err := json.Marshal(pdt.Args)
1386 require.NoError(t, err)
1387 var args map[string]any
1388 require.NoError(t, json.Unmarshal(data, &args))
1389 pdt.Args = args
1390 return pdt
1391}
1392
1393func TestNewComputerUseTool(t *testing.T) {
1394 t.Parallel()
1395
1396 t.Run("creates tool with correct ID and name", func(t *testing.T) {
1397 t.Parallel()
1398 tool := NewComputerUseTool(ComputerUseToolOptions{
1399 DisplayWidthPx: 1920,
1400 DisplayHeightPx: 1080,
1401 ToolVersion: ComputerUse20250124,
1402 }, noopComputerRun).Definition()
1403 require.Equal(t, "anthropic.computer", tool.ID)
1404 require.Equal(t, "computer", tool.Name)
1405 require.Equal(t, int64(1920), tool.Args["display_width_px"])
1406 require.Equal(t, int64(1080), tool.Args["display_height_px"])
1407 require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1408 })
1409
1410 t.Run("includes optional fields when set", func(t *testing.T) {
1411 t.Parallel()
1412 displayNum := int64(1)
1413 enableZoom := true
1414 tool := NewComputerUseTool(ComputerUseToolOptions{
1415 DisplayWidthPx: 1024,
1416 DisplayHeightPx: 768,
1417 DisplayNumber: &displayNum,
1418 EnableZoom: &enableZoom,
1419 ToolVersion: ComputerUse20251124,
1420 CacheControl: &CacheControl{Type: "ephemeral"},
1421 }, noopComputerRun).Definition()
1422 require.Equal(t, int64(1), tool.Args["display_number"])
1423 require.Equal(t, true, tool.Args["enable_zoom"])
1424 require.NotNil(t, tool.Args["cache_control"])
1425 })
1426
1427 t.Run("omits optional fields when nil", func(t *testing.T) {
1428 t.Parallel()
1429 tool := NewComputerUseTool(ComputerUseToolOptions{
1430 DisplayWidthPx: 1920,
1431 DisplayHeightPx: 1080,
1432 ToolVersion: ComputerUse20250124,
1433 }, noopComputerRun).Definition()
1434 _, hasDisplayNum := tool.Args["display_number"]
1435 _, hasEnableZoom := tool.Args["enable_zoom"]
1436 _, hasCacheControl := tool.Args["cache_control"]
1437 require.False(t, hasDisplayNum)
1438 require.False(t, hasEnableZoom)
1439 require.False(t, hasCacheControl)
1440 })
1441}
1442
1443func TestIsComputerUseTool(t *testing.T) {
1444 t.Parallel()
1445
1446 t.Run("returns true for computer use tool", func(t *testing.T) {
1447 t.Parallel()
1448 tool := NewComputerUseTool(ComputerUseToolOptions{
1449 DisplayWidthPx: 1920,
1450 DisplayHeightPx: 1080,
1451 ToolVersion: ComputerUse20250124,
1452 }, noopComputerRun)
1453 require.True(t, IsComputerUseTool(tool.Definition()))
1454 })
1455
1456 t.Run("returns false for function tool", func(t *testing.T) {
1457 t.Parallel()
1458 tool := fantasy.FunctionTool{
1459 Name: "test",
1460 Description: "test tool",
1461 }
1462 require.False(t, IsComputerUseTool(tool))
1463 })
1464
1465 t.Run("returns false for other provider defined tool", func(t *testing.T) {
1466 t.Parallel()
1467 tool := fantasy.ProviderDefinedTool{
1468 ID: "other.tool",
1469 Name: "other",
1470 }
1471 require.False(t, IsComputerUseTool(tool))
1472 })
1473}
1474
1475func TestNeedsBetaAPI(t *testing.T) {
1476 t.Parallel()
1477
1478 lm := languageModel{options: options{}}
1479
1480 t.Run("returns false for empty tools", func(t *testing.T) {
1481 t.Parallel()
1482 _, _, _, betaFlags := lm.toTools(nil, nil, false)
1483 require.Empty(t, betaFlags)
1484 _, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1485 require.Empty(t, betaFlags)
1486 })
1487
1488 t.Run("returns false for only function tools", func(t *testing.T) {
1489 t.Parallel()
1490 tools := []fantasy.Tool{
1491 fantasy.FunctionTool{Name: "test"},
1492 }
1493 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1494 require.Empty(t, betaFlags)
1495 })
1496
1497 t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1498 t.Parallel()
1499 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1500 DisplayWidthPx: 1920,
1501 DisplayHeightPx: 1080,
1502 ToolVersion: ComputerUse20250124,
1503 }, noopComputerRun))
1504 tools := []fantasy.Tool{
1505 fantasy.FunctionTool{Name: "test"},
1506 cuTool,
1507 }
1508 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1509 require.NotEmpty(t, betaFlags)
1510 })
1511}
1512
1513func TestComputerUseToolJSON(t *testing.T) {
1514 t.Parallel()
1515
1516 t.Run("builds JSON for version 20250124", func(t *testing.T) {
1517 t.Parallel()
1518 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1519 DisplayWidthPx: 1920,
1520 DisplayHeightPx: 1080,
1521 ToolVersion: ComputerUse20250124,
1522 }, noopComputerRun))
1523 data, err := computerUseToolJSON(cuTool)
1524 require.NoError(t, err)
1525 var m map[string]any
1526 require.NoError(t, json.Unmarshal(data, &m))
1527 require.Equal(t, "computer_20250124", m["type"])
1528 require.Equal(t, "computer", m["name"])
1529 require.InDelta(t, 1920, m["display_width_px"], 0)
1530 require.InDelta(t, 1080, m["display_height_px"], 0)
1531 })
1532
1533 t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1534 t.Parallel()
1535 enableZoom := true
1536 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1537 DisplayWidthPx: 1024,
1538 DisplayHeightPx: 768,
1539 EnableZoom: &enableZoom,
1540 ToolVersion: ComputerUse20251124,
1541 }, noopComputerRun))
1542 data, err := computerUseToolJSON(cuTool)
1543 require.NoError(t, err)
1544 var m map[string]any
1545 require.NoError(t, json.Unmarshal(data, &m))
1546 require.Equal(t, "computer_20251124", m["type"])
1547 require.Equal(t, true, m["enable_zoom"])
1548 })
1549
1550 t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1551 t.Parallel()
1552 // Direct construction stores int64 values.
1553 cuTool := NewComputerUseTool(ComputerUseToolOptions{
1554 DisplayWidthPx: 1920,
1555 DisplayHeightPx: 1080,
1556 ToolVersion: ComputerUse20250124,
1557 }, noopComputerRun)
1558 data, err := computerUseToolJSON(cuTool.Definition())
1559 require.NoError(t, err)
1560 var m map[string]any
1561 require.NoError(t, json.Unmarshal(data, &m))
1562 require.InDelta(t, 1920, m["display_width_px"], 0)
1563 })
1564
1565 t.Run("returns error when version is missing", func(t *testing.T) {
1566 t.Parallel()
1567 pdt := fantasy.ProviderDefinedTool{
1568 ID: "anthropic.computer",
1569 Name: "computer",
1570 Args: map[string]any{
1571 "display_width_px": float64(1920),
1572 "display_height_px": float64(1080),
1573 },
1574 }
1575 _, err := computerUseToolJSON(pdt)
1576 require.Error(t, err)
1577 require.Contains(t, err.Error(), "tool_version arg is missing") })
1578
1579 t.Run("returns error for unsupported version", func(t *testing.T) {
1580 t.Parallel()
1581 pdt := fantasy.ProviderDefinedTool{
1582 ID: "anthropic.computer",
1583 Name: "computer",
1584 Args: map[string]any{
1585 "display_width_px": float64(1920),
1586 "display_height_px": float64(1080),
1587 "tool_version": "computer_99991231",
1588 },
1589 }
1590 _, err := computerUseToolJSON(pdt)
1591 require.Error(t, err)
1592 require.Contains(t, err.Error(), "unsupported")
1593 })
1594}
1595
1596func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1597 t.Parallel()
1598
1599 t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1600 t.Parallel()
1601 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1602 require.Error(t, err)
1603 require.Contains(t, err.Error(), "coordinate")
1604 })
1605
1606 t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1607 t.Parallel()
1608 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1609 require.Error(t, err)
1610 require.Contains(t, err.Error(), "coordinate")
1611 })
1612
1613 t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1614 t.Parallel()
1615 _, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1616 require.Error(t, err)
1617 require.Contains(t, err.Error(), "start_coordinate")
1618 })
1619
1620 t.Run("rejects region with 3 elements", func(t *testing.T) {
1621 t.Parallel()
1622 _, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1623 require.Error(t, err)
1624 require.Contains(t, err.Error(), "region")
1625 })
1626
1627 t.Run("accepts valid coordinate", func(t *testing.T) {
1628 t.Parallel()
1629 result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1630 require.NoError(t, err)
1631 require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1632 })
1633
1634 t.Run("accepts absent optional arrays", func(t *testing.T) {
1635 t.Parallel()
1636 result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1637 require.NoError(t, err)
1638 require.Equal(t, ActionScreenshot, result.Action)
1639 })
1640}
1641
1642func TestToTools_RawJSON(t *testing.T) {
1643 t.Parallel()
1644
1645 lm := languageModel{options: options{}}
1646
1647 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1648 DisplayWidthPx: 1920,
1649 DisplayHeightPx: 1080,
1650 ToolVersion: ComputerUse20250124,
1651 }, noopComputerRun))
1652
1653 tools := []fantasy.Tool{
1654 fantasy.FunctionTool{
1655 Name: "weather",
1656 Description: "Get weather",
1657 InputSchema: map[string]any{
1658 "properties": map[string]any{
1659 "location": map[string]any{"type": "string"},
1660 },
1661 "required": []string{"location"},
1662 },
1663 },
1664 WebSearchTool(nil),
1665 cuTool,
1666 }
1667
1668 rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1669
1670 require.Len(t, rawTools, 3)
1671 require.Nil(t, toolChoice)
1672 require.Empty(t, warnings)
1673 require.NotEmpty(t, betaFlags)
1674
1675 // Verify each raw tool is valid JSON.
1676 for i, raw := range rawTools {
1677 var m map[string]any
1678 require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1679 }
1680
1681 // Check function tool.
1682 var funcTool map[string]any
1683 require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1684 require.Equal(t, "weather", funcTool["name"])
1685
1686 // Check web search tool.
1687 var webTool map[string]any
1688 require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1689 require.Equal(t, "web_search_20250305", webTool["type"])
1690
1691 // Check computer use tool.
1692 var cuToolJSON map[string]any
1693 require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1694 require.Equal(t, "computer_20250124", cuToolJSON["type"])
1695 require.Equal(t, "computer", cuToolJSON["name"])
1696}
1697
1698func TestGenerate_BetaAPI(t *testing.T) {
1699 t.Parallel()
1700
1701 t.Run("sends beta header for computer use", func(t *testing.T) {
1702 t.Parallel()
1703
1704 var capturedHeaders http.Header
1705 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1706 capturedHeaders = r.Header.Clone()
1707 w.Header().Set("Content-Type", "application/json")
1708 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1709 }))
1710 defer server.Close()
1711
1712 provider, err := New(
1713 WithAPIKey("test-api-key"),
1714 WithBaseURL(server.URL),
1715 )
1716 require.NoError(t, err)
1717
1718 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1719 require.NoError(t, err)
1720
1721 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1722 DisplayWidthPx: 1920,
1723 DisplayHeightPx: 1080,
1724 ToolVersion: ComputerUse20250124,
1725 }, noopComputerRun))
1726
1727 _, err = model.Generate(context.Background(), fantasy.Call{
1728 Prompt: testPrompt(),
1729 Tools: []fantasy.Tool{cuTool},
1730 })
1731 require.NoError(t, err)
1732 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1733 })
1734
1735 t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1736 t.Parallel()
1737
1738 var capturedHeaders http.Header
1739 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1740 capturedHeaders = r.Header.Clone()
1741 w.Header().Set("Content-Type", "application/json")
1742 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1743 }))
1744 defer server.Close()
1745
1746 provider, err := New(
1747 WithAPIKey("test-api-key"),
1748 WithBaseURL(server.URL),
1749 )
1750 require.NoError(t, err)
1751
1752 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1753 require.NoError(t, err)
1754
1755 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1756 DisplayWidthPx: 1920,
1757 DisplayHeightPx: 1080,
1758 ToolVersion: ComputerUse20251124,
1759 }, noopComputerRun))
1760
1761 _, err = model.Generate(context.Background(), fantasy.Call{
1762 Prompt: testPrompt(),
1763 Tools: []fantasy.Tool{cuTool},
1764 })
1765 require.NoError(t, err)
1766 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1767 })
1768
1769 t.Run("returns tool use from beta response", func(t *testing.T) {
1770 t.Parallel()
1771
1772 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1773 w.Header().Set("Content-Type", "application/json")
1774 _ = json.NewEncoder(w).Encode(map[string]any{
1775 "id": "msg_01Test",
1776 "type": "message",
1777 "role": "assistant",
1778 "model": "claude-sonnet-4-20250514",
1779 "content": []any{
1780 map[string]any{
1781 "type": "tool_use",
1782 "id": "toolu_01",
1783 "name": "computer",
1784 "input": map[string]any{"action": "screenshot"},
1785 },
1786 },
1787 "stop_reason": "tool_use",
1788 "usage": map[string]any{
1789 "input_tokens": 10,
1790 "output_tokens": 5,
1791 "cache_creation": map[string]any{
1792 "ephemeral_1h_input_tokens": 0,
1793 "ephemeral_5m_input_tokens": 0,
1794 },
1795 "cache_creation_input_tokens": 0,
1796 "cache_read_input_tokens": 0,
1797 "server_tool_use": map[string]any{
1798 "web_search_requests": 0,
1799 },
1800 "service_tier": "standard",
1801 },
1802 })
1803 }))
1804 defer server.Close()
1805
1806 provider, err := New(
1807 WithAPIKey("test-api-key"),
1808 WithBaseURL(server.URL),
1809 )
1810 require.NoError(t, err)
1811
1812 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1813 require.NoError(t, err)
1814
1815 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1816 DisplayWidthPx: 1920,
1817 DisplayHeightPx: 1080,
1818 ToolVersion: ComputerUse20250124,
1819 }, noopComputerRun))
1820
1821 resp, err := model.Generate(context.Background(), fantasy.Call{
1822 Prompt: testPrompt(),
1823 Tools: []fantasy.Tool{cuTool},
1824 })
1825 require.NoError(t, err)
1826
1827 toolCalls := resp.Content.ToolCalls()
1828 require.Len(t, toolCalls, 1)
1829 require.Equal(t, "computer", toolCalls[0].ToolName)
1830 require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1831 require.Contains(t, toolCalls[0].Input, "screenshot")
1832 require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1833
1834 // Verify typed parsing works on the tool call input.
1835 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1836 require.NoError(t, err)
1837 require.Equal(t, ActionScreenshot, parsed.Action)
1838 })
1839}
1840
1841func TestBedrockRegionPrefixing(t *testing.T) {
1842 t.Run("uses us prefix when region is unset", func(t *testing.T) {
1843 t.Setenv("AWS_REGION", "")
1844 t.Setenv("AWS_DEFAULT_REGION", "")
1845 require.Equal(t, "us.anthropic.claude-sonnet-4-5-20250929-v1:0", bedrockPrefixModelWithRegion("anthropic.claude-sonnet-4-5-20250929-v1:0"))
1846 })
1847
1848 t.Run("uses AWS_DEFAULT_REGION when AWS_REGION is unset", func(t *testing.T) {
1849 t.Setenv("AWS_REGION", "")
1850 t.Setenv("AWS_DEFAULT_REGION", "eu-west-1")
1851 require.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", bedrockPrefixModelWithRegion("anthropic.claude-sonnet-4-5-20250929-v1:0"))
1852 })
1853
1854 t.Run("keeps already-prefixed model IDs", func(t *testing.T) {
1855 t.Setenv("AWS_REGION", "us-east-1")
1856 t.Setenv("AWS_DEFAULT_REGION", "")
1857 require.Equal(t, "global.anthropic.claude-sonnet-4-5-20250929-v1:0", bedrockPrefixModelWithRegion("global.anthropic.claude-sonnet-4-5-20250929-v1:0"))
1858 })
1859
1860 t.Run("maps known regions to profile prefixes", func(t *testing.T) {
1861 require.Equal(t, "us", bedrockRegionPrefix("us-west-2"))
1862 require.Equal(t, "us", bedrockRegionPrefix("ca-central-1"))
1863 require.Equal(t, "eu", bedrockRegionPrefix("eu-central-1"))
1864 require.Equal(t, "jp", bedrockRegionPrefix("ap-northeast-1"))
1865 require.Equal(t, "au", bedrockRegionPrefix("ap-southeast-2"))
1866 require.Equal(t, "global", bedrockRegionPrefix("ap-south-1"))
1867 require.Equal(t, "global", bedrockRegionPrefix("sa-east-1"))
1868 })
1869}
1870
1871func TestStream_BetaAPI(t *testing.T) {
1872 t.Parallel()
1873
1874 t.Run("streams via beta API for computer use", func(t *testing.T) {
1875 t.Parallel()
1876
1877 var capturedHeaders http.Header
1878 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1879 capturedHeaders = r.Header.Clone()
1880 w.Header().Set("Content-Type", "text/event-stream")
1881 w.Header().Set("Cache-Control", "no-cache")
1882 w.WriteHeader(http.StatusOK)
1883 chunks := []string{
1884 "event: message_start\n",
1885 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1886 "event: message_stop\n",
1887 "data: {\"type\":\"message_stop\"}\n\n",
1888 }
1889 for _, chunk := range chunks {
1890 _, _ = fmt.Fprint(w, chunk)
1891 if flusher, ok := w.(http.Flusher); ok {
1892 flusher.Flush()
1893 }
1894 }
1895 }))
1896 defer server.Close()
1897
1898 provider, err := New(
1899 WithAPIKey("test-api-key"),
1900 WithBaseURL(server.URL),
1901 )
1902 require.NoError(t, err)
1903
1904 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1905 require.NoError(t, err)
1906
1907 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1908 DisplayWidthPx: 1920,
1909 DisplayHeightPx: 1080,
1910 ToolVersion: ComputerUse20250124,
1911 }, noopComputerRun))
1912
1913 stream, err := model.Stream(context.Background(), fantasy.Call{
1914 Prompt: testPrompt(),
1915 Tools: []fantasy.Tool{cuTool},
1916 })
1917 require.NoError(t, err)
1918
1919 stream(func(fantasy.StreamPart) bool { return true })
1920
1921 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1922 })
1923
1924 t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1925 t.Parallel()
1926
1927 var capturedHeaders http.Header
1928 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1929 capturedHeaders = r.Header.Clone()
1930 w.Header().Set("Content-Type", "text/event-stream")
1931 w.Header().Set("Cache-Control", "no-cache")
1932 w.WriteHeader(http.StatusOK)
1933 chunks := []string{
1934 "event: message_start\n",
1935 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1936 "event: message_stop\n",
1937 "data: {\"type\":\"message_stop\"}\n\n",
1938 }
1939 for _, chunk := range chunks {
1940 _, _ = fmt.Fprint(w, chunk)
1941 if flusher, ok := w.(http.Flusher); ok {
1942 flusher.Flush()
1943 }
1944 }
1945 }))
1946 defer server.Close()
1947
1948 provider, err := New(
1949 WithAPIKey("test-api-key"),
1950 WithBaseURL(server.URL),
1951 )
1952 require.NoError(t, err)
1953
1954 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1955 require.NoError(t, err)
1956
1957 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1958 DisplayWidthPx: 1920,
1959 DisplayHeightPx: 1080,
1960 ToolVersion: ComputerUse20251124,
1961 }, noopComputerRun))
1962
1963 stream, err := model.Stream(context.Background(), fantasy.Call{
1964 Prompt: testPrompt(),
1965 Tools: []fantasy.Tool{cuTool},
1966 })
1967 require.NoError(t, err)
1968
1969 stream(func(fantasy.StreamPart) bool { return true })
1970
1971 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1972 })
1973}
1974
1975// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1976// via model.Generate, passing the ExecutableProviderTool directly into
1977// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1978// walks through a scripted sequence of actions — screenshot, click,
1979// type, key, scroll — then finishes with a text reply. Each turn the
1980// test parses the tool call, builds a screenshot result, and appends
1981// both to the prompt for the next request.
1982func TestGenerate_ComputerUseTool(t *testing.T) {
1983 t.Parallel()
1984
1985 type actionStep struct {
1986 input map[string]any
1987 want ComputerUseInput
1988 }
1989 steps := []actionStep{
1990 {
1991 input: map[string]any{"action": "screenshot"},
1992 want: ComputerUseInput{Action: ActionScreenshot},
1993 },
1994 {
1995 input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
1996 want: ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
1997 },
1998 {
1999 input: map[string]any{"action": "type", "text": "hello world"},
2000 want: ComputerUseInput{Action: ActionType, Text: "hello world"},
2001 },
2002 {
2003 input: map[string]any{"action": "key", "text": "Return"},
2004 want: ComputerUseInput{Action: ActionKey, Text: "Return"},
2005 },
2006 {
2007 input: map[string]any{
2008 "action": "scroll",
2009 "coordinate": []any{500, 300},
2010 "scroll_direction": "down",
2011 "scroll_amount": 3,
2012 },
2013 want: ComputerUseInput{
2014 Action: ActionScroll,
2015 Coordinate: [2]int64{500, 300},
2016 ScrollDirection: "down",
2017 ScrollAmount: 3,
2018 },
2019 },
2020 {
2021 input: map[string]any{"action": "screenshot"},
2022 want: ComputerUseInput{Action: ActionScreenshot},
2023 },
2024 }
2025
2026 var (
2027 requestIdx int
2028 betaHeaders []string
2029 )
2030 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2031 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2032 idx := requestIdx
2033 requestIdx++
2034
2035 w.Header().Set("Content-Type", "application/json")
2036 if idx < len(steps) {
2037 _ = json.NewEncoder(w).Encode(map[string]any{
2038 "id": fmt.Sprintf("msg_%02d", idx),
2039 "type": "message",
2040 "role": "assistant",
2041 "model": "claude-sonnet-4-20250514",
2042 "content": []any{map[string]any{
2043 "type": "tool_use",
2044 "id": fmt.Sprintf("toolu_%02d", idx),
2045 "name": "computer",
2046 "input": steps[idx].input,
2047 }},
2048 "stop_reason": "tool_use",
2049 "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
2050 })
2051 return
2052 }
2053 _ = json.NewEncoder(w).Encode(map[string]any{
2054 "id": "msg_final",
2055 "type": "message",
2056 "role": "assistant",
2057 "model": "claude-sonnet-4-20250514",
2058 "content": []any{map[string]any{
2059 "type": "text",
2060 "text": "Done! I have completed all the requested actions.",
2061 }},
2062 "stop_reason": "end_turn",
2063 "usage": map[string]any{"input_tokens": 10, "output_tokens": 15},
2064 })
2065 }))
2066 defer server.Close()
2067
2068 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2069 require.NoError(t, err)
2070
2071 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2072 require.NoError(t, err)
2073
2074 // Pass the ExecutableProviderTool directly — the whole point is
2075 // to verify that the Tool interface works without unwrapping.
2076 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2077 DisplayWidthPx: 1920,
2078 DisplayHeightPx: 1080,
2079 ToolVersion: ComputerUse20250124,
2080 }, noopComputerRun)
2081
2082 var got []ComputerUseInput
2083 prompt := testPrompt()
2084 fakePNG := []byte("fake-screenshot-png")
2085
2086 for turn := 0; turn <= len(steps); turn++ {
2087 resp, err := model.Generate(context.Background(), fantasy.Call{
2088 Prompt: prompt,
2089 Tools: []fantasy.Tool{cuTool},
2090 })
2091 require.NoError(t, err, "turn %d", turn)
2092
2093 if resp.FinishReason != fantasy.FinishReasonToolCalls {
2094 require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2095 require.Contains(t, resp.Content.Text(), "Done")
2096 break
2097 }
2098
2099 toolCalls := resp.Content.ToolCalls()
2100 require.Len(t, toolCalls, 1, "turn %d", turn)
2101 require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2102
2103 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2104 require.NoError(t, err, "turn %d", turn)
2105 got = append(got, parsed)
2106
2107 // Build the next prompt: append the assistant tool-call turn
2108 // and the user screenshot-result turn.
2109 prompt = append(prompt,
2110 fantasy.Message{
2111 Role: fantasy.MessageRoleAssistant,
2112 Content: []fantasy.MessagePart{
2113 fantasy.ToolCallPart{
2114 ToolCallID: toolCalls[0].ToolCallID,
2115 ToolName: toolCalls[0].ToolName,
2116 Input: toolCalls[0].Input,
2117 },
2118 },
2119 },
2120 fantasy.Message{
2121 // Use MessageRoleTool for tool results — this matches
2122 // what the agent loop produces.
2123 Role: fantasy.MessageRoleTool,
2124 Content: []fantasy.MessagePart{
2125 NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2126 },
2127 },
2128 )
2129 }
2130
2131 // Every scripted action was received and parsed correctly.
2132 require.Len(t, got, len(steps))
2133 for i, step := range steps {
2134 require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2135 require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2136 require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2137 require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2138 require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2139 }
2140
2141 // Beta header was sent on every request.
2142 require.Len(t, betaHeaders, len(steps)+1)
2143 for i, h := range betaHeaders {
2144 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2145 }
2146}
2147
2148// TestStream_ComputerUseTool runs a multi-turn computer use session
2149// via model.Stream, verifying that the ExecutableProviderTool works
2150// through the streaming path end-to-end.
2151func TestStream_ComputerUseTool(t *testing.T) {
2152 t.Parallel()
2153
2154 type streamStep struct {
2155 input map[string]any
2156 wantAction ComputerAction
2157 }
2158 steps := []streamStep{
2159 {input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2160 {input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2161 {input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2162 }
2163
2164 var (
2165 requestIdx int
2166 betaHeaders []string
2167 )
2168
2169 // streamToolUseChunks returns SSE chunks for a single
2170 // computer-use tool_use content block.
2171 streamToolUseChunks := func(id string, input map[string]any) []string {
2172 inputJSON, _ := json.Marshal(input)
2173 escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2174 return []string{
2175 "event: message_start\n",
2176 `data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2177 "event: content_block_start\n",
2178 `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2179 "event: content_block_delta\n",
2180 `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2181 "event: content_block_stop\n",
2182 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2183 "event: message_delta\n",
2184 `data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2185 "event: message_stop\n",
2186 `data: {"type":"message_stop"}` + "\n\n",
2187 }
2188 }
2189
2190 streamTextChunks := func() []string {
2191 return []string{
2192 "event: message_start\n",
2193 `data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2194 "event: content_block_start\n",
2195 `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2196 "event: content_block_delta\n",
2197 `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2198 "event: content_block_stop\n",
2199 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2200 "event: message_delta\n",
2201 `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2202 "event: message_stop\n",
2203 `data: {"type":"message_stop"}` + "\n\n",
2204 }
2205 }
2206
2207 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2208 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2209 idx := requestIdx
2210 requestIdx++
2211
2212 w.Header().Set("Content-Type", "text/event-stream")
2213 w.Header().Set("Cache-Control", "no-cache")
2214 w.WriteHeader(http.StatusOK)
2215
2216 var chunks []string
2217 if idx < len(steps) {
2218 chunks = streamToolUseChunks(
2219 fmt.Sprintf("toolu_%02d", idx),
2220 steps[idx].input,
2221 )
2222 } else {
2223 chunks = streamTextChunks()
2224 }
2225 for _, chunk := range chunks {
2226 _, _ = fmt.Fprint(w, chunk)
2227 if f, ok := w.(http.Flusher); ok {
2228 f.Flush()
2229 }
2230 }
2231 }))
2232 defer server.Close()
2233
2234 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2235 require.NoError(t, err)
2236
2237 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2238 require.NoError(t, err)
2239
2240 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2241 DisplayWidthPx: 1920,
2242 DisplayHeightPx: 1080,
2243 ToolVersion: ComputerUse20250124,
2244 }, noopComputerRun)
2245
2246 var gotActions []ComputerAction
2247 prompt := testPrompt()
2248 fakePNG := []byte("fake-screenshot-png")
2249
2250 for turn := 0; turn <= len(steps); turn++ {
2251 stream, err := model.Stream(context.Background(), fantasy.Call{
2252 Prompt: prompt,
2253 Tools: []fantasy.Tool{cuTool},
2254 })
2255 require.NoError(t, err, "turn %d", turn)
2256
2257 var (
2258 toolCallName string
2259 toolCallID string
2260 toolCallInput string
2261 finishReason fantasy.FinishReason
2262 gotText string
2263 )
2264 stream(func(part fantasy.StreamPart) bool {
2265 switch part.Type {
2266 case fantasy.StreamPartTypeToolCall:
2267 toolCallName = part.ToolCallName
2268 toolCallID = part.ID
2269 toolCallInput = part.ToolCallInput
2270 case fantasy.StreamPartTypeFinish:
2271 finishReason = part.FinishReason
2272 case fantasy.StreamPartTypeTextDelta:
2273 gotText += part.Delta
2274 }
2275 return true
2276 })
2277
2278 if finishReason != fantasy.FinishReasonToolCalls {
2279 require.Contains(t, gotText, "All done")
2280 break
2281 }
2282
2283 require.Equal(t, "computer", toolCallName, "turn %d", turn)
2284
2285 parsed, err := ParseComputerUseInput(toolCallInput)
2286 require.NoError(t, err, "turn %d", turn)
2287 gotActions = append(gotActions, parsed.Action)
2288
2289 prompt = append(prompt,
2290 fantasy.Message{
2291 Role: fantasy.MessageRoleAssistant,
2292 Content: []fantasy.MessagePart{
2293 fantasy.ToolCallPart{
2294 ToolCallID: toolCallID,
2295 ToolName: toolCallName,
2296 Input: toolCallInput,
2297 },
2298 },
2299 },
2300 fantasy.Message{
2301 // Use MessageRoleTool for tool results — this matches
2302 // what the agent loop produces.
2303 Role: fantasy.MessageRoleTool,
2304 Content: []fantasy.MessagePart{
2305 NewComputerUseScreenshotResult(toolCallID, fakePNG),
2306 },
2307 },
2308 )
2309 }
2310
2311 require.Len(t, gotActions, len(steps))
2312 for i, step := range steps {
2313 require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2314 }
2315
2316 require.Len(t, betaHeaders, len(steps)+1)
2317 for i, h := range betaHeaders {
2318 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2319 }
2320}