wip: initial setup with hoooks

kujtimiihoxha created

Change summary

cspell.json                    |   2 
openai/openai.go               | 195 +++++++++++++++++++++--------------
openrouter/openrouter.go       |  74 +++++++++++++
openrouter/provider_options.go |  81 ++++++++++++++
4 files changed, 273 insertions(+), 79 deletions(-)

Detailed changes

cspell.json 🔗

@@ -1 +1 @@
-{"language":"en","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr"],"version":"0.2","flagWords":[]}
+{"version":"0.2","words":["mapstructure","mapstructure","charmbracelet","providertests","joho","godotenv","stretchr","Quantizations","Logit","Probs"],"flagWords":[],"language":"en"}

openai/openai.go 🔗

@@ -29,12 +29,19 @@ type provider struct {
 	options options
 }
 
+type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
+
+type Hooks struct {
+	PrepareCallWithOptions PrepareCallWithOptions
+}
+
 type options struct {
 	baseURL      string
 	apiKey       string
 	organization string
 	project      string
 	name         string
+	hooks        Hooks
 	headers      map[string]string
 	client       option.HTTPClient
 }
@@ -104,6 +111,12 @@ func WithHTTPClient(client option.HTTPClient) Option {
 	}
 }
 
+func WithHooks(hooks Hooks) Option {
+	return func(o *options) {
+		o.hooks = hooks
+	}
+}
+
 // LanguageModel implements ai.Provider.
 func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 	openaiClientOptions := []option.RequestOption{}
@@ -147,24 +160,19 @@ func (o languageModel) Provider() string {
 	return o.provider
 }
 
-func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
-	params := &openai.ChatCompletionNewParams{}
-	messages, warnings := toPrompt(call.Prompt)
+func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
+	if call.ProviderOptions == nil {
+		return nil, nil
+	}
+	var warnings []ai.CallWarning
 	providerOptions := &ProviderOptions{}
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
+			return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
 		}
 	}
-	if call.TopK != nil {
-		warnings = append(warnings, ai.CallWarning{
-			Type:    ai.CallWarningTypeUnsupportedSetting,
-			Setting: "top_k",
-		})
-	}
-	params.Messages = messages
-	params.Model = o.modelID
+
 	if providerOptions.LogitBias != nil {
 		params.LogitBias = providerOptions.LogitBias
 	}
@@ -183,23 +191,6 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 	if providerOptions.ParallelToolCalls != nil {
 		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 	}
-
-	if call.MaxOutputTokens != nil {
-		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
-	}
-	if call.Temperature != nil {
-		params.Temperature = param.NewOpt(*call.Temperature)
-	}
-	if call.TopP != nil {
-		params.TopP = param.NewOpt(*call.TopP)
-	}
-	if call.FrequencyPenalty != nil {
-		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
-	}
-	if call.PresencePenalty != nil {
-		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
-	}
-
 	if providerOptions.MaxCompletionTokens != nil {
 		params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
 	}
@@ -253,45 +244,11 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 		case ReasoningEffortHigh:
 			params.ReasoningEffort = shared.ReasoningEffortHigh
 		default:
-			return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
+			return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
 		}
 	}
 
-	if isReasoningModel(o.modelID) {
-		// remove unsupported settings for reasoning models
-		// see https://platform.openai.com/docs/guides/reasoning#limitations
-		if call.Temperature != nil {
-			params.Temperature = param.Opt[float64]{}
-			warnings = append(warnings, ai.CallWarning{
-				Type:    ai.CallWarningTypeUnsupportedSetting,
-				Setting: "temperature",
-				Details: "temperature is not supported for reasoning models",
-			})
-		}
-		if call.TopP != nil {
-			params.TopP = param.Opt[float64]{}
-			warnings = append(warnings, ai.CallWarning{
-				Type:    ai.CallWarningTypeUnsupportedSetting,
-				Setting: "TopP",
-				Details: "TopP is not supported for reasoning models",
-			})
-		}
-		if call.FrequencyPenalty != nil {
-			params.FrequencyPenalty = param.Opt[float64]{}
-			warnings = append(warnings, ai.CallWarning{
-				Type:    ai.CallWarningTypeUnsupportedSetting,
-				Setting: "FrequencyPenalty",
-				Details: "FrequencyPenalty is not supported for reasoning models",
-			})
-		}
-		if call.PresencePenalty != nil {
-			params.PresencePenalty = param.Opt[float64]{}
-			warnings = append(warnings, ai.CallWarning{
-				Type:    ai.CallWarningTypeUnsupportedSetting,
-				Setting: "PresencePenalty",
-				Details: "PresencePenalty is not supported for reasoning models",
-			})
-		}
+	if isReasoningModel(model.Model()) {
 		if providerOptions.LogitBias != nil {
 			params.LogitBias = nil
 			warnings = append(warnings, ai.CallWarning{
@@ -324,31 +281,20 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 			}
 			params.MaxTokens = param.Opt[int64]{}
 		}
-	}
 
-	// Handle search preview models
-	if isSearchPreviewModel(o.modelID) {
-		if call.Temperature != nil {
-			params.Temperature = param.Opt[float64]{}
-			warnings = append(warnings, ai.CallWarning{
-				Type:    ai.CallWarningTypeUnsupportedSetting,
-				Setting: "temperature",
-				Details: "temperature is not supported for the search preview models and has been removed.",
-			})
-		}
 	}
 
 	// Handle service tier validation
 	if providerOptions.ServiceTier != nil {
 		serviceTier := *providerOptions.ServiceTier
-		if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
+		if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
 			params.ServiceTier = ""
 			warnings = append(warnings, ai.CallWarning{
 				Type:    ai.CallWarningTypeUnsupportedSetting,
 				Setting: "ServiceTier",
 				Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
 			})
-		} else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
+		} else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
 			params.ServiceTier = ""
 			warnings = append(warnings, ai.CallWarning{
 				Type:    ai.CallWarningTypeUnsupportedSetting,
@@ -357,6 +303,99 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
 			})
 		}
 	}
+	return warnings, nil
+}
+
+func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
+	params := &openai.ChatCompletionNewParams{}
+	messages, warnings := toPrompt(call.Prompt)
+	if call.TopK != nil {
+		warnings = append(warnings, ai.CallWarning{
+			Type:    ai.CallWarningTypeUnsupportedSetting,
+			Setting: "top_k",
+		})
+	}
+	params.Messages = messages
+	params.Model = o.modelID
+
+	if call.MaxOutputTokens != nil {
+		params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
+	}
+	if call.Temperature != nil {
+		params.Temperature = param.NewOpt(*call.Temperature)
+	}
+	if call.TopP != nil {
+		params.TopP = param.NewOpt(*call.TopP)
+	}
+	if call.FrequencyPenalty != nil {
+		params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
+	}
+	if call.PresencePenalty != nil {
+		params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
+	}
+
+	if isReasoningModel(o.modelID) {
+		// remove unsupported settings for reasoning models
+		// see https://platform.openai.com/docs/guides/reasoning#limitations
+		if call.Temperature != nil {
+			params.Temperature = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "temperature",
+				Details: "temperature is not supported for reasoning models",
+			})
+		}
+		if call.TopP != nil {
+			params.TopP = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "TopP",
+				Details: "TopP is not supported for reasoning models",
+			})
+		}
+		if call.FrequencyPenalty != nil {
+			params.FrequencyPenalty = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "FrequencyPenalty",
+				Details: "FrequencyPenalty is not supported for reasoning models",
+			})
+		}
+		if call.PresencePenalty != nil {
+			params.PresencePenalty = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "PresencePenalty",
+				Details: "PresencePenalty is not supported for reasoning models",
+			})
+		}
+	}
+
+	// Handle search preview models
+	if isSearchPreviewModel(o.modelID) {
+		if call.Temperature != nil {
+			params.Temperature = param.Opt[float64]{}
+			warnings = append(warnings, ai.CallWarning{
+				Type:    ai.CallWarningTypeUnsupportedSetting,
+				Setting: "temperature",
+				Details: "temperature is not supported for the search preview models and has been removed.",
+			})
+		}
+	}
+
+	prepareOptions := prepareCallWithOptions
+	if o.options.hooks.PrepareCallWithOptions != nil {
+		prepareOptions = o.options.hooks.PrepareCallWithOptions
+	}
+
+	optionsWarnings, err := prepareOptions(o, params, call)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	if len(optionsWarnings) > 0 {
+		warnings = append(warnings, optionsWarnings...)
+	}
 
 	if len(call.Tools) > 0 {
 		tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)

openrouter/openrouter.go 🔗

@@ -0,0 +1,74 @@
+package openrouter
+
+import (
+	"github.com/charmbracelet/fantasy/ai"
+	"github.com/charmbracelet/fantasy/openai"
+	openaiSDK "github.com/openai/openai-go/v2"
+	"github.com/openai/openai-go/v2/option"
+)
+
+type options struct {
+	openaiOptions []openai.Option
+}
+
+type Option = func(*options)
+
+func prepareCallWithOptions(model ai.LanguageModel, params *openaiSDK.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
+	providerOptions := &ProviderOptions{}
+	if v, ok := call.ProviderOptions[Name]; ok {
+		providerOptions, ok = v.(*ProviderOptions)
+		if !ok {
+			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+		}
+	}
+	_ = providerOptions
+
+	// HANDLE OPENROUTER call modification here
+
+	return nil, nil
+}
+
+func New(opts ...Option) ai.Provider {
+	providerOptions := options{
+		openaiOptions: []openai.Option{
+			openai.WithHooks(openai.Hooks{
+				PrepareCallWithOptions: prepareCallWithOptions,
+			}),
+		},
+	}
+	for _, o := range opts {
+		o(&providerOptions)
+	}
+	return openai.New(providerOptions.openaiOptions...)
+}
+
+func WithBaseURL(baseURL string) Option {
+	return func(o *options) {
+		o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(baseURL))
+	}
+}
+
+func WithAPIKey(apiKey string) Option {
+	return func(o *options) {
+		o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey))
+	}
+}
+
+func WithName(name string) Option {
+	return func(o *options) {
+		o.openaiOptions = append(o.openaiOptions, openai.WithName(name))
+	}
+}
+
+func WithHeaders(headers map[string]string) Option {
+	return func(o *options) {
+		o.openaiOptions = append(o.openaiOptions, openai.WithHeaders(headers))
+	}
+}
+
+func WithHTTPClient(client option.HTTPClient) Option {
+	return func(o *options) {
+		o.openaiOptions = append(o.openaiOptions, openai.WithHTTPClient(client))
+	}
+}
+

openrouter/provider_options.go 🔗

@@ -0,0 +1,81 @@
+package openrouter
+
+import (
+	"github.com/charmbracelet/fantasy/ai"
+)
+
+const Name = "openrouter"
+
+type ReasoningEffort string
+
+const (
+	ReasoningEffortLow    ReasoningEffort = "low"
+	ReasoningEffortMedium ReasoningEffort = "medium"
+	ReasoningEffortHigh   ReasoningEffort = "high"
+)
+
+type ProviderMetadata struct{}
+
+func (*ProviderMetadata) Options() {}
+
+type ReasoningOptions struct {
+	// Whether reasoning is enabled
+	Enabled *bool `json:"enabled"`
+	// Whether to exclude reasoning from the response
+	Exclude *bool `json:"exclude"`
+	// Maximum number of tokens to use for reasoning
+	MaxTokens *int64 `json:"max_tokens"`
+	// Reasoning effort level: "low" | "medium" | "high"
+	Effort *ReasoningEffort `json:"effort"`
+}
+
+type Provider struct {
+	// List of provider slugs to try in order (e.g. ["anthropic", "openai"])
+	Order []string `json:"order"`
+	// Whether to allow backup providers when primary is unavailable (default: true)
+	AllowFallbacks *bool `json:"allow_fallbacks"`
+	// Only use providers that support all parameters in your request (default: false)
+	RequireParameters *bool `json:"require_parameters"`
+	// Control whether to use providers that may store data: "allow" | "deny"
+	DataCollection *string `json:"data_collection"`
+	// List of provider slugs to allow for this request
+	Only []string `json:"only"`
+	// List of provider slugs to skip for this request
+	Ignore []string `json:"ignore"`
+	// List of quantization levels to filter by (e.g. ["int4", "int8"])
+	Quantizations []string `json:"quantizations"`
+	// Sort providers by "price" | "throughput" | "latency"
+	Sort *string `json:"sort"`
+}
+
+type ProviderOptions struct {
+	Reasoning    *ReasoningOptions `json:"reasoning"`
+	ExtraBody    map[string]any    `json:"extra_body"`
+	IncludeUsage *bool             `json:"include_usage"`
+	// Modify the likelihood of specified tokens appearing in the completion.
+	// Accepts a map that maps tokens (specified by their token ID) to an associated bias value from -100 to 100.
+	// The bias is added to the logits generated by the model prior to sampling.
+	LogitBias map[string]int64 `json:"logit_bias"`
+	// Return the log probabilities of the tokens. Including logprobs will increase the response size.
+	// Setting to true will return the log probabilities of the tokens that were generated.
+	LogProbs *bool `json:"log_probs"`
+	// Whether to enable parallel function calling during tool use. Default to true.
+	ParallelToolCalls *bool `json:"parallel_tool_calls"`
+	// A unique identifier representing your end-user, which can help OpenRouter to monitor and detect abuse.
+	User *string `json:"user"`
+	// Provider routing preferences to control request routing behavior
+	Provider *Provider `json:"provider"`
+	// TODO: add the web search plugin config
+}
+
+func (*ProviderOptions) Options() {}
+
+func ReasoningEffortOption(e ReasoningEffort) *ReasoningEffort {
+	return &e
+}
+
+func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions {
+	return ai.ProviderOptions{
+		Name: opts,
+	}
+}