openrouter_test.go

  1package providertests
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net/http"
  7	"os"
  8	"strconv"
  9	"strings"
 10	"testing"
 11
 12	"github.com/charmbracelet/fantasy/ai"
 13	"github.com/charmbracelet/fantasy/openrouter"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16)
 17
 18var openrouterTestModels = []testModel{
 19	{"kimi-k2", "moonshotai/kimi-k2-0905", false},
 20	{"grok-code-fast-1", "x-ai/grok-code-fast-1", false},
 21	{"claude-sonnet-4", "anthropic/claude-sonnet-4", true},
 22	{"gemini-2.5-flash", "google/gemini-2.5-flash", false},
 23	{"deepseek-chat-v3.1-free", "deepseek/deepseek-chat-v3.1:free", false},
 24	{"qwen3-235b-a22b-2507", "qwen/qwen3-235b-a22b-2507", false},
 25	{"gpt-5", "openai/gpt-5", true},
 26	{"glm-4.5", "z-ai/glm-4.5", false},
 27}
 28
 29func TestOpenRouterCommon(t *testing.T) {
 30	var pairs []builderPair
 31	for _, m := range openrouterTestModels {
 32		pairs = append(pairs, builderPair{m.name, openrouterBuilder(m.model), nil})
 33	}
 34	testCommon(t, pairs)
 35}
 36
 37func TestOpenRouterThinking(t *testing.T) {
 38	opts := ai.ProviderOptions{
 39		openrouter.Name: &openrouter.ProviderOptions{
 40			Reasoning: &openrouter.ReasoningOptions{
 41				Effort: openrouter.ReasoningEffortOption(openrouter.ReasoningEffortMedium),
 42			},
 43		},
 44	}
 45
 46	var pairs []builderPair
 47	for _, m := range openrouterTestModels {
 48		if !m.reasoning {
 49			continue
 50		}
 51		pairs = append(pairs, builderPair{m.name, openrouterBuilder(m.model), opts})
 52	}
 53	testThinking(t, pairs, testOpenrouterThinking)
 54}
 55
 56func testOpenrouterThinking(t *testing.T, result *ai.AgentResult) {
 57	reasoningContentCount := 0
 58	for _, step := range result.Steps {
 59		for _, msg := range step.Messages {
 60			for _, content := range msg.Content {
 61				if content.GetType() == ai.ContentTypeReasoning {
 62					reasoningContentCount += 1
 63				}
 64			}
 65		}
 66	}
 67	require.Greater(t, reasoningContentCount, 0)
 68}
 69
 70func TestOpenRouterWithUniqueToolCallIDs(t *testing.T) {
 71	type CalculatorInput struct {
 72		A int `json:"a" description:"first number"`
 73		B int `json:"b" description:"second number"`
 74	}
 75
 76	addTool := ai.NewAgentTool(
 77		"add",
 78		"Add two numbers",
 79		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 80			result := input.A + input.B
 81			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
 82		},
 83	)
 84	multiplyTool := ai.NewAgentTool(
 85		"multiply",
 86		"Multiply two numbers",
 87		func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
 88			result := input.A * input.B
 89			return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
 90		},
 91	)
 92	checkResult := func(t *testing.T, result *ai.AgentResult) {
 93		require.Len(t, result.Steps, 2)
 94
 95		var toolCalls []ai.ToolCallContent
 96		for _, content := range result.Steps[0].Content {
 97			if content.GetType() == ai.ContentTypeToolCall {
 98				toolCalls = append(toolCalls, content.(ai.ToolCallContent))
 99			}
100		}
101		for _, tc := range toolCalls {
102			require.False(t, tc.Invalid)
103			require.Contains(t, tc.ToolCallID, "test-")
104		}
105		require.Len(t, toolCalls, 2)
106
107		finalText := result.Response.Content.Text()
108		require.Contains(t, finalText, "5", "expected response to contain '5', got: %q", finalText)
109		require.Contains(t, finalText, "6", "expected response to contain '6', got: %q", finalText)
110	}
111
112	id := 0
113	generateIDFunc := func() string {
114		id += 1
115		return fmt.Sprintf("test-%d", id)
116	}
117
118	t.Run("unique tool call ids", func(t *testing.T) {
119		r := newRecorder(t)
120
121		provider := openrouter.New(
122			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
123			openrouter.WithHTTPClient(&http.Client{Transport: r}),
124			openrouter.WithLanguageUniqueToolCallIds(),
125			openrouter.WithLanguageModelGenerateIDFunc(generateIDFunc),
126		)
127		languageModel, err := provider.LanguageModel("moonshotai/kimi-k2-0905")
128		require.NoError(t, err, "failed to build language model")
129
130		agent := ai.NewAgent(
131			languageModel,
132			ai.WithSystemPrompt("You are a helpful assistant. CRITICAL: Always use both add and multiply at the same time ALWAYS."),
133			ai.WithTools(addTool),
134			ai.WithTools(multiplyTool),
135		)
136		result, err := agent.Generate(t.Context(), ai.AgentCall{
137			Prompt:          "Add and multiply the number 2 and 3",
138			MaxOutputTokens: ai.IntOption(4000),
139		})
140		require.NoError(t, err, "failed to generate")
141		checkResult(t, result)
142	})
143	t.Run("stream unique tool call ids", func(t *testing.T) {
144		r := newRecorder(t)
145
146		provider := openrouter.New(
147			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
148			openrouter.WithHTTPClient(&http.Client{Transport: r}),
149			openrouter.WithLanguageUniqueToolCallIds(),
150			openrouter.WithLanguageModelGenerateIDFunc(generateIDFunc),
151		)
152		languageModel, err := provider.LanguageModel("moonshotai/kimi-k2-0905")
153		require.NoError(t, err, "failed to build language model")
154
155		agent := ai.NewAgent(
156			languageModel,
157			ai.WithSystemPrompt("You are a helpful assistant. Always use both add and multiply at the same time."),
158			ai.WithTools(addTool),
159			ai.WithTools(multiplyTool),
160		)
161		result, err := agent.Stream(t.Context(), ai.AgentStreamCall{
162			Prompt:          "Add and multiply the number 2 and 3",
163			MaxOutputTokens: ai.IntOption(4000),
164		})
165		require.NoError(t, err, "failed to generate")
166		checkResult(t, result)
167	})
168}
169
170func openrouterBuilder(model string) builderFunc {
171	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
172		provider := openrouter.New(
173			openrouter.WithAPIKey(os.Getenv("OPENROUTER_API_KEY")),
174			openrouter.WithHTTPClient(&http.Client{Transport: r}),
175		)
176		return provider.LanguageModel(model)
177	}
178}