anthropic_test.go

   1package anthropic
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"math"
   9	"net/http"
  10	"net/http/httptest"
  11	"strings"
  12	"testing"
  13	"time"
  14
  15	"charm.land/fantasy"
  16	"github.com/charmbracelet/anthropic-sdk-go"
  17	"github.com/stretchr/testify/require"
  18)
  19
  20// noopComputerRun is a no-op run function for tests that only need
  21// to inspect the tool definition, not execute it.
  22var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
  23	return fantasy.ToolResponse{}, nil
  24}
  25
  26func TestToPrompt_DropsEmptyMessages(t *testing.T) {
  27	t.Parallel()
  28
  29	t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
  30		t.Parallel()
  31
  32		prompt := fantasy.Prompt{
  33			{
  34				Role: fantasy.MessageRoleUser,
  35				Content: []fantasy.MessagePart{
  36					fantasy.TextPart{Text: "Hello"},
  37				},
  38			},
  39			{
  40				Role: fantasy.MessageRoleAssistant,
  41				Content: []fantasy.MessagePart{
  42					fantasy.ReasoningPart{
  43						Text: "Let me think about this...",
  44						ProviderOptions: fantasy.ProviderOptions{
  45							Name: &ReasoningOptionMetadata{
  46								Signature: "abc123",
  47							},
  48						},
  49					},
  50				},
  51			},
  52		}
  53
  54		systemBlocks, messages, warnings := toPrompt(prompt, true)
  55
  56		require.Empty(t, systemBlocks)
  57		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  58		require.Len(t, warnings, 1)
  59		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  60		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
  61		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
  62	})
  63
  64	t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
  65		t.Parallel()
  66
  67		prompt := fantasy.Prompt{
  68			{
  69				Role: fantasy.MessageRoleUser,
  70				Content: []fantasy.MessagePart{
  71					fantasy.TextPart{Text: "Hello"},
  72				},
  73			},
  74			{
  75				Role: fantasy.MessageRoleAssistant,
  76				Content: []fantasy.MessagePart{
  77					fantasy.ReasoningPart{
  78						Text: "Let me think about this...",
  79						ProviderOptions: fantasy.ProviderOptions{
  80							Name: &ReasoningOptionMetadata{
  81								Signature: "def456",
  82							},
  83						},
  84					},
  85				},
  86			},
  87		}
  88
  89		systemBlocks, messages, warnings := toPrompt(prompt, false)
  90
  91		require.Empty(t, systemBlocks)
  92		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  93		require.Len(t, warnings, 2)
  94		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  95		require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
  96		require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
  97		require.Contains(t, warnings[1].Message, "dropping empty assistant message")
  98	})
  99
 100	t.Run("should drop truly empty assistant messages", func(t *testing.T) {
 101		t.Parallel()
 102
 103		prompt := fantasy.Prompt{
 104			{
 105				Role: fantasy.MessageRoleUser,
 106				Content: []fantasy.MessagePart{
 107					fantasy.TextPart{Text: "Hello"},
 108				},
 109			},
 110			{
 111				Role:    fantasy.MessageRoleAssistant,
 112				Content: []fantasy.MessagePart{},
 113			},
 114		}
 115
 116		systemBlocks, messages, warnings := toPrompt(prompt, true)
 117
 118		require.Empty(t, systemBlocks)
 119		require.Len(t, messages, 1, "should only have user message")
 120		require.Len(t, warnings, 1)
 121		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 122		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 123	})
 124
 125	t.Run("should keep assistant messages with text content", func(t *testing.T) {
 126		t.Parallel()
 127
 128		prompt := fantasy.Prompt{
 129			{
 130				Role: fantasy.MessageRoleUser,
 131				Content: []fantasy.MessagePart{
 132					fantasy.TextPart{Text: "Hello"},
 133				},
 134			},
 135			{
 136				Role: fantasy.MessageRoleAssistant,
 137				Content: []fantasy.MessagePart{
 138					fantasy.TextPart{Text: "Hi there!"},
 139				},
 140			},
 141		}
 142
 143		systemBlocks, messages, warnings := toPrompt(prompt, true)
 144
 145		require.Empty(t, systemBlocks)
 146		require.Len(t, messages, 2, "should have both user and assistant messages")
 147		require.Empty(t, warnings)
 148	})
 149
 150	t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
 151		t.Parallel()
 152
 153		prompt := fantasy.Prompt{
 154			{
 155				Role: fantasy.MessageRoleUser,
 156				Content: []fantasy.MessagePart{
 157					fantasy.TextPart{Text: "What's the weather?"},
 158				},
 159			},
 160			{
 161				Role: fantasy.MessageRoleAssistant,
 162				Content: []fantasy.MessagePart{
 163					fantasy.ToolCallPart{
 164						ToolCallID: "call_123",
 165						ToolName:   "get_weather",
 166						Input:      `{"location":"NYC"}`,
 167					},
 168				},
 169			},
 170		}
 171
 172		systemBlocks, messages, warnings := toPrompt(prompt, true)
 173
 174		require.Empty(t, systemBlocks)
 175		require.Len(t, messages, 2, "should have both user and assistant messages")
 176		require.Empty(t, warnings)
 177	})
 178
 179	t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
 180		t.Parallel()
 181
 182		prompt := fantasy.Prompt{
 183			{
 184				Role: fantasy.MessageRoleUser,
 185				Content: []fantasy.MessagePart{
 186					fantasy.TextPart{Text: "Hi"},
 187				},
 188			},
 189			{
 190				Role: fantasy.MessageRoleAssistant,
 191				Content: []fantasy.MessagePart{
 192					fantasy.ToolCallPart{
 193						ToolCallID: "call_123",
 194						ToolName:   "get_weather",
 195						Input:      "{not-json",
 196					},
 197				},
 198			},
 199		}
 200
 201		systemBlocks, messages, warnings := toPrompt(prompt, true)
 202
 203		require.Empty(t, systemBlocks)
 204		require.Len(t, messages, 1, "should only have user message")
 205		require.Len(t, warnings, 1)
 206		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 207		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 208	})
 209
 210	t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
 211		t.Parallel()
 212
 213		prompt := fantasy.Prompt{
 214			{
 215				Role: fantasy.MessageRoleUser,
 216				Content: []fantasy.MessagePart{
 217					fantasy.TextPart{Text: "Hello"},
 218				},
 219			},
 220			{
 221				Role: fantasy.MessageRoleAssistant,
 222				Content: []fantasy.MessagePart{
 223					fantasy.ReasoningPart{
 224						Text: "Let me think...",
 225						ProviderOptions: fantasy.ProviderOptions{
 226							Name: &ReasoningOptionMetadata{
 227								Signature: "abc123",
 228							},
 229						},
 230					},
 231					fantasy.TextPart{Text: "Hi there!"},
 232				},
 233			},
 234		}
 235
 236		systemBlocks, messages, warnings := toPrompt(prompt, true)
 237
 238		require.Empty(t, systemBlocks)
 239		require.Len(t, messages, 2, "should have both user and assistant messages")
 240		require.Empty(t, warnings)
 241	})
 242
 243	t.Run("should keep user messages with image content", func(t *testing.T) {
 244		t.Parallel()
 245
 246		prompt := fantasy.Prompt{
 247			{
 248				Role: fantasy.MessageRoleUser,
 249				Content: []fantasy.MessagePart{
 250					fantasy.FilePart{
 251						Data:      []byte{0x01, 0x02, 0x03},
 252						MediaType: "image/png",
 253					},
 254				},
 255			},
 256		}
 257
 258		systemBlocks, messages, warnings := toPrompt(prompt, true)
 259
 260		require.Empty(t, systemBlocks)
 261		require.Len(t, messages, 1)
 262		require.Empty(t, warnings)
 263	})
 264
 265	t.Run("should drop user messages without visible content", func(t *testing.T) {
 266		t.Parallel()
 267
 268		prompt := fantasy.Prompt{
 269			{
 270				Role: fantasy.MessageRoleUser,
 271				Content: []fantasy.MessagePart{
 272					fantasy.FilePart{
 273						Data:      []byte("not supported"),
 274						MediaType: "application/pdf",
 275					},
 276				},
 277			},
 278		}
 279
 280		systemBlocks, messages, warnings := toPrompt(prompt, true)
 281
 282		require.Empty(t, systemBlocks)
 283		require.Empty(t, messages)
 284		require.Len(t, warnings, 1)
 285		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 286		require.Contains(t, warnings[0].Message, "dropping empty user message")
 287		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
 288	})
 289
 290	t.Run("should keep user messages with tool results", func(t *testing.T) {
 291		t.Parallel()
 292
 293		prompt := fantasy.Prompt{
 294			{
 295				Role: fantasy.MessageRoleTool,
 296				Content: []fantasy.MessagePart{
 297					fantasy.ToolResultPart{
 298						ToolCallID: "call_123",
 299						Output:     fantasy.ToolResultOutputContentText{Text: "done"},
 300					},
 301				},
 302			},
 303		}
 304
 305		systemBlocks, messages, warnings := toPrompt(prompt, true)
 306
 307		require.Empty(t, systemBlocks)
 308		require.Len(t, messages, 1)
 309		require.Empty(t, warnings)
 310	})
 311
 312	t.Run("should keep user messages with tool error results", func(t *testing.T) {
 313		t.Parallel()
 314
 315		prompt := fantasy.Prompt{
 316			{
 317				Role: fantasy.MessageRoleTool,
 318				Content: []fantasy.MessagePart{
 319					fantasy.ToolResultPart{
 320						ToolCallID: "call_456",
 321						Output:     fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
 322					},
 323				},
 324			},
 325		}
 326
 327		systemBlocks, messages, warnings := toPrompt(prompt, true)
 328
 329		require.Empty(t, systemBlocks)
 330		require.Len(t, messages, 1)
 331		require.Empty(t, warnings)
 332	})
 333
 334	t.Run("should keep user messages with tool media results", func(t *testing.T) {
 335		t.Parallel()
 336
 337		prompt := fantasy.Prompt{
 338			{
 339				Role: fantasy.MessageRoleTool,
 340				Content: []fantasy.MessagePart{
 341					fantasy.ToolResultPart{
 342						ToolCallID: "call_789",
 343						Output: fantasy.ToolResultOutputContentMedia{
 344							Data:      "AQID",
 345							MediaType: "image/png",
 346						},
 347					},
 348				},
 349			},
 350		}
 351
 352		systemBlocks, messages, warnings := toPrompt(prompt, true)
 353
 354		require.Empty(t, systemBlocks)
 355		require.Len(t, messages, 1)
 356		require.Empty(t, warnings)
 357	})
 358}
 359
 360func TestParseContextTooLargeError(t *testing.T) {
 361	t.Parallel()
 362
 363	tests := []struct {
 364		name     string
 365		message  string
 366		wantErr  bool
 367		wantUsed int
 368		wantMax  int
 369	}{
 370		{
 371			name:     "matches anthropic format",
 372			message:  "prompt is too long: 202630 tokens > 200000 maximum",
 373			wantErr:  true,
 374			wantUsed: 202630,
 375			wantMax:  200000,
 376		},
 377		{
 378			name:     "matches with different numbers",
 379			message:  "prompt is too long: 150000 tokens > 128000 maximum",
 380			wantErr:  true,
 381			wantUsed: 150000,
 382			wantMax:  128000,
 383		},
 384		{
 385			name:     "matches with extra whitespace",
 386			message:  "prompt is too long:  202630  tokens  >  200000  maximum",
 387			wantErr:  true,
 388			wantUsed: 202630,
 389			wantMax:  200000,
 390		},
 391		{
 392			name:    "does not match unrelated error",
 393			message: "invalid api key",
 394			wantErr: false,
 395		},
 396		{
 397			name:    "does not match rate limit error",
 398			message: "rate limit exceeded",
 399			wantErr: false,
 400		},
 401	}
 402
 403	for _, tt := range tests {
 404		t.Run(tt.name, func(t *testing.T) {
 405			t.Parallel()
 406			providerErr := &fantasy.ProviderError{Message: tt.message}
 407			parseContextTooLargeError(tt.message, providerErr)
 408
 409			if tt.wantErr {
 410				require.True(t, providerErr.IsContextTooLarge())
 411				require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
 412				require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
 413			} else {
 414				require.False(t, providerErr.IsContextTooLarge())
 415			}
 416		})
 417	}
 418}
 419
 420func TestParseOptions_Effort(t *testing.T) {
 421	t.Parallel()
 422
 423	options, err := ParseOptions(map[string]any{
 424		"send_reasoning":            true,
 425		"thinking":                  map[string]any{"budget_tokens": int64(2048)},
 426		"effort":                    "medium",
 427		"disable_parallel_tool_use": true,
 428	})
 429	require.NoError(t, err)
 430	require.NotNil(t, options.SendReasoning)
 431	require.True(t, *options.SendReasoning)
 432	require.NotNil(t, options.Thinking)
 433	require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
 434	require.NotNil(t, options.Effort)
 435	require.Equal(t, EffortMedium, *options.Effort)
 436	require.NotNil(t, options.DisableParallelToolUse)
 437	require.True(t, *options.DisableParallelToolUse)
 438}
 439
 440func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
 441	t.Parallel()
 442
 443	server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 444	defer server.Close()
 445
 446	provider, err := New(
 447		WithAPIKey("test-api-key"),
 448		WithBaseURL(server.URL),
 449	)
 450	require.NoError(t, err)
 451
 452	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 453	require.NoError(t, err)
 454
 455	effort := EffortMedium
 456	_, err = model.Generate(context.Background(), fantasy.Call{
 457		Prompt: testPrompt(),
 458		ProviderOptions: NewProviderOptions(&ProviderOptions{
 459			Effort: &effort,
 460		}),
 461	})
 462	require.NoError(t, err)
 463
 464	call := awaitAnthropicCall(t, calls)
 465	require.Equal(t, "POST", call.method)
 466	require.Equal(t, "/v1/messages", call.path)
 467	requireAnthropicEffort(t, call.body, EffortMedium)
 468}
 469
 470func TestStream_SendsOutputConfigEffort(t *testing.T) {
 471	t.Parallel()
 472
 473	server, calls := newAnthropicStreamingServer([]string{
 474		"event: message_start\n",
 475		"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
 476		"event: message_stop\n",
 477		"data: {\"type\":\"message_stop\"}\n\n",
 478	})
 479	defer server.Close()
 480
 481	provider, err := New(
 482		WithAPIKey("test-api-key"),
 483		WithBaseURL(server.URL),
 484	)
 485	require.NoError(t, err)
 486
 487	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 488	require.NoError(t, err)
 489
 490	effort := EffortHigh
 491	stream, err := model.Stream(context.Background(), fantasy.Call{
 492		Prompt: testPrompt(),
 493		ProviderOptions: NewProviderOptions(&ProviderOptions{
 494			Effort: &effort,
 495		}),
 496	})
 497	require.NoError(t, err)
 498
 499	stream(func(fantasy.StreamPart) bool { return true })
 500
 501	call := awaitAnthropicCall(t, calls)
 502	require.Equal(t, "POST", call.method)
 503	require.Equal(t, "/v1/messages", call.path)
 504	requireAnthropicEffort(t, call.body, EffortHigh)
 505}
 506
 507type anthropicCall struct {
 508	method string
 509	path   string
 510	body   map[string]any
 511}
 512
 513func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
 514	calls := make(chan anthropicCall, 4)
 515
 516	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 517		var body map[string]any
 518		if r.Body != nil {
 519			_ = json.NewDecoder(r.Body).Decode(&body)
 520		}
 521
 522		calls <- anthropicCall{
 523			method: r.Method,
 524			path:   r.URL.Path,
 525			body:   body,
 526		}
 527
 528		w.Header().Set("Content-Type", "application/json")
 529		_ = json.NewEncoder(w).Encode(response)
 530	}))
 531
 532	return server, calls
 533}
 534
 535func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
 536	calls := make(chan anthropicCall, 4)
 537
 538	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 539		var body map[string]any
 540		if r.Body != nil {
 541			_ = json.NewDecoder(r.Body).Decode(&body)
 542		}
 543
 544		calls <- anthropicCall{
 545			method: r.Method,
 546			path:   r.URL.Path,
 547			body:   body,
 548		}
 549
 550		w.Header().Set("Content-Type", "text/event-stream")
 551		w.Header().Set("Cache-Control", "no-cache")
 552		w.Header().Set("Connection", "keep-alive")
 553		w.WriteHeader(http.StatusOK)
 554
 555		for _, chunk := range chunks {
 556			_, _ = fmt.Fprint(w, chunk)
 557			if flusher, ok := w.(http.Flusher); ok {
 558				flusher.Flush()
 559			}
 560		}
 561	}))
 562
 563	return server, calls
 564}
 565
 566func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
 567	t.Helper()
 568
 569	select {
 570	case call := <-calls:
 571		return call
 572	case <-time.After(2 * time.Second):
 573		t.Fatal("timed out waiting for Anthropic request")
 574		return anthropicCall{}
 575	}
 576}
 577
 578func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
 579	t.Helper()
 580
 581	select {
 582	case call := <-calls:
 583		t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
 584	case <-time.After(200 * time.Millisecond):
 585	}
 586}
 587
 588func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
 589	t.Helper()
 590
 591	outputConfig, ok := body["output_config"].(map[string]any)
 592	thinking, ok := body["thinking"].(map[string]any)
 593	require.True(t, ok)
 594	require.Equal(t, string(expected), outputConfig["effort"])
 595	require.Equal(t, "adaptive", thinking["type"])
 596}
 597
 598func testPrompt() fantasy.Prompt {
 599	return fantasy.Prompt{
 600		{
 601			Role: fantasy.MessageRoleUser,
 602			Content: []fantasy.MessagePart{
 603				fantasy.TextPart{Text: "Hello"},
 604			},
 605		},
 606	}
 607}
 608
 609func mockAnthropicGenerateResponse() map[string]any {
 610	return map[string]any{
 611		"id":    "msg_01Test",
 612		"type":  "message",
 613		"role":  "assistant",
 614		"model": "claude-sonnet-4-20250514",
 615		"content": []any{
 616			map[string]any{
 617				"type": "text",
 618				"text": "Hi there",
 619			},
 620		},
 621		"stop_reason":   "end_turn",
 622		"stop_sequence": "",
 623		"usage": map[string]any{
 624			"cache_creation": map[string]any{
 625				"ephemeral_1h_input_tokens": 0,
 626				"ephemeral_5m_input_tokens": 0,
 627			},
 628			"cache_creation_input_tokens": 0,
 629			"cache_read_input_tokens":     0,
 630			"input_tokens":                5,
 631			"output_tokens":               2,
 632			"server_tool_use": map[string]any{
 633				"web_search_requests": 0,
 634			},
 635			"service_tier": "standard",
 636		},
 637	}
 638}
 639
 640func mockAnthropicWebSearchResponse() map[string]any {
 641	return map[string]any{
 642		"id":    "msg_01WebSearch",
 643		"type":  "message",
 644		"role":  "assistant",
 645		"model": "claude-sonnet-4-20250514",
 646		"content": []any{
 647			map[string]any{
 648				"type":   "server_tool_use",
 649				"id":     "srvtoolu_01",
 650				"name":   "web_search",
 651				"input":  map[string]any{"query": "latest AI news"},
 652				"caller": map[string]any{"type": "direct"},
 653			},
 654			map[string]any{
 655				"type":        "web_search_tool_result",
 656				"tool_use_id": "srvtoolu_01",
 657				"caller":      map[string]any{"type": "direct"},
 658				"content": []any{
 659					map[string]any{
 660						"type":              "web_search_result",
 661						"url":               "https://example.com/ai-news",
 662						"title":             "Latest AI News",
 663						"encrypted_content": "encrypted_abc123",
 664						"page_age":          "2 hours ago",
 665					},
 666					map[string]any{
 667						"type":              "web_search_result",
 668						"url":               "https://example.com/ml-update",
 669						"title":             "ML Update",
 670						"encrypted_content": "encrypted_def456",
 671						"page_age":          "",
 672					},
 673				},
 674			},
 675			map[string]any{
 676				"type": "text",
 677				"text": "Based on recent search results, here is the latest AI news.",
 678			},
 679		},
 680		"stop_reason":   "end_turn",
 681		"stop_sequence": nil,
 682		"usage": map[string]any{
 683			"input_tokens":                100,
 684			"output_tokens":               50,
 685			"cache_creation_input_tokens": 0,
 686			"cache_read_input_tokens":     0,
 687			"server_tool_use": map[string]any{
 688				"web_search_requests": 1,
 689			},
 690		},
 691	}
 692}
 693
 694func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
 695	t.Parallel()
 696
 697	prompt := fantasy.Prompt{
 698		// User message.
 699		{
 700			Role: fantasy.MessageRoleUser,
 701			Content: []fantasy.MessagePart{
 702				fantasy.TextPart{Text: "Search for the latest AI news"},
 703			},
 704		},
 705		// Assistant message with a provider-executed tool call, its
 706		// result, and trailing text. toResponseMessages routes
 707		// provider-executed results into the assistant message, so
 708		// the prompt already reflects that structure.
 709		{
 710			Role: fantasy.MessageRoleAssistant,
 711			Content: []fantasy.MessagePart{
 712				fantasy.ToolCallPart{
 713					ToolCallID:       "srvtoolu_01",
 714					ToolName:         "web_search",
 715					Input:            `{"query":"latest AI news"}`,
 716					ProviderExecuted: true,
 717				},
 718				fantasy.ToolResultPart{
 719					ToolCallID:       "srvtoolu_01",
 720					ProviderExecuted: true,
 721					ProviderOptions: fantasy.ProviderOptions{
 722						Name: &WebSearchResultMetadata{
 723							Results: []WebSearchResultItem{
 724								{
 725									URL:              "https://example.com/ai-news",
 726									Title:            "Latest AI News",
 727									EncryptedContent: "encrypted_abc123",
 728									PageAge:          "2 hours ago",
 729								},
 730								{
 731									URL:              "https://example.com/ml-update",
 732									Title:            "ML Update",
 733									EncryptedContent: "encrypted_def456",
 734								},
 735							},
 736						},
 737					},
 738				},
 739				fantasy.TextPart{Text: "Here is what I found."},
 740			},
 741		},
 742	}
 743
 744	_, messages, warnings := toPrompt(prompt, true)
 745
 746	// No warnings expected; the provider-executed result is in the
 747	// assistant message so there is no empty tool message to drop.
 748	require.Empty(t, warnings)
 749
 750	// We should have a user message and an assistant message.
 751	require.Len(t, messages, 2, "expected user + assistant messages")
 752
 753	assistantMsg := messages[1]
 754	require.Len(t, assistantMsg.Content, 3,
 755		"expected server_tool_use + web_search_tool_result + text")
 756
 757	// First content block: reconstructed server_tool_use.
 758	serverToolUse := assistantMsg.Content[0]
 759	require.NotNil(t, serverToolUse.OfServerToolUse,
 760		"first block should be a server_tool_use")
 761	require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
 762	require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
 763		serverToolUse.OfServerToolUse.Name)
 764
 765	// Second content block: reconstructed web_search_tool_result with
 766	// encrypted_content preserved for multi-turn round-tripping.
 767	webResult := assistantMsg.Content[1]
 768	require.NotNil(t, webResult.OfWebSearchToolResult,
 769		"second block should be a web_search_tool_result")
 770	require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
 771
 772	results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
 773	require.Len(t, results, 2)
 774	require.Equal(t, "https://example.com/ai-news", results[0].URL)
 775	require.Equal(t, "Latest AI News", results[0].Title)
 776	require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
 777	require.Equal(t, "https://example.com/ml-update", results[1].URL)
 778	require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
 779	// PageAge should be set for the first result and absent for the second.
 780	require.True(t, results[0].PageAge.Valid())
 781	require.Equal(t, "2 hours ago", results[0].PageAge.Value)
 782	require.False(t, results[1].PageAge.Valid())
 783
 784	// Third content block: plain text.
 785	require.NotNil(t, assistantMsg.Content[2].OfText)
 786	require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
 787}
 788
 789func TestGenerate_WebSearchResponse(t *testing.T) {
 790	t.Parallel()
 791
 792	server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
 793	defer server.Close()
 794
 795	provider, err := New(
 796		WithAPIKey("test-api-key"),
 797		WithBaseURL(server.URL),
 798	)
 799	require.NoError(t, err)
 800
 801	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 802	require.NoError(t, err)
 803
 804	resp, err := model.Generate(context.Background(), fantasy.Call{
 805		Prompt: testPrompt(),
 806		Tools: []fantasy.Tool{
 807			WebSearchTool(nil),
 808		},
 809	})
 810	require.NoError(t, err)
 811
 812	call := awaitAnthropicCall(t, calls)
 813	require.Equal(t, "POST", call.method)
 814	require.Equal(t, "/v1/messages", call.path)
 815
 816	// Walk the response content and categorise each item.
 817	var (
 818		toolCalls   []fantasy.ToolCallContent
 819		sources     []fantasy.SourceContent
 820		toolResults []fantasy.ToolResultContent
 821		texts       []fantasy.TextContent
 822	)
 823	for _, c := range resp.Content {
 824		switch v := c.(type) {
 825		case fantasy.ToolCallContent:
 826			toolCalls = append(toolCalls, v)
 827		case fantasy.SourceContent:
 828			sources = append(sources, v)
 829		case fantasy.ToolResultContent:
 830			toolResults = append(toolResults, v)
 831		case fantasy.TextContent:
 832			texts = append(texts, v)
 833		}
 834	}
 835
 836	// ToolCallContent for the provider-executed web_search.
 837	require.Len(t, toolCalls, 1)
 838	require.True(t, toolCalls[0].ProviderExecuted)
 839	require.Equal(t, "web_search", toolCalls[0].ToolName)
 840	require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
 841
 842	// SourceContent entries for each search result.
 843	require.Len(t, sources, 2)
 844	require.Equal(t, "https://example.com/ai-news", sources[0].URL)
 845	require.Equal(t, "Latest AI News", sources[0].Title)
 846	require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
 847	require.Equal(t, "https://example.com/ml-update", sources[1].URL)
 848	require.Equal(t, "ML Update", sources[1].Title)
 849
 850	// ToolResultContent with provider metadata preserving encrypted_content.
 851	require.Len(t, toolResults, 1)
 852	require.True(t, toolResults[0].ProviderExecuted)
 853	require.Equal(t, "web_search", toolResults[0].ToolName)
 854	require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
 855
 856	searchMeta, ok := toolResults[0].ProviderMetadata[Name]
 857	require.True(t, ok, "providerMetadata should contain anthropic key")
 858	webMeta, ok := searchMeta.(*WebSearchResultMetadata)
 859	require.True(t, ok, "metadata should be *WebSearchResultMetadata")
 860	require.Len(t, webMeta.Results, 2)
 861	require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
 862	require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
 863	require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
 864
 865	// TextContent with the final answer.
 866	require.Len(t, texts, 1)
 867	require.Equal(t,
 868		"Based on recent search results, here is the latest AI news.",
 869		texts[0].Text,
 870	)
 871}
 872
 873func TestGenerate_WebSearchToolInRequest(t *testing.T) {
 874	t.Parallel()
 875
 876	t.Run("basic web_search tool", func(t *testing.T) {
 877		t.Parallel()
 878
 879		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 880		defer server.Close()
 881
 882		provider, err := New(
 883			WithAPIKey("test-api-key"),
 884			WithBaseURL(server.URL),
 885		)
 886		require.NoError(t, err)
 887
 888		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 889		require.NoError(t, err)
 890
 891		_, err = model.Generate(context.Background(), fantasy.Call{
 892			Prompt: testPrompt(),
 893			Tools: []fantasy.Tool{
 894				WebSearchTool(nil),
 895			},
 896		})
 897		require.NoError(t, err)
 898
 899		call := awaitAnthropicCall(t, calls)
 900		tools, ok := call.body["tools"].([]any)
 901		require.True(t, ok, "request body should have tools array")
 902		require.Len(t, tools, 1)
 903
 904		tool, ok := tools[0].(map[string]any)
 905		require.True(t, ok)
 906		require.Equal(t, "web_search_20250305", tool["type"])
 907	})
 908
 909	t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
 910		t.Parallel()
 911
 912		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 913		defer server.Close()
 914
 915		provider, err := New(
 916			WithAPIKey("test-api-key"),
 917			WithBaseURL(server.URL),
 918		)
 919		require.NoError(t, err)
 920
 921		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 922		require.NoError(t, err)
 923
 924		_, err = model.Generate(context.Background(), fantasy.Call{
 925			Prompt: testPrompt(),
 926			Tools: []fantasy.Tool{
 927				WebSearchTool(&WebSearchToolOptions{
 928					AllowedDomains: []string{"example.com", "test.com"},
 929				}),
 930			},
 931		})
 932		require.NoError(t, err)
 933
 934		call := awaitAnthropicCall(t, calls)
 935		tools, ok := call.body["tools"].([]any)
 936		require.True(t, ok)
 937		require.Len(t, tools, 1)
 938
 939		tool, ok := tools[0].(map[string]any)
 940		require.True(t, ok)
 941		require.Equal(t, "web_search_20250305", tool["type"])
 942
 943		domains, ok := tool["allowed_domains"].([]any)
 944		require.True(t, ok, "tool should have allowed_domains")
 945		require.Len(t, domains, 2)
 946		require.Equal(t, "example.com", domains[0])
 947		require.Equal(t, "test.com", domains[1])
 948	})
 949
 950	t.Run("with max uses and user location", func(t *testing.T) {
 951		t.Parallel()
 952
 953		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 954		defer server.Close()
 955
 956		provider, err := New(
 957			WithAPIKey("test-api-key"),
 958			WithBaseURL(server.URL),
 959		)
 960		require.NoError(t, err)
 961
 962		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 963		require.NoError(t, err)
 964
 965		_, err = model.Generate(context.Background(), fantasy.Call{
 966			Prompt: testPrompt(),
 967			Tools: []fantasy.Tool{
 968				WebSearchTool(&WebSearchToolOptions{
 969					MaxUses: 5,
 970					UserLocation: &UserLocation{
 971						City:    "San Francisco",
 972						Country: "US",
 973					},
 974				}),
 975			},
 976		})
 977		require.NoError(t, err)
 978
 979		call := awaitAnthropicCall(t, calls)
 980		tools, ok := call.body["tools"].([]any)
 981		require.True(t, ok)
 982		require.Len(t, tools, 1)
 983
 984		tool, ok := tools[0].(map[string]any)
 985		require.True(t, ok)
 986		require.Equal(t, "web_search_20250305", tool["type"])
 987
 988		// max_uses is serialized as a JSON number; json.Unmarshal
 989		// into map[string]any decodes numbers as float64.
 990		maxUses, ok := tool["max_uses"].(float64)
 991		require.True(t, ok, "tool should have max_uses")
 992		require.Equal(t, float64(5), maxUses)
 993
 994		userLoc, ok := tool["user_location"].(map[string]any)
 995		require.True(t, ok, "tool should have user_location")
 996		require.Equal(t, "San Francisco", userLoc["city"])
 997		require.Equal(t, "US", userLoc["country"])
 998		require.Equal(t, "approximate", userLoc["type"])
 999	})
1000
1001	t.Run("with max uses", func(t *testing.T) {
1002		t.Parallel()
1003
1004		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1005		defer server.Close()
1006
1007		provider, err := New(
1008			WithAPIKey("test-api-key"),
1009			WithBaseURL(server.URL),
1010		)
1011		require.NoError(t, err)
1012
1013		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1014		require.NoError(t, err)
1015
1016		_, err = model.Generate(context.Background(), fantasy.Call{
1017			Prompt: testPrompt(),
1018			Tools: []fantasy.Tool{
1019				WebSearchTool(&WebSearchToolOptions{
1020					MaxUses: 3,
1021				}),
1022			},
1023		})
1024		require.NoError(t, err)
1025
1026		call := awaitAnthropicCall(t, calls)
1027		tools, ok := call.body["tools"].([]any)
1028		require.True(t, ok)
1029		require.Len(t, tools, 1)
1030
1031		tool, ok := tools[0].(map[string]any)
1032		require.True(t, ok)
1033		require.Equal(t, "web_search_20250305", tool["type"])
1034
1035		maxUses, ok := tool["max_uses"].(float64)
1036		require.True(t, ok, "tool should have max_uses")
1037		require.Equal(t, float64(3), maxUses)
1038	})
1039
1040	t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1041		t.Parallel()
1042
1043		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1044		defer server.Close()
1045
1046		provider, err := New(
1047			WithAPIKey("test-api-key"),
1048			WithBaseURL(server.URL),
1049		)
1050		require.NoError(t, err)
1051
1052		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1053		require.NoError(t, err)
1054
1055		baseTool := WebSearchTool(&WebSearchToolOptions{
1056			MaxUses:        7,
1057			BlockedDomains: []string{"example.com", "test.com"},
1058			UserLocation: &UserLocation{
1059				City:     "San Francisco",
1060				Region:   "CA",
1061				Country:  "US",
1062				Timezone: "America/Los_Angeles",
1063			},
1064		})
1065
1066		data, err := json.Marshal(baseTool)
1067		require.NoError(t, err)
1068
1069		var roundTripped fantasy.ProviderDefinedTool
1070		err = json.Unmarshal(data, &roundTripped)
1071		require.NoError(t, err)
1072
1073		_, err = model.Generate(context.Background(), fantasy.Call{
1074			Prompt: testPrompt(),
1075			Tools:  []fantasy.Tool{roundTripped},
1076		})
1077		require.NoError(t, err)
1078
1079		call := awaitAnthropicCall(t, calls)
1080		tools, ok := call.body["tools"].([]any)
1081		require.True(t, ok)
1082		require.Len(t, tools, 1)
1083
1084		tool, ok := tools[0].(map[string]any)
1085		require.True(t, ok)
1086		require.Equal(t, "web_search_20250305", tool["type"])
1087
1088		domains, ok := tool["blocked_domains"].([]any)
1089		require.True(t, ok, "tool should have blocked_domains")
1090		require.Len(t, domains, 2)
1091		require.Equal(t, "example.com", domains[0])
1092		require.Equal(t, "test.com", domains[1])
1093
1094		maxUses, ok := tool["max_uses"].(float64)
1095		require.True(t, ok, "tool should have max_uses")
1096		require.Equal(t, float64(7), maxUses)
1097
1098		userLoc, ok := tool["user_location"].(map[string]any)
1099		require.True(t, ok, "tool should have user_location")
1100		require.Equal(t, "San Francisco", userLoc["city"])
1101		require.Equal(t, "CA", userLoc["region"])
1102		require.Equal(t, "US", userLoc["country"])
1103		require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1104		require.Equal(t, "approximate", userLoc["type"])
1105	})
1106}
1107
1108func TestAnyToStringSlice(t *testing.T) {
1109	t.Parallel()
1110
1111	t.Run("from string slice", func(t *testing.T) {
1112		t.Parallel()
1113
1114		got := anyToStringSlice([]string{"example.com", ""})
1115		require.Equal(t, []string{"example.com", ""}, got)
1116	})
1117
1118	t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1119		t.Parallel()
1120
1121		got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1122		require.Equal(t, []string{"example.com", "test.com"}, got)
1123	})
1124
1125	t.Run("unsupported type", func(t *testing.T) {
1126		t.Parallel()
1127
1128		got := anyToStringSlice("example.com")
1129		require.Nil(t, got)
1130	})
1131}
1132
1133func TestAnyToInt64(t *testing.T) {
1134	t.Parallel()
1135
1136	tests := []struct {
1137		name   string
1138		input  any
1139		want   int64
1140		wantOK bool
1141	}{
1142		{name: "int64", input: int64(7), want: 7, wantOK: true},
1143		{name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1144		{name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1145		{name: "float64 non-integer", input: float64(7.5), wantOK: false},
1146		{name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1147		{name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1148		{name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1149		{name: "json number float", input: json.Number("4.2"), wantOK: false},
1150		{name: "nan", input: math.NaN(), wantOK: false},
1151		{name: "inf", input: math.Inf(1), wantOK: false},
1152		{name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1153	}
1154
1155	for _, tt := range tests {
1156		t.Run(tt.name, func(t *testing.T) {
1157			got, ok := anyToInt64(tt.input)
1158			require.Equal(t, tt.wantOK, ok)
1159			if tt.wantOK {
1160				require.Equal(t, tt.want, got)
1161			}
1162		})
1163	}
1164}
1165
1166func TestAnyToUserLocation(t *testing.T) {
1167	t.Parallel()
1168
1169	t.Run("pointer passthrough", func(t *testing.T) {
1170		t.Parallel()
1171
1172		input := &UserLocation{City: "San Francisco", Country: "US"}
1173		got := anyToUserLocation(input)
1174		require.Same(t, input, got)
1175	})
1176
1177	t.Run("struct value", func(t *testing.T) {
1178		t.Parallel()
1179
1180		got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1181		require.NotNil(t, got)
1182		require.Equal(t, "San Francisco", got.City)
1183		require.Equal(t, "US", got.Country)
1184	})
1185
1186	t.Run("map value", func(t *testing.T) {
1187		t.Parallel()
1188
1189		got := anyToUserLocation(map[string]any{
1190			"city":     "San Francisco",
1191			"region":   "CA",
1192			"country":  "US",
1193			"timezone": "America/Los_Angeles",
1194			"type":     "approximate",
1195		})
1196		require.NotNil(t, got)
1197		require.Equal(t, "San Francisco", got.City)
1198		require.Equal(t, "CA", got.Region)
1199		require.Equal(t, "US", got.Country)
1200		require.Equal(t, "America/Los_Angeles", got.Timezone)
1201	})
1202
1203	t.Run("empty map", func(t *testing.T) {
1204		t.Parallel()
1205
1206		got := anyToUserLocation(map[string]any{"type": "approximate"})
1207		require.Nil(t, got)
1208	})
1209
1210	t.Run("unsupported type", func(t *testing.T) {
1211		t.Parallel()
1212
1213		got := anyToUserLocation("San Francisco")
1214		require.Nil(t, got)
1215	})
1216}
1217
1218func TestStream_WebSearchResponse(t *testing.T) {
1219	t.Parallel()
1220
1221	// Build SSE chunks that simulate a web search streaming response.
1222	// The Anthropic SDK accumulates content blocks via
1223	// acc.Accumulate(event). We read the Content and ToolUseID
1224	// directly from struct fields instead of using AsAny(), which
1225	// avoids the SDK's re-marshal limitation that previously dropped
1226	// source data.
1227	webSearchResultContent, _ := json.Marshal([]any{
1228		map[string]any{
1229			"type":              "web_search_result",
1230			"url":               "https://example.com/ai-news",
1231			"title":             "Latest AI News",
1232			"encrypted_content": "encrypted_abc123",
1233			"page_age":          "2 hours ago",
1234		},
1235	})
1236
1237	chunks := []string{
1238		// message_start
1239		"event: message_start\n",
1240		`data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1241		// Block 0: server_tool_use
1242		"event: content_block_start\n",
1243		`data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1244		"event: content_block_stop\n",
1245		`data: {"type":"content_block_stop","index":0}` + "\n\n",
1246		// Block 1: web_search_tool_result
1247		"event: content_block_start\n",
1248		`data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1249		"event: content_block_stop\n",
1250		`data: {"type":"content_block_stop","index":1}` + "\n\n",
1251		// Block 2: text
1252		"event: content_block_start\n",
1253		`data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1254		"event: content_block_delta\n",
1255		`data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1256		"event: content_block_stop\n",
1257		`data: {"type":"content_block_stop","index":2}` + "\n\n",
1258		// message_stop
1259		"event: message_stop\n",
1260		`data: {"type":"message_stop"}` + "\n\n",
1261	}
1262
1263	server, calls := newAnthropicStreamingServer(chunks)
1264	defer server.Close()
1265
1266	provider, err := New(
1267		WithAPIKey("test-api-key"),
1268		WithBaseURL(server.URL),
1269	)
1270	require.NoError(t, err)
1271
1272	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1273	require.NoError(t, err)
1274
1275	stream, err := model.Stream(context.Background(), fantasy.Call{
1276		Prompt: testPrompt(),
1277		Tools: []fantasy.Tool{
1278			WebSearchTool(nil),
1279		},
1280	})
1281	require.NoError(t, err)
1282
1283	var parts []fantasy.StreamPart
1284	stream(func(part fantasy.StreamPart) bool {
1285		parts = append(parts, part)
1286		return true
1287	})
1288
1289	_ = awaitAnthropicCall(t, calls)
1290
1291	// Collect parts by type for assertions.
1292	var (
1293		toolInputStarts []fantasy.StreamPart
1294		toolCalls       []fantasy.StreamPart
1295		toolResults     []fantasy.StreamPart
1296		sourceParts     []fantasy.StreamPart
1297		textDeltas      []fantasy.StreamPart
1298	)
1299	for _, p := range parts {
1300		switch p.Type {
1301		case fantasy.StreamPartTypeToolInputStart:
1302			toolInputStarts = append(toolInputStarts, p)
1303		case fantasy.StreamPartTypeToolCall:
1304			toolCalls = append(toolCalls, p)
1305		case fantasy.StreamPartTypeToolResult:
1306			toolResults = append(toolResults, p)
1307		case fantasy.StreamPartTypeSource:
1308			sourceParts = append(sourceParts, p)
1309		case fantasy.StreamPartTypeTextDelta:
1310			textDeltas = append(textDeltas, p)
1311		}
1312	}
1313
1314	// server_tool_use emits a ToolInputStart with ProviderExecuted.
1315	require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1316	require.True(t, toolInputStarts[0].ProviderExecuted)
1317	require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1318
1319	// server_tool_use emits a ToolCall with ProviderExecuted.
1320	require.NotEmpty(t, toolCalls, "should have a tool call")
1321	require.True(t, toolCalls[0].ProviderExecuted)
1322
1323	// web_search_tool_result always emits a ToolResult even when
1324	// the SDK drops source data. The ToolUseID comes from the raw
1325	// union field as a fallback.
1326	require.NotEmpty(t, toolResults, "should have a tool result")
1327	require.True(t, toolResults[0].ProviderExecuted)
1328	require.Equal(t, "web_search", toolResults[0].ToolCallName)
1329	require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1330		"tool result ID should match the tool_use_id")
1331
1332	// Source parts are now correctly emitted by reading struct fields
1333	// directly instead of using AsAny().
1334	require.Len(t, sourceParts, 1)
1335	require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1336	require.Equal(t, "Latest AI News", sourceParts[0].Title)
1337	require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1338
1339	// Text block emits a text delta.
1340	require.NotEmpty(t, textDeltas, "should have text deltas")
1341	require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1342}
1343
1344func TestGenerate_ToolChoiceNone(t *testing.T) {
1345	t.Parallel()
1346
1347	server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1348	defer server.Close()
1349
1350	provider, err := New(
1351		WithAPIKey("test-api-key"),
1352		WithBaseURL(server.URL),
1353	)
1354	require.NoError(t, err)
1355
1356	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1357	require.NoError(t, err)
1358
1359	toolChoiceNone := fantasy.ToolChoiceNone
1360	_, err = model.Generate(context.Background(), fantasy.Call{
1361		Prompt: testPrompt(),
1362		Tools: []fantasy.Tool{
1363			WebSearchTool(nil),
1364		},
1365		ToolChoice: &toolChoiceNone,
1366	})
1367	require.NoError(t, err)
1368
1369	call := awaitAnthropicCall(t, calls)
1370	toolChoice, ok := call.body["tool_choice"].(map[string]any)
1371	require.True(t, ok, "request body should have tool_choice")
1372	require.Equal(t, "none", toolChoice["type"], "tool_choice should be 'none'")
1373}
1374
1375// --- Computer Use Tests ---
1376
1377// jsonRoundTripTool simulates a JSON round-trip on a
1378// ProviderDefinedTool so that its Args map contains float64
1379// values (as json.Unmarshal produces) rather than the int64
1380// values that NewComputerUseTool stores directly. The
1381// production toBetaTools code asserts float64.
1382func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
1383	t.Helper()
1384	pdt := tool.Definition()
1385	data, err := json.Marshal(pdt.Args)
1386	require.NoError(t, err)
1387	var args map[string]any
1388	require.NoError(t, json.Unmarshal(data, &args))
1389	pdt.Args = args
1390	return pdt
1391}
1392
1393func TestNewComputerUseTool(t *testing.T) {
1394	t.Parallel()
1395
1396	t.Run("creates tool with correct ID and name", func(t *testing.T) {
1397		t.Parallel()
1398		tool := NewComputerUseTool(ComputerUseToolOptions{
1399			DisplayWidthPx:  1920,
1400			DisplayHeightPx: 1080,
1401			ToolVersion:     ComputerUse20250124,
1402		}, noopComputerRun).Definition()
1403		require.Equal(t, "anthropic.computer", tool.ID)
1404		require.Equal(t, "computer", tool.Name)
1405		require.Equal(t, int64(1920), tool.Args["display_width_px"])
1406		require.Equal(t, int64(1080), tool.Args["display_height_px"])
1407		require.Equal(t, string(ComputerUse20250124), tool.Args["tool_version"])
1408	})
1409
1410	t.Run("includes optional fields when set", func(t *testing.T) {
1411		t.Parallel()
1412		displayNum := int64(1)
1413		enableZoom := true
1414		tool := NewComputerUseTool(ComputerUseToolOptions{
1415			DisplayWidthPx:  1024,
1416			DisplayHeightPx: 768,
1417			DisplayNumber:   &displayNum,
1418			EnableZoom:      &enableZoom,
1419			ToolVersion:     ComputerUse20251124,
1420			CacheControl:    &CacheControl{Type: "ephemeral"},
1421		}, noopComputerRun).Definition()
1422		require.Equal(t, int64(1), tool.Args["display_number"])
1423		require.Equal(t, true, tool.Args["enable_zoom"])
1424		require.NotNil(t, tool.Args["cache_control"])
1425	})
1426
1427	t.Run("omits optional fields when nil", func(t *testing.T) {
1428		t.Parallel()
1429		tool := NewComputerUseTool(ComputerUseToolOptions{
1430			DisplayWidthPx:  1920,
1431			DisplayHeightPx: 1080,
1432			ToolVersion:     ComputerUse20250124,
1433		}, noopComputerRun).Definition()
1434		_, hasDisplayNum := tool.Args["display_number"]
1435		_, hasEnableZoom := tool.Args["enable_zoom"]
1436		_, hasCacheControl := tool.Args["cache_control"]
1437		require.False(t, hasDisplayNum)
1438		require.False(t, hasEnableZoom)
1439		require.False(t, hasCacheControl)
1440	})
1441}
1442
1443func TestIsComputerUseTool(t *testing.T) {
1444	t.Parallel()
1445
1446	t.Run("returns true for computer use tool", func(t *testing.T) {
1447		t.Parallel()
1448		tool := NewComputerUseTool(ComputerUseToolOptions{
1449			DisplayWidthPx:  1920,
1450			DisplayHeightPx: 1080,
1451			ToolVersion:     ComputerUse20250124,
1452		}, noopComputerRun)
1453		require.True(t, IsComputerUseTool(tool.Definition()))
1454	})
1455
1456	t.Run("returns false for function tool", func(t *testing.T) {
1457		t.Parallel()
1458		tool := fantasy.FunctionTool{
1459			Name:        "test",
1460			Description: "test tool",
1461		}
1462		require.False(t, IsComputerUseTool(tool))
1463	})
1464
1465	t.Run("returns false for other provider defined tool", func(t *testing.T) {
1466		t.Parallel()
1467		tool := fantasy.ProviderDefinedTool{
1468			ID:   "other.tool",
1469			Name: "other",
1470		}
1471		require.False(t, IsComputerUseTool(tool))
1472	})
1473}
1474
1475func TestNeedsBetaAPI(t *testing.T) {
1476	t.Parallel()
1477
1478	lm := languageModel{options: options{}}
1479
1480	t.Run("returns false for empty tools", func(t *testing.T) {
1481		t.Parallel()
1482		_, _, _, betaFlags := lm.toTools(nil, nil, false)
1483		require.Empty(t, betaFlags)
1484		_, _, _, betaFlags = lm.toTools([]fantasy.Tool{}, nil, false)
1485		require.Empty(t, betaFlags)
1486	})
1487
1488	t.Run("returns false for only function tools", func(t *testing.T) {
1489		t.Parallel()
1490		tools := []fantasy.Tool{
1491			fantasy.FunctionTool{Name: "test"},
1492		}
1493		_, _, _, betaFlags := lm.toTools(tools, nil, false)
1494		require.Empty(t, betaFlags)
1495	})
1496
1497	t.Run("returns beta flags when computer use tool present", func(t *testing.T) {
1498		t.Parallel()
1499		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1500			DisplayWidthPx:  1920,
1501			DisplayHeightPx: 1080,
1502			ToolVersion:     ComputerUse20250124,
1503		}, noopComputerRun))
1504		tools := []fantasy.Tool{
1505			fantasy.FunctionTool{Name: "test"},
1506			cuTool,
1507		}
1508		_, _, _, betaFlags := lm.toTools(tools, nil, false)
1509		require.NotEmpty(t, betaFlags)
1510	})
1511}
1512
1513func TestComputerUseToolJSON(t *testing.T) {
1514	t.Parallel()
1515
1516	t.Run("builds JSON for version 20250124", func(t *testing.T) {
1517		t.Parallel()
1518		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1519			DisplayWidthPx:  1920,
1520			DisplayHeightPx: 1080,
1521			ToolVersion:     ComputerUse20250124,
1522		}, noopComputerRun))
1523		data, err := computerUseToolJSON(cuTool)
1524		require.NoError(t, err)
1525		var m map[string]any
1526		require.NoError(t, json.Unmarshal(data, &m))
1527		require.Equal(t, "computer_20250124", m["type"])
1528		require.Equal(t, "computer", m["name"])
1529		require.InDelta(t, 1920, m["display_width_px"], 0)
1530		require.InDelta(t, 1080, m["display_height_px"], 0)
1531	})
1532
1533	t.Run("builds JSON for version 20251124 with enable_zoom", func(t *testing.T) {
1534		t.Parallel()
1535		enableZoom := true
1536		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1537			DisplayWidthPx:  1024,
1538			DisplayHeightPx: 768,
1539			EnableZoom:      &enableZoom,
1540			ToolVersion:     ComputerUse20251124,
1541		}, noopComputerRun))
1542		data, err := computerUseToolJSON(cuTool)
1543		require.NoError(t, err)
1544		var m map[string]any
1545		require.NoError(t, json.Unmarshal(data, &m))
1546		require.Equal(t, "computer_20251124", m["type"])
1547		require.Equal(t, true, m["enable_zoom"])
1548	})
1549
1550	t.Run("handles int64 args without JSON round-trip", func(t *testing.T) {
1551		t.Parallel()
1552		// Direct construction stores int64 values.
1553		cuTool := NewComputerUseTool(ComputerUseToolOptions{
1554			DisplayWidthPx:  1920,
1555			DisplayHeightPx: 1080,
1556			ToolVersion:     ComputerUse20250124,
1557		}, noopComputerRun)
1558		data, err := computerUseToolJSON(cuTool.Definition())
1559		require.NoError(t, err)
1560		var m map[string]any
1561		require.NoError(t, json.Unmarshal(data, &m))
1562		require.InDelta(t, 1920, m["display_width_px"], 0)
1563	})
1564
1565	t.Run("returns error when version is missing", func(t *testing.T) {
1566		t.Parallel()
1567		pdt := fantasy.ProviderDefinedTool{
1568			ID:   "anthropic.computer",
1569			Name: "computer",
1570			Args: map[string]any{
1571				"display_width_px":  float64(1920),
1572				"display_height_px": float64(1080),
1573			},
1574		}
1575		_, err := computerUseToolJSON(pdt)
1576		require.Error(t, err)
1577			require.Contains(t, err.Error(), "tool_version arg is missing")	})
1578
1579	t.Run("returns error for unsupported version", func(t *testing.T) {
1580		t.Parallel()
1581		pdt := fantasy.ProviderDefinedTool{
1582			ID:   "anthropic.computer",
1583			Name: "computer",
1584			Args: map[string]any{
1585				"display_width_px":  float64(1920),
1586				"display_height_px": float64(1080),
1587				"tool_version":      "computer_99991231",
1588			},
1589		}
1590		_, err := computerUseToolJSON(pdt)
1591		require.Error(t, err)
1592		require.Contains(t, err.Error(), "unsupported")
1593	})
1594}
1595
1596func TestParseComputerUseInput_CoordinateValidation(t *testing.T) {
1597	t.Parallel()
1598
1599	t.Run("rejects coordinate with 1 element", func(t *testing.T) {
1600		t.Parallel()
1601		_, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100]}`)
1602		require.Error(t, err)
1603		require.Contains(t, err.Error(), "coordinate")
1604	})
1605
1606	t.Run("rejects coordinate with 3 elements", func(t *testing.T) {
1607		t.Parallel()
1608		_, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200,300]}`)
1609		require.Error(t, err)
1610		require.Contains(t, err.Error(), "coordinate")
1611	})
1612
1613	t.Run("rejects start_coordinate with 1 element", func(t *testing.T) {
1614		t.Parallel()
1615		_, err := ParseComputerUseInput(`{"action":"left_click_drag","coordinate":[100,200],"start_coordinate":[50]}`)
1616		require.Error(t, err)
1617		require.Contains(t, err.Error(), "start_coordinate")
1618	})
1619
1620	t.Run("rejects region with 3 elements", func(t *testing.T) {
1621		t.Parallel()
1622		_, err := ParseComputerUseInput(`{"action":"zoom","region":[10,20,30]}`)
1623		require.Error(t, err)
1624		require.Contains(t, err.Error(), "region")
1625	})
1626
1627	t.Run("accepts valid coordinate", func(t *testing.T) {
1628		t.Parallel()
1629		result, err := ParseComputerUseInput(`{"action":"left_click","coordinate":[100,200]}`)
1630		require.NoError(t, err)
1631		require.Equal(t, [2]int64{100, 200}, result.Coordinate)
1632	})
1633
1634	t.Run("accepts absent optional arrays", func(t *testing.T) {
1635		t.Parallel()
1636		result, err := ParseComputerUseInput(`{"action":"screenshot"}`)
1637		require.NoError(t, err)
1638		require.Equal(t, ActionScreenshot, result.Action)
1639	})
1640}
1641
1642func TestToTools_RawJSON(t *testing.T) {
1643	t.Parallel()
1644
1645	lm := languageModel{options: options{}}
1646
1647	cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1648		DisplayWidthPx:  1920,
1649		DisplayHeightPx: 1080,
1650		ToolVersion:     ComputerUse20250124,
1651	}, noopComputerRun))
1652
1653	tools := []fantasy.Tool{
1654		fantasy.FunctionTool{
1655			Name:        "weather",
1656			Description: "Get weather",
1657			InputSchema: map[string]any{
1658				"properties": map[string]any{
1659					"location": map[string]any{"type": "string"},
1660				},
1661				"required": []string{"location"},
1662			},
1663		},
1664		WebSearchTool(nil),
1665		cuTool,
1666	}
1667
1668	rawTools, toolChoice, warnings, betaFlags := lm.toTools(tools, nil, false)
1669
1670	require.Len(t, rawTools, 3)
1671	require.Nil(t, toolChoice)
1672	require.Empty(t, warnings)
1673	require.NotEmpty(t, betaFlags)
1674
1675	// Verify each raw tool is valid JSON.
1676	for i, raw := range rawTools {
1677		var m map[string]any
1678		require.NoError(t, json.Unmarshal(raw, &m), "tool %d should be valid JSON", i)
1679	}
1680
1681	// Check function tool.
1682	var funcTool map[string]any
1683	require.NoError(t, json.Unmarshal(rawTools[0], &funcTool))
1684	require.Equal(t, "weather", funcTool["name"])
1685
1686	// Check web search tool.
1687	var webTool map[string]any
1688	require.NoError(t, json.Unmarshal(rawTools[1], &webTool))
1689	require.Equal(t, "web_search_20250305", webTool["type"])
1690
1691	// Check computer use tool.
1692	var cuToolJSON map[string]any
1693	require.NoError(t, json.Unmarshal(rawTools[2], &cuToolJSON))
1694	require.Equal(t, "computer_20250124", cuToolJSON["type"])
1695	require.Equal(t, "computer", cuToolJSON["name"])
1696}
1697
1698func TestGenerate_BetaAPI(t *testing.T) {
1699	t.Parallel()
1700
1701	t.Run("sends beta header for computer use", func(t *testing.T) {
1702		t.Parallel()
1703
1704		var capturedHeaders http.Header
1705		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1706			capturedHeaders = r.Header.Clone()
1707			w.Header().Set("Content-Type", "application/json")
1708			_ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1709		}))
1710		defer server.Close()
1711
1712		provider, err := New(
1713			WithAPIKey("test-api-key"),
1714			WithBaseURL(server.URL),
1715		)
1716		require.NoError(t, err)
1717
1718		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1719		require.NoError(t, err)
1720
1721		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1722			DisplayWidthPx:  1920,
1723			DisplayHeightPx: 1080,
1724			ToolVersion:     ComputerUse20250124,
1725		}, noopComputerRun))
1726
1727		_, err = model.Generate(context.Background(), fantasy.Call{
1728			Prompt: testPrompt(),
1729			Tools:  []fantasy.Tool{cuTool},
1730		})
1731		require.NoError(t, err)
1732		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1733	})
1734
1735	t.Run("sends beta header for computer use 20251124", func(t *testing.T) {
1736		t.Parallel()
1737
1738		var capturedHeaders http.Header
1739		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1740			capturedHeaders = r.Header.Clone()
1741			w.Header().Set("Content-Type", "application/json")
1742			_ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
1743		}))
1744		defer server.Close()
1745
1746		provider, err := New(
1747			WithAPIKey("test-api-key"),
1748			WithBaseURL(server.URL),
1749		)
1750		require.NoError(t, err)
1751
1752		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1753		require.NoError(t, err)
1754
1755		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1756			DisplayWidthPx:  1920,
1757			DisplayHeightPx: 1080,
1758			ToolVersion:     ComputerUse20251124,
1759		}, noopComputerRun))
1760
1761		_, err = model.Generate(context.Background(), fantasy.Call{
1762			Prompt: testPrompt(),
1763			Tools:  []fantasy.Tool{cuTool},
1764		})
1765		require.NoError(t, err)
1766		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1767	})
1768
1769	t.Run("returns tool use from beta response", func(t *testing.T) {
1770		t.Parallel()
1771
1772		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1773			w.Header().Set("Content-Type", "application/json")
1774			_ = json.NewEncoder(w).Encode(map[string]any{
1775				"id":    "msg_01Test",
1776				"type":  "message",
1777				"role":  "assistant",
1778				"model": "claude-sonnet-4-20250514",
1779				"content": []any{
1780					map[string]any{
1781						"type":  "tool_use",
1782						"id":    "toolu_01",
1783						"name":  "computer",
1784						"input": map[string]any{"action": "screenshot"},
1785					},
1786				},
1787				"stop_reason": "tool_use",
1788				"usage": map[string]any{
1789					"input_tokens":  10,
1790					"output_tokens": 5,
1791					"cache_creation": map[string]any{
1792						"ephemeral_1h_input_tokens": 0,
1793						"ephemeral_5m_input_tokens": 0,
1794					},
1795					"cache_creation_input_tokens": 0,
1796					"cache_read_input_tokens":     0,
1797					"server_tool_use": map[string]any{
1798						"web_search_requests": 0,
1799					},
1800					"service_tier": "standard",
1801				},
1802			})
1803		}))
1804		defer server.Close()
1805
1806		provider, err := New(
1807			WithAPIKey("test-api-key"),
1808			WithBaseURL(server.URL),
1809		)
1810		require.NoError(t, err)
1811
1812		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1813		require.NoError(t, err)
1814
1815		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1816			DisplayWidthPx:  1920,
1817			DisplayHeightPx: 1080,
1818			ToolVersion:     ComputerUse20250124,
1819		}, noopComputerRun))
1820
1821		resp, err := model.Generate(context.Background(), fantasy.Call{
1822			Prompt: testPrompt(),
1823			Tools:  []fantasy.Tool{cuTool},
1824		})
1825		require.NoError(t, err)
1826
1827		toolCalls := resp.Content.ToolCalls()
1828		require.Len(t, toolCalls, 1)
1829		require.Equal(t, "computer", toolCalls[0].ToolName)
1830		require.Equal(t, "toolu_01", toolCalls[0].ToolCallID)
1831		require.Contains(t, toolCalls[0].Input, "screenshot")
1832		require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
1833
1834		// Verify typed parsing works on the tool call input.
1835		parsed, err := ParseComputerUseInput(toolCalls[0].Input)
1836		require.NoError(t, err)
1837		require.Equal(t, ActionScreenshot, parsed.Action)
1838	})
1839}
1840
1841func TestStream_BetaAPI(t *testing.T) {
1842	t.Parallel()
1843
1844	t.Run("streams via beta API for computer use", func(t *testing.T) {
1845		t.Parallel()
1846
1847		var capturedHeaders http.Header
1848		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1849			capturedHeaders = r.Header.Clone()
1850			w.Header().Set("Content-Type", "text/event-stream")
1851			w.Header().Set("Cache-Control", "no-cache")
1852			w.WriteHeader(http.StatusOK)
1853			chunks := []string{
1854				"event: message_start\n",
1855				"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1856				"event: message_stop\n",
1857				"data: {\"type\":\"message_stop\"}\n\n",
1858			}
1859			for _, chunk := range chunks {
1860				_, _ = fmt.Fprint(w, chunk)
1861				if flusher, ok := w.(http.Flusher); ok {
1862					flusher.Flush()
1863				}
1864			}
1865		}))
1866		defer server.Close()
1867
1868		provider, err := New(
1869			WithAPIKey("test-api-key"),
1870			WithBaseURL(server.URL),
1871		)
1872		require.NoError(t, err)
1873
1874		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1875		require.NoError(t, err)
1876
1877		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1878			DisplayWidthPx:  1920,
1879			DisplayHeightPx: 1080,
1880			ToolVersion:     ComputerUse20250124,
1881		}, noopComputerRun))
1882
1883		stream, err := model.Stream(context.Background(), fantasy.Call{
1884			Prompt: testPrompt(),
1885			Tools:  []fantasy.Tool{cuTool},
1886		})
1887		require.NoError(t, err)
1888
1889		stream(func(fantasy.StreamPart) bool { return true })
1890
1891		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-01-24")
1892	})
1893
1894	t.Run("streams via beta API for computer use 20251124", func(t *testing.T) {
1895		t.Parallel()
1896
1897		var capturedHeaders http.Header
1898		server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1899			capturedHeaders = r.Header.Clone()
1900			w.Header().Set("Content-Type", "text/event-stream")
1901			w.Header().Set("Cache-Control", "no-cache")
1902			w.WriteHeader(http.StatusOK)
1903			chunks := []string{
1904				"event: message_start\n",
1905				"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
1906				"event: message_stop\n",
1907				"data: {\"type\":\"message_stop\"}\n\n",
1908			}
1909			for _, chunk := range chunks {
1910				_, _ = fmt.Fprint(w, chunk)
1911				if flusher, ok := w.(http.Flusher); ok {
1912					flusher.Flush()
1913				}
1914			}
1915		}))
1916		defer server.Close()
1917
1918		provider, err := New(
1919			WithAPIKey("test-api-key"),
1920			WithBaseURL(server.URL),
1921		)
1922		require.NoError(t, err)
1923
1924		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1925		require.NoError(t, err)
1926
1927		cuTool := jsonRoundTripTool(t, NewComputerUseTool(ComputerUseToolOptions{
1928			DisplayWidthPx:  1920,
1929			DisplayHeightPx: 1080,
1930			ToolVersion:     ComputerUse20251124,
1931		}, noopComputerRun))
1932
1933		stream, err := model.Stream(context.Background(), fantasy.Call{
1934			Prompt: testPrompt(),
1935			Tools:  []fantasy.Tool{cuTool},
1936		})
1937		require.NoError(t, err)
1938
1939		stream(func(fantasy.StreamPart) bool { return true })
1940
1941		require.Contains(t, capturedHeaders.Get("Anthropic-Beta"), "computer-use-2025-11-24")
1942	})
1943}
1944
1945// TestGenerate_ComputerUseTool runs a multi-turn computer use session
1946// via model.Generate, passing the ExecutableProviderTool directly into
1947// Call.Tools (no .Definition(), no jsonRoundTripTool). The mock server
1948// walks through a scripted sequence of actions — screenshot, click,
1949// type, key, scroll — then finishes with a text reply. Each turn the
1950// test parses the tool call, builds a screenshot result, and appends
1951// both to the prompt for the next request.
1952func TestGenerate_ComputerUseTool(t *testing.T) {
1953	t.Parallel()
1954
1955	type actionStep struct {
1956		input map[string]any
1957		want  ComputerUseInput
1958	}
1959	steps := []actionStep{
1960		{
1961			input: map[string]any{"action": "screenshot"},
1962			want:  ComputerUseInput{Action: ActionScreenshot},
1963		},
1964		{
1965			input: map[string]any{"action": "left_click", "coordinate": []any{100, 200}},
1966			want:  ComputerUseInput{Action: ActionLeftClick, Coordinate: [2]int64{100, 200}},
1967		},
1968		{
1969			input: map[string]any{"action": "type", "text": "hello world"},
1970			want:  ComputerUseInput{Action: ActionType, Text: "hello world"},
1971		},
1972		{
1973			input: map[string]any{"action": "key", "text": "Return"},
1974			want:  ComputerUseInput{Action: ActionKey, Text: "Return"},
1975		},
1976		{
1977			input: map[string]any{
1978				"action":           "scroll",
1979				"coordinate":       []any{500, 300},
1980				"scroll_direction": "down",
1981				"scroll_amount":    3,
1982			},
1983			want: ComputerUseInput{
1984				Action:          ActionScroll,
1985				Coordinate:      [2]int64{500, 300},
1986				ScrollDirection: "down",
1987				ScrollAmount:    3,
1988			},
1989		},
1990		{
1991			input: map[string]any{"action": "screenshot"},
1992			want:  ComputerUseInput{Action: ActionScreenshot},
1993		},
1994	}
1995
1996	var (
1997		requestIdx  int
1998		betaHeaders []string
1999	)
2000	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2001		betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2002		idx := requestIdx
2003		requestIdx++
2004
2005		w.Header().Set("Content-Type", "application/json")
2006		if idx < len(steps) {
2007			_ = json.NewEncoder(w).Encode(map[string]any{
2008				"id":    fmt.Sprintf("msg_%02d", idx),
2009				"type":  "message",
2010				"role":  "assistant",
2011				"model": "claude-sonnet-4-20250514",
2012				"content": []any{map[string]any{
2013					"type":  "tool_use",
2014					"id":    fmt.Sprintf("toolu_%02d", idx),
2015					"name":  "computer",
2016					"input": steps[idx].input,
2017				}},
2018				"stop_reason": "tool_use",
2019				"usage":       map[string]any{"input_tokens": 10, "output_tokens": 5},
2020			})
2021			return
2022		}
2023		_ = json.NewEncoder(w).Encode(map[string]any{
2024			"id":    "msg_final",
2025			"type":  "message",
2026			"role":  "assistant",
2027			"model": "claude-sonnet-4-20250514",
2028			"content": []any{map[string]any{
2029				"type": "text",
2030				"text": "Done! I have completed all the requested actions.",
2031			}},
2032			"stop_reason": "end_turn",
2033			"usage":       map[string]any{"input_tokens": 10, "output_tokens": 15},
2034		})
2035	}))
2036	defer server.Close()
2037
2038	provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2039	require.NoError(t, err)
2040
2041	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2042	require.NoError(t, err)
2043
2044	// Pass the ExecutableProviderTool directly — the whole point is
2045	// to verify that the Tool interface works without unwrapping.
2046	cuTool := NewComputerUseTool(ComputerUseToolOptions{
2047		DisplayWidthPx:  1920,
2048		DisplayHeightPx: 1080,
2049		ToolVersion:     ComputerUse20250124,
2050	}, noopComputerRun)
2051
2052	var got []ComputerUseInput
2053	prompt := testPrompt()
2054	fakePNG := []byte("fake-screenshot-png")
2055
2056	for turn := 0; turn <= len(steps); turn++ {
2057		resp, err := model.Generate(context.Background(), fantasy.Call{
2058			Prompt: prompt,
2059			Tools:  []fantasy.Tool{cuTool},
2060		})
2061		require.NoError(t, err, "turn %d", turn)
2062
2063		if resp.FinishReason != fantasy.FinishReasonToolCalls {
2064			require.Equal(t, fantasy.FinishReasonStop, resp.FinishReason)
2065			require.Contains(t, resp.Content.Text(), "Done")
2066			break
2067		}
2068
2069		toolCalls := resp.Content.ToolCalls()
2070		require.Len(t, toolCalls, 1, "turn %d", turn)
2071		require.Equal(t, "computer", toolCalls[0].ToolName, "turn %d", turn)
2072
2073		parsed, err := ParseComputerUseInput(toolCalls[0].Input)
2074		require.NoError(t, err, "turn %d", turn)
2075		got = append(got, parsed)
2076
2077		// Build the next prompt: append the assistant tool-call turn
2078		// and the user screenshot-result turn.
2079		prompt = append(prompt,
2080			fantasy.Message{
2081				Role: fantasy.MessageRoleAssistant,
2082				Content: []fantasy.MessagePart{
2083					fantasy.ToolCallPart{
2084						ToolCallID: toolCalls[0].ToolCallID,
2085						ToolName:   toolCalls[0].ToolName,
2086						Input:      toolCalls[0].Input,
2087					},
2088				},
2089			},
2090			fantasy.Message{
2091				// Use MessageRoleTool for tool results — this matches
2092				// what the agent loop produces.
2093				Role: fantasy.MessageRoleTool,
2094				Content: []fantasy.MessagePart{
2095					NewComputerUseScreenshotResult(toolCalls[0].ToolCallID, fakePNG),
2096				},
2097			},
2098		)
2099	}
2100
2101	// Every scripted action was received and parsed correctly.
2102	require.Len(t, got, len(steps))
2103	for i, step := range steps {
2104		require.Equal(t, step.want.Action, got[i].Action, "step %d", i)
2105		require.Equal(t, step.want.Coordinate, got[i].Coordinate, "step %d", i)
2106		require.Equal(t, step.want.Text, got[i].Text, "step %d", i)
2107		require.Equal(t, step.want.ScrollDirection, got[i].ScrollDirection, "step %d", i)
2108		require.Equal(t, step.want.ScrollAmount, got[i].ScrollAmount, "step %d", i)
2109	}
2110
2111	// Beta header was sent on every request.
2112	require.Len(t, betaHeaders, len(steps)+1)
2113	for i, h := range betaHeaders {
2114		require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2115	}
2116}
2117
2118// TestStream_ComputerUseTool runs a multi-turn computer use session
2119// via model.Stream, verifying that the ExecutableProviderTool works
2120// through the streaming path end-to-end.
2121func TestStream_ComputerUseTool(t *testing.T) {
2122	t.Parallel()
2123
2124	type streamStep struct {
2125		input      map[string]any
2126		wantAction ComputerAction
2127	}
2128	steps := []streamStep{
2129		{input: map[string]any{"action": "screenshot"}, wantAction: ActionScreenshot},
2130		{input: map[string]any{"action": "left_click", "coordinate": []any{150, 250}}, wantAction: ActionLeftClick},
2131		{input: map[string]any{"action": "type", "text": "search query"}, wantAction: ActionType},
2132	}
2133
2134	var (
2135		requestIdx  int
2136		betaHeaders []string
2137	)
2138
2139	// streamToolUseChunks returns SSE chunks for a single
2140	// computer-use tool_use content block.
2141	streamToolUseChunks := func(id string, input map[string]any) []string {
2142		inputJSON, _ := json.Marshal(input)
2143		escaped := strings.ReplaceAll(string(inputJSON), `"`, `\"`)
2144		return []string{
2145			"event: message_start\n",
2146			`data: {"type":"message_start","message":{"id":"` + id + `","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2147			"event: content_block_start\n",
2148			`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"` + id + `","name":"computer","input":{}}}` + "\n\n",
2149			"event: content_block_delta\n",
2150			`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"` + escaped + `"}}` + "\n\n",
2151			"event: content_block_stop\n",
2152			`data: {"type":"content_block_stop","index":0}` + "\n\n",
2153			"event: message_delta\n",
2154			`data: {"type":"message_delta","delta":{"stop_reason":"tool_use"},"usage":{"output_tokens":5}}` + "\n\n",
2155			"event: message_stop\n",
2156			`data: {"type":"message_stop"}` + "\n\n",
2157		}
2158	}
2159
2160	streamTextChunks := func() []string {
2161		return []string{
2162			"event: message_start\n",
2163			`data: {"type":"message_start","message":{"id":"msg_final","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":0}}}` + "\n\n",
2164			"event: content_block_start\n",
2165			`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + "\n\n",
2166			"event: content_block_delta\n",
2167			`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"All done."}}` + "\n\n",
2168			"event: content_block_stop\n",
2169			`data: {"type":"content_block_stop","index":0}` + "\n\n",
2170			"event: message_delta\n",
2171			`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}}` + "\n\n",
2172			"event: message_stop\n",
2173			`data: {"type":"message_stop"}` + "\n\n",
2174		}
2175	}
2176
2177	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2178		betaHeaders = append(betaHeaders, r.Header.Get("Anthropic-Beta"))
2179		idx := requestIdx
2180		requestIdx++
2181
2182		w.Header().Set("Content-Type", "text/event-stream")
2183		w.Header().Set("Cache-Control", "no-cache")
2184		w.WriteHeader(http.StatusOK)
2185
2186		var chunks []string
2187		if idx < len(steps) {
2188			chunks = streamToolUseChunks(
2189				fmt.Sprintf("toolu_%02d", idx),
2190				steps[idx].input,
2191			)
2192		} else {
2193			chunks = streamTextChunks()
2194		}
2195		for _, chunk := range chunks {
2196			_, _ = fmt.Fprint(w, chunk)
2197			if f, ok := w.(http.Flusher); ok {
2198				f.Flush()
2199			}
2200		}
2201	}))
2202	defer server.Close()
2203
2204	provider, err := New(WithAPIKey("test-api-key"), WithBaseURL(server.URL))
2205	require.NoError(t, err)
2206
2207	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
2208	require.NoError(t, err)
2209
2210	cuTool := NewComputerUseTool(ComputerUseToolOptions{
2211		DisplayWidthPx:  1920,
2212		DisplayHeightPx: 1080,
2213		ToolVersion:     ComputerUse20250124,
2214	}, noopComputerRun)
2215
2216	var gotActions []ComputerAction
2217	prompt := testPrompt()
2218	fakePNG := []byte("fake-screenshot-png")
2219
2220	for turn := 0; turn <= len(steps); turn++ {
2221		stream, err := model.Stream(context.Background(), fantasy.Call{
2222			Prompt: prompt,
2223			Tools:  []fantasy.Tool{cuTool},
2224		})
2225		require.NoError(t, err, "turn %d", turn)
2226
2227		var (
2228			toolCallName  string
2229			toolCallID    string
2230			toolCallInput string
2231			finishReason  fantasy.FinishReason
2232			gotText       string
2233		)
2234		stream(func(part fantasy.StreamPart) bool {
2235			switch part.Type {
2236			case fantasy.StreamPartTypeToolCall:
2237				toolCallName = part.ToolCallName
2238				toolCallID = part.ID
2239				toolCallInput = part.ToolCallInput
2240			case fantasy.StreamPartTypeFinish:
2241				finishReason = part.FinishReason
2242			case fantasy.StreamPartTypeTextDelta:
2243				gotText += part.Delta
2244			}
2245			return true
2246		})
2247
2248		if finishReason != fantasy.FinishReasonToolCalls {
2249			require.Contains(t, gotText, "All done")
2250			break
2251		}
2252
2253		require.Equal(t, "computer", toolCallName, "turn %d", turn)
2254
2255		parsed, err := ParseComputerUseInput(toolCallInput)
2256		require.NoError(t, err, "turn %d", turn)
2257		gotActions = append(gotActions, parsed.Action)
2258
2259		prompt = append(prompt,
2260			fantasy.Message{
2261				Role: fantasy.MessageRoleAssistant,
2262				Content: []fantasy.MessagePart{
2263					fantasy.ToolCallPart{
2264						ToolCallID: toolCallID,
2265						ToolName:   toolCallName,
2266						Input:      toolCallInput,
2267					},
2268				},
2269			},
2270			fantasy.Message{
2271				// Use MessageRoleTool for tool results — this matches
2272				// what the agent loop produces.
2273				Role: fantasy.MessageRoleTool,
2274				Content: []fantasy.MessagePart{
2275					NewComputerUseScreenshotResult(toolCallID, fakePNG),
2276				},
2277			},
2278		)
2279	}
2280
2281	require.Len(t, gotActions, len(steps))
2282	for i, step := range steps {
2283		require.Equal(t, step.wantAction, gotActions[i], "step %d", i)
2284	}
2285
2286	require.Len(t, betaHeaders, len(steps)+1)
2287	for i, h := range betaHeaders {
2288		require.Contains(t, h, "computer-use-2025-01-24", "request %d", i)
2289	}
2290}