@@ -32,6 +32,7 @@ import (
"charm.land/fantasy/providers/openai"
"charm.land/fantasy/providers/openaicompat"
"charm.land/fantasy/providers/openrouter"
+ openaisdk "github.com/openai/openai-go/v2/option"
"github.com/qjebbs/go-jsons"
)
@@ -118,7 +119,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
return nil, errors.New("model provider not configured")
}
- mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
+ mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
return c.currentAgent.Run(ctx, SessionAgentCall{
SessionID: sessionID,
@@ -134,10 +135,11 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
})
}
-func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
+func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
options := fantasy.ProviderOptions{}
cfgOpts := []byte("{}")
+ providerCfgOpts := []byte("{}")
catwalkOpts := []byte("{}")
if model.ModelCfg.ProviderOptions != nil {
@@ -147,6 +149,13 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
}
}
+ if providerCfg.ProviderOptions != nil {
+ data, err := json.Marshal(providerCfg.ProviderOptions)
+ if err == nil {
+ providerCfgOpts = data
+ }
+ }
+
if model.CatwalkCfg.Options.ProviderOptions != nil {
data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
if err == nil {
@@ -156,6 +165,7 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
readers := []io.Reader{
bytes.NewReader(catwalkOpts),
+ bytes.NewReader(providerCfgOpts),
bytes.NewReader(cfgOpts),
}
@@ -173,7 +183,7 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
return options
}
- switch tp {
+ switch providerCfg.Type {
case openai.Name:
_, hasReasoningEffort := mergedOptions["reasoning_effort"]
if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
@@ -255,8 +265,8 @@ func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
return options
}
-func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
- modelOptions := getProviderOptions(model, tp)
+func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
+ modelOptions := getProviderOptions(model, cfg)
temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
@@ -516,7 +526,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
return openrouter.New(opts...)
}
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
opts := []openaicompat.Option{
openaicompat.WithBaseURL(baseURL),
openaicompat.WithAPIKey(apiKey),
@@ -529,6 +539,10 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
opts = append(opts, openaicompat.WithHeaders(headers))
}
+ for extraKey, extraValue := range extraBody {
+ opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
+ }
+
return openaicompat.New(opts...)
}
@@ -646,7 +660,7 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
case "google-vertex", "vertexai":
return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
case openaicompat.Name:
- return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
+ return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
default:
return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
}
@@ -717,5 +731,5 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
if !ok {
return errors.New("model provider not configured")
}
- return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
+ return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
}