@@ -1 +1 @@
-{"language":"en","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr"],"version":"0.2","flagWords":[]}
+{"version":"0.2","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr","Quantizations","Logit","Probs"],"flagWords":[],"language":"en"}
@@ -29,12 +29,19 @@ type provider struct {
options options
}
+type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
+
+type Hooks struct {
+ PrepareCallWithOptions PrepareCallWithOptions
+}
+
type options struct {
baseURL string
apiKey string
organization string
project string
name string
+ hooks Hooks
headers map[string]string
client option.HTTPClient
}
@@ -104,6 +111,12 @@ func WithHTTPClient(client option.HTTPClient) Option {
}
}
+func WithHooks(hooks Hooks) Option {
+ return func(o *options) {
+ o.hooks = hooks
+ }
+}
+
// LanguageModel implements ai.Provider.
func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
openaiClientOptions := []option.RequestOption{}
@@ -147,24 +160,19 @@ func (o languageModel) Provider() string {
return o.provider
}
-func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
- params := &openai.ChatCompletionNewParams{}
- messages, warnings := toPrompt(call.Prompt)
+func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
+ if call.ProviderOptions == nil {
+ return nil, nil
+ }
+ var warnings []ai.CallWarning
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
+ return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
}
}
- if call.TopK != nil {
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "top_k",
- })
- }
- params.Messages = messages
- params.Model = o.modelID
+
if providerOptions.LogitBias != nil {
params.LogitBias = providerOptions.LogitBias
}
@@ -183,23 +191,6 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
if providerOptions.ParallelToolCalls != nil {
params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
}
-
- if call.MaxOutputTokens != nil {
- params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
- }
- if call.Temperature != nil {
- params.Temperature = param.NewOpt(*call.Temperature)
- }
- if call.TopP != nil {
- params.TopP = param.NewOpt(*call.TopP)
- }
- if call.FrequencyPenalty != nil {
- params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
- }
- if call.PresencePenalty != nil {
- params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
- }
-
if providerOptions.MaxCompletionTokens != nil {
params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
}
@@ -253,45 +244,11 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
case ReasoningEffortHigh:
params.ReasoningEffort = shared.ReasoningEffortHigh
default:
- return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
+ return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
}
}
- if isReasoningModel(o.modelID) {
- // remove unsupported settings for reasoning models
- // see https://platform.openai.com/docs/guides/reasoning#limitations
- if call.Temperature != nil {
- params.Temperature = param.Opt[float64]{}
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "temperature",
- Details: "temperature is not supported for reasoning models",
- })
- }
- if call.TopP != nil {
- params.TopP = param.Opt[float64]{}
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "TopP",
- Details: "TopP is not supported for reasoning models",
- })
- }
- if call.FrequencyPenalty != nil {
- params.FrequencyPenalty = param.Opt[float64]{}
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "FrequencyPenalty",
- Details: "FrequencyPenalty is not supported for reasoning models",
- })
- }
- if call.PresencePenalty != nil {
- params.PresencePenalty = param.Opt[float64]{}
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "PresencePenalty",
- Details: "PresencePenalty is not supported for reasoning models",
- })
- }
+ if isReasoningModel(model.Model()) {
if providerOptions.LogitBias != nil {
params.LogitBias = nil
warnings = append(warnings, ai.CallWarning{
@@ -324,31 +281,20 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
}
params.MaxTokens = param.Opt[int64]{}
}
- }
- // Handle search preview models
- if isSearchPreviewModel(o.modelID) {
- if call.Temperature != nil {
- params.Temperature = param.Opt[float64]{}
- warnings = append(warnings, ai.CallWarning{
- Type: ai.CallWarningTypeUnsupportedSetting,
- Setting: "temperature",
- Details: "temperature is not supported for the search preview models and has been removed.",
- })
- }
}
// Handle service tier validation
if providerOptions.ServiceTier != nil {
serviceTier := *providerOptions.ServiceTier
- if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
+ if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
params.ServiceTier = ""
warnings = append(warnings, ai.CallWarning{
Type: ai.CallWarningTypeUnsupportedSetting,
Setting: "ServiceTier",
Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
})
- } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
+ } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
params.ServiceTier = ""
warnings = append(warnings, ai.CallWarning{
Type: ai.CallWarningTypeUnsupportedSetting,
@@ -357,6 +303,99 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
})
}
}
+ return warnings, nil
+}
+
+func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
+ params := &openai.ChatCompletionNewParams{}
+ messages, warnings := toPrompt(call.Prompt)
+ if call.TopK != nil {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "top_k",
+ })
+ }
+ params.Messages = messages
+ params.Model = o.modelID
+
+ if call.MaxOutputTokens != nil {
+ params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
+ }
+ if call.Temperature != nil {
+ params.Temperature = param.NewOpt(*call.Temperature)
+ }
+ if call.TopP != nil {
+ params.TopP = param.NewOpt(*call.TopP)
+ }
+ if call.FrequencyPenalty != nil {
+ params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
+ }
+ if call.PresencePenalty != nil {
+ params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
+ }
+
+ if isReasoningModel(o.modelID) {
+ // remove unsupported settings for reasoning models
+ // see https://platform.openai.com/docs/guides/reasoning#limitations
+ if call.Temperature != nil {
+ params.Temperature = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "temperature",
+ Details: "temperature is not supported for reasoning models",
+ })
+ }
+ if call.TopP != nil {
+ params.TopP = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "TopP",
+ Details: "TopP is not supported for reasoning models",
+ })
+ }
+ if call.FrequencyPenalty != nil {
+ params.FrequencyPenalty = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "FrequencyPenalty",
+ Details: "FrequencyPenalty is not supported for reasoning models",
+ })
+ }
+ if call.PresencePenalty != nil {
+ params.PresencePenalty = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "PresencePenalty",
+ Details: "PresencePenalty is not supported for reasoning models",
+ })
+ }
+ }
+
+ // Handle search preview models
+ if isSearchPreviewModel(o.modelID) {
+ if call.Temperature != nil {
+ params.Temperature = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "temperature",
+ Details: "temperature is not supported for the search preview models and has been removed.",
+ })
+ }
+ }
+
+ prepareOptions := prepareCallWithOptions
+ if o.options.hooks.PrepareCallWithOptions != nil {
+ prepareOptions = o.options.hooks.PrepareCallWithOptions
+ }
+
+ optionsWarnings, err := prepareOptions(o, params, call)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if len(optionsWarnings) > 0 {
+ warnings = append(warnings, optionsWarnings...)
+ }
if len(call.Tools) > 0 {
tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
@@ -0,0 +1,74 @@
+package openrouter
+
+import (
+ "github.com/charmbracelet/fantasy/ai"
+ "github.com/charmbracelet/fantasy/openai"
+ openaiSDK "github.com/openai/openai-go/v2"
+ "github.com/openai/openai-go/v2/option"
+)
+
+type options struct {
+ openaiOptions []openai.Option
+}
+
+type Option = func(*options)
+
+func prepareCallWithOptions(model ai.LanguageModel, params *openaiSDK.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
+ providerOptions := &ProviderOptions{}
+ if v, ok := call.ProviderOptions[Name]; ok {
+ providerOptions, ok = v.(*ProviderOptions)
+ if !ok {
+ return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+ }
+ }
+ _ = providerOptions
+
+ // HANDLE OPENROUTER call modification here
+
+ return nil, nil
+}
+
+func New(opts ...Option) ai.Provider {
+ providerOptions := options{
+ openaiOptions: []openai.Option{
+ openai.WithHooks(openai.Hooks{
+ PrepareCallWithOptions: prepareCallWithOptions,
+ }),
+ },
+ }
+ for _, o := range opts {
+ o(&providerOptions)
+ }
+ return openai.New(providerOptions.openaiOptions...)
+}
+
+func WithBaseURL(baseURL string) Option {
+ return func(o *options) {
+ o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(baseURL))
+ }
+}
+
+func WithAPIKey(apiKey string) Option {
+ return func(o *options) {
+ o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey))
+ }
+}
+
+func WithName(name string) Option {
+ return func(o *options) {
+ o.openaiOptions = append(o.openaiOptions, openai.WithName(name))
+ }
+}
+
+func WithHeaders(headers map[string]string) Option {
+ return func(o *options) {
+ o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers))
+ }
+}
+
+func WithHTTPClient(client option.HTTPClient) Option {
+ return func(o *options) {
+ o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client))
+ }
+}
+