diff --git a/ai/content.go b/ai/content.go index 5ff7f0805aa7a6a533527e4525cfb657dff9750a..0588d9103fdeeb26eff5d34f1cb0b2ab2b98e6a1 100644 --- a/ai/content.go +++ b/ai/content.go @@ -14,7 +14,7 @@ package ai // "cacheControl": { "type": "ephemeral" } // } // } -type ProviderMetadata map[string]map[string]any +type ProviderMetadata map[string]any // ProviderOptions represents additional provider-specific options. // Options are additional input to the provider. They are passed through @@ -34,7 +34,7 @@ type ProviderMetadata map[string]map[string]any // "cacheControl": { "type": "ephemeral" } // } // } -type ProviderOptions map[string]map[string]any +type ProviderOptions map[string]any // FinishReason represents why a language model finished generating a response. // diff --git a/ai/model.go b/ai/model.go index ac1875517b3c320863e1cf2dad3fb9564a0c1a5f..01fb3dd818e7b9cfd9a4ca2361d6a24b6557d0b4 100644 --- a/ai/model.go +++ b/ai/model.go @@ -118,7 +118,7 @@ type Response struct { Warnings []CallWarning `json:"warnings"` // for provider specific response metadata, the key is the provider id - ProviderMetadata map[string]map[string]any `json:"provider_metadata"` + ProviderMetadata ProviderMetadata `json:"provider_metadata"` } type StreamPartType string diff --git a/ai/util.go b/ai/util.go index dc50b4c6c27a59fa4a261f4132cbb8d0d3b4c1b3..7af28c22ccabfbeb45fb2b5a288a2d6b9063ccff 100644 --- a/ai/util.go +++ b/ai/util.go @@ -11,3 +11,15 @@ func ParseOptions[T any](options map[string]any, m *T) error { func FloatOption(f float64) *float64 { return &f } + +func BoolOption(b bool) *bool { + return &b +} + +func StringOption(s string) *string { + return &s +} + +func IntOption(i int64) *int64 { + return &i +} diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 00a444ec7ae0113b873a15644f7943292725fa5c..0ee67cb67a2d1929393778c5248aecdf55309fa9 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -120,9 +120,9 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, params := &anthropic.MessageNewParams{} providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["anthropic"]; ok { - err := ai.ParseOptions(v, providerOptions) - if err != nil { - return nil, nil, err + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) } } sendReasoning := true @@ -217,24 +217,10 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, return params, warnings, nil } -func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions { +func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl { if anthropicOptions, ok := providerOptions["anthropic"]; ok { - if cacheControl, ok := anthropicOptions["cache_control"]; ok { - if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &CacheControlProviderOptions{} - err := ai.ParseOptions(cc, cacheControlOption) - if err == nil { - return cacheControlOption - } - } - } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok { - if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &CacheControlProviderOptions{} - err := ai.ParseOptions(cc, cacheControlOption) - if err == nil { - return cacheControlOption - } - } + if options, ok := anthropicOptions.(*CacheControl); ok { + return options } } return nil @@ -242,10 +228,8 @@ func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOp func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata { if anthropicOptions, ok := providerOptions["anthropic"]; ok { - reasoningMetadata := &ReasoningMetadata{} - err := ai.ParseOptions(anthropicOptions, reasoningMetadata) - if err == nil { - return reasoningMetadata + if reasoning, ok := anthropicOptions.(*ReasoningMetadata); ok { + return reasoning } } return nil @@ -675,9 +659,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } content = append(content, ai.ReasoningContent{ Text: reasoning.Thinking, - ProviderMetadata: map[string]map[string]any{ - "anthropic": { - "signature": reasoning.Signature, + ProviderMetadata: map[string]any{ + "anthropic": &ReasoningMetadata{ + Signature: reasoning.Signature, }, }, }) @@ -688,9 +672,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response } content = append(content, ai.ReasoningContent{ Text: "", - ProviderMetadata: map[string]map[string]any{ - "anthropic": { - "redacted_data": reasoning.Data, + ProviderMetadata: map[string]any{ + "anthropic": &ReasoningMetadata{ + RedactedData: reasoning.Data, }, }, }) @@ -717,11 +701,9 @@ func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response CacheCreationTokens: response.Usage.CacheCreationInputTokens, CacheReadTokens: response.Usage.CacheReadInputTokens, }, - FinishReason: mapFinishReason(string(response.StopReason)), - ProviderMetadata: ai.ProviderMetadata{ - "anthropic": make(map[string]any), - }, - Warnings: warnings, + FinishReason: mapFinishReason(string(response.StopReason)), + ProviderMetadata: ai.ProviderMetadata{}, + Warnings: warnings, }, nil } @@ -770,8 +752,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo Type: ai.StreamPartTypeReasoningStart, ID: fmt.Sprintf("%d", chunk.Index), ProviderMetadata: ai.ProviderMetadata{ - "anthropic": { - "redacted_data": chunk.ContentBlock.Data, + "anthropic": &ReasoningMetadata{ + RedactedData: chunk.ContentBlock.Data, }, }, }) { @@ -846,8 +828,8 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo Type: ai.StreamPartTypeReasoningDelta, ID: fmt.Sprintf("%d", chunk.Index), ProviderMetadata: ai.ProviderMetadata{ - "anthropic": { - "signature": chunk.Delta.Signature, + "anthropic": &ReasoningMetadata{ + Signature: chunk.Delta.Signature, }, }, }) { diff --git a/anthropic/provider_options.go b/anthropic/provider_options.go index c22face3491e886382afa3de2e367d629a9bba3d..d01fb16a50f8ef455ca3702672e9560cc6b72e82 100644 --- a/anthropic/provider_options.go +++ b/anthropic/provider_options.go @@ -1,20 +1,38 @@ package anthropic +import "github.com/charmbracelet/ai/ai" + type ProviderOptions struct { - SendReasoning *bool `mapstructure:"send_reasoning,omitempty"` - Thinking *ThinkingProviderOption `mapstructure:"thinking,omitempty"` - DisableParallelToolUse *bool `mapstructure:"disable_parallel_tool_use,omitempty"` + SendReasoning *bool + Thinking *ThinkingProviderOption + DisableParallelToolUse *bool } type ThinkingProviderOption struct { - BudgetTokens int64 `mapstructure:"budget_tokens"` + BudgetTokens int64 } type ReasoningMetadata struct { - Signature string `mapstructure:"signature"` - RedactedData string `mapstructure:"redacted_data"` + Signature string + RedactedData string +} + +type ProviderCacheControlOptions struct { + CacheControl CacheControl +} + +type CacheControl struct { + Type string +} + +func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + "anthropic": opts, + } } -type CacheControlProviderOptions struct { - Type string `mapstructure:"type"` +func NewProviderCacheControlOptions(opts *ProviderCacheControlOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + "anthropic": opts, + } } diff --git a/openai/openai.go b/openai/openai.go index edf9cba4c19d832514488a7622637238027a0c23..c016fa6bfba4908a7f70aeedd6d4e64822396a85 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -147,9 +147,9 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar messages, warnings := toPrompt(call.Prompt) providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["openai"]; ok { - err := ai.ParseOptions(v, providerOptions) - if err != nil { - return nil, nil, err + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) } } if call.TopK != nil { @@ -439,22 +439,19 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response promptTokenDetails := response.Usage.PromptTokensDetails // Build provider metadata - providerMetadata := ai.ProviderMetadata{ - "openai": make(map[string]any), - } - + providerMetadata := &ProviderMetadata{} // Add logprobs if available if len(choice.Logprobs.Content) > 0 { - providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content + providerMetadata.Logprobs = choice.Logprobs.Content } // Add prediction tokens if available if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { if completionTokenDetails.AcceptedPredictionTokens > 0 { - providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens + providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens } if completionTokenDetails.RejectedPredictionTokens > 0 { - providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens + providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens } } @@ -467,9 +464,11 @@ func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response ReasoningTokens: completionTokenDetails.ReasoningTokens, CacheReadTokens: promptTokenDetails.CachedTokens, }, - FinishReason: mapOpenAiFinishReason(choice.FinishReason), - ProviderMetadata: providerMetadata, - Warnings: warnings, + FinishReason: mapOpenAiFinishReason(choice.FinishReason), + ProviderMetadata: ai.ProviderMetadata{ + "openai": providerMetadata, + }, + Warnings: warnings, }, nil } @@ -496,10 +495,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo toolCalls := make(map[int64]toolCall) // Build provider metadata for streaming - streamProviderMetadata := ai.ProviderMetadata{ - "openai": make(map[string]any), - } - + streamProviderMetadata := &ProviderMetadata{} acc := openai.ChatCompletionAccumulator{} var usage ai.Usage return func(yield func(ai.StreamPart) bool) { @@ -529,10 +525,10 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo // Add prediction tokens if available if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 { if completionTokenDetails.AcceptedPredictionTokens > 0 { - streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens + streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens } if completionTokenDetails.RejectedPredictionTokens > 0 { - streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens + streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens } } } @@ -706,7 +702,7 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo // Add logprobs if available if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 { - streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content + streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content } // Handle annotations/citations from accumulated response @@ -728,10 +724,12 @@ func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason) yield(ai.StreamPart{ - Type: ai.StreamPartTypeFinish, - Usage: usage, - FinishReason: finishReason, - ProviderMetadata: streamProviderMetadata, + Type: ai.StreamPartTypeFinish, + Usage: usage, + FinishReason: finishReason, + ProviderMetadata: ai.ProviderMetadata{ + "openai": streamProviderMetadata, + }, }) return } else { @@ -921,8 +919,8 @@ func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai. // Check for provider-specific options like image detail if providerOptions, ok := filePart.ProviderOptions["openai"]; ok { - if detail, ok := providerOptions["imageDetail"].(string); ok { - imageURL.Detail = detail + if detail, ok := providerOptions.(*ProviderFileOptions); ok { + imageURL.Detail = detail.ImageDetail } } diff --git a/openai/openai_test.go b/openai/openai_test.go index 27fe8c4ee8f0979f218c33683af1f003fbb46b1b..c6ef09b6b81e4e4de291cd9fbbf25a788e16dd6d 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -157,11 +157,9 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { ai.FilePart{ MediaType: "image/png", Data: imageData, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "imageDetail": "low", - }, - }, + ProviderOptions: NewProviderFileOptions(&ProviderFileOptions{ + ImageDetail: "low", + }), }, }, }, @@ -941,20 +939,18 @@ func TestDoGenerate(t *testing.T) { result, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "logProbs": true, - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + LogProbs: ai.BoolOption(true), + }), }) require.NoError(t, err) require.NotNil(t, result.ProviderMetadata) - openaiMeta, ok := result.ProviderMetadata["openai"] + openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata) require.True(t, ok) - logprobs, ok := openaiMeta["logprobs"] + logprobs := openaiMeta.Logprobs require.True(t, ok) require.NotNil(t, logprobs) }) @@ -1057,15 +1053,13 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "logit_bias": map[string]int64{ - "50256": -100, - }, - "parallel_tool_calls": false, - "user": "test-user-id", + ProviderOptions: NewProviderOptions(&ProviderOptions{ + LogitBias: map[string]int64{ + "50256": -100, }, - }, + ParallelToolCalls: ai.BoolOption(false), + User: ai.StringOption("test-user-id"), + }), }) require.NoError(t, err) @@ -1101,11 +1095,11 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "reasoning_effort": "low", + ProviderOptions: NewProviderOptions( + &ProviderOptions{ + ReasoningEffort: ReasoningEffortOption(ReasoningEffortLow), }, - }, + ), }) require.NoError(t, err) @@ -1141,11 +1135,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "text_verbosity": "low", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + TextVerbosity: ai.StringOption("low"), + }), }) require.NoError(t, err) @@ -1393,10 +1385,11 @@ func TestDoGenerate(t *testing.T) { require.NoError(t, err) require.NotNil(t, result.ProviderMetadata) - openaiMeta, ok := result.ProviderMetadata["openai"] + openaiMeta, ok := result.ProviderMetadata["openai"].(*ProviderMetadata) + require.True(t, ok) - require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"]) - require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"]) + require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens) + require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens) }) t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) { @@ -1534,11 +1527,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "max_completion_tokens": 255, - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + MaxCompletionTokens: ai.IntOption(255), + }), }) require.NoError(t, err) @@ -1574,14 +1565,12 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "prediction": map[string]any{ - "type": "content", - "content": "Hello, World!", - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + Prediction: map[string]any{ + "type": "content", + "content": "Hello, World!", }, - }, + }), }) require.NoError(t, err) @@ -1620,11 +1609,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "store": true, - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + Store: ai.BoolOption(true), + }), }) require.NoError(t, err) @@ -1660,13 +1647,11 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "metadata": map[string]any{ - "custom": "value", - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + Metadata: map[string]any{ + "custom": "value", }, - }, + }), }) require.NoError(t, err) @@ -1704,11 +1689,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "prompt_cache_key": "test-cache-key-123", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + PromptCacheKey: ai.StringOption("test-cache-key-123"), + }), }) require.NoError(t, err) @@ -1744,11 +1727,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "safety_identifier": "test-safety-identifier-123", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + SafetyIdentifier: ai.StringOption("test-safety-identifier-123"), + }), }) require.NoError(t, err) @@ -1816,11 +1797,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "flex", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("flex"), + }), }) require.NoError(t, err) @@ -1854,11 +1833,9 @@ func TestDoGenerate(t *testing.T) { result, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "flex", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("flex"), + }), }) require.NoError(t, err) @@ -1889,11 +1866,9 @@ func TestDoGenerate(t *testing.T) { _, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "priority", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("priority"), + }), }) require.NoError(t, err) @@ -1927,11 +1902,9 @@ func TestDoGenerate(t *testing.T) { result, err := model.Generate(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "priority", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("priority"), + }), }) require.NoError(t, err) @@ -2575,10 +2548,10 @@ func TestDoStream(t *testing.T) { require.NotNil(t, finishPart) require.NotNil(t, finishPart.ProviderMetadata) - openaiMeta, ok := finishPart.ProviderMetadata["openai"] + openaiMeta, ok := finishPart.ProviderMetadata["openai"].(*ProviderMetadata) require.True(t, ok) - require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"]) - require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"]) + require.Equal(t, int64(123), openaiMeta.AcceptedPredictionTokens) + require.Equal(t, int64(456), openaiMeta.RejectedPredictionTokens) }) t.Run("should send store extension setting", func(t *testing.T) { @@ -2599,11 +2572,9 @@ func TestDoStream(t *testing.T) { _, err := model.Stream(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "store": true, - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + Store: ai.BoolOption(true), + }), }) require.NoError(t, err) @@ -2643,13 +2614,11 @@ func TestDoStream(t *testing.T) { _, err := model.Stream(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "metadata": map[string]any{ - "custom": "value", - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + Metadata: map[string]any{ + "custom": "value", }, - }, + }), }) require.NoError(t, err) @@ -2691,11 +2660,9 @@ func TestDoStream(t *testing.T) { _, err := model.Stream(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "flex", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("flex"), + }), }) require.NoError(t, err) @@ -2735,11 +2702,9 @@ func TestDoStream(t *testing.T) { _, err := model.Stream(context.Background(), ai.Call{ Prompt: testPrompt, - ProviderOptions: ai.ProviderOptions{ - "openai": map[string]any{ - "service_tier": "priority", - }, - }, + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ServiceTier: ai.StringOption("priority"), + }), }) require.NoError(t, err) diff --git a/openai/provider_options.go b/openai/provider_options.go index b02e31072e975c01f624e849cc8d5f669c315234..e13197445c29e06b352c89c03eb0cb287d73f560 100644 --- a/openai/provider_options.go +++ b/openai/provider_options.go @@ -1,5 +1,10 @@ package openai +import ( + "github.com/charmbracelet/ai/ai" + "github.com/openai/openai-go/v2" +) + type ReasoningEffort string const ( @@ -9,20 +14,46 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) +type ProviderFileOptions struct { + ImageDetail string +} + +type ProviderMetadata struct { + Logprobs []openai.ChatCompletionTokenLogprob + AcceptedPredictionTokens int64 + RejectedPredictionTokens int64 +} + type ProviderOptions struct { - LogitBias map[string]int64 `mapstructure:"logit_bias"` - LogProbs *bool `mapstructure:"log_probes"` - TopLogProbs *int64 `mapstructure:"top_log_probs"` - ParallelToolCalls *bool `mapstructure:"parallel_tool_calls"` - User *string `mapstructure:"user"` - ReasoningEffort *ReasoningEffort `mapstructure:"reasoning_effort"` - MaxCompletionTokens *int64 `mapstructure:"max_completion_tokens"` - TextVerbosity *string `mapstructure:"text_verbosity"` - Prediction map[string]any `mapstructure:"prediction"` - Store *bool `mapstructure:"store"` - Metadata map[string]any `mapstructure:"metadata"` - PromptCacheKey *string `mapstructure:"prompt_cache_key"` - SafetyIdentifier *string `mapstructure:"safety_identifier"` - ServiceTier *string `mapstructure:"service_tier"` - StructuredOutputs *bool `mapstructure:"structured_outputs"` + LogitBias map[string]int64 + LogProbs *bool + TopLogProbs *int64 + ParallelToolCalls *bool + User *string + ReasoningEffort *ReasoningEffort + MaxCompletionTokens *int64 + TextVerbosity *string + Prediction map[string]any + Store *bool + Metadata map[string]any + PromptCacheKey *string + SafetyIdentifier *string + ServiceTier *string + StructuredOutputs *bool +} + +func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { + return &e +} + +func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + "openai": opts, + } +} + +func NewProviderFileOptions(opts *ProviderFileOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + "openai": opts, + } }