diff --git a/cspell.json b/cspell.json index dad6dc4e9a6dc9e6e42c09d5e920ffae10360f3e..b18a66de464c2ba1bbad75e8bbd795f22c139e75 100644 --- a/cspell.json +++ b/cspell.json @@ -1 +1 @@ -{"language":"en","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr"],"version":"0.2","flagWords":[]} \ No newline at end of file +{"version":"0.2","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr","Quantizations","Logit","Probs"],"flagWords":[],"language":"en"} \ No newline at end of file diff --git a/openai/openai.go b/openai/openai.go index e7d51767af75ff38e5fe1ad810588ec3ca1fe0f3..335c75cc4787f74dd3bc31cf151cb78615a618e7 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -29,12 +29,19 @@ type provider struct { options options } +type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) + +type Hooks struct { + PrepareCallWithOptions PrepareCallWithOptions +} + type options struct { baseURL string apiKey string organization string project string name string + hooks Hooks headers map[string]string client option.HTTPClient } @@ -104,6 +111,12 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +func WithHooks(hooks Hooks) Option { + return func(o *options) { + o.hooks = hooks + } +} + // LanguageModel implements ai.Provider. func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { openaiClientOptions := []option.RequestOption{} @@ -147,24 +160,19 @@ func (o languageModel) Provider() string { return o.provider } -func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { - params := &openai.ChatCompletionNewParams{} - messages, warnings := toPrompt(call.Prompt) +func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { + if call.ProviderOptions == nil { + return nil, nil + } + var warnings []ai.CallWarning providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) } } - if call.TopK != nil { - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "top_k", - }) - } - params.Messages = messages - params.Model = o.modelID + if providerOptions.LogitBias != nil { params.LogitBias = providerOptions.LogitBias } @@ -183,23 +191,6 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar if providerOptions.ParallelToolCalls != nil { params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) } - - if call.MaxOutputTokens != nil { - params.MaxTokens = param.NewOpt(*call.MaxOutputTokens) - } - if call.Temperature != nil { - params.Temperature = param.NewOpt(*call.Temperature) - } - if call.TopP != nil { - params.TopP = param.NewOpt(*call.TopP) - } - if call.FrequencyPenalty != nil { - params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty) - } - if call.PresencePenalty != nil { - params.PresencePenalty = param.NewOpt(*call.PresencePenalty) - } - if providerOptions.MaxCompletionTokens != nil { params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) } @@ -253,45 +244,11 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar case ReasoningEffortHigh: params.ReasoningEffort = shared.ReasoningEffortHigh default: - return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) + return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort) } } - if isReasoningModel(o.modelID) { - // remove unsupported settings for reasoning models - // see https://platform.openai.com/docs/guides/reasoning#limitations - if call.Temperature != nil { - params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "temperature", - Details: "temperature is not supported for reasoning models", - }) - } - if call.TopP != nil { - params.TopP = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "TopP", - Details: "TopP is not supported for reasoning models", - }) - } - if call.FrequencyPenalty != nil { - params.FrequencyPenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "FrequencyPenalty", - Details: "FrequencyPenalty is not supported for reasoning models", - }) - } - if call.PresencePenalty != nil { - params.PresencePenalty = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "PresencePenalty", - Details: "PresencePenalty is not supported for reasoning models", - }) - } + if isReasoningModel(model.Model()) { if providerOptions.LogitBias != nil { params.LogitBias = nil warnings = append(warnings, ai.CallWarning{ @@ -324,31 +281,20 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar } params.MaxTokens = param.Opt[int64]{} } - } - // Handle search preview models - if isSearchPreviewModel(o.modelID) { - if call.Temperature != nil { - params.Temperature = param.Opt[float64]{} - warnings = append(warnings, ai.CallWarning{ - Type: ai.CallWarningTypeUnsupportedSetting, - Setting: "temperature", - Details: "temperature is not supported for the search preview models and has been removed.", - }) - } } // Handle service tier validation if providerOptions.ServiceTier != nil { serviceTier := *providerOptions.ServiceTier - if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) { + if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) { params.ServiceTier = "" warnings = append(warnings, ai.CallWarning{ Type: ai.CallWarningTypeUnsupportedSetting, Setting: "ServiceTier", Details: "flex processing is only available for o3, o4-mini, and gpt-5 models", }) - } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) { + } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) { params.ServiceTier = "" warnings = append(warnings, ai.CallWarning{ Type: ai.CallWarningTypeUnsupportedSetting, @@ -357,6 +303,99 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar }) } } + return warnings, nil +} + +func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { + params := &openai.ChatCompletionNewParams{} + messages, warnings := toPrompt(call.Prompt) + if call.TopK != nil { + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "top_k", + }) + } + params.Messages = messages + params.Model = o.modelID + + if call.MaxOutputTokens != nil { + params.MaxTokens = param.NewOpt(*call.MaxOutputTokens) + } + if call.Temperature != nil { + params.Temperature = param.NewOpt(*call.Temperature) + } + if call.TopP != nil { + params.TopP = param.NewOpt(*call.TopP) + } + if call.FrequencyPenalty != nil { + params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty) + } + if call.PresencePenalty != nil { + params.PresencePenalty = param.NewOpt(*call.PresencePenalty) + } + + if isReasoningModel(o.modelID) { + // remove unsupported settings for reasoning models + // see https://platform.openai.com/docs/guides/reasoning#limitations + if call.Temperature != nil { + params.Temperature = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "temperature", + Details: "temperature is not supported for reasoning models", + }) + } + if call.TopP != nil { + params.TopP = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "TopP", + Details: "TopP is not supported for reasoning models", + }) + } + if call.FrequencyPenalty != nil { + params.FrequencyPenalty = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "FrequencyPenalty", + Details: "FrequencyPenalty is not supported for reasoning models", + }) + } + if call.PresencePenalty != nil { + params.PresencePenalty = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "PresencePenalty", + Details: "PresencePenalty is not supported for reasoning models", + }) + } + } + + // Handle search preview models + if isSearchPreviewModel(o.modelID) { + if call.Temperature != nil { + params.Temperature = param.Opt[float64]{} + warnings = append(warnings, ai.CallWarning{ + Type: ai.CallWarningTypeUnsupportedSetting, + Setting: "temperature", + Details: "temperature is not supported for the search preview models and has been removed.", + }) + } + } + + prepareOptions := prepareCallWithOptions + if o.options.hooks.PrepareCallWithOptions != nil { + prepareOptions = o.options.hooks.PrepareCallWithOptions + } + + optionsWarnings, err := prepareOptions(o, params, call) + if err != nil { + return nil, nil, err + } + + if len(optionsWarnings) > 0 { + warnings = append(warnings, optionsWarnings...) + } if len(call.Tools) > 0 { tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice) diff --git a/openrouter/openrouter.go b/openrouter/openrouter.go new file mode 100644 index 0000000000000000000000000000000000000000..4c3d39c89596c8d0d4be383cb7b9264157c2536b --- /dev/null +++ b/openrouter/openrouter.go @@ -0,0 +1,74 @@ +package openrouter + +import ( + "github.com/charmbracelet/fantasy/ai" + "github.com/charmbracelet/fantasy/openai" + openaiSDK "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" +) + +type options struct { + openaiOptions []openai.Option +} + +type Option = func(*options) + +func prepareCallWithOptions(model ai.LanguageModel, params *openaiSDK.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) { + providerOptions := &ProviderOptions{} + if v, ok := call.ProviderOptions[Name]; ok { + providerOptions, ok = v.(*ProviderOptions) + if !ok { + return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + } + } + _ = providerOptions + + // HANDLE OPENROUTER call modification here + + return nil, nil +} + +func New(opts ...Option) ai.Provider { + providerOptions := options{ + openaiOptions: []openai.Option{ + openai.WithHooks(openai.Hooks{ + PrepareCallWithOptions: prepareCallWithOptions, + }), + }, + } + for _, o := range opts { + o(&providerOptions) + } + return openai.New(providerOptions.openaiOptions...) +} + +func WithBaseURL(baseURL string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(baseURL)) + } +} + +func WithAPIKey(apiKey string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey)) + } +} + +func WithName(name string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithName(name)) + } +} + +func WithHeaders(headers map[string]string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers)) + } +} + +func WithHTTPClient(client option.HTTPClient) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client)) + } +} + diff --git a/openrouter/provider_options.go b/openrouter/provider_options.go new file mode 100644 index 0000000000000000000000000000000000000000..a9bbbbd6738d4776cca9627e2bc9171b4bcdd5e5 --- /dev/null +++ b/openrouter/provider_options.go @@ -0,0 +1,81 @@ +package openrouter + +import ( + "github.com/charmbracelet/fantasy/ai" +) + +const Name = "openrouter" + +type ReasoningEffort string + +const ( + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ProviderMetadata struct{} + +func (*ProviderMetadata) Options() {} + +type ReasoningOptions struct { + // Whether reasoning is enabled + Enabled *bool `json:"enabled"` + // Whether to exclude reasoning from the response + Exclude *bool `json:"exclude"` + // Maximum number of tokens to use for reasoning + MaxTokens *int64 `json:"max_tokens"` + // Reasoning effort level: "low" | "medium" | "high" + Effort *ReasoningEffort `json:"effort"` +} + +type Provider struct { + // List of provider slugs to try in order (e.g. ["anthropic", "openai"]) + Order []string `json:"order"` + // Whether to allow backup providers when primary is unavailable (default: true) + AllowFallbacks *bool `json:"allow_fallbacks"` + // Only use providers that support all parameters in your request (default: false) + RequireParameters *bool `json:"require_parameters"` + // Control whether to use providers that may store data: "allow" | "deny" + DataCollection *string `json:"data_collection"` + // List of provider slugs to allow for this request + Only []string `json:"only"` + // List of provider slugs to skip for this request + Ignore []string `json:"ignore"` + // List of quantization levels to filter by (e.g. ["int4", "int8"]) + Quantizations []string `json:"quantizations"` + // Sort providers by "price" | "throughput" | "latency" + Sort *string `json:"sort"` +} + +type ProviderOptions struct { + Reasoning *ReasoningOptions `json:"reasoning"` + ExtraBody map[string]any `json:"extra_body"` + IncludeUsage *bool `json:"include_usage"` + // Modify the likelihood of specified tokens appearing in the completion. + // Accepts a map that maps tokens (specified by their token ID) to an associated bias value from -100 to 100. + // The bias is added to the logits generated by the model prior to sampling. + LogitBias map[string]int64 `json:"logit_bias"` + // Return the log probabilities of the tokens. Including logprobs will increase the response size. + // Setting to true will return the log probabilities of the tokens that were generated. + LogProbs *bool `json:"log_probs"` + // Whether to enable parallel function calling during tool use. Default to true. + ParallelToolCalls *bool `json:"parallel_tool_calls"` + // A unique identifier representing your end-user, which can help OpenRouter to monitor and detect abuse. + User *string `json:"user"` + // Provider routing preferences to control request routing behavior + Provider *Provider `json:"provider"` + // TODO: add the web search plugin config +} + +func (*ProviderOptions) Options() {} + +func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort { + return &e +} + +func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + Name: opts, + } +}