chore: support provider extra body and provider options

Kujtim Hoxha created

Change summary

internal/agent/agent_tool.go  |  2 +-
internal/agent/coordinator.go | 30 ++++++++++++++++++++++--------
internal/config/config.go     |  4 +++-
3 files changed, 26 insertions(+), 10 deletions(-)

Detailed changes

internal/agent/agent_tool.go 🔗

@@ -79,7 +79,7 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
 				SessionID:        session.ID,
 				Prompt:           params.Prompt,
 				MaxOutputTokens:  maxTokens,
-				ProviderOptions:  getProviderOptions(model, providerCfg.Type),
+				ProviderOptions:  getProviderOptions(model, providerCfg),
 				Temperature:      model.ModelCfg.Temperature,
 				TopP:             model.ModelCfg.TopP,
 				TopK:             model.ModelCfg.TopK,

internal/agent/coordinator.go 🔗

@@ -32,6 +32,7 @@ import (
 	"charm.land/fantasy/providers/openai"
 	"charm.land/fantasy/providers/openaicompat"
 	"charm.land/fantasy/providers/openrouter"
+	openaisdk "github.com/openai/openai-go/v2/option"
 	"github.com/qjebbs/go-jsons"
 )
 
@@ -118,7 +119,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 		return nil, errors.New("model provider not configured")
 	}
 
-	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
+	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 
 	return c.currentAgent.Run(ctx, SessionAgentCall{
 		SessionID:        sessionID,
@@ -134,10 +135,11 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 	})
 }
 
-func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
+func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
 	options := fantasy.ProviderOptions{}
 
 	cfgOpts := []byte("{}")
+	providerCfgOpts := []byte("{}")
 	catwalkOpts := []byte("{}")
 
 	if model.ModelCfg.ProviderOptions != nil {
@@ -147,6 +149,13 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
 		}
 	}
 
+	if providerCfg.ProviderOptions != nil {
+		data, err := json.Marshal(providerCfg.ProviderOptions)
+		if err == nil {
+			providerCfgOpts = data
+		}
+	}
+
 	if model.CatwalkCfg.Options.ProviderOptions != nil {
 		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
 		if err == nil {
@@ -156,6 +165,7 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
 
 	readers := []io.Reader{
 		bytes.NewReader(catwalkOpts),
+		bytes.NewReader(providerCfgOpts),
 		bytes.NewReader(cfgOpts),
 	}
 
@@ -173,7 +183,7 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
 		return options
 	}
 
-	switch tp {
+	switch providerCfg.Type {
 	case openai.Name:
 		_, hasReasoningEffort := mergedOptions["reasoning_effort"]
 		if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
@@ -255,8 +265,8 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
 	return options
 }
 
-func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
-	modelOptions := getProviderOptions(model, tp)
+func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
+	modelOptions := getProviderOptions(model, cfg)
 	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
 	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
 	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
@@ -516,7 +526,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
 	return openrouter.New(opts...)
 }
 
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
 	opts := []openaicompat.Option{
 		openaicompat.WithBaseURL(baseURL),
 		openaicompat.WithAPIKey(apiKey),
@@ -529,6 +539,10 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
 		opts = append(opts, openaicompat.WithHeaders(headers))
 	}
 
+	for extraKey, extraValue := range extraBody {
+		opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
+	}
+
 	return openaicompat.New(opts...)
 }
 
@@ -646,7 +660,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
 	case "google-vertex", "vertexai":
 		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 	case openaicompat.Name:
-		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
+		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
 	default:
 		return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
 	}
@@ -717,5 +731,5 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 	if !ok {
 		return errors.New("model provider not configured")
 	}
-	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
+	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 }

internal/config/config.go 🔗

@@ -99,7 +99,9 @@ type ProviderConfig struct {
 	// Extra headers to send with each request to the provider.
 	ExtraHeaders map[string]string `json:"extra_headers,omitempty" jsonschema:"description=Additional HTTP headers to send with requests"`
 	// Extra body
-	ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies"`
+	ExtraBody map[string]any `json:"extra_body,omitempty" jsonschema:"description=Additional fields to include in request bodies, only works with openai-compatible providers"`
+
+	ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for this provider"`
 
 	// Used to pass extra parameters to the provider.
 	ExtraParams map[string]string `json:"-"`