Minor fixes (#8)

Kujtim Hoxha created

* chore: remove unused headers

* feat: expose provider options

* fix: use mapstruct tag

* fix: thinking delta

* fix: agent prompt

* fix: fix reasoning metadata

fix: fix reasoning content

* fix: test

Change summary

ai/agent.go                   | 42 ++++++++++++++++--------------------
ai/agent_test.go              | 36 -------------------------------
ai/model.go                   | 19 +++++++--------
anthropic/anthropic.go        | 20 ++++++++--------
anthropic/provider_options.go | 22 +++++++++---------
cspell.json                   |  9 +++++++
openai/openai.go              | 10 ++++----
openai/openai_test.go         | 30 +++++++++++++-------------
openai/provider_options.go    | 42 ++++++++++++++++++------------------
9 files changed, 99 insertions(+), 131 deletions(-)

Detailed changes

ai/agent.go 🔗

@@ -144,7 +144,6 @@ type AgentCall struct {
 	PresencePenalty  *float64 `json:"presence_penalty"`
 	FrequencyPenalty *float64 `json:"frequency_penalty"`
 	ActiveTools      []string `json:"active_tools"`
-	Headers          map[string]string
 	ProviderOptions  ProviderOptions
 	OnRetry          OnRetryCallback
 	MaxRetries       *int
@@ -336,10 +335,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
 		maps.Copy(headers, a.settings.headers)
 	}
 
-	if call.Headers != nil {
-		maps.Copy(headers, call.Headers)
-	}
-	call.Headers = headers
 	return call
 }
 
@@ -420,7 +415,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				FrequencyPenalty: opts.FrequencyPenalty,
 				Tools:            preparedTools,
 				ToolChoice:       &stepToolChoice,
-				Headers:          opts.Headers,
 				ProviderOptions:  opts.ProviderOptions,
 			})
 		})
@@ -747,7 +741,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 		PresencePenalty:  opts.PresencePenalty,
 		FrequencyPenalty: opts.FrequencyPenalty,
 		ActiveTools:      opts.ActiveTools,
-		Headers:          opts.Headers,
 		ProviderOptions:  opts.ProviderOptions,
 		MaxRetries:       opts.MaxRetries,
 		StopWhen:         opts.StopWhen,
@@ -838,7 +831,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			FrequencyPenalty: call.FrequencyPenalty,
 			Tools:            preparedTools,
 			ToolChoice:       &stepToolChoice,
-			Headers:          call.Headers,
 			ProviderOptions:  call.ProviderOptions,
 		}
 
@@ -994,9 +986,8 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
 	if system != "" {
 		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
 	}
-
-	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 	preparedPrompt = append(preparedPrompt, messages...)
+	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
 	return preparedPrompt, nil
 }
 
@@ -1077,6 +1068,11 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 
 	activeToolCalls := make(map[string]*ToolCallContent)
 	activeTextContent := make(map[string]string)
+	type reasoningContent struct {
+		content string
+		options ProviderMetadata
+	}
+	activeReasoningContent := make(map[string]reasoningContent)
 
 	// Process stream parts
 	for part := range stream {
@@ -1134,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			}
 
 		case StreamPartTypeReasoningStart:
-			activeTextContent[part.ID] = ""
+			activeReasoningContent[part.ID] = reasoningContent{content: ""}
 			if opts.OnReasoningStart != nil {
 				err := opts.OnReasoningStart(part.ID)
 				if err != nil {
@@ -1143,8 +1139,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			}
 
 		case StreamPartTypeReasoningDelta:
-			if _, exists := activeTextContent[part.ID]; exists {
-				activeTextContent[part.ID] += part.Delta
+			if active, exists := activeReasoningContent[part.ID]; exists {
+				active.content += part.Delta
+				active.options = part.ProviderMetadata
+				activeReasoningContent[part.ID] = active
 			}
 			if opts.OnReasoningDelta != nil {
 				err := opts.OnReasoningDelta(part.ID, part.Delta)
@@ -1154,21 +1152,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			}
 
 		case StreamPartTypeReasoningEnd:
-			if text, exists := activeTextContent[part.ID]; exists {
-				stepContent = append(stepContent, ReasoningContent{
-					Text:             text,
-					ProviderMetadata: part.ProviderMetadata,
-				})
+			if active, exists := activeReasoningContent[part.ID]; exists {
+				content := ReasoningContent{
+					Text:             active.content,
+					ProviderMetadata: active.options,
+				}
+				stepContent = append(stepContent, content)
 				if opts.OnReasoningEnd != nil {
-					err := opts.OnReasoningEnd(part.ID, ReasoningContent{
-						Text:             text,
-						ProviderMetadata: part.ProviderMetadata,
-					})
+					err := opts.OnReasoningEnd(part.ID, content)
 					if err != nil {
 						return StepResult{}, false, err
 					}
 				}
-				delete(activeTextContent, part.ID)
+				delete(activeReasoningContent, part.ID)
 			}
 
 		case StreamPartTypeToolInputStart:

ai/agent_test.go 🔗

@@ -563,42 +563,6 @@ func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
 	require.NotNil(t, result)
 }
 
-// Test options.headers
-func TestAgent_Generate_OptionsHeaders(t *testing.T) {
-	t.Parallel()
-
-	model := &mockLanguageModel{
-		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
-			// Verify headers are passed
-			require.Equal(t, map[string]string{
-				"custom-request-header": "request-header-value",
-			}, call.Headers)
-
-			return &Response{
-				Content: []Content{
-					TextContent{Text: "Hello, world!"},
-				},
-				Usage: Usage{
-					InputTokens:  3,
-					OutputTokens: 10,
-					TotalTokens:  13,
-				},
-				FinishReason: FinishReasonStop,
-			}, nil
-		},
-	}
-
-	agent := NewAgent(model)
-	result, err := agent.Generate(context.Background(), AgentCall{
-		Prompt:  "test-input",
-		Headers: map[string]string{"custom-request-header": "request-header-value"},
-	})
-
-	require.NoError(t, err)
-	require.NotNil(t, result)
-	require.Equal(t, "Hello, world!", result.Response.Content.Text())
-}
-
 // Test options.activeTools filtering
 func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 	t.Parallel()

ai/model.go 🔗

@@ -176,16 +176,15 @@ func SpecificToolChoice(name string) ToolChoice {
 }
 
 type Call struct {
-	Prompt           Prompt            `json:"prompt"`
-	MaxOutputTokens  *int64            `json:"max_output_tokens"`
-	Temperature      *float64          `json:"temperature"`
-	TopP             *float64          `json:"top_p"`
-	TopK             *int64            `json:"top_k"`
-	PresencePenalty  *float64          `json:"presence_penalty"`
-	FrequencyPenalty *float64          `json:"frequency_penalty"`
-	Tools            []Tool            `json:"tools"`
-	ToolChoice       *ToolChoice       `json:"tool_choice"`
-	Headers          map[string]string `json:"headers"`
+	Prompt           Prompt      `json:"prompt"`
+	MaxOutputTokens  *int64      `json:"max_output_tokens"`
+	Temperature      *float64    `json:"temperature"`
+	TopP             *float64    `json:"top_p"`
+	TopK             *int64      `json:"top_k"`
+	PresencePenalty  *float64    `json:"presence_penalty"`
+	FrequencyPenalty *float64    `json:"frequency_penalty"`
+	Tools            []Tool      `json:"tools"`
+	ToolChoice       *ToolChoice `json:"tool_choice"`
 
 	// for provider specific options, the key is the provider id
 	ProviderOptions ProviderOptions `json:"provider_options"`

anthropic/anthropic.go 🔗

@@ -118,7 +118,7 @@ func (a languageModel) Provider() string {
 
 func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
 	params := &anthropic.MessageNewParams{}
-	providerOptions := &providerOptions{}
+	providerOptions := &ProviderOptions{}
 	if v, ok := call.ProviderOptions["anthropic"]; ok {
 		err := ai.ParseOptions(v, providerOptions)
 		if err != nil {
@@ -217,21 +217,21 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
 	return params, warnings, nil
 }
 
-func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOptions {
+func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
 	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
 		if cacheControl, ok := anthropicOptions["cache_control"]; ok {
 			if cc, ok := cacheControl.(map[string]any); ok {
-				cacheControlOption := &cacheControlProviderOptions{}
+				cacheControlOption := &CacheControlProviderOptions{}
 				err := ai.ParseOptions(cc, cacheControlOption)
-				if err != nil {
+				if err == nil {
 					return cacheControlOption
 				}
 			}
 		} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
 			if cc, ok := cacheControl.(map[string]any); ok {
-				cacheControlOption := &cacheControlProviderOptions{}
+				cacheControlOption := &CacheControlProviderOptions{}
 				err := ai.ParseOptions(cc, cacheControlOption)
-				if err != nil {
+				if err == nil {
 					return cacheControlOption
 				}
 			}
@@ -240,11 +240,11 @@ func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOp
 	return nil
 }
 
-func getReasoningMetadata(providerOptions ai.ProviderOptions) *reasoningMetadata {
+func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
 	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
-		reasoningMetadata := &reasoningMetadata{}
+		reasoningMetadata := &ReasoningMetadata{}
 		err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
-		if err != nil {
+		if err == nil {
 			return reasoningMetadata
 		}
 	}
@@ -837,7 +837,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 					if !yield(ai.StreamPart{
 						Type:  ai.StreamPartTypeReasoningDelta,
 						ID:    fmt.Sprintf("%d", chunk.Index),
-						Delta: chunk.Delta.Text,
+						Delta: chunk.Delta.Thinking,
 					}) {
 						return
 					}

anthropic/provider_options.go 🔗

@@ -1,20 +1,20 @@
 package anthropic
 
-type providerOptions struct {
-	SendReasoning          *bool                   `json:"send_reasoning,omitempty"`
-	Thinking               *thinkingProviderOption `json:"thinking,omitempty"`
-	DisableParallelToolUse *bool                   `json:"disable_parallel_tool_use,omitempty"`
+type ProviderOptions struct {
+	SendReasoning          *bool                   `mapstructure:"send_reasoning,omitempty"`
+	Thinking               *ThinkingProviderOption `mapstructure:"thinking,omitempty"`
+	DisableParallelToolUse *bool                   `mapstructure:"disable_parallel_tool_use,omitempty"`
 }
 
-type thinkingProviderOption struct {
-	BudgetTokens int64 `json:"budget_tokens"`
+type ThinkingProviderOption struct {
+	BudgetTokens int64 `mapstructure:"budget_tokens"`
 }
 
-type reasoningMetadata struct {
-	Signature    string `json:"signature"`
-	RedactedData string `json:"redacted_data"`
+type ReasoningMetadata struct {
+	Signature    string `mapstructure:"signature"`
+	RedactedData string `mapstructure:"redacted_data"`
 }
 
-type cacheControlProviderOptions struct {
-	Type string `json:"type"`
+type CacheControlProviderOptions struct {
+	Type string `mapstructure:"type"`
 }

cspell.json 🔗

@@ -0,0 +1,9 @@
+{
+  "language": "en",
+  "version": "0.2",
+  "flagWords": [],
+  "words": [
+    "mapstructure",
+    "mapstructure"
+  ]
+}

openai/openai.go 🔗

@@ -145,7 +145,7 @@ func (o languageModel) Provider() string {
 func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
 	params := &openai.ChatCompletionNewParams{}
 	messages, warnings := toPrompt(call.Prompt)
-	providerOptions := &providerOptions{}
+	providerOptions := &ProviderOptions{}
 	if v, ok := call.ProviderOptions["openai"]; ok {
 		err := ai.ParseOptions(v, providerOptions)
 		if err != nil {
@@ -239,13 +239,13 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 
 	if providerOptions.ReasoningEffort != nil {
 		switch *providerOptions.ReasoningEffort {
-		case reasoningEffortMinimal:
+		case ReasoningEffortMinimal:
 			params.ReasoningEffort = shared.ReasoningEffortMinimal
-		case reasoningEffortLow:
+		case ReasoningEffortLow:
 			params.ReasoningEffort = shared.ReasoningEffortLow
-		case reasoningEffortMedium:
+		case ReasoningEffortMedium:
 			params.ReasoningEffort = shared.ReasoningEffortMedium
-		case reasoningEffortHigh:
+		case ReasoningEffortHigh:
 			params.ReasoningEffort = shared.ReasoningEffortHigh
 		default:
 			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)

openai/openai_test.go 🔗

@@ -1059,11 +1059,11 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"logitBias": map[string]int64{
+					"logit_bias": map[string]int64{
 						"50256": -100,
 					},
-					"parallelToolCalls": false,
-					"user":              "test-user-id",
+					"parallel_tool_calls": false,
+					"user":                "test-user-id",
 				},
 			},
 		})
@@ -1103,7 +1103,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"reasoningEffort": "low",
+					"reasoning_effort": "low",
 				},
 			},
 		})
@@ -1143,7 +1143,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"textVerbosity": "low",
+					"text_verbosity": "low",
 				},
 			},
 		})
@@ -1536,7 +1536,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"maxCompletionTokens": 255,
+					"max_completion_tokens": 255,
 				},
 			},
 		})
@@ -1706,7 +1706,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"promptCacheKey": "test-cache-key-123",
+					"prompt_cache_key": "test-cache-key-123",
 				},
 			},
 		})
@@ -1726,7 +1726,7 @@ func TestDoGenerate(t *testing.T) {
 		require.Equal(t, "Hello", message["content"])
 	})
 
-	t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
+	t.Run("should send safety_identifier extension value", func(t *testing.T) {
 		t.Parallel()
 
 		server := newMockServer()
@@ -1746,7 +1746,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"safetyIdentifier": "test-safety-identifier-123",
+					"safety_identifier": "test-safety-identifier-123",
 				},
 			},
 		})
@@ -1818,7 +1818,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "flex",
+					"service_tier": "flex",
 				},
 			},
 		})
@@ -1856,7 +1856,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "flex",
+					"service_tier": "flex",
 				},
 			},
 		})
@@ -1891,7 +1891,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "priority",
+					"service_tier": "priority",
 				},
 			},
 		})
@@ -1929,7 +1929,7 @@ func TestDoGenerate(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "priority",
+					"service_tier": "priority",
 				},
 			},
 		})
@@ -2693,7 +2693,7 @@ func TestDoStream(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "flex",
+					"service_tier": "flex",
 				},
 			},
 		})
@@ -2737,7 +2737,7 @@ func TestDoStream(t *testing.T) {
 			Prompt: testPrompt,
 			ProviderOptions: ai.ProviderOptions{
 				"openai": map[string]any{
-					"serviceTier": "priority",
+					"service_tier": "priority",
 				},
 			},
 		})

openai/provider_options.go 🔗

@@ -1,28 +1,28 @@
 package openai
 
-type reasoningEffort string
+type ReasoningEffort string
 
 const (
-	reasoningEffortMinimal reasoningEffort = "minimal"
-	reasoningEffortLow     reasoningEffort = "low"
-	reasoningEffortMedium  reasoningEffort = "medium"
-	reasoningEffortHigh    reasoningEffort = "high"
+	ReasoningEffortMinimal ReasoningEffort = "minimal"
+	ReasoningEffortLow     ReasoningEffort = "low"
+	ReasoningEffortMedium  ReasoningEffort = "medium"
+	ReasoningEffortHigh    ReasoningEffort = "high"
 )
 
-type providerOptions struct {
-	LogitBias           map[string]int64 `json:"logit_bias"`
-	LogProbs            *bool            `json:"log_probes"`
-	TopLogProbs         *int64           `json:"top_log_probs"`
-	ParallelToolCalls   *bool            `json:"parallel_tool_calls"`
-	User                *string          `json:"user"`
-	ReasoningEffort     *reasoningEffort `json:"reasoning_effort"`
-	MaxCompletionTokens *int64           `json:"max_completion_tokens"`
-	TextVerbosity       *string          `json:"text_verbosity"`
-	Prediction          map[string]any   `json:"prediction"`
-	Store               *bool            `json:"store"`
-	Metadata            map[string]any   `json:"metadata"`
-	PromptCacheKey      *string          `json:"prompt_cache_key"`
-	SafetyIdentifier    *string          `json:"safety_identifier"`
-	ServiceTier         *string          `json:"service_tier"`
-	StructuredOutputs   *bool            `json:"structured_outputs"`
+type ProviderOptions struct {
+	LogitBias           map[string]int64 `mapstructure:"logit_bias"`
+	LogProbs            *bool            `mapstructure:"log_probes"`
+	TopLogProbs         *int64           `mapstructure:"top_log_probs"`
+	ParallelToolCalls   *bool            `mapstructure:"parallel_tool_calls"`
+	User                *string          `mapstructure:"user"`
+	ReasoningEffort     *ReasoningEffort `mapstructure:"reasoning_effort"`
+	MaxCompletionTokens *int64           `mapstructure:"max_completion_tokens"`
+	TextVerbosity       *string          `mapstructure:"text_verbosity"`
+	Prediction          map[string]any   `mapstructure:"prediction"`
+	Store               *bool            `mapstructure:"store"`
+	Metadata            map[string]any   `mapstructure:"metadata"`
+	PromptCacheKey      *string          `mapstructure:"prompt_cache_key"`
+	SafetyIdentifier    *string          `mapstructure:"safety_identifier"`
+	ServiceTier         *string          `mapstructure:"service_tier"`
+	StructuredOutputs   *bool            `mapstructure:"structured_outputs"`
 }