Detailed changes
@@ -1,7 +1,11 @@
package fantasy
+import (
+ "context"
+)
+
// Provider represents a provider of language models.
type Provider interface {
Name() string
- LanguageModel(modelID string) (LanguageModel, error)
+ LanguageModel(ctx context.Context, modelID string) (LanguageModel, error)
}
@@ -120,7 +120,7 @@ func WithHTTPClient(client option.HTTPClient) Option {
}
}
-func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
+func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers))
if a.options.apiKey != "" && !a.options.useBedrock {
clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey))
@@ -140,7 +140,7 @@ func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error)
credentials = &google.Credentials{TokenSource: &googleDummyTokenSource{}}
} else {
var err error
- credentials, err = google.FindDefaultCredentials(context.TODO())
+ credentials, err = google.FindDefaultCredentials(ctx)
if err != nil {
return nil, err
}
@@ -149,7 +149,7 @@ func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error)
clientOptions = append(
clientOptions,
vertex.WithCredentials(
- context.TODO(),
+ ctx,
a.options.vertexLocation,
a.options.vertexProject,
credentials,
@@ -165,7 +165,7 @@ func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error)
} else {
clientOptions = append(
clientOptions,
- bedrock.WithLoadDefaultConfig(context.TODO()),
+ bedrock.WithLoadDefaultConfig(ctx),
)
}
}
@@ -126,7 +126,7 @@ type languageModel struct {
}
// LanguageModel implements fantasy.Provider.
-func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
+func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) {
if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") {
p, err := anthropic.New(
anthropic.WithVertex(a.options.project, a.options.location),
@@ -136,7 +136,7 @@ func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error)
if err != nil {
return nil, err
}
- return p.LanguageModel(modelID)
+ return p.LanguageModel(ctx, modelID)
}
cc := &genai.ClientConfig{
@@ -160,7 +160,7 @@ func (a *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error)
Headers: headers,
}
}
- client, err := genai.NewClient(context.Background(), cc)
+ client, err := genai.NewClient(ctx, cc)
if err != nil {
return nil, err
}
@@ -3,6 +3,7 @@ package openai
import (
"cmp"
+ "context"
"maps"
"charm.land/fantasy"
@@ -131,7 +132,7 @@ func WithUseResponsesAPI() Option {
}
// LanguageModel implements fantasy.Provider.
-func (o *provider) LanguageModel(modelID string) (fantasy.LanguageModel, error) {
+func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) {
openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
if o.options.apiKey != "" {
@@ -814,7 +814,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -847,7 +847,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -872,7 +872,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -913,7 +913,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -940,7 +940,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -975,7 +975,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1000,7 +1000,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1025,7 +1025,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1058,7 +1058,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1101,7 +1101,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-mini")
+ model, _ := provider.LanguageModel(t.Context(), "o1-mini")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1142,7 +1142,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1181,7 +1181,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1254,7 +1254,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1313,7 +1313,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1356,7 +1356,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-mini")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1392,7 +1392,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-mini")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1421,7 +1421,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1470,7 +1470,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1515,7 +1515,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1543,7 +1543,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1582,7 +1582,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1627,7 +1627,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1666,7 +1666,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1709,7 +1709,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1748,7 +1748,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1785,7 +1785,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-search-preview")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-search-preview")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1820,7 +1820,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o3-mini")
+ model, _ := provider.LanguageModel(t.Context(), "o3-mini")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1857,7 +1857,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-mini")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1891,7 +1891,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-mini")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -1928,7 +1928,7 @@ func TestDoGenerate(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
result, err := model.Generate(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2236,7 +2236,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2293,7 +2293,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2380,7 +2380,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2420,7 +2420,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2462,7 +2462,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2511,7 +2511,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2562,7 +2562,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2606,7 +2606,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2649,7 +2649,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-3.5-turbo")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo")
_, err = model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2696,7 +2696,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o3-mini")
+ model, _ := provider.LanguageModel(t.Context(), "o3-mini")
_, err = model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2739,7 +2739,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("gpt-4o-mini")
+ model, _ := provider.LanguageModel(t.Context(), "gpt-4o-mini")
_, err = model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2783,7 +2783,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -2830,7 +2830,7 @@ func TestDoStream(t *testing.T) {
WithBaseURL(server.server.URL),
)
require.NoError(t, err)
- model, _ := provider.LanguageModel("o1-preview")
+ model, _ := provider.LanguageModel(t.Context(), "o1-preview")
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: testPrompt,
@@ -136,7 +136,7 @@ func testAnthropicThinking(t *testing.T, result *fantasy.AgentResult) {
}
func anthropicBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := anthropic.New(
anthropic.WithAPIKey(os.Getenv("FANTASY_ANTHROPIC_API_KEY")),
anthropic.WithHTTPClient(&http.Client{Transport: r}),
@@ -144,6 +144,6 @@ func anthropicBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -39,7 +39,7 @@ func testAzureThinking(t *testing.T, result *fantasy.AgentResult) {
require.Greater(t, result.Response.Usage.ReasoningTokens, int64(0), "expected reasoning tokens, got none")
}
-func builderAzureO4Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderAzureO4Mini(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := azure.New(
azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)),
azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")),
@@ -48,10 +48,10 @@ func builderAzureO4Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("o4-mini")
+ return provider.LanguageModel(t.Context(), "o4-mini")
}
-func builderAzureGpt5Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderAzureGpt5Mini(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := azure.New(
azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)),
azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")),
@@ -60,10 +60,10 @@ func builderAzureGpt5Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("gpt-5-mini")
+ return provider.LanguageModel(t.Context(), "gpt-5-mini")
}
-func builderAzureGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderAzureGrok3Mini(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := azure.New(
azure.WithBaseURL(cmp.Or(os.Getenv("FANTASY_AZURE_BASE_URL"), defaultBaseURL)),
azure.WithAPIKey(cmp.Or(os.Getenv("FANTASY_AZURE_API_KEY"), "(missing)")),
@@ -72,5 +72,5 @@ func builderAzureGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error)
if err != nil {
return nil, err
}
- return provider.LanguageModel("grok-3-mini")
+ return provider.LanguageModel(t.Context(), "grok-3-mini")
}
@@ -22,7 +22,7 @@ func TestBedrockBasicAuth(t *testing.T) {
testSimple(t, builderPair{"bedrock-anthropic-claude-3-sonnet", buildersBedrockBasicAuth, nil, nil})
}
-func builderBedrockClaude3Sonnet(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderBedrockClaude3Sonnet(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := bedrock.New(
bedrock.WithHTTPClient(&http.Client{Transport: r}),
bedrock.WithSkipAuth(!r.IsRecording()),
@@ -30,10 +30,10 @@ func builderBedrockClaude3Sonnet(r *recorder.Recorder) (fantasy.LanguageModel, e
if err != nil {
return nil, err
}
- return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
}
-func builderBedrockClaude3Opus(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderBedrockClaude3Opus(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := bedrock.New(
bedrock.WithHTTPClient(&http.Client{Transport: r}),
bedrock.WithSkipAuth(!r.IsRecording()),
@@ -41,10 +41,10 @@ func builderBedrockClaude3Opus(r *recorder.Recorder) (fantasy.LanguageModel, err
if err != nil {
return nil, err
}
- return provider.LanguageModel("us.anthropic.claude-3-opus-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-opus-20240229-v1:0")
}
-func builderBedrockClaude3Haiku(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := bedrock.New(
bedrock.WithHTTPClient(&http.Client{Transport: r}),
bedrock.WithSkipAuth(!r.IsRecording()),
@@ -52,10 +52,10 @@ func builderBedrockClaude3Haiku(r *recorder.Recorder) (fantasy.LanguageModel, er
if err != nil {
return nil, err
}
- return provider.LanguageModel("us.anthropic.claude-3-haiku-20240307-v1:0")
+ return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-haiku-20240307-v1:0")
}
-func buildersBedrockBasicAuth(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func buildersBedrockBasicAuth(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := bedrock.New(
bedrock.WithHTTPClient(&http.Client{Transport: r}),
bedrock.WithAPIKey(os.Getenv("FANTASY_BEDROCK_API_KEY")),
@@ -64,5 +64,5 @@ func buildersBedrockBasicAuth(r *recorder.Recorder) (fantasy.LanguageModel, erro
if err != nil {
return nil, err
}
- return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
}
@@ -27,7 +27,7 @@ type testModel struct {
reasoning bool
}
-type builderFunc func(r *recorder.Recorder) (fantasy.LanguageModel, error)
+type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
type builderPair struct {
name string
@@ -56,7 +56,7 @@ func testSimple(t *testing.T, pair builderPair) {
t.Run("simple", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -75,7 +75,7 @@ func testSimple(t *testing.T, pair builderPair) {
t.Run("simple streaming", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -129,7 +129,7 @@ func testTool(t *testing.T, pair builderPair) {
t.Run("tool", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -149,7 +149,7 @@ func testTool(t *testing.T, pair builderPair) {
t.Run("tool streaming", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -226,7 +226,7 @@ func testMultiTool(t *testing.T, pair builderPair) {
t.Run("multi tool", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -247,7 +247,7 @@ func testMultiTool(t *testing.T, pair builderPair) {
t.Run("multi tool streaming", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
agent := fantasy.NewAgent(
@@ -273,7 +273,7 @@ func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T
t.Run("thinking", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
type WeatherInput struct {
@@ -310,7 +310,7 @@ func testThinking(t *testing.T, pairs []builderPair, thinkChecks func(*testing.T
t.Run("thinking-streaming", func(t *testing.T) {
r := newRecorder(t)
- languageModel, err := pair.builder(r)
+ languageModel, err := pair.builder(t, r)
require.NoError(t, err, "failed to build language model")
type WeatherInput struct {
@@ -70,7 +70,7 @@ func testGoogleThinking(t *testing.T, result *fantasy.AgentResult) {
}
func geminiBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := google.New(
google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
google.WithHTTPClient(&http.Client{Transport: r}),
@@ -78,12 +78,12 @@ func geminiBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
func vertexBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := google.New(
google.WithVertex(os.Getenv("FANTASY_VERTEX_PROJECT"), os.Getenv("FANTASY_VERTEX_LOCATION")),
google.WithHTTPClient(&http.Client{Transport: r}),
@@ -92,6 +92,6 @@ func vertexBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -15,7 +15,7 @@ import (
)
func anthropicImageBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := anthropic.New(
anthropic.WithAPIKey(cmp.Or(os.Getenv("FANTASY_ANTHROPIC_API_KEY"), "(missing)")),
anthropic.WithHTTPClient(&http.Client{Transport: r}),
@@ -23,12 +23,12 @@ func anthropicImageBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
func openAIImageBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openai.New(
openai.WithAPIKey(cmp.Or(os.Getenv("FANTASY_OPENAI_API_KEY"), "(missing)")),
openai.WithHTTPClient(&http.Client{Transport: r}),
@@ -36,12 +36,12 @@ func openAIImageBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
func geminiImageBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := google.New(
google.WithGeminiAPIKey(cmp.Or(os.Getenv("FANTASY_GEMINI_API_KEY"), "(missing)")),
google.WithHTTPClient(&http.Client{Transport: r}),
@@ -49,7 +49,7 @@ func geminiImageBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -79,7 +79,7 @@ func TestImageUploadAgent(t *testing.T) {
t.Run(pair.name, func(t *testing.T) {
r := newRecorder(t)
- lm, err := pair.builder(r)
+ lm, err := pair.builder(t, r)
require.NoError(t, err)
agent := fantasy.NewAgent(
@@ -126,7 +126,7 @@ func TestImageUploadAgentStreaming(t *testing.T) {
t.Run(pair.name+"-stream", func(t *testing.T) {
r := newRecorder(t)
- lm, err := pair.builder(r)
+ lm, err := pair.builder(t, r)
require.NoError(t, err)
agent := fantasy.NewAgent(
@@ -20,7 +20,7 @@ func TestOpenAIResponsesCommon(t *testing.T) {
}
func openAIReasoningBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openai.New(
openai.WithAPIKey(os.Getenv("FANTASY_OPENAI_API_KEY")),
openai.WithHTTPClient(&http.Client{Transport: r}),
@@ -29,7 +29,7 @@ func openAIReasoningBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -26,7 +26,7 @@ func TestOpenAICommon(t *testing.T) {
}
func openAIBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openai.New(
openai.WithAPIKey(os.Getenv("FANTASY_OPENAI_API_KEY")),
openai.WithHTTPClient(&http.Client{Transport: r}),
@@ -34,6 +34,6 @@ func openAIBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}
@@ -48,7 +48,7 @@ func testOpenAICompatThinking(t *testing.T, result *fantasy.AgentResult) {
require.Greater(t, reasoningContentCount, 0, "expected reasoning content, got none")
}
-func builderXAIGrokCodeFast(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderXAIGrokCodeFast(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.x.ai/v1"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")),
@@ -57,10 +57,10 @@ func builderXAIGrokCodeFast(r *recorder.Recorder) (fantasy.LanguageModel, error)
if err != nil {
return nil, err
}
- return provider.LanguageModel("grok-code-fast-1")
+ return provider.LanguageModel(t.Context(), "grok-code-fast-1")
}
-func builderXAIGrok4Fast(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderXAIGrok4Fast(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.x.ai/v1"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")),
@@ -69,10 +69,10 @@ func builderXAIGrok4Fast(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("grok-4-fast")
+ return provider.LanguageModel(t.Context(), "grok-4-fast")
}
-func builderXAIGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderXAIGrok3Mini(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.x.ai/v1"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")),
@@ -81,10 +81,10 @@ func builderXAIGrok3Mini(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("grok-3-mini")
+ return provider.LanguageModel(t.Context(), "grok-3-mini")
}
-func builderZAIGLM45(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderZAIGLM45(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_ZAI_API_KEY")),
@@ -93,10 +93,10 @@ func builderZAIGLM45(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("glm-4.5")
+ return provider.LanguageModel(t.Context(), "glm-4.5")
}
-func builderGroq(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderGroq(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://api.groq.com/openai/v1"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_GROQ_API_KEY")),
@@ -105,10 +105,10 @@ func builderGroq(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("moonshotai/kimi-k2-instruct-0905")
+ return provider.LanguageModel(t.Context(), "moonshotai/kimi-k2-instruct-0905")
}
-func builderHuggingFace(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+func builderHuggingFace(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openaicompat.New(
openaicompat.WithBaseURL("https://router.huggingface.co/v1"),
openaicompat.WithAPIKey(os.Getenv("FANTASY_HUGGINGFACE_API_KEY")),
@@ -117,5 +117,5 @@ func builderHuggingFace(r *recorder.Recorder) (fantasy.LanguageModel, error) {
if err != nil {
return nil, err
}
- return provider.LanguageModel("Qwen/Qwen3-Coder-480B-A35B-Instruct:cerebras")
+ return provider.LanguageModel(t.Context(), "Qwen/Qwen3-Coder-480B-A35B-Instruct:cerebras")
}
@@ -115,7 +115,7 @@ func testOpenrouterThinking(t *testing.T, result *fantasy.AgentResult) {
}
func openrouterBuilder(model string) builderFunc {
- return func(r *recorder.Recorder) (fantasy.LanguageModel, error) {
+ return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
provider, err := openrouter.New(
openrouter.WithAPIKey(os.Getenv("FANTASY_OPENROUTER_API_KEY")),
openrouter.WithHTTPClient(&http.Client{Transport: r}),
@@ -123,6 +123,6 @@ func openrouterBuilder(model string) builderFunc {
if err != nil {
return nil, err
}
- return provider.LanguageModel(model)
+ return provider.LanguageModel(t.Context(), model)
}
}