chore(anthroipic): harden web_search coercion for json round-trips

Christian Rocha created

This guards against cases when web_search options could be silently
dropped when tool args pass through generic/JSON-shaped data, so domain
filters, max uses, and user location reliably reach Anthropic.

Change summary

providers/anthropic/anthropic.go      | 121 +++++++++++++++++++
providers/anthropic/anthropic_test.go | 178 +++++++++++++++++++++++++++++
2 files changed, 295 insertions(+), 4 deletions(-)

Detailed changes

providers/anthropic/anthropic.go 🔗

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"io"
 	"maps"
+	"math"
 	"strings"
 
 	"charm.land/fantasy"
@@ -407,6 +408,118 @@ func groupIntoBlocks(prompt fantasy.Prompt) []*messageBlock {
 	return blocks
 }
 
+func anyToStringSlice(v any) []string {
+	switch typed := v.(type) {
+	case []string:
+		if len(typed) == 0 {
+			return nil
+		}
+		out := make([]string, len(typed))
+		copy(out, typed)
+		return out
+	case []any:
+		if len(typed) == 0 {
+			return nil
+		}
+		out := make([]string, 0, len(typed))
+		for _, item := range typed {
+			s, ok := item.(string)
+			if !ok || s == "" {
+				continue
+			}
+			out = append(out, s)
+		}
+		if len(out) == 0 {
+			return nil
+		}
+		return out
+	default:
+		return nil
+	}
+}
+
+const maxExactIntFloat64 = float64(1<<53 - 1)
+
+func anyToInt64(v any) (int64, bool) {
+	switch typed := v.(type) {
+	case int:
+		return int64(typed), true
+	case int8:
+		return int64(typed), true
+	case int16:
+		return int64(typed), true
+	case int32:
+		return int64(typed), true
+	case int64:
+		return typed, true
+	case uint:
+		if uint64(typed) > math.MaxInt64 {
+			return 0, false
+		}
+		return int64(typed), true
+	case uint8:
+		return int64(typed), true
+	case uint16:
+		return int64(typed), true
+	case uint32:
+		return int64(typed), true
+	case uint64:
+		if typed > math.MaxInt64 {
+			return 0, false
+		}
+		return int64(typed), true
+	case float32:
+		f := float64(typed)
+		if math.Trunc(f) != f || math.IsNaN(f) || math.IsInf(f, 0) || f < -maxExactIntFloat64 || f > maxExactIntFloat64 {
+			return 0, false
+		}
+		return int64(f), true
+	case float64:
+		if math.Trunc(typed) != typed || math.IsNaN(typed) || math.IsInf(typed, 0) || typed < -maxExactIntFloat64 || typed > maxExactIntFloat64 {
+			return 0, false
+		}
+		return int64(typed), true
+	case json.Number:
+		parsed, err := typed.Int64()
+		if err != nil {
+			return 0, false
+		}
+		return parsed, true
+	default:
+		return 0, false
+	}
+}
+
+func anyToUserLocation(v any) *UserLocation {
+	switch typed := v.(type) {
+	case *UserLocation:
+		return typed
+	case UserLocation:
+		loc := typed
+		return &loc
+	case map[string]any:
+		loc := &UserLocation{}
+		if city, ok := typed["city"].(string); ok {
+			loc.City = city
+		}
+		if region, ok := typed["region"].(string); ok {
+			loc.Region = region
+		}
+		if country, ok := typed["country"].(string); ok {
+			loc.Country = country
+		}
+		if timezone, ok := typed["timezone"].(string); ok {
+			loc.Timezone = timezone
+		}
+		if loc.City == "" && loc.Region == "" && loc.Country == "" && loc.Timezone == "" {
+			return nil
+		}
+		return loc
+	default:
+		return nil
+	}
+}
+
 func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []fantasy.CallWarning) {
 	for _, tool := range tools {
 		if tool.GetType() == fantasy.ToolTypeFunction {
@@ -449,16 +562,16 @@ func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolCho
 			case "web_search":
 				webSearchTool := anthropic.WebSearchTool20250305Param{}
 				if pt.Args != nil {
-					if domains, ok := pt.Args["allowed_domains"].([]string); ok && len(domains) > 0 {
+					if domains := anyToStringSlice(pt.Args["allowed_domains"]); len(domains) > 0 {
 						webSearchTool.AllowedDomains = domains
 					}
-					if domains, ok := pt.Args["blocked_domains"].([]string); ok && len(domains) > 0 {
+					if domains := anyToStringSlice(pt.Args["blocked_domains"]); len(domains) > 0 {
 						webSearchTool.BlockedDomains = domains
 					}
-					if maxUses, ok := pt.Args["max_uses"].(int64); ok && maxUses > 0 {
+					if maxUses, ok := anyToInt64(pt.Args["max_uses"]); ok && maxUses > 0 {
 						webSearchTool.MaxUses = param.NewOpt(maxUses)
 					}
-					if loc, ok := pt.Args["user_location"].(*UserLocation); ok && loc != nil {
+					if loc := anyToUserLocation(pt.Args["user_location"]); loc != nil {
 						var ulp anthropic.UserLocationParam
 						if loc.City != "" {
 							ulp.City = param.NewOpt(loc.City)

providers/anthropic/anthropic_test.go 🔗

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"math"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -1027,6 +1028,183 @@ func TestGenerate_WebSearchToolInRequest(t *testing.T) {
 		require.True(t, ok, "tool should have max_uses")
 		require.Equal(t, float64(3), maxUses)
 	})
+
+	t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
+		t.Parallel()
+
+		server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
+		defer server.Close()
+
+		provider, err := New(
+			WithAPIKey("test-api-key"),
+			WithBaseURL(server.URL),
+		)
+		require.NoError(t, err)
+
+		model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
+		require.NoError(t, err)
+
+		baseTool := WebSearchTool(&WebSearchToolOptions{
+			MaxUses:        7,
+			BlockedDomains: []string{"example.com", "test.com"},
+			UserLocation: &UserLocation{
+				City:     "San Francisco",
+				Region:   "CA",
+				Country:  "US",
+				Timezone: "America/Los_Angeles",
+			},
+		})
+
+		data, err := json.Marshal(baseTool)
+		require.NoError(t, err)
+
+		var roundTripped fantasy.ProviderDefinedTool
+		err = json.Unmarshal(data, &roundTripped)
+		require.NoError(t, err)
+
+		_, err = model.Generate(context.Background(), fantasy.Call{
+			Prompt: testPrompt(),
+			Tools:  []fantasy.Tool{roundTripped},
+		})
+		require.NoError(t, err)
+
+		call := awaitAnthropicCall(t, calls)
+		tools, ok := call.body["tools"].([]any)
+		require.True(t, ok)
+		require.Len(t, tools, 1)
+
+		tool, ok := tools[0].(map[string]any)
+		require.True(t, ok)
+		require.Equal(t, "web_search_20250305", tool["type"])
+
+		domains, ok := tool["blocked_domains"].([]any)
+		require.True(t, ok, "tool should have blocked_domains")
+		require.Len(t, domains, 2)
+		require.Equal(t, "example.com", domains[0])
+		require.Equal(t, "test.com", domains[1])
+
+		maxUses, ok := tool["max_uses"].(float64)
+		require.True(t, ok, "tool should have max_uses")
+		require.Equal(t, float64(7), maxUses)
+
+		userLoc, ok := tool["user_location"].(map[string]any)
+		require.True(t, ok, "tool should have user_location")
+		require.Equal(t, "San Francisco", userLoc["city"])
+		require.Equal(t, "CA", userLoc["region"])
+		require.Equal(t, "US", userLoc["country"])
+		require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
+		require.Equal(t, "approximate", userLoc["type"])
+	})
+}
+
+func TestAnyToStringSlice(t *testing.T) {
+	t.Parallel()
+
+	t.Run("from string slice", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToStringSlice([]string{"example.com", ""})
+		require.Equal(t, []string{"example.com", ""}, got)
+	})
+
+	t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
+		require.Equal(t, []string{"example.com", "test.com"}, got)
+	})
+
+	t.Run("unsupported type", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToStringSlice("example.com")
+		require.Nil(t, got)
+	})
+}
+
+func TestAnyToInt64(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name   string
+		input  any
+		want   int64
+		wantOK bool
+	}{
+		{name: "int64", input: int64(7), want: 7, wantOK: true},
+		{name: "float64 integer", input: float64(7), want: 7, wantOK: true},
+		{name: "float32 integer", input: float32(9), want: 9, wantOK: true},
+		{name: "float64 non-integer", input: float64(7.5), wantOK: false},
+		{name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
+		{name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
+		{name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
+		{name: "json number float", input: json.Number("4.2"), wantOK: false},
+		{name: "nan", input: math.NaN(), wantOK: false},
+		{name: "inf", input: math.Inf(1), wantOK: false},
+		{name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, ok := anyToInt64(tt.input)
+			require.Equal(t, tt.wantOK, ok)
+			if tt.wantOK {
+				require.Equal(t, tt.want, got)
+			}
+		})
+	}
+}
+
+func TestAnyToUserLocation(t *testing.T) {
+	t.Parallel()
+
+	t.Run("pointer passthrough", func(t *testing.T) {
+		t.Parallel()
+
+		input := &UserLocation{City: "San Francisco", Country: "US"}
+		got := anyToUserLocation(input)
+		require.Same(t, input, got)
+	})
+
+	t.Run("struct value", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
+		require.NotNil(t, got)
+		require.Equal(t, "San Francisco", got.City)
+		require.Equal(t, "US", got.Country)
+	})
+
+	t.Run("map value", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToUserLocation(map[string]any{
+			"city":     "San Francisco",
+			"region":   "CA",
+			"country":  "US",
+			"timezone": "America/Los_Angeles",
+			"type":     "approximate",
+		})
+		require.NotNil(t, got)
+		require.Equal(t, "San Francisco", got.City)
+		require.Equal(t, "CA", got.Region)
+		require.Equal(t, "US", got.Country)
+		require.Equal(t, "America/Los_Angeles", got.Timezone)
+	})
+
+	t.Run("empty map", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToUserLocation(map[string]any{"type": "approximate"})
+		require.Nil(t, got)
+	})
+
+	t.Run("unsupported type", func(t *testing.T) {
+		t.Parallel()
+
+		got := anyToUserLocation("San Francisco")
+		require.Nil(t, got)
+	})
 }
 
 func TestStream_WebSearchResponse(t *testing.T) {