openrouter_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"strconv"
  9	"strings"
 10	"testing"
 11
 12	"github.com/charmbracelet/fantasy/ai"
 13	"github.com/charmbracelet/fantasy/openrouter"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16)
 17
 18var openrouterTestModels = []testModel{
 19	{"kimi-k2", "moonshotai/kimi-k2-0905", false},
 20	{"grok-code-fast-1", "x-ai/grok-code-fast-1", false},
 21	{"claude-sonnet-4", "anthropic/claude-sonnet-4", true},
 22	{"gemini-2.5-flash", "google/gemini-2.5-flash", false},
 23	{"deepseek-chat-v3.1-free", "deepseek/deepseek-chat-v3.1:free", false},
 24	{"qwen3-235b-a22b-2507", "qwen/qwen3-235b-a22b-2507", false},
 25	{"gpt-5", "openai/gpt-5", true},
 26	{"glm-4.5", "z-ai/glm-4.5", false},
 27}
 28
 29func TestOpenRouterCommon(t *testing.T) {
 30	var pairs []builderPair
 31	for _, m := range openrouterTestModels {
 32		pairs = append(pairs, builderPair{m.name, openrouterBuilder(m.model), nil})
 33	}
 34	testCommon(t, pairs)
 35}
 36
 37func TestOpenRouterThinking(t *testing.T) {
 38	opts := ai.ProviderOptions{
 39		openrouter.Name: &openrouter.ProviderOptions{
 40			Reasoning: &openrouter.ReasoningOptions{
 41				Effort: openrouter.ReasoningEffortOption(openrouter.ReasoningEffortMedium),
 42			},
 43		},
 44	}
 45
 46	var pairs []builderPair
 47	for _, m := range openrouterTestModels {
 48		if !m.reasoning {
 49			continue
 50		}
 51		pairs = append(pairs, builderPair{m.name, openrouterBuilder(m.model), opts})
 52	}
 53	testThinking(t, pairs, testOpenrouterThinking)
 54
 55	// test anthropic signature
 56	testThinking(t, []builderPair{
 57		{"claude-sonnet-4-sig", openrouterBuilder("anthropic/claude-sonnet-4"), opts},
 58	}, testOpenrouterThinkingWithSignature)
 59}
 60
 61func testOpenrouterThinkingWithSignature(t *testing.T, result *ai.AgentResult) {
 62	reasoningContentCount := 0
 63	signaturesCount := 0
 64	// Test if we got the signature
 65	for _, step := range result.Steps {
 66		for _, msg := range step.Messages {
 67			for _, content := range msg.Content {
 68				if content.GetType() == ai.ContentTypeReasoning {
 69					reasoningContentCount += 1
 70					reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content)
 71					if !ok {
 72						continue
 73					}
 74					if len(reasoningContent.ProviderOptions) == 0 {
 75						continue
 76					}
 77
 78					anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[openrouter.Name]
 79					if !ok {
 80						continue
 81					}
 82					if reasoningContent.Text != "" {
 83						if typed, ok := anthropicReasoningMetadata.(*openrouter.ReasoningMetadata); ok {
 84							require.NotEmpty(t, typed.Signature)
 85							signaturesCount += 1
 86						}
 87					}
 88				}
 89			}
 90		}
 91	}
 92	require.Greater(t, reasoningContentCount, 0)
 93	require.Greater(t, signaturesCount, 0)
 94	require.Equal(t, reasoningContentCount, signaturesCount)
 95	// we also add the anthropic metadata so test that
 96	testAnthropicThinking(t, result)
 97}
 98
 99func testOpenrouterThinking(t *testing.T, result *ai.AgentResult) {
100	reasoningContentCount := 0
101	for _, step := range result.Steps {
102		for _, msg := range step.Messages {
103			for _, content := range msg.Content {
104				if content.GetType() == ai.ContentTypeReasoning {
105					reasoningContentCount += 1
106				}
107			}
108		}
109	}
110	require.Greater(t, reasoningContentCount, 0)
111}
112
113func TestOpenRouterWithUniqueToolCallIDs(t *testing.T) {
114	type CalculatorInput struct {
115		A int `json:"a" description:"first number"`
116		B int `json:"b" description:"second number"`
117	}
118
119	addTool := ai.NewAgentTool(
120		"add",
121		"Add two numbers",
122		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
123			result := input.A + input.B
124			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
125		},
126	)
127	multiplyTool := ai.NewAgentTool(
128		"multiply",
129		"Multiply two numbers",
130		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
131			result := input.A * input.B
132			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
133		},
134	)
135	checkResult := func(t *testing.T, result *ai.AgentResult) {
136		require.Len(t, result.Steps, 2)
137
138		var toolCalls []ai.ToolCallContent
139		for _, content := range result.Steps[0].Content {
140			if content.GetType() == ai.ContentTypeToolCall {
141				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
142			}
143		}
144		for _, tc := range toolCalls {
145			require.False(t, tc.Invalid)
146			require.Contains(t, tc.ToolCallID, "test-")
147		}
148		require.Len(t, toolCalls, 2)
149
150		finalText := result.Response.Content.Text()
151		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
152		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
153	}
154
155	id := 0
156	generateIDFunc := func() string {
157		id += 1
158		return fmt.Sprintf("test-%d", id)
159	}
160
161	t.Run("unique tool call ids", func(t *testing.T) {
162		r := newRecorder(t)
163
164		provider := openrouter.New(
165			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
166			openrouter.WithHTTPClient(&http.Client{Transport: r}),
167			openrouter.WithLanguageUniqueToolCallIds(),
168			openrouter.WithLanguageModelGenerateIDFunc(generateIDFunc),
169		)
170		languageModel, err := provider.LanguageModel("moonshotai/kimi-k2-0905")
171		require.NoError(t, err, "failed to build language model")
172
173		agent := ai.NewAgent(
174			languageModel,
175			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
176			ai.WithTools(addTool),
177			ai.WithTools(multiplyTool),
178		)
179		result, err := agent.Generate(t.Context(), ai.AgentCall{
180			Prompt:          "Add and multiply the number 2 and 3",
181			MaxOutputTokens: ai.IntOption(4000),
182		})
183		require.NoError(t, err, "failed to generate")
184		checkResult(t, result)
185	})
186	t.Run("stream unique tool call ids", func(t *testing.T) {
187		r := newRecorder(t)
188
189		provider := openrouter.New(
190			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
191			openrouter.WithHTTPClient(&http.Client{Transport: r}),
192			openrouter.WithLanguageUniqueToolCallIds(),
193			openrouter.WithLanguageModelGenerateIDFunc(generateIDFunc),
194		)
195		languageModel, err := provider.LanguageModel("moonshotai/kimi-k2-0905")
196		require.NoError(t, err, "failed to build language model")
197
198		agent := ai.NewAgent(
199			languageModel,
200			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
201			ai.WithTools(addTool),
202			ai.WithTools(multiplyTool),
203		)
204		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
205			Prompt:          "Add and multiply the number 2 and 3",
206			MaxOutputTokens: ai.IntOption(4000),
207		})
208		require.NoError(t, err, "failed to generate")
209		checkResult(t, result)
210	})
211}
212
213func openrouterBuilder(model string) builderFunc {
214	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
215		provider := openrouter.New(
216			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
217			openrouter.WithHTTPClient(&http.Client{Transport: r}),
218		)
219		return provider.LanguageModel(model)
220	}
221}