anthropic_test.go

   1package anthropic
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"math"
   9	"net/http"
  10	"net/http/httptest"
  11	"strings"
  12	"testing"
  13	"time"
  14
  15	"charm.land/fantasy"
  16	"github.com/charmbracelet/anthropic-sdk-go"
  17	"github.com/stretchr/testify/require"
  18)
  19
  20// noopComputerRun is a no-op run function for tests that only need
  21// to inspect the tool definition, not execute it.
  22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
  23	return fantasy.ToolResponse{}, nil
  24}
  25
  26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
  27	t.Parallel()
  28
  29	t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
  30		t.Parallel()
  31
  32		prompt := fantasy.Prompt{
  33			{
  34				Role: fantasy.MessageRoleUser,
  35				Content: []fantasy.MessagePart{
  36					fantasy.TextPart{Text: "Hello"},
  37				},
  38			},
  39			{
  40				Role: fantasy.MessageRoleAssistant,
  41				Content: []fantasy.MessagePart{
  42					fantasy.ReasoningPart{
  43						Text: "Let me think about this...",
  44						ProviderOptions: fantasy.ProviderOptions{
  45							Name: &ReasoningOptionMetadata{
  46								Signature: "abc123",
  47							},
  48						},
  49					},
  50				},
  51			},
  52		}
  53
  54		systemBlocks, messages, warnings := toPrompt(prompt, true)
  55
  56		require.Empty(t, systemBlocks)
  57		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  58		require.Len(t, warnings, 1)
  59		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  60		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
  61		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
  62	})
  63
  64	t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
  65		t.Parallel()
  66
  67		prompt := fantasy.Prompt{
  68			{
  69				Role: fantasy.MessageRoleUser,
  70				Content: []fantasy.MessagePart{
  71					fantasy.TextPart{Text: "Hello"},
  72				},
  73			},
  74			{
  75				Role: fantasy.MessageRoleAssistant,
  76				Content: []fantasy.MessagePart{
  77					fantasy.ReasoningPart{
  78						Text: "Let me think about this...",
  79						ProviderOptions: fantasy.ProviderOptions{
  80							Name: &ReasoningOptionMetadata{
  81								Signature: "def456",
  82							},
  83						},
  84					},
  85				},
  86			},
  87		}
  88
  89		systemBlocks, messages, warnings := toPrompt(prompt, false)
  90
  91		require.Empty(t, systemBlocks)
  92		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  93		require.Len(t, warnings, 2)
  94		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  95		require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
  96		require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
  97		require.Contains(t, warnings[1].Message, "dropping empty assistant message")
  98	})
  99
 100	t.Run("should drop truly empty assistant messages", func(t *testing.T) {
 101		t.Parallel()
 102
 103		prompt := fantasy.Prompt{
 104			{
 105				Role: fantasy.MessageRoleUser,
 106				Content: []fantasy.MessagePart{
 107					fantasy.TextPart{Text: "Hello"},
 108				},
 109			},
 110			{
 111				Role:    fantasy.MessageRoleAssistant,
 112				Content: []fantasy.MessagePart{},
 113			},
 114		}
 115
 116		systemBlocks, messages, warnings := toPrompt(prompt, true)
 117
 118		require.Empty(t, systemBlocks)
 119		require.Len(t, messages, 1, "should only have user message")
 120		require.Len(t, warnings, 1)
 121		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 122		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 123	})
 124
 125	t.Run("should keep assistant messages with text content", func(t *testing.T) {
 126		t.Parallel()
 127
 128		prompt := fantasy.Prompt{
 129			{
 130				Role: fantasy.MessageRoleUser,
 131				Content: []fantasy.MessagePart{
 132					fantasy.TextPart{Text: "Hello"},
 133				},
 134			},
 135			{
 136				Role: fantasy.MessageRoleAssistant,
 137				Content: []fantasy.MessagePart{
 138					fantasy.TextPart{Text: "Hi there!"},
 139				},
 140			},
 141		}
 142
 143		systemBlocks, messages, warnings := toPrompt(prompt, true)
 144
 145		require.Empty(t, systemBlocks)
 146		require.Len(t, messages, 2, "should have both user and assistant messages")
 147		require.Empty(t, warnings)
 148	})
 149
 150	t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
 151		t.Parallel()
 152
 153		prompt := fantasy.Prompt{
 154			{
 155				Role: fantasy.MessageRoleUser,
 156				Content: []fantasy.MessagePart{
 157					fantasy.TextPart{Text: "What's the weather?"},
 158				},
 159			},
 160			{
 161				Role: fantasy.MessageRoleAssistant,
 162				Content: []fantasy.MessagePart{
 163					fantasy.ToolCallPart{
 164						ToolCallID: "call_123",
 165						ToolName:   "get_weather",
 166						Input:      `{"location":"NYC"}`,
 167					},
 168				},
 169			},
 170		}
 171
 172		systemBlocks, messages, warnings := toPrompt(prompt, true)
 173
 174		require.Empty(t, systemBlocks)
 175		require.Len(t, messages, 2, "should have both user and assistant messages")
 176		require.Empty(t, warnings)
 177	})
 178
 179	t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
 180		t.Parallel()
 181
 182		prompt := fantasy.Prompt{
 183			{
 184				Role: fantasy.MessageRoleUser,
 185				Content: []fantasy.MessagePart{
 186					fantasy.TextPart{Text: "Hi"},
 187				},
 188			},
 189			{
 190				Role: fantasy.MessageRoleAssistant,
 191				Content: []fantasy.MessagePart{
 192					fantasy.ToolCallPart{
 193						ToolCallID: "call_123",
 194						ToolName:   "get_weather",
 195						Input:      "{not-json",
 196					},
 197				},
 198			},
 199		}
 200
 201		systemBlocks, messages, warnings := toPrompt(prompt, true)
 202
 203		require.Empty(t, systemBlocks)
 204		require.Len(t, messages, 1, "should only have user message")
 205		require.Len(t, warnings, 1)
 206		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 207		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 208	})
 209
 210	t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
 211		t.Parallel()
 212
 213		prompt := fantasy.Prompt{
 214			{
 215				Role: fantasy.MessageRoleUser,
 216				Content: []fantasy.MessagePart{
 217					fantasy.TextPart{Text: "Hello"},
 218				},
 219			},
 220			{
 221				Role: fantasy.MessageRoleAssistant,
 222				Content: []fantasy.MessagePart{
 223					fantasy.ReasoningPart{
 224						Text: "Let me think...",
 225						ProviderOptions: fantasy.ProviderOptions{
 226							Name: &ReasoningOptionMetadata{
 227								Signature: "abc123",
 228							},
 229						},
 230					},
 231					fantasy.TextPart{Text: "Hi there!"},
 232				},
 233			},
 234		}
 235
 236		systemBlocks, messages, warnings := toPrompt(prompt, true)
 237
 238		require.Empty(t, systemBlocks)
 239		require.Len(t, messages, 2, "should have both user and assistant messages")
 240		require.Empty(t, warnings)
 241	})
 242
 243	t.Run("should keep user messages with image content", func(t *testing.T) {
 244		t.Parallel()
 245
 246		prompt := fantasy.Prompt{
 247			{
 248				Role: fantasy.MessageRoleUser,
 249				Content: []fantasy.MessagePart{
 250					fantasy.FilePart{
 251						Data:      []byte{0x01, 0x02, 0x03},
 252						MediaType: "image/png",
 253					},
 254				},
 255			},
 256		}
 257
 258		systemBlocks, messages, warnings := toPrompt(prompt, true)
 259
 260		require.Empty(t, systemBlocks)
 261		require.Len(t, messages, 1)
 262		require.Empty(t, warnings)
 263	})
 264
 265	t.Run("should drop user messages without visible content", func(t *testing.T) {
 266		t.Parallel()
 267
 268		prompt := fantasy.Prompt{
 269			{
 270				Role: fantasy.MessageRoleUser,
 271				Content: []fantasy.MessagePart{
 272					fantasy.FilePart{
 273						Data:      []byte("not supported"),
 274						MediaType: "application/pdf",
 275					},
 276				},
 277			},
 278		}
 279
 280		systemBlocks, messages, warnings := toPrompt(prompt, true)
 281
 282		require.Empty(t, systemBlocks)
 283		require.Empty(t, messages)
 284		require.Len(t, warnings, 1)
 285		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 286		require.Contains(t, warnings[0].Message, "dropping empty user message")
 287		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
 288	})
 289
 290	t.Run("should keep user messages with tool results", func(t *testing.T) {
 291		t.Parallel()
 292
 293		prompt := fantasy.Prompt{
 294			{
 295				Role: fantasy.MessageRoleTool,
 296				Content: []fantasy.MessagePart{
 297					fantasy.ToolResultPart{
 298						ToolCallID: "call_123",
 299						Output:     fantasy.ToolResultOutputContentText{Text: "done"},
 300					},
 301				},
 302			},
 303		}
 304
 305		systemBlocks, messages, warnings := toPrompt(prompt, true)
 306
 307		require.Empty(t, systemBlocks)
 308		require.Len(t, messages, 1)
 309		require.Empty(t, warnings)
 310	})
 311
 312	t.Run("should keep user messages with tool error results", func(t *testing.T) {
 313		t.Parallel()
 314
 315		prompt := fantasy.Prompt{
 316			{
 317				Role: fantasy.MessageRoleTool,
 318				Content: []fantasy.MessagePart{
 319					fantasy.ToolResultPart{
 320						ToolCallID: "call_456",
 321						Output:     fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
 322					},
 323				},
 324			},
 325		}
 326
 327		systemBlocks, messages, warnings := toPrompt(prompt, true)
 328
 329		require.Empty(t, systemBlocks)
 330		require.Len(t, messages, 1)
 331		require.Empty(t, warnings)
 332	})
 333
 334	t.Run("should keep user messages with tool media results", func(t *testing.T) {
 335		t.Parallel()
 336
 337		prompt := fantasy.Prompt{
 338			{
 339				Role: fantasy.MessageRoleTool,
 340				Content: []fantasy.MessagePart{
 341					fantasy.ToolResultPart{
 342						ToolCallID: "call_789",
 343						Output: fantasy.ToolResultOutputContentMedia{
 344							Data:      "AQID",
 345							MediaType: "image/png",
 346						},
 347					},
 348				},
 349			},
 350		}
 351
 352		systemBlocks, messages, warnings := toPrompt(prompt, true)
 353
 354		require.Empty(t, systemBlocks)
 355		require.Len(t, messages, 1)
 356		require.Empty(t, warnings)
 357	})
 358}
 359
 360func TestParseContextTooLargeError(t *testing.T) {
 361	t.Parallel()
 362
 363	tests := []struct {
 364		name     string
 365		message  string
 366		wantErr  bool
 367		wantUsed int
 368		wantMax  int
 369	}{
 370		{
 371			name:     "matches anthropic format",
 372			message:  "prompt is too long: 202630 tokens > 200000 maximum",
 373			wantErr:  true,
 374			wantUsed: 202630,
 375			wantMax:  200000,
 376		},
 377		{
 378			name:     "matches with different numbers",
 379			message:  "prompt is too long: 150000 tokens > 128000 maximum",
 380			wantErr:  true,
 381			wantUsed: 150000,
 382			wantMax:  128000,
 383		},
 384		{
 385			name:     "matches with extra whitespace",
 386			message:  "prompt is too long:  202630  tokens  >  200000  maximum",
 387			wantErr:  true,
 388			wantUsed: 202630,
 389			wantMax:  200000,
 390		},
 391		{
 392			name:    "does not match unrelated error",
 393			message: "invalid api key",
 394			wantErr: false,
 395		},
 396		{
 397			name:    "does not match rate limit error",
 398			message: "rate limit exceeded",
 399			wantErr: false,
 400		},
 401	}
 402
 403	for _, tt := range tests {
 404		t.Run(tt.name, func(t *testing.T) {
 405			t.Parallel()
 406			providerErr := &fantasy.ProviderError{Message: tt.message}
 407			parseContextTooLargeError(tt.message, providerErr)
 408
 409			if tt.wantErr {
 410				require.True(t, providerErr.IsContextTooLarge())
 411				require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
 412				require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
 413			} else {
 414				require.False(t, providerErr.IsContextTooLarge())
 415			}
 416		})
 417	}
 418}
 419
 420func TestParseOptions_Effort(t *testing.T) {
 421	t.Parallel()
 422
 423	options, err := ParseOptions(map[string]any{
 424		"send_reasoning":            true,
 425		"thinking":                  map[string]any{"budget_tokens": int64(2048)},
 426		"effort":                    "medium",
 427		"disable_parallel_tool_use": true,
 428	})
 429	require.NoError(t, err)
 430	require.NotNil(t, options.SendReasoning)
 431	require.True(t, *options.SendReasoning)
 432	require.NotNil(t, options.Thinking)
 433	require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
 434	require.NotNil(t, options.Effort)
 435	require.Equal(t, EffortMedium, *options.Effort)
 436	require.NotNil(t, options.DisableParallelToolUse)
 437	require.True(t, *options.DisableParallelToolUse)
 438}
 439
 440func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
 441	t.Parallel()
 442
 443	server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 444	defer server.Close()
 445
 446	provider, err := New(
 447		WithAPIKey("test-api-key"),
 448		WithBaseURL(server.URL),
 449	)
 450	require.NoError(t, err)
 451
 452	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 453	require.NoError(t, err)
 454
 455	effort := EffortMedium
 456	_, err = model.Generate(context.Background(), fantasy.Call{
 457		Prompt: testPrompt(),
 458		ProviderOptions: NewProviderOptions(&ProviderOptions{
 459			Effort: &effort,
 460		}),
 461	})
 462	require.NoError(t, err)
 463
 464	call := awaitAnthropicCall(t, calls)
 465	require.Equal(t, "POST", call.method)
 466	require.Equal(t, "/v1/messages", call.path)
 467	requireAnthropicEffort(t, call.body, EffortMedium)
 468}
 469
 470func TestStream_SendsOutputConfigEffort(t *testing.T) {
 471	t.Parallel()
 472
 473	server, calls := newAnthropicStreamingServer([]string{
 474		"event: message_start\n",
 475		"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
 476		"event: message_stop\n",
 477		"data: {\"type\":\"message_stop\"}\n\n",
 478	})
 479	defer server.Close()
 480
 481	provider, err := New(
 482		WithAPIKey("test-api-key"),
 483		WithBaseURL(server.URL),
 484	)
 485	require.NoError(t, err)
 486
 487	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 488	require.NoError(t, err)
 489
 490	effort := EffortHigh
 491	stream, err := model.Stream(context.Background(), fantasy.Call{
 492		Prompt: testPrompt(),
 493		ProviderOptions: NewProviderOptions(&ProviderOptions{
 494			Effort: &effort,
 495		}),
 496	})
 497	require.NoError(t, err)
 498
 499	stream(func(fantasy.StreamPart) bool { return true })
 500
 501	call := awaitAnthropicCall(t, calls)
 502	require.Equal(t, "POST", call.method)
 503	require.Equal(t, "/v1/messages", call.path)
 504	requireAnthropicEffort(t, call.body, EffortHigh)
 505}
 506
 507type anthropicCall struct {
 508	method string
 509	path   string
 510	body   map[string]any
 511}
 512
 513func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
 514	calls := make(chan anthropicCall, 4)
 515
 516	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 517		var body map[string]any
 518		if r.Body != nil {
 519			_ = json.NewDecoder(r.Body).Decode(&body)
 520		}
 521
 522		calls <- anthropicCall{
 523			method: r.Method,
 524			path:   r.URL.Path,
 525			body:   body,
 526		}
 527
 528		w.Header().Set("Content-Type", "application/json")
 529		_ = json.NewEncoder(w).Encode(response)
 530	}))
 531
 532	return server, calls
 533}
 534
 535func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
 536	calls := make(chan anthropicCall, 4)
 537
 538	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 539		var body map[string]any
 540		if r.Body != nil {
 541			_ = json.NewDecoder(r.Body).Decode(&body)
 542		}
 543
 544		calls <- anthropicCall{
 545			method: r.Method,
 546			path:   r.URL.Path,
 547			body:   body,
 548		}
 549
 550		w.Header().Set("Content-Type", "text/event-stream")
 551		w.Header().Set("Cache-Control", "no-cache")
 552		w.Header().Set("Connection", "keep-alive")
 553		w.WriteHeader(http.StatusOK)
 554
 555		for _, chunk := range chunks {
 556			_, _ = fmt.Fprint(w, chunk)
 557			if flusher, ok := w.(http.Flusher); ok {
 558				flusher.Flush()
 559			}
 560		}
 561	}))
 562
 563	return server, calls
 564}
 565
 566func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
 567	t.Helper()
 568
 569	select {
 570	case call := <-calls:
 571		return call
 572	case <-time.After(2 * time.Second):
 573		t.Fatal("timed out waiting for Anthropic request")
 574		return anthropicCall{}
 575	}
 576}
 577
 578func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
 579	t.Helper()
 580
 581	select {
 582	case call := <-calls:
 583		t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
 584	case <-time.After(200 * time.Millisecond):
 585	}
 586}
 587
 588func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
 589	t.Helper()
 590
 591	outputConfig, ok := body["output_config"].(map[string]any)
 592	thinking, ok := body["thinking"].(map[string]any)
 593	require.True(t, ok)
 594	require.Equal(t, string(expected), outputConfig["effort"])
 595	require.Equal(t, "adaptive", thinking["type"])
 596}
 597
 598func testPrompt() fantasy.Prompt {
 599	return fantasy.Prompt{
 600		{
 601			Role: fantasy.MessageRoleUser,
 602			Content: []fantasy.MessagePart{
 603				fantasy.TextPart{Text: "Hello"},
 604			},
 605		},
 606	}
 607}
 608
 609func mockAnthropicGenerateResponse() map[string]any {
 610	return map[string]any{
 611		"id":    "msg_01Test",
 612		"type":  "message",
 613		"role":  "assistant",
 614		"model": "claude-sonnet-4-20250514",
 615		"content": []any{
 616			map[string]any{
 617				"type": "text",
 618				"text": "Hi there",
 619			},
 620		},
 621		"stop_reason":   "end_turn",
 622		"stop_sequence": "",
 623		"usage": map[string]any{
 624			"cache_creation": map[string]any{
 625				"ephemeral_1h_input_tokens": 0,
 626				"ephemeral_5m_input_tokens": 0,
 627			},
 628			"cache_creation_input_tokens": 0,
 629			"cache_read_input_tokens":     0,
 630			"input_tokens":                5,
 631			"output_tokens":               2,
 632			"server_tool_use": map[string]any{
 633				"web_search_requests": 0,
 634			},
 635			"service_tier": "standard",
 636		},
 637	}
 638}
 639
 640func mockAnthropicWebSearchResponse() map[string]any {
 641	return map[string]any{
 642		"id":    "msg_01WebSearch",
 643		"type":  "message",
 644		"role":  "assistant",
 645		"model": "claude-sonnet-4-20250514",
 646		"content": []any{
 647			map[string]any{
 648				"type":   "server_tool_use",
 649				"id":     "srvtoolu_01",
 650				"name":   "web_search",
 651				"input":  map[string]any{"query": "latest AI news"},
 652				"caller": map[string]any{"type": "direct"},
 653			},
 654			map[string]any{
 655				"type":        "web_search_tool_result",
 656				"tool_use_id": "srvtoolu_01",
 657				"caller":      map[string]any{"type": "direct"},
 658				"content": []any{
 659					map[string]any{
 660						"type":              "web_search_result",
 661						"url":               "https://example.com/ai-news",
 662						"title":             "Latest AI News",
 663						"encrypted_content": "encrypted_abc123",
 664						"page_age":          "2 hours ago",
 665					},
 666					map[string]any{
 667						"type":              "web_search_result",
 668						"url":               "https://example.com/ml-update",
 669						"title":             "ML Update",
 670						"encrypted_content": "encrypted_def456",
 671						"page_age":          "",
 672					},
 673				},
 674			},
 675			map[string]any{
 676				"type": "text",
 677				"text": "Based on recent search results, here is the latest AI news.",
 678			},
 679		},
 680		"stop_reason":   "end_turn",
 681		"stop_sequence": nil,
 682		"usage": map[string]any{
 683			"input_tokens":                100,
 684			"output_tokens":               50,
 685			"cache_creation_input_tokens": 0,
 686			"cache_read_input_tokens":     0,
 687			"server_tool_use": map[string]any{
 688				"web_search_requests": 1,
 689			},
 690		},
 691	}
 692}
 693
 694func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
 695	t.Parallel()
 696
 697	prompt := fantasy.Prompt{
 698		// User message.
 699		{
 700			Role: fantasy.MessageRoleUser,
 701			Content: []fantasy.MessagePart{
 702				fantasy.TextPart{Text: "Search for the latest AI news"},
 703			},
 704		},
 705		// Assistant message with a provider-executed tool call, its
 706		// result, and trailing text. toResponseMessages routes
 707		// provider-executed results into the assistant message, so
 708		// the prompt already reflects that structure.
 709		{
 710			Role: fantasy.MessageRoleAssistant,
 711			Content: []fantasy.MessagePart{
 712				fantasy.ToolCallPart{
 713					ToolCallID:       "srvtoolu_01",
 714					ToolName:         "web_search",
 715					Input:            `{"query":"latest AI news"}`,
 716					ProviderExecuted: true,
 717				},
 718				fantasy.ToolResultPart{
 719					ToolCallID:       "srvtoolu_01",
 720					ProviderExecuted: true,
 721					ProviderOptions: fantasy.ProviderOptions{
 722						Name: &WebSearchResultMetadata{
 723							Results: []WebSearchResultItem{
 724								{
 725									URL:              "https://example.com/ai-news",
 726									Title:            "Latest AI News",
 727									EncryptedContent: "encrypted_abc123",
 728									PageAge:          "2 hours ago",
 729								},
 730								{
 731									URL:              "https://example.com/ml-update",
 732									Title:            "ML Update",
 733									EncryptedContent: "encrypted_def456",
 734								},
 735							},
 736						},
 737					},
 738				},
 739				fantasy.TextPart{Text: "Here is what I found."},
 740			},
 741		},
 742	}
 743
 744	_, messages, warnings := toPrompt(prompt, true)
 745
 746	// No warnings expected; the provider-executed result is in the
 747	// assistant message so there is no empty tool message to drop.
 748	require.Empty(t, warnings)
 749
 750	// We should have a user message and an assistant message.
 751	require.Len(t, messages, 2, "expected user + assistant messages")
 752
 753	assistantMsg := messages[1]
 754	require.Len(t, assistantMsg.Content, 3,
 755		"expected server_tool_use + web_search_tool_result + text")
 756
 757	// First content block: reconstructed server_tool_use.
 758	serverToolUse := assistantMsg.Content[0]
 759	require.NotNil(t, serverToolUse.OfServerToolUse,
 760		"first block should be a server_tool_use")
 761	require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
 762	require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
 763		serverToolUse.OfServerToolUse.Name)
 764
 765	// Second content block: reconstructed web_search_tool_result with
 766	// encrypted_content preserved for multi-turn round-tripping.
 767	webResult := assistantMsg.Content[1]
 768	require.NotNil(t, webResult.OfWebSearchToolResult,
 769		"second block should be a web_search_tool_result")
 770	require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
 771
 772	results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
 773	require.Len(t, results, 2)
 774	require.Equal(t, "https://example.com/ai-news", results[0].URL)
 775	require.Equal(t, "Latest AI News", results[0].Title)
 776	require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
 777	require.Equal(t, "https://example.com/ml-update", results[1].URL)
 778	require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
 779	// PageAge should be set for the first result and absent for the second.
 780	require.True(t, results[0].PageAge.Valid())
 781	require.Equal(t, "2 hours ago", results[0].PageAge.Value)
 782	require.False(t, results[1].PageAge.Valid())
 783
 784	// Third content block: plain text.
 785	require.NotNil(t, assistantMsg.Content[2].OfText)
 786	require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
 787}
 788
 789func TestGenerate_WebSearchResponse(t *testing.T) {
 790	t.Parallel()
 791
 792	server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
 793	defer server.Close()
 794
 795	provider, err := New(
 796		WithAPIKey("test-api-key"),
 797		WithBaseURL(server.URL),
 798	)
 799	require.NoError(t, err)
 800
 801	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 802	require.NoError(t, err)
 803
 804	resp, err := model.Generate(context.Background(), fantasy.Call{
 805		Prompt: testPrompt(),
 806		Tools: []fantasy.Tool{
 807			WebSearchTool(nil),
 808		},
 809	})
 810	require.NoError(t, err)
 811
 812	call := awaitAnthropicCall(t, calls)
 813	require.Equal(t, "POST", call.method)
 814	require.Equal(t, "/v1/messages", call.path)
 815
 816	// Walk the response content and categorise each item.
 817	var (
 818		toolCalls   []fantasy.ToolCallContent
 819		sources     []fantasy.SourceContent
 820		toolResults []fantasy.ToolResultContent
 821		texts       []fantasy.TextContent
 822	)
 823	for _, c := range resp.Content {
 824		switch v := c.(type) {
 825		case fantasy.ToolCallContent:
 826			toolCalls = append(toolCalls, v)
 827		case fantasy.SourceContent:
 828			sources = append(sources, v)
 829		case fantasy.ToolResultContent:
 830			toolResults = append(toolResults, v)
 831		case fantasy.TextContent:
 832			texts = append(texts, v)
 833		}
 834	}
 835
 836	// ToolCallContent for the provider-executed web_search.
 837	require.Len(t, toolCalls, 1)
 838	require.True(t, toolCalls[0].ProviderExecuted)
 839	require.Equal(t, "web_search", toolCalls[0].ToolName)
 840	require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
 841
 842	// SourceContent entries for each search result.
 843	require.Len(t, sources, 2)
 844	require.Equal(t, "https://example.com/ai-news", sources[0].URL)
 845	require.Equal(t, "Latest AI News", sources[0].Title)
 846	require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
 847	require.Equal(t, "https://example.com/ml-update", sources[1].URL)
 848	require.Equal(t, "ML Update", sources[1].Title)
 849
 850	// ToolResultContent with provider metadata preserving encrypted_content.
 851	require.Len(t, toolResults, 1)
 852	require.True(t, toolResults[0].ProviderExecuted)
 853	require.Equal(t, "web_search", toolResults[0].ToolName)
 854	require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
 855
 856	searchMeta, ok := toolResults[0].ProviderMetadata[Name]
 857	require.True(t, ok, "providerMetadata should contain anthropic key")
 858	webMeta, ok := searchMeta.(*WebSearchResultMetadata)
 859	require.True(t, ok, "metadata should be *WebSearchResultMetadata")
 860	require.Len(t, webMeta.Results, 2)
 861	require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
 862	require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
 863	require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
 864
 865	// TextContent with the final answer.
 866	require.Len(t, texts, 1)
 867	require.Equal(t,
 868		"Based on recent search results, here is the latest AI news.",
 869		texts[0].Text,
 870	)
 871}
 872
 873func TestGenerate_WebSearchToolInRequest(t *testing.T) {
 874	t.Parallel()
 875
 876	t.Run("basic web_search tool", func(t *testing.T) {
 877		t.Parallel()
 878
 879		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 880		defer server.Close()
 881
 882		provider, err := New(
 883			WithAPIKey("test-api-key"),
 884			WithBaseURL(server.URL),
 885		)
 886		require.NoError(t, err)
 887
 888		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 889		require.NoError(t, err)
 890
 891		_, err = model.Generate(context.Background(), fantasy.Call{
 892			Prompt: testPrompt(),
 893			Tools: []fantasy.Tool{
 894				WebSearchTool(nil),
 895			},
 896		})
 897		require.NoError(t, err)
 898
 899		call := awaitAnthropicCall(t, calls)
 900		tools, ok := call.body["tools"].([]any)
 901		require.True(t, ok, "request body should have tools array")
 902		require.Len(t, tools, 1)
 903
 904		tool, ok := tools[0].(map[string]any)
 905		require.True(t, ok)
 906		require.Equal(t, "web_search_20250305", tool["type"])
 907	})
 908
 909	t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
 910		t.Parallel()
 911
 912		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 913		defer server.Close()
 914
 915		provider, err := New(
 916			WithAPIKey("test-api-key"),
 917			WithBaseURL(server.URL),
 918		)
 919		require.NoError(t, err)
 920
 921		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 922		require.NoError(t, err)
 923
 924		_, err = model.Generate(context.Background(), fantasy.Call{
 925			Prompt: testPrompt(),
 926			Tools: []fantasy.Tool{
 927				WebSearchTool(&WebSearchToolOptions{
 928					AllowedDomains: []string{"example.com", "test.com"},
 929				}),
 930			},
 931		})
 932		require.NoError(t, err)
 933
 934		call := awaitAnthropicCall(t, calls)
 935		tools, ok := call.body["tools"].([]any)
 936		require.True(t, ok)
 937		require.Len(t, tools, 1)
 938
 939		tool, ok := tools[0].(map[string]any)
 940		require.True(t, ok)
 941		require.Equal(t, "web_search_20250305", tool["type"])
 942
 943		domains, ok := tool["allowed_domains"].([]any)
 944		require.True(t, ok, "tool should have allowed_domains")
 945		require.Len(t, domains, 2)
 946		require.Equal(t, "example.com", domains[0])
 947		require.Equal(t, "test.com", domains[1])
 948	})
 949
 950	t.Run("with max uses and user location", func(t *testing.T) {
 951		t.Parallel()
 952
 953		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 954		defer server.Close()
 955
 956		provider, err := New(
 957			WithAPIKey("test-api-key"),
 958			WithBaseURL(server.URL),
 959		)
 960		require.NoError(t, err)
 961
 962		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 963		require.NoError(t, err)
 964
 965		_, err = model.Generate(context.Background(), fantasy.Call{
 966			Prompt: testPrompt(),
 967			Tools: []fantasy.Tool{
 968				WebSearchTool(&WebSearchToolOptions{
 969					MaxUses: 5,
 970					UserLocation: &UserLocation{
 971						City:    "San Francisco",
 972						Country: "US",
 973					},
 974				}),
 975			},
 976		})
 977		require.NoError(t, err)
 978
 979		call := awaitAnthropicCall(t, calls)
 980		tools, ok := call.body["tools"].([]any)
 981		require.True(t, ok)
 982		require.Len(t, tools, 1)
 983
 984		tool, ok := tools[0].(map[string]any)
 985		require.True(t, ok)
 986		require.Equal(t, "web_search_20250305", tool["type"])
 987
 988		// max_uses is serialized as a JSON number; json.Unmarshal
 989		// into map[string]any decodes numbers as float64.
 990		maxUses, ok := tool["max_uses"].(float64)
 991		require.True(t, ok, "tool should have max_uses")
 992		require.Equal(t, float64(5), maxUses)
 993
 994		userLoc, ok := tool["user_location"].(map[string]any)
 995		require.True(t, ok, "tool should have user_location")
 996		require.Equal(t, "San Francisco", userLoc["city"])
 997		require.Equal(t, "US", userLoc["country"])
 998		require.Equal(t, "approximate", userLoc["type"])
 999	})
1000
1001	t.Run("with max uses", func(t *testing.T) {
1002		t.Parallel()
1003
1004		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1005		defer server.Close()
1006
1007		provider, err := New(
1008			WithAPIKey("test-api-key"),
1009			WithBaseURL(server.URL),
1010		)
1011		require.NoError(t, err)
1012
1013		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1014		require.NoError(t, err)
1015
1016		_, err = model.Generate(context.Background(), fantasy.Call{
1017			Prompt: testPrompt(),
1018			Tools: []fantasy.Tool{
1019				WebSearchTool(&WebSearchToolOptions{
1020					MaxUses: 3,
1021				}),
1022			},
1023		})
1024		require.NoError(t, err)
1025
1026		call := awaitAnthropicCall(t, calls)
1027		tools, ok := call.body["tools"].([]any)
1028		require.True(t, ok)
1029		require.Len(t, tools, 1)
1030
1031		tool, ok := tools[0].(map[string]any)
1032		require.True(t, ok)
1033		require.Equal(t, "web_search_20250305", tool["type"])
1034
1035		maxUses, ok := tool["max_uses"].(float64)
1036		require.True(t, ok, "tool should have max_uses")
1037		require.Equal(t, float64(3), maxUses)
1038	})
1039
1040	t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1041		t.Parallel()
1042
1043		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1044		defer server.Close()
1045
1046		provider, err := New(
1047			WithAPIKey("test-api-key"),
1048			WithBaseURL(server.URL),
1049		)
1050		require.NoError(t, err)
1051
1052		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1053		require.NoError(t, err)
1054
1055		baseTool := WebSearchTool(&WebSearchToolOptions{
1056			MaxUses:        7,
1057			BlockedDomains: []string{"example.com", "test.com"},
1058			UserLocation: &UserLocation{
1059				City:     "San Francisco",
1060				Region:   "CA",
1061				Country:  "US",
1062				Timezone: "America/Los_Angeles",
1063			},
1064		})
1065
1066		data, err := json.Marshal(baseTool)
1067		require.NoError(t, err)
1068
1069		var roundTripped fantasy.ProviderDefinedTool
1070		err = json.Unmarshal(data, &roundTripped)
1071		require.NoError(t, err)
1072
1073		_, err = model.Generate(context.Background(), fantasy.Call{
1074			Prompt: testPrompt(),
1075			Tools:  []fantasy.Tool{roundTripped},
1076		})
1077		require.NoError(t, err)
1078
1079		call := awaitAnthropicCall(t, calls)
1080		tools, ok := call.body["tools"].([]any)
1081		require.True(t, ok)
1082		require.Len(t, tools, 1)
1083
1084		tool, ok := tools[0].(map[string]any)
1085		require.True(t, ok)
1086		require.Equal(t, "web_search_20250305", tool["type"])
1087
1088		domains, ok := tool["blocked_domains"].([]any)
1089		require.True(t, ok, "tool should have blocked_domains")
1090		require.Len(t, domains, 2)
1091		require.Equal(t, "example.com", domains[0])
1092		require.Equal(t, "test.com", domains[1])
1093
1094		maxUses, ok := tool["max_uses"].(float64)
1095		require.True(t, ok, "tool should have max_uses")
1096		require.Equal(t, float64(7), maxUses)
1097
1098		userLoc, ok := tool["user_location"].(map[string]any)
1099		require.True(t, ok, "tool should have user_location")
1100		require.Equal(t, "San Francisco", userLoc["city"])
1101		require.Equal(t, "CA", userLoc["region"])
1102		require.Equal(t, "US", userLoc["country"])
1103		require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1104		require.Equal(t, "approximate", userLoc["type"])
1105	})
1106}
1107
1108func TestAnyToStringSlice(t *testing.T) {
1109	t.Parallel()
1110
1111	t.Run("from string slice", func(t *testing.T) {
1112		t.Parallel()
1113
1114		got := anyToStringSlice([]string{"example.com", ""})
1115		require.Equal(t, []string{"example.com", ""}, got)
1116	})
1117
1118	t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1119		t.Parallel()
1120
1121		got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1122		require.Equal(t, []string{"example.com", "test.com"}, got)
1123	})
1124
1125	t.Run("unsupported type", func(t *testing.T) {
1126		t.Parallel()
1127
1128		got := anyToStringSlice("example.com")
1129		require.Nil(t, got)
1130	})
1131}
1132
1133func TestAnyToInt64(t *testing.T) {
1134	t.Parallel()
1135
1136	tests := []struct {
1137		name   string
1138		input  any
1139		want   int64
1140		wantOK bool
1141	}{
1142		{name: "int64", input: int64(7), want: 7, wantOK: true},
1143		{name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1144		{name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1145		{name: "float64 non-integer", input: float64(7.5), wantOK: false},
1146		{name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1147		{name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1148		{name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1149		{name: "json number float", input: json.Number("4.2"), wantOK: false},
1150		{name: "nan", input: math.NaN(), wantOK: false},
1151		{name: "inf", input: math.Inf(1), wantOK: false},
1152		{name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1153	}
1154
1155	for _, tt := range tests {
1156		t.Run(tt.name, func(t *testing.T) {
1157			got, ok := anyToInt64(tt.input)
1158			require.Equal(t, tt.wantOK, ok)
1159			if tt.wantOK {
1160				require.Equal(t, tt.want, got)
1161			}
1162		})
1163	}
1164}
1165
1166func TestAnyToUserLocation(t *testing.T) {
1167	t.Parallel()
1168
1169	t.Run("pointer passthrough", func(t *testing.T) {
1170		t.Parallel()
1171
1172		input := &UserLocation{City: "San Francisco", Country: "US"}
1173		got := anyToUserLocation(input)
1174		require.Same(t, input, got)
1175	})
1176
1177	t.Run("struct value", func(t *testing.T) {
1178		t.Parallel()
1179
1180		got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1181		require.NotNil(t, got)
1182		require.Equal(t, "San Francisco", got.City)
1183		require.Equal(t, "US", got.Country)
1184	})
1185
1186	t.Run("map value", func(t *testing.T) {
1187		t.Parallel()
1188
1189		got := anyToUserLocation(map[string]any{
1190			"city":     "San Francisco",
1191			"region":   "CA",
1192			"country":  "US",
1193			"timezone": "America/Los_Angeles",
1194			"type":     "approximate",
1195		})
1196		require.NotNil(t, got)
1197		require.Equal(t, "San Francisco", got.City)
1198		require.Equal(t, "CA", got.Region)
1199		require.Equal(t, "US", got.Country)
1200		require.Equal(t, "America/Los_Angeles", got.Timezone)
1201	})
1202
1203	t.Run("empty map", func(t *testing.T) {
1204		t.Parallel()
1205
1206		got := anyToUserLocation(map[string]any{"type": "approximate"})
1207		require.Nil(t, got)
1208	})
1209
1210	t.Run("unsupported type", func(t *testing.T) {
1211		t.Parallel()
1212
1213		got := anyToUserLocation("San Francisco")
1214		require.Nil(t, got)
1215	})
1216}
1217
1218func TestStream_WebSearchResponse(t *testing.T) {
1219	t.Parallel()
1220
1221	// Build SSE chunks that simulate a web search streaming response.
1222	// The Anthropic SDK accumulates content blocks via
1223	// acc.Accumulate(event). We read the Content and ToolUseID
1224	// directly from struct fields instead of using AsAny(), which
1225	// avoids the SDK's re-marshal limitation that previously dropped
1226	// source data.
1227	webSearchResultContent, _ := json.Marshal([]any{
1228		map[string]any{
1229			"type":              "web_search_result",
1230			"url":               "https://example.com/ai-news",
1231			"title":             "Latest AI News",
1232			"encrypted_content": "encrypted_abc123",
1233			"page_age":          "2 hours ago",
1234		},
1235	})
1236
1237	chunks := []string{
1238		// message_start
1239		"event: message_start\n",
1240		`data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1241		// Block 0: server_tool_use
1242		"event: content_block_start\n",
1243		`data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1244		"event: content_block_stop\n",
1245		`data: {"type":"content_block_stop","index":0}` + "\n\n",
1246		// Block 1: web_search_tool_result
1247		"event: content_block_start\n",
1248		`data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1249		"event: content_block_stop\n",
1250		`data: {"type":"content_block_stop","index":1}` + "\n\n",
1251		// Block 2: text
1252		"event: content_block_start\n",
1253		`data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1254		"event: content_block_delta\n",
1255		`data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1256		"event: content_block_stop\n",
1257		`data: {"type":"content_block_stop","index":2}` + "\n\n",
1258		// message_stop
1259		"event: message_stop\n",
1260		`data: {"type":"message_stop"}` + "\n\n",
1261	}
1262
1263	server, calls := newAnthropicStreamingServer(chunks)
1264	defer server.Close()
1265
1266	provider, err := New(
1267		WithAPIKey("test-api-key"),
1268		WithBaseURL(server.URL),
1269	)
1270	require.NoError(t, err)
1271
1272	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1273	require.NoError(t, err)
1274
1275	stream, err := model.Stream(context.Background(), fantasy.Call{
1276		Prompt: testPrompt(),
1277		Tools: []fantasy.Tool{
1278			WebSearchTool(nil),
1279		},
1280	})
1281	require.NoError(t, err)
1282
1283	var parts []fantasy.StreamPart
1284	stream(func(part fantasy.StreamPart) bool {
1285		parts = append(parts, part)
1286		return true
1287	})
1288
1289	_ = awaitAnthropicCall(t, calls)
1290
1291	// Collect parts by type for assertions.
1292	var (
1293		toolInputStarts []fantasy.StreamPart
1294		toolCalls       []fantasy.StreamPart
1295		toolResults     []fantasy.StreamPart
1296		sourceParts     []fantasy.StreamPart
1297		textDeltas      []fantasy.StreamPart
1298	)
1299	for _, p := range parts {
1300		switch p.Type {
1301		case fantasy.StreamPartTypeToolInputStart:
1302			toolInputStarts = append(toolInputStarts, p)
1303		case fantasy.StreamPartTypeToolCall:
1304			toolCalls = append(toolCalls, p)
1305		case fantasy.StreamPartTypeToolResult:
1306			toolResults = append(toolResults, p)
1307		case fantasy.StreamPartTypeSource:
1308			sourceParts = append(sourceParts, p)
1309		case fantasy.StreamPartTypeTextDelta:
1310			textDeltas = append(textDeltas, p)
1311		}
1312	}
1313
1314	// server_tool_use emits a ToolInputStart with ProviderExecuted.
1315	require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1316	require.True(t, toolInputStarts[0].ProviderExecuted)
1317	require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1318
1319	// server_tool_use emits a ToolCall with ProviderExecuted.
1320	require.NotEmpty(t, toolCalls, "should have a tool call")
1321	require.True(t, toolCalls[0].ProviderExecuted)
1322
1323	// web_search_tool_result always emits a ToolResult even when
1324	// the SDK drops source data. The ToolUseID comes from the raw
1325	// union field as a fallback.
1326	require.NotEmpty(t, toolResults, "should have a tool result")
1327	require.True(t, toolResults[0].ProviderExecuted)
1328	require.Equal(t, "web_search", toolResults[0].ToolCallName)
1329	require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1330		"tool result ID should match the tool_use_id")
1331
1332	// Source parts are now correctly emitted by reading struct fields
1333	// directly instead of using AsAny().
1334	require.Len(t, sourceParts, 1)
1335	require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1336	require.Equal(t, "Latest AI News", sourceParts[0].Title)
1337	require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1338
1339	// Text block emits a text delta.
1340	require.NotEmpty(t, textDeltas, "should have text deltas")
1341	require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1342}
1343
1344func TestGenerate_ToolChoiceNone(t *testing.T) {
1345	t.Parallel()
1346
1347	server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1348	defer server.Close()
1349
1350	provider, err := New(
1351		WithAPIKey("test-api-key"),
1352		WithBaseURL(server.URL),
1353	)
1354	require.NoError(t, err)
1355
1356	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1357	require.NoError(t, err)
1358
1359	toolChoiceNone := fantasy.ToolChoiceNone
1360	_, err = model.Generate(context.Background(), fantasy.Call{
1361		Prompt: testPrompt(),
1362		Tools: []fantasy.Tool{
1363			WebSearchTool(nil),
1364		},
1365		ToolChoice: &toolChoiceNone,
1366	})
1367	require.NoError(t, err)
1368
1369	call := awaitAnthropicCall(t, calls)
1370	toolChoice, ok := call.body["tool_choice"].(map[string]any)
1371	require.True(t, ok, "request body should have tool_choice")
1372	require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1373}
1374
1375// --- Computer Use Tests ---
1376
1377// jsonRoundTripTool simulates a JSON round-trip on a
1378// ProviderDefinedTool so that its Args map contains float64
1379// values (as json.Unmarshal produces) rather than the int64
1380// values that NewComputerUseTool stores directly. The
1381// production toBetaTools code asserts float64.
1382func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1383	t.Helper()
1384	pdt := tool.Definition()
1385	data, err := json.Marshal(pdt.Args)
1386	require.NoError(t, err)
1387	var args map[string]any
1388	require.NoError(t, json.Unmarshal(data, &args))
1389	pdt.Args = args
1390	return pdt
1391}
1392
1393func TestNewComputerUseTool(t *testing.T) {
1394	t.Parallel()
1395
1396	t.Run("creates tool with correct ID and name", func(t *testing.T) {
1397		t.Parallel()
1398		tool := NewComputerUseTool(ComputerUseToolOptions{
1399			DisplayWidthPx:  1920,
1400			DisplayHeightPx: 1080,
1401			ToolVersion:     ComputerUse20250124,
1402		}, noopComputerRun).Definition()
1403		require.Equal(t, "anthropic.computer", tool.ID)
1404		require.Equal(t, "computer", tool.Name)
1405		require.Equal(t, int64(1920), tool.Args["display_width_px"])
1406		require.Equal(t, int64(1080), tool.Args["display_height_px"])
1407		require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1408	})
1409
1410	t.Run("includes optional fields when set", func(t *testing.T) {
1411		t.Parallel()
1412		displayNum := int64(1)
1413		enableZoom := true
1414		tool := NewComputerUseTool(ComputerUseToolOptions{
1415			DisplayWidthPx:  1024,
1416			DisplayHeightPx: 768,
1417			DisplayNumber:   &displayNum,
1418			EnableZoom:      &enableZoom,
1419			ToolVersion:     ComputerUse20251124,
1420			CacheControl:    &CacheControl{Type: "ephemeral"},
1421		}, noopComputerRun).Definition()
1422		require.Equal(t, int64(1), tool.Args["display_number"])
1423		require.Equal(t, true, tool.Args["enable_zoom"])
1424		require.NotNil(t, tool.Args["cache_control"])
1425	})
1426
1427	t.Run("omits optional fields when nil", func(t *testing.T) {
1428		t.Parallel()
1429		tool := NewComputerUseTool(ComputerUseToolOptions{
1430			DisplayWidthPx:  1920,
1431			DisplayHeightPx: 1080,
1432			ToolVersion:     ComputerUse20250124,
1433		}, noopComputerRun).Definition()
1434		_, hasDisplayNum := tool.Args["display_number"]
1435		_, hasEnableZoom := tool.Args["enable_zoom"]
1436		_, hasCacheControl := tool.Args["cache_control"]
1437		require.False(t, hasDisplayNum)
1438		require.False(t, hasEnableZoom)
1439		require.False(t, hasCacheControl)
1440	})
1441}
1442
1443func TestIsComputerUseTool(t *testing.T) {
1444	t.Parallel()
1445
1446	t.Run("returns true for computer use tool", func(t *testing.T) {
1447		t.Parallel()
1448		tool := NewComputerUseTool(ComputerUseToolOptions{
1449			DisplayWidthPx:  1920,
1450			DisplayHeightPx: 1080,
1451			ToolVersion:     ComputerUse20250124,
1452		}, noopComputerRun)
1453		require.True(t, IsComputerUseTool(tool.Definition()))
1454	})
1455
1456	t.Run("returns false for function tool", func(t *testing.T) {
1457		t.Parallel()
1458		tool := fantasy.FunctionTool{
1459			Name:        "test",
1460			Description: "test tool",
1461		}
1462		require.False(t, IsComputerUseTool(tool))
1463	})
1464
1465	t.Run("returns false for other provider defined tool", func(t *testing.T) {
1466		t.Parallel()
1467		tool := fantasy.ProviderDefinedTool{
1468			ID:   "other.tool",
1469			Name: "other",
1470		}
1471		require.False(t, IsComputerUseTool(tool))
1472	})
1473}
1474
1475func TestNeedsBetaAPI(t *testing.T) {
1476	t.Parallel()
1477
1478	lm := languageModel{options: options{}}
1479
1480	t.Run("returns false for empty tools", func(t *testing.T) {
1481		t.Parallel()
1482		_, _, _, betaFlags := lm.toTools(nil, nil, false)
1483		require.Empty(t, betaFlags)
1484		_, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1485		require.Empty(t, betaFlags)
1486	})
1487
1488	t.Run("returns false for only function tools", func(t *testing.T) {
1489		t.Parallel()
1490		tools := []fantasy.Tool{
1491			fantasy.FunctionTool{Name: "test"},
1492		}
1493		_, _, _, betaFlags := lm.toTools(tools, nil, false)
1494		require.Empty(t, betaFlags)
1495	})
1496
1497	t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1498		t.Parallel()
1499		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1500			DisplayWidthPx:  1920,
1501			DisplayHeightPx: 1080,
1502			ToolVersion:     ComputerUse20250124,
1503		}, noopComputerRun))
1504		tools := []fantasy.Tool{
1505			fantasy.FunctionTool{Name: "test"},
1506			cuTool,
1507		}
1508		_, _, _, betaFlags := lm.toTools(tools, nil, false)
1509		require.NotEmpty(t, betaFlags)
1510	})
1511}
1512
1513func TestComputerUseToolJSON(t *testing.T) {
1514	t.Parallel()
1515
1516	t.Run("builds JSON for version 20250124", func(t *testing.T) {
1517		t.Parallel()
1518		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1519			DisplayWidthPx:  1920,
1520			DisplayHeightPx: 1080,
1521			ToolVersion:     ComputerUse20250124,
1522		}, noopComputerRun))
1523		data, err := computerUseToolJSON(cuTool)
1524		require.NoError(t, err)
1525		var m map[string]any
1526		require.NoError(t, json.Unmarshal(data, &m))
1527		require.Equal(t, "computer_20250124", m["type"])
1528		require.Equal(t, "computer", m["name"])
1529		require.InDelta(t, 1920, m["display_width_px"], 0)
1530		require.InDelta(t, 1080, m["display_height_px"], 0)
1531	})
1532
1533	t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1534		t.Parallel()
1535		enableZoom := true
1536		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1537			DisplayWidthPx:  1024,
1538			DisplayHeightPx: 768,
1539			EnableZoom:      &enableZoom,
1540			ToolVersion:     ComputerUse20251124,
1541		}, noopComputerRun))
1542		data, err := computerUseToolJSON(cuTool)
1543		require.NoError(t, err)
1544		var m map[string]any
1545		require.NoError(t, json.Unmarshal(data, &m))
1546		require.Equal(t, "computer_20251124", m["type"])
1547		require.Equal(t, true, m["enable_zoom"])
1548	})
1549
1550	t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1551		t.Parallel()
1552		// Direct construction stores int64 values.
1553		cuTool := NewComputerUseTool(ComputerUseToolOptions{
1554			DisplayWidthPx:  1920,
1555			DisplayHeightPx: 1080,
1556			ToolVersion:     ComputerUse20250124,
1557		}, noopComputerRun)
1558		data, err := computerUseToolJSON(cuTool.Definition())
1559		require.NoError(t, err)
1560		var m map[string]any
1561		require.NoError(t, json.Unmarshal(data, &m))
1562		require.InDelta(t, 1920, m["display_width_px"], 0)
1563	})
1564
1565	t.Run("returns error when version is missing", func(t *testing.T) {
1566		t.Parallel()
1567		pdt := fantasy.ProviderDefinedTool{
1568			ID:   "anthropic.computer",
1569			Name: "computer",
1570			Args: map[string]any{
1571				"display_width_px":  float64(1920),
1572				"display_height_px": float64(1080),
1573			},
1574		}
1575		_, err := computerUseToolJSON(pdt)
1576		require.Error(t, err)
1577		require.Contains(t, err.Error(), "tool_version arg is missing")
1578	})
1579
1580	t.Run("returns error for unsupported version", func(t *testing.T) {
1581		t.Parallel()
1582		pdt := fantasy.ProviderDefinedTool{
1583			ID:   "anthropic.computer",
1584			Name: "computer",
1585			Args: map[string]any{
1586				"display_width_px":  float64(1920),
1587				"display_height_px": float64(1080),
1588				"tool_version":      "computer_99991231",
1589			},
1590		}
1591		_, err := computerUseToolJSON(pdt)
1592		require.Error(t, err)
1593		require.Contains(t, err.Error(), "unsupported")
1594	})
1595}
1596
1597func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1598	t.Parallel()
1599
1600	t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1601		t.Parallel()
1602		_, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1603		require.Error(t, err)
1604		require.Contains(t, err.Error(), "coordinate")
1605	})
1606
1607	t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1608		t.Parallel()
1609		_, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1610		require.Error(t, err)
1611		require.Contains(t, err.Error(), "coordinate")
1612	})
1613
1614	t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1615		t.Parallel()
1616		_, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1617		require.Error(t, err)
1618		require.Contains(t, err.Error(), "start_coordinate")
1619	})
1620
1621	t.Run("rejects region with 3 elements", func(t *testing.T) {
1622		t.Parallel()
1623		_, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1624		require.Error(t, err)
1625		require.Contains(t, err.Error(), "region")
1626	})
1627
1628	t.Run("accepts valid coordinate", func(t *testing.T) {
1629		t.Parallel()
1630		result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1631		require.NoError(t, err)
1632		require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1633	})
1634
1635	t.Run("accepts absent optional arrays", func(t *testing.T) {
1636		t.Parallel()
1637		result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1638		require.NoError(t, err)
1639		require.Equal(t, ActionScreenshot, result.Action)
1640	})
1641}
1642
1643func TestToTools_RawJSON(t *testing.T) {
1644	t.Parallel()
1645
1646	lm := languageModel{options: options{}}
1647
1648	cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1649		DisplayWidthPx:  1920,
1650		DisplayHeightPx: 1080,
1651		ToolVersion:     ComputerUse20250124,
1652	}, noopComputerRun))
1653
1654	tools := []fantasy.Tool{
1655		fantasy.FunctionTool{
1656			Name:        "weather",
1657			Description: "Get weather",
1658			InputSchema: map[string]any{
1659				"properties": map[string]any{
1660					"location": map[string]any{"type": "string"},
1661				},
1662				"required": []string{"location"},
1663			},
1664		},
1665		WebSearchTool(nil),
1666		cuTool,
1667	}
1668
1669	rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1670
1671	require.Len(t, rawTools, 3)
1672	require.Nil(t, toolChoice)
1673	require.Empty(t, warnings)
1674	require.NotEmpty(t, betaFlags)
1675
1676	// Verify each raw tool is valid JSON.
1677	for i, raw := range rawTools {
1678		var m map[string]any
1679		require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1680	}
1681
1682	// Check function tool.
1683	var funcTool map[string]any
1684	require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1685	require.Equal(t, "weather", funcTool["name"])
1686
1687	// Check web search tool.
1688	var webTool map[string]any
1689	require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1690	require.Equal(t, "web_search_20250305", webTool["type"])
1691
1692	// Check computer use tool.
1693	var cuToolJSON map[string]any
1694	require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1695	require.Equal(t, "computer_20250124", cuToolJSON["type"])
1696	require.Equal(t, "computer", cuToolJSON["name"])
1697}
1698
1699func TestGenerate_BetaAPI(t *testing.T) {
1700	t.Parallel()
1701
1702	t.Run("sends beta header for computer use", func(t *testing.T) {
1703		t.Parallel()
1704
1705		var capturedHeaders http.Header
1706		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1707			capturedHeaders = r.Header.Clone()
1708			w.Header().Set("Content-Type", "application/json")
1709			_ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1710		}))
1711		defer server.Close()
1712
1713		provider, err := New(
1714			WithAPIKey("test-api-key"),
1715			WithBaseURL(server.URL),
1716		)
1717		require.NoError(t, err)
1718
1719		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1720		require.NoError(t, err)
1721
1722		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1723			DisplayWidthPx:  1920,
1724			DisplayHeightPx: 1080,
1725			ToolVersion:     ComputerUse20250124,
1726		}, noopComputerRun))
1727
1728		_, err = model.Generate(context.Background(), fantasy.Call{
1729			Prompt: testPrompt(),
1730			Tools:  []fantasy.Tool{cuTool},
1731		})
1732		require.NoError(t, err)
1733		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1734	})
1735
1736	t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1737		t.Parallel()
1738
1739		var capturedHeaders http.Header
1740		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1741			capturedHeaders = r.Header.Clone()
1742			w.Header().Set("Content-Type", "application/json")
1743			_ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1744		}))
1745		defer server.Close()
1746
1747		provider, err := New(
1748			WithAPIKey("test-api-key"),
1749			WithBaseURL(server.URL),
1750		)
1751		require.NoError(t, err)
1752
1753		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1754		require.NoError(t, err)
1755
1756		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1757			DisplayWidthPx:  1920,
1758			DisplayHeightPx: 1080,
1759			ToolVersion:     ComputerUse20251124,
1760		}, noopComputerRun))
1761
1762		_, err = model.Generate(context.Background(), fantasy.Call{
1763			Prompt: testPrompt(),
1764			Tools:  []fantasy.Tool{cuTool},
1765		})
1766		require.NoError(t, err)
1767		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1768	})
1769
1770	t.Run("returns tool use from beta response", func(t *testing.T) {
1771		t.Parallel()
1772
1773		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1774			w.Header().Set("Content-Type", "application/json")
1775			_ = json.NewEncoder(w).Encode(map[string]any{
1776				"id":    "msg_01Test",
1777				"type":  "message",
1778				"role":  "assistant",
1779				"model": "claude-sonnet-4-20250514",
1780				"content": []any{
1781					map[string]any{
1782						"type":  "tool_use",
1783						"id":    "toolu_01",
1784						"name":  "computer",
1785						"input": map[string]any{"action": "screenshot"},
1786					},
1787				},
1788				"stop_reason": "tool_use",
1789				"usage": map[string]any{
1790					"input_tokens":  10,
1791					"output_tokens": 5,
1792					"cache_creation": map[string]any{
1793						"ephemeral_1h_input_tokens": 0,
1794						"ephemeral_5m_input_tokens": 0,
1795					},
1796					"cache_creation_input_tokens": 0,
1797					"cache_read_input_tokens":     0,
1798					"server_tool_use": map[string]any{
1799						"web_search_requests": 0,
1800					},
1801					"service_tier": "standard",
1802				},
1803			})
1804		}))
1805		defer server.Close()
1806
1807		provider, err := New(
1808			WithAPIKey("test-api-key"),
1809			WithBaseURL(server.URL),
1810		)
1811		require.NoError(t, err)
1812
1813		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1814		require.NoError(t, err)
1815
1816		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1817			DisplayWidthPx:  1920,
1818			DisplayHeightPx: 1080,
1819			ToolVersion:     ComputerUse20250124,
1820		}, noopComputerRun))
1821
1822		resp, err := model.Generate(context.Background(), fantasy.Call{
1823			Prompt: testPrompt(),
1824			Tools:  []fantasy.Tool{cuTool},
1825		})
1826		require.NoError(t, err)
1827
1828		toolCalls := resp.Content.ToolCalls()
1829		require.Len(t, toolCalls, 1)
1830		require.Equal(t, "computer", toolCalls[0].ToolName)
1831		require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1832		require.Contains(t, toolCalls[0].Input, "screenshot")
1833		require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1834
1835		// Verify typed parsing works on the tool call input.
1836		parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1837		require.NoError(t, err)
1838		require.Equal(t, ActionScreenshot, parsed.Action)
1839	})
1840}
1841
1842func TestStream_BetaAPI(t *testing.T) {
1843	t.Parallel()
1844
1845	t.Run("streams via beta API for computer use", func(t *testing.T) {
1846		t.Parallel()
1847
1848		var capturedHeaders http.Header
1849		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1850			capturedHeaders = r.Header.Clone()
1851			w.Header().Set("Content-Type", "text/event-stream")
1852			w.Header().Set("Cache-Control", "no-cache")
1853			w.WriteHeader(http.StatusOK)
1854			chunks := []string{
1855				"event: message_start\n",
1856				"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1857				"event: message_stop\n",
1858				"data: {\"type\":\"message_stop\"}\n\n",
1859			}
1860			for _, chunk := range chunks {
1861				_, _ = fmt.Fprint(w, chunk)
1862				if flusher, ok := w.(http.Flusher); ok {
1863					flusher.Flush()
1864				}
1865			}
1866		}))
1867		defer server.Close()
1868
1869		provider, err := New(
1870			WithAPIKey("test-api-key"),
1871			WithBaseURL(server.URL),
1872		)
1873		require.NoError(t, err)
1874
1875		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1876		require.NoError(t, err)
1877
1878		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1879			DisplayWidthPx:  1920,
1880			DisplayHeightPx: 1080,
1881			ToolVersion:     ComputerUse20250124,
1882		}, noopComputerRun))
1883
1884		stream, err := model.Stream(context.Background(), fantasy.Call{
1885			Prompt: testPrompt(),
1886			Tools:  []fantasy.Tool{cuTool},
1887		})
1888		require.NoError(t, err)
1889
1890		stream(func(fantasy.StreamPart) bool { return true })
1891
1892		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1893	})
1894
1895	t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1896		t.Parallel()
1897
1898		var capturedHeaders http.Header
1899		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1900			capturedHeaders = r.Header.Clone()
1901			w.Header().Set("Content-Type", "text/event-stream")
1902			w.Header().Set("Cache-Control", "no-cache")
1903			w.WriteHeader(http.StatusOK)
1904			chunks := []string{
1905				"event: message_start\n",
1906				"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1907				"event: message_stop\n",
1908				"data: {\"type\":\"message_stop\"}\n\n",
1909			}
1910			for _, chunk := range chunks {
1911				_, _ = fmt.Fprint(w, chunk)
1912				if flusher, ok := w.(http.Flusher); ok {
1913					flusher.Flush()
1914				}
1915			}
1916		}))
1917		defer server.Close()
1918
1919		provider, err := New(
1920			WithAPIKey("test-api-key"),
1921			WithBaseURL(server.URL),
1922		)
1923		require.NoError(t, err)
1924
1925		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1926		require.NoError(t, err)
1927
1928		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1929			DisplayWidthPx:  1920,
1930			DisplayHeightPx: 1080,
1931			ToolVersion:     ComputerUse20251124,
1932		}, noopComputerRun))
1933
1934		stream, err := model.Stream(context.Background(), fantasy.Call{
1935			Prompt: testPrompt(),
1936			Tools:  []fantasy.Tool{cuTool},
1937		})
1938		require.NoError(t, err)
1939
1940		stream(func(fantasy.StreamPart) bool { return true })
1941
1942		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1943	})
1944}
1945
1946// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1947// via model.Generate, passing the ExecutableProviderTool directly into
1948// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1949// walks through a scripted sequence of actions — screenshot, click,
1950// type, key, scroll — then finishes with a text reply. Each turn the
1951// test parses the tool call, builds a screenshot result, and appends
1952// both to the prompt for the next request.
1953func TestGenerate_ComputerUseTool(t *testing.T) {
1954	t.Parallel()
1955
1956	type actionStep struct {
1957		input map[string]any
1958		want  ComputerUseInput
1959	}
1960	steps := []actionStep{
1961		{
1962			input: map[string]any{"action": "screenshot"},
1963			want:  ComputerUseInput{Action: ActionScreenshot},
1964		},
1965		{
1966			input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
1967			want:  ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
1968		},
1969		{
1970			input: map[string]any{"action": "type", "text": "hello world"},
1971			want:  ComputerUseInput{Action: ActionType, Text: "hello world"},
1972		},
1973		{
1974			input: map[string]any{"action": "key", "text": "Return"},
1975			want:  ComputerUseInput{Action: ActionKey, Text: "Return"},
1976		},
1977		{
1978			input: map[string]any{
1979				"action":           "scroll",
1980				"coordinate":       []any{500, 300},
1981				"scroll_direction": "down",
1982				"scroll_amount":    3,
1983			},
1984			want: ComputerUseInput{
1985				Action:          ActionScroll,
1986				Coordinate:      [2]int64{500, 300},
1987				ScrollDirection: "down",
1988				ScrollAmount:    3,
1989			},
1990		},
1991		{
1992			input: map[string]any{"action": "screenshot"},
1993			want:  ComputerUseInput{Action: ActionScreenshot},
1994		},
1995	}
1996
1997	var (
1998		requestIdx  int
1999		betaHeaders []string
2000	)
2001	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2002		betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2003		idx := requestIdx
2004		requestIdx++
2005
2006		w.Header().Set("Content-Type", "application/json")
2007		if idx < len(steps) {
2008			_ = json.NewEncoder(w).Encode(map[string]any{
2009				"id":    fmt.Sprintf("msg_%02d", idx),
2010				"type":  "message",
2011				"role":  "assistant",
2012				"model": "claude-sonnet-4-20250514",
2013				"content": []any{map[string]any{
2014					"type":  "tool_use",
2015					"id":    fmt.Sprintf("toolu_%02d", idx),
2016					"name":  "computer",
2017					"input": steps[idx].input,
2018				}},
2019				"stop_reason": "tool_use",
2020				"usage":       map[string]any{"input_tokens": 10, "output_tokens": 5},
2021			})
2022			return
2023		}
2024		_ = json.NewEncoder(w).Encode(map[string]any{
2025			"id":    "msg_final",
2026			"type":  "message",
2027			"role":  "assistant",
2028			"model": "claude-sonnet-4-20250514",
2029			"content": []any{map[string]any{
2030				"type": "text",
2031				"text": "Done! I have completed all the requested actions.",
2032			}},
2033			"stop_reason": "end_turn",
2034			"usage":       map[string]any{"input_tokens": 10, "output_tokens": 15},
2035		})
2036	}))
2037	defer server.Close()
2038
2039	provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2040	require.NoError(t, err)
2041
2042	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2043	require.NoError(t, err)
2044
2045	// Pass the ExecutableProviderTool directly — the whole point is
2046	// to verify that the Tool interface works without unwrapping.
2047	cuTool := NewComputerUseTool(ComputerUseToolOptions{
2048		DisplayWidthPx:  1920,
2049		DisplayHeightPx: 1080,
2050		ToolVersion:     ComputerUse20250124,
2051	}, noopComputerRun)
2052
2053	var got []ComputerUseInput
2054	prompt := testPrompt()
2055	fakePNG := []byte("fake-screenshot-png")
2056
2057	for turn := 0; turn <= len(steps); turn++ {
2058		resp, err := model.Generate(context.Background(), fantasy.Call{
2059			Prompt: prompt,
2060			Tools:  []fantasy.Tool{cuTool},
2061		})
2062		require.NoError(t, err, "turn %d", turn)
2063
2064		if resp.FinishReason != fantasy.FinishReasonToolCalls {
2065			require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2066			require.Contains(t, resp.Content.Text(), "Done")
2067			break
2068		}
2069
2070		toolCalls := resp.Content.ToolCalls()
2071		require.Len(t, toolCalls, 1, "turn %d", turn)
2072		require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2073
2074		parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2075		require.NoError(t, err, "turn %d", turn)
2076		got = append(got, parsed)
2077
2078		// Build the next prompt: append the assistant tool-call turn
2079		// and the user screenshot-result turn.
2080		prompt = append(prompt,
2081			fantasy.Message{
2082				Role: fantasy.MessageRoleAssistant,
2083				Content: []fantasy.MessagePart{
2084					fantasy.ToolCallPart{
2085						ToolCallID: toolCalls[0].ToolCallID,
2086						ToolName:   toolCalls[0].ToolName,
2087						Input:      toolCalls[0].Input,
2088					},
2089				},
2090			},
2091			fantasy.Message{
2092				// Use MessageRoleTool for tool results — this matches
2093				// what the agent loop produces.
2094				Role: fantasy.MessageRoleTool,
2095				Content: []fantasy.MessagePart{
2096					NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2097				},
2098			},
2099		)
2100	}
2101
2102	// Every scripted action was received and parsed correctly.
2103	require.Len(t, got, len(steps))
2104	for i, step := range steps {
2105		require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2106		require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2107		require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2108		require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2109		require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2110	}
2111
2112	// Beta header was sent on every request.
2113	require.Len(t, betaHeaders, len(steps)+1)
2114	for i, h := range betaHeaders {
2115		require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2116	}
2117}
2118
2119// TestStream_ComputerUseTool runs a multi-turn computer use session
2120// via model.Stream, verifying that the ExecutableProviderTool works
2121// through the streaming path end-to-end.
2122func TestStream_ComputerUseTool(t *testing.T) {
2123	t.Parallel()
2124
2125	type streamStep struct {
2126		input      map[string]any
2127		wantAction ComputerAction
2128	}
2129	steps := []streamStep{
2130		{input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2131		{input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2132		{input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2133	}
2134
2135	var (
2136		requestIdx  int
2137		betaHeaders []string
2138	)
2139
2140	// streamToolUseChunks returns SSE chunks for a single
2141	// computer-use tool_use content block.
2142	streamToolUseChunks := func(id string, input map[string]any) []string {
2143		inputJSON, _ := json.Marshal(input)
2144		escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2145		return []string{
2146			"event: message_start\n",
2147			`data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2148			"event: content_block_start\n",
2149			`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2150			"event: content_block_delta\n",
2151			`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2152			"event: content_block_stop\n",
2153			`data: {"type":"content_block_stop","index":0}` + "\n\n",
2154			"event: message_delta\n",
2155			`data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2156			"event: message_stop\n",
2157			`data: {"type":"message_stop"}` + "\n\n",
2158		}
2159	}
2160
2161	streamTextChunks := func() []string {
2162		return []string{
2163			"event: message_start\n",
2164			`data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2165			"event: content_block_start\n",
2166			`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2167			"event: content_block_delta\n",
2168			`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2169			"event: content_block_stop\n",
2170			`data: {"type":"content_block_stop","index":0}` + "\n\n",
2171			"event: message_delta\n",
2172			`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2173			"event: message_stop\n",
2174			`data: {"type":"message_stop"}` + "\n\n",
2175		}
2176	}
2177
2178	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2179		betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2180		idx := requestIdx
2181		requestIdx++
2182
2183		w.Header().Set("Content-Type", "text/event-stream")
2184		w.Header().Set("Cache-Control", "no-cache")
2185		w.WriteHeader(http.StatusOK)
2186
2187		var chunks []string
2188		if idx < len(steps) {
2189			chunks = streamToolUseChunks(
2190				fmt.Sprintf("toolu_%02d", idx),
2191				steps[idx].input,
2192			)
2193		} else {
2194			chunks = streamTextChunks()
2195		}
2196		for _, chunk := range chunks {
2197			_, _ = fmt.Fprint(w, chunk)
2198			if f, ok := w.(http.Flusher); ok {
2199				f.Flush()
2200			}
2201		}
2202	}))
2203	defer server.Close()
2204
2205	provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2206	require.NoError(t, err)
2207
2208	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2209	require.NoError(t, err)
2210
2211	cuTool := NewComputerUseTool(ComputerUseToolOptions{
2212		DisplayWidthPx:  1920,
2213		DisplayHeightPx: 1080,
2214		ToolVersion:     ComputerUse20250124,
2215	}, noopComputerRun)
2216
2217	var gotActions []ComputerAction
2218	prompt := testPrompt()
2219	fakePNG := []byte("fake-screenshot-png")
2220
2221	for turn := 0; turn <= len(steps); turn++ {
2222		stream, err := model.Stream(context.Background(), fantasy.Call{
2223			Prompt: prompt,
2224			Tools:  []fantasy.Tool{cuTool},
2225		})
2226		require.NoError(t, err, "turn %d", turn)
2227
2228		var (
2229			toolCallName  string
2230			toolCallID    string
2231			toolCallInput string
2232			finishReason  fantasy.FinishReason
2233			gotText       string
2234		)
2235		stream(func(part fantasy.StreamPart) bool {
2236			switch part.Type {
2237			case fantasy.StreamPartTypeToolCall:
2238				toolCallName = part.ToolCallName
2239				toolCallID = part.ID
2240				toolCallInput = part.ToolCallInput
2241			case fantasy.StreamPartTypeFinish:
2242				finishReason = part.FinishReason
2243			case fantasy.StreamPartTypeTextDelta:
2244				gotText += part.Delta
2245			}
2246			return true
2247		})
2248
2249		if finishReason != fantasy.FinishReasonToolCalls {
2250			require.Contains(t, gotText, "All done")
2251			break
2252		}
2253
2254		require.Equal(t, "computer", toolCallName, "turn %d", turn)
2255
2256		parsed, err := ParseComputerUseInput(toolCallInput)
2257		require.NoError(t, err, "turn %d", turn)
2258		gotActions = append(gotActions, parsed.Action)
2259
2260		prompt = append(prompt,
2261			fantasy.Message{
2262				Role: fantasy.MessageRoleAssistant,
2263				Content: []fantasy.MessagePart{
2264					fantasy.ToolCallPart{
2265						ToolCallID: toolCallID,
2266						ToolName:   toolCallName,
2267						Input:      toolCallInput,
2268					},
2269				},
2270			},
2271			fantasy.Message{
2272				// Use MessageRoleTool for tool results — this matches
2273				// what the agent loop produces.
2274				Role: fantasy.MessageRoleTool,
2275				Content: []fantasy.MessagePart{
2276					NewComputerUseScreenshotResult(toolCallID, fakePNG),
2277				},
2278			},
2279		)
2280	}
2281
2282	require.Len(t, gotActions, len(steps))
2283	for i, step := range steps {
2284		require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2285	}
2286
2287	require.Len(t, betaHeaders, len(steps)+1)
2288	for i, h := range betaHeaders {
2289		require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2290	}
2291}