1package anthropic
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "math"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/anthropic-sdk-go"
17 "github.com/stretchr/testify/require"
18)
19
20// noopComputerRun is a no-op run function for tests that only need
21// to inspect the tool definition, not execute it.
22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
23 return fantasy.ToolResponse{}, nil
24}
25
26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
27 t.Parallel()
28
29 t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
30 t.Parallel()
31
32 prompt := fantasy.Prompt{
33 {
34 Role: fantasy.MessageRoleUser,
35 Content: []fantasy.MessagePart{
36 fantasy.TextPart{Text: "Hello"},
37 },
38 },
39 {
40 Role: fantasy.MessageRoleAssistant,
41 Content: []fantasy.MessagePart{
42 fantasy.ReasoningPart{
43 Text: "Let me think about this...",
44 ProviderOptions: fantasy.ProviderOptions{
45 Name: &ReasoningOptionMetadata{
46 Signature: "abc123",
47 },
48 },
49 },
50 },
51 },
52 }
53
54 systemBlocks, messages, warnings := toPrompt(prompt, true)
55
56 require.Empty(t, systemBlocks)
57 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
58 require.Len(t, warnings, 1)
59 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
60 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
61 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
62 })
63
64 t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
65 t.Parallel()
66
67 prompt := fantasy.Prompt{
68 {
69 Role: fantasy.MessageRoleUser,
70 Content: []fantasy.MessagePart{
71 fantasy.TextPart{Text: "Hello"},
72 },
73 },
74 {
75 Role: fantasy.MessageRoleAssistant,
76 Content: []fantasy.MessagePart{
77 fantasy.ReasoningPart{
78 Text: "Let me think about this...",
79 ProviderOptions: fantasy.ProviderOptions{
80 Name: &ReasoningOptionMetadata{
81 Signature: "def456",
82 },
83 },
84 },
85 },
86 },
87 }
88
89 systemBlocks, messages, warnings := toPrompt(prompt, false)
90
91 require.Empty(t, systemBlocks)
92 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
93 require.Len(t, warnings, 2)
94 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
95 require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
96 require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
97 require.Contains(t, warnings[1].Message, "dropping empty assistant message")
98 })
99
100 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
101 t.Parallel()
102
103 prompt := fantasy.Prompt{
104 {
105 Role: fantasy.MessageRoleUser,
106 Content: []fantasy.MessagePart{
107 fantasy.TextPart{Text: "Hello"},
108 },
109 },
110 {
111 Role: fantasy.MessageRoleAssistant,
112 Content: []fantasy.MessagePart{},
113 },
114 }
115
116 systemBlocks, messages, warnings := toPrompt(prompt, true)
117
118 require.Empty(t, systemBlocks)
119 require.Len(t, messages, 1, "should only have user message")
120 require.Len(t, warnings, 1)
121 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
122 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
123 })
124
125 t.Run("should keep assistant messages with text content", func(t *testing.T) {
126 t.Parallel()
127
128 prompt := fantasy.Prompt{
129 {
130 Role: fantasy.MessageRoleUser,
131 Content: []fantasy.MessagePart{
132 fantasy.TextPart{Text: "Hello"},
133 },
134 },
135 {
136 Role: fantasy.MessageRoleAssistant,
137 Content: []fantasy.MessagePart{
138 fantasy.TextPart{Text: "Hi there!"},
139 },
140 },
141 }
142
143 systemBlocks, messages, warnings := toPrompt(prompt, true)
144
145 require.Empty(t, systemBlocks)
146 require.Len(t, messages, 2, "should have both user and assistant messages")
147 require.Empty(t, warnings)
148 })
149
150 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
151 t.Parallel()
152
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.TextPart{Text: "What's the weather?"},
158 },
159 },
160 {
161 Role: fantasy.MessageRoleAssistant,
162 Content: []fantasy.MessagePart{
163 fantasy.ToolCallPart{
164 ToolCallID: "call_123",
165 ToolName: "get_weather",
166 Input: `{"location":"NYC"}`,
167 },
168 },
169 },
170 }
171
172 systemBlocks, messages, warnings := toPrompt(prompt, true)
173
174 require.Empty(t, systemBlocks)
175 require.Len(t, messages, 2, "should have both user and assistant messages")
176 require.Empty(t, warnings)
177 })
178
179 t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
180 t.Parallel()
181
182 prompt := fantasy.Prompt{
183 {
184 Role: fantasy.MessageRoleUser,
185 Content: []fantasy.MessagePart{
186 fantasy.TextPart{Text: "Hi"},
187 },
188 },
189 {
190 Role: fantasy.MessageRoleAssistant,
191 Content: []fantasy.MessagePart{
192 fantasy.ToolCallPart{
193 ToolCallID: "call_123",
194 ToolName: "get_weather",
195 Input: "{not-json",
196 },
197 },
198 },
199 }
200
201 systemBlocks, messages, warnings := toPrompt(prompt, true)
202
203 require.Empty(t, systemBlocks)
204 require.Len(t, messages, 1, "should only have user message")
205 require.Len(t, warnings, 1)
206 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
207 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
208 })
209
210 t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
211 t.Parallel()
212
213 prompt := fantasy.Prompt{
214 {
215 Role: fantasy.MessageRoleUser,
216 Content: []fantasy.MessagePart{
217 fantasy.TextPart{Text: "Hello"},
218 },
219 },
220 {
221 Role: fantasy.MessageRoleAssistant,
222 Content: []fantasy.MessagePart{
223 fantasy.ReasoningPart{
224 Text: "Let me think...",
225 ProviderOptions: fantasy.ProviderOptions{
226 Name: &ReasoningOptionMetadata{
227 Signature: "abc123",
228 },
229 },
230 },
231 fantasy.TextPart{Text: "Hi there!"},
232 },
233 },
234 }
235
236 systemBlocks, messages, warnings := toPrompt(prompt, true)
237
238 require.Empty(t, systemBlocks)
239 require.Len(t, messages, 2, "should have both user and assistant messages")
240 require.Empty(t, warnings)
241 })
242
243 t.Run("should keep user messages with image content", func(t *testing.T) {
244 t.Parallel()
245
246 prompt := fantasy.Prompt{
247 {
248 Role: fantasy.MessageRoleUser,
249 Content: []fantasy.MessagePart{
250 fantasy.FilePart{
251 Data: []byte{0x01, 0x02, 0x03},
252 MediaType: "image/png",
253 },
254 },
255 },
256 }
257
258 systemBlocks, messages, warnings := toPrompt(prompt, true)
259
260 require.Empty(t, systemBlocks)
261 require.Len(t, messages, 1)
262 require.Empty(t, warnings)
263 })
264
265 t.Run("should drop user messages without visible content", func(t *testing.T) {
266 t.Parallel()
267
268 prompt := fantasy.Prompt{
269 {
270 Role: fantasy.MessageRoleUser,
271 Content: []fantasy.MessagePart{
272 fantasy.FilePart{
273 Data: []byte("not supported"),
274 MediaType: "application/pdf",
275 },
276 },
277 },
278 }
279
280 systemBlocks, messages, warnings := toPrompt(prompt, true)
281
282 require.Empty(t, systemBlocks)
283 require.Empty(t, messages)
284 require.Len(t, warnings, 1)
285 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
286 require.Contains(t, warnings[0].Message, "dropping empty user message")
287 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
288 })
289
290 t.Run("should keep user messages with tool results", func(t *testing.T) {
291 t.Parallel()
292
293 prompt := fantasy.Prompt{
294 {
295 Role: fantasy.MessageRoleTool,
296 Content: []fantasy.MessagePart{
297 fantasy.ToolResultPart{
298 ToolCallID: "call_123",
299 Output: fantasy.ToolResultOutputContentText{Text: "done"},
300 },
301 },
302 },
303 }
304
305 systemBlocks, messages, warnings := toPrompt(prompt, true)
306
307 require.Empty(t, systemBlocks)
308 require.Len(t, messages, 1)
309 require.Empty(t, warnings)
310 })
311
312 t.Run("should keep user messages with tool error results", func(t *testing.T) {
313 t.Parallel()
314
315 prompt := fantasy.Prompt{
316 {
317 Role: fantasy.MessageRoleTool,
318 Content: []fantasy.MessagePart{
319 fantasy.ToolResultPart{
320 ToolCallID: "call_456",
321 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
322 },
323 },
324 },
325 }
326
327 systemBlocks, messages, warnings := toPrompt(prompt, true)
328
329 require.Empty(t, systemBlocks)
330 require.Len(t, messages, 1)
331 require.Empty(t, warnings)
332 })
333
334 t.Run("should keep user messages with tool media results", func(t *testing.T) {
335 t.Parallel()
336
337 prompt := fantasy.Prompt{
338 {
339 Role: fantasy.MessageRoleTool,
340 Content: []fantasy.MessagePart{
341 fantasy.ToolResultPart{
342 ToolCallID: "call_789",
343 Output: fantasy.ToolResultOutputContentMedia{
344 Data: "AQID",
345 MediaType: "image/png",
346 },
347 },
348 },
349 },
350 }
351
352 systemBlocks, messages, warnings := toPrompt(prompt, true)
353
354 require.Empty(t, systemBlocks)
355 require.Len(t, messages, 1)
356 require.Empty(t, warnings)
357 })
358}
359
360func TestParseContextTooLargeError(t *testing.T) {
361 t.Parallel()
362
363 tests := []struct {
364 name string
365 message string
366 wantErr bool
367 wantUsed int
368 wantMax int
369 }{
370 {
371 name: "matches anthropic format",
372 message: "prompt is too long: 202630 tokens > 200000 maximum",
373 wantErr: true,
374 wantUsed: 202630,
375 wantMax: 200000,
376 },
377 {
378 name: "matches with different numbers",
379 message: "prompt is too long: 150000 tokens > 128000 maximum",
380 wantErr: true,
381 wantUsed: 150000,
382 wantMax: 128000,
383 },
384 {
385 name: "matches with extra whitespace",
386 message: "prompt is too long: 202630 tokens > 200000 maximum",
387 wantErr: true,
388 wantUsed: 202630,
389 wantMax: 200000,
390 },
391 {
392 name: "does not match unrelated error",
393 message: "invalid api key",
394 wantErr: false,
395 },
396 {
397 name: "does not match rate limit error",
398 message: "rate limit exceeded",
399 wantErr: false,
400 },
401 }
402
403 for _, tt := range tests {
404 t.Run(tt.name, func(t *testing.T) {
405 t.Parallel()
406 providerErr := &fantasy.ProviderError{Message: tt.message}
407 parseContextTooLargeError(tt.message, providerErr)
408
409 if tt.wantErr {
410 require.True(t, providerErr.IsContextTooLarge())
411 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
412 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
413 } else {
414 require.False(t, providerErr.IsContextTooLarge())
415 }
416 })
417 }
418}
419
420func TestParseOptions_Effort(t *testing.T) {
421 t.Parallel()
422
423 options, err := ParseOptions(map[string]any{
424 "send_reasoning": true,
425 "thinking": map[string]any{"budget_tokens": int64(2048)},
426 "effort": "medium",
427 "disable_parallel_tool_use": true,
428 })
429 require.NoError(t, err)
430 require.NotNil(t, options.SendReasoning)
431 require.True(t, *options.SendReasoning)
432 require.NotNil(t, options.Thinking)
433 require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
434 require.NotNil(t, options.Effort)
435 require.Equal(t, EffortMedium, *options.Effort)
436 require.NotNil(t, options.DisableParallelToolUse)
437 require.True(t, *options.DisableParallelToolUse)
438}
439
440func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
441 t.Parallel()
442
443 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
444 defer server.Close()
445
446 provider, err := New(
447 WithAPIKey("test-api-key"),
448 WithBaseURL(server.URL),
449 )
450 require.NoError(t, err)
451
452 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
453 require.NoError(t, err)
454
455 effort := EffortMedium
456 _, err = model.Generate(context.Background(), fantasy.Call{
457 Prompt: testPrompt(),
458 ProviderOptions: NewProviderOptions(&ProviderOptions{
459 Effort: &effort,
460 }),
461 })
462 require.NoError(t, err)
463
464 call := awaitAnthropicCall(t, calls)
465 require.Equal(t, "POST", call.method)
466 require.Equal(t, "/v1/messages", call.path)
467 requireAnthropicEffort(t, call.body, EffortMedium)
468}
469
470func TestStream_SendsOutputConfigEffort(t *testing.T) {
471 t.Parallel()
472
473 server, calls := newAnthropicStreamingServer([]string{
474 "event: message_start\n",
475 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
476 "event: message_stop\n",
477 "data: {\"type\":\"message_stop\"}\n\n",
478 })
479 defer server.Close()
480
481 provider, err := New(
482 WithAPIKey("test-api-key"),
483 WithBaseURL(server.URL),
484 )
485 require.NoError(t, err)
486
487 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
488 require.NoError(t, err)
489
490 effort := EffortHigh
491 stream, err := model.Stream(context.Background(), fantasy.Call{
492 Prompt: testPrompt(),
493 ProviderOptions: NewProviderOptions(&ProviderOptions{
494 Effort: &effort,
495 }),
496 })
497 require.NoError(t, err)
498
499 stream(func(fantasy.StreamPart) bool { return true })
500
501 call := awaitAnthropicCall(t, calls)
502 require.Equal(t, "POST", call.method)
503 require.Equal(t, "/v1/messages", call.path)
504 requireAnthropicEffort(t, call.body, EffortHigh)
505}
506
507type anthropicCall struct {
508 method string
509 path string
510 body map[string]any
511}
512
513func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
514 calls := make(chan anthropicCall, 4)
515
516 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
517 var body map[string]any
518 if r.Body != nil {
519 _ = json.NewDecoder(r.Body).Decode(&body)
520 }
521
522 calls <- anthropicCall{
523 method: r.Method,
524 path: r.URL.Path,
525 body: body,
526 }
527
528 w.Header().Set("Content-Type", "application/json")
529 _ = json.NewEncoder(w).Encode(response)
530 }))
531
532 return server, calls
533}
534
535func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
536 calls := make(chan anthropicCall, 4)
537
538 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
539 var body map[string]any
540 if r.Body != nil {
541 _ = json.NewDecoder(r.Body).Decode(&body)
542 }
543
544 calls <- anthropicCall{
545 method: r.Method,
546 path: r.URL.Path,
547 body: body,
548 }
549
550 w.Header().Set("Content-Type", "text/event-stream")
551 w.Header().Set("Cache-Control", "no-cache")
552 w.Header().Set("Connection", "keep-alive")
553 w.WriteHeader(http.StatusOK)
554
555 for _, chunk := range chunks {
556 _, _ = fmt.Fprint(w, chunk)
557 if flusher, ok := w.(http.Flusher); ok {
558 flusher.Flush()
559 }
560 }
561 }))
562
563 return server, calls
564}
565
566func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
567 t.Helper()
568
569 select {
570 case call := <-calls:
571 return call
572 case <-time.After(2 * time.Second):
573 t.Fatal("timed out waiting for Anthropic request")
574 return anthropicCall{}
575 }
576}
577
578func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
579 t.Helper()
580
581 select {
582 case call := <-calls:
583 t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
584 case <-time.After(200 * time.Millisecond):
585 }
586}
587
588func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
589 t.Helper()
590
591 outputConfig, ok := body["output_config"].(map[string]any)
592 thinking, ok := body["thinking"].(map[string]any)
593 require.True(t, ok)
594 require.Equal(t, string(expected), outputConfig["effort"])
595 require.Equal(t, "adaptive", thinking["type"])
596}
597
598func testPrompt() fantasy.Prompt {
599 return fantasy.Prompt{
600 {
601 Role: fantasy.MessageRoleUser,
602 Content: []fantasy.MessagePart{
603 fantasy.TextPart{Text: "Hello"},
604 },
605 },
606 }
607}
608
609func mockAnthropicGenerateResponse() map[string]any {
610 return map[string]any{
611 "id": "msg_01Test",
612 "type": "message",
613 "role": "assistant",
614 "model": "claude-sonnet-4-20250514",
615 "content": []any{
616 map[string]any{
617 "type": "text",
618 "text": "Hi there",
619 },
620 },
621 "stop_reason": "end_turn",
622 "stop_sequence": "",
623 "usage": map[string]any{
624 "cache_creation": map[string]any{
625 "ephemeral_1h_input_tokens": 0,
626 "ephemeral_5m_input_tokens": 0,
627 },
628 "cache_creation_input_tokens": 0,
629 "cache_read_input_tokens": 0,
630 "input_tokens": 5,
631 "output_tokens": 2,
632 "server_tool_use": map[string]any{
633 "web_search_requests": 0,
634 },
635 "service_tier": "standard",
636 },
637 }
638}
639
640func mockAnthropicWebSearchResponse() map[string]any {
641 return map[string]any{
642 "id": "msg_01WebSearch",
643 "type": "message",
644 "role": "assistant",
645 "model": "claude-sonnet-4-20250514",
646 "content": []any{
647 map[string]any{
648 "type": "server_tool_use",
649 "id": "srvtoolu_01",
650 "name": "web_search",
651 "input": map[string]any{"query": "latest AI news"},
652 "caller": map[string]any{"type": "direct"},
653 },
654 map[string]any{
655 "type": "web_search_tool_result",
656 "tool_use_id": "srvtoolu_01",
657 "caller": map[string]any{"type": "direct"},
658 "content": []any{
659 map[string]any{
660 "type": "web_search_result",
661 "url": "https://example.com/ai-news",
662 "title": "Latest AI News",
663 "encrypted_content": "encrypted_abc123",
664 "page_age": "2 hours ago",
665 },
666 map[string]any{
667 "type": "web_search_result",
668 "url": "https://example.com/ml-update",
669 "title": "ML Update",
670 "encrypted_content": "encrypted_def456",
671 "page_age": "",
672 },
673 },
674 },
675 map[string]any{
676 "type": "text",
677 "text": "Based on recent search results, here is the latest AI news.",
678 },
679 },
680 "stop_reason": "end_turn",
681 "stop_sequence": nil,
682 "usage": map[string]any{
683 "input_tokens": 100,
684 "output_tokens": 50,
685 "cache_creation_input_tokens": 0,
686 "cache_read_input_tokens": 0,
687 "server_tool_use": map[string]any{
688 "web_search_requests": 1,
689 },
690 },
691 }
692}
693
694func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
695 t.Parallel()
696
697 prompt := fantasy.Prompt{
698 // User message.
699 {
700 Role: fantasy.MessageRoleUser,
701 Content: []fantasy.MessagePart{
702 fantasy.TextPart{Text: "Search for the latest AI news"},
703 },
704 },
705 // Assistant message with a provider-executed tool call, its
706 // result, and trailing text. toResponseMessages routes
707 // provider-executed results into the assistant message, so
708 // the prompt already reflects that structure.
709 {
710 Role: fantasy.MessageRoleAssistant,
711 Content: []fantasy.MessagePart{
712 fantasy.ToolCallPart{
713 ToolCallID: "srvtoolu_01",
714 ToolName: "web_search",
715 Input: `{"query":"latest AI news"}`,
716 ProviderExecuted: true,
717 },
718 fantasy.ToolResultPart{
719 ToolCallID: "srvtoolu_01",
720 ProviderExecuted: true,
721 ProviderOptions: fantasy.ProviderOptions{
722 Name: &WebSearchResultMetadata{
723 Results: []WebSearchResultItem{
724 {
725 URL: "https://example.com/ai-news",
726 Title: "Latest AI News",
727 EncryptedContent: "encrypted_abc123",
728 PageAge: "2 hours ago",
729 },
730 {
731 URL: "https://example.com/ml-update",
732 Title: "ML Update",
733 EncryptedContent: "encrypted_def456",
734 },
735 },
736 },
737 },
738 },
739 fantasy.TextPart{Text: "Here is what I found."},
740 },
741 },
742 }
743
744 _, messages, warnings := toPrompt(prompt, true)
745
746 // No warnings expected; the provider-executed result is in the
747 // assistant message so there is no empty tool message to drop.
748 require.Empty(t, warnings)
749
750 // We should have a user message and an assistant message.
751 require.Len(t, messages, 2, "expected user + assistant messages")
752
753 assistantMsg := messages[1]
754 require.Len(t, assistantMsg.Content, 3,
755 "expected server_tool_use + web_search_tool_result + text")
756
757 // First content block: reconstructed server_tool_use.
758 serverToolUse := assistantMsg.Content[0]
759 require.NotNil(t, serverToolUse.OfServerToolUse,
760 "first block should be a server_tool_use")
761 require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
762 require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
763 serverToolUse.OfServerToolUse.Name)
764
765 // Second content block: reconstructed web_search_tool_result with
766 // encrypted_content preserved for multi-turn round-tripping.
767 webResult := assistantMsg.Content[1]
768 require.NotNil(t, webResult.OfWebSearchToolResult,
769 "second block should be a web_search_tool_result")
770 require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
771
772 results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
773 require.Len(t, results, 2)
774 require.Equal(t, "https://example.com/ai-news", results[0].URL)
775 require.Equal(t, "Latest AI News", results[0].Title)
776 require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
777 require.Equal(t, "https://example.com/ml-update", results[1].URL)
778 require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
779 // PageAge should be set for the first result and absent for the second.
780 require.True(t, results[0].PageAge.Valid())
781 require.Equal(t, "2 hours ago", results[0].PageAge.Value)
782 require.False(t, results[1].PageAge.Valid())
783
784 // Third content block: plain text.
785 require.NotNil(t, assistantMsg.Content[2].OfText)
786 require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
787}
788
789func TestGenerate_WebSearchResponse(t *testing.T) {
790 t.Parallel()
791
792 server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
793 defer server.Close()
794
795 provider, err := New(
796 WithAPIKey("test-api-key"),
797 WithBaseURL(server.URL),
798 )
799 require.NoError(t, err)
800
801 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
802 require.NoError(t, err)
803
804 resp, err := model.Generate(context.Background(), fantasy.Call{
805 Prompt: testPrompt(),
806 Tools: []fantasy.Tool{
807 WebSearchTool(nil),
808 },
809 })
810 require.NoError(t, err)
811
812 call := awaitAnthropicCall(t, calls)
813 require.Equal(t, "POST", call.method)
814 require.Equal(t, "/v1/messages", call.path)
815
816 // Walk the response content and categorise each item.
817 var (
818 toolCalls []fantasy.ToolCallContent
819 sources []fantasy.SourceContent
820 toolResults []fantasy.ToolResultContent
821 texts []fantasy.TextContent
822 )
823 for _, c := range resp.Content {
824 switch v := c.(type) {
825 case fantasy.ToolCallContent:
826 toolCalls = append(toolCalls, v)
827 case fantasy.SourceContent:
828 sources = append(sources, v)
829 case fantasy.ToolResultContent:
830 toolResults = append(toolResults, v)
831 case fantasy.TextContent:
832 texts = append(texts, v)
833 }
834 }
835
836 // ToolCallContent for the provider-executed web_search.
837 require.Len(t, toolCalls, 1)
838 require.True(t, toolCalls[0].ProviderExecuted)
839 require.Equal(t, "web_search", toolCalls[0].ToolName)
840 require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
841
842 // SourceContent entries for each search result.
843 require.Len(t, sources, 2)
844 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
845 require.Equal(t, "Latest AI News", sources[0].Title)
846 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
847 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
848 require.Equal(t, "ML Update", sources[1].Title)
849
850 // ToolResultContent with provider metadata preserving encrypted_content.
851 require.Len(t, toolResults, 1)
852 require.True(t, toolResults[0].ProviderExecuted)
853 require.Equal(t, "web_search", toolResults[0].ToolName)
854 require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
855
856 searchMeta, ok := toolResults[0].ProviderMetadata[Name]
857 require.True(t, ok, "providerMetadata should contain anthropic key")
858 webMeta, ok := searchMeta.(*WebSearchResultMetadata)
859 require.True(t, ok, "metadata should be *WebSearchResultMetadata")
860 require.Len(t, webMeta.Results, 2)
861 require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
862 require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
863 require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
864
865 // TextContent with the final answer.
866 require.Len(t, texts, 1)
867 require.Equal(t,
868 "Based on recent search results, here is the latest AI news.",
869 texts[0].Text,
870 )
871}
872
873func TestGenerate_WebSearchToolInRequest(t *testing.T) {
874 t.Parallel()
875
876 t.Run("basic web_search tool", func(t *testing.T) {
877 t.Parallel()
878
879 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
880 defer server.Close()
881
882 provider, err := New(
883 WithAPIKey("test-api-key"),
884 WithBaseURL(server.URL),
885 )
886 require.NoError(t, err)
887
888 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
889 require.NoError(t, err)
890
891 _, err = model.Generate(context.Background(), fantasy.Call{
892 Prompt: testPrompt(),
893 Tools: []fantasy.Tool{
894 WebSearchTool(nil),
895 },
896 })
897 require.NoError(t, err)
898
899 call := awaitAnthropicCall(t, calls)
900 tools, ok := call.body["tools"].([]any)
901 require.True(t, ok, "request body should have tools array")
902 require.Len(t, tools, 1)
903
904 tool, ok := tools[0].(map[string]any)
905 require.True(t, ok)
906 require.Equal(t, "web_search_20250305", tool["type"])
907 })
908
909 t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
910 t.Parallel()
911
912 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
913 defer server.Close()
914
915 provider, err := New(
916 WithAPIKey("test-api-key"),
917 WithBaseURL(server.URL),
918 )
919 require.NoError(t, err)
920
921 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
922 require.NoError(t, err)
923
924 _, err = model.Generate(context.Background(), fantasy.Call{
925 Prompt: testPrompt(),
926 Tools: []fantasy.Tool{
927 WebSearchTool(&WebSearchToolOptions{
928 AllowedDomains: []string{"example.com", "test.com"},
929 }),
930 },
931 })
932 require.NoError(t, err)
933
934 call := awaitAnthropicCall(t, calls)
935 tools, ok := call.body["tools"].([]any)
936 require.True(t, ok)
937 require.Len(t, tools, 1)
938
939 tool, ok := tools[0].(map[string]any)
940 require.True(t, ok)
941 require.Equal(t, "web_search_20250305", tool["type"])
942
943 domains, ok := tool["allowed_domains"].([]any)
944 require.True(t, ok, "tool should have allowed_domains")
945 require.Len(t, domains, 2)
946 require.Equal(t, "example.com", domains[0])
947 require.Equal(t, "test.com", domains[1])
948 })
949
950 t.Run("with max uses and user location", func(t *testing.T) {
951 t.Parallel()
952
953 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
954 defer server.Close()
955
956 provider, err := New(
957 WithAPIKey("test-api-key"),
958 WithBaseURL(server.URL),
959 )
960 require.NoError(t, err)
961
962 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
963 require.NoError(t, err)
964
965 _, err = model.Generate(context.Background(), fantasy.Call{
966 Prompt: testPrompt(),
967 Tools: []fantasy.Tool{
968 WebSearchTool(&WebSearchToolOptions{
969 MaxUses: 5,
970 UserLocation: &UserLocation{
971 City: "San Francisco",
972 Country: "US",
973 },
974 }),
975 },
976 })
977 require.NoError(t, err)
978
979 call := awaitAnthropicCall(t, calls)
980 tools, ok := call.body["tools"].([]any)
981 require.True(t, ok)
982 require.Len(t, tools, 1)
983
984 tool, ok := tools[0].(map[string]any)
985 require.True(t, ok)
986 require.Equal(t, "web_search_20250305", tool["type"])
987
988 // max_uses is serialized as a JSON number; json.Unmarshal
989 // into map[string]any decodes numbers as float64.
990 maxUses, ok := tool["max_uses"].(float64)
991 require.True(t, ok, "tool should have max_uses")
992 require.Equal(t, float64(5), maxUses)
993
994 userLoc, ok := tool["user_location"].(map[string]any)
995 require.True(t, ok, "tool should have user_location")
996 require.Equal(t, "San Francisco", userLoc["city"])
997 require.Equal(t, "US", userLoc["country"])
998 require.Equal(t, "approximate", userLoc["type"])
999 })
1000
1001 t.Run("with max uses", func(t *testing.T) {
1002 t.Parallel()
1003
1004 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1005 defer server.Close()
1006
1007 provider, err := New(
1008 WithAPIKey("test-api-key"),
1009 WithBaseURL(server.URL),
1010 )
1011 require.NoError(t, err)
1012
1013 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1014 require.NoError(t, err)
1015
1016 _, err = model.Generate(context.Background(), fantasy.Call{
1017 Prompt: testPrompt(),
1018 Tools: []fantasy.Tool{
1019 WebSearchTool(&WebSearchToolOptions{
1020 MaxUses: 3,
1021 }),
1022 },
1023 })
1024 require.NoError(t, err)
1025
1026 call := awaitAnthropicCall(t, calls)
1027 tools, ok := call.body["tools"].([]any)
1028 require.True(t, ok)
1029 require.Len(t, tools, 1)
1030
1031 tool, ok := tools[0].(map[string]any)
1032 require.True(t, ok)
1033 require.Equal(t, "web_search_20250305", tool["type"])
1034
1035 maxUses, ok := tool["max_uses"].(float64)
1036 require.True(t, ok, "tool should have max_uses")
1037 require.Equal(t, float64(3), maxUses)
1038 })
1039
1040 t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1041 t.Parallel()
1042
1043 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1044 defer server.Close()
1045
1046 provider, err := New(
1047 WithAPIKey("test-api-key"),
1048 WithBaseURL(server.URL),
1049 )
1050 require.NoError(t, err)
1051
1052 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1053 require.NoError(t, err)
1054
1055 baseTool := WebSearchTool(&WebSearchToolOptions{
1056 MaxUses: 7,
1057 BlockedDomains: []string{"example.com", "test.com"},
1058 UserLocation: &UserLocation{
1059 City: "San Francisco",
1060 Region: "CA",
1061 Country: "US",
1062 Timezone: "America/Los_Angeles",
1063 },
1064 })
1065
1066 data, err := json.Marshal(baseTool)
1067 require.NoError(t, err)
1068
1069 var roundTripped fantasy.ProviderDefinedTool
1070 err = json.Unmarshal(data, &roundTripped)
1071 require.NoError(t, err)
1072
1073 _, err = model.Generate(context.Background(), fantasy.Call{
1074 Prompt: testPrompt(),
1075 Tools: []fantasy.Tool{roundTripped},
1076 })
1077 require.NoError(t, err)
1078
1079 call := awaitAnthropicCall(t, calls)
1080 tools, ok := call.body["tools"].([]any)
1081 require.True(t, ok)
1082 require.Len(t, tools, 1)
1083
1084 tool, ok := tools[0].(map[string]any)
1085 require.True(t, ok)
1086 require.Equal(t, "web_search_20250305", tool["type"])
1087
1088 domains, ok := tool["blocked_domains"].([]any)
1089 require.True(t, ok, "tool should have blocked_domains")
1090 require.Len(t, domains, 2)
1091 require.Equal(t, "example.com", domains[0])
1092 require.Equal(t, "test.com", domains[1])
1093
1094 maxUses, ok := tool["max_uses"].(float64)
1095 require.True(t, ok, "tool should have max_uses")
1096 require.Equal(t, float64(7), maxUses)
1097
1098 userLoc, ok := tool["user_location"].(map[string]any)
1099 require.True(t, ok, "tool should have user_location")
1100 require.Equal(t, "San Francisco", userLoc["city"])
1101 require.Equal(t, "CA", userLoc["region"])
1102 require.Equal(t, "US", userLoc["country"])
1103 require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1104 require.Equal(t, "approximate", userLoc["type"])
1105 })
1106}
1107
1108func TestAnyToStringSlice(t *testing.T) {
1109 t.Parallel()
1110
1111 t.Run("from string slice", func(t *testing.T) {
1112 t.Parallel()
1113
1114 got := anyToStringSlice([]string{"example.com", ""})
1115 require.Equal(t, []string{"example.com", ""}, got)
1116 })
1117
1118 t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1119 t.Parallel()
1120
1121 got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1122 require.Equal(t, []string{"example.com", "test.com"}, got)
1123 })
1124
1125 t.Run("unsupported type", func(t *testing.T) {
1126 t.Parallel()
1127
1128 got := anyToStringSlice("example.com")
1129 require.Nil(t, got)
1130 })
1131}
1132
1133func TestAnyToInt64(t *testing.T) {
1134 t.Parallel()
1135
1136 tests := []struct {
1137 name string
1138 input any
1139 want int64
1140 wantOK bool
1141 }{
1142 {name: "int64", input: int64(7), want: 7, wantOK: true},
1143 {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1144 {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1145 {name: "float64 non-integer", input: float64(7.5), wantOK: false},
1146 {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1147 {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1148 {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1149 {name: "json number float", input: json.Number("4.2"), wantOK: false},
1150 {name: "nan", input: math.NaN(), wantOK: false},
1151 {name: "inf", input: math.Inf(1), wantOK: false},
1152 {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1153 }
1154
1155 for _, tt := range tests {
1156 t.Run(tt.name, func(t *testing.T) {
1157 got, ok := anyToInt64(tt.input)
1158 require.Equal(t, tt.wantOK, ok)
1159 if tt.wantOK {
1160 require.Equal(t, tt.want, got)
1161 }
1162 })
1163 }
1164}
1165
1166func TestAnyToUserLocation(t *testing.T) {
1167 t.Parallel()
1168
1169 t.Run("pointer passthrough", func(t *testing.T) {
1170 t.Parallel()
1171
1172 input := &UserLocation{City: "San Francisco", Country: "US"}
1173 got := anyToUserLocation(input)
1174 require.Same(t, input, got)
1175 })
1176
1177 t.Run("struct value", func(t *testing.T) {
1178 t.Parallel()
1179
1180 got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1181 require.NotNil(t, got)
1182 require.Equal(t, "San Francisco", got.City)
1183 require.Equal(t, "US", got.Country)
1184 })
1185
1186 t.Run("map value", func(t *testing.T) {
1187 t.Parallel()
1188
1189 got := anyToUserLocation(map[string]any{
1190 "city": "San Francisco",
1191 "region": "CA",
1192 "country": "US",
1193 "timezone": "America/Los_Angeles",
1194 "type": "approximate",
1195 })
1196 require.NotNil(t, got)
1197 require.Equal(t, "San Francisco", got.City)
1198 require.Equal(t, "CA", got.Region)
1199 require.Equal(t, "US", got.Country)
1200 require.Equal(t, "America/Los_Angeles", got.Timezone)
1201 })
1202
1203 t.Run("empty map", func(t *testing.T) {
1204 t.Parallel()
1205
1206 got := anyToUserLocation(map[string]any{"type": "approximate"})
1207 require.Nil(t, got)
1208 })
1209
1210 t.Run("unsupported type", func(t *testing.T) {
1211 t.Parallel()
1212
1213 got := anyToUserLocation("San Francisco")
1214 require.Nil(t, got)
1215 })
1216}
1217
1218func TestStream_WebSearchResponse(t *testing.T) {
1219 t.Parallel()
1220
1221 // Build SSE chunks that simulate a web search streaming response.
1222 // The Anthropic SDK accumulates content blocks via
1223 // acc.Accumulate(event). We read the Content and ToolUseID
1224 // directly from struct fields instead of using AsAny(), which
1225 // avoids the SDK's re-marshal limitation that previously dropped
1226 // source data.
1227 webSearchResultContent, _ := json.Marshal([]any{
1228 map[string]any{
1229 "type": "web_search_result",
1230 "url": "https://example.com/ai-news",
1231 "title": "Latest AI News",
1232 "encrypted_content": "encrypted_abc123",
1233 "page_age": "2 hours ago",
1234 },
1235 })
1236
1237 chunks := []string{
1238 // message_start
1239 "event: message_start\n",
1240 `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1241 // Block 0: server_tool_use
1242 "event: content_block_start\n",
1243 `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1244 "event: content_block_stop\n",
1245 `data: {"type":"content_block_stop","index":0}` + "\n\n",
1246 // Block 1: web_search_tool_result
1247 "event: content_block_start\n",
1248 `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1249 "event: content_block_stop\n",
1250 `data: {"type":"content_block_stop","index":1}` + "\n\n",
1251 // Block 2: text
1252 "event: content_block_start\n",
1253 `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1254 "event: content_block_delta\n",
1255 `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1256 "event: content_block_stop\n",
1257 `data: {"type":"content_block_stop","index":2}` + "\n\n",
1258 // message_stop
1259 "event: message_stop\n",
1260 `data: {"type":"message_stop"}` + "\n\n",
1261 }
1262
1263 server, calls := newAnthropicStreamingServer(chunks)
1264 defer server.Close()
1265
1266 provider, err := New(
1267 WithAPIKey("test-api-key"),
1268 WithBaseURL(server.URL),
1269 )
1270 require.NoError(t, err)
1271
1272 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1273 require.NoError(t, err)
1274
1275 stream, err := model.Stream(context.Background(), fantasy.Call{
1276 Prompt: testPrompt(),
1277 Tools: []fantasy.Tool{
1278 WebSearchTool(nil),
1279 },
1280 })
1281 require.NoError(t, err)
1282
1283 var parts []fantasy.StreamPart
1284 stream(func(part fantasy.StreamPart) bool {
1285 parts = append(parts, part)
1286 return true
1287 })
1288
1289 _ = awaitAnthropicCall(t, calls)
1290
1291 // Collect parts by type for assertions.
1292 var (
1293 toolInputStarts []fantasy.StreamPart
1294 toolCalls []fantasy.StreamPart
1295 toolResults []fantasy.StreamPart
1296 sourceParts []fantasy.StreamPart
1297 textDeltas []fantasy.StreamPart
1298 )
1299 for _, p := range parts {
1300 switch p.Type {
1301 case fantasy.StreamPartTypeToolInputStart:
1302 toolInputStarts = append(toolInputStarts, p)
1303 case fantasy.StreamPartTypeToolCall:
1304 toolCalls = append(toolCalls, p)
1305 case fantasy.StreamPartTypeToolResult:
1306 toolResults = append(toolResults, p)
1307 case fantasy.StreamPartTypeSource:
1308 sourceParts = append(sourceParts, p)
1309 case fantasy.StreamPartTypeTextDelta:
1310 textDeltas = append(textDeltas, p)
1311 }
1312 }
1313
1314 // server_tool_use emits a ToolInputStart with ProviderExecuted.
1315 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1316 require.True(t, toolInputStarts[0].ProviderExecuted)
1317 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1318
1319 // server_tool_use emits a ToolCall with ProviderExecuted.
1320 require.NotEmpty(t, toolCalls, "should have a tool call")
1321 require.True(t, toolCalls[0].ProviderExecuted)
1322
1323 // web_search_tool_result always emits a ToolResult even when
1324 // the SDK drops source data. The ToolUseID comes from the raw
1325 // union field as a fallback.
1326 require.NotEmpty(t, toolResults, "should have a tool result")
1327 require.True(t, toolResults[0].ProviderExecuted)
1328 require.Equal(t, "web_search", toolResults[0].ToolCallName)
1329 require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1330 "tool result ID should match the tool_use_id")
1331
1332 // Source parts are now correctly emitted by reading struct fields
1333 // directly instead of using AsAny().
1334 require.Len(t, sourceParts, 1)
1335 require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1336 require.Equal(t, "Latest AI News", sourceParts[0].Title)
1337 require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1338
1339 // Text block emits a text delta.
1340 require.NotEmpty(t, textDeltas, "should have text deltas")
1341 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1342}
1343
1344func TestGenerate_ToolChoiceNone(t *testing.T) {
1345 t.Parallel()
1346
1347 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1348 defer server.Close()
1349
1350 provider, err := New(
1351 WithAPIKey("test-api-key"),
1352 WithBaseURL(server.URL),
1353 )
1354 require.NoError(t, err)
1355
1356 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1357 require.NoError(t, err)
1358
1359 toolChoiceNone := fantasy.ToolChoiceNone
1360 _, err = model.Generate(context.Background(), fantasy.Call{
1361 Prompt: testPrompt(),
1362 Tools: []fantasy.Tool{
1363 WebSearchTool(nil),
1364 },
1365 ToolChoice: &toolChoiceNone,
1366 })
1367 require.NoError(t, err)
1368
1369 call := awaitAnthropicCall(t, calls)
1370 toolChoice, ok := call.body["tool_choice"].(map[string]any)
1371 require.True(t, ok, "request body should have tool_choice")
1372 require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1373}
1374
1375// --- Computer Use Tests ---
1376
1377// jsonRoundTripTool simulates a JSON round-trip on a
1378// ProviderDefinedTool so that its Args map contains float64
1379// values (as json.Unmarshal produces) rather than the int64
1380// values that NewComputerUseTool stores directly. The
1381// production toBetaTools code asserts float64.
1382func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1383 t.Helper()
1384 pdt := tool.Definition()
1385 data, err := json.Marshal(pdt.Args)
1386 require.NoError(t, err)
1387 var args map[string]any
1388 require.NoError(t, json.Unmarshal(data, &args))
1389 pdt.Args = args
1390 return pdt
1391}
1392
1393func TestNewComputerUseTool(t *testing.T) {
1394 t.Parallel()
1395
1396 t.Run("creates tool with correct ID and name", func(t *testing.T) {
1397 t.Parallel()
1398 tool := NewComputerUseTool(ComputerUseToolOptions{
1399 DisplayWidthPx: 1920,
1400 DisplayHeightPx: 1080,
1401 ToolVersion: ComputerUse20250124,
1402 }, noopComputerRun).Definition()
1403 require.Equal(t, "anthropic.computer", tool.ID)
1404 require.Equal(t, "computer", tool.Name)
1405 require.Equal(t, int64(1920), tool.Args["display_width_px"])
1406 require.Equal(t, int64(1080), tool.Args["display_height_px"])
1407 require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1408 })
1409
1410 t.Run("includes optional fields when set", func(t *testing.T) {
1411 t.Parallel()
1412 displayNum := int64(1)
1413 enableZoom := true
1414 tool := NewComputerUseTool(ComputerUseToolOptions{
1415 DisplayWidthPx: 1024,
1416 DisplayHeightPx: 768,
1417 DisplayNumber: &displayNum,
1418 EnableZoom: &enableZoom,
1419 ToolVersion: ComputerUse20251124,
1420 CacheControl: &CacheControl{Type: "ephemeral"},
1421 }, noopComputerRun).Definition()
1422 require.Equal(t, int64(1), tool.Args["display_number"])
1423 require.Equal(t, true, tool.Args["enable_zoom"])
1424 require.NotNil(t, tool.Args["cache_control"])
1425 })
1426
1427 t.Run("omits optional fields when nil", func(t *testing.T) {
1428 t.Parallel()
1429 tool := NewComputerUseTool(ComputerUseToolOptions{
1430 DisplayWidthPx: 1920,
1431 DisplayHeightPx: 1080,
1432 ToolVersion: ComputerUse20250124,
1433 }, noopComputerRun).Definition()
1434 _, hasDisplayNum := tool.Args["display_number"]
1435 _, hasEnableZoom := tool.Args["enable_zoom"]
1436 _, hasCacheControl := tool.Args["cache_control"]
1437 require.False(t, hasDisplayNum)
1438 require.False(t, hasEnableZoom)
1439 require.False(t, hasCacheControl)
1440 })
1441}
1442
1443func TestIsComputerUseTool(t *testing.T) {
1444 t.Parallel()
1445
1446 t.Run("returns true for computer use tool", func(t *testing.T) {
1447 t.Parallel()
1448 tool := NewComputerUseTool(ComputerUseToolOptions{
1449 DisplayWidthPx: 1920,
1450 DisplayHeightPx: 1080,
1451 ToolVersion: ComputerUse20250124,
1452 }, noopComputerRun)
1453 require.True(t, IsComputerUseTool(tool.Definition()))
1454 })
1455
1456 t.Run("returns false for function tool", func(t *testing.T) {
1457 t.Parallel()
1458 tool := fantasy.FunctionTool{
1459 Name: "test",
1460 Description: "test tool",
1461 }
1462 require.False(t, IsComputerUseTool(tool))
1463 })
1464
1465 t.Run("returns false for other provider defined tool", func(t *testing.T) {
1466 t.Parallel()
1467 tool := fantasy.ProviderDefinedTool{
1468 ID: "other.tool",
1469 Name: "other",
1470 }
1471 require.False(t, IsComputerUseTool(tool))
1472 })
1473}
1474
1475func TestNeedsBetaAPI(t *testing.T) {
1476 t.Parallel()
1477
1478 lm := languageModel{options: options{}}
1479
1480 t.Run("returns false for empty tools", func(t *testing.T) {
1481 t.Parallel()
1482 _, _, _, betaFlags := lm.toTools(nil, nil, false)
1483 require.Empty(t, betaFlags)
1484 _, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1485 require.Empty(t, betaFlags)
1486 })
1487
1488 t.Run("returns false for only function tools", func(t *testing.T) {
1489 t.Parallel()
1490 tools := []fantasy.Tool{
1491 fantasy.FunctionTool{Name: "test"},
1492 }
1493 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1494 require.Empty(t, betaFlags)
1495 })
1496
1497 t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1498 t.Parallel()
1499 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1500 DisplayWidthPx: 1920,
1501 DisplayHeightPx: 1080,
1502 ToolVersion: ComputerUse20250124,
1503 }, noopComputerRun))
1504 tools := []fantasy.Tool{
1505 fantasy.FunctionTool{Name: "test"},
1506 cuTool,
1507 }
1508 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1509 require.NotEmpty(t, betaFlags)
1510 })
1511}
1512
1513func TestComputerUseToolJSON(t *testing.T) {
1514 t.Parallel()
1515
1516 t.Run("builds JSON for version 20250124", func(t *testing.T) {
1517 t.Parallel()
1518 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1519 DisplayWidthPx: 1920,
1520 DisplayHeightPx: 1080,
1521 ToolVersion: ComputerUse20250124,
1522 }, noopComputerRun))
1523 data, err := computerUseToolJSON(cuTool)
1524 require.NoError(t, err)
1525 var m map[string]any
1526 require.NoError(t, json.Unmarshal(data, &m))
1527 require.Equal(t, "computer_20250124", m["type"])
1528 require.Equal(t, "computer", m["name"])
1529 require.InDelta(t, 1920, m["display_width_px"], 0)
1530 require.InDelta(t, 1080, m["display_height_px"], 0)
1531 })
1532
1533 t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1534 t.Parallel()
1535 enableZoom := true
1536 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1537 DisplayWidthPx: 1024,
1538 DisplayHeightPx: 768,
1539 EnableZoom: &enableZoom,
1540 ToolVersion: ComputerUse20251124,
1541 }, noopComputerRun))
1542 data, err := computerUseToolJSON(cuTool)
1543 require.NoError(t, err)
1544 var m map[string]any
1545 require.NoError(t, json.Unmarshal(data, &m))
1546 require.Equal(t, "computer_20251124", m["type"])
1547 require.Equal(t, true, m["enable_zoom"])
1548 })
1549
1550 t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1551 t.Parallel()
1552 // Direct construction stores int64 values.
1553 cuTool := NewComputerUseTool(ComputerUseToolOptions{
1554 DisplayWidthPx: 1920,
1555 DisplayHeightPx: 1080,
1556 ToolVersion: ComputerUse20250124,
1557 }, noopComputerRun)
1558 data, err := computerUseToolJSON(cuTool.Definition())
1559 require.NoError(t, err)
1560 var m map[string]any
1561 require.NoError(t, json.Unmarshal(data, &m))
1562 require.InDelta(t, 1920, m["display_width_px"], 0)
1563 })
1564
1565 t.Run("returns error when version is missing", func(t *testing.T) {
1566 t.Parallel()
1567 pdt := fantasy.ProviderDefinedTool{
1568 ID: "anthropic.computer",
1569 Name: "computer",
1570 Args: map[string]any{
1571 "display_width_px": float64(1920),
1572 "display_height_px": float64(1080),
1573 },
1574 }
1575 _, err := computerUseToolJSON(pdt)
1576 require.Error(t, err)
1577 require.Contains(t, err.Error(), "tool_version arg is missing")
1578 })
1579
1580 t.Run("returns error for unsupported version", func(t *testing.T) {
1581 t.Parallel()
1582 pdt := fantasy.ProviderDefinedTool{
1583 ID: "anthropic.computer",
1584 Name: "computer",
1585 Args: map[string]any{
1586 "display_width_px": float64(1920),
1587 "display_height_px": float64(1080),
1588 "tool_version": "computer_99991231",
1589 },
1590 }
1591 _, err := computerUseToolJSON(pdt)
1592 require.Error(t, err)
1593 require.Contains(t, err.Error(), "unsupported")
1594 })
1595}
1596
1597func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1598 t.Parallel()
1599
1600 t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1601 t.Parallel()
1602 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1603 require.Error(t, err)
1604 require.Contains(t, err.Error(), "coordinate")
1605 })
1606
1607 t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1608 t.Parallel()
1609 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1610 require.Error(t, err)
1611 require.Contains(t, err.Error(), "coordinate")
1612 })
1613
1614 t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1615 t.Parallel()
1616 _, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1617 require.Error(t, err)
1618 require.Contains(t, err.Error(), "start_coordinate")
1619 })
1620
1621 t.Run("rejects region with 3 elements", func(t *testing.T) {
1622 t.Parallel()
1623 _, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1624 require.Error(t, err)
1625 require.Contains(t, err.Error(), "region")
1626 })
1627
1628 t.Run("accepts valid coordinate", func(t *testing.T) {
1629 t.Parallel()
1630 result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1631 require.NoError(t, err)
1632 require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1633 })
1634
1635 t.Run("accepts absent optional arrays", func(t *testing.T) {
1636 t.Parallel()
1637 result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1638 require.NoError(t, err)
1639 require.Equal(t, ActionScreenshot, result.Action)
1640 })
1641}
1642
1643func TestToTools_RawJSON(t *testing.T) {
1644 t.Parallel()
1645
1646 lm := languageModel{options: options{}}
1647
1648 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1649 DisplayWidthPx: 1920,
1650 DisplayHeightPx: 1080,
1651 ToolVersion: ComputerUse20250124,
1652 }, noopComputerRun))
1653
1654 tools := []fantasy.Tool{
1655 fantasy.FunctionTool{
1656 Name: "weather",
1657 Description: "Get weather",
1658 InputSchema: map[string]any{
1659 "properties": map[string]any{
1660 "location": map[string]any{"type": "string"},
1661 },
1662 "required": []string{"location"},
1663 },
1664 },
1665 WebSearchTool(nil),
1666 cuTool,
1667 }
1668
1669 rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1670
1671 require.Len(t, rawTools, 3)
1672 require.Nil(t, toolChoice)
1673 require.Empty(t, warnings)
1674 require.NotEmpty(t, betaFlags)
1675
1676 // Verify each raw tool is valid JSON.
1677 for i, raw := range rawTools {
1678 var m map[string]any
1679 require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1680 }
1681
1682 // Check function tool.
1683 var funcTool map[string]any
1684 require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1685 require.Equal(t, "weather", funcTool["name"])
1686
1687 // Check web search tool.
1688 var webTool map[string]any
1689 require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1690 require.Equal(t, "web_search_20250305", webTool["type"])
1691
1692 // Check computer use tool.
1693 var cuToolJSON map[string]any
1694 require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1695 require.Equal(t, "computer_20250124", cuToolJSON["type"])
1696 require.Equal(t, "computer", cuToolJSON["name"])
1697}
1698
1699func TestGenerate_BetaAPI(t *testing.T) {
1700 t.Parallel()
1701
1702 t.Run("sends beta header for computer use", func(t *testing.T) {
1703 t.Parallel()
1704
1705 var capturedHeaders http.Header
1706 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1707 capturedHeaders = r.Header.Clone()
1708 w.Header().Set("Content-Type", "application/json")
1709 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1710 }))
1711 defer server.Close()
1712
1713 provider, err := New(
1714 WithAPIKey("test-api-key"),
1715 WithBaseURL(server.URL),
1716 )
1717 require.NoError(t, err)
1718
1719 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1720 require.NoError(t, err)
1721
1722 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1723 DisplayWidthPx: 1920,
1724 DisplayHeightPx: 1080,
1725 ToolVersion: ComputerUse20250124,
1726 }, noopComputerRun))
1727
1728 _, err = model.Generate(context.Background(), fantasy.Call{
1729 Prompt: testPrompt(),
1730 Tools: []fantasy.Tool{cuTool},
1731 })
1732 require.NoError(t, err)
1733 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1734 })
1735
1736 t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1737 t.Parallel()
1738
1739 var capturedHeaders http.Header
1740 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1741 capturedHeaders = r.Header.Clone()
1742 w.Header().Set("Content-Type", "application/json")
1743 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1744 }))
1745 defer server.Close()
1746
1747 provider, err := New(
1748 WithAPIKey("test-api-key"),
1749 WithBaseURL(server.URL),
1750 )
1751 require.NoError(t, err)
1752
1753 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1754 require.NoError(t, err)
1755
1756 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1757 DisplayWidthPx: 1920,
1758 DisplayHeightPx: 1080,
1759 ToolVersion: ComputerUse20251124,
1760 }, noopComputerRun))
1761
1762 _, err = model.Generate(context.Background(), fantasy.Call{
1763 Prompt: testPrompt(),
1764 Tools: []fantasy.Tool{cuTool},
1765 })
1766 require.NoError(t, err)
1767 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1768 })
1769
1770 t.Run("returns tool use from beta response", func(t *testing.T) {
1771 t.Parallel()
1772
1773 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1774 w.Header().Set("Content-Type", "application/json")
1775 _ = json.NewEncoder(w).Encode(map[string]any{
1776 "id": "msg_01Test",
1777 "type": "message",
1778 "role": "assistant",
1779 "model": "claude-sonnet-4-20250514",
1780 "content": []any{
1781 map[string]any{
1782 "type": "tool_use",
1783 "id": "toolu_01",
1784 "name": "computer",
1785 "input": map[string]any{"action": "screenshot"},
1786 },
1787 },
1788 "stop_reason": "tool_use",
1789 "usage": map[string]any{
1790 "input_tokens": 10,
1791 "output_tokens": 5,
1792 "cache_creation": map[string]any{
1793 "ephemeral_1h_input_tokens": 0,
1794 "ephemeral_5m_input_tokens": 0,
1795 },
1796 "cache_creation_input_tokens": 0,
1797 "cache_read_input_tokens": 0,
1798 "server_tool_use": map[string]any{
1799 "web_search_requests": 0,
1800 },
1801 "service_tier": "standard",
1802 },
1803 })
1804 }))
1805 defer server.Close()
1806
1807 provider, err := New(
1808 WithAPIKey("test-api-key"),
1809 WithBaseURL(server.URL),
1810 )
1811 require.NoError(t, err)
1812
1813 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1814 require.NoError(t, err)
1815
1816 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1817 DisplayWidthPx: 1920,
1818 DisplayHeightPx: 1080,
1819 ToolVersion: ComputerUse20250124,
1820 }, noopComputerRun))
1821
1822 resp, err := model.Generate(context.Background(), fantasy.Call{
1823 Prompt: testPrompt(),
1824 Tools: []fantasy.Tool{cuTool},
1825 })
1826 require.NoError(t, err)
1827
1828 toolCalls := resp.Content.ToolCalls()
1829 require.Len(t, toolCalls, 1)
1830 require.Equal(t, "computer", toolCalls[0].ToolName)
1831 require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1832 require.Contains(t, toolCalls[0].Input, "screenshot")
1833 require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1834
1835 // Verify typed parsing works on the tool call input.
1836 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1837 require.NoError(t, err)
1838 require.Equal(t, ActionScreenshot, parsed.Action)
1839 })
1840}
1841
1842func TestStream_BetaAPI(t *testing.T) {
1843 t.Parallel()
1844
1845 t.Run("streams via beta API for computer use", func(t *testing.T) {
1846 t.Parallel()
1847
1848 var capturedHeaders http.Header
1849 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1850 capturedHeaders = r.Header.Clone()
1851 w.Header().Set("Content-Type", "text/event-stream")
1852 w.Header().Set("Cache-Control", "no-cache")
1853 w.WriteHeader(http.StatusOK)
1854 chunks := []string{
1855 "event: message_start\n",
1856 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1857 "event: message_stop\n",
1858 "data: {\"type\":\"message_stop\"}\n\n",
1859 }
1860 for _, chunk := range chunks {
1861 _, _ = fmt.Fprint(w, chunk)
1862 if flusher, ok := w.(http.Flusher); ok {
1863 flusher.Flush()
1864 }
1865 }
1866 }))
1867 defer server.Close()
1868
1869 provider, err := New(
1870 WithAPIKey("test-api-key"),
1871 WithBaseURL(server.URL),
1872 )
1873 require.NoError(t, err)
1874
1875 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1876 require.NoError(t, err)
1877
1878 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1879 DisplayWidthPx: 1920,
1880 DisplayHeightPx: 1080,
1881 ToolVersion: ComputerUse20250124,
1882 }, noopComputerRun))
1883
1884 stream, err := model.Stream(context.Background(), fantasy.Call{
1885 Prompt: testPrompt(),
1886 Tools: []fantasy.Tool{cuTool},
1887 })
1888 require.NoError(t, err)
1889
1890 stream(func(fantasy.StreamPart) bool { return true })
1891
1892 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1893 })
1894
1895 t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1896 t.Parallel()
1897
1898 var capturedHeaders http.Header
1899 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1900 capturedHeaders = r.Header.Clone()
1901 w.Header().Set("Content-Type", "text/event-stream")
1902 w.Header().Set("Cache-Control", "no-cache")
1903 w.WriteHeader(http.StatusOK)
1904 chunks := []string{
1905 "event: message_start\n",
1906 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1907 "event: message_stop\n",
1908 "data: {\"type\":\"message_stop\"}\n\n",
1909 }
1910 for _, chunk := range chunks {
1911 _, _ = fmt.Fprint(w, chunk)
1912 if flusher, ok := w.(http.Flusher); ok {
1913 flusher.Flush()
1914 }
1915 }
1916 }))
1917 defer server.Close()
1918
1919 provider, err := New(
1920 WithAPIKey("test-api-key"),
1921 WithBaseURL(server.URL),
1922 )
1923 require.NoError(t, err)
1924
1925 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1926 require.NoError(t, err)
1927
1928 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1929 DisplayWidthPx: 1920,
1930 DisplayHeightPx: 1080,
1931 ToolVersion: ComputerUse20251124,
1932 }, noopComputerRun))
1933
1934 stream, err := model.Stream(context.Background(), fantasy.Call{
1935 Prompt: testPrompt(),
1936 Tools: []fantasy.Tool{cuTool},
1937 })
1938 require.NoError(t, err)
1939
1940 stream(func(fantasy.StreamPart) bool { return true })
1941
1942 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1943 })
1944}
1945
1946// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1947// via model.Generate, passing the ExecutableProviderTool directly into
1948// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1949// walks through a scripted sequence of actions — screenshot, click,
1950// type, key, scroll — then finishes with a text reply. Each turn the
1951// test parses the tool call, builds a screenshot result, and appends
1952// both to the prompt for the next request.
1953func TestGenerate_ComputerUseTool(t *testing.T) {
1954 t.Parallel()
1955
1956 type actionStep struct {
1957 input map[string]any
1958 want ComputerUseInput
1959 }
1960 steps := []actionStep{
1961 {
1962 input: map[string]any{"action": "screenshot"},
1963 want: ComputerUseInput{Action: ActionScreenshot},
1964 },
1965 {
1966 input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
1967 want: ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
1968 },
1969 {
1970 input: map[string]any{"action": "type", "text": "hello world"},
1971 want: ComputerUseInput{Action: ActionType, Text: "hello world"},
1972 },
1973 {
1974 input: map[string]any{"action": "key", "text": "Return"},
1975 want: ComputerUseInput{Action: ActionKey, Text: "Return"},
1976 },
1977 {
1978 input: map[string]any{
1979 "action": "scroll",
1980 "coordinate": []any{500, 300},
1981 "scroll_direction": "down",
1982 "scroll_amount": 3,
1983 },
1984 want: ComputerUseInput{
1985 Action: ActionScroll,
1986 Coordinate: [2]int64{500, 300},
1987 ScrollDirection: "down",
1988 ScrollAmount: 3,
1989 },
1990 },
1991 {
1992 input: map[string]any{"action": "screenshot"},
1993 want: ComputerUseInput{Action: ActionScreenshot},
1994 },
1995 }
1996
1997 var (
1998 requestIdx int
1999 betaHeaders []string
2000 )
2001 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2002 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2003 idx := requestIdx
2004 requestIdx++
2005
2006 w.Header().Set("Content-Type", "application/json")
2007 if idx < len(steps) {
2008 _ = json.NewEncoder(w).Encode(map[string]any{
2009 "id": fmt.Sprintf("msg_%02d", idx),
2010 "type": "message",
2011 "role": "assistant",
2012 "model": "claude-sonnet-4-20250514",
2013 "content": []any{map[string]any{
2014 "type": "tool_use",
2015 "id": fmt.Sprintf("toolu_%02d", idx),
2016 "name": "computer",
2017 "input": steps[idx].input,
2018 }},
2019 "stop_reason": "tool_use",
2020 "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
2021 })
2022 return
2023 }
2024 _ = json.NewEncoder(w).Encode(map[string]any{
2025 "id": "msg_final",
2026 "type": "message",
2027 "role": "assistant",
2028 "model": "claude-sonnet-4-20250514",
2029 "content": []any{map[string]any{
2030 "type": "text",
2031 "text": "Done! I have completed all the requested actions.",
2032 }},
2033 "stop_reason": "end_turn",
2034 "usage": map[string]any{"input_tokens": 10, "output_tokens": 15},
2035 })
2036 }))
2037 defer server.Close()
2038
2039 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2040 require.NoError(t, err)
2041
2042 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2043 require.NoError(t, err)
2044
2045 // Pass the ExecutableProviderTool directly — the whole point is
2046 // to verify that the Tool interface works without unwrapping.
2047 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2048 DisplayWidthPx: 1920,
2049 DisplayHeightPx: 1080,
2050 ToolVersion: ComputerUse20250124,
2051 }, noopComputerRun)
2052
2053 var got []ComputerUseInput
2054 prompt := testPrompt()
2055 fakePNG := []byte("fake-screenshot-png")
2056
2057 for turn := 0; turn <= len(steps); turn++ {
2058 resp, err := model.Generate(context.Background(), fantasy.Call{
2059 Prompt: prompt,
2060 Tools: []fantasy.Tool{cuTool},
2061 })
2062 require.NoError(t, err, "turn %d", turn)
2063
2064 if resp.FinishReason != fantasy.FinishReasonToolCalls {
2065 require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2066 require.Contains(t, resp.Content.Text(), "Done")
2067 break
2068 }
2069
2070 toolCalls := resp.Content.ToolCalls()
2071 require.Len(t, toolCalls, 1, "turn %d", turn)
2072 require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2073
2074 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2075 require.NoError(t, err, "turn %d", turn)
2076 got = append(got, parsed)
2077
2078 // Build the next prompt: append the assistant tool-call turn
2079 // and the user screenshot-result turn.
2080 prompt = append(prompt,
2081 fantasy.Message{
2082 Role: fantasy.MessageRoleAssistant,
2083 Content: []fantasy.MessagePart{
2084 fantasy.ToolCallPart{
2085 ToolCallID: toolCalls[0].ToolCallID,
2086 ToolName: toolCalls[0].ToolName,
2087 Input: toolCalls[0].Input,
2088 },
2089 },
2090 },
2091 fantasy.Message{
2092 // Use MessageRoleTool for tool results — this matches
2093 // what the agent loop produces.
2094 Role: fantasy.MessageRoleTool,
2095 Content: []fantasy.MessagePart{
2096 NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2097 },
2098 },
2099 )
2100 }
2101
2102 // Every scripted action was received and parsed correctly.
2103 require.Len(t, got, len(steps))
2104 for i, step := range steps {
2105 require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2106 require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2107 require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2108 require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2109 require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2110 }
2111
2112 // Beta header was sent on every request.
2113 require.Len(t, betaHeaders, len(steps)+1)
2114 for i, h := range betaHeaders {
2115 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2116 }
2117}
2118
2119// TestStream_ComputerUseTool runs a multi-turn computer use session
2120// via model.Stream, verifying that the ExecutableProviderTool works
2121// through the streaming path end-to-end.
2122func TestStream_ComputerUseTool(t *testing.T) {
2123 t.Parallel()
2124
2125 type streamStep struct {
2126 input map[string]any
2127 wantAction ComputerAction
2128 }
2129 steps := []streamStep{
2130 {input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2131 {input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2132 {input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2133 }
2134
2135 var (
2136 requestIdx int
2137 betaHeaders []string
2138 )
2139
2140 // streamToolUseChunks returns SSE chunks for a single
2141 // computer-use tool_use content block.
2142 streamToolUseChunks := func(id string, input map[string]any) []string {
2143 inputJSON, _ := json.Marshal(input)
2144 escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2145 return []string{
2146 "event: message_start\n",
2147 `data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2148 "event: content_block_start\n",
2149 `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2150 "event: content_block_delta\n",
2151 `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2152 "event: content_block_stop\n",
2153 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2154 "event: message_delta\n",
2155 `data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2156 "event: message_stop\n",
2157 `data: {"type":"message_stop"}` + "\n\n",
2158 }
2159 }
2160
2161 streamTextChunks := func() []string {
2162 return []string{
2163 "event: message_start\n",
2164 `data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2165 "event: content_block_start\n",
2166 `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2167 "event: content_block_delta\n",
2168 `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2169 "event: content_block_stop\n",
2170 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2171 "event: message_delta\n",
2172 `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2173 "event: message_stop\n",
2174 `data: {"type":"message_stop"}` + "\n\n",
2175 }
2176 }
2177
2178 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2179 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2180 idx := requestIdx
2181 requestIdx++
2182
2183 w.Header().Set("Content-Type", "text/event-stream")
2184 w.Header().Set("Cache-Control", "no-cache")
2185 w.WriteHeader(http.StatusOK)
2186
2187 var chunks []string
2188 if idx < len(steps) {
2189 chunks = streamToolUseChunks(
2190 fmt.Sprintf("toolu_%02d", idx),
2191 steps[idx].input,
2192 )
2193 } else {
2194 chunks = streamTextChunks()
2195 }
2196 for _, chunk := range chunks {
2197 _, _ = fmt.Fprint(w, chunk)
2198 if f, ok := w.(http.Flusher); ok {
2199 f.Flush()
2200 }
2201 }
2202 }))
2203 defer server.Close()
2204
2205 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2206 require.NoError(t, err)
2207
2208 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2209 require.NoError(t, err)
2210
2211 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2212 DisplayWidthPx: 1920,
2213 DisplayHeightPx: 1080,
2214 ToolVersion: ComputerUse20250124,
2215 }, noopComputerRun)
2216
2217 var gotActions []ComputerAction
2218 prompt := testPrompt()
2219 fakePNG := []byte("fake-screenshot-png")
2220
2221 for turn := 0; turn <= len(steps); turn++ {
2222 stream, err := model.Stream(context.Background(), fantasy.Call{
2223 Prompt: prompt,
2224 Tools: []fantasy.Tool{cuTool},
2225 })
2226 require.NoError(t, err, "turn %d", turn)
2227
2228 var (
2229 toolCallName string
2230 toolCallID string
2231 toolCallInput string
2232 finishReason fantasy.FinishReason
2233 gotText string
2234 )
2235 stream(func(part fantasy.StreamPart) bool {
2236 switch part.Type {
2237 case fantasy.StreamPartTypeToolCall:
2238 toolCallName = part.ToolCallName
2239 toolCallID = part.ID
2240 toolCallInput = part.ToolCallInput
2241 case fantasy.StreamPartTypeFinish:
2242 finishReason = part.FinishReason
2243 case fantasy.StreamPartTypeTextDelta:
2244 gotText += part.Delta
2245 }
2246 return true
2247 })
2248
2249 if finishReason != fantasy.FinishReasonToolCalls {
2250 require.Contains(t, gotText, "All done")
2251 break
2252 }
2253
2254 require.Equal(t, "computer", toolCallName, "turn %d", turn)
2255
2256 parsed, err := ParseComputerUseInput(toolCallInput)
2257 require.NoError(t, err, "turn %d", turn)
2258 gotActions = append(gotActions, parsed.Action)
2259
2260 prompt = append(prompt,
2261 fantasy.Message{
2262 Role: fantasy.MessageRoleAssistant,
2263 Content: []fantasy.MessagePart{
2264 fantasy.ToolCallPart{
2265 ToolCallID: toolCallID,
2266 ToolName: toolCallName,
2267 Input: toolCallInput,
2268 },
2269 },
2270 },
2271 fantasy.Message{
2272 // Use MessageRoleTool for tool results — this matches
2273 // what the agent loop produces.
2274 Role: fantasy.MessageRoleTool,
2275 Content: []fantasy.MessagePart{
2276 NewComputerUseScreenshotResult(toolCallID, fakePNG),
2277 },
2278 },
2279 )
2280 }
2281
2282 require.Len(t, gotActions, len(steps))
2283 for i, step := range steps {
2284 require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2285 }
2286
2287 require.Len(t, betaHeaders, len(steps)+1)
2288 for i, h := range betaHeaders {
2289 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2290 }
2291}