anthropic_test.go

   1package anthropic
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"math"
   9	"net/http"
  10	"net/http/httptest"
  11	"testing"
  12	"time"
  13
  14	"charm.land/fantasy"
  15	"github.com/charmbracelet/anthropic-sdk-go"
  16	"github.com/stretchr/testify/require"
  17)
  18
  19func TestToPrompt_DropsEmptyMessages(t *testing.T) {
  20	t.Parallel()
  21
  22	t.Run("should drop assistant messages with only reasoning content", func(t *testing.T) {
  23		t.Parallel()
  24
  25		prompt := fantasy.Prompt{
  26			{
  27				Role: fantasy.MessageRoleUser,
  28				Content: []fantasy.MessagePart{
  29					fantasy.TextPart{Text: "Hello"},
  30				},
  31			},
  32			{
  33				Role: fantasy.MessageRoleAssistant,
  34				Content: []fantasy.MessagePart{
  35					fantasy.ReasoningPart{
  36						Text: "Let me think about this...",
  37						ProviderOptions: fantasy.ProviderOptions{
  38							Name: &ReasoningOptionMetadata{
  39								Signature: "abc123",
  40							},
  41						},
  42					},
  43				},
  44			},
  45		}
  46
  47		systemBlocks, messages, warnings := toPrompt(prompt, true)
  48
  49		require.Empty(t, systemBlocks)
  50		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  51		require.Len(t, warnings, 1)
  52		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  53		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
  54		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool calls")
  55	})
  56
  57	t.Run("should drop assistant reasoning when sendReasoning disabled", func(t *testing.T) {
  58		t.Parallel()
  59
  60		prompt := fantasy.Prompt{
  61			{
  62				Role: fantasy.MessageRoleUser,
  63				Content: []fantasy.MessagePart{
  64					fantasy.TextPart{Text: "Hello"},
  65				},
  66			},
  67			{
  68				Role: fantasy.MessageRoleAssistant,
  69				Content: []fantasy.MessagePart{
  70					fantasy.ReasoningPart{
  71						Text: "Let me think about this...",
  72						ProviderOptions: fantasy.ProviderOptions{
  73							Name: &ReasoningOptionMetadata{
  74								Signature: "def456",
  75							},
  76						},
  77					},
  78				},
  79			},
  80		}
  81
  82		systemBlocks, messages, warnings := toPrompt(prompt, false)
  83
  84		require.Empty(t, systemBlocks)
  85		require.Len(t, messages, 1, "should only have user message, assistant message should be dropped")
  86		require.Len(t, warnings, 2)
  87		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
  88		require.Contains(t, warnings[0].Message, "sending reasoning content is disabled")
  89		require.Equal(t, fantasy.CallWarningTypeOther, warnings[1].Type)
  90		require.Contains(t, warnings[1].Message, "dropping empty assistant message")
  91	})
  92
  93	t.Run("should drop truly empty assistant messages", func(t *testing.T) {
  94		t.Parallel()
  95
  96		prompt := fantasy.Prompt{
  97			{
  98				Role: fantasy.MessageRoleUser,
  99				Content: []fantasy.MessagePart{
 100					fantasy.TextPart{Text: "Hello"},
 101				},
 102			},
 103			{
 104				Role:    fantasy.MessageRoleAssistant,
 105				Content: []fantasy.MessagePart{},
 106			},
 107		}
 108
 109		systemBlocks, messages, warnings := toPrompt(prompt, true)
 110
 111		require.Empty(t, systemBlocks)
 112		require.Len(t, messages, 1, "should only have user message")
 113		require.Len(t, warnings, 1)
 114		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 115		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 116	})
 117
 118	t.Run("should keep assistant messages with text content", func(t *testing.T) {
 119		t.Parallel()
 120
 121		prompt := fantasy.Prompt{
 122			{
 123				Role: fantasy.MessageRoleUser,
 124				Content: []fantasy.MessagePart{
 125					fantasy.TextPart{Text: "Hello"},
 126				},
 127			},
 128			{
 129				Role: fantasy.MessageRoleAssistant,
 130				Content: []fantasy.MessagePart{
 131					fantasy.TextPart{Text: "Hi there!"},
 132				},
 133			},
 134		}
 135
 136		systemBlocks, messages, warnings := toPrompt(prompt, true)
 137
 138		require.Empty(t, systemBlocks)
 139		require.Len(t, messages, 2, "should have both user and assistant messages")
 140		require.Empty(t, warnings)
 141	})
 142
 143	t.Run("should keep assistant messages with tool calls", func(t *testing.T) {
 144		t.Parallel()
 145
 146		prompt := fantasy.Prompt{
 147			{
 148				Role: fantasy.MessageRoleUser,
 149				Content: []fantasy.MessagePart{
 150					fantasy.TextPart{Text: "What's the weather?"},
 151				},
 152			},
 153			{
 154				Role: fantasy.MessageRoleAssistant,
 155				Content: []fantasy.MessagePart{
 156					fantasy.ToolCallPart{
 157						ToolCallID: "call_123",
 158						ToolName:   "get_weather",
 159						Input:      `{"location":"NYC"}`,
 160					},
 161				},
 162			},
 163		}
 164
 165		systemBlocks, messages, warnings := toPrompt(prompt, true)
 166
 167		require.Empty(t, systemBlocks)
 168		require.Len(t, messages, 2, "should have both user and assistant messages")
 169		require.Empty(t, warnings)
 170	})
 171
 172	t.Run("should drop assistant messages with invalid tool input", func(t *testing.T) {
 173		t.Parallel()
 174
 175		prompt := fantasy.Prompt{
 176			{
 177				Role: fantasy.MessageRoleUser,
 178				Content: []fantasy.MessagePart{
 179					fantasy.TextPart{Text: "Hi"},
 180				},
 181			},
 182			{
 183				Role: fantasy.MessageRoleAssistant,
 184				Content: []fantasy.MessagePart{
 185					fantasy.ToolCallPart{
 186						ToolCallID: "call_123",
 187						ToolName:   "get_weather",
 188						Input:      "{not-json",
 189					},
 190				},
 191			},
 192		}
 193
 194		systemBlocks, messages, warnings := toPrompt(prompt, true)
 195
 196		require.Empty(t, systemBlocks)
 197		require.Len(t, messages, 1, "should only have user message")
 198		require.Len(t, warnings, 1)
 199		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 200		require.Contains(t, warnings[0].Message, "dropping empty assistant message")
 201	})
 202
 203	t.Run("should keep assistant messages with reasoning and text", func(t *testing.T) {
 204		t.Parallel()
 205
 206		prompt := fantasy.Prompt{
 207			{
 208				Role: fantasy.MessageRoleUser,
 209				Content: []fantasy.MessagePart{
 210					fantasy.TextPart{Text: "Hello"},
 211				},
 212			},
 213			{
 214				Role: fantasy.MessageRoleAssistant,
 215				Content: []fantasy.MessagePart{
 216					fantasy.ReasoningPart{
 217						Text: "Let me think...",
 218						ProviderOptions: fantasy.ProviderOptions{
 219							Name: &ReasoningOptionMetadata{
 220								Signature: "abc123",
 221							},
 222						},
 223					},
 224					fantasy.TextPart{Text: "Hi there!"},
 225				},
 226			},
 227		}
 228
 229		systemBlocks, messages, warnings := toPrompt(prompt, true)
 230
 231		require.Empty(t, systemBlocks)
 232		require.Len(t, messages, 2, "should have both user and assistant messages")
 233		require.Empty(t, warnings)
 234	})
 235
 236	t.Run("should keep user messages with image content", func(t *testing.T) {
 237		t.Parallel()
 238
 239		prompt := fantasy.Prompt{
 240			{
 241				Role: fantasy.MessageRoleUser,
 242				Content: []fantasy.MessagePart{
 243					fantasy.FilePart{
 244						Data:      []byte{0x01, 0x02, 0x03},
 245						MediaType: "image/png",
 246					},
 247				},
 248			},
 249		}
 250
 251		systemBlocks, messages, warnings := toPrompt(prompt, true)
 252
 253		require.Empty(t, systemBlocks)
 254		require.Len(t, messages, 1)
 255		require.Empty(t, warnings)
 256	})
 257
 258	t.Run("should drop user messages without visible content", func(t *testing.T) {
 259		t.Parallel()
 260
 261		prompt := fantasy.Prompt{
 262			{
 263				Role: fantasy.MessageRoleUser,
 264				Content: []fantasy.MessagePart{
 265					fantasy.FilePart{
 266						Data:      []byte("not supported"),
 267						MediaType: "application/pdf",
 268					},
 269				},
 270			},
 271		}
 272
 273		systemBlocks, messages, warnings := toPrompt(prompt, true)
 274
 275		require.Empty(t, systemBlocks)
 276		require.Empty(t, messages)
 277		require.Len(t, warnings, 1)
 278		require.Equal(t, fantasy.CallWarningTypeOther, warnings[0].Type)
 279		require.Contains(t, warnings[0].Message, "dropping empty user message")
 280		require.Contains(t, warnings[0].Message, "neither user-facing content nor tool results")
 281	})
 282
 283	t.Run("should keep user messages with tool results", func(t *testing.T) {
 284		t.Parallel()
 285
 286		prompt := fantasy.Prompt{
 287			{
 288				Role: fantasy.MessageRoleTool,
 289				Content: []fantasy.MessagePart{
 290					fantasy.ToolResultPart{
 291						ToolCallID: "call_123",
 292						Output:     fantasy.ToolResultOutputContentText{Text: "done"},
 293					},
 294				},
 295			},
 296		}
 297
 298		systemBlocks, messages, warnings := toPrompt(prompt, true)
 299
 300		require.Empty(t, systemBlocks)
 301		require.Len(t, messages, 1)
 302		require.Empty(t, warnings)
 303	})
 304
 305	t.Run("should keep user messages with tool error results", func(t *testing.T) {
 306		t.Parallel()
 307
 308		prompt := fantasy.Prompt{
 309			{
 310				Role: fantasy.MessageRoleTool,
 311				Content: []fantasy.MessagePart{
 312					fantasy.ToolResultPart{
 313						ToolCallID: "call_456",
 314						Output:     fantasy.ToolResultOutputContentError{Error: errors.New("boom")},
 315					},
 316				},
 317			},
 318		}
 319
 320		systemBlocks, messages, warnings := toPrompt(prompt, true)
 321
 322		require.Empty(t, systemBlocks)
 323		require.Len(t, messages, 1)
 324		require.Empty(t, warnings)
 325	})
 326
 327	t.Run("should keep user messages with tool media results", func(t *testing.T) {
 328		t.Parallel()
 329
 330		prompt := fantasy.Prompt{
 331			{
 332				Role: fantasy.MessageRoleTool,
 333				Content: []fantasy.MessagePart{
 334					fantasy.ToolResultPart{
 335						ToolCallID: "call_789",
 336						Output: fantasy.ToolResultOutputContentMedia{
 337							Data:      "AQID",
 338							MediaType: "image/png",
 339						},
 340					},
 341				},
 342			},
 343		}
 344
 345		systemBlocks, messages, warnings := toPrompt(prompt, true)
 346
 347		require.Empty(t, systemBlocks)
 348		require.Len(t, messages, 1)
 349		require.Empty(t, warnings)
 350	})
 351}
 352
 353func TestParseContextTooLargeError(t *testing.T) {
 354	t.Parallel()
 355
 356	tests := []struct {
 357		name     string
 358		message  string
 359		wantErr  bool
 360		wantUsed int
 361		wantMax  int
 362	}{
 363		{
 364			name:     "matches anthropic format",
 365			message:  "prompt is too long: 202630 tokens > 200000 maximum",
 366			wantErr:  true,
 367			wantUsed: 202630,
 368			wantMax:  200000,
 369		},
 370		{
 371			name:     "matches with different numbers",
 372			message:  "prompt is too long: 150000 tokens > 128000 maximum",
 373			wantErr:  true,
 374			wantUsed: 150000,
 375			wantMax:  128000,
 376		},
 377		{
 378			name:     "matches with extra whitespace",
 379			message:  "prompt is too long:  202630  tokens  >  200000  maximum",
 380			wantErr:  true,
 381			wantUsed: 202630,
 382			wantMax:  200000,
 383		},
 384		{
 385			name:    "does not match unrelated error",
 386			message: "invalid api key",
 387			wantErr: false,
 388		},
 389		{
 390			name:    "does not match rate limit error",
 391			message: "rate limit exceeded",
 392			wantErr: false,
 393		},
 394	}
 395
 396	for _, tt := range tests {
 397		t.Run(tt.name, func(t *testing.T) {
 398			t.Parallel()
 399			providerErr := &fantasy.ProviderError{Message: tt.message}
 400			parseContextTooLargeError(tt.message, providerErr)
 401
 402			if tt.wantErr {
 403				require.True(t, providerErr.IsContextTooLarge())
 404				require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
 405				require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
 406			} else {
 407				require.False(t, providerErr.IsContextTooLarge())
 408			}
 409		})
 410	}
 411}
 412
 413func TestParseOptions_Effort(t *testing.T) {
 414	t.Parallel()
 415
 416	options, err := ParseOptions(map[string]any{
 417		"send_reasoning":            true,
 418		"thinking":                  map[string]any{"budget_tokens": int64(2048)},
 419		"effort":                    "medium",
 420		"disable_parallel_tool_use": true,
 421	})
 422	require.NoError(t, err)
 423	require.NotNil(t, options.SendReasoning)
 424	require.True(t, *options.SendReasoning)
 425	require.NotNil(t, options.Thinking)
 426	require.Equal(t, int64(2048), options.Thinking.BudgetTokens)
 427	require.NotNil(t, options.Effort)
 428	require.Equal(t, EffortMedium, *options.Effort)
 429	require.NotNil(t, options.DisableParallelToolUse)
 430	require.True(t, *options.DisableParallelToolUse)
 431}
 432
 433func TestGenerate_SendsOutputConfigEffort(t *testing.T) {
 434	t.Parallel()
 435
 436	server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 437	defer server.Close()
 438
 439	provider, err := New(
 440		WithAPIKey("test-api-key"),
 441		WithBaseURL(server.URL),
 442	)
 443	require.NoError(t, err)
 444
 445	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 446	require.NoError(t, err)
 447
 448	effort := EffortMedium
 449	_, err = model.Generate(context.Background(), fantasy.Call{
 450		Prompt: testPrompt(),
 451		ProviderOptions: NewProviderOptions(&ProviderOptions{
 452			Effort: &effort,
 453		}),
 454	})
 455	require.NoError(t, err)
 456
 457	call := awaitAnthropicCall(t, calls)
 458	require.Equal(t, "POST", call.method)
 459	require.Equal(t, "/v1/messages", call.path)
 460	requireAnthropicEffort(t, call.body, EffortMedium)
 461}
 462
 463func TestStream_SendsOutputConfigEffort(t *testing.T) {
 464	t.Parallel()
 465
 466	server, calls := newAnthropicStreamingServer([]string{
 467		"event: message_start\n",
 468		"data: {\"type\":\"message_start\",\"message\":{}}\n\n",
 469		"event: message_stop\n",
 470		"data: {\"type\":\"message_stop\"}\n\n",
 471	})
 472	defer server.Close()
 473
 474	provider, err := New(
 475		WithAPIKey("test-api-key"),
 476		WithBaseURL(server.URL),
 477	)
 478	require.NoError(t, err)
 479
 480	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 481	require.NoError(t, err)
 482
 483	effort := EffortHigh
 484	stream, err := model.Stream(context.Background(), fantasy.Call{
 485		Prompt: testPrompt(),
 486		ProviderOptions: NewProviderOptions(&ProviderOptions{
 487			Effort: &effort,
 488		}),
 489	})
 490	require.NoError(t, err)
 491
 492	stream(func(fantasy.StreamPart) bool { return true })
 493
 494	call := awaitAnthropicCall(t, calls)
 495	require.Equal(t, "POST", call.method)
 496	require.Equal(t, "/v1/messages", call.path)
 497	requireAnthropicEffort(t, call.body, EffortHigh)
 498}
 499
 500type anthropicCall struct {
 501	method string
 502	path   string
 503	body   map[string]any
 504}
 505
 506func newAnthropicJSONServer(response map[string]any) (*httptest.Server, <-chan anthropicCall) {
 507	calls := make(chan anthropicCall, 4)
 508
 509	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 510		var body map[string]any
 511		if r.Body != nil {
 512			_ = json.NewDecoder(r.Body).Decode(&body)
 513		}
 514
 515		calls <- anthropicCall{
 516			method: r.Method,
 517			path:   r.URL.Path,
 518			body:   body,
 519		}
 520
 521		w.Header().Set("Content-Type", "application/json")
 522		_ = json.NewEncoder(w).Encode(response)
 523	}))
 524
 525	return server, calls
 526}
 527
 528func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anthropicCall) {
 529	calls := make(chan anthropicCall, 4)
 530
 531	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 532		var body map[string]any
 533		if r.Body != nil {
 534			_ = json.NewDecoder(r.Body).Decode(&body)
 535		}
 536
 537		calls <- anthropicCall{
 538			method: r.Method,
 539			path:   r.URL.Path,
 540			body:   body,
 541		}
 542
 543		w.Header().Set("Content-Type", "text/event-stream")
 544		w.Header().Set("Cache-Control", "no-cache")
 545		w.Header().Set("Connection", "keep-alive")
 546		w.WriteHeader(http.StatusOK)
 547
 548		for _, chunk := range chunks {
 549			_, _ = fmt.Fprint(w, chunk)
 550			if flusher, ok := w.(http.Flusher); ok {
 551				flusher.Flush()
 552			}
 553		}
 554	}))
 555
 556	return server, calls
 557}
 558
 559func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall {
 560	t.Helper()
 561
 562	select {
 563	case call := <-calls:
 564		return call
 565	case <-time.After(2 * time.Second):
 566		t.Fatal("timed out waiting for Anthropic request")
 567		return anthropicCall{}
 568	}
 569}
 570
 571func assertNoAnthropicCall(t *testing.T, calls <-chan anthropicCall) {
 572	t.Helper()
 573
 574	select {
 575	case call := <-calls:
 576		t.Fatalf("expected no Anthropic API call, but got %s %s", call.method, call.path)
 577	case <-time.After(200 * time.Millisecond):
 578	}
 579}
 580
 581func requireAnthropicEffort(t *testing.T, body map[string]any, expected Effort) {
 582	t.Helper()
 583
 584	outputConfig, ok := body["output_config"].(map[string]any)
 585	thinking, ok := body["thinking"].(map[string]any)
 586	require.True(t, ok)
 587	require.Equal(t, string(expected), outputConfig["effort"])
 588	require.Equal(t, "adaptive", thinking["type"])
 589}
 590
 591func testPrompt() fantasy.Prompt {
 592	return fantasy.Prompt{
 593		{
 594			Role: fantasy.MessageRoleUser,
 595			Content: []fantasy.MessagePart{
 596				fantasy.TextPart{Text: "Hello"},
 597			},
 598		},
 599	}
 600}
 601
 602func mockAnthropicGenerateResponse() map[string]any {
 603	return map[string]any{
 604		"id":    "msg_01Test",
 605		"type":  "message",
 606		"role":  "assistant",
 607		"model": "claude-sonnet-4-20250514",
 608		"content": []any{
 609			map[string]any{
 610				"type": "text",
 611				"text": "Hi there",
 612			},
 613		},
 614		"stop_reason":   "end_turn",
 615		"stop_sequence": "",
 616		"usage": map[string]any{
 617			"cache_creation": map[string]any{
 618				"ephemeral_1h_input_tokens": 0,
 619				"ephemeral_5m_input_tokens": 0,
 620			},
 621			"cache_creation_input_tokens": 0,
 622			"cache_read_input_tokens":     0,
 623			"input_tokens":                5,
 624			"output_tokens":               2,
 625			"server_tool_use": map[string]any{
 626				"web_search_requests": 0,
 627			},
 628			"service_tier": "standard",
 629		},
 630	}
 631}
 632
 633func mockAnthropicWebSearchResponse() map[string]any {
 634	return map[string]any{
 635		"id":    "msg_01WebSearch",
 636		"type":  "message",
 637		"role":  "assistant",
 638		"model": "claude-sonnet-4-20250514",
 639		"content": []any{
 640			map[string]any{
 641				"type":   "server_tool_use",
 642				"id":     "srvtoolu_01",
 643				"name":   "web_search",
 644				"input":  map[string]any{"query": "latest AI news"},
 645				"caller": map[string]any{"type": "direct"},
 646			},
 647			map[string]any{
 648				"type":        "web_search_tool_result",
 649				"tool_use_id": "srvtoolu_01",
 650				"caller":      map[string]any{"type": "direct"},
 651				"content": []any{
 652					map[string]any{
 653						"type":              "web_search_result",
 654						"url":               "https://example.com/ai-news",
 655						"title":             "Latest AI News",
 656						"encrypted_content": "encrypted_abc123",
 657						"page_age":          "2 hours ago",
 658					},
 659					map[string]any{
 660						"type":              "web_search_result",
 661						"url":               "https://example.com/ml-update",
 662						"title":             "ML Update",
 663						"encrypted_content": "encrypted_def456",
 664						"page_age":          "",
 665					},
 666				},
 667			},
 668			map[string]any{
 669				"type": "text",
 670				"text": "Based on recent search results, here is the latest AI news.",
 671			},
 672		},
 673		"stop_reason":   "end_turn",
 674		"stop_sequence": nil,
 675		"usage": map[string]any{
 676			"input_tokens":                100,
 677			"output_tokens":               50,
 678			"cache_creation_input_tokens": 0,
 679			"cache_read_input_tokens":     0,
 680			"server_tool_use": map[string]any{
 681				"web_search_requests": 1,
 682			},
 683		},
 684	}
 685}
 686
 687func TestToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) {
 688	t.Parallel()
 689
 690	prompt := fantasy.Prompt{
 691		// User message.
 692		{
 693			Role: fantasy.MessageRoleUser,
 694			Content: []fantasy.MessagePart{
 695				fantasy.TextPart{Text: "Search for the latest AI news"},
 696			},
 697		},
 698		// Assistant message with a provider-executed tool call, its
 699		// result, and trailing text. toResponseMessages routes
 700		// provider-executed results into the assistant message, so
 701		// the prompt already reflects that structure.
 702		{
 703			Role: fantasy.MessageRoleAssistant,
 704			Content: []fantasy.MessagePart{
 705				fantasy.ToolCallPart{
 706					ToolCallID:       "srvtoolu_01",
 707					ToolName:         "web_search",
 708					Input:            `{"query":"latest AI news"}`,
 709					ProviderExecuted: true,
 710				},
 711				fantasy.ToolResultPart{
 712					ToolCallID:       "srvtoolu_01",
 713					ProviderExecuted: true,
 714					ProviderOptions: fantasy.ProviderOptions{
 715						Name: &WebSearchResultMetadata{
 716							Results: []WebSearchResultItem{
 717								{
 718									URL:              "https://example.com/ai-news",
 719									Title:            "Latest AI News",
 720									EncryptedContent: "encrypted_abc123",
 721									PageAge:          "2 hours ago",
 722								},
 723								{
 724									URL:              "https://example.com/ml-update",
 725									Title:            "ML Update",
 726									EncryptedContent: "encrypted_def456",
 727								},
 728							},
 729						},
 730					},
 731				},
 732				fantasy.TextPart{Text: "Here is what I found."},
 733			},
 734		},
 735	}
 736
 737	_, messages, warnings := toPrompt(prompt, true)
 738
 739	// No warnings expected; the provider-executed result is in the
 740	// assistant message so there is no empty tool message to drop.
 741	require.Empty(t, warnings)
 742
 743	// We should have a user message and an assistant message.
 744	require.Len(t, messages, 2, "expected user + assistant messages")
 745
 746	assistantMsg := messages[1]
 747	require.Len(t, assistantMsg.Content, 3,
 748		"expected server_tool_use + web_search_tool_result + text")
 749
 750	// First content block: reconstructed server_tool_use.
 751	serverToolUse := assistantMsg.Content[0]
 752	require.NotNil(t, serverToolUse.OfServerToolUse,
 753		"first block should be a server_tool_use")
 754	require.Equal(t, "srvtoolu_01", serverToolUse.OfServerToolUse.ID)
 755	require.Equal(t, anthropic.ServerToolUseBlockParamName("web_search"),
 756		serverToolUse.OfServerToolUse.Name)
 757
 758	// Second content block: reconstructed web_search_tool_result with
 759	// encrypted_content preserved for multi-turn round-tripping.
 760	webResult := assistantMsg.Content[1]
 761	require.NotNil(t, webResult.OfWebSearchToolResult,
 762		"second block should be a web_search_tool_result")
 763	require.Equal(t, "srvtoolu_01", webResult.OfWebSearchToolResult.ToolUseID)
 764
 765	results := webResult.OfWebSearchToolResult.Content.OfWebSearchToolResultBlockItem
 766	require.Len(t, results, 2)
 767	require.Equal(t, "https://example.com/ai-news", results[0].URL)
 768	require.Equal(t, "Latest AI News", results[0].Title)
 769	require.Equal(t, "encrypted_abc123", results[0].EncryptedContent)
 770	require.Equal(t, "https://example.com/ml-update", results[1].URL)
 771	require.Equal(t, "encrypted_def456", results[1].EncryptedContent)
 772	// PageAge should be set for the first result and absent for the second.
 773	require.True(t, results[0].PageAge.Valid())
 774	require.Equal(t, "2 hours ago", results[0].PageAge.Value)
 775	require.False(t, results[1].PageAge.Valid())
 776
 777	// Third content block: plain text.
 778	require.NotNil(t, assistantMsg.Content[2].OfText)
 779	require.Equal(t, "Here is what I found.", assistantMsg.Content[2].OfText.Text)
 780}
 781
 782func TestGenerate_WebSearchResponse(t *testing.T) {
 783	t.Parallel()
 784
 785	server, calls := newAnthropicJSONServer(mockAnthropicWebSearchResponse())
 786	defer server.Close()
 787
 788	provider, err := New(
 789		WithAPIKey("test-api-key"),
 790		WithBaseURL(server.URL),
 791	)
 792	require.NoError(t, err)
 793
 794	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 795	require.NoError(t, err)
 796
 797	resp, err := model.Generate(context.Background(), fantasy.Call{
 798		Prompt: testPrompt(),
 799		Tools: []fantasy.Tool{
 800			WebSearchTool(nil),
 801		},
 802	})
 803	require.NoError(t, err)
 804
 805	call := awaitAnthropicCall(t, calls)
 806	require.Equal(t, "POST", call.method)
 807	require.Equal(t, "/v1/messages", call.path)
 808
 809	// Walk the response content and categorise each item.
 810	var (
 811		toolCalls   []fantasy.ToolCallContent
 812		sources     []fantasy.SourceContent
 813		toolResults []fantasy.ToolResultContent
 814		texts       []fantasy.TextContent
 815	)
 816	for _, c := range resp.Content {
 817		switch v := c.(type) {
 818		case fantasy.ToolCallContent:
 819			toolCalls = append(toolCalls, v)
 820		case fantasy.SourceContent:
 821			sources = append(sources, v)
 822		case fantasy.ToolResultContent:
 823			toolResults = append(toolResults, v)
 824		case fantasy.TextContent:
 825			texts = append(texts, v)
 826		}
 827	}
 828
 829	// ToolCallContent for the provider-executed web_search.
 830	require.Len(t, toolCalls, 1)
 831	require.True(t, toolCalls[0].ProviderExecuted)
 832	require.Equal(t, "web_search", toolCalls[0].ToolName)
 833	require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID)
 834
 835	// SourceContent entries for each search result.
 836	require.Len(t, sources, 2)
 837	require.Equal(t, "https://example.com/ai-news", sources[0].URL)
 838	require.Equal(t, "Latest AI News", sources[0].Title)
 839	require.Equal(t, fantasy.SourceTypeURL, sources[0].SourceType)
 840	require.Equal(t, "https://example.com/ml-update", sources[1].URL)
 841	require.Equal(t, "ML Update", sources[1].Title)
 842
 843	// ToolResultContent with provider metadata preserving encrypted_content.
 844	require.Len(t, toolResults, 1)
 845	require.True(t, toolResults[0].ProviderExecuted)
 846	require.Equal(t, "web_search", toolResults[0].ToolName)
 847	require.Equal(t, "srvtoolu_01", toolResults[0].ToolCallID)
 848
 849	searchMeta, ok := toolResults[0].ProviderMetadata[Name]
 850	require.True(t, ok, "providerMetadata should contain anthropic key")
 851	webMeta, ok := searchMeta.(*WebSearchResultMetadata)
 852	require.True(t, ok, "metadata should be *WebSearchResultMetadata")
 853	require.Len(t, webMeta.Results, 2)
 854	require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent)
 855	require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent)
 856	require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge)
 857
 858	// TextContent with the final answer.
 859	require.Len(t, texts, 1)
 860	require.Equal(t,
 861		"Based on recent search results, here is the latest AI news.",
 862		texts[0].Text,
 863	)
 864}
 865
 866func TestGenerate_WebSearchToolInRequest(t *testing.T) {
 867	t.Parallel()
 868
 869	t.Run("basic web_search tool", func(t *testing.T) {
 870		t.Parallel()
 871
 872		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 873		defer server.Close()
 874
 875		provider, err := New(
 876			WithAPIKey("test-api-key"),
 877			WithBaseURL(server.URL),
 878		)
 879		require.NoError(t, err)
 880
 881		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 882		require.NoError(t, err)
 883
 884		_, err = model.Generate(context.Background(), fantasy.Call{
 885			Prompt: testPrompt(),
 886			Tools: []fantasy.Tool{
 887				WebSearchTool(nil),
 888			},
 889		})
 890		require.NoError(t, err)
 891
 892		call := awaitAnthropicCall(t, calls)
 893		tools, ok := call.body["tools"].([]any)
 894		require.True(t, ok, "request body should have tools array")
 895		require.Len(t, tools, 1)
 896
 897		tool, ok := tools[0].(map[string]any)
 898		require.True(t, ok)
 899		require.Equal(t, "web_search_20250305", tool["type"])
 900	})
 901
 902	t.Run("with allowed_domains and blocked_domains", func(t *testing.T) {
 903		t.Parallel()
 904
 905		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 906		defer server.Close()
 907
 908		provider, err := New(
 909			WithAPIKey("test-api-key"),
 910			WithBaseURL(server.URL),
 911		)
 912		require.NoError(t, err)
 913
 914		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 915		require.NoError(t, err)
 916
 917		_, err = model.Generate(context.Background(), fantasy.Call{
 918			Prompt: testPrompt(),
 919			Tools: []fantasy.Tool{
 920				WebSearchTool(&WebSearchToolOptions{
 921					AllowedDomains: []string{"example.com", "test.com"},
 922				}),
 923			},
 924		})
 925		require.NoError(t, err)
 926
 927		call := awaitAnthropicCall(t, calls)
 928		tools, ok := call.body["tools"].([]any)
 929		require.True(t, ok)
 930		require.Len(t, tools, 1)
 931
 932		tool, ok := tools[0].(map[string]any)
 933		require.True(t, ok)
 934		require.Equal(t, "web_search_20250305", tool["type"])
 935
 936		domains, ok := tool["allowed_domains"].([]any)
 937		require.True(t, ok, "tool should have allowed_domains")
 938		require.Len(t, domains, 2)
 939		require.Equal(t, "example.com", domains[0])
 940		require.Equal(t, "test.com", domains[1])
 941	})
 942
 943	t.Run("with max uses and user location", func(t *testing.T) {
 944		t.Parallel()
 945
 946		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 947		defer server.Close()
 948
 949		provider, err := New(
 950			WithAPIKey("test-api-key"),
 951			WithBaseURL(server.URL),
 952		)
 953		require.NoError(t, err)
 954
 955		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
 956		require.NoError(t, err)
 957
 958		_, err = model.Generate(context.Background(), fantasy.Call{
 959			Prompt: testPrompt(),
 960			Tools: []fantasy.Tool{
 961				WebSearchTool(&WebSearchToolOptions{
 962					MaxUses: 5,
 963					UserLocation: &UserLocation{
 964						City:    "San Francisco",
 965						Country: "US",
 966					},
 967				}),
 968			},
 969		})
 970		require.NoError(t, err)
 971
 972		call := awaitAnthropicCall(t, calls)
 973		tools, ok := call.body["tools"].([]any)
 974		require.True(t, ok)
 975		require.Len(t, tools, 1)
 976
 977		tool, ok := tools[0].(map[string]any)
 978		require.True(t, ok)
 979		require.Equal(t, "web_search_20250305", tool["type"])
 980
 981		// max_uses is serialized as a JSON number; json.Unmarshal
 982		// into map[string]any decodes numbers as float64.
 983		maxUses, ok := tool["max_uses"].(float64)
 984		require.True(t, ok, "tool should have max_uses")
 985		require.Equal(t, float64(5), maxUses)
 986
 987		userLoc, ok := tool["user_location"].(map[string]any)
 988		require.True(t, ok, "tool should have user_location")
 989		require.Equal(t, "San Francisco", userLoc["city"])
 990		require.Equal(t, "US", userLoc["country"])
 991		require.Equal(t, "approximate", userLoc["type"])
 992	})
 993
 994	t.Run("with max uses", func(t *testing.T) {
 995		t.Parallel()
 996
 997		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
 998		defer server.Close()
 999
1000		provider, err := New(
1001			WithAPIKey("test-api-key"),
1002			WithBaseURL(server.URL),
1003		)
1004		require.NoError(t, err)
1005
1006		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1007		require.NoError(t, err)
1008
1009		_, err = model.Generate(context.Background(), fantasy.Call{
1010			Prompt: testPrompt(),
1011			Tools: []fantasy.Tool{
1012				WebSearchTool(&WebSearchToolOptions{
1013					MaxUses: 3,
1014				}),
1015			},
1016		})
1017		require.NoError(t, err)
1018
1019		call := awaitAnthropicCall(t, calls)
1020		tools, ok := call.body["tools"].([]any)
1021		require.True(t, ok)
1022		require.Len(t, tools, 1)
1023
1024		tool, ok := tools[0].(map[string]any)
1025		require.True(t, ok)
1026		require.Equal(t, "web_search_20250305", tool["type"])
1027
1028		maxUses, ok := tool["max_uses"].(float64)
1029		require.True(t, ok, "tool should have max_uses")
1030		require.Equal(t, float64(3), maxUses)
1031	})
1032
1033	t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
1034		t.Parallel()
1035
1036		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
1037		defer server.Close()
1038
1039		provider, err := New(
1040			WithAPIKey("test-api-key"),
1041			WithBaseURL(server.URL),
1042		)
1043		require.NoError(t, err)
1044
1045		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1046		require.NoError(t, err)
1047
1048		baseTool := WebSearchTool(&WebSearchToolOptions{
1049			MaxUses:        7,
1050			BlockedDomains: []string{"example.com", "test.com"},
1051			UserLocation: &UserLocation{
1052				City:     "San Francisco",
1053				Region:   "CA",
1054				Country:  "US",
1055				Timezone: "America/Los_Angeles",
1056			},
1057		})
1058
1059		data, err := json.Marshal(baseTool)
1060		require.NoError(t, err)
1061
1062		var roundTripped fantasy.ProviderDefinedTool
1063		err = json.Unmarshal(data, &roundTripped)
1064		require.NoError(t, err)
1065
1066		_, err = model.Generate(context.Background(), fantasy.Call{
1067			Prompt: testPrompt(),
1068			Tools:  []fantasy.Tool{roundTripped},
1069		})
1070		require.NoError(t, err)
1071
1072		call := awaitAnthropicCall(t, calls)
1073		tools, ok := call.body["tools"].([]any)
1074		require.True(t, ok)
1075		require.Len(t, tools, 1)
1076
1077		tool, ok := tools[0].(map[string]any)
1078		require.True(t, ok)
1079		require.Equal(t, "web_search_20250305", tool["type"])
1080
1081		domains, ok := tool["blocked_domains"].([]any)
1082		require.True(t, ok, "tool should have blocked_domains")
1083		require.Len(t, domains, 2)
1084		require.Equal(t, "example.com", domains[0])
1085		require.Equal(t, "test.com", domains[1])
1086
1087		maxUses, ok := tool["max_uses"].(float64)
1088		require.True(t, ok, "tool should have max_uses")
1089		require.Equal(t, float64(7), maxUses)
1090
1091		userLoc, ok := tool["user_location"].(map[string]any)
1092		require.True(t, ok, "tool should have user_location")
1093		require.Equal(t, "San Francisco", userLoc["city"])
1094		require.Equal(t, "CA", userLoc["region"])
1095		require.Equal(t, "US", userLoc["country"])
1096		require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
1097		require.Equal(t, "approximate", userLoc["type"])
1098	})
1099}
1100
1101func TestAnyToStringSlice(t *testing.T) {
1102	t.Parallel()
1103
1104	t.Run("from string slice", func(t *testing.T) {
1105		t.Parallel()
1106
1107		got := anyToStringSlice([]string{"example.com", ""})
1108		require.Equal(t, []string{"example.com", ""}, got)
1109	})
1110
1111	t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
1112		t.Parallel()
1113
1114		got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
1115		require.Equal(t, []string{"example.com", "test.com"}, got)
1116	})
1117
1118	t.Run("unsupported type", func(t *testing.T) {
1119		t.Parallel()
1120
1121		got := anyToStringSlice("example.com")
1122		require.Nil(t, got)
1123	})
1124}
1125
1126func TestAnyToInt64(t *testing.T) {
1127	t.Parallel()
1128
1129	tests := []struct {
1130		name   string
1131		input  any
1132		want   int64
1133		wantOK bool
1134	}{
1135		{name: "int64", input: int64(7), want: 7, wantOK: true},
1136		{name: "float64 integer", input: float64(7), want: 7, wantOK: true},
1137		{name: "float32 integer", input: float32(9), want: 9, wantOK: true},
1138		{name: "float64 non-integer", input: float64(7.5), wantOK: false},
1139		{name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
1140		{name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
1141		{name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
1142		{name: "json number float", input: json.Number("4.2"), wantOK: false},
1143		{name: "nan", input: math.NaN(), wantOK: false},
1144		{name: "inf", input: math.Inf(1), wantOK: false},
1145		{name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
1146	}
1147
1148	for _, tt := range tests {
1149		t.Run(tt.name, func(t *testing.T) {
1150			got, ok := anyToInt64(tt.input)
1151			require.Equal(t, tt.wantOK, ok)
1152			if tt.wantOK {
1153				require.Equal(t, tt.want, got)
1154			}
1155		})
1156	}
1157}
1158
1159func TestAnyToUserLocation(t *testing.T) {
1160	t.Parallel()
1161
1162	t.Run("pointer passthrough", func(t *testing.T) {
1163		t.Parallel()
1164
1165		input := &UserLocation{City: "San Francisco", Country: "US"}
1166		got := anyToUserLocation(input)
1167		require.Same(t, input, got)
1168	})
1169
1170	t.Run("struct value", func(t *testing.T) {
1171		t.Parallel()
1172
1173		got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
1174		require.NotNil(t, got)
1175		require.Equal(t, "San Francisco", got.City)
1176		require.Equal(t, "US", got.Country)
1177	})
1178
1179	t.Run("map value", func(t *testing.T) {
1180		t.Parallel()
1181
1182		got := anyToUserLocation(map[string]any{
1183			"city":     "San Francisco",
1184			"region":   "CA",
1185			"country":  "US",
1186			"timezone": "America/Los_Angeles",
1187			"type":     "approximate",
1188		})
1189		require.NotNil(t, got)
1190		require.Equal(t, "San Francisco", got.City)
1191		require.Equal(t, "CA", got.Region)
1192		require.Equal(t, "US", got.Country)
1193		require.Equal(t, "America/Los_Angeles", got.Timezone)
1194	})
1195
1196	t.Run("empty map", func(t *testing.T) {
1197		t.Parallel()
1198
1199		got := anyToUserLocation(map[string]any{"type": "approximate"})
1200		require.Nil(t, got)
1201	})
1202
1203	t.Run("unsupported type", func(t *testing.T) {
1204		t.Parallel()
1205
1206		got := anyToUserLocation("San Francisco")
1207		require.Nil(t, got)
1208	})
1209}
1210
1211func TestStream_WebSearchResponse(t *testing.T) {
1212	t.Parallel()
1213
1214	// Build SSE chunks that simulate a web search streaming response.
1215	// The Anthropic SDK accumulates content blocks via
1216	// acc.Accumulate(event). We read the Content and ToolUseID
1217	// directly from struct fields instead of using AsAny(), which
1218	// avoids the SDK's re-marshal limitation that previously dropped
1219	// source data.
1220	webSearchResultContent, _ := json.Marshal([]any{
1221		map[string]any{
1222			"type":              "web_search_result",
1223			"url":               "https://example.com/ai-news",
1224			"title":             "Latest AI News",
1225			"encrypted_content": "encrypted_abc123",
1226			"page_age":          "2 hours ago",
1227		},
1228	})
1229
1230	chunks := []string{
1231		// message_start
1232		"event: message_start\n",
1233		`data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n",
1234		// Block 0: server_tool_use
1235		"event: content_block_start\n",
1236		`data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n",
1237		"event: content_block_stop\n",
1238		`data: {"type":"content_block_stop","index":0}` + "\n\n",
1239		// Block 1: web_search_tool_result
1240		"event: content_block_start\n",
1241		`data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n",
1242		"event: content_block_stop\n",
1243		`data: {"type":"content_block_stop","index":1}` + "\n\n",
1244		// Block 2: text
1245		"event: content_block_start\n",
1246		`data: {"type":"content_block_start","index":2,"content_block":{"type":"text","text":""}}` + "\n\n",
1247		"event: content_block_delta\n",
1248		`data: {"type":"content_block_delta","index":2,"delta":{"type":"text_delta","text":"Here are the results."}}` + "\n\n",
1249		"event: content_block_stop\n",
1250		`data: {"type":"content_block_stop","index":2}` + "\n\n",
1251		// message_stop
1252		"event: message_stop\n",
1253		`data: {"type":"message_stop"}` + "\n\n",
1254	}
1255
1256	server, calls := newAnthropicStreamingServer(chunks)
1257	defer server.Close()
1258
1259	provider, err := New(
1260		WithAPIKey("test-api-key"),
1261		WithBaseURL(server.URL),
1262	)
1263	require.NoError(t, err)
1264
1265	model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
1266	require.NoError(t, err)
1267
1268	stream, err := model.Stream(context.Background(), fantasy.Call{
1269		Prompt: testPrompt(),
1270		Tools: []fantasy.Tool{
1271			WebSearchTool(nil),
1272		},
1273	})
1274	require.NoError(t, err)
1275
1276	var parts []fantasy.StreamPart
1277	stream(func(part fantasy.StreamPart) bool {
1278		parts = append(parts, part)
1279		return true
1280	})
1281
1282	_ = awaitAnthropicCall(t, calls)
1283
1284	// Collect parts by type for assertions.
1285	var (
1286		toolInputStarts []fantasy.StreamPart
1287		toolCalls       []fantasy.StreamPart
1288		toolResults     []fantasy.StreamPart
1289		sourceParts     []fantasy.StreamPart
1290		textDeltas      []fantasy.StreamPart
1291	)
1292	for _, p := range parts {
1293		switch p.Type {
1294		case fantasy.StreamPartTypeToolInputStart:
1295			toolInputStarts = append(toolInputStarts, p)
1296		case fantasy.StreamPartTypeToolCall:
1297			toolCalls = append(toolCalls, p)
1298		case fantasy.StreamPartTypeToolResult:
1299			toolResults = append(toolResults, p)
1300		case fantasy.StreamPartTypeSource:
1301			sourceParts = append(sourceParts, p)
1302		case fantasy.StreamPartTypeTextDelta:
1303			textDeltas = append(textDeltas, p)
1304		}
1305	}
1306
1307	// server_tool_use emits a ToolInputStart with ProviderExecuted.
1308	require.NotEmpty(t, toolInputStarts, "should have a tool input start")
1309	require.True(t, toolInputStarts[0].ProviderExecuted)
1310	require.Equal(t, "web_search", toolInputStarts[0].ToolCallName)
1311
1312	// server_tool_use emits a ToolCall with ProviderExecuted.
1313	require.NotEmpty(t, toolCalls, "should have a tool call")
1314	require.True(t, toolCalls[0].ProviderExecuted)
1315
1316	// web_search_tool_result always emits a ToolResult even when
1317	// the SDK drops source data. The ToolUseID comes from the raw
1318	// union field as a fallback.
1319	require.NotEmpty(t, toolResults, "should have a tool result")
1320	require.True(t, toolResults[0].ProviderExecuted)
1321	require.Equal(t, "web_search", toolResults[0].ToolCallName)
1322	require.Equal(t, "srvtoolu_01", toolResults[0].ID,
1323		"tool result ID should match the tool_use_id")
1324
1325	// Source parts are now correctly emitted by reading struct fields
1326	// directly instead of using AsAny().
1327	require.Len(t, sourceParts, 1)
1328	require.Equal(t, "https://example.com/ai-news", sourceParts[0].URL)
1329	require.Equal(t, "Latest AI News", sourceParts[0].Title)
1330	require.Equal(t, fantasy.SourceTypeURL, sourceParts[0].SourceType)
1331
1332	// Text block emits a text delta.
1333	require.NotEmpty(t, textDeltas, "should have text deltas")
1334	require.Equal(t, "Here are the results.", textDeltas[0].Delta)
1335}