anthropic_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"net/http"
  7	"os"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/anthropic"
 12	"charm.land/x/vcr"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16var anthropicTestModels = []testModel{
 17	{"claude-sonnet-4", "claude-sonnet-4-20250514", true},
 18}
 19
 20func TestAnthropicCommon(t *testing.T) {
 21	var pairs []builderPair
 22	for _, m := range anthropicTestModels {
 23		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
 24	}
 25	testCommon(t, pairs)
 26}
 27
 28func addAnthropicCaching(ctx context.Context, options fantasy.PrepareStepFunctionOptions) (context.Context, fantasy.PrepareStepResult, error) {
 29	prepared := fantasy.PrepareStepResult{}
 30	prepared.Messages = options.Messages
 31
 32	for i := range prepared.Messages {
 33		prepared.Messages[i].ProviderOptions = nil
 34	}
 35	providerOption := fantasy.ProviderOptions{
 36		anthropic.Name: &anthropic.ProviderCacheControlOptions{
 37			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
 38		},
 39	}
 40
 41	lastSystemRoleInx := 0
 42	systemMessageUpdated := false
 43	for i, msg := range prepared.Messages {
 44		// only add cache control to the last message
 45		if msg.Role == fantasy.MessageRoleSystem {
 46			lastSystemRoleInx = i
 47		} else if !systemMessageUpdated {
 48			prepared.Messages[lastSystemRoleInx].ProviderOptions = providerOption
 49			systemMessageUpdated = true
 50		}
 51		// than add cache control to the last 2 messages
 52		if i > len(prepared.Messages)-3 {
 53			prepared.Messages[i].ProviderOptions = providerOption
 54		}
 55	}
 56	return ctx, prepared, nil
 57}
 58
 59func TestAnthropicCommonWithCacheControl(t *testing.T) {
 60	var pairs []builderPair
 61	for _, m := range anthropicTestModels {
 62		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, addAnthropicCaching})
 63	}
 64	testCommon(t, pairs)
 65}
 66
 67func TestAnthropicThinking(t *testing.T) {
 68	opts := fantasy.ProviderOptions{
 69		anthropic.Name: &anthropic.ProviderOptions{
 70			Thinking: &anthropic.ThinkingProviderOption{
 71				BudgetTokens: 4000,
 72			},
 73		},
 74	}
 75	var pairs []builderPair
 76	for _, m := range anthropicTestModels {
 77		if !m.reasoning {
 78			continue
 79		}
 80		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, nil})
 81	}
 82	testThinking(t, pairs, testAnthropicThinking)
 83}
 84
 85func TestAnthropicThinkingWithCacheControl(t *testing.T) {
 86	opts := fantasy.ProviderOptions{
 87		anthropic.Name: &anthropic.ProviderOptions{
 88			Thinking: &anthropic.ThinkingProviderOption{
 89				BudgetTokens: 4000,
 90			},
 91		},
 92	}
 93	var pairs []builderPair
 94	for _, m := range anthropicTestModels {
 95		if !m.reasoning {
 96			continue
 97		}
 98		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), opts, addAnthropicCaching})
 99	}
100	testThinking(t, pairs, testAnthropicThinking)
101}
102
103func TestAnthropicObjectGeneration(t *testing.T) {
104	var pairs []builderPair
105	for _, m := range anthropicTestModels {
106		pairs = append(pairs, builderPair{m.name, anthropicBuilder(m.model), nil, nil})
107	}
108	testObjectGeneration(t, pairs)
109}
110
111func testAnthropicThinking(t *testing.T, result *fantasy.AgentResult) {
112	reasoningContentCount := 0
113	signaturesCount := 0
114	// Test if we got the signature
115	for _, step := range result.Steps {
116		for _, msg := range step.Messages {
117			for _, content := range msg.Content {
118				if content.GetType() == fantasy.ContentTypeReasoning {
119					reasoningContentCount += 1
120					reasoningContent, ok := fantasy.AsContentType[fantasy.ReasoningPart](content)
121					if !ok {
122						continue
123					}
124					if len(reasoningContent.ProviderOptions) == 0 {
125						continue
126					}
127
128					anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[anthropic.Name]
129					if !ok {
130						continue
131					}
132					if reasoningContent.Text != "" {
133						if typed, ok := anthropicReasoningMetadata.(*anthropic.ReasoningOptionMetadata); ok {
134							require.NotEmpty(t, typed.Signature)
135							signaturesCount += 1
136						}
137					}
138				}
139			}
140		}
141	}
142	require.Greater(t, reasoningContentCount, 0)
143	require.Greater(t, signaturesCount, 0)
144	require.Equal(t, reasoningContentCount, signaturesCount)
145}
146
147func anthropicBuilder(model string) builderFunc {
148	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
149		provider, err := anthropic.New(
150			anthropic.WithAPIKey(os.Getenv("FANTASY_ANTHROPIC_API_KEY")),
151			anthropic.WithHTTPClient(&http.Client{Transport: r}),
152		)
153		if err != nil {
154			return nil, err
155		}
156		return provider.LanguageModel(t.Context(), model)
157	}
158}
159
160// TestAnthropicWebSearch tests web search tool support via the agent
161// using WithProviderDefinedTools.
162func TestAnthropicWebSearch(t *testing.T) {
163	model := "claude-sonnet-4-20250514"
164	webSearchTool := anthropic.WebSearchTool(nil)
165
166	t.Run("generate", func(t *testing.T) {
167		r := vcr.NewRecorder(t)
168
169		lm, err := anthropicBuilder(model)(t, r)
170		require.NoError(t, err)
171
172		agent := fantasy.NewAgent(
173			lm,
174			fantasy.WithSystemPrompt("You are a helpful assistant"),
175			fantasy.WithProviderDefinedTools(webSearchTool),
176		)
177
178		result, err := agent.Generate(t.Context(), fantasy.AgentCall{
179			Prompt:          "What is the current population of Tokyo? Cite your source.",
180			MaxOutputTokens: fantasy.Opt(int64(4000)),
181		})
182		require.NoError(t, err)
183
184		got := result.Response.Content.Text()
185		require.NotEmpty(t, got, "should have a text response")
186		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
187
188		// Walk the steps and verify web search content was produced.
189		var sources []fantasy.SourceContent
190		var providerToolCalls []fantasy.ToolCallContent
191		for _, step := range result.Steps {
192			for _, c := range step.Content {
193				switch v := c.(type) {
194				case fantasy.ToolCallContent:
195					if v.ProviderExecuted {
196						providerToolCalls = append(providerToolCalls, v)
197					}
198				case fantasy.SourceContent:
199					sources = append(sources, v)
200				}
201			}
202		}
203
204		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
205		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
206		require.NotEmpty(t, sources, "should have source citations")
207		require.NotEmpty(t, sources[0].URL, "source should have a URL")
208	})
209
210	t.Run("stream", func(t *testing.T) {
211		r := vcr.NewRecorder(t)
212
213		lm, err := anthropicBuilder(model)(t, r)
214		require.NoError(t, err)
215
216		agent := fantasy.NewAgent(
217			lm,
218			fantasy.WithSystemPrompt("You are a helpful assistant"),
219			fantasy.WithProviderDefinedTools(webSearchTool),
220		)
221
222		// Turn 1: initial query triggers web search.
223		result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
224			Prompt:          "What is the current population of Tokyo? Cite your source.",
225			MaxOutputTokens: fantasy.Opt(int64(4000)),
226		})
227		require.NoError(t, err)
228
229		got := result.Response.Content.Text()
230		require.NotEmpty(t, got, "should have a text response")
231		require.Contains(t, got, "Tokyo", "response should mention Tokyo")
232
233		// Verify provider-executed tool calls and results in steps.
234		var providerToolCalls []fantasy.ToolCallContent
235		var providerToolResults []fantasy.ToolResultContent
236		for _, step := range result.Steps {
237			for _, c := range step.Content {
238				switch v := c.(type) {
239				case fantasy.ToolCallContent:
240					if v.ProviderExecuted {
241						providerToolCalls = append(providerToolCalls, v)
242					}
243				case fantasy.ToolResultContent:
244					if v.ProviderExecuted {
245						providerToolResults = append(providerToolResults, v)
246					}
247				}
248			}
249		}
250		require.NotEmpty(t, providerToolCalls, "should have provider-executed tool calls")
251		require.Equal(t, "web_search", providerToolCalls[0].ToolName)
252		require.NotEmpty(t, providerToolResults, "should have provider-executed tool results")
253
254		// Turn 2: follow-up using step messages from turn 1.
255		// This verifies that the web_search_tool_result block
256		// round-trips correctly through toPrompt.
257		var history fantasy.Prompt
258		history = append(history, fantasy.Message{
259			Role:    fantasy.MessageRoleUser,
260			Content: []fantasy.MessagePart{fantasy.TextPart{Text: "What is the current population of Tokyo? Cite your source."}},
261		})
262		for _, step := range result.Steps {
263			history = append(history, step.Messages...)
264		}
265
266		result2, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{
267			Messages:        history,
268			Prompt:          "How does that compare to Osaka?",
269			MaxOutputTokens: fantasy.Opt(int64(4000)),
270		})
271		require.NoError(t, err)
272
273		got2 := result2.Response.Content.Text()
274		require.NotEmpty(t, got2, "turn 2 should have a text response")
275		require.Contains(t, got2, "Osaka", "turn 2 response should mention Osaka")
276	})
277}
278
279// screenshotBase64 is a tiny valid 1x1 PNG encoded as base64,
280// used as a stub screenshot result in computer use tests.
281const screenshotBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
282
283func TestAnthropicComputerUse(t *testing.T) {
284	type computerUseModel struct {
285		name        string
286		model       string
287		toolVersion anthropic.ComputerUseToolVersion
288	}
289	computerUseModels := []computerUseModel{
290		{"claude-sonnet-4", "claude-sonnet-4-20250514", anthropic.ComputerUse20250124},
291		{"claude-opus-4-6", "claude-opus-4-6", anthropic.ComputerUse20251124},
292	}
293	for _, m := range computerUseModels {
294		t.Run(m.name, func(t *testing.T) {
295			t.Run("computer use", func(t *testing.T) {
296				r := vcr.NewRecorder(t)
297
298				model, err := anthropicBuilder(m.model)(t, r)
299				require.NoError(t, err)
300
301				cuTool := jsonRoundTripTool(t, anthropic.NewComputerUseTool(anthropic.ComputerUseToolOptions{
302					DisplayWidthPx:  1920,
303					DisplayHeightPx: 1080,
304					ToolVersion:     m.toolVersion,
305				}, noopComputerRun))
306
307				// First call: expect a screenshot tool call.
308				resp, err := model.Generate(t.Context(), fantasy.Call{
309					Prompt: fantasy.Prompt{
310						{Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}},
311						{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}},
312					},
313					Tools: []fantasy.Tool{cuTool},
314				})
315				require.NoError(t, err)
316				require.Equal(t, fantasy.FinishReasonToolCalls, resp.FinishReason)
317
318				toolCalls := resp.Content.ToolCalls()
319				require.Len(t, toolCalls, 1)
320				require.Equal(t, "computer", toolCalls[0].ToolName)
321				require.Contains(t, toolCalls[0].Input, "screenshot")
322
323				// Second call: send the tool result back, expect text.
324				resp2, err := model.Generate(t.Context(), fantasy.Call{
325					Prompt: fantasy.Prompt{
326						{Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}},
327						{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}},
328						{
329							Role: fantasy.MessageRoleAssistant,
330							Content: []fantasy.MessagePart{
331								fantasy.ToolCallPart{
332									ToolCallID: toolCalls[0].ToolCallID,
333									ToolName:   toolCalls[0].ToolName,
334									Input:      toolCalls[0].Input,
335								},
336							},
337						},
338						{
339							Role: fantasy.MessageRoleTool,
340							Content: []fantasy.MessagePart{
341								fantasy.ToolResultPart{
342									ToolCallID: toolCalls[0].ToolCallID,
343									Output: fantasy.ToolResultOutputContentMedia{
344										Data:      screenshotBase64,
345										MediaType: "image/png",
346									},
347								},
348							},
349						},
350					},
351					Tools: []fantasy.Tool{cuTool},
352				})
353				require.NoError(t, err)
354				require.NotEmpty(t, resp2.Content.Text())
355				require.Contains(t, resp2.Content.Text(), "desktop")
356			})
357
358			t.Run("computer use streaming", func(t *testing.T) {
359				r := vcr.NewRecorder(t)
360
361				model, err := anthropicBuilder(m.model)(t, r)
362				require.NoError(t, err)
363
364				cuTool := jsonRoundTripTool(t, anthropic.NewComputerUseTool(anthropic.ComputerUseToolOptions{
365					DisplayWidthPx:  1920,
366					DisplayHeightPx: 1080,
367					ToolVersion:     m.toolVersion,
368				}, noopComputerRun))
369
370				// First call: stream, collect tool call.
371				stream, err := model.Stream(t.Context(), fantasy.Call{
372					Prompt: fantasy.Prompt{
373						{Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}},
374						{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}},
375					},
376					Tools: []fantasy.Tool{cuTool},
377				})
378				require.NoError(t, err)
379
380				var toolCallID, toolCallName, toolCallInput string
381				var finishReason fantasy.FinishReason
382				stream(func(part fantasy.StreamPart) bool {
383					switch part.Type {
384					case fantasy.StreamPartTypeToolCall:
385						toolCallID = part.ID
386						toolCallName = part.ToolCallName
387						toolCallInput = part.ToolCallInput
388					case fantasy.StreamPartTypeFinish:
389						finishReason = part.FinishReason
390					}
391					return true
392				})
393
394				require.Equal(t, fantasy.FinishReasonToolCalls, finishReason)
395				require.Equal(t, "computer", toolCallName)
396				require.Contains(t, toolCallInput, "screenshot")
397
398				// Second call: send tool result, stream text back.
399				stream2, err := model.Stream(t.Context(), fantasy.Call{
400					Prompt: fantasy.Prompt{
401						{Role: fantasy.MessageRoleSystem, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "You are a helpful assistant"}}},
402						{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Take a screenshot of the desktop"}}},
403						{
404							Role: fantasy.MessageRoleAssistant,
405							Content: []fantasy.MessagePart{
406								fantasy.ToolCallPart{
407									ToolCallID: toolCallID,
408									ToolName:   toolCallName,
409									Input:      toolCallInput,
410								},
411							},
412						},
413						{
414							Role: fantasy.MessageRoleTool,
415							Content: []fantasy.MessagePart{
416								fantasy.ToolResultPart{
417									ToolCallID: toolCallID,
418									Output: fantasy.ToolResultOutputContentMedia{
419										Data:      screenshotBase64,
420										MediaType: "image/png",
421									},
422								},
423							},
424						},
425					},
426					Tools: []fantasy.Tool{cuTool},
427				})
428				require.NoError(t, err)
429
430				var text string
431				stream2(func(part fantasy.StreamPart) bool {
432					if part.Type == fantasy.StreamPartTypeTextDelta {
433						text += part.Delta
434					}
435					return true
436				})
437				require.NotEmpty(t, text)
438				require.Contains(t, text, "desktop")
439			})
440		})
441	}
442}
443
444// noopComputerRun is a no-op run function for tests that only need
445// to inspect the tool definition, not execute it.
446var noopComputerRun = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
447	return fantasy.ToolResponse{}, nil
448}
449
450// jsonRoundTripTool simulates a JSON round-trip on a ProviderDefinedTool
451// so numeric values become float64 as they would in real usage.
452func jsonRoundTripTool(t *testing.T, tool fantasy.ExecutableProviderTool) fantasy.ProviderDefinedTool {
453	t.Helper()
454	pdt := tool.Definition()
455	data, err := json.Marshal(pdt.Args)
456	require.NoError(t, err)
457	var args map[string]any
458	require.NoError(t, json.Unmarshal(data, &args))
459	pdt.Args = args
460	return pdt
461}