diff --git a/ai/agent.go b/ai/agent.go index 4b75f6782893bffa0422f2e78267981788161275..16327e65cf9b24e39f0c5158f23719893d3b6611 100644 --- a/ai/agent.go +++ b/ai/agent.go @@ -144,7 +144,6 @@ type AgentCall struct { PresencePenalty *float64 `json:"presence_penalty"` FrequencyPenalty *float64 `json:"frequency_penalty"` ActiveTools []string `json:"active_tools"` - Headers map[string]string ProviderOptions ProviderOptions OnRetry OnRetryCallback MaxRetries *int @@ -336,10 +335,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { maps.Copy(headers, a.settings.headers) } - if call.Headers != nil { - maps.Copy(headers, call.Headers) - } - call.Headers = headers return call } @@ -420,7 +415,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err FrequencyPenalty: opts.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, - Headers: opts.Headers, ProviderOptions: opts.ProviderOptions, }) }) @@ -747,7 +741,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, PresencePenalty: opts.PresencePenalty, FrequencyPenalty: opts.FrequencyPenalty, ActiveTools: opts.ActiveTools, - Headers: opts.Headers, ProviderOptions: opts.ProviderOptions, MaxRetries: opts.MaxRetries, StopWhen: opts.StopWhen, @@ -838,7 +831,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, FrequencyPenalty: call.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, - Headers: call.Headers, ProviderOptions: call.ProviderOptions, } @@ -994,9 +986,8 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files .. if system != "" { preparedPrompt = append(preparedPrompt, NewSystemMessage(system)) } - - preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...)) preparedPrompt = append(preparedPrompt, messages...) + preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...)) return preparedPrompt, nil } @@ -1077,6 +1068,11 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op activeToolCalls := make(map[string]*ToolCallContent) activeTextContent := make(map[string]string) + type reasoningContent struct { + content string + options ProviderMetadata + } + activeReasoningContent := make(map[string]reasoningContent) // Process stream parts for part := range stream { @@ -1134,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } case StreamPartTypeReasoningStart: - activeTextContent[part.ID] = "" + activeReasoningContent[part.ID] = reasoningContent{content: ""} if opts.OnReasoningStart != nil { err := opts.OnReasoningStart(part.ID) if err != nil { @@ -1143,8 +1139,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } case StreamPartTypeReasoningDelta: - if _, exists := activeTextContent[part.ID]; exists { - activeTextContent[part.ID] += part.Delta + if active, exists := activeReasoningContent[part.ID]; exists { + active.content += part.Delta + active.options = part.ProviderMetadata + activeReasoningContent[part.ID] = active } if opts.OnReasoningDelta != nil { err := opts.OnReasoningDelta(part.ID, part.Delta) @@ -1154,21 +1152,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } case StreamPartTypeReasoningEnd: - if text, exists := activeTextContent[part.ID]; exists { - stepContent = append(stepContent, ReasoningContent{ - Text: text, - ProviderMetadata: part.ProviderMetadata, - }) + if active, exists := activeReasoningContent[part.ID]; exists { + content := ReasoningContent{ + Text: active.content, + ProviderMetadata: active.options, + } + stepContent = append(stepContent, content) if opts.OnReasoningEnd != nil { - err := opts.OnReasoningEnd(part.ID, ReasoningContent{ - Text: text, - ProviderMetadata: part.ProviderMetadata, - }) + err := opts.OnReasoningEnd(part.ID, content) if err != nil { return StepResult{}, false, err } } - delete(activeTextContent, part.ID) + delete(activeReasoningContent, part.ID) } case StreamPartTypeToolInputStart: diff --git a/ai/agent_test.go b/ai/agent_test.go index 73301a0892892f45d67a9f1ab9e0c865a34a9858..0339c675bc20a3845765d3bc784bed88f6976cce 100644 --- a/ai/agent_test.go +++ b/ai/agent_test.go @@ -563,42 +563,6 @@ func TestAgent_Generate_WithSystemPrompt(t *testing.T) { require.NotNil(t, result) } -// Test options.headers -func TestAgent_Generate_OptionsHeaders(t *testing.T) { - t.Parallel() - - model := &mockLanguageModel{ - generateFunc: func(ctx context.Context, call Call) (*Response, error) { - // Verify headers are passed - require.Equal(t, map[string]string{ - "custom-request-header": "request-header-value", - }, call.Headers) - - return &Response{ - Content: []Content{ - TextContent{Text: "Hello, world!"}, - }, - Usage: Usage{ - InputTokens: 3, - OutputTokens: 10, - TotalTokens: 13, - }, - FinishReason: FinishReasonStop, - }, nil - }, - } - - agent := NewAgent(model) - result, err := agent.Generate(context.Background(), AgentCall{ - Prompt: "test-input", - Headers: map[string]string{"custom-request-header": "request-header-value"}, - }) - - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "Hello, world!", result.Response.Content.Text()) -} - // Test options.activeTools filtering func TestAgent_Generate_OptionsActiveTools(t *testing.T) { t.Parallel() diff --git a/ai/model.go b/ai/model.go index 6e2f2415de5221e0a150a51c3f6e49d3d2e5dfa8..ac1875517b3c320863e1cf2dad3fb9564a0c1a5f 100644 --- a/ai/model.go +++ b/ai/model.go @@ -176,16 +176,15 @@ func SpecificToolChoice(name string) ToolChoice { } type Call struct { - Prompt Prompt `json:"prompt"` - MaxOutputTokens *int64 `json:"max_output_tokens"` - Temperature *float64 `json:"temperature"` - TopP *float64 `json:"top_p"` - TopK *int64 `json:"top_k"` - PresencePenalty *float64 `json:"presence_penalty"` - FrequencyPenalty *float64 `json:"frequency_penalty"` - Tools []Tool `json:"tools"` - ToolChoice *ToolChoice `json:"tool_choice"` - Headers map[string]string `json:"headers"` + Prompt Prompt `json:"prompt"` + MaxOutputTokens *int64 `json:"max_output_tokens"` + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int64 `json:"top_k"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + Tools []Tool `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice"` // for provider specific options, the key is the provider id ProviderOptions ProviderOptions `json:"provider_options"` diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index aea53ab8d5cc9c8f5092bad2b9c51e6586524d8f..00a444ec7ae0113b873a15644f7943292725fa5c 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -118,7 +118,7 @@ func (a languageModel) Provider() string { func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { params := &anthropic.MessageNewParams{} - providerOptions := &providerOptions{} + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["anthropic"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -217,21 +217,21 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, return params, warnings, nil } -func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOptions { +func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions { if anthropicOptions, ok := providerOptions["anthropic"]; ok { if cacheControl, ok := anthropicOptions["cache_control"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &cacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) - if err != nil { + if err == nil { return cacheControlOption } } } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &cacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) - if err != nil { + if err == nil { return cacheControlOption } } @@ -240,11 +240,11 @@ func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOp return nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *reasoningMetadata { +func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata { if anthropicOptions, ok := providerOptions["anthropic"]; ok { - reasoningMetadata := &reasoningMetadata{} + reasoningMetadata := &ReasoningMetadata{} err := ai.ParseOptions(anthropicOptions, reasoningMetadata) - if err != nil { + if err == nil { return reasoningMetadata } } @@ -837,7 +837,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo if !yield(ai.StreamPart{ Type: ai.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", chunk.Index), - Delta: chunk.Delta.Text, + Delta: chunk.Delta.Thinking, }) { return } diff --git a/anthropic/provider_options.go b/anthropic/provider_options.go index b98c31c36721b6331e3ad5d5d231461443d0dd34..c22face3491e886382afa3de2e367d629a9bba3d 100644 --- a/anthropic/provider_options.go +++ b/anthropic/provider_options.go @@ -1,20 +1,20 @@ package anthropic -type providerOptions struct { - SendReasoning *bool `json:"send_reasoning,omitempty"` - Thinking *thinkingProviderOption `json:"thinking,omitempty"` - DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` +type ProviderOptions struct { + SendReasoning *bool `mapstructure:"send_reasoning,omitempty"` + Thinking *ThinkingProviderOption `mapstructure:"thinking,omitempty"` + DisableParallelToolUse *bool `mapstructure:"disable_parallel_tool_use,omitempty"` } -type thinkingProviderOption struct { - BudgetTokens int64 `json:"budget_tokens"` +type ThinkingProviderOption struct { + BudgetTokens int64 `mapstructure:"budget_tokens"` } -type reasoningMetadata struct { - Signature string `json:"signature"` - RedactedData string `json:"redacted_data"` +type ReasoningMetadata struct { + Signature string `mapstructure:"signature"` + RedactedData string `mapstructure:"redacted_data"` } -type cacheControlProviderOptions struct { - Type string `json:"type"` +type CacheControlProviderOptions struct { + Type string `mapstructure:"type"` } diff --git a/cspell.json b/cspell.json new file mode 100644 index 0000000000000000000000000000000000000000..4436732c9d075240149890a296a289b16940ac79 --- /dev/null +++ b/cspell.json @@ -0,0 +1,9 @@ +{ + "language": "en", + "version": "0.2", + "flagWords": [], + "words": [ + "mapstructure", + "mapstructure" + ] +} diff --git a/openai/openai.go b/openai/openai.go index 40dab7fe5873ef4424a985c5bf84d99634f3a23d..edf9cba4c19d832514488a7622637238027a0c23 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -145,7 +145,7 @@ func (o languageModel) Provider() string { func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { params := &openai.ChatCompletionNewParams{} messages, warnings := toPrompt(call.Prompt) - providerOptions := &providerOptions{} + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["openai"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -239,13 +239,13 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar if providerOptions.ReasoningEffort != nil { switch *providerOptions.ReasoningEffort { - case reasoningEffortMinimal: + case ReasoningEffortMinimal: params.ReasoningEffort = shared.ReasoningEffortMinimal - case reasoningEffortLow: + case ReasoningEffortLow: params.ReasoningEffort = shared.ReasoningEffortLow - case reasoningEffortMedium: + case ReasoningEffortMedium: params.ReasoningEffort = shared.ReasoningEffortMedium - case reasoningEffortHigh: + case ReasoningEffortHigh: params.ReasoningEffort = shared.ReasoningEffortHigh default: return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) diff --git a/openai/openai_test.go b/openai/openai_test.go index 5c05141d45624eb7c75216e0f7c6b4afa518dfea..27fe8c4ee8f0979f218c33683af1f003fbb46b1b 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -1059,11 +1059,11 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "logitBias": map[string]int64{ + "logit_bias": map[string]int64{ "50256": -100, }, - "parallelToolCalls": false, - "user": "test-user-id", + "parallel_tool_calls": false, + "user": "test-user-id", }, }, }) @@ -1103,7 +1103,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "reasoningEffort": "low", + "reasoning_effort": "low", }, }, }) @@ -1143,7 +1143,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "textVerbosity": "low", + "text_verbosity": "low", }, }, }) @@ -1536,7 +1536,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "maxCompletionTokens": 255, + "max_completion_tokens": 255, }, }, }) @@ -1706,7 +1706,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "promptCacheKey": "test-cache-key-123", + "prompt_cache_key": "test-cache-key-123", }, }, }) @@ -1726,7 +1726,7 @@ func TestDoGenerate(t *testing.T) { require.Equal(t, "Hello", message["content"]) }) - t.Run("should send safetyIdentifier extension value", func(t *testing.T) { + t.Run("should send safety_identifier extension value", func(t *testing.T) { t.Parallel() server := newMockServer() @@ -1746,7 +1746,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "safetyIdentifier": "test-safety-identifier-123", + "safety_identifier": "test-safety-identifier-123", }, }, }) @@ -1818,7 +1818,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "flex", + "service_tier": "flex", }, }, }) @@ -1856,7 +1856,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "flex", + "service_tier": "flex", }, }, }) @@ -1891,7 +1891,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "priority", + "service_tier": "priority", }, }, }) @@ -1929,7 +1929,7 @@ func TestDoGenerate(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "priority", + "service_tier": "priority", }, }, }) @@ -2693,7 +2693,7 @@ func TestDoStream(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "flex", + "service_tier": "flex", }, }, }) @@ -2737,7 +2737,7 @@ func TestDoStream(t *testing.T) { Prompt: testPrompt, ProviderOptions: ai.ProviderOptions{ "openai": map[string]any{ - "serviceTier": "priority", + "service_tier": "priority", }, }, }) diff --git a/openai/provider_options.go b/openai/provider_options.go index 5496563e1802c681cc1e37a9c52ea43024ae0647..b02e31072e975c01f624e849cc8d5f669c315234 100644 --- a/openai/provider_options.go +++ b/openai/provider_options.go @@ -1,28 +1,28 @@ package openai -type reasoningEffort string +type ReasoningEffort string const ( - reasoningEffortMinimal reasoningEffort = "minimal" - reasoningEffortLow reasoningEffort = "low" - reasoningEffortMedium reasoningEffort = "medium" - reasoningEffortHigh reasoningEffort = "high" + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" ) -type providerOptions struct { - LogitBias map[string]int64 `json:"logit_bias"` - LogProbs *bool `json:"log_probes"` - TopLogProbs *int64 `json:"top_log_probs"` - ParallelToolCalls *bool `json:"parallel_tool_calls"` - User *string `json:"user"` - ReasoningEffort *reasoningEffort `json:"reasoning_effort"` - MaxCompletionTokens *int64 `json:"max_completion_tokens"` - TextVerbosity *string `json:"text_verbosity"` - Prediction map[string]any `json:"prediction"` - Store *bool `json:"store"` - Metadata map[string]any `json:"metadata"` - PromptCacheKey *string `json:"prompt_cache_key"` - SafetyIdentifier *string `json:"safety_identifier"` - ServiceTier *string `json:"service_tier"` - StructuredOutputs *bool `json:"structured_outputs"` +type ProviderOptions struct { + LogitBias map[string]int64 `mapstructure:"logit_bias"` + LogProbs *bool `mapstructure:"log_probes"` + TopLogProbs *int64 `mapstructure:"top_log_probs"` + ParallelToolCalls *bool `mapstructure:"parallel_tool_calls"` + User *string `mapstructure:"user"` + ReasoningEffort *ReasoningEffort `mapstructure:"reasoning_effort"` + MaxCompletionTokens *int64 `mapstructure:"max_completion_tokens"` + TextVerbosity *string `mapstructure:"text_verbosity"` + Prediction map[string]any `mapstructure:"prediction"` + Store *bool `mapstructure:"store"` + Metadata map[string]any `mapstructure:"metadata"` + PromptCacheKey *string `mapstructure:"prompt_cache_key"` + SafetyIdentifier *string `mapstructure:"safety_identifier"` + ServiceTier *string `mapstructure:"service_tier"` + StructuredOutputs *bool `mapstructure:"structured_outputs"` }