From cc3c955a222f7ab6e59ecbc8e9b7725fb076fd7f Mon Sep 17 00:00:00 2001 From: kujtimiihoxha Date: Thu, 11 Sep 2025 13:29:03 +0200 Subject: [PATCH] chore: make provider options an interface --- ai/agent_test.go | 3 --- ai/content.go | 6 ++--- anthropic/anthropic.go | 46 ++++++++++++++++++----------------- anthropic/provider_options.go | 14 ++++++++--- openai/openai.go | 31 +++++++++++++---------- openai/provider_options.go | 20 ++++++++++----- 6 files changed, 70 insertions(+), 50 deletions(-) diff --git a/ai/agent_test.go b/ai/agent_test.go index 0339c675bc20a3845765d3bc784bed88f6976cce..23e5a6025a3141c9dc75308fb7e92060a45e4537 100644 --- a/ai/agent_test.go +++ b/ai/agent_test.go @@ -101,9 +101,6 @@ func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) { URL: "https://example.com", Title: "Example", SourceType: SourceTypeURL, - ProviderMetadata: ProviderMetadata{ - "provider": map[string]any{"custom": "value"}, - }, }, FileContent{ Data: []byte{1, 2, 3}, diff --git a/ai/content.go b/ai/content.go index 0588d9103fdeeb26eff5d34f1cb0b2ab2b98e6a1..df1f60e73733a68f2e71772050230eee15709468 100644 --- a/ai/content.go +++ b/ai/content.go @@ -11,10 +11,10 @@ package ai // // { // "anthropic": { -// "cacheControl": { "type": "ephemeral" } +// "signature": "sig....." // } // } -type ProviderMetadata map[string]any +type ProviderMetadata map[string]interface{ Options() } // ProviderOptions represents additional provider-specific options. // Options are additional input to the provider. They are passed through @@ -34,7 +34,7 @@ type ProviderMetadata map[string]any // "cacheControl": { "type": "ephemeral" } // } // } -type ProviderOptions map[string]any +type ProviderOptions map[string]interface{ Options() } // FinishReason represents why a language model finished generating a response. // diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 0ee67cb67a2d1929393778c5248aecdf55309fa9..7055c4b77d56071737924d236a201fdb6441bead 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -17,6 +17,11 @@ import ( "github.com/charmbracelet/ai/ai" ) +const ( + ProviderName = "anthropic" + DefaultURL = "https://api.anthropic.com" +) + type options struct { baseURL string apiKey string @@ -32,17 +37,16 @@ type provider struct { type Option = func(*options) func New(opts ...Option) ai.Provider { - options := options{ + providerOptions := options{ headers: map[string]string{}, } for _, o := range opts { - o(&options) + o(&providerOptions) } - options.baseURL = cmp.Or(options.baseURL, "https://api.anthropic.com") - options.name = cmp.Or(options.name, "anthropic") - - return &provider{options: options} + providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL) + providerOptions.name = cmp.Or(providerOptions.name, ProviderName) + return &provider{options: providerOptions} } func WithBaseURL(baseURL string) Option { @@ -119,7 +123,7 @@ func (a languageModel) Provider() string { func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { params := &anthropic.MessageNewParams{} providerOptions := &ProviderOptions{} - if v, ok := call.ProviderOptions["anthropic"]; ok { + if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) @@ -218,17 +222,17 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, } func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl { - if anthropicOptions, ok := providerOptions["anthropic"]; ok { - if options, ok := anthropicOptions.(*CacheControl); ok { - return options + if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok { + if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok { + return &options.CacheControl } } return nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata { - if anthropicOptions, ok := providerOptions["anthropic"]; ok { - if reasoning, ok := anthropicOptions.(*ReasoningMetadata); ok { +func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata { + if anthropicOptions, ok := providerOptions[ProviderOptionsKey]; ok { + if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok { return reasoning } } @@ -659,8 +663,8 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } content = append(content, ai.ReasoningContent{ Text: reasoning.Thinking, - ProviderMetadata: map[string]any{ - "anthropic": &ReasoningMetadata{ + ProviderMetadata: ai.ProviderMetadata{ + ProviderOptionsKey: &ReasoningOptionMetadata{ Signature: reasoning.Signature, }, }, @@ -672,8 +676,8 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } content = append(content, ai.ReasoningContent{ Text: "", - ProviderMetadata: map[string]any{ - "anthropic": &ReasoningMetadata{ + ProviderMetadata: ai.ProviderMetadata{ + ProviderOptionsKey: &ReasoningOptionMetadata{ RedactedData: reasoning.Data, }, }, @@ -752,7 +756,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo Type: ai.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", chunk.Index), ProviderMetadata: ai.ProviderMetadata{ - "anthropic": &ReasoningMetadata{ + ProviderOptionsKey: &ReasoningOptionMetadata{ RedactedData: chunk.ContentBlock.Data, }, }, @@ -828,7 +832,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo Type: ai.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", chunk.Index), ProviderMetadata: ai.ProviderMetadata{ - "anthropic": &ReasoningMetadata{ + ProviderOptionsKey: &ReasoningOptionMetadata{ Signature: chunk.Delta.Signature, }, }, @@ -865,9 +869,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo CacheCreationTokens: acc.Usage.CacheCreationInputTokens, CacheReadTokens: acc.Usage.CacheReadInputTokens, }, - ProviderMetadata: ai.ProviderMetadata{ - "anthropic": make(map[string]any), - }, + ProviderMetadata: ai.ProviderMetadata{}, }) return } else { diff --git a/anthropic/provider_options.go b/anthropic/provider_options.go index d01fb16a50f8ef455ca3702672e9560cc6b72e82..d46a0ed9597523de611b32853f80e4f2b56cf383 100644 --- a/anthropic/provider_options.go +++ b/anthropic/provider_options.go @@ -2,37 +2,45 @@ package anthropic import "github.com/charmbracelet/ai/ai" +const ProviderOptionsKey = "anthropic" + type ProviderOptions struct { SendReasoning *bool Thinking *ThinkingProviderOption DisableParallelToolUse *bool } +func (o *ProviderOptions) Options() {} + type ThinkingProviderOption struct { BudgetTokens int64 } -type ReasoningMetadata struct { +type ReasoningOptionMetadata struct { Signature string RedactedData string } +func (*ReasoningOptionMetadata) Options() {} + type ProviderCacheControlOptions struct { CacheControl CacheControl } +func (*ProviderCacheControlOptions) Options() {} + type CacheControl struct { Type string } func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { return ai.ProviderOptions{ - "anthropic": opts, + ProviderOptionsKey: opts, } } func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) ai.ProviderOptions { return ai.ProviderOptions{ - "anthropic": opts, + ProviderOptionsKey: opts, } } diff --git a/openai/openai.go b/openai/openai.go index c016fa6bfba4908a7f70aeedd6d4e64822396a85..239620e536b90c01823bd2e962438606f7ff5736 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -20,6 +20,11 @@ import ( "github.com/openai/openai-go/v2/shared" ) +const ( + ProviderName = "openai" + DefaultURL = "https://api.openai.com/v1" +) + type provider struct { options options } @@ -37,24 +42,24 @@ type options struct { type Option = func(*options) func New(opts ...Option) ai.Provider { - options := options{ + providerOptions := options{ headers: map[string]string{}, } for _, o := range opts { - o(&options) + o(&providerOptions) } - options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1") - options.name = cmp.Or(options.name, "openai") + providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL) + providerOptions.name = cmp.Or(providerOptions.name, ProviderName) - if options.organization != "" { - options.headers["OpenAi-Organization"] = options.organization + if providerOptions.organization != "" { + providerOptions.headers["OpenAi-Organization"] = providerOptions.organization } - if options.project != "" { - options.headers["OpenAi-Project"] = options.project + if providerOptions.project != "" { + providerOptions.headers["OpenAi-Project"] = providerOptions.project } - return &provider{options: options} + return &provider{options: providerOptions} } func WithBaseURL(baseURL string) Option { @@ -146,7 +151,7 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar params := &openai.ChatCompletionNewParams{} messages, warnings := toPrompt(call.Prompt) providerOptions := &ProviderOptions{} - if v, ok := call.ProviderOptions["openai"]; ok { + if v, ok := call.ProviderOptions[ProviderOptionsKey]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) @@ -466,7 +471,7 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response }, FinishReason: mapOpenAiFinishReason(choice.FinishReason), ProviderMetadata: ai.ProviderMetadata{ - "openai": providerMetadata, + ProviderOptionsKey: providerMetadata, }, Warnings: warnings, }, nil @@ -728,7 +733,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo Usage: usage, FinishReason: finishReason, ProviderMetadata: ai.ProviderMetadata{ - "openai": streamProviderMetadata, + ProviderOptionsKey: streamProviderMetadata, }, }) return @@ -918,7 +923,7 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data} // Check for provider-specific options like image detail - if providerOptions, ok := filePart.ProviderOptions["openai"]; ok { + if providerOptions, ok := filePart.ProviderOptions[ProviderOptionsKey]; ok { if detail, ok := providerOptions.(*ProviderFileOptions); ok { imageURL.Detail = detail.ImageDetail } diff --git a/openai/provider_options.go b/openai/provider_options.go index e13197445c29e06b352c89c03eb0cb287d73f560..62332185fe94f8f317843f1218ca31b8d97e7fa3 100644 --- a/openai/provider_options.go +++ b/openai/provider_options.go @@ -5,6 +5,8 @@ import ( "github.com/openai/openai-go/v2" ) +const ProviderOptionsKey = "openai" + type ReasoningEffort string const ( @@ -14,16 +16,14 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) -type ProviderFileOptions struct { - ImageDetail string -} - type ProviderMetadata struct { Logprobs []openai.ChatCompletionTokenLogprob AcceptedPredictionTokens int64 RejectedPredictionTokens int64 } +func (*ProviderMetadata) Options() {} + type ProviderOptions struct { LogitBias map[string]int64 LogProbs *bool @@ -42,18 +42,26 @@ type ProviderOptions struct { StructuredOutputs *bool } +func (*ProviderOptions) Options() {} + +type ProviderFileOptions struct { + ImageDetail string +} + +func (*ProviderFileOptions) Options() {} + func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { return &e } func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { return ai.ProviderOptions{ - "openai": opts, + ProviderOptionsKey: opts, } } func NewProviderFileOptions(opts *ProviderFileOptions) ai.ProviderOptions { return ai.ProviderOptions{ - "openai": opts, + ProviderOptionsKey: opts, } }