@@ -74,7 +74,7 @@ func (c *coordinator) agentTool() (ai.AgentTool, error) {
SessionID: session.ID,
Prompt: params.Prompt,
MaxOutputTokens: maxTokens,
- ProviderOptions: c.getProviderOptions(model),
+ ProviderOptions: getProviderOptions(model),
Temperature: model.ModelCfg.Temperature,
TopP: model.ModelCfg.TopP,
TopK: model.ModelCfg.TopK,
@@ -1,8 +1,11 @@
package agent
import (
+ "cmp"
"context"
+ "encoding/json"
"errors"
+ "log/slog"
"slices"
"strings"
@@ -24,6 +27,7 @@ import (
"github.com/charmbracelet/fantasy/openai"
"github.com/charmbracelet/fantasy/openaicompat"
"github.com/charmbracelet/fantasy/openrouter"
+ "github.com/qjebbs/go-jsons"
)
type Coordinator interface {
@@ -103,46 +107,79 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
attachments = nil
}
+ mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
+
return c.currentAgent.Run(ctx, SessionAgentCall{
SessionID: sessionID,
Prompt: prompt,
Attachments: attachments,
MaxOutputTokens: maxTokens,
- ProviderOptions: c.getProviderOptions(model),
- Temperature: model.ModelCfg.Temperature,
- TopP: model.ModelCfg.TopP,
- TopK: model.ModelCfg.TopK,
- FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
- PresencePenalty: model.ModelCfg.PresencePenalty,
+ ProviderOptions: mergedOptions,
+ Temperature: temp,
+ TopP: topP,
+ TopK: topK,
+ FrequencyPenalty: freqPenalty,
+ PresencePenalty: presPenalty,
})
}
-func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
+func getProviderOptions(model Model) ai.ProviderOptions {
options := ai.ProviderOptions{}
+ cfgOpts := "{}"
+ catwalkOpts := "{}"
+
+ if model.ModelCfg.ProviderOptions != nil {
+ data, err := json.Marshal(model.ModelCfg.ProviderOptions)
+ if err == nil {
+ cfgOpts = string(data)
+ }
+ }
+
+ if model.CatwalkCfg.Options.ProviderOptions != nil {
+ data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
+ if err == nil {
+ catwalkOpts = string(data)
+ }
+ }
+
+ got, err := jsons.Merge([]string{catwalkOpts, cfgOpts})
+ if err != nil {
+ slog.Error("Could not merge call config", "err", err)
+ return options
+ }
+
+ mergedOptions := make(map[string]any)
+
+ err = json.Unmarshal([]byte(got), &mergedOptions)
+ if err != nil {
+ slog.Error("Could not create config for call", "err", err)
+ return options
+ }
+
switch model.Model.Provider() {
case openai.Name:
- parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
+ parsed, err := openai.ParseOptions(mergedOptions)
if err == nil {
options[openai.Name] = parsed
}
case anthropic.Name:
- parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
+ parsed, err := anthropic.ParseOptions(mergedOptions)
if err == nil {
options[anthropic.Name] = parsed
}
case openrouter.Name:
- parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
+ parsed, err := openrouter.ParseOptions(mergedOptions)
if err == nil {
options[openrouter.Name] = parsed
}
case google.Name:
- parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
+ parsed, err := google.ParseOptions(mergedOptions)
if err == nil {
options[google.Name] = parsed
}
case openaicompat.Name:
- parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
+ parsed, err := openaicompat.ParseOptions(mergedOptions)
if err == nil {
options[openaicompat.Name] = parsed
}
@@ -151,6 +188,16 @@ func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
return options
}
+func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
+ modelOptions := getProviderOptions(model)
+ temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
+ topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
+ topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
+ freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
+ presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
+ return modelOptions, temp, topP, topK, freqPenalty, presPenalty
+}
+
func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
large, small, err := c.buildAgentModels()
if err != nil {