@@ -120,9 +120,9 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
params := &anthropic.MessageNewParams{}
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["anthropic"]; ok {
- err := ai.ParseOptions(v, providerOptions)
- if err != nil {
- return nil, nil, err
+ providerOptions, ok = v.(*ProviderOptions)
+ if !ok {
+ return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
}
}
sendReasoning := true
@@ -217,24 +217,10 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
return params, warnings, nil
}
-func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
+func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
- if cacheControl, ok := anthropicOptions["cache_control"]; ok {
- if cc, ok := cacheControl.(map[string]any); ok {
- cacheControlOption := &CacheControlProviderOptions{}
- err := ai.ParseOptions(cc, cacheControlOption)
- if err == nil {
- return cacheControlOption
- }
- }
- } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
- if cc, ok := cacheControl.(map[string]any); ok {
- cacheControlOption := &CacheControlProviderOptions{}
- err := ai.ParseOptions(cc, cacheControlOption)
- if err == nil {
- return cacheControlOption
- }
- }
+ if options, ok := anthropicOptions.(*CacheControl); ok {
+ return options
}
}
return nil
@@ -242,10 +228,8 @@ func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOp
func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
- reasoningMetadata := &ReasoningMetadata{}
- err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
- if err == nil {
- return reasoningMetadata
+ if reasoning, ok := anthropicOptions.(*ReasoningMetadata); ok {
+ return reasoning
}
}
return nil
@@ -675,9 +659,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
}
content = append(content, ai.ReasoningContent{
Text: reasoning.Thinking,
- ProviderMetadata: map[string]map[string]any{
- "anthropic": {
- "signature": reasoning.Signature,
+ ProviderMetadata: map[string]any{
+ "anthropic": &ReasoningMetadata{
+ Signature: reasoning.Signature,
},
},
})
@@ -688,9 +672,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
}
content = append(content, ai.ReasoningContent{
Text: "",
- ProviderMetadata: map[string]map[string]any{
- "anthropic": {
- "redacted_data": reasoning.Data,
+ ProviderMetadata: map[string]any{
+ "anthropic": &ReasoningMetadata{
+ RedactedData: reasoning.Data,
},
},
})
@@ -717,11 +701,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
CacheCreationTokens: response.Usage.CacheCreationInputTokens,
CacheReadTokens: response.Usage.CacheReadInputTokens,
},
- FinishReason: mapFinishReason(string(response.StopReason)),
- ProviderMetadata: ai.ProviderMetadata{
- "anthropic": make(map[string]any),
- },
- Warnings: warnings,
+ FinishReason: mapFinishReason(string(response.StopReason)),
+ ProviderMetadata: ai.ProviderMetadata{},
+ Warnings: warnings,
}, nil
}
@@ -770,8 +752,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
Type: ai.StreamPartTypeReasoningStart,
ID: fmt.Sprintf("%d", chunk.Index),
ProviderMetadata: ai.ProviderMetadata{
- "anthropic": {
- "redacted_data": chunk.ContentBlock.Data,
+ "anthropic": &ReasoningMetadata{
+ RedactedData: chunk.ContentBlock.Data,
},
},
}) {
@@ -846,8 +828,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
Type: ai.StreamPartTypeReasoningDelta,
ID: fmt.Sprintf("%d", chunk.Index),
ProviderMetadata: ai.ProviderMetadata{
- "anthropic": {
- "signature": chunk.Delta.Signature,
+ "anthropic": &ReasoningMetadata{
+ Signature: chunk.Delta.Signature,
},
},
}) {
@@ -147,9 +147,9 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
messages, warnings := toPrompt(call.Prompt)
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["openai"]; ok {
- err := ai.ParseOptions(v, providerOptions)
- if err != nil {
- return nil, nil, err
+ providerOptions, ok = v.(*ProviderOptions)
+ if !ok {
+ return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
}
}
if call.TopK != nil {
@@ -439,22 +439,19 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
promptTokenDetails := response.Usage.PromptTokensDetails
// Build provider metadata
- providerMetadata := ai.ProviderMetadata{
- "openai": make(map[string]any),
- }
-
+ providerMetadata := &ProviderMetadata{}
// Add logprobs if available
if len(choice.Logprobs.Content) > 0 {
- providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
+ providerMetadata.Logprobs = choice.Logprobs.Content
}
// Add prediction tokens if available
if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
if completionTokenDetails.AcceptedPredictionTokens > 0 {
- providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+ providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
}
if completionTokenDetails.RejectedPredictionTokens > 0 {
- providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+ providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
}
}
@@ -467,9 +464,11 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
ReasoningTokens: completionTokenDetails.ReasoningTokens,
CacheReadTokens: promptTokenDetails.CachedTokens,
},
- FinishReason: mapOpenAiFinishReason(choice.FinishReason),
- ProviderMetadata: providerMetadata,
- Warnings: warnings,
+ FinishReason: mapOpenAiFinishReason(choice.FinishReason),
+ ProviderMetadata: ai.ProviderMetadata{
+ "openai": providerMetadata,
+ },
+ Warnings: warnings,
}, nil
}
@@ -496,10 +495,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
toolCalls := make(map[int64]toolCall)
// Build provider metadata for streaming
- streamProviderMetadata := ai.ProviderMetadata{
- "openai": make(map[string]any),
- }
-
+ streamProviderMetadata := &ProviderMetadata{}
acc := openai.ChatCompletionAccumulator{}
var usage ai.Usage
return func(yield func(ai.StreamPart) bool) {
@@ -529,10 +525,10 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
// Add prediction tokens if available
if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
if completionTokenDetails.AcceptedPredictionTokens > 0 {
- streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+ streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
}
if completionTokenDetails.RejectedPredictionTokens > 0 {
- streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+ streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
}
}
}
@@ -706,7 +702,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
// Add logprobs if available
if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
- streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
+ streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
}
// Handle annotations/citations from accumulated response
@@ -728,10 +724,12 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
yield(ai.StreamPart{
- Type: ai.StreamPartTypeFinish,
- Usage: usage,
- FinishReason: finishReason,
- ProviderMetadata: streamProviderMetadata,
+ Type: ai.StreamPartTypeFinish,
+ Usage: usage,
+ FinishReason: finishReason,
+ ProviderMetadata: ai.ProviderMetadata{
+ "openai": streamProviderMetadata,
+ },
})
return
} else {
@@ -921,8 +919,8 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.
// Check for provider-specific options like image detail
if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
- if detail, ok := providerOptions["imageDetail"].(string); ok {
- imageURL.Detail = detail
+ if detail, ok := providerOptions.(*ProviderFileOptions); ok {
+ imageURL.Detail = detail.ImageDetail
}
}
@@ -157,11 +157,9 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) {
ai.FilePart{
MediaType: "image/png",
Data: imageData,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "imageDetail": "low",
- },
- },
+ ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
+ ImageDetail: "low",
+ }),
},
},
},
@@ -941,20 +939,18 @@ func TestDoGenerate(t *testing.T) {
result, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "logProbs": true,
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ LogProbs: ai.BoolOption(true),
+ }),
})
require.NoError(t, err)
require.NotNil(t, result.ProviderMetadata)
- openaiMeta, ok := result.ProviderMetadata["openai"]
+ openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
require.True(t, ok)
- logprobs, ok := openaiMeta["logprobs"]
+ logprobs := openaiMeta.Logprobs
require.True(t, ok)
require.NotNil(t, logprobs)
})
@@ -1057,15 +1053,13 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "logit_bias": map[string]int64{
- "50256": -100,
- },
- "parallel_tool_calls": false,
- "user": "test-user-id",
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ LogitBias: map[string]int64{
+ "50256": -100,
},
- },
+ ParallelToolCalls: ai.BoolOption(false),
+ User: ai.StringOption("test-user-id"),
+ }),
})
require.NoError(t, err)
@@ -1101,11 +1095,11 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "reasoning_effort": "low",
+ ProviderOptions: NewProviderOptions(
+ &ProviderOptions{
+ ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
},
- },
+ ),
})
require.NoError(t, err)
@@ -1141,11 +1135,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "text_verbosity": "low",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ TextVerbosity: ai.StringOption("low"),
+ }),
})
require.NoError(t, err)
@@ -1393,10 +1385,11 @@ func TestDoGenerate(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, result.ProviderMetadata)
- openaiMeta, ok := result.ProviderMetadata["openai"]
+ openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
+
require.True(t, ok)
- require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
- require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+ require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
+ require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
})
t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
@@ -1534,11 +1527,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "max_completion_tokens": 255,
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ MaxCompletionTokens: ai.IntOption(255),
+ }),
})
require.NoError(t, err)
@@ -1574,14 +1565,12 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "prediction": map[string]any{
- "type": "content",
- "content": "Hello, World!",
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ Prediction: map[string]any{
+ "type": "content",
+ "content": "Hello, World!",
},
- },
+ }),
})
require.NoError(t, err)
@@ -1620,11 +1609,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "store": true,
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ Store: ai.BoolOption(true),
+ }),
})
require.NoError(t, err)
@@ -1660,13 +1647,11 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "metadata": map[string]any{
- "custom": "value",
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ Metadata: map[string]any{
+ "custom": "value",
},
- },
+ }),
})
require.NoError(t, err)
@@ -1704,11 +1689,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "prompt_cache_key": "test-cache-key-123",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ PromptCacheKey: ai.StringOption("test-cache-key-123"),
+ }),
})
require.NoError(t, err)
@@ -1744,11 +1727,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "safety_identifier": "test-safety-identifier-123",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ SafetyIdentifier: ai.StringOption("test-safety-identifier-123"),
+ }),
})
require.NoError(t, err)
@@ -1816,11 +1797,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "flex",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("flex"),
+ }),
})
require.NoError(t, err)
@@ -1854,11 +1833,9 @@ func TestDoGenerate(t *testing.T) {
result, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "flex",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("flex"),
+ }),
})
require.NoError(t, err)
@@ -1889,11 +1866,9 @@ func TestDoGenerate(t *testing.T) {
_, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "priority",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("priority"),
+ }),
})
require.NoError(t, err)
@@ -1927,11 +1902,9 @@ func TestDoGenerate(t *testing.T) {
result, err := model.Generate(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "priority",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("priority"),
+ }),
})
require.NoError(t, err)
@@ -2575,10 +2548,10 @@ func TestDoStream(t *testing.T) {
require.NotNil(t, finishPart)
require.NotNil(t, finishPart.ProviderMetadata)
- openaiMeta, ok := finishPart.ProviderMetadata["openai"]
+ openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
require.True(t, ok)
- require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
- require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+ require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
+ require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
})
t.Run("should send store extension setting", func(t *testing.T) {
@@ -2599,11 +2572,9 @@ func TestDoStream(t *testing.T) {
_, err := model.Stream(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "store": true,
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ Store: ai.BoolOption(true),
+ }),
})
require.NoError(t, err)
@@ -2643,13 +2614,11 @@ func TestDoStream(t *testing.T) {
_, err := model.Stream(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "metadata": map[string]any{
- "custom": "value",
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ Metadata: map[string]any{
+ "custom": "value",
},
- },
+ }),
})
require.NoError(t, err)
@@ -2691,11 +2660,9 @@ func TestDoStream(t *testing.T) {
_, err := model.Stream(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "flex",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("flex"),
+ }),
})
require.NoError(t, err)
@@ -2735,11 +2702,9 @@ func TestDoStream(t *testing.T) {
_, err := model.Stream(context.Background(), ai.Call{
Prompt: testPrompt,
- ProviderOptions: ai.ProviderOptions{
- "openai": map[string]any{
- "service_tier": "priority",
- },
- },
+ ProviderOptions: NewProviderOptions(&ProviderOptions{
+ ServiceTier: ai.StringOption("priority"),
+ }),
})
require.NoError(t, err)