1package anthropic
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "math"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/anthropic-sdk-go"
17 "github.com/stretchr/testify/require"
18)
19
20// noopComputerRun is a no-op run function for tests that only need
21// to inspect the tool definition, not execute it.
22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
23 return fantasy.ToolResponse{}, nil
24}
25
26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
27 t.Parallel()
28
29 t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
30 t.Parallel()
31
32 prompt := fantasy.Prompt{
33 {
34 Role: fantasy.MessageRoleUser,
35 Content: []fantasy.MessagePart{
36 fantasy.TextPart{Text: "Hello"},
37 },
38 },
39 {
40 Role: fantasy.MessageRoleAssistant,
41 Content: []fantasy.MessagePart{
42 fantasy.ReasoningPart{
43 Text: "Let me think about this...",
44 ProviderOptions: fantasy.ProviderOptions{
45 Name: &ReasoningOptionMetadata{
46 Signature: "abc123",
47 },
48 },
49 },
50 },
51 },
52 }
53
54 systemBlocks, messages, warnings := toPrompt(prompt, true)
55
56 require.Empty(t, systemBlocks)
57 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
58 require.Len(t, warnings, 1)
59 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
60 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
61 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
62 })
63
64 t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
65 t.Parallel()
66
67 prompt := fantasy.Prompt{
68 {
69 Role: fantasy.MessageRoleUser,
70 Content: []fantasy.MessagePart{
71 fantasy.TextPart{Text: "Hello"},
72 },
73 },
74 {
75 Role: fantasy.MessageRoleAssistant,
76 Content: []fantasy.MessagePart{
77 fantasy.ReasoningPart{
78 Text: "Let me think about this...",
79 ProviderOptions: fantasy.ProviderOptions{
80 Name: &ReasoningOptionMetadata{
81 Signature: "def456",
82 },
83 },
84 },
85 },
86 },
87 }
88
89 systemBlocks, messages, warnings := toPrompt(prompt, false)
90
91 require.Empty(t, systemBlocks)
92 require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
93 require.Len(t, warnings, 2)
94 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
95 require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
96 require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
97 require.Contains(t, warnings[1].Message, "dropping empty assistant message")
98 })
99
100 t.Run("should drop truly empty assistant messages", func(t *testing.T) {
101 t.Parallel()
102
103 prompt := fantasy.Prompt{
104 {
105 Role: fantasy.MessageRoleUser,
106 Content: []fantasy.MessagePart{
107 fantasy.TextPart{Text: "Hello"},
108 },
109 },
110 {
111 Role: fantasy.MessageRoleAssistant,
112 Content: []fantasy.MessagePart{},
113 },
114 }
115
116 systemBlocks, messages, warnings := toPrompt(prompt, true)
117
118 require.Empty(t, systemBlocks)
119 require.Len(t, messages, 1, "should only have user message")
120 require.Len(t, warnings, 1)
121 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
122 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
123 })
124
125 t.Run("should keep assistant messages with text content", func(t *testing.T) {
126 t.Parallel()
127
128 prompt := fantasy.Prompt{
129 {
130 Role: fantasy.MessageRoleUser,
131 Content: []fantasy.MessagePart{
132 fantasy.TextPart{Text: "Hello"},
133 },
134 },
135 {
136 Role: fantasy.MessageRoleAssistant,
137 Content: []fantasy.MessagePart{
138 fantasy.TextPart{Text: "Hi there!"},
139 },
140 },
141 }
142
143 systemBlocks, messages, warnings := toPrompt(prompt, true)
144
145 require.Empty(t, systemBlocks)
146 require.Len(t, messages, 2, "should have both user and assistant messages")
147 require.Empty(t, warnings)
148 })
149
150 t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
151 t.Parallel()
152
153 prompt := fantasy.Prompt{
154 {
155 Role: fantasy.MessageRoleUser,
156 Content: []fantasy.MessagePart{
157 fantasy.TextPart{Text: "What's the weather?"},
158 },
159 },
160 {
161 Role: fantasy.MessageRoleAssistant,
162 Content: []fantasy.MessagePart{
163 fantasy.ToolCallPart{
164 ToolCallID: "call_123",
165 ToolName: "get_weather",
166 Input: `{"location":"NYC"}`,
167 },
168 },
169 },
170 }
171
172 systemBlocks, messages, warnings := toPrompt(prompt, true)
173
174 require.Empty(t, systemBlocks)
175 require.Len(t, messages, 2, "should have both user and assistant messages")
176 require.Empty(t, warnings)
177 })
178
179 t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
180 t.Parallel()
181
182 prompt := fantasy.Prompt{
183 {
184 Role: fantasy.MessageRoleUser,
185 Content: []fantasy.MessagePart{
186 fantasy.TextPart{Text: "Hi"},
187 },
188 },
189 {
190 Role: fantasy.MessageRoleAssistant,
191 Content: []fantasy.MessagePart{
192 fantasy.ToolCallPart{
193 ToolCallID: "call_123",
194 ToolName: "get_weather",
195 Input: "{not-json",
196 },
197 },
198 },
199 }
200
201 systemBlocks, messages, warnings := toPrompt(prompt, true)
202
203 require.Empty(t, systemBlocks)
204 require.Len(t, messages, 1, "should only have user message")
205 require.Len(t, warnings, 1)
206 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
207 require.Contains(t, warnings[0].Message, "dropping empty assistant message")
208 })
209
210 t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
211 t.Parallel()
212
213 prompt := fantasy.Prompt{
214 {
215 Role: fantasy.MessageRoleUser,
216 Content: []fantasy.MessagePart{
217 fantasy.TextPart{Text: "Hello"},
218 },
219 },
220 {
221 Role: fantasy.MessageRoleAssistant,
222 Content: []fantasy.MessagePart{
223 fantasy.ReasoningPart{
224 Text: "Let me think...",
225 ProviderOptions: fantasy.ProviderOptions{
226 Name: &ReasoningOptionMetadata{
227 Signature: "abc123",
228 },
229 },
230 },
231 fantasy.TextPart{Text: "Hi there!"},
232 },
233 },
234 }
235
236 systemBlocks, messages, warnings := toPrompt(prompt, true)
237
238 require.Empty(t, systemBlocks)
239 require.Len(t, messages, 2, "should have both user and assistant messages")
240 require.Empty(t, warnings)
241 })
242
243 t.Run("should keep user messages with image content", func(t *testing.T) {
244 t.Parallel()
245
246 prompt := fantasy.Prompt{
247 {
248 Role: fantasy.MessageRoleUser,
249 Content: []fantasy.MessagePart{
250 fantasy.FilePart{
251 Data: []byte{0x01, 0x02, 0x03},
252 MediaType: "image/png",
253 },
254 },
255 },
256 }
257
258 systemBlocks, messages, warnings := toPrompt(prompt, true)
259
260 require.Empty(t, systemBlocks)
261 require.Len(t, messages, 1)
262 require.Empty(t, warnings)
263 })
264
265 t.Run("should drop user messages without visible content", func(t *testing.T) {
266 t.Parallel()
267
268 prompt := fantasy.Prompt{
269 {
270 Role: fantasy.MessageRoleUser,
271 Content: []fantasy.MessagePart{
272 fantasy.FilePart{
273 Data: []byte("not supported"),
274 MediaType: "application/pdf",
275 },
276 },
277 },
278 }
279
280 systemBlocks, messages, warnings := toPrompt(prompt, true)
281
282 require.Empty(t, systemBlocks)
283 require.Empty(t, messages)
284 require.Len(t, warnings, 1)
285 require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
286 require.Contains(t, warnings[0].Message, "dropping empty user message")
287 require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
288 })
289
290 t.Run("should keep user messages with tool results", func(t *testing.T) {
291 t.Parallel()
292
293 prompt := fantasy.Prompt{
294 {
295 Role: fantasy.MessageRoleTool,
296 Content: []fantasy.MessagePart{
297 fantasy.ToolResultPart{
298 ToolCallID: "call_123",
299 Output: fantasy.ToolResultOutputContentText{Text: "done"},
300 },
301 },
302 },
303 }
304
305 systemBlocks, messages, warnings := toPrompt(prompt, true)
306
307 require.Empty(t, systemBlocks)
308 require.Len(t, messages, 1)
309 require.Empty(t, warnings)
310 })
311
312 t.Run("should keep user messages with tool error results", func(t *testing.T) {
313 t.Parallel()
314
315 prompt := fantasy.Prompt{
316 {
317 Role: fantasy.MessageRoleTool,
318 Content: []fantasy.MessagePart{
319 fantasy.ToolResultPart{
320 ToolCallID: "call_456",
321 Output: fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
322 },
323 },
324 },
325 }
326
327 systemBlocks, messages, warnings := toPrompt(prompt, true)
328
329 require.Empty(t, systemBlocks)
330 require.Len(t, messages, 1)
331 require.Empty(t, warnings)
332 })
333
334 t.Run("should keep user messages with tool media results", func(t *testing.T) {
335 t.Parallel()
336
337 prompt := fantasy.Prompt{
338 {
339 Role: fantasy.MessageRoleTool,
340 Content: []fantasy.MessagePart{
341 fantasy.ToolResultPart{
342 ToolCallID: "call_789",
343 Output: fantasy.ToolResultOutputContentMedia{
344 Data: "AQID",
345 MediaType: "image/png",
346 },
347 },
348 },
349 },
350 }
351
352 systemBlocks, messages, warnings := toPrompt(prompt, true)
353
354 require.Empty(t, systemBlocks)
355 require.Len(t, messages, 1)
356 require.Empty(t, warnings)
357 })
358}
359
360func TestParseContextTooLargeError(t *testing.T) {
361 t.Parallel()
362
363 tests := []struct {
364 name string
365 message string
366 wantErr bool
367 wantUsed int
368 wantMax int
369 }{
370 {
371 name: "matches anthropic format",
372 message: "prompt is too long: 202630 tokens > 200000 maximum",
373 wantErr: true,
374 wantUsed: 202630,
375 wantMax: 200000,
376 },
377 {
378 name: "matches with different numbers",
379 message: "prompt is too long: 150000 tokens > 128000 maximum",
380 wantErr: true,
381 wantUsed: 150000,
382 wantMax: 128000,
383 },
384 {
385 name: "matches with extra whitespace",
386 message: "prompt is too long: 202630 tokens > 200000 maximum",
387 wantErr: true,
388 wantUsed: 202630,
389 wantMax: 200000,
390 },
391 {
392 name: "does not match unrelated error",
393 message: "invalid api key",
394 wantErr: false,
395 },
396 {
397 name: "does not match rate limit error",
398 message: "rate limit exceeded",
399 wantErr: false,
400 },
401 }
402
403 for _, tt := range tests {
404 t.Run(tt.name, func(t *testing.T) {
405 t.Parallel()
406 providerErr := &fantasy.ProviderError{Message: tt.message}
407 parseContextTooLargeError(tt.message, providerErr)
408
409 if tt.wantErr {
410 require.True(t, providerErr.IsContextTooLarge())
411 require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
412 require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
413 } else {
414 require.False(t, providerErr.IsContextTooLarge())
415 }
416 })
417 }
418}
419
420func TestParseOptions_Effort(t *testing.T) {
421 t.Parallel()
422
423 options, err := ParseOptions(map[string]any{
424 "send_reasoning": true,
425 "thinking": map[string]any{"budget_tokens": int64(2048)},
426 "effort": "medium",
427 "disable_parallel_tool_use": true,
428 })
429 require.NoError(t, err)
430 require.NotNil(t, options.SendReasoning)
431 require.True(t, *options.SendReasoning)
432 require.NotNil(t, options.Thinking)
433 require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
434 require.NotNil(t, options.Effort)
435 require.Equal(t, EffortMedium, *options.Effort)
436 require.NotNil(t, options.DisableParallelToolUse)
437 require.True(t, *options.DisableParallelToolUse)
438}
439
440func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
441 t.Parallel()
442
443 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
444 defer server.Close()
445
446 provider, err := New(
447 WithAPIKey("test-api-key"),
448 WithBaseURL(server.URL),
449 )
450 require.NoError(t, err)
451
452 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
453 require.NoError(t, err)
454
455 effort := EffortMedium
456 _, err = model.Generate(context.Background(), fantasy.Call{
457 Prompt: testPrompt(),
458 ProviderOptions: NewProviderOptions(&ProviderOptions{
459 Effort: &effort,
460 }),
461 })
462 require.NoError(t, err)
463
464 call := awaitAnthropicCall(t, calls)
465 require.Equal(t, "POST", call.method)
466 require.Equal(t, "/v1/messages", call.path)
467 requireAnthropicEffort(t, call.body, EffortMedium)
468}
469
470func TestStream_SendsOutputConfigEffort(t *testing.T) {
471 t.Parallel()
472
473 server, calls := newAnthropicStreamingServer([]string{
474 "event: message_start\n",
475 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
476 "event: message_stop\n",
477 "data: {\"type\":\"message_stop\"}\n\n",
478 })
479 defer server.Close()
480
481 provider, err := New(
482 WithAPIKey("test-api-key"),
483 WithBaseURL(server.URL),
484 )
485 require.NoError(t, err)
486
487 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
488 require.NoError(t, err)
489
490 effort := EffortHigh
491 stream, err := model.Stream(context.Background(), fantasy.Call{
492 Prompt: testPrompt(),
493 ProviderOptions: NewProviderOptions(&ProviderOptions{
494 Effort: &effort,
495 }),
496 })
497 require.NoError(t, err)
498
499 stream(func(fantasy.StreamPart) bool { return true })
500
501 call := awaitAnthropicCall(t, calls)
502 require.Equal(t, "POST", call.method)
503 require.Equal(t, "/v1/messages", call.path)
504 requireAnthropicEffort(t, call.body, EffortHigh)
505}
506
507type anthropicCall struct {
508 method string
509 path string
510 body map[string]any
511}
512
513func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
514 calls := make(chan anthropicCall, 4)
515
516 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
517 var body map[string]any
518 if r.Body != nil {
519 _ = json.NewDecoder(r.Body).Decode(&body)
520 }
521
522 calls <- anthropicCall{
523 method: r.Method,
524 path: r.URL.Path,
525 body: body,
526 }
527
528 w.Header().Set("Content-Type", "application/json")
529 _ = json.NewEncoder(w).Encode(response)
530 }))
531
532 return server, calls
533}
534
535func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
536 calls := make(chan anthropicCall, 4)
537
538 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
539 var body map[string]any
540 if r.Body != nil {
541 _ = json.NewDecoder(r.Body).Decode(&body)
542 }
543
544 calls <- anthropicCall{
545 method: r.Method,
546 path: r.URL.Path,
547 body: body,
548 }
549
550 w.Header().Set("Content-Type", "text/event-stream")
551 w.Header().Set("Cache-Control", "no-cache")
552 w.Header().Set("Connection", "keep-alive")
553 w.WriteHeader(http.StatusOK)
554
555 for _, chunk := range chunks {
556 _, _ = fmt.Fprint(w, chunk)
557 if flusher, ok := w.(http.Flusher); ok {
558 flusher.Flush()
559 }
560 }
561 }))
562
563 return server, calls
564}
565
566func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
567 t.Helper()
568
569 select {
570 case call := <-calls:
571 return call
572 case <-time.After(2 * time.Second):
573 t.Fatal("timed out waiting for Anthropic request")
574 return anthropicCall{}
575 }
576}
577
578func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
579 t.Helper()
580
581 select {
582 case call := <-calls:
583 t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
584 case <-time.After(200 * time.Millisecond):
585 }
586}
587
588func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
589 t.Helper()
590
591 outputConfig, ok := body["output_config"].(map[string]any)
592 thinking, ok := body["thinking"].(map[string]any)
593 require.True(t, ok)
594 require.Equal(t, string(expected), outputConfig["effort"])
595 require.Equal(t, "adaptive", thinking["type"])
596}
597
598func testPrompt() fantasy.Prompt {
599 return fantasy.Prompt{
600 {
601 Role: fantasy.MessageRoleUser,
602 Content: []fantasy.MessagePart{
603 fantasy.TextPart{Text: "Hello"},
604 },
605 },
606 }
607}
608
609func mockAnthropicGenerateResponse() map[string]any {
610 return map[string]any{
611 "id": "msg_01Test",
612 "type": "message",
613 "role": "assistant",
614 "model": "claude-sonnet-4-20250514",
615 "content": []any{
616 map[string]any{
617 "type": "text",
618 "text": "Hi there",
619 },
620 },
621 "stop_reason": "end_turn",
622 "stop_sequence": "",
623 "usage": map[string]any{
624 "cache_creation": map[string]any{
625 "ephemeral_1h_input_tokens": 0,
626 "ephemeral_5m_input_tokens": 0,
627 },
628 "cache_creation_input_tokens": 0,
629 "cache_read_input_tokens": 0,
630 "input_tokens": 5,
631 "output_tokens": 2,
632 "server_tool_use": map[string]any{
633 "web_search_requests": 0,
634 },
635 "service_tier": "standard",
636 },
637 }
638}
639
640func mockAnthropicWebSearchResponse() map[string]any {
641 return map[string]any{
642 "id": "msg_01WebSearch",
643 "type": "message",
644 "role": "assistant",
645 "model": "claude-sonnet-4-20250514",
646 "content": []any{
647 map[string]any{
648 "type": "server_tool_use",
649 "id": "srvtoolu_01",
650 "name": "web_search",
651 "input": map[string]any{"query": "latest AI news"},
652 "caller": map[string]any{"type": "direct"},
653 },
654 map[string]any{
655 "type": "web_search_tool_result",
656 "tool_use_id": "srvtoolu_01",
657 "caller": map[string]any{"type": "direct"},
658 "content": []any{
659 map[string]any{
660 "type": "web_search_result",
661 "url": "https://example.com/ai-news",
662 "title": "Latest AI News",
663 "encrypted_content": "encrypted_abc123",
664 "page_age": "2 hours ago",
665 },
666 map[string]any{
667 "type": "web_search_result",
668 "url": "https://example.com/ml-update",
669 "title": "ML Update",
670 "encrypted_content": "encrypted_def456",
671 "page_age": "",
672 },
673 },
674 },
675 map[string]any{
676 "type": "text",
677 "text": "Based on recent search results, here is the latest AI news.",
678 },
679 },
680 "stop_reason": "end_turn",
681 "stop_sequence": nil,
682 "usage": map[string]any{
683 "input_tokens": 100,
684 "output_tokens": 50,
685 "cache_creation_input_tokens": 0,
686 "cache_read_input_tokens": 0,
687 "server_tool_use": map[string]any{
688 "web_search_requests": 1,
689 },
690 },
691 }
692}
693
694func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
695 t.Parallel()
696
697 prompt := fantasy.Prompt{
698 // User message.
699 {
700 Role: fantasy.MessageRoleUser,
701 Content: []fantasy.MessagePart{
702 fantasy.TextPart{Text: "Search for the latest AI news"},
703 },
704 },
705 // Assistant message with a provider-executed tool call, its
706 // result, and trailing text. toResponseMessages routes
707 // provider-executed results into the assistant message, so
708 // the prompt already reflects that structure.
709 {
710 Role: fantasy.MessageRoleAssistant,
711 Content: []fantasy.MessagePart{
712 fantasy.ToolCallPart{
713 ToolCallID: "srvtoolu_01",
714 ToolName: "web_search",
715 Input: `{"query":"latest AI news"}`,
716 ProviderExecuted: true,
717 },
718 fantasy.ToolResultPart{
719 ToolCallID: "srvtoolu_01",
720 ProviderExecuted: true,
721 ProviderOptions: fantasy.ProviderOptions{
722 Name: &WebSearchResultMetadata{
723 Results: []WebSearchResultItem{
724 {
725 URL: "https://example.com/ai-news",
726 Title: "Latest AI News",
727 EncryptedContent: "encrypted_abc123",
728 PageAge: "2 hours ago",
729 },
730 {
731 URL: "https://example.com/ml-update",
732 Title: "ML Update",
733 EncryptedContent: "encrypted_def456",
734 },
735 },
736 },
737 },
738 },
739 fantasy.TextPart{Text: "Here is what I found."},
740 },
741 },
742 }
743
744 _, messages, warnings := toPrompt(prompt, true)
745
746 // No warnings expected; the provider-executed result is in the
747 // assistant message so there is no empty tool message to drop.
748 require.Empty(t, warnings)
749
750 // We should have a user message and an assistant message.
751 require.Len(t, messages, 2, "expected user + assistant messages")
752
753 assistantMsg := messages[1]
754 require.Len(t, assistantMsg.Content, 3,
755 "expected server_tool_use + web_search_tool_result + text")
756
757 // First content block: reconstructed server_tool_use.
758 serverToolUse := assistantMsg.Content[0]
759 require.NotNil(t, serverToolUse.OfServerToolUse,
760 "first block should be a server_tool_use")
761 require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
762 require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
763 serverToolUse.OfServerToolUse.Name)
764
765 // Second content block: reconstructed web_search_tool_result with
766 // encrypted_content preserved for multi-turn round-tripping.
767 webResult := assistantMsg.Content[1]
768 require.NotNil(t, webResult.OfWebSearchToolResult,
769 "second block should be a web_search_tool_result")
770 require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
771
772 results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
773 require.Len(t, results, 2)
774 require.Equal(t, "https://example.com/ai-news", results[0].URL)
775 require.Equal(t, "Latest AI News", results[0].Title)
776 require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
777 require.Equal(t, "https://example.com/ml-update", results[1].URL)
778 require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
779 // PageAge should be set for the first result and absent for the second.
780 require.True(t, results[0].PageAge.Valid())
781 require.Equal(t, "2 hours ago", results[0].PageAge.Value)
782 require.False(t, results[1].PageAge.Valid())
783
784 // Third content block: plain text.
785 require.NotNil(t, assistantMsg.Content[2].OfText)
786 require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
787}
788
789func TestGenerate_WebSearchResponse(t *testing.T) {
790 t.Parallel()
791
792 server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
793 defer server.Close()
794
795 provider, err := New(
796 WithAPIKey("test-api-key"),
797 WithBaseURL(server.URL),
798 )
799 require.NoError(t, err)
800
801 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
802 require.NoError(t, err)
803
804 resp, err := model.Generate(context.Background(), fantasy.Call{
805 Prompt: testPrompt(),
806 Tools: []fantasy.Tool{
807 WebSearchTool(nil),
808 },
809 })
810 require.NoError(t, err)
811
812 call := awaitAnthropicCall(t, calls)
813 require.Equal(t, "POST", call.method)
814 require.Equal(t, "/v1/messages", call.path)
815
816 // Walk the response content and categorise each item.
817 var (
818 toolCalls []fantasy.ToolCallContent
819 sources []fantasy.SourceContent
820 toolResults []fantasy.ToolResultContent
821 texts []fantasy.TextContent
822 )
823 for _, c := range resp.Content {
824 switch v := c.(type) {
825 case fantasy.ToolCallContent:
826 toolCalls = append(toolCalls, v)
827 case fantasy.SourceContent:
828 sources = append(sources, v)
829 case fantasy.ToolResultContent:
830 toolResults = append(toolResults, v)
831 case fantasy.TextContent:
832 texts = append(texts, v)
833 }
834 }
835
836 // ToolCallContent for the provider-executed web_search.
837 require.Len(t, toolCalls, 1)
838 require.True(t, toolCalls[0].ProviderExecuted)
839 require.Equal(t, "web_search", toolCalls[0].ToolName)
840 require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
841
842 // SourceContent entries for each search result.
843 require.Len(t, sources, 2)
844 require.Equal(t, "https://example.com/ai-news", sources[0].URL)
845 require.Equal(t, "Latest AI News", sources[0].Title)
846 require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
847 require.Equal(t, "https://example.com/ml-update", sources[1].URL)
848 require.Equal(t, "ML Update", sources[1].Title)
849
850 // ToolResultContent with provider metadata preserving encrypted_content.
851 require.Len(t, toolResults, 1)
852 require.True(t, toolResults[0].ProviderExecuted)
853 require.Equal(t, "web_search", toolResults[0].ToolName)
854 require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
855
856 searchMeta, ok := toolResults[0].ProviderMetadata[Name]
857 require.True(t, ok, "providerMetadata should contain anthropic key")
858 webMeta, ok := searchMeta.(*WebSearchResultMetadata)
859 require.True(t, ok, "metadata should be *WebSearchResultMetadata")
860 require.Len(t, webMeta.Results, 2)
861 require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
862 require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
863 require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
864
865 // TextContent with the final answer.
866 require.Len(t, texts, 1)
867 require.Equal(t,
868 "Based on recent search results, here is the latest AI news.",
869 texts[0].Text,
870 )
871}
872
873func TestGenerate_WebSearchToolInRequest(t *testing.T) {
874 t.Parallel()
875
876 t.Run("basic web_search tool", func(t *testing.T) {
877 t.Parallel()
878
879 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
880 defer server.Close()
881
882 provider, err := New(
883 WithAPIKey("test-api-key"),
884 WithBaseURL(server.URL),
885 )
886 require.NoError(t, err)
887
888 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
889 require.NoError(t, err)
890
891 _, err = model.Generate(context.Background(), fantasy.Call{
892 Prompt: testPrompt(),
893 Tools: []fantasy.Tool{
894 WebSearchTool(nil),
895 },
896 })
897 require.NoError(t, err)
898
899 call := awaitAnthropicCall(t, calls)
900 tools, ok := call.body["tools"].([]any)
901 require.True(t, ok, "request body should have tools array")
902 require.Len(t, tools, 1)
903
904 tool, ok := tools[0].(map[string]any)
905 require.True(t, ok)
906 require.Equal(t, "web_search_20250305", tool["type"])
907 })
908
909 t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
910 t.Parallel()
911
912 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
913 defer server.Close()
914
915 provider, err := New(
916 WithAPIKey("test-api-key"),
917 WithBaseURL(server.URL),
918 )
919 require.NoError(t, err)
920
921 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
922 require.NoError(t, err)
923
924 _, err = model.Generate(context.Background(), fantasy.Call{
925 Prompt: testPrompt(),
926 Tools: []fantasy.Tool{
927 WebSearchTool(&WebSearchToolOptions{
928 AllowedDomains: []string{"example.com", "test.com"},
929 }),
930 },
931 })
932 require.NoError(t, err)
933
934 call := awaitAnthropicCall(t, calls)
935 tools, ok := call.body["tools"].([]any)
936 require.True(t, ok)
937 require.Len(t, tools, 1)
938
939 tool, ok := tools[0].(map[string]any)
940 require.True(t, ok)
941 require.Equal(t, "web_search_20250305", tool["type"])
942
943 domains, ok := tool["allowed_domains"].([]any)
944 require.True(t, ok, "tool should have allowed_domains")
945 require.Len(t, domains, 2)
946 require.Equal(t, "example.com", domains[0])
947 require.Equal(t, "test.com", domains[1])
948 })
949
950 t.Run("with max uses and user location", func(t *testing.T) {
951 t.Parallel()
952
953 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
954 defer server.Close()
955
956 provider, err := New(
957 WithAPIKey("test-api-key"),
958 WithBaseURL(server.URL),
959 )
960 require.NoError(t, err)
961
962 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
963 require.NoError(t, err)
964
965 _, err = model.Generate(context.Background(), fantasy.Call{
966 Prompt: testPrompt(),
967 Tools: []fantasy.Tool{
968 WebSearchTool(&WebSearchToolOptions{
969 MaxUses: 5,
970 UserLocation: &UserLocation{
971 City: "San Francisco",
972 Country: "US",
973 },
974 }),
975 },
976 })
977 require.NoError(t, err)
978
979 call := awaitAnthropicCall(t, calls)
980 tools, ok := call.body["tools"].([]any)
981 require.True(t, ok)
982 require.Len(t, tools, 1)
983
984 tool, ok := tools[0].(map[string]any)
985 require.True(t, ok)
986 require.Equal(t, "web_search_20250305", tool["type"])
987
988 // max_uses is serialized as a JSON number; json.Unmarshal
989 // into map[string]any decodes numbers as float64.
990 maxUses, ok := tool["max_uses"].(float64)
991 require.True(t, ok, "tool should have max_uses")
992 require.Equal(t, float64(5), maxUses)
993
994 userLoc, ok := tool["user_location"].(map[string]any)
995 require.True(t, ok, "tool should have user_location")
996 require.Equal(t, "San Francisco", userLoc["city"])
997 require.Equal(t, "US", userLoc["country"])
998 require.Equal(t, "approximate", userLoc["type"])
999 })
1000
1001 t.Run("with max uses", func(t *testing.T) {
1002 t.Parallel()
1003
1004 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1005 defer server.Close()
1006
1007 provider, err := New(
1008 WithAPIKey("test-api-key"),
1009 WithBaseURL(server.URL),
1010 )
1011 require.NoError(t, err)
1012
1013 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1014 require.NoError(t, err)
1015
1016 _, err = model.Generate(context.Background(), fantasy.Call{
1017 Prompt: testPrompt(),
1018 Tools: []fantasy.Tool{
1019 WebSearchTool(&WebSearchToolOptions{
1020 MaxUses: 3,
1021 }),
1022 },
1023 })
1024 require.NoError(t, err)
1025
1026 call := awaitAnthropicCall(t, calls)
1027 tools, ok := call.body["tools"].([]any)
1028 require.True(t, ok)
1029 require.Len(t, tools, 1)
1030
1031 tool, ok := tools[0].(map[string]any)
1032 require.True(t, ok)
1033 require.Equal(t, "web_search_20250305", tool["type"])
1034
1035 maxUses, ok := tool["max_uses"].(float64)
1036 require.True(t, ok, "tool should have max_uses")
1037 require.Equal(t, float64(3), maxUses)
1038 })
1039
1040 t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1041 t.Parallel()
1042
1043 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1044 defer server.Close()
1045
1046 provider, err := New(
1047 WithAPIKey("test-api-key"),
1048 WithBaseURL(server.URL),
1049 )
1050 require.NoError(t, err)
1051
1052 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1053 require.NoError(t, err)
1054
1055 baseTool := WebSearchTool(&WebSearchToolOptions{
1056 MaxUses: 7,
1057 BlockedDomains: []string{"example.com", "test.com"},
1058 UserLocation: &UserLocation{
1059 City: "San Francisco",
1060 Region: "CA",
1061 Country: "US",
1062 Timezone: "America/Los_Angeles",
1063 },
1064 })
1065
1066 data, err := json.Marshal(baseTool)
1067 require.NoError(t, err)
1068
1069 var roundTripped fantasy.ProviderDefinedTool
1070 err = json.Unmarshal(data, &roundTripped)
1071 require.NoError(t, err)
1072
1073 _, err = model.Generate(context.Background(), fantasy.Call{
1074 Prompt: testPrompt(),
1075 Tools: []fantasy.Tool{roundTripped},
1076 })
1077 require.NoError(t, err)
1078
1079 call := awaitAnthropicCall(t, calls)
1080 tools, ok := call.body["tools"].([]any)
1081 require.True(t, ok)
1082 require.Len(t, tools, 1)
1083
1084 tool, ok := tools[0].(map[string]any)
1085 require.True(t, ok)
1086 require.Equal(t, "web_search_20250305", tool["type"])
1087
1088 domains, ok := tool["blocked_domains"].([]any)
1089 require.True(t, ok, "tool should have blocked_domains")
1090 require.Len(t, domains, 2)
1091 require.Equal(t, "example.com", domains[0])
1092 require.Equal(t, "test.com", domains[1])
1093
1094 maxUses, ok := tool["max_uses"].(float64)
1095 require.True(t, ok, "tool should have max_uses")
1096 require.Equal(t, float64(7), maxUses)
1097
1098 userLoc, ok := tool["user_location"].(map[string]any)
1099 require.True(t, ok, "tool should have user_location")
1100 require.Equal(t, "San Francisco", userLoc["city"])
1101 require.Equal(t, "CA", userLoc["region"])
1102 require.Equal(t, "US", userLoc["country"])
1103 require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1104 require.Equal(t, "approximate", userLoc["type"])
1105 })
1106}
1107
1108func TestAnyToStringSlice(t *testing.T) {
1109 t.Parallel()
1110
1111 t.Run("from string slice", func(t *testing.T) {
1112 t.Parallel()
1113
1114 got := anyToStringSlice([]string{"example.com", ""})
1115 require.Equal(t, []string{"example.com", ""}, got)
1116 })
1117
1118 t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1119 t.Parallel()
1120
1121 got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1122 require.Equal(t, []string{"example.com", "test.com"}, got)
1123 })
1124
1125 t.Run("unsupported type", func(t *testing.T) {
1126 t.Parallel()
1127
1128 got := anyToStringSlice("example.com")
1129 require.Nil(t, got)
1130 })
1131}
1132
1133func TestAnyToInt64(t *testing.T) {
1134 t.Parallel()
1135
1136 tests := []struct {
1137 name string
1138 input any
1139 want int64
1140 wantOK bool
1141 }{
1142 {name: "int64", input: int64(7), want: 7, wantOK: true},
1143 {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1144 {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1145 {name: "float64 non-integer", input: float64(7.5), wantOK: false},
1146 {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1147 {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1148 {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1149 {name: "json number float", input: json.Number("4.2"), wantOK: false},
1150 {name: "nan", input: math.NaN(), wantOK: false},
1151 {name: "inf", input: math.Inf(1), wantOK: false},
1152 {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1153 }
1154
1155 for _, tt := range tests {
1156 t.Run(tt.name, func(t *testing.T) {
1157 got, ok := anyToInt64(tt.input)
1158 require.Equal(t, tt.wantOK, ok)
1159 if tt.wantOK {
1160 require.Equal(t, tt.want, got)
1161 }
1162 })
1163 }
1164}
1165
1166func TestAnyToUserLocation(t *testing.T) {
1167 t.Parallel()
1168
1169 t.Run("pointer passthrough", func(t *testing.T) {
1170 t.Parallel()
1171
1172 input := &UserLocation{City: "San Francisco", Country: "US"}
1173 got := anyToUserLocation(input)
1174 require.Same(t, input, got)
1175 })
1176
1177 t.Run("struct value", func(t *testing.T) {
1178 t.Parallel()
1179
1180 got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1181 require.NotNil(t, got)
1182 require.Equal(t, "San Francisco", got.City)
1183 require.Equal(t, "US", got.Country)
1184 })
1185
1186 t.Run("map value", func(t *testing.T) {
1187 t.Parallel()
1188
1189 got := anyToUserLocation(map[string]any{
1190 "city": "San Francisco",
1191 "region": "CA",
1192 "country": "US",
1193 "timezone": "America/Los_Angeles",
1194 "type": "approximate",
1195 })
1196 require.NotNil(t, got)
1197 require.Equal(t, "San Francisco", got.City)
1198 require.Equal(t, "CA", got.Region)
1199 require.Equal(t, "US", got.Country)
1200 require.Equal(t, "America/Los_Angeles", got.Timezone)
1201 })
1202
1203 t.Run("empty map", func(t *testing.T) {
1204 t.Parallel()
1205
1206 got := anyToUserLocation(map[string]any{"type": "approximate"})
1207 require.Nil(t, got)
1208 })
1209
1210 t.Run("unsupported type", func(t *testing.T) {
1211 t.Parallel()
1212
1213 got := anyToUserLocation("San Francisco")
1214 require.Nil(t, got)
1215 })
1216}
1217
1218func TestStream_WebSearchResponse(t *testing.T) {
1219 t.Parallel()
1220
1221 // Build SSE chunks that simulate a web search streaming response.
1222 // The Anthropic SDK accumulates content blocks via
1223 // acc.Accumulate(event). We read the Content and ToolUseID
1224 // directly from struct fields instead of using AsAny(), which
1225 // avoids the SDK's re-marshal limitation that previously dropped
1226 // source data.
1227 webSearchResultContent, _ := json.Marshal([]any{
1228 map[string]any{
1229 "type": "web_search_result",
1230 "url": "https://example.com/ai-news",
1231 "title": "Latest AI News",
1232 "encrypted_content": "encrypted_abc123",
1233 "page_age": "2 hours ago",
1234 },
1235 })
1236
1237 chunks := []string{
1238 // message_start
1239 "event: message_start\n",
1240 `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1241 // Block 0: server_tool_use
1242 "event: content_block_start\n",
1243 `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1244 "event: content_block_stop\n",
1245 `data: {"type":"content_block_stop","index":0}` + "\n\n",
1246 // Block 1: web_search_tool_result
1247 "event: content_block_start\n",
1248 `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1249 "event: content_block_stop\n",
1250 `data: {"type":"content_block_stop","index":1}` + "\n\n",
1251 // Block 2: text
1252 "event: content_block_start\n",
1253 `data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1254 "event: content_block_delta\n",
1255 `data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1256 "event: content_block_stop\n",
1257 `data: {"type":"content_block_stop","index":2}` + "\n\n",
1258 // message_stop
1259 "event: message_stop\n",
1260 `data: {"type":"message_stop"}` + "\n\n",
1261 }
1262
1263 server, calls := newAnthropicStreamingServer(chunks)
1264 defer server.Close()
1265
1266 provider, err := New(
1267 WithAPIKey("test-api-key"),
1268 WithBaseURL(server.URL),
1269 )
1270 require.NoError(t, err)
1271
1272 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1273 require.NoError(t, err)
1274
1275 stream, err := model.Stream(context.Background(), fantasy.Call{
1276 Prompt: testPrompt(),
1277 Tools: []fantasy.Tool{
1278 WebSearchTool(nil),
1279 },
1280 })
1281 require.NoError(t, err)
1282
1283 var parts []fantasy.StreamPart
1284 stream(func(part fantasy.StreamPart) bool {
1285 parts = append(parts, part)
1286 return true
1287 })
1288
1289 _ = awaitAnthropicCall(t, calls)
1290
1291 // Collect parts by type for assertions.
1292 var (
1293 toolInputStarts []fantasy.StreamPart
1294 toolCalls []fantasy.StreamPart
1295 toolResults []fantasy.StreamPart
1296 sourceParts []fantasy.StreamPart
1297 textDeltas []fantasy.StreamPart
1298 )
1299 for _, p := range parts {
1300 switch p.Type {
1301 case fantasy.StreamPartTypeToolInputStart:
1302 toolInputStarts = append(toolInputStarts, p)
1303 case fantasy.StreamPartTypeToolCall:
1304 toolCalls = append(toolCalls, p)
1305 case fantasy.StreamPartTypeToolResult:
1306 toolResults = append(toolResults, p)
1307 case fantasy.StreamPartTypeSource:
1308 sourceParts = append(sourceParts, p)
1309 case fantasy.StreamPartTypeTextDelta:
1310 textDeltas = append(textDeltas, p)
1311 }
1312 }
1313
1314 // server_tool_use emits a ToolInputStart with ProviderExecuted.
1315 require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1316 require.True(t, toolInputStarts[0].ProviderExecuted)
1317 require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1318
1319 // server_tool_use emits a ToolCall with ProviderExecuted.
1320 require.NotEmpty(t, toolCalls, "should have a tool call")
1321 require.True(t, toolCalls[0].ProviderExecuted)
1322
1323 // web_search_tool_result always emits a ToolResult even when
1324 // the SDK drops source data. The ToolUseID comes from the raw
1325 // union field as a fallback.
1326 require.NotEmpty(t, toolResults, "should have a tool result")
1327 require.True(t, toolResults[0].ProviderExecuted)
1328 require.Equal(t, "web_search", toolResults[0].ToolCallName)
1329 require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1330 "tool result ID should match the tool_use_id")
1331
1332 // Source parts are now correctly emitted by reading struct fields
1333 // directly instead of using AsAny().
1334 require.Len(t, sourceParts, 1)
1335 require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1336 require.Equal(t, "Latest AI News", sourceParts[0].Title)
1337 require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1338
1339 // Text block emits a text delta.
1340 require.NotEmpty(t, textDeltas, "should have text deltas")
1341 require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1342}
1343
1344func TestGenerate_ToolChoiceNone(t *testing.T) {
1345 t.Parallel()
1346
1347 server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1348 defer server.Close()
1349
1350 provider, err := New(
1351 WithAPIKey("test-api-key"),
1352 WithBaseURL(server.URL),
1353 )
1354 require.NoError(t, err)
1355
1356 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1357 require.NoError(t, err)
1358
1359 toolChoiceNone := fantasy.ToolChoiceNone
1360 _, err = model.Generate(context.Background(), fantasy.Call{
1361 Prompt: testPrompt(),
1362 Tools: []fantasy.Tool{
1363 WebSearchTool(nil),
1364 },
1365 ToolChoice: &toolChoiceNone,
1366 })
1367 require.NoError(t, err)
1368
1369 call := awaitAnthropicCall(t, calls)
1370 toolChoice, ok := call.body["tool_choice"].(map[string]any)
1371 require.True(t, ok, "request body should have tool_choice")
1372 require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1373}
1374
1375// --- Computer Use Tests ---
1376
1377// jsonRoundTripTool simulates a JSON round-trip on a
1378// ProviderDefinedTool so that its Args map contains float64
1379// values (as json.Unmarshal produces) rather than the int64
1380// values that NewComputerUseTool stores directly. The
1381// production toBetaTools code asserts float64.
1382func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1383 t.Helper()
1384 pdt := tool.Definition()
1385 data, err := json.Marshal(pdt.Args)
1386 require.NoError(t, err)
1387 var args map[string]any
1388 require.NoError(t, json.Unmarshal(data, &args))
1389 pdt.Args = args
1390 return pdt
1391}
1392
1393func TestNewComputerUseTool(t *testing.T) {
1394 t.Parallel()
1395
1396 t.Run("creates tool with correct ID and name", func(t *testing.T) {
1397 t.Parallel()
1398 tool := NewComputerUseTool(ComputerUseToolOptions{
1399 DisplayWidthPx: 1920,
1400 DisplayHeightPx: 1080,
1401 ToolVersion: ComputerUse20250124,
1402 }, noopComputerRun).Definition()
1403 require.Equal(t, "anthropic.computer", tool.ID)
1404 require.Equal(t, "computer", tool.Name)
1405 require.Equal(t, int64(1920), tool.Args["display_width_px"])
1406 require.Equal(t, int64(1080), tool.Args["display_height_px"])
1407 require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1408 })
1409
1410 t.Run("includes optional fields when set", func(t *testing.T) {
1411 t.Parallel()
1412 displayNum := int64(1)
1413 enableZoom := true
1414 tool := NewComputerUseTool(ComputerUseToolOptions{
1415 DisplayWidthPx: 1024,
1416 DisplayHeightPx: 768,
1417 DisplayNumber: &displayNum,
1418 EnableZoom: &enableZoom,
1419 ToolVersion: ComputerUse20251124,
1420 CacheControl: &CacheControl{Type: "ephemeral"},
1421 }, noopComputerRun).Definition()
1422 require.Equal(t, int64(1), tool.Args["display_number"])
1423 require.Equal(t, true, tool.Args["enable_zoom"])
1424 require.NotNil(t, tool.Args["cache_control"])
1425 })
1426
1427 t.Run("omits optional fields when nil", func(t *testing.T) {
1428 t.Parallel()
1429 tool := NewComputerUseTool(ComputerUseToolOptions{
1430 DisplayWidthPx: 1920,
1431 DisplayHeightPx: 1080,
1432 ToolVersion: ComputerUse20250124,
1433 }, noopComputerRun).Definition()
1434 _, hasDisplayNum := tool.Args["display_number"]
1435 _, hasEnableZoom := tool.Args["enable_zoom"]
1436 _, hasCacheControl := tool.Args["cache_control"]
1437 require.False(t, hasDisplayNum)
1438 require.False(t, hasEnableZoom)
1439 require.False(t, hasCacheControl)
1440 })
1441}
1442
1443func TestIsComputerUseTool(t *testing.T) {
1444 t.Parallel()
1445
1446 t.Run("returns true for computer use tool", func(t *testing.T) {
1447 t.Parallel()
1448 tool := NewComputerUseTool(ComputerUseToolOptions{
1449 DisplayWidthPx: 1920,
1450 DisplayHeightPx: 1080,
1451 ToolVersion: ComputerUse20250124,
1452 }, noopComputerRun)
1453 require.True(t, IsComputerUseTool(tool.Definition()))
1454 })
1455
1456 t.Run("returns false for function tool", func(t *testing.T) {
1457 t.Parallel()
1458 tool := fantasy.FunctionTool{
1459 Name: "test",
1460 Description: "test tool",
1461 }
1462 require.False(t, IsComputerUseTool(tool))
1463 })
1464
1465 t.Run("returns false for other provider defined tool", func(t *testing.T) {
1466 t.Parallel()
1467 tool := fantasy.ProviderDefinedTool{
1468 ID: "other.tool",
1469 Name: "other",
1470 }
1471 require.False(t, IsComputerUseTool(tool))
1472 })
1473}
1474
1475func TestNeedsBetaAPI(t *testing.T) {
1476 t.Parallel()
1477
1478 lm := languageModel{options: options{}}
1479
1480 t.Run("returns false for empty tools", func(t *testing.T) {
1481 t.Parallel()
1482 _, _, _, betaFlags := lm.toTools(nil, nil, false)
1483 require.Empty(t, betaFlags)
1484 _, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1485 require.Empty(t, betaFlags)
1486 })
1487
1488 t.Run("returns false for only function tools", func(t *testing.T) {
1489 t.Parallel()
1490 tools := []fantasy.Tool{
1491 fantasy.FunctionTool{Name: "test"},
1492 }
1493 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1494 require.Empty(t, betaFlags)
1495 })
1496
1497 t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1498 t.Parallel()
1499 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1500 DisplayWidthPx: 1920,
1501 DisplayHeightPx: 1080,
1502 ToolVersion: ComputerUse20250124,
1503 }, noopComputerRun))
1504 tools := []fantasy.Tool{
1505 fantasy.FunctionTool{Name: "test"},
1506 cuTool,
1507 }
1508 _, _, _, betaFlags := lm.toTools(tools, nil, false)
1509 require.NotEmpty(t, betaFlags)
1510 })
1511}
1512
1513func TestComputerUseToolJSON(t *testing.T) {
1514 t.Parallel()
1515
1516 t.Run("builds JSON for version 20250124", func(t *testing.T) {
1517 t.Parallel()
1518 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1519 DisplayWidthPx: 1920,
1520 DisplayHeightPx: 1080,
1521 ToolVersion: ComputerUse20250124,
1522 }, noopComputerRun))
1523 data, err := computerUseToolJSON(cuTool)
1524 require.NoError(t, err)
1525 var m map[string]any
1526 require.NoError(t, json.Unmarshal(data, &m))
1527 require.Equal(t, "computer_20250124", m["type"])
1528 require.Equal(t, "computer", m["name"])
1529 require.InDelta(t, 1920, m["display_width_px"], 0)
1530 require.InDelta(t, 1080, m["display_height_px"], 0)
1531 })
1532
1533 t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1534 t.Parallel()
1535 enableZoom := true
1536 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1537 DisplayWidthPx: 1024,
1538 DisplayHeightPx: 768,
1539 EnableZoom: &enableZoom,
1540 ToolVersion: ComputerUse20251124,
1541 }, noopComputerRun))
1542 data, err := computerUseToolJSON(cuTool)
1543 require.NoError(t, err)
1544 var m map[string]any
1545 require.NoError(t, json.Unmarshal(data, &m))
1546 require.Equal(t, "computer_20251124", m["type"])
1547 require.Equal(t, true, m["enable_zoom"])
1548 })
1549
1550 t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1551 t.Parallel()
1552 // Direct construction stores int64 values.
1553 cuTool := NewComputerUseTool(ComputerUseToolOptions{
1554 DisplayWidthPx: 1920,
1555 DisplayHeightPx: 1080,
1556 ToolVersion: ComputerUse20250124,
1557 }, noopComputerRun)
1558 data, err := computerUseToolJSON(cuTool.Definition())
1559 require.NoError(t, err)
1560 var m map[string]any
1561 require.NoError(t, json.Unmarshal(data, &m))
1562 require.InDelta(t, 1920, m["display_width_px"], 0)
1563 })
1564
1565 t.Run("returns error when version is missing", func(t *testing.T) {
1566 t.Parallel()
1567 pdt := fantasy.ProviderDefinedTool{
1568 ID: "anthropic.computer",
1569 Name: "computer",
1570 Args: map[string]any{
1571 "display_width_px": float64(1920),
1572 "display_height_px": float64(1080),
1573 },
1574 }
1575 _, err := computerUseToolJSON(pdt)
1576 require.Error(t, err)
1577 require.Contains(t, err.Error(), "tool_version arg is missing") })
1578
1579 t.Run("returns error for unsupported version", func(t *testing.T) {
1580 t.Parallel()
1581 pdt := fantasy.ProviderDefinedTool{
1582 ID: "anthropic.computer",
1583 Name: "computer",
1584 Args: map[string]any{
1585 "display_width_px": float64(1920),
1586 "display_height_px": float64(1080),
1587 "tool_version": "computer_99991231",
1588 },
1589 }
1590 _, err := computerUseToolJSON(pdt)
1591 require.Error(t, err)
1592 require.Contains(t, err.Error(), "unsupported")
1593 })
1594}
1595
1596func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1597 t.Parallel()
1598
1599 t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1600 t.Parallel()
1601 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1602 require.Error(t, err)
1603 require.Contains(t, err.Error(), "coordinate")
1604 })
1605
1606 t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1607 t.Parallel()
1608 _, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1609 require.Error(t, err)
1610 require.Contains(t, err.Error(), "coordinate")
1611 })
1612
1613 t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1614 t.Parallel()
1615 _, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1616 require.Error(t, err)
1617 require.Contains(t, err.Error(), "start_coordinate")
1618 })
1619
1620 t.Run("rejects region with 3 elements", func(t *testing.T) {
1621 t.Parallel()
1622 _, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1623 require.Error(t, err)
1624 require.Contains(t, err.Error(), "region")
1625 })
1626
1627 t.Run("accepts valid coordinate", func(t *testing.T) {
1628 t.Parallel()
1629 result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1630 require.NoError(t, err)
1631 require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1632 })
1633
1634 t.Run("accepts absent optional arrays", func(t *testing.T) {
1635 t.Parallel()
1636 result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1637 require.NoError(t, err)
1638 require.Equal(t, ActionScreenshot, result.Action)
1639 })
1640}
1641
1642func TestToTools_RawJSON(t *testing.T) {
1643 t.Parallel()
1644
1645 lm := languageModel{options: options{}}
1646
1647 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1648 DisplayWidthPx: 1920,
1649 DisplayHeightPx: 1080,
1650 ToolVersion: ComputerUse20250124,
1651 }, noopComputerRun))
1652
1653 tools := []fantasy.Tool{
1654 fantasy.FunctionTool{
1655 Name: "weather",
1656 Description: "Get weather",
1657 InputSchema: map[string]any{
1658 "properties": map[string]any{
1659 "location": map[string]any{"type": "string"},
1660 },
1661 "required": []string{"location"},
1662 },
1663 },
1664 WebSearchTool(nil),
1665 cuTool,
1666 }
1667
1668 rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1669
1670 require.Len(t, rawTools, 3)
1671 require.Nil(t, toolChoice)
1672 require.Empty(t, warnings)
1673 require.NotEmpty(t, betaFlags)
1674
1675 // Verify each raw tool is valid JSON.
1676 for i, raw := range rawTools {
1677 var m map[string]any
1678 require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1679 }
1680
1681 // Check function tool.
1682 var funcTool map[string]any
1683 require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1684 require.Equal(t, "weather", funcTool["name"])
1685
1686 // Check web search tool.
1687 var webTool map[string]any
1688 require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1689 require.Equal(t, "web_search_20250305", webTool["type"])
1690
1691 // Check computer use tool.
1692 var cuToolJSON map[string]any
1693 require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1694 require.Equal(t, "computer_20250124", cuToolJSON["type"])
1695 require.Equal(t, "computer", cuToolJSON["name"])
1696}
1697
1698func TestGenerate_BetaAPI(t *testing.T) {
1699 t.Parallel()
1700
1701 t.Run("sends beta header for computer use", func(t *testing.T) {
1702 t.Parallel()
1703
1704 var capturedHeaders http.Header
1705 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1706 capturedHeaders = r.Header.Clone()
1707 w.Header().Set("Content-Type", "application/json")
1708 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1709 }))
1710 defer server.Close()
1711
1712 provider, err := New(
1713 WithAPIKey("test-api-key"),
1714 WithBaseURL(server.URL),
1715 )
1716 require.NoError(t, err)
1717
1718 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1719 require.NoError(t, err)
1720
1721 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1722 DisplayWidthPx: 1920,
1723 DisplayHeightPx: 1080,
1724 ToolVersion: ComputerUse20250124,
1725 }, noopComputerRun))
1726
1727 _, err = model.Generate(context.Background(), fantasy.Call{
1728 Prompt: testPrompt(),
1729 Tools: []fantasy.Tool{cuTool},
1730 })
1731 require.NoError(t, err)
1732 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1733 })
1734
1735 t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1736 t.Parallel()
1737
1738 var capturedHeaders http.Header
1739 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1740 capturedHeaders = r.Header.Clone()
1741 w.Header().Set("Content-Type", "application/json")
1742 _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1743 }))
1744 defer server.Close()
1745
1746 provider, err := New(
1747 WithAPIKey("test-api-key"),
1748 WithBaseURL(server.URL),
1749 )
1750 require.NoError(t, err)
1751
1752 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1753 require.NoError(t, err)
1754
1755 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1756 DisplayWidthPx: 1920,
1757 DisplayHeightPx: 1080,
1758 ToolVersion: ComputerUse20251124,
1759 }, noopComputerRun))
1760
1761 _, err = model.Generate(context.Background(), fantasy.Call{
1762 Prompt: testPrompt(),
1763 Tools: []fantasy.Tool{cuTool},
1764 })
1765 require.NoError(t, err)
1766 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1767 })
1768
1769 t.Run("returns tool use from beta response", func(t *testing.T) {
1770 t.Parallel()
1771
1772 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1773 w.Header().Set("Content-Type", "application/json")
1774 _ = json.NewEncoder(w).Encode(map[string]any{
1775 "id": "msg_01Test",
1776 "type": "message",
1777 "role": "assistant",
1778 "model": "claude-sonnet-4-20250514",
1779 "content": []any{
1780 map[string]any{
1781 "type": "tool_use",
1782 "id": "toolu_01",
1783 "name": "computer",
1784 "input": map[string]any{"action": "screenshot"},
1785 },
1786 },
1787 "stop_reason": "tool_use",
1788 "usage": map[string]any{
1789 "input_tokens": 10,
1790 "output_tokens": 5,
1791 "cache_creation": map[string]any{
1792 "ephemeral_1h_input_tokens": 0,
1793 "ephemeral_5m_input_tokens": 0,
1794 },
1795 "cache_creation_input_tokens": 0,
1796 "cache_read_input_tokens": 0,
1797 "server_tool_use": map[string]any{
1798 "web_search_requests": 0,
1799 },
1800 "service_tier": "standard",
1801 },
1802 })
1803 }))
1804 defer server.Close()
1805
1806 provider, err := New(
1807 WithAPIKey("test-api-key"),
1808 WithBaseURL(server.URL),
1809 )
1810 require.NoError(t, err)
1811
1812 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1813 require.NoError(t, err)
1814
1815 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1816 DisplayWidthPx: 1920,
1817 DisplayHeightPx: 1080,
1818 ToolVersion: ComputerUse20250124,
1819 }, noopComputerRun))
1820
1821 resp, err := model.Generate(context.Background(), fantasy.Call{
1822 Prompt: testPrompt(),
1823 Tools: []fantasy.Tool{cuTool},
1824 })
1825 require.NoError(t, err)
1826
1827 toolCalls := resp.Content.ToolCalls()
1828 require.Len(t, toolCalls, 1)
1829 require.Equal(t, "computer", toolCalls[0].ToolName)
1830 require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1831 require.Contains(t, toolCalls[0].Input, "screenshot")
1832 require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1833
1834 // Verify typed parsing works on the tool call input.
1835 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1836 require.NoError(t, err)
1837 require.Equal(t, ActionScreenshot, parsed.Action)
1838 })
1839}
1840
1841func TestStream_BetaAPI(t *testing.T) {
1842 t.Parallel()
1843
1844 t.Run("streams via beta API for computer use", func(t *testing.T) {
1845 t.Parallel()
1846
1847 var capturedHeaders http.Header
1848 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1849 capturedHeaders = r.Header.Clone()
1850 w.Header().Set("Content-Type", "text/event-stream")
1851 w.Header().Set("Cache-Control", "no-cache")
1852 w.WriteHeader(http.StatusOK)
1853 chunks := []string{
1854 "event: message_start\n",
1855 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1856 "event: message_stop\n",
1857 "data: {\"type\":\"message_stop\"}\n\n",
1858 }
1859 for _, chunk := range chunks {
1860 _, _ = fmt.Fprint(w, chunk)
1861 if flusher, ok := w.(http.Flusher); ok {
1862 flusher.Flush()
1863 }
1864 }
1865 }))
1866 defer server.Close()
1867
1868 provider, err := New(
1869 WithAPIKey("test-api-key"),
1870 WithBaseURL(server.URL),
1871 )
1872 require.NoError(t, err)
1873
1874 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1875 require.NoError(t, err)
1876
1877 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1878 DisplayWidthPx: 1920,
1879 DisplayHeightPx: 1080,
1880 ToolVersion: ComputerUse20250124,
1881 }, noopComputerRun))
1882
1883 stream, err := model.Stream(context.Background(), fantasy.Call{
1884 Prompt: testPrompt(),
1885 Tools: []fantasy.Tool{cuTool},
1886 })
1887 require.NoError(t, err)
1888
1889 stream(func(fantasy.StreamPart) bool { return true })
1890
1891 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1892 })
1893
1894 t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1895 t.Parallel()
1896
1897 var capturedHeaders http.Header
1898 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1899 capturedHeaders = r.Header.Clone()
1900 w.Header().Set("Content-Type", "text/event-stream")
1901 w.Header().Set("Cache-Control", "no-cache")
1902 w.WriteHeader(http.StatusOK)
1903 chunks := []string{
1904 "event: message_start\n",
1905 "data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1906 "event: message_stop\n",
1907 "data: {\"type\":\"message_stop\"}\n\n",
1908 }
1909 for _, chunk := range chunks {
1910 _, _ = fmt.Fprint(w, chunk)
1911 if flusher, ok := w.(http.Flusher); ok {
1912 flusher.Flush()
1913 }
1914 }
1915 }))
1916 defer server.Close()
1917
1918 provider, err := New(
1919 WithAPIKey("test-api-key"),
1920 WithBaseURL(server.URL),
1921 )
1922 require.NoError(t, err)
1923
1924 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1925 require.NoError(t, err)
1926
1927 cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1928 DisplayWidthPx: 1920,
1929 DisplayHeightPx: 1080,
1930 ToolVersion: ComputerUse20251124,
1931 }, noopComputerRun))
1932
1933 stream, err := model.Stream(context.Background(), fantasy.Call{
1934 Prompt: testPrompt(),
1935 Tools: []fantasy.Tool{cuTool},
1936 })
1937 require.NoError(t, err)
1938
1939 stream(func(fantasy.StreamPart) bool { return true })
1940
1941 require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1942 })
1943}
1944
1945// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1946// via model.Generate, passing the ExecutableProviderTool directly into
1947// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1948// walks through a scripted sequence of actions — screenshot, click,
1949// type, key, scroll — then finishes with a text reply. Each turn the
1950// test parses the tool call, builds a screenshot result, and appends
1951// both to the prompt for the next request.
1952func TestGenerate_ComputerUseTool(t *testing.T) {
1953 t.Parallel()
1954
1955 type actionStep struct {
1956 input map[string]any
1957 want ComputerUseInput
1958 }
1959 steps := []actionStep{
1960 {
1961 input: map[string]any{"action": "screenshot"},
1962 want: ComputerUseInput{Action: ActionScreenshot},
1963 },
1964 {
1965 input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
1966 want: ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
1967 },
1968 {
1969 input: map[string]any{"action": "type", "text": "hello world"},
1970 want: ComputerUseInput{Action: ActionType, Text: "hello world"},
1971 },
1972 {
1973 input: map[string]any{"action": "key", "text": "Return"},
1974 want: ComputerUseInput{Action: ActionKey, Text: "Return"},
1975 },
1976 {
1977 input: map[string]any{
1978 "action": "scroll",
1979 "coordinate": []any{500, 300},
1980 "scroll_direction": "down",
1981 "scroll_amount": 3,
1982 },
1983 want: ComputerUseInput{
1984 Action: ActionScroll,
1985 Coordinate: [2]int64{500, 300},
1986 ScrollDirection: "down",
1987 ScrollAmount: 3,
1988 },
1989 },
1990 {
1991 input: map[string]any{"action": "screenshot"},
1992 want: ComputerUseInput{Action: ActionScreenshot},
1993 },
1994 }
1995
1996 var (
1997 requestIdx int
1998 betaHeaders []string
1999 )
2000 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2001 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2002 idx := requestIdx
2003 requestIdx++
2004
2005 w.Header().Set("Content-Type", "application/json")
2006 if idx < len(steps) {
2007 _ = json.NewEncoder(w).Encode(map[string]any{
2008 "id": fmt.Sprintf("msg_%02d", idx),
2009 "type": "message",
2010 "role": "assistant",
2011 "model": "claude-sonnet-4-20250514",
2012 "content": []any{map[string]any{
2013 "type": "tool_use",
2014 "id": fmt.Sprintf("toolu_%02d", idx),
2015 "name": "computer",
2016 "input": steps[idx].input,
2017 }},
2018 "stop_reason": "tool_use",
2019 "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
2020 })
2021 return
2022 }
2023 _ = json.NewEncoder(w).Encode(map[string]any{
2024 "id": "msg_final",
2025 "type": "message",
2026 "role": "assistant",
2027 "model": "claude-sonnet-4-20250514",
2028 "content": []any{map[string]any{
2029 "type": "text",
2030 "text": "Done! I have completed all the requested actions.",
2031 }},
2032 "stop_reason": "end_turn",
2033 "usage": map[string]any{"input_tokens": 10, "output_tokens": 15},
2034 })
2035 }))
2036 defer server.Close()
2037
2038 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2039 require.NoError(t, err)
2040
2041 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2042 require.NoError(t, err)
2043
2044 // Pass the ExecutableProviderTool directly — the whole point is
2045 // to verify that the Tool interface works without unwrapping.
2046 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2047 DisplayWidthPx: 1920,
2048 DisplayHeightPx: 1080,
2049 ToolVersion: ComputerUse20250124,
2050 }, noopComputerRun)
2051
2052 var got []ComputerUseInput
2053 prompt := testPrompt()
2054 fakePNG := []byte("fake-screenshot-png")
2055
2056 for turn := 0; turn <= len(steps); turn++ {
2057 resp, err := model.Generate(context.Background(), fantasy.Call{
2058 Prompt: prompt,
2059 Tools: []fantasy.Tool{cuTool},
2060 })
2061 require.NoError(t, err, "turn %d", turn)
2062
2063 if resp.FinishReason != fantasy.FinishReasonToolCalls {
2064 require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2065 require.Contains(t, resp.Content.Text(), "Done")
2066 break
2067 }
2068
2069 toolCalls := resp.Content.ToolCalls()
2070 require.Len(t, toolCalls, 1, "turn %d", turn)
2071 require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2072
2073 parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2074 require.NoError(t, err, "turn %d", turn)
2075 got = append(got, parsed)
2076
2077 // Build the next prompt: append the assistant tool-call turn
2078 // and the user screenshot-result turn.
2079 prompt = append(prompt,
2080 fantasy.Message{
2081 Role: fantasy.MessageRoleAssistant,
2082 Content: []fantasy.MessagePart{
2083 fantasy.ToolCallPart{
2084 ToolCallID: toolCalls[0].ToolCallID,
2085 ToolName: toolCalls[0].ToolName,
2086 Input: toolCalls[0].Input,
2087 },
2088 },
2089 },
2090 fantasy.Message{
2091 // Use MessageRoleTool for tool results — this matches
2092 // what the agent loop produces.
2093 Role: fantasy.MessageRoleTool,
2094 Content: []fantasy.MessagePart{
2095 NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2096 },
2097 },
2098 )
2099 }
2100
2101 // Every scripted action was received and parsed correctly.
2102 require.Len(t, got, len(steps))
2103 for i, step := range steps {
2104 require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2105 require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2106 require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2107 require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2108 require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2109 }
2110
2111 // Beta header was sent on every request.
2112 require.Len(t, betaHeaders, len(steps)+1)
2113 for i, h := range betaHeaders {
2114 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2115 }
2116}
2117
2118// TestStream_ComputerUseTool runs a multi-turn computer use session
2119// via model.Stream, verifying that the ExecutableProviderTool works
2120// through the streaming path end-to-end.
2121func TestStream_ComputerUseTool(t *testing.T) {
2122 t.Parallel()
2123
2124 type streamStep struct {
2125 input map[string]any
2126 wantAction ComputerAction
2127 }
2128 steps := []streamStep{
2129 {input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2130 {input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2131 {input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2132 }
2133
2134 var (
2135 requestIdx int
2136 betaHeaders []string
2137 )
2138
2139 // streamToolUseChunks returns SSE chunks for a single
2140 // computer-use tool_use content block.
2141 streamToolUseChunks := func(id string, input map[string]any) []string {
2142 inputJSON, _ := json.Marshal(input)
2143 escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2144 return []string{
2145 "event: message_start\n",
2146 `data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2147 "event: content_block_start\n",
2148 `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2149 "event: content_block_delta\n",
2150 `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2151 "event: content_block_stop\n",
2152 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2153 "event: message_delta\n",
2154 `data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2155 "event: message_stop\n",
2156 `data: {"type":"message_stop"}` + "\n\n",
2157 }
2158 }
2159
2160 streamTextChunks := func() []string {
2161 return []string{
2162 "event: message_start\n",
2163 `data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2164 "event: content_block_start\n",
2165 `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2166 "event: content_block_delta\n",
2167 `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2168 "event: content_block_stop\n",
2169 `data: {"type":"content_block_stop","index":0}` + "\n\n",
2170 "event: message_delta\n",
2171 `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2172 "event: message_stop\n",
2173 `data: {"type":"message_stop"}` + "\n\n",
2174 }
2175 }
2176
2177 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2178 betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2179 idx := requestIdx
2180 requestIdx++
2181
2182 w.Header().Set("Content-Type", "text/event-stream")
2183 w.Header().Set("Cache-Control", "no-cache")
2184 w.WriteHeader(http.StatusOK)
2185
2186 var chunks []string
2187 if idx < len(steps) {
2188 chunks = streamToolUseChunks(
2189 fmt.Sprintf("toolu_%02d", idx),
2190 steps[idx].input,
2191 )
2192 } else {
2193 chunks = streamTextChunks()
2194 }
2195 for _, chunk := range chunks {
2196 _, _ = fmt.Fprint(w, chunk)
2197 if f, ok := w.(http.Flusher); ok {
2198 f.Flush()
2199 }
2200 }
2201 }))
2202 defer server.Close()
2203
2204 provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2205 require.NoError(t, err)
2206
2207 model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2208 require.NoError(t, err)
2209
2210 cuTool := NewComputerUseTool(ComputerUseToolOptions{
2211 DisplayWidthPx: 1920,
2212 DisplayHeightPx: 1080,
2213 ToolVersion: ComputerUse20250124,
2214 }, noopComputerRun)
2215
2216 var gotActions []ComputerAction
2217 prompt := testPrompt()
2218 fakePNG := []byte("fake-screenshot-png")
2219
2220 for turn := 0; turn <= len(steps); turn++ {
2221 stream, err := model.Stream(context.Background(), fantasy.Call{
2222 Prompt: prompt,
2223 Tools: []fantasy.Tool{cuTool},
2224 })
2225 require.NoError(t, err, "turn %d", turn)
2226
2227 var (
2228 toolCallName string
2229 toolCallID string
2230 toolCallInput string
2231 finishReason fantasy.FinishReason
2232 gotText string
2233 )
2234 stream(func(part fantasy.StreamPart) bool {
2235 switch part.Type {
2236 case fantasy.StreamPartTypeToolCall:
2237 toolCallName = part.ToolCallName
2238 toolCallID = part.ID
2239 toolCallInput = part.ToolCallInput
2240 case fantasy.StreamPartTypeFinish:
2241 finishReason = part.FinishReason
2242 case fantasy.StreamPartTypeTextDelta:
2243 gotText += part.Delta
2244 }
2245 return true
2246 })
2247
2248 if finishReason != fantasy.FinishReasonToolCalls {
2249 require.Contains(t, gotText, "All done")
2250 break
2251 }
2252
2253 require.Equal(t, "computer", toolCallName, "turn %d", turn)
2254
2255 parsed, err := ParseComputerUseInput(toolCallInput)
2256 require.NoError(t, err, "turn %d", turn)
2257 gotActions = append(gotActions, parsed.Action)
2258
2259 prompt = append(prompt,
2260 fantasy.Message{
2261 Role: fantasy.MessageRoleAssistant,
2262 Content: []fantasy.MessagePart{
2263 fantasy.ToolCallPart{
2264 ToolCallID: toolCallID,
2265 ToolName: toolCallName,
2266 Input: toolCallInput,
2267 },
2268 },
2269 },
2270 fantasy.Message{
2271 // Use MessageRoleTool for tool results — this matches
2272 // what the agent loop produces.
2273 Role: fantasy.MessageRoleTool,
2274 Content: []fantasy.MessagePart{
2275 NewComputerUseScreenshotResult(toolCallID, fakePNG),
2276 },
2277 },
2278 )
2279 }
2280
2281 require.Len(t, gotActions, len(steps))
2282 for i, step := range steps {
2283 require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2284 }
2285
2286 require.Len(t, betaHeaders, len(steps)+1)
2287 for i, h := range betaHeaders {
2288 require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2289 }
2290}