chore: make provider options an interface

kujtimiihoxha created

Change summary

ai/agent_test.go              |  3 --
ai/content.go                 |  6 ++--
anthropic/anthropic.go        | 46 +++++++++++++++++++-----------------
anthropic/provider_options.go | 14 ++++++++--
openai/openai.go              | 31 ++++++++++++++----------
openai/provider_options.go    | 20 +++++++++++----
6 files changed, 70 insertions(+), 50 deletions(-)

Detailed changes

ai/agent_test.go 🔗

@@ -101,9 +101,6 @@ func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
 						URL:        "https://example.com",
 						Title:      "Example",
 						SourceType: SourceTypeURL,
-						ProviderMetadata: ProviderMetadata{
-							"provider": map[string]any{"custom": "value"},
-						},
 					},
 					FileContent{
 						Data:      []byte{1, 2, 3},

ai/content.go 🔗

@@ -11,10 +11,10 @@ package ai
 //
 //	{
 //	  "anthropic": {
-//	    "cacheControl": { "type": "ephemeral" }
+//	    "signature": "sig....."
 //	  }
 //	}
-type ProviderMetadata map[string]any
+type ProviderMetadata map[string]interface{ Options() }
 
 // ProviderOptions represents additional provider-specific options.
 // Options are additional input to the provider. They are passed through
@@ -34,7 +34,7 @@ type ProviderMetadata map[string]any
 //	    "cacheControl": { "type": "ephemeral" }
 //	  }
 //	}
-type ProviderOptions map[string]any
+type ProviderOptions map[string]interface{ Options() }
 
 // FinishReason represents why a language model finished generating a response.
 //

anthropic/anthropic.go 🔗

@@ -17,6 +17,11 @@ import (
 	"github.com/charmbracelet/ai/ai"
 )
 
+const (
+	ProviderName = "anthropic"
+	DefaultURL   = "https://api.anthropic.com"
+)
+
 type options struct {
 	baseURL string
 	apiKey  string
@@ -32,17 +37,16 @@ type provider struct {
 type Option = func(*options)
 
 func New(opts ...Option) ai.Provider {
-	options := options{
+	providerOptions := options{
 		headers: map[string]string{},
 	}
 	for _, o := range opts {
-		o(&options)
+		o(&providerOptions)
 	}
 
-	options.baseURL = cmp.Or(options.baseURL, "https://api.anthropic.com")
-	options.name = cmp.Or(options.name, "anthropic")
-
-	return &provider{options: options}
+	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
+	providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
+	return &provider{options: providerOptions}
 }
 
 func WithBaseURL(baseURL string) Option {
@@ -119,7 +123,7 @@ func (a languageModel) Provider() string {
 func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
 	params := &anthropic.MessageNewParams{}
 	providerOptions := &ProviderOptions{}
-	if v, ok := call.ProviderOptions["anthropic"]; ok {
+	if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
 			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
@@ -218,17 +222,17 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
 }
 
 func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
-	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
-		if options, ok := anthropicOptions.(*CacheControl); ok {
-			return options
+	if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok {
+		if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
+			return &options.CacheControl
 		}
 	}
 	return nil
 }
 
-func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
-	if anthropicOptions, ok := providerOptions["anthropic"]; ok {
-		if reasoning, ok := anthropicOptions.(*ReasoningMetadata); ok {
+func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
+	if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok {
+		if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
 			return reasoning
 		}
 	}
@@ -659,8 +663,8 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			}
 			content = append(content, ai.ReasoningContent{
 				Text: reasoning.Thinking,
-				ProviderMetadata: map[string]any{
-					"anthropic": &ReasoningMetadata{
+				ProviderMetadata: ai.ProviderMetadata{
+					ProviderOptionsKey: &ReasoningOptionMetadata{
 						Signature: reasoning.Signature,
 					},
 				},
@@ -672,8 +676,8 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 			}
 			content = append(content, ai.ReasoningContent{
 				Text: "",
-				ProviderMetadata: map[string]any{
-					"anthropic": &ReasoningMetadata{
+				ProviderMetadata: ai.ProviderMetadata{
+					ProviderOptionsKey: &ReasoningOptionMetadata{
 						RedactedData: reasoning.Data,
 					},
 				},
@@ -752,7 +756,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 						Type: ai.StreamPartTypeReasoningStart,
 						ID:   fmt.Sprintf("%d", chunk.Index),
 						ProviderMetadata: ai.ProviderMetadata{
-							"anthropic": &ReasoningMetadata{
+							ProviderOptionsKey: &ReasoningOptionMetadata{
 								RedactedData: chunk.ContentBlock.Data,
 							},
 						},
@@ -828,7 +832,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 						Type: ai.StreamPartTypeReasoningDelta,
 						ID:   fmt.Sprintf("%d", chunk.Index),
 						ProviderMetadata: ai.ProviderMetadata{
-							"anthropic": &ReasoningMetadata{
+							ProviderOptionsKey: &ReasoningOptionMetadata{
 								Signature: chunk.Delta.Signature,
 							},
 						},
@@ -865,9 +869,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 					CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
 					CacheReadTokens:     acc.Usage.CacheReadInputTokens,
 				},
-				ProviderMetadata: ai.ProviderMetadata{
-					"anthropic": make(map[string]any),
-				},
+				ProviderMetadata: ai.ProviderMetadata{},
 			})
 			return
 		} else {

anthropic/provider_options.go 🔗

@@ -2,37 +2,45 @@ package anthropic
 
 import "github.com/charmbracelet/ai/ai"
 
+const ProviderOptionsKey = "anthropic"
+
 type ProviderOptions struct {
 	SendReasoning          *bool
 	Thinking               *ThinkingProviderOption
 	DisableParallelToolUse *bool
 }
 
+func (o *ProviderOptions) Options() {}
+
 type ThinkingProviderOption struct {
 	BudgetTokens int64
 }
 
-type ReasoningMetadata struct {
+type ReasoningOptionMetadata struct {
 	Signature    string
 	RedactedData string
 }
 
+func (*ReasoningOptionMetadata) Options() {}
+
 type ProviderCacheControlOptions struct {
 	CacheControl CacheControl
 }
 
+func (*ProviderCacheControlOptions) Options() {}
+
 type CacheControl struct {
 	Type string
 }
 
 func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions {
 	return ai.ProviderOptions{
-		"anthropic": opts,
+		ProviderOptionsKey: opts,
 	}
 }
 
 func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) ai.ProviderOptions {
 	return ai.ProviderOptions{
-		"anthropic": opts,
+		ProviderOptionsKey: opts,
 	}
 }

openai/openai.go 🔗

@@ -20,6 +20,11 @@ import (
 	"github.com/openai/openai-go/v2/shared"
 )
 
+const (
+	ProviderName = "openai"
+	DefaultURL   = "https://api.openai.com/v1"
+)
+
 type provider struct {
 	options options
 }
@@ -37,24 +42,24 @@ type options struct {
 type Option = func(*options)
 
 func New(opts ...Option) ai.Provider {
-	options := options{
+	providerOptions := options{
 		headers: map[string]string{},
 	}
 	for _, o := range opts {
-		o(&options)
+		o(&providerOptions)
 	}
 
-	options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
-	options.name = cmp.Or(options.name, "openai")
+	providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
+	providerOptions.name = cmp.Or(providerOptions.name, ProviderName)
 
-	if options.organization != "" {
-		options.headers["OpenAi-Organization"] = options.organization
+	if providerOptions.organization != "" {
+		providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
 	}
-	if options.project != "" {
-		options.headers["OpenAi-Project"] = options.project
+	if providerOptions.project != "" {
+		providerOptions.headers["OpenAi-Project"] = providerOptions.project
 	}
 
-	return &provider{options: options}
+	return &provider{options: providerOptions}
 }
 
 func WithBaseURL(baseURL string) Option {
@@ -146,7 +151,7 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 	params := &openai.ChatCompletionNewParams{}
 	messages, warnings := toPrompt(call.Prompt)
 	providerOptions := &ProviderOptions{}
-	if v, ok := call.ProviderOptions["openai"]; ok {
+	if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
 			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
@@ -466,7 +471,7 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response
 		},
 		FinishReason: mapOpenAiFinishReason(choice.FinishReason),
 		ProviderMetadata: ai.ProviderMetadata{
-			"openai": providerMetadata,
+			ProviderOptionsKey: providerMetadata,
 		},
 		Warnings: warnings,
 	}, nil
@@ -728,7 +733,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
 				Usage:        usage,
 				FinishReason: finishReason,
 				ProviderMetadata: ai.ProviderMetadata{
-					"openai": streamProviderMetadata,
+					ProviderOptionsKey: streamProviderMetadata,
 				},
 			})
 			return
@@ -918,7 +923,7 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.
 						imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
 
 						// Check for provider-specific options like image detail
-						if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
+						if providerOptions, ok := filePart.ProviderOptions[ProviderOptionsKey]; ok {
 							if detail, ok := providerOptions.(*ProviderFileOptions); ok {
 								imageURL.Detail = detail.ImageDetail
 							}

openai/provider_options.go 🔗

@@ -5,6 +5,8 @@ import (
 	"github.com/openai/openai-go/v2"
 )
 
+const ProviderOptionsKey = "openai"
+
 type ReasoningEffort string
 
 const (
@@ -14,16 +16,14 @@ const (
 	ReasoningEffortHigh    ReasoningEffort = "high"
 )
 
-type ProviderFileOptions struct {
-	ImageDetail string
-}
-
 type ProviderMetadata struct {
 	Logprobs                 []openai.ChatCompletionTokenLogprob
 	AcceptedPredictionTokens int64
 	RejectedPredictionTokens int64
 }
 
+func (*ProviderMetadata) Options() {}
+
 type ProviderOptions struct {
 	LogitBias           map[string]int64
 	LogProbs            *bool
@@ -42,18 +42,26 @@ type ProviderOptions struct {
 	StructuredOutputs   *bool
 }
 
+func (*ProviderOptions) Options() {}
+
+type ProviderFileOptions struct {
+	ImageDetail string
+}
+
+func (*ProviderFileOptions) Options() {}
+
 func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort {
 	return &e
 }
 
 func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions {
 	return ai.ProviderOptions{
-		"openai": opts,
+		ProviderOptionsKey: opts,
 	}
 }
 
 func NewProviderFileOptions(opts *ProviderFileOptions) ai.ProviderOptions {
 	return ai.ProviderOptions{
-		"openai": opts,
+		ProviderOptionsKey: opts,
 	}
 }