Detailed changes
@@ -144,7 +144,6 @@ type AgentCall struct {
PresencePenalty *float64 `json:"presence_penalty"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
ActiveTools []string `json:"active_tools"`
- Headers map[string]string
ProviderOptions ProviderOptions
OnRetry OnRetryCallback
MaxRetries *int
@@ -336,10 +335,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
maps.Copy(headers, a.settings.headers)
}
- if call.Headers != nil {
- maps.Copy(headers, call.Headers)
- }
- call.Headers = headers
return call
}
@@ -420,7 +415,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
FrequencyPenalty: opts.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
- Headers: opts.Headers,
ProviderOptions: opts.ProviderOptions,
})
})
@@ -747,7 +741,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
ActiveTools: opts.ActiveTools,
- Headers: opts.Headers,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
StopWhen: opts.StopWhen,
@@ -838,7 +831,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
FrequencyPenalty: call.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
- Headers: call.Headers,
ProviderOptions: call.ProviderOptions,
}
@@ -994,9 +986,8 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
if system != "" {
preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
}
-
- preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
preparedPrompt = append(preparedPrompt, messages...)
+ preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
return preparedPrompt, nil
}
@@ -1077,6 +1068,11 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
activeToolCalls := make(map[string]*ToolCallContent)
activeTextContent := make(map[string]string)
+ type reasoningContent struct {
+ content string
+ options ProviderMetadata
+ }
+ activeReasoningContent := make(map[string]reasoningContent)
// Process stream parts
for part := range stream {
@@ -1134,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
case StreamPartTypeReasoningStart:
- activeTextContent[part.ID] = ""
+ activeReasoningContent[part.ID] = reasoningContent{content: ""}
if opts.OnReasoningStart != nil {
err := opts.OnReasoningStart(part.ID)
if err != nil {
@@ -1143,8 +1139,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
case StreamPartTypeReasoningDelta:
- if _, exists := activeTextContent[part.ID]; exists {
- activeTextContent[part.ID] += part.Delta
+ if active, exists := activeReasoningContent[part.ID]; exists {
+ active.content += part.Delta
+ active.options = part.ProviderMetadata
+ activeReasoningContent[part.ID] = active
}
if opts.OnReasoningDelta != nil {
err := opts.OnReasoningDelta(part.ID, part.Delta)
@@ -1154,21 +1152,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
case StreamPartTypeReasoningEnd:
- if text, exists := activeTextContent[part.ID]; exists {
- stepContent = append(stepContent, ReasoningContent{
- Text: text,
- ProviderMetadata: part.ProviderMetadata,
- })
+ if active, exists := activeReasoningContent[part.ID]; exists {
+ content := ReasoningContent{
+ Text: active.content,
+ ProviderMetadata: active.options,
+ }
+ stepContent = append(stepContent, content)
if opts.OnReasoningEnd != nil {
- err := opts.OnReasoningEnd(part.ID, ReasoningContent{
- Text: text,
- ProviderMetadata: part.ProviderMetadata,
- })
+ err := opts.OnReasoningEnd(part.ID, content)
if err != nil {
return StepResult{}, false, err
}
}
- delete(activeTextContent, part.ID)
+ delete(activeReasoningContent, part.ID)
}
case StreamPartTypeToolInputStart:
@@ -563,42 +563,6 @@ func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
require.NotNil(t, result)
}
-// Test options.headers
-func TestAgent_Generate_OptionsHeaders(t *testing.T) {
- t.Parallel()
-
- model := &mockLanguageModel{
- generateFunc: func(ctx context.Context, call Call) (*Response, error) {
- // Verify headers are passed
- require.Equal(t, map[string]string{
- "custom-request-header": "request-header-value",
- }, call.Headers)
-
- return &Response{
- Content: []Content{
- TextContent{Text: "Hello, world!"},
- },
- Usage: Usage{
- InputTokens: 3,
- OutputTokens: 10,
- TotalTokens: 13,
- },
- FinishReason: FinishReasonStop,
- }, nil
- },
- }
-
- agent := NewAgent(model)
- result, err := agent.Generate(context.Background(), AgentCall{
- Prompt: "test-input",
- Headers: map[string]string{"custom-request-header": "request-header-value"},
- })
-
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "Hello, world!", result.Response.Content.Text())
-}
-
// Test options.activeTools filtering
func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
t.Parallel()
@@ -176,16 +176,15 @@ func SpecificToolChoice(name string) ToolChoice {
}
type Call struct {
- Prompt Prompt `json:"prompt"`
- MaxOutputTokens *int64 `json:"max_output_tokens"`
- Temperature *float64 `json:"temperature"`
- TopP *float64 `json:"top_p"`
- TopK *int64 `json:"top_k"`
- PresencePenalty *float64 `json:"presence_penalty"`
- FrequencyPenalty *float64 `json:"frequency_penalty"`
- Tools []Tool `json:"tools"`
- ToolChoice *ToolChoice `json:"tool_choice"`
- Headers map[string]string `json:"headers"`
+ Prompt Prompt `json:"prompt"`
+ MaxOutputTokens *int64 `json:"max_output_tokens"`
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ Tools []Tool `json:"tools"`
+ ToolChoice *ToolChoice `json:"tool_choice"`
// for provider specific options, the key is the provider id
ProviderOptions ProviderOptions `json:"provider_options"`
@@ -118,7 +118,7 @@ func (a languageModel) Provider() string {
func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
params := &anthropic.MessageNewParams{}
- providerOptions := &providerOptions{}
+ providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["anthropic"]; ok {
err := ai.ParseOptions(v, providerOptions)
if err != nil {
@@ -217,21 +217,21 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
return params, warnings, nil
}
-func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOptions {
+func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
if cacheControl, ok := anthropicOptions["cache_control"]; ok {
if cc, ok := cacheControl.(map[string]any); ok {
- cacheControlOption := &cacheControlProviderOptions{}
+ cacheControlOption := &CacheControlProviderOptions{}
err := ai.ParseOptions(cc, cacheControlOption)
- if err != nil {
+ if err == nil {
return cacheControlOption
}
}
} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
if cc, ok := cacheControl.(map[string]any); ok {
- cacheControlOption := &cacheControlProviderOptions{}
+ cacheControlOption := &CacheControlProviderOptions{}
err := ai.ParseOptions(cc, cacheControlOption)
- if err != nil {
+ if err == nil {
return cacheControlOption
}
}
@@ -240,11 +240,11 @@ func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOp
return nil
}
-func getReasoningMetadata(providerOptions ai.ProviderOptions) *reasoningMetadata {
+func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
- reasoningMetadata := &reasoningMetadata{}
+ reasoningMetadata := &ReasoningMetadata{}
err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
- if err != nil {
+ if err == nil {
return reasoningMetadata
}
}
@@ -837,7 +837,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
if !yield(ai.StreamPart{
Type: ai.StreamPartTypeReasoningDelta,
ID: fmt.Sprintf("%d", chunk.Index),
- Delta: chunk.Delta.Text,
+ Delta: chunk.Delta.Thinking,
}) {
return
}
@@ -1,20 +1,20 @@
package anthropic
-type providerOptions struct {
- SendReasoning *bool `json:"send_reasoning,omitempty"`
- Thinking *thinkingProviderOption `json:"thinking,omitempty"`
- DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
+type ProviderOptions struct {
+ SendReasoning *bool `mapstructure:"send_reasoning,omitempty"`
+ Thinking *ThinkingProviderOption `mapstructure:"thinking,omitempty"`
+ DisableParallelToolUse *bool `mapstructure:"disable_parallel_tool_use,omitempty"`
}
-type thinkingProviderOption struct {
- BudgetTokens int64 `json:"budget_tokens"`
+type ThinkingProviderOption struct {
+ BudgetTokens int64 `mapstructure:"budget_tokens"`
}
-type reasoningMetadata struct {
- Signature string `json:"signature"`
- RedactedData string `json:"redacted_data"`
+type ReasoningMetadata struct {
+ Signature string `mapstructure:"signature"`
+ RedactedData string `mapstructure:"redacted_data"`
}
-type cacheControlProviderOptions struct {
- Type string `json:"type"`
+type CacheControlProviderOptions struct {
+ Type string `mapstructure:"type"`
}
@@ -0,0 +1,9 @@
+{
+ "language": "en",
+ "version": "0.2",
+ "flagWords": [],
+ "words": [
+ "mapstructure",
+ "mapstructure"
+ ]
+}
@@ -145,7 +145,7 @@ func (o languageModel) Provider() string {
func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
params := &openai.ChatCompletionNewParams{}
messages, warnings := toPrompt(call.Prompt)
- providerOptions := &providerOptions{}
+ providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["openai"]; ok {
err := ai.ParseOptions(v, providerOptions)
if err != nil {
@@ -239,13 +239,13 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
if providerOptions.ReasoningEffort != nil {
switch *providerOptions.ReasoningEffort {
- case reasoningEffortMinimal:
+ case ReasoningEffortMinimal:
params.ReasoningEffort = shared.ReasoningEffortMinimal
- case reasoningEffortLow:
+ case ReasoningEffortLow:
params.ReasoningEffort = shared.ReasoningEffortLow
- case reasoningEffortMedium:
+ case ReasoningEffortMedium:
params.ReasoningEffort = shared.ReasoningEffortMedium
- case reasoningEffortHigh:
+ case ReasoningEffortHigh:
params.ReasoningEffort = shared.ReasoningEffortHigh
default:
return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
@@ -1059,11 +1059,11 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "logitBias": map[string]int64{
+ "logit_bias": map[string]int64{
"50256": -100,
},
- "parallelToolCalls": false,
- "user": "test-user-id",
+ "parallel_tool_calls": false,
+ "user": "test-user-id",
},
},
})
@@ -1103,7 +1103,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "reasoningEffort": "low",
+ "reasoning_effort": "low",
},
},
})
@@ -1143,7 +1143,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "textVerbosity": "low",
+ "text_verbosity": "low",
},
},
})
@@ -1536,7 +1536,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "maxCompletionTokens": 255,
+ "max_completion_tokens": 255,
},
},
})
@@ -1706,7 +1706,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "promptCacheKey": "test-cache-key-123",
+ "prompt_cache_key": "test-cache-key-123",
},
},
})
@@ -1726,7 +1726,7 @@ func TestDoGenerate(t *testing.T) {
require.Equal(t, "Hello", message["content"])
})
- t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
+ t.Run("should send safety_identifier extension value", func(t *testing.T) {
t.Parallel()
server := newMockServer()
@@ -1746,7 +1746,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "safetyIdentifier": "test-safety-identifier-123",
+ "safety_identifier": "test-safety-identifier-123",
},
},
})
@@ -1818,7 +1818,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "flex",
+ "service_tier": "flex",
},
},
})
@@ -1856,7 +1856,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "flex",
+ "service_tier": "flex",
},
},
})
@@ -1891,7 +1891,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "priority",
+ "service_tier": "priority",
},
},
})
@@ -1929,7 +1929,7 @@ func TestDoGenerate(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "priority",
+ "service_tier": "priority",
},
},
})
@@ -2693,7 +2693,7 @@ func TestDoStream(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "flex",
+ "service_tier": "flex",
},
},
})
@@ -2737,7 +2737,7 @@ func TestDoStream(t *testing.T) {
Prompt: testPrompt,
ProviderOptions: ai.ProviderOptions{
"openai": map[string]any{
- "serviceTier": "priority",
+ "service_tier": "priority",
},
},
})
@@ -1,28 +1,28 @@
package openai
-type reasoningEffort string
+type ReasoningEffort string
const (
- reasoningEffortMinimal reasoningEffort = "minimal"
- reasoningEffortLow reasoningEffort = "low"
- reasoningEffortMedium reasoningEffort = "medium"
- reasoningEffortHigh reasoningEffort = "high"
+ ReasoningEffortMinimal ReasoningEffort = "minimal"
+ ReasoningEffortLow ReasoningEffort = "low"
+ ReasoningEffortMedium ReasoningEffort = "medium"
+ ReasoningEffortHigh ReasoningEffort = "high"
)
-type providerOptions struct {
- LogitBias map[string]int64 `json:"logit_bias"`
- LogProbs *bool `json:"log_probes"`
- TopLogProbs *int64 `json:"top_log_probs"`
- ParallelToolCalls *bool `json:"parallel_tool_calls"`
- User *string `json:"user"`
- ReasoningEffort *reasoningEffort `json:"reasoning_effort"`
- MaxCompletionTokens *int64 `json:"max_completion_tokens"`
- TextVerbosity *string `json:"text_verbosity"`
- Prediction map[string]any `json:"prediction"`
- Store *bool `json:"store"`
- Metadata map[string]any `json:"metadata"`
- PromptCacheKey *string `json:"prompt_cache_key"`
- SafetyIdentifier *string `json:"safety_identifier"`
- ServiceTier *string `json:"service_tier"`
- StructuredOutputs *bool `json:"structured_outputs"`
+type ProviderOptions struct {
+ LogitBias map[string]int64 `mapstructure:"logit_bias"`
+ LogProbs *bool `mapstructure:"log_probes"`
+ TopLogProbs *int64 `mapstructure:"top_log_probs"`
+ ParallelToolCalls *bool `mapstructure:"parallel_tool_calls"`
+ User *string `mapstructure:"user"`
+ ReasoningEffort *ReasoningEffort `mapstructure:"reasoning_effort"`
+ MaxCompletionTokens *int64 `mapstructure:"max_completion_tokens"`
+ TextVerbosity *string `mapstructure:"text_verbosity"`
+ Prediction map[string]any `mapstructure:"prediction"`
+ Store *bool `mapstructure:"store"`
+ Metadata map[string]any `mapstructure:"metadata"`
+ PromptCacheKey *string `mapstructure:"prompt_cache_key"`
+ SafetyIdentifier *string `mapstructure:"safety_identifier"`
+ ServiceTier *string `mapstructure:"service_tier"`
+ StructuredOutputs *bool `mapstructure:"structured_outputs"`
}