chore: change how provider options work

kujtimiihoxha created

Change summary

ai/content.go                 |   4 
ai/model.go                   |   2 
ai/util.go                    |  12 ++
anthropic/anthropic.go        |  60 ++++-------
anthropic/provider_options.go |  34 +++++-
openai/openai.go              |  50 ++++-----
openai/openai_test.go         | 183 ++++++++++++++----------------------
openai/provider_options.go    |  61 +++++++++---
8 files changed, 206 insertions(+), 200 deletions(-)

Detailed changes

ai/content.go 🔗

@@ -14,7 +14,7 @@ package ai
 //	    "cacheControl": { "type": "ephemeral" }
 //	  }
 //	}
-type ProviderMetadata map[string]map[string]any
+type ProviderMetadata map[string]any
 
 // ProviderOptions represents additional provider-specific options.
 // Options are additional input to the provider. They are passed through
@@ -34,7 +34,7 @@ type ProviderMetadata map[string]map[string]any
 //	    "cacheControl": { "type": "ephemeral" }
 //	  }
 //	}
-type ProviderOptions map[string]map[string]any
+type ProviderOptions map[string]any
 
 // FinishReason represents why a language model finished generating a response.
 //

ai/model.go 🔗

@@ -118,7 +118,7 @@ type Response struct {
 	Warnings     []CallWarning   `json:"warnings"`
 
 	// for provider specific response metadata, the key is the provider id
-	ProviderMetadata map[string]map[string]any `json:"provider_metadata"`
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
 }
 
 type StreamPartType string

ai/util.go 🔗

@@ -11,3 +11,15 @@ func ParseOptions[T any](options map[string]any, m *T) error {
 func FloatOption(f float64) *float64 {
 	return &f
 }
+
+func BoolOption(b bool) *bool {
+	return &b
+}
+
+func StringOption(s string) *string {
+	return &s
+}
+
+func IntOption(i int64) *int64 {
+	return &i
+}

anthropic/anthropic.go 🔗

@@ -120,9 +120,9 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
 	params := &anthropic.MessageNewParams{}
 	providerOptions := &ProviderOptions{}
 	if v, ok := call.ProviderOptions["anthropic"]; ok {
-		err := ai.ParseOptions(v, providerOptions)
-		if err != nil {
-			return nil, nil, err
+		providerOptions, ok = v.(*ProviderOptions)
+		if !ok {
+			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
 		}
 	}
 	sendReasoning := true
@@ -217,24 +217,10 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
 	return params, warnings, nil
 }
 
-func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
+func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
 	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
-		if cacheControl, ok := anthropicOptions["cache_control"]; ok {
-			if cc, ok := cacheControl.(map[string]any); ok {
-				cacheControlOption := &CacheControlProviderOptions{}
-				err := ai.ParseOptions(cc, cacheControlOption)
-				if err == nil {
-					return cacheControlOption
-				}
-			}
-		} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
-			if cc, ok := cacheControl.(map[string]any); ok {
-				cacheControlOption := &CacheControlProviderOptions{}
-				err := ai.ParseOptions(cc, cacheControlOption)
-				if err == nil {
-					return cacheControlOption
-				}
-			}
+		if options, ok := anthropicOptions.(*CacheControl); ok {
+			return options
 		}
 	}
 	return nil
@@ -242,10 +228,8 @@ func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOp
 
 func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
 	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
-		reasoningMetadata := &ReasoningMetadata{}
-		err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
-		if err == nil {
-			return reasoningMetadata
+		if reasoning, ok := anthropicOptions.(*ReasoningMetadata); ok {
+			return reasoning
 		}
 	}
 	return nil
@@ -675,9 +659,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			}
 			content = append(content, ai.ReasoningContent{
 				Text: reasoning.Thinking,
-				ProviderMetadata: map[string]map[string]any{
-					"anthropic": {
-						"signature": reasoning.Signature,
+				ProviderMetadata: map[string]any{
+					"anthropic": &ReasoningMetadata{
+						Signature: reasoning.Signature,
 					},
 				},
 			})
@@ -688,9 +672,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			}
 			content = append(content, ai.ReasoningContent{
 				Text: "",
-				ProviderMetadata: map[string]map[string]any{
-					"anthropic": {
-						"redacted_data": reasoning.Data,
+				ProviderMetadata: map[string]any{
+					"anthropic": &ReasoningMetadata{
+						RedactedData: reasoning.Data,
 					},
 				},
 			})
@@ -717,11 +701,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			CacheCreationTokens: response.Usage.CacheCreationInputTokens,
 			CacheReadTokens:     response.Usage.CacheReadInputTokens,
 		},
-		FinishReason: mapFinishReason(string(response.StopReason)),
-		ProviderMetadata: ai.ProviderMetadata{
-			"anthropic": make(map[string]any),
-		},
-		Warnings: warnings,
+		FinishReason:     mapFinishReason(string(response.StopReason)),
+		ProviderMetadata: ai.ProviderMetadata{},
+		Warnings:         warnings,
 	}, nil
 }
 
@@ -770,8 +752,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 						Type: ai.StreamPartTypeReasoningStart,
 						ID:   fmt.Sprintf("%d", chunk.Index),
 						ProviderMetadata: ai.ProviderMetadata{
-							"anthropic": {
-								"redacted_data": chunk.ContentBlock.Data,
+							"anthropic": &ReasoningMetadata{
+								RedactedData: chunk.ContentBlock.Data,
 							},
 						},
 					}) {
@@ -846,8 +828,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 						Type: ai.StreamPartTypeReasoningDelta,
 						ID:   fmt.Sprintf("%d", chunk.Index),
 						ProviderMetadata: ai.ProviderMetadata{
-							"anthropic": {
-								"signature": chunk.Delta.Signature,
+							"anthropic": &ReasoningMetadata{
+								Signature: chunk.Delta.Signature,
 							},
 						},
 					}) {

anthropic/provider_options.go 🔗

@@ -1,20 +1,38 @@
 package anthropic
 
+import "github.com/charmbracelet/ai/ai"
+
 type ProviderOptions struct {
-	SendReasoning          *bool                   `mapstructure:"send_reasoning,omitempty"`
-	Thinking               *ThinkingProviderOption `mapstructure:"thinking,omitempty"`
-	DisableParallelToolUse *bool                   `mapstructure:"disable_parallel_tool_use,omitempty"`
+	SendReasoning          *bool
+	Thinking               *ThinkingProviderOption
+	DisableParallelToolUse *bool
 }
 
 type ThinkingProviderOption struct {
-	BudgetTokens int64 `mapstructure:"budget_tokens"`
+	BudgetTokens int64
 }
 
 type ReasoningMetadata struct {
-	Signature    string `mapstructure:"signature"`
-	RedactedData string `mapstructure:"redacted_data"`
+	Signature    string
+	RedactedData string
+}
+
+type ProviderCacheControlOptions struct {
+	CacheControl CacheControl
+}
+
+type CacheControl struct {
+	Type string
+}
+
+func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions {
+	return ai.ProviderOptions{
+		"anthropic": opts,
+	}
 }
 
-type CacheControlProviderOptions struct {
-	Type string `mapstructure:"type"`
+func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) ai.ProviderOptions {
+	return ai.ProviderOptions{
+		"anthropic": opts,
+	}
 }

openai/openai.go 🔗

@@ -147,9 +147,9 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 	messages, warnings := toPrompt(call.Prompt)
 	providerOptions := &ProviderOptions{}
 	if v, ok := call.ProviderOptions["openai"]; ok {
-		err := ai.ParseOptions(v, providerOptions)
-		if err != nil {
-			return nil, nil, err
+		providerOptions, ok = v.(*ProviderOptions)
+		if !ok {
+			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 		}
 	}
 	if call.TopK != nil {
@@ -439,22 +439,19 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 	promptTokenDetails := response.Usage.PromptTokensDetails
 
 	// Build provider metadata
-	providerMetadata := ai.ProviderMetadata{
-		"openai": make(map[string]any),
-	}
-
+	providerMetadata := &ProviderMetadata{}
 	// Add logprobs if available
 	if len(choice.Logprobs.Content) > 0 {
-		providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
+		providerMetadata.Logprobs = choice.Logprobs.Content
 	}
 
 	// Add prediction tokens if available
 	if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 		if completionTokenDetails.AcceptedPredictionTokens > 0 {
-			providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+			providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 		}
 		if completionTokenDetails.RejectedPredictionTokens > 0 {
-			providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+			providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 		}
 	}
 
@@ -467,9 +464,11 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			ReasoningTokens: completionTokenDetails.ReasoningTokens,
 			CacheReadTokens: promptTokenDetails.CachedTokens,
 		},
-		FinishReason:     mapOpenAiFinishReason(choice.FinishReason),
-		ProviderMetadata: providerMetadata,
-		Warnings:         warnings,
+		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
+		ProviderMetadata: ai.ProviderMetadata{
+			"openai": providerMetadata,
+		},
+		Warnings: warnings,
 	}, nil
 }
 
@@ -496,10 +495,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 	toolCalls := make(map[int64]toolCall)
 
 	// Build provider metadata for streaming
-	streamProviderMetadata := ai.ProviderMetadata{
-		"openai": make(map[string]any),
-	}
-
+	streamProviderMetadata := &ProviderMetadata{}
 	acc := openai.ChatCompletionAccumulator{}
 	var usage ai.Usage
 	return func(yield func(ai.StreamPart) bool) {
@@ -529,10 +525,10 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 				// Add prediction tokens if available
 				if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
 					if completionTokenDetails.AcceptedPredictionTokens > 0 {
-						streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+						streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
 					}
 					if completionTokenDetails.RejectedPredictionTokens > 0 {
-						streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+						streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
 					}
 				}
 			}
@@ -706,7 +702,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 
 			// Add logprobs if available
 			if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
-				streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
+				streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
 			}
 
 			// Handle annotations/citations from accumulated response
@@ -728,10 +724,12 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 
 			finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
 			yield(ai.StreamPart{
-				Type:             ai.StreamPartTypeFinish,
-				Usage:            usage,
-				FinishReason:     finishReason,
-				ProviderMetadata: streamProviderMetadata,
+				Type:         ai.StreamPartTypeFinish,
+				Usage:        usage,
+				FinishReason: finishReason,
+				ProviderMetadata: ai.ProviderMetadata{
+					"openai": streamProviderMetadata,
+				},
 			})
 			return
 		} else {
@@ -921,8 +919,8 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.
 
 						// Check for provider-specific options like image detail
 						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
-							if detail, ok := providerOptions["imageDetail"].(string); ok {
-								imageURL.Detail = detail
+							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
+								imageURL.Detail = detail.ImageDetail
 							}
 						}
 

openai/openai_test.go 🔗

@@ -157,11 +157,9 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) {
 					ai.FilePart{
 						MediaType: "image/png",
 						Data:      imageData,
-						ProviderOptions: ai.ProviderOptions{
-							"openai": map[string]any{
-								"imageDetail": "low",
-							},
-						},
+						ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{
+							ImageDetail: "low",
+						}),
 					},
 				},
 			},
@@ -941,20 +939,18 @@ func TestDoGenerate(t *testing.T) {
 
 		result, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"logProbs": true,
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				LogProbs: ai.BoolOption(true),
+			}),
 		})
 
 		require.NoError(t, err)
 		require.NotNil(t, result.ProviderMetadata)
 
-		openaiMeta, ok := result.ProviderMetadata["openai"]
+		openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
 		require.True(t, ok)
 
-		logprobs, ok := openaiMeta["logprobs"]
+		logprobs := openaiMeta.Logprobs
 		require.True(t, ok)
 		require.NotNil(t, logprobs)
 	})
@@ -1057,15 +1053,13 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"logit_bias": map[string]int64{
-						"50256": -100,
-					},
-					"parallel_tool_calls": false,
-					"user":                "test-user-id",
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				LogitBias: map[string]int64{
+					"50256": -100,
 				},
-			},
+				ParallelToolCalls: ai.BoolOption(false),
+				User:              ai.StringOption("test-user-id"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1101,11 +1095,11 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"reasoning_effort": "low",
+			ProviderOptions: NewProviderOptions(
+				&ProviderOptions{
+					ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow),
 				},
-			},
+			),
 		})
 
 		require.NoError(t, err)
@@ -1141,11 +1135,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"text_verbosity": "low",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				TextVerbosity: ai.StringOption("low"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1393,10 +1385,11 @@ func TestDoGenerate(t *testing.T) {
 		require.NoError(t, err)
 		require.NotNil(t, result.ProviderMetadata)
 
-		openaiMeta, ok := result.ProviderMetadata["openai"]
+		openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata)
+
 		require.True(t, ok)
-		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
-		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+		require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
+		require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
 	})
 
 	t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
@@ -1534,11 +1527,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"max_completion_tokens": 255,
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				MaxCompletionTokens: ai.IntOption(255),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1574,14 +1565,12 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"prediction": map[string]any{
-						"type":    "content",
-						"content": "Hello, World!",
-					},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				Prediction: map[string]any{
+					"type":    "content",
+					"content": "Hello, World!",
 				},
-			},
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1620,11 +1609,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"store": true,
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				Store: ai.BoolOption(true),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1660,13 +1647,11 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"metadata": map[string]any{
-						"custom": "value",
-					},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				Metadata: map[string]any{
+					"custom": "value",
 				},
-			},
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1704,11 +1689,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"prompt_cache_key": "test-cache-key-123",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				PromptCacheKey: ai.StringOption("test-cache-key-123"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1744,11 +1727,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"safety_identifier": "test-safety-identifier-123",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				SafetyIdentifier: ai.StringOption("test-safety-identifier-123"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1816,11 +1797,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "flex",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("flex"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1854,11 +1833,9 @@ func TestDoGenerate(t *testing.T) {
 
 		result, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "flex",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("flex"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1889,11 +1866,9 @@ func TestDoGenerate(t *testing.T) {
 
 		_, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "priority",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("priority"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -1927,11 +1902,9 @@ func TestDoGenerate(t *testing.T) {
 
 		result, err := model.Generate(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "priority",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("priority"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -2575,10 +2548,10 @@ func TestDoStream(t *testing.T) {
 		require.NotNil(t, finishPart)
 		require.NotNil(t, finishPart.ProviderMetadata)
 
-		openaiMeta, ok := finishPart.ProviderMetadata["openai"]
+		openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata)
 		require.True(t, ok)
-		require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
-		require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+		require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens)
+		require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens)
 	})
 
 	t.Run("should send store extension setting", func(t *testing.T) {
@@ -2599,11 +2572,9 @@ func TestDoStream(t *testing.T) {
 
 		_, err := model.Stream(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"store": true,
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				Store: ai.BoolOption(true),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -2643,13 +2614,11 @@ func TestDoStream(t *testing.T) {
 
 		_, err := model.Stream(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"metadata": map[string]any{
-						"custom": "value",
-					},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				Metadata: map[string]any{
+					"custom": "value",
 				},
-			},
+			}),
 		})
 
 		require.NoError(t, err)
@@ -2691,11 +2660,9 @@ func TestDoStream(t *testing.T) {
 
 		_, err := model.Stream(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "flex",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("flex"),
+			}),
 		})
 
 		require.NoError(t, err)
@@ -2735,11 +2702,9 @@ func TestDoStream(t *testing.T) {
 
 		_, err := model.Stream(context.Background(), ai.Call{
 			Prompt: testPrompt,
-			ProviderOptions: ai.ProviderOptions{
-				"openai": map[string]any{
-					"service_tier": "priority",
-				},
-			},
+			ProviderOptions: NewProviderOptions(&ProviderOptions{
+				ServiceTier: ai.StringOption("priority"),
+			}),
 		})
 
 		require.NoError(t, err)

openai/provider_options.go 🔗

@@ -1,5 +1,10 @@
 package openai
 
+import (
+	"github.com/charmbracelet/ai/ai"
+	"github.com/openai/openai-go/v2"
+)
+
 type ReasoningEffort string
 
 const (
@@ -9,20 +14,46 @@ const (
 	ReasoningEffortHigh    ReasoningEffort = "high"
 )
 
+type ProviderFileOptions struct {
+	ImageDetail string
+}
+
+type ProviderMetadata struct {
+	Logprobs                 []openai.ChatCompletionTokenLogprob
+	AcceptedPredictionTokens int64
+	RejectedPredictionTokens int64
+}
+
 type ProviderOptions struct {
-	LogitBias           map[string]int64 `mapstructure:"logit_bias"`
-	LogProbs            *bool            `mapstructure:"log_probes"`
-	TopLogProbs         *int64           `mapstructure:"top_log_probs"`
-	ParallelToolCalls   *bool            `mapstructure:"parallel_tool_calls"`
-	User                *string          `mapstructure:"user"`
-	ReasoningEffort     *ReasoningEffort `mapstructure:"reasoning_effort"`
-	MaxCompletionTokens *int64           `mapstructure:"max_completion_tokens"`
-	TextVerbosity       *string          `mapstructure:"text_verbosity"`
-	Prediction          map[string]any   `mapstructure:"prediction"`
-	Store               *bool            `mapstructure:"store"`
-	Metadata            map[string]any   `mapstructure:"metadata"`
-	PromptCacheKey      *string          `mapstructure:"prompt_cache_key"`
-	SafetyIdentifier    *string          `mapstructure:"safety_identifier"`
-	ServiceTier         *string          `mapstructure:"service_tier"`
-	StructuredOutputs   *bool            `mapstructure:"structured_outputs"`
+	LogitBias           map[string]int64
+	LogProbs            *bool
+	TopLogProbs         *int64
+	ParallelToolCalls   *bool
+	User                *string
+	ReasoningEffort     *ReasoningEffort
+	MaxCompletionTokens *int64
+	TextVerbosity       *string
+	Prediction          map[string]any
+	Store               *bool
+	Metadata            map[string]any
+	PromptCacheKey      *string
+	SafetyIdentifier    *string
+	ServiceTier         *string
+	StructuredOutputs   *bool
+}
+
+func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort {
+	return &e
+}
+
+func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions {
+	return ai.ProviderOptions{
+		"openai": opts,
+	}
+}
+
+func NewProviderFileOptions(opts *ProviderFileOptions) ai.ProviderOptions {
+	return ai.ProviderOptions{
+		"openai": opts,
+	}
 }