@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"maps"
+ "math"
"strings"
"charm.land/fantasy"
@@ -407,6 +408,118 @@ func groupIntoBlocks(prompt fantasy.Prompt) []*messageBlock {
return blocks
}
+func anyToStringSlice(v any) []string {
+ switch typed := v.(type) {
+ case []string:
+ if len(typed) == 0 {
+ return nil
+ }
+ out := make([]string, len(typed))
+ copy(out, typed)
+ return out
+ case []any:
+ if len(typed) == 0 {
+ return nil
+ }
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ s, ok := item.(string)
+ if !ok || s == "" {
+ continue
+ }
+ out = append(out, s)
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+ default:
+ return nil
+ }
+}
+
+const maxExactIntFloat64 = float64(1<<53 - 1)
+
+func anyToInt64(v any) (int64, bool) {
+ switch typed := v.(type) {
+ case int:
+ return int64(typed), true
+ case int8:
+ return int64(typed), true
+ case int16:
+ return int64(typed), true
+ case int32:
+ return int64(typed), true
+ case int64:
+ return typed, true
+ case uint:
+ if uint64(typed) > math.MaxInt64 {
+ return 0, false
+ }
+ return int64(typed), true
+ case uint8:
+ return int64(typed), true
+ case uint16:
+ return int64(typed), true
+ case uint32:
+ return int64(typed), true
+ case uint64:
+ if typed > math.MaxInt64 {
+ return 0, false
+ }
+ return int64(typed), true
+ case float32:
+ f := float64(typed)
+ if math.Trunc(f) != f || math.IsNaN(f) || math.IsInf(f, 0) || f < -maxExactIntFloat64 || f > maxExactIntFloat64 {
+ return 0, false
+ }
+ return int64(f), true
+ case float64:
+ if math.Trunc(typed) != typed || math.IsNaN(typed) || math.IsInf(typed, 0) || typed < -maxExactIntFloat64 || typed > maxExactIntFloat64 {
+ return 0, false
+ }
+ return int64(typed), true
+ case json.Number:
+ parsed, err := typed.Int64()
+ if err != nil {
+ return 0, false
+ }
+ return parsed, true
+ default:
+ return 0, false
+ }
+}
+
+func anyToUserLocation(v any) *UserLocation {
+ switch typed := v.(type) {
+ case *UserLocation:
+ return typed
+ case UserLocation:
+ loc := typed
+ return &loc
+ case map[string]any:
+ loc := &UserLocation{}
+ if city, ok := typed["city"].(string); ok {
+ loc.City = city
+ }
+ if region, ok := typed["region"].(string); ok {
+ loc.Region = region
+ }
+ if country, ok := typed["country"].(string); ok {
+ loc.Country = country
+ }
+ if timezone, ok := typed["timezone"].(string); ok {
+ loc.Timezone = timezone
+ }
+ if loc.City == "" && loc.Region == "" && loc.Country == "" && loc.Timezone == "" {
+ return nil
+ }
+ return loc
+ default:
+ return nil
+ }
+}
+
func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []fantasy.CallWarning) {
for _, tool := range tools {
if tool.GetType() == fantasy.ToolTypeFunction {
@@ -449,16 +562,16 @@ func (a languageModel) toTools(tools []fantasy.Tool, toolChoice *fantasy.ToolCho
case "web_search":
webSearchTool := anthropic.WebSearchTool20250305Param{}
if pt.Args != nil {
- if domains, ok := pt.Args["allowed_domains"].([]string); ok && len(domains) > 0 {
+ if domains := anyToStringSlice(pt.Args["allowed_domains"]); len(domains) > 0 {
webSearchTool.AllowedDomains = domains
}
- if domains, ok := pt.Args["blocked_domains"].([]string); ok && len(domains) > 0 {
+ if domains := anyToStringSlice(pt.Args["blocked_domains"]); len(domains) > 0 {
webSearchTool.BlockedDomains = domains
}
- if maxUses, ok := pt.Args["max_uses"].(int64); ok && maxUses > 0 {
+ if maxUses, ok := anyToInt64(pt.Args["max_uses"]); ok && maxUses > 0 {
webSearchTool.MaxUses = param.NewOpt(maxUses)
}
- if loc, ok := pt.Args["user_location"].(*UserLocation); ok && loc != nil {
+ if loc := anyToUserLocation(pt.Args["user_location"]); loc != nil {
var ulp anthropic.UserLocationParam
if loc.City != "" {
ulp.City = param.NewOpt(loc.City)
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "math"
"net/http"
"net/http/httptest"
"testing"
@@ -1027,6 +1028,183 @@ func TestGenerate_WebSearchToolInRequest(t *testing.T) {
require.True(t, ok, "tool should have max_uses")
require.Equal(t, float64(3), maxUses)
})
+
+ t.Run("with json-round-tripped provider tool args", func(t *testing.T) {
+ t.Parallel()
+
+ server, calls := newAnthropicJSONServer(mockAnthropicGenerateResponse())
+ defer server.Close()
+
+ provider, err := New(
+ WithAPIKey("test-api-key"),
+ WithBaseURL(server.URL),
+ )
+ require.NoError(t, err)
+
+ model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
+ require.NoError(t, err)
+
+ baseTool := WebSearchTool(&WebSearchToolOptions{
+ MaxUses: 7,
+ BlockedDomains: []string{"example.com", "test.com"},
+ UserLocation: &UserLocation{
+ City: "San Francisco",
+ Region: "CA",
+ Country: "US",
+ Timezone: "America/Los_Angeles",
+ },
+ })
+
+ data, err := json.Marshal(baseTool)
+ require.NoError(t, err)
+
+ var roundTripped fantasy.ProviderDefinedTool
+ err = json.Unmarshal(data, &roundTripped)
+ require.NoError(t, err)
+
+ _, err = model.Generate(context.Background(), fantasy.Call{
+ Prompt: testPrompt(),
+ Tools: []fantasy.Tool{roundTripped},
+ })
+ require.NoError(t, err)
+
+ call := awaitAnthropicCall(t, calls)
+ tools, ok := call.body["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 1)
+
+ tool, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "web_search_20250305", tool["type"])
+
+ domains, ok := tool["blocked_domains"].([]any)
+ require.True(t, ok, "tool should have blocked_domains")
+ require.Len(t, domains, 2)
+ require.Equal(t, "example.com", domains[0])
+ require.Equal(t, "test.com", domains[1])
+
+ maxUses, ok := tool["max_uses"].(float64)
+ require.True(t, ok, "tool should have max_uses")
+ require.Equal(t, float64(7), maxUses)
+
+ userLoc, ok := tool["user_location"].(map[string]any)
+ require.True(t, ok, "tool should have user_location")
+ require.Equal(t, "San Francisco", userLoc["city"])
+ require.Equal(t, "CA", userLoc["region"])
+ require.Equal(t, "US", userLoc["country"])
+ require.Equal(t, "America/Los_Angeles", userLoc["timezone"])
+ require.Equal(t, "approximate", userLoc["type"])
+ })
+}
+
+func TestAnyToStringSlice(t *testing.T) {
+ t.Parallel()
+
+ t.Run("from string slice", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToStringSlice([]string{"example.com", ""})
+ require.Equal(t, []string{"example.com", ""}, got)
+ })
+
+ t.Run("from any slice filters non-strings and empty", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToStringSlice([]any{"example.com", 123, "", "test.com"})
+ require.Equal(t, []string{"example.com", "test.com"}, got)
+ })
+
+ t.Run("unsupported type", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToStringSlice("example.com")
+ require.Nil(t, got)
+ })
+}
+
+func TestAnyToInt64(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input any
+ want int64
+ wantOK bool
+ }{
+ {name: "int64", input: int64(7), want: 7, wantOK: true},
+ {name: "float64 integer", input: float64(7), want: 7, wantOK: true},
+ {name: "float32 integer", input: float32(9), want: 9, wantOK: true},
+ {name: "float64 non-integer", input: float64(7.5), wantOK: false},
+ {name: "float64 max exact int ok", input: float64(1<<53 - 1), want: 1<<53 - 1, wantOK: true},
+ {name: "float64 over max exact int", input: float64(1 << 53), wantOK: false},
+ {name: "json number int", input: json.Number("42"), want: 42, wantOK: true},
+ {name: "json number float", input: json.Number("4.2"), wantOK: false},
+ {name: "nan", input: math.NaN(), wantOK: false},
+ {name: "inf", input: math.Inf(1), wantOK: false},
+ {name: "uint64 overflow", input: uint64(math.MaxInt64) + 1, wantOK: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := anyToInt64(tt.input)
+ require.Equal(t, tt.wantOK, ok)
+ if tt.wantOK {
+ require.Equal(t, tt.want, got)
+ }
+ })
+ }
+}
+
+func TestAnyToUserLocation(t *testing.T) {
+ t.Parallel()
+
+ t.Run("pointer passthrough", func(t *testing.T) {
+ t.Parallel()
+
+ input := &UserLocation{City: "San Francisco", Country: "US"}
+ got := anyToUserLocation(input)
+ require.Same(t, input, got)
+ })
+
+ t.Run("struct value", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToUserLocation(UserLocation{City: "San Francisco", Country: "US"})
+ require.NotNil(t, got)
+ require.Equal(t, "San Francisco", got.City)
+ require.Equal(t, "US", got.Country)
+ })
+
+ t.Run("map value", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToUserLocation(map[string]any{
+ "city": "San Francisco",
+ "region": "CA",
+ "country": "US",
+ "timezone": "America/Los_Angeles",
+ "type": "approximate",
+ })
+ require.NotNil(t, got)
+ require.Equal(t, "San Francisco", got.City)
+ require.Equal(t, "CA", got.Region)
+ require.Equal(t, "US", got.Country)
+ require.Equal(t, "America/Los_Angeles", got.Timezone)
+ })
+
+ t.Run("empty map", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToUserLocation(map[string]any{"type": "approximate"})
+ require.Nil(t, got)
+ })
+
+ t.Run("unsupported type", func(t *testing.T) {
+ t.Parallel()
+
+ got := anyToUserLocation("San Francisco")
+ require.Nil(t, got)
+ })
}
func TestStream_WebSearchResponse(t *testing.T) {