anthropic_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/anthropic"
 11	"charm.land/x/vcr"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15var anthropicTestModels = []testModel{
 16	{"claude-sonnet-4", "claude-sonnet-4-20250514", true},
 17}
 18
 19func TestAnthropicCommon(t *testing.T) {
 20	var pairs []builderPair
 21	for _, m := range anthropicTestModels {
 22		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
 23	}
 24	testCommon(t, pairs)
 25}
 26
 27func addAnthropicCaching(ctx context.Context, options fantasy.PrepareStepFunctionOptions) (context.Context, fantasy.PrepareStepResult, error) {
 28	prepared := fantasy.PrepareStepResult{}
 29	prepared.Messages = options.Messages
 30
 31	for i := range prepared.Messages {
 32		prepared.Messages[i].ProviderOptions = nil
 33	}
 34	providerOption := fantasy.ProviderOptions{
 35		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 36			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 37		},
 38	}
 39
 40	lastSystemRoleInx := 0
 41	systemMessageUpdated := false
 42	for i, msg := range prepared.Messages {
 43		// only add cache control to the last message
 44		if msg.Role == fantasy.MessageRoleSystem {
 45			lastSystemRoleInx = i
 46		} else if !systemMessageUpdated {
 47			prepared.Messages[lastSystemRoleInx].ProviderOptions = providerOption
 48			systemMessageUpdated = true
 49		}
 50		// than add cache control to the last 2 messages
 51		if i > len(prepared.Messages)-3 {
 52			prepared.Messages[i].ProviderOptions = providerOption
 53		}
 54	}
 55	return ctx, prepared, nil
 56}
 57
 58func TestAnthropicCommonWithCacheControl(t *testing.T) {
 59	var pairs []builderPair
 60	for _, m := range anthropicTestModels {
 61		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, addAnthropicCaching})
 62	}
 63	testCommon(t, pairs)
 64}
 65
 66func TestAnthropicThinking(t *testing.T) {
 67	opts := fantasy.ProviderOptions{
 68		anthropic.Name: &anthropic.ProviderOptions{
 69			Thinking: &anthropic.ThinkingProviderOption{
 70				BudgetTokens: 4000,
 71			},
 72		},
 73	}
 74	var pairs []builderPair
 75	for _, m := range anthropicTestModels {
 76		if !m.reasoning {
 77			continue
 78		}
 79		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, nil})
 80	}
 81	testThinking(t, pairs, testAnthropicThinking)
 82}
 83
 84func TestAnthropicThinkingWithCacheControl(t *testing.T) {
 85	opts := fantasy.ProviderOptions{
 86		anthropic.Name: &anthropic.ProviderOptions{
 87			Thinking: &anthropic.ThinkingProviderOption{
 88				BudgetTokens: 4000,
 89			},
 90		},
 91	}
 92	var pairs []builderPair
 93	for _, m := range anthropicTestModels {
 94		if !m.reasoning {
 95			continue
 96		}
 97		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, addAnthropicCaching})
 98	}
 99	testThinking(t, pairs, testAnthropicThinking)
100}
101
102func TestAnthropicObjectGeneration(t *testing.T) {
103	var pairs []builderPair
104	for _, m := range anthropicTestModels {
105		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
106	}
107	testObjectGeneration(t, pairs)
108}
109
110func testAnthropicThinking(t *testing.T, result *fantasy.AgentResult) {
111	reasoningContentCount := 0
112	signaturesCount := 0
113	// Test if we got the signature
114	for _, step := range result.Steps {
115		for _, msg := range step.Messages {
116			for _, content := range msg.Content {
117				if content.GetType() == fantasy.ContentTypeReasoning {
118					reasoningContentCount += 1
119					reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content)
120					if !ok {
121						continue
122					}
123					if len(reasoningContent.ProviderOptions) == 0 {
124						continue
125					}
126
127					anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[anthropic.Name]
128					if !ok {
129						continue
130					}
131					if reasoningContent.Text != "" {
132						if typed, ok := anthropicReasoningMetadata.(*anthropic.ReasoningOptionMetadata); ok {
133							require.NotEmpty(t, typed.Signature)
134							signaturesCount += 1
135						}
136					}
137				}
138			}
139		}
140	}
141	require.Greater(t, reasoningContentCount, 0)
142	require.Greater(t, signaturesCount, 0)
143	require.Equal(t, reasoningContentCount, signaturesCount)
144}
145
146func anthropicBuilder(model string) builderFunc {
147	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
148		provider, err := anthropic.New(
149			anthropic.WithAPIKey(os.Getenv("FANTASY_ANTHROPIC_API_KEY")),
150			anthropic.WithHTTPClient(&http.Client{Transport: r}),
151		)
152		if err != nil {
153			return nil, err
154		}
155		return provider.LanguageModel(t.Context(), model)
156	}
157}
158
159// TestAnthropicWebSearch tests web search tool support via the agent
160// using WithProviderDefinedTools.
161func TestAnthropicWebSearch(t *testing.T) {
162	model := "claude-sonnet-4-20250514"
163	webSearchTool := anthropic.WebSearchTool(nil)
164
165	t.Run("generate", func(t *testing.T) {
166		r := vcr.NewRecorder(t)
167
168		lm, err := anthropicBuilder(model)(t, r)
169		require.NoError(t, err)
170
171		agent := fantasy.NewAgent(
172			lm,
173			fantasy.WithSystemPrompt("You are a helpful assistant"),
174			fantasy.WithProviderDefinedTools(webSearchTool),
175		)
176
177		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
178			Prompt:          "What is the current population of Tokyo? Cite your source.",
179			MaxOutputTokens: fantasy.Opt(int64(4000)),
180		})
181		require.NoError(t, err)
182
183		got := result.Response.Content.Text()
184		require.NotEmpty(t, got, "should have a text response")
185		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
186
187		// Walk the steps and verify web search content was produced.
188		var sources []fantasy.SourceContent
189		var providerToolCalls []fantasy.ToolCallContent
190		for _, step := range result.Steps {
191			for _, c := range step.Content {
192				switch v := c.(type) {
193				case fantasy.ToolCallContent:
194					if v.ProviderExecuted {
195						providerToolCalls = append(providerToolCalls, v)
196					}
197				case fantasy.SourceContent:
198					sources = append(sources, v)
199				}
200			}
201		}
202
203		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
204		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
205		require.NotEmpty(t, sources, "should have source citations")
206		require.NotEmpty(t, sources[0].URL, "source should have a URL")
207	})
208
209	t.Run("stream", func(t *testing.T) {
210		r := vcr.NewRecorder(t)
211
212		lm, err := anthropicBuilder(model)(t, r)
213		require.NoError(t, err)
214
215		agent := fantasy.NewAgent(
216			lm,
217			fantasy.WithSystemPrompt("You are a helpful assistant"),
218			fantasy.WithProviderDefinedTools(webSearchTool),
219		)
220
221		// Turn 1: initial query triggers web search.
222		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
223			Prompt:          "What is the current population of Tokyo? Cite your source.",
224			MaxOutputTokens: fantasy.Opt(int64(4000)),
225		})
226		require.NoError(t, err)
227
228		got := result.Response.Content.Text()
229		require.NotEmpty(t, got, "should have a text response")
230		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
231
232		// Verify provider-executed tool calls and results in steps.
233		var providerToolCalls []fantasy.ToolCallContent
234		var providerToolResults []fantasy.ToolResultContent
235		for _, step := range result.Steps {
236			for _, c := range step.Content {
237				switch v := c.(type) {
238				case fantasy.ToolCallContent:
239					if v.ProviderExecuted {
240						providerToolCalls = append(providerToolCalls, v)
241					}
242				case fantasy.ToolResultContent:
243					if v.ProviderExecuted {
244						providerToolResults = append(providerToolResults, v)
245					}
246				}
247			}
248		}
249		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
250		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
251		require.NotEmpty(t, providerToolResults, "should have provider-executed tool results")
252
253		// Turn 2: follow-up using step messages from turn 1.
254		// This verifies that the web_search_tool_result block
255		// round-trips correctly through toPrompt.
256		var history fantasy.Prompt
257		history = append(history, fantasy.Message{
258			Role:    fantasy.MessageRoleUser,
259			Content: []fantasy.MessagePart{fantasy.TextPart{Text: "What is the current population of Tokyo? Cite your source."}},
260		})
261		for _, step := range result.Steps {
262			history = append(history, step.Messages...)
263		}
264
265		result2, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
266			Messages:        history,
267			Prompt:          "How does that compare to Osaka?",
268			MaxOutputTokens: fantasy.Opt(int64(4000)),
269		})
270		require.NoError(t, err)
271
272		got2 := result2.Response.Content.Text()
273		require.NotEmpty(t, got2, "turn 2 should have a text response")
274		require.Contains(t, got2, "Osaka", "turn 2 response should mention Osaka")
275	})
276}