From eaa111e354f3994eb07a062bdebb40875a040631 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sat, 14 Mar 2026 22:52:09 -0400 Subject: [PATCH] chore(anthroipic): harden web_search coercion for json round-trips This guards against cases when web_search options could be silently dropped when tool args pass through generic/JSON-shaped data, so domain filters, max uses, and user location reliably reach Anthropic. --- providers/anthropic/anthropic.go | 121 ++++++++++++++++- providers/anthropic/anthropic_test.go | 178 ++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 4 deletions(-) diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 61c1b1e0fa0fb1b92f5f894c46398270a7d03ac0..89d14959a2b811ec7531f21b2d7ff58814be3e8d 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "maps" + "math" "strings" "charm.land/fantasy" @@ -407,6 +408,118 @@ func groupIntoBlocks(prompt fantasy.Prompt) []*messageBlock { return blocks } +func anyToStringSlice(v any) []string { + switch typed := v.(type) { + case []string: + if len(typed) == 0 { + return nil + } + out := make([]string, len(typed)) + copy(out, typed) + return out + case []any: + if len(typed) == 0 { + return nil + } + out := make([]string, 0, len(typed)) + for _, item := range typed { + s, ok := item.(string) + if !ok || s == "" { + continue + } + out = append(out, s) + } + if len(out) == 0 { + return nil + } + return out + default: + return nil + } +} + +const maxExactIntFloat64 = float64(1<<53 - 1) + +func anyToInt64(v any) (int64, bool) { + switch typed := v.(type) { + case int: + return int64(typed), true + case int8: + return int64(typed), true + case int16: + return int64(typed), true + case int32: + return int64(typed), true + case int64: + return typed, true + case uint: + if uint64(typed) > math.MaxInt64 { + return 0, false + } + return int64(typed), true + case uint8: + return int64(typed), true + case uint16: + return int64(typed), true + case uint32: + return int64(typed), true + case uint64: + if typed > math.MaxInt64 { + return 0, false + } + return int64(typed), true + case float32: + f := float64(typed) + if math.Trunc(f) != f || math.IsNaN(f) || math.IsInf(f, 0) || f < -maxExactIntFloat64 || f > maxExactIntFloat64 { + return 0, false + } + return int64(f), true + case float64: + if math.Trunc(typed) != typed || math.IsNaN(typed) || math.IsInf(typed, 0) || typed < -maxExactIntFloat64 || typed > maxExactIntFloat64 { + return 0, false + } + return int64(typed), true + case json.Number: + parsed, err := typed.Int64() + if err != nil { + return 0, false + } + return parsed, true + default: + return 0, false + } +} + +func anyToUserLocation(v any) *UserLocation { + switch typed := v.(type) { + case *UserLocation: + return typed + case UserLocation: + loc := typed + return &loc + case map[string]any: + loc := &UserLocation{} + if city, ok := typed["city"].(string); ok { + loc.City = city + } + if region, ok := typed["region"].(string); ok { + loc.Region = region + } + if country, ok := typed["country"].(string); ok { + loc.Country = country + } + if timezone, ok := typed["timezone"].(string); ok { + loc.Timezone = timezone + } + if loc.City == "" && loc.Region == "" && loc.Country == "" && loc.Timezone == "" { + return nil + } + return loc + default: + return nil + } +} + func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []fantasy.CallWarning) { for _, tool := range tools { if tool.GetType() == fantasy.ToolTypeFunction { @@ -449,16 +562,16 @@ func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolCho case "web_search": webSearchTool := anthropic.WebSearchTool20250305Param{} if pt.Args != nil { - if domains, ok := pt.Args["allowed_domains"].([]string); ok && len(domains) > 0 { + if domains := anyToStringSlice(pt.Args["allowed_domains"]); len(domains) > 0 { webSearchTool.AllowedDomains = domains } - if domains, ok := pt.Args["blocked_domains"].([]string); ok && len(domains) > 0 { + if domains := anyToStringSlice(pt.Args["blocked_domains"]); len(domains) > 0 { webSearchTool.BlockedDomains = domains } - if maxUses, ok := pt.Args["max_uses"].(int64); ok && maxUses > 0 { + if maxUses, ok := anyToInt64(pt.Args["max_uses"]); ok && maxUses > 0 { webSearchTool.MaxUses = param.NewOpt(maxUses) } - if loc, ok := pt.Args["user_location"].(*UserLocation); ok && loc != nil { + if loc := anyToUserLocation(pt.Args["user_location"]); loc != nil { var ulp anthropic.UserLocationParam if loc.City != "" { ulp.City = param.NewOpt(loc.City) diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index 545d71b3e2928fdec80c866bfc6ad0b05c6a7a5d..c6c2704fa9134bf4fc667c7cd190b95a53a1845a 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/http" "net/http/httptest" "testing" @@ -1027,6 +1028,183 @@ func TestGenerate_WebSearchToolInRequest(t *testing.T) { require.True(t, ok, "tool should have max_uses") require.Equal(t, float64(3), maxUses) }) + + t.Run("with json-round-tripped provider tool args", func(t *testing.T) { + t.Parallel() + + server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse()) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + baseTool := WebSearchTool(&WebSearchToolOptions{ + MaxUses: 7, + BlockedDomains: []string{"example.com", "test.com"}, + UserLocation: &UserLocation{ + City: "San Francisco", + Region: "CA", + Country: "US", + Timezone: "America/Los_Angeles", + }, + }) + + data, err := json.Marshal(baseTool) + require.NoError(t, err) + + var roundTripped fantasy.ProviderDefinedTool + err = json.Unmarshal(data, &roundTripped) + require.NoError(t, err) + + _, err = model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + Tools: []fantasy.Tool{roundTripped}, + }) + require.NoError(t, err) + + call := awaitAnthropicCall(t, calls) + tools, ok := call.body["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "web_search_20250305", tool["type"]) + + domains, ok := tool["blocked_domains"].([]any) + require.True(t, ok, "tool should have blocked_domains") + require.Len(t, domains, 2) + require.Equal(t, "example.com", domains[0]) + require.Equal(t, "test.com", domains[1]) + + maxUses, ok := tool["max_uses"].(float64) + require.True(t, ok, "tool should have max_uses") + require.Equal(t, float64(7), maxUses) + + userLoc, ok := tool["user_location"].(map[string]any) + require.True(t, ok, "tool should have user_location") + require.Equal(t, "San Francisco", userLoc["city"]) + require.Equal(t, "CA", userLoc["region"]) + require.Equal(t, "US", userLoc["country"]) + require.Equal(t, "America/Los_Angeles", userLoc["timezone"]) + require.Equal(t, "approximate", userLoc["type"]) + }) +} + +func TestAnyToStringSlice(t *testing.T) { + t.Parallel() + + t.Run("from string slice", func(t *testing.T) { + t.Parallel() + + got := anyToStringSlice([]string{"example.com", ""}) + require.Equal(t, []string{"example.com", ""}, got) + }) + + t.Run("from any slice filters non-strings and empty", func(t *testing.T) { + t.Parallel() + + got := anyToStringSlice([]any{"example.com", 123, "", "test.com"}) + require.Equal(t, []string{"example.com", "test.com"}, got) + }) + + t.Run("unsupported type", func(t *testing.T) { + t.Parallel() + + got := anyToStringSlice("example.com") + require.Nil(t, got) + }) +} + +func TestAnyToInt64(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input any + want int64 + wantOK bool + }{ + {name: "int64", input: int64(7), want: 7, wantOK: true}, + {name: "float64 integer", input: float64(7), want: 7, wantOK: true}, + {name: "float32 integer", input: float32(9), want: 9, wantOK: true}, + {name: "float64 non-integer", input: float64(7.5), wantOK: false}, + {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true}, + {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false}, + {name: "json number int", input: json.Number("42"), want: 42, wantOK: true}, + {name: "json number float", input: json.Number("4.2"), wantOK: false}, + {name: "nan", input: math.NaN(), wantOK: false}, + {name: "inf", input: math.Inf(1), wantOK: false}, + {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := anyToInt64(tt.input) + require.Equal(t, tt.wantOK, ok) + if tt.wantOK { + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestAnyToUserLocation(t *testing.T) { + t.Parallel() + + t.Run("pointer passthrough", func(t *testing.T) { + t.Parallel() + + input := &UserLocation{City: "San Francisco", Country: "US"} + got := anyToUserLocation(input) + require.Same(t, input, got) + }) + + t.Run("struct value", func(t *testing.T) { + t.Parallel() + + got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"}) + require.NotNil(t, got) + require.Equal(t, "San Francisco", got.City) + require.Equal(t, "US", got.Country) + }) + + t.Run("map value", func(t *testing.T) { + t.Parallel() + + got := anyToUserLocation(map[string]any{ + "city": "San Francisco", + "region": "CA", + "country": "US", + "timezone": "America/Los_Angeles", + "type": "approximate", + }) + require.NotNil(t, got) + require.Equal(t, "San Francisco", got.City) + require.Equal(t, "CA", got.Region) + require.Equal(t, "US", got.Country) + require.Equal(t, "America/Los_Angeles", got.Timezone) + }) + + t.Run("empty map", func(t *testing.T) { + t.Parallel() + + got := anyToUserLocation(map[string]any{"type": "approximate"}) + require.Nil(t, got) + }) + + t.Run("unsupported type", func(t *testing.T) { + t.Parallel() + + got := anyToUserLocation("San Francisco") + require.Nil(t, got) + }) } func TestStream_WebSearchResponse(t *testing.T) {