From 08d368d3e67ed95b4a017e57a4a937de09206f84 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Fri, 27 Feb 2026 13:31:38 -0500 Subject: [PATCH] feat(useragent): allow UA to be set on the agent level --- agent.go | 24 +++ agent_useragent_test.go | 137 ++++++++++++++++ model.go | 5 + object.go | 5 + providers/anthropic/anthropic.go | 4 +- providers/anthropic/call_useragent.go | 14 ++ providers/google/call_useragent.go | 53 ++++++ providers/google/google.go | 6 +- providers/internal/httpheaders/httpheaders.go | 18 +++ .../internal/httpheaders/httpheaders_test.go | 25 +++ providers/openai/call_useragent.go | 25 +++ providers/openai/language_model.go | 9 +- providers/openai/openai_test.go | 153 ++++++++++++++++++ providers/openai/responses_language_model.go | 8 +- 14 files changed, 474 insertions(+), 12 deletions(-) create mode 100644 agent_useragent_test.go create mode 100644 providers/anthropic/call_useragent.go create mode 100644 providers/google/call_useragent.go create mode 100644 providers/openai/call_useragent.go diff --git a/agent.go b/agent.go index 426beff16c7ec4c7d7ebb6b1e51266d2477dea68..32c8849ddfabf771f108e9f0e250f49ce9ea7433 100644 --- a/agent.go +++ b/agent.go @@ -138,6 +138,8 @@ type agentSettings struct { presencePenalty *float64 frequencyPenalty *float64 headers map[string]string + userAgent string + modelSegment string providerOptions ProviderOptions // TODO: add support for provider tools @@ -448,6 +450,8 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err FrequencyPenalty: opts.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, + UserAgent: a.settings.userAgent, + ModelSegment: a.settings.modelSegment, ProviderOptions: opts.ProviderOptions, }) }) @@ -829,6 +833,8 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, FrequencyPenalty: call.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, + UserAgent: a.settings.userAgent, + ModelSegment: a.settings.modelSegment, ProviderOptions: call.ProviderOptions, } @@ -1418,6 +1424,24 @@ func WithHeaders(headers map[string]string) AgentOption { } } +// WithUserAgent sets the User-Agent header for the agent. This overrides any +// provider-level User-Agent setting. +func WithUserAgent(ua string) AgentOption { + return func(s *agentSettings) { + s.userAgent = ua + } +} + +// WithModelSegment sets the model segment appended to the default User-Agent +// header. The default UA becomes "Fantasy/ ()". An empty +// string clears any previously set segment. This is overridden by WithUserAgent +// at either the agent or provider level. +func WithModelSegment(segment string) AgentOption { + return func(s *agentSettings) { + s.modelSegment = segment + } +} + // WithProviderOptions sets the provider options for the agent. func WithProviderOptions(providerOptions ProviderOptions) AgentOption { return func(s *agentSettings) { diff --git a/agent_useragent_test.go b/agent_useragent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..76e25b10328e3488dcd3273c8272157a06f1abe3 --- /dev/null +++ b/agent_useragent_test.go @@ -0,0 +1,137 @@ +package fantasy + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAgent_WithUserAgent_PropagatesOnGenerate(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithUserAgent("MyApp/2.0")) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "MyApp/2.0", capturedCall.UserAgent) + assert.Empty(t, capturedCall.ModelSegment) +} + +func TestAgent_WithModelSegment_PropagatesOnGenerate(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithModelSegment("Claude 4.6 Opus")) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Empty(t, capturedCall.UserAgent) + assert.Equal(t, "Claude 4.6 Opus", capturedCall.ModelSegment) +} + +func TestAgent_WithUserAgent_PropagatesOnStream(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + streamFunc: func(_ context.Context, call Call) (StreamResponse, error) { + capturedCall = call + return func(yield func(StreamPart) bool) { + yield(StreamPart{ + Type: StreamPartTypeFinish, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(model, WithUserAgent("StreamApp/1.0")) + _, err := agent.Stream(context.Background(), AgentStreamCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "StreamApp/1.0", capturedCall.UserAgent) +} + +func TestAgent_WithModelSegment_PropagatesOnStream(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + streamFunc: func(_ context.Context, call Call) (StreamResponse, error) { + capturedCall = call + return func(yield func(StreamPart) bool) { + yield(StreamPart{ + Type: StreamPartTypeFinish, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(model, WithModelSegment("GPT-5")) + _, err := agent.Stream(context.Background(), AgentStreamCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "GPT-5", capturedCall.ModelSegment) +} + +func TestAgent_NoUA_OmitsCallLevelFields(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Empty(t, capturedCall.UserAgent) + assert.Empty(t, capturedCall.ModelSegment) +} + +func TestAgent_WithUserAgentAndModelSegment_BothPropagate(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithUserAgent("App/1.0"), WithModelSegment("Claude 4.6")) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "App/1.0", capturedCall.UserAgent) + assert.Equal(t, "Claude 4.6", capturedCall.ModelSegment) +} diff --git a/model.go b/model.go index 4d1c3e31bf25ce1e3cef19853c4eb5553f008f19..92dabf3377db4f9311184a0ed55f0abc58eaf026 100644 --- a/model.go +++ b/model.go @@ -218,6 +218,11 @@ type Call struct { Tools []Tool `json:"tools"` ToolChoice *ToolChoice `json:"tool_choice"` + // UserAgent overrides the provider-level User-Agent header for this call. + UserAgent string `json:"-"` + // ModelSegment overrides the provider-level model segment for this call. + ModelSegment string `json:"-"` + // for provider specific options, the key is the provider id ProviderOptions ProviderOptions `json:"provider_options"` } diff --git a/object.go b/object.go index 4b8aed3692eda0c08c95d9726843665ed04dea5b..3e434e3818b9d774b41d2b020f2886270ba7bda7 100644 --- a/object.go +++ b/object.go @@ -41,6 +41,11 @@ type ObjectCall struct { PresencePenalty *float64 FrequencyPenalty *float64 + // UserAgent overrides the provider-level User-Agent header for this call. + UserAgent string `json:"-"` + // ModelSegment overrides the provider-level model segment for this call. + ModelSegment string `json:"-"` + ProviderOptions ProviderOptions RepairText schema.ObjectRepairFunc diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 6e574e463d3ae1193dcc752e2fa4cb3afac6873a..dfa3dd3792a09c80f543c12d59805865b74be525 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -793,7 +793,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas if err != nil { return nil, err } - response, err := a.client.Messages.New(ctx, *params) + response, err := a.client.Messages.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -871,7 +871,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S return nil, err } - stream := a.client.Messages.NewStreaming(ctx, *params) + stream := a.client.Messages.NewStreaming(ctx, *params, callUARequestOptions(call)...) acc := anthropic.Message{} return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { diff --git a/providers/anthropic/call_useragent.go b/providers/anthropic/call_useragent.go new file mode 100644 index 0000000000000000000000000000000000000000..d1fc97780f848a27ad772fc1b3c9a24d2828df8a --- /dev/null +++ b/providers/anthropic/call_useragent.go @@ -0,0 +1,14 @@ +package anthropic + +import ( + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" + "github.com/charmbracelet/anthropic-sdk-go/option" +) + +func callUARequestOptions(call fantasy.Call) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(fantasy.Version, call.UserAgent, call.ModelSegment); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} diff --git a/providers/google/call_useragent.go b/providers/google/call_useragent.go new file mode 100644 index 0000000000000000000000000000000000000000..57d5d3927f591412d50f04ac6f9f23bff4ade514 --- /dev/null +++ b/providers/google/call_useragent.go @@ -0,0 +1,53 @@ +package google + +import ( + "context" + "net/http" + + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" +) + +type callUAKey struct{} + +func withCallUA(ctx context.Context, call fantasy.Call) context.Context { + if ua, ok := httpheaders.CallUserAgent(fantasy.Version, call.UserAgent, call.ModelSegment); ok { + return context.WithValue(ctx, callUAKey{}, ua) + } + return ctx +} + +func withObjectCallUA(ctx context.Context, call fantasy.ObjectCall) context.Context { + if ua, ok := httpheaders.CallUserAgent(fantasy.Version, call.UserAgent, call.ModelSegment); ok { + return context.WithValue(ctx, callUAKey{}, ua) + } + return ctx +} + +func wrapHTTPClient(c *http.Client) *http.Client { + if c == nil { + c = http.DefaultClient + } + transport := c.Transport + if transport == nil { + transport = http.DefaultTransport + } + return &http.Client{ + Transport: &uaTransport{base: transport}, + CheckRedirect: c.CheckRedirect, + Jar: c.Jar, + Timeout: c.Timeout, + } +} + +type uaTransport struct { + base http.RoundTripper +} + +func (t *uaTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if ua, ok := req.Context().Value(callUAKey{}).(string); ok && ua != "" { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", ua) + } + return t.base.RoundTrip(req) +} diff --git a/providers/google/google.go b/providers/google/google.go index 1830a7cc2c0efaa0580d0e134be55d6fd4161a5f..2b9dd6b08df49c28a0d9dd76b15ae471d08d4b4c 100644 --- a/providers/google/google.go +++ b/providers/google/google.go @@ -193,7 +193,7 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L } cc := &genai.ClientConfig{ - HTTPClient: a.options.client, + HTTPClient: wrapHTTPClient(a.options.client), Backend: a.options.backend, APIKey: a.options.apiKey, Project: a.options.project, @@ -558,6 +558,7 @@ func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, [] // Generate implements fantasy.LanguageModel. func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + ctx = withCallUA(ctx, call) config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -593,6 +594,7 @@ func (g *languageModel) Provider() string { // Stream implements fantasy.LanguageModel. func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + ctx = withCallUA(ctx, call) config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -919,6 +921,7 @@ func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCal } func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + ctx = withObjectCallUA(ctx, call) // Convert our Schema to Google's JSON Schema format jsonSchemaMap := schema.ToMap(call.Schema) @@ -1001,6 +1004,7 @@ func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fan } func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + ctx = withObjectCallUA(ctx, call) // Convert our Schema to Google's JSON Schema format jsonSchemaMap := schema.ToMap(call.Schema) diff --git a/providers/internal/httpheaders/httpheaders.go b/providers/internal/httpheaders/httpheaders.go index 167cb0fbbf6cdfd402724a255c85650dded23c6d..263ff204613f2bebeff09356a4070db9a129ac0b 100644 --- a/providers/internal/httpheaders/httpheaders.go +++ b/providers/internal/httpheaders/httpheaders.go @@ -52,6 +52,24 @@ func ResolveHeaders(headers map[string]string, explicitUA, defaultUA string) map return out } +// CallUserAgent resolves the User-Agent for a single API call. It returns the +// resolved UA string and true if a per-call override should be applied, or +// empty string and false if the client-level UA should be used as-is. +// +// Precedence: +// 1. callUA (agent-level WithUserAgent) — highest +// 2. callSegment used to build default UA (agent-level WithModelSegment) +// 3. empty — use client-level UA (return false) +func CallUserAgent(version, callUA, callSegment string) (string, bool) { + if callUA != "" { + return callUA, true + } + if callSegment != "" { + return DefaultUserAgent(version, callSegment), true + } + return "", false +} + func sanitizeAgent(s string) string { s = strings.TrimSpace(s) var b strings.Builder diff --git a/providers/internal/httpheaders/httpheaders_test.go b/providers/internal/httpheaders/httpheaders_test.go index 771af687671aa519ccd832cdac3f1881b57d832f..b04cc4aae2c79190bd50b4192db25fa71cd22a9f 100644 --- a/providers/internal/httpheaders/httpheaders_test.go +++ b/providers/internal/httpheaders/httpheaders_test.go @@ -139,3 +139,28 @@ func TestResolveHeaders_DuplicateCaseInsensitiveKeys(t *testing.T) { _, hasLower := got["user-agent"] assert.False(t, hasLower, "all case-insensitive UA keys must be removed") } + +func TestCallUserAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + callUA string + callSegment string + wantUA string + wantOK bool + }{ + {name: "no override", callUA: "", callSegment: "", wantUA: "", wantOK: false}, + {name: "explicit UA", callUA: "MyAgent/1.0", callSegment: "", wantUA: "MyAgent/1.0", wantOK: true}, + {name: "model segment only", callUA: "", callSegment: "Claude 4.6", wantUA: "Charm Fantasy/0.11.0 (Claude 4.6)", wantOK: true}, + {name: "explicit UA wins over segment", callUA: "MyAgent/1.0", callSegment: "Claude 4.6", wantUA: "MyAgent/1.0", wantOK: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ua, ok := CallUserAgent("0.11.0", tt.callUA, tt.callSegment) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.wantUA, ua) + }) + } +} diff --git a/providers/openai/call_useragent.go b/providers/openai/call_useragent.go new file mode 100644 index 0000000000000000000000000000000000000000..007ded93b0822cf3d7efeae688b7db8aa3138234 --- /dev/null +++ b/providers/openai/call_useragent.go @@ -0,0 +1,25 @@ +package openai + +import ( + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" + "github.com/openai/openai-go/v2/option" +) + +// callUARequestOptions returns per-request options that override the +// client-level User-Agent header when the Call carries agent-level UA settings. +func callUARequestOptions(call fantasy.Call) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(fantasy.Version, call.UserAgent, call.ModelSegment); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} + +// objectCallUARequestOptions returns per-request options that override the +// client-level User-Agent header when the ObjectCall carries agent-level UA settings. +func objectCallUARequestOptions(call fantasy.ObjectCall) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(fantasy.Version, call.UserAgent, call.ModelSegment); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 9df357ac878adbe839f914d91acb0e950a1cf4e3..ae3d87649185aaef2e9eec38f3e71ef39bb07216 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -246,7 +246,7 @@ func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas if err != nil { return nil, err } - response, err := o.client.Chat.Completions.New(ctx, *params) + response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -314,7 +314,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S IncludeUsage: openai.Bool(true), } - stream := o.client.Chat.Completions.NewStreaming(ctx, *params) + stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...) isActiveText := false toolCalls := make(map[int64]streamToolCall) @@ -733,11 +733,10 @@ func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fant }, } - response, err := o.client.Chat.Completions.New(ctx, *params) + response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } - if len(response.Choices) == 0 { usage, _ := o.usageFunc(*response) return nil, &fantasy.NoObjectGeneratedError{ @@ -818,7 +817,7 @@ func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantas IncludeUsage: openai.Bool(true), } - stream := o.client.Chat.Completions.NewStreaming(ctx, *params) + stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...) return func(yield func(fantasy.ObjectStreamPart) bool) { if len(warnings) > 0 { diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 90966d0673292b2f299cf16f060041704a069ca7..67d863bea22f28ca359b3df52db0f1db8e2e8e63 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -3396,4 +3396,157 @@ func TestUserAgent(t *testing.T) { require.Len(t, server.calls, 1) assert.Equal(t, "Charm Fantasy/"+fantasy.Version, server.calls[0].headers["User-Agent"]) }) + + t.Run("Call.UserAgent overrides provider UA", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{ + Prompt: testPrompt, + UserAgent: "agent-ua", + }) + + require.Len(t, server.calls, 1) + assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("Call.ModelSegment overrides provider default", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{ + Prompt: testPrompt, + ModelSegment: "GPT-5", + }) + + require.Len(t, server.calls, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version+" (GPT-5)", server.calls[0].headers["User-Agent"]) + }) + + t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithHeaders(map[string]string{"User-Agent": "header-ua"}), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{ + Prompt: testPrompt, + UserAgent: "call-level-ua", + }) + + require.Len(t, server.calls, 1) + assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("no Call UA falls through to provider UA", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + + agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua")) + _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("agent WithModelSegment overrides provider default end-to-end", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + + agent := fantasy.NewAgent(model, fantasy.WithModelSegment("Claude 4.6")) + _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version+" (Claude 4.6)", server.calls[0].headers["User-Agent"]) + }) + + t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + + agent := fantasy.NewAgent(model) + _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"]) + }) } diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 39ee8e427e881cb88cfb206c2e75849cf1f83ee8..090117de07ded7dd7273d8e6a8655dbfdb39ea8f 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -668,7 +668,7 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) - response, err := o.client.Responses.New(ctx, *params) + response, err := o.client.Responses.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -806,7 +806,7 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings := o.prepareParams(call) - stream := o.client.Responses.NewStreaming(ctx, *params) + stream := o.client.Responses.NewStreaming(ctx, *params, callUARequestOptions(call)...) finishReason := fantasy.FinishReasonUnknown var usage fantasy.Usage @@ -1106,7 +1106,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } // Make request - response, err := o.client.Responses.New(ctx, *params) + response, err := o.client.Responses.New(ctx, *params, objectCallUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -1216,7 +1216,7 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca Format: responses.ResponseFormatTextConfigParamOfJSONSchema(schemaName, jsonSchemaMap), } - stream := o.client.Responses.NewStreaming(ctx, *params) + stream := o.client.Responses.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...) return func(yield func(fantasy.ObjectStreamPart) bool) { if len(warnings) > 0 {