1package providertests
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "testing"
8
9 "charm.land/fantasy"
10 "charm.land/fantasy/providers/anthropic"
11 "charm.land/x/vcr"
12 "github.com/stretchr/testify/require"
13)
14
15var anthropicTestModels = []testModel{
16 {"claude-sonnet-4", "claude-sonnet-4-20250514", true},
17}
18
19func TestAnthropicCommon(t *testing.T) {
20 var pairs []builderPair
21 for _, m := range anthropicTestModels {
22 pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
23 }
24 testCommon(t, pairs)
25}
26
27func addAnthropicCaching(ctx context.Context, options fantasy.PrepareStepFunctionOptions) (context.Context, fantasy.PrepareStepResult, error) {
28 prepared := fantasy.PrepareStepResult{}
29 prepared.Messages = options.Messages
30
31 for i := range prepared.Messages {
32 prepared.Messages[i].ProviderOptions = nil
33 }
34 providerOption := fantasy.ProviderOptions{
35 anthropic.Name: &anthropic.ProviderCacheControlOptions{
36 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
37 },
38 }
39
40 lastSystemRoleInx := 0
41 systemMessageUpdated := false
42 for i, msg := range prepared.Messages {
43 // only add cache control to the last message
44 if msg.Role == fantasy.MessageRoleSystem {
45 lastSystemRoleInx = i
46 } else if !systemMessageUpdated {
47 prepared.Messages[lastSystemRoleInx].ProviderOptions = providerOption
48 systemMessageUpdated = true
49 }
50 // than add cache control to the last 2 messages
51 if i > len(prepared.Messages)-3 {
52 prepared.Messages[i].ProviderOptions = providerOption
53 }
54 }
55 return ctx, prepared, nil
56}
57
58func TestAnthropicCommonWithCacheControl(t *testing.T) {
59 var pairs []builderPair
60 for _, m := range anthropicTestModels {
61 pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, addAnthropicCaching})
62 }
63 testCommon(t, pairs)
64}
65
66func TestAnthropicThinking(t *testing.T) {
67 opts := fantasy.ProviderOptions{
68 anthropic.Name: &anthropic.ProviderOptions{
69 Thinking: &anthropic.ThinkingProviderOption{
70 BudgetTokens: 4000,
71 },
72 },
73 }
74 var pairs []builderPair
75 for _, m := range anthropicTestModels {
76 if !m.reasoning {
77 continue
78 }
79 pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, nil})
80 }
81 testThinking(t, pairs, testAnthropicThinking)
82}
83
84func TestAnthropicThinkingWithCacheControl(t *testing.T) {
85 opts := fantasy.ProviderOptions{
86 anthropic.Name: &anthropic.ProviderOptions{
87 Thinking: &anthropic.ThinkingProviderOption{
88 BudgetTokens: 4000,
89 },
90 },
91 }
92 var pairs []builderPair
93 for _, m := range anthropicTestModels {
94 if !m.reasoning {
95 continue
96 }
97 pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, addAnthropicCaching})
98 }
99 testThinking(t, pairs, testAnthropicThinking)
100}
101
102func TestAnthropicObjectGeneration(t *testing.T) {
103 var pairs []builderPair
104 for _, m := range anthropicTestModels {
105 pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
106 }
107 testObjectGeneration(t, pairs)
108}
109
110func testAnthropicThinking(t *testing.T, result *fantasy.AgentResult) {
111 reasoningContentCount := 0
112 signaturesCount := 0
113 // Test if we got the signature
114 for _, step := range result.Steps {
115 for _, msg := range step.Messages {
116 for _, content := range msg.Content {
117 if content.GetType() == fantasy.ContentTypeReasoning {
118 reasoningContentCount += 1
119 reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content)
120 if !ok {
121 continue
122 }
123 if len(reasoningContent.ProviderOptions) == 0 {
124 continue
125 }
126
127 anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[anthropic.Name]
128 if !ok {
129 continue
130 }
131 if reasoningContent.Text != "" {
132 if typed, ok := anthropicReasoningMetadata.(*anthropic.ReasoningOptionMetadata); ok {
133 require.NotEmpty(t, typed.Signature)
134 signaturesCount += 1
135 }
136 }
137 }
138 }
139 }
140 }
141 require.Greater(t, reasoningContentCount, 0)
142 require.Greater(t, signaturesCount, 0)
143 require.Equal(t, reasoningContentCount, signaturesCount)
144}
145
146func anthropicBuilder(model string) builderFunc {
147 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
148 provider, err := anthropic.New(
149 anthropic.WithAPIKey(os.Getenv("FANTASY_ANTHROPIC_API_KEY")),
150 anthropic.WithHTTPClient(&http.Client{Transport: r}),
151 )
152 if err != nil {
153 return nil, err
154 }
155 return provider.LanguageModel(t.Context(), model)
156 }
157}
158
159// TestAnthropicWebSearch tests web search tool support via the agent
160// using WithProviderDefinedTools.
161func TestAnthropicWebSearch(t *testing.T) {
162 model := "claude-sonnet-4-20250514"
163 webSearchTool := anthropic.WebSearchTool(nil)
164
165 t.Run("generate", func(t *testing.T) {
166 r := vcr.NewRecorder(t)
167
168 lm, err := anthropicBuilder(model)(t, r)
169 require.NoError(t, err)
170
171 agent := fantasy.NewAgent(
172 lm,
173 fantasy.WithSystemPrompt("You are a helpful assistant"),
174 fantasy.WithProviderDefinedTools(webSearchTool),
175 )
176
177 result, err := agent.Generate(t.Context(), fantasy.AgentCall{
178 Prompt: "What is the current population of Tokyo? Cite your source.",
179 MaxOutputTokens: fantasy.Opt(int64(4000)),
180 })
181 require.NoError(t, err)
182
183 got := result.Response.Content.Text()
184 require.NotEmpty(t, got, "should have a text response")
185 require.Contains(t, got, "Tokyo", "response should mention Tokyo")
186
187 // Walk the steps and verify web search content was produced.
188 var sources []fantasy.SourceContent
189 var providerToolCalls []fantasy.ToolCallContent
190 for _, step := range result.Steps {
191 for _, c := range step.Content {
192 switch v := c.(type) {
193 case fantasy.ToolCallContent:
194 if v.ProviderExecuted {
195 providerToolCalls = append(providerToolCalls, v)
196 }
197 case fantasy.SourceContent:
198 sources = append(sources, v)
199 }
200 }
201 }
202
203 require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
204 require.Equal(t, "web_search", providerToolCalls[0].ToolName)
205 require.NotEmpty(t, sources, "should have source citations")
206 require.NotEmpty(t, sources[0].URL, "source should have a URL")
207 })
208
209 t.Run("stream", func(t *testing.T) {
210 r := vcr.NewRecorder(t)
211
212 lm, err := anthropicBuilder(model)(t, r)
213 require.NoError(t, err)
214
215 agent := fantasy.NewAgent(
216 lm,
217 fantasy.WithSystemPrompt("You are a helpful assistant"),
218 fantasy.WithProviderDefinedTools(webSearchTool),
219 )
220
221 // Turn 1: initial query triggers web search.
222 result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
223 Prompt: "What is the current population of Tokyo? Cite your source.",
224 MaxOutputTokens: fantasy.Opt(int64(4000)),
225 })
226 require.NoError(t, err)
227
228 got := result.Response.Content.Text()
229 require.NotEmpty(t, got, "should have a text response")
230 require.Contains(t, got, "Tokyo", "response should mention Tokyo")
231
232 // Verify provider-executed tool calls and results in steps.
233 var providerToolCalls []fantasy.ToolCallContent
234 var providerToolResults []fantasy.ToolResultContent
235 for _, step := range result.Steps {
236 for _, c := range step.Content {
237 switch v := c.(type) {
238 case fantasy.ToolCallContent:
239 if v.ProviderExecuted {
240 providerToolCalls = append(providerToolCalls, v)
241 }
242 case fantasy.ToolResultContent:
243 if v.ProviderExecuted {
244 providerToolResults = append(providerToolResults, v)
245 }
246 }
247 }
248 }
249 require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
250 require.Equal(t, "web_search", providerToolCalls[0].ToolName)
251 require.NotEmpty(t, providerToolResults, "should have provider-executed tool results")
252
253 // Turn 2: follow-up using step messages from turn 1.
254 // This verifies that the web_search_tool_result block
255 // round-trips correctly through toPrompt.
256 var history fantasy.Prompt
257 history = append(history, fantasy.Message{
258 Role: fantasy.MessageRoleUser,
259 Content: []fantasy.MessagePart{fantasy.TextPart{Text: "What is the current population of Tokyo? Cite your source."}},
260 })
261 for _, step := range result.Steps {
262 history = append(history, step.Messages...)
263 }
264
265 result2, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
266 Messages: history,
267 Prompt: "How does that compare to Osaka?",
268 MaxOutputTokens: fantasy.Opt(int64(4000)),
269 })
270 require.NoError(t, err)
271
272 got2 := result2.Response.Content.Text()
273 require.NotEmpty(t, got2, "turn 2 should have a text response")
274 require.Contains(t, got2, "Osaka", "turn 2 response should mention Osaka")
275 })
276}