chore: merge catwalk and user options

Kujtim Hoxha created

Change summary

internal/agent/agent_tool.go  |  2 
internal/agent/coordinator.go | 71 ++++++++++++++++++++++++++++++------
2 files changed, 60 insertions(+), 13 deletions(-)

Detailed changes

internal/agent/agent_tool.go 🔗

@@ -74,7 +74,7 @@ func (c *coordinator) agentTool() (ai.AgentTool, error) {
 				SessionID:        session.ID,
 				Prompt:           params.Prompt,
 				MaxOutputTokens:  maxTokens,
-				ProviderOptions:  c.getProviderOptions(model),
+				ProviderOptions:  getProviderOptions(model),
 				Temperature:      model.ModelCfg.Temperature,
 				TopP:             model.ModelCfg.TopP,
 				TopK:             model.ModelCfg.TopK,

internal/agent/coordinator.go 🔗

@@ -1,8 +1,11 @@
 package agent
 
 import (
+	"cmp"
 	"context"
+	"encoding/json"
 	"errors"
+	"log/slog"
 	"slices"
 	"strings"
 
@@ -24,6 +27,7 @@ import (
 	"github.com/charmbracelet/fantasy/openai"
 	"github.com/charmbracelet/fantasy/openaicompat"
 	"github.com/charmbracelet/fantasy/openrouter"
+	"github.com/qjebbs/go-jsons"
 )
 
 type Coordinator interface {
@@ -103,46 +107,79 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 		attachments = nil
 	}
 
+	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
+
 	return c.currentAgent.Run(ctx, SessionAgentCall{
 		SessionID:        sessionID,
 		Prompt:           prompt,
 		Attachments:      attachments,
 		MaxOutputTokens:  maxTokens,
-		ProviderOptions:  c.getProviderOptions(model),
-		Temperature:      model.ModelCfg.Temperature,
-		TopP:             model.ModelCfg.TopP,
-		TopK:             model.ModelCfg.TopK,
-		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
-		PresencePenalty:  model.ModelCfg.PresencePenalty,
+		ProviderOptions:  mergedOptions,
+		Temperature:      temp,
+		TopP:             topP,
+		TopK:             topK,
+		FrequencyPenalty: freqPenalty,
+		PresencePenalty:  presPenalty,
 	})
 }
 
-func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
+func getProviderOptions(model Model) ai.ProviderOptions {
 	options := ai.ProviderOptions{}
 
+	cfgOpts := "{}"
+	catwalkOpts := "{}"
+
+	if model.ModelCfg.ProviderOptions != nil {
+		data, err := json.Marshal(model.ModelCfg.ProviderOptions)
+		if err == nil {
+			cfgOpts = string(data)
+		}
+	}
+
+	if model.CatwalkCfg.Options.ProviderOptions != nil {
+		data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
+		if err == nil {
+			catwalkOpts = string(data)
+		}
+	}
+
+	got, err := jsons.Merge([]string{catwalkOpts, cfgOpts})
+	if err != nil {
+		slog.Error("Could not merge call config", "err", err)
+		return options
+	}
+
+	mergedOptions := make(map[string]any)
+
+	err = json.Unmarshal([]byte(got), &mergedOptions)
+	if err != nil {
+		slog.Error("Could not create config for call", "err", err)
+		return options
+	}
+
 	switch model.Model.Provider() {
 	case openai.Name:
-		parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
+		parsed, err := openai.ParseOptions(mergedOptions)
 		if err == nil {
 			options[openai.Name] = parsed
 		}
 	case anthropic.Name:
-		parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
+		parsed, err := anthropic.ParseOptions(mergedOptions)
 		if err == nil {
 			options[anthropic.Name] = parsed
 		}
 	case openrouter.Name:
-		parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
+		parsed, err := openrouter.ParseOptions(mergedOptions)
 		if err == nil {
 			options[openrouter.Name] = parsed
 		}
 	case google.Name:
-		parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
+		parsed, err := google.ParseOptions(mergedOptions)
 		if err == nil {
 			options[google.Name] = parsed
 		}
 	case openaicompat.Name:
-		parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
+		parsed, err := openaicompat.ParseOptions(mergedOptions)
 		if err == nil {
 			options[openaicompat.Name] = parsed
 		}
@@ -151,6 +188,16 @@ func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
 	return options
 }
 
+func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
+	modelOptions := getProviderOptions(model)
+	temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
+	topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
+	topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
+	freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
+	presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
+	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
+}
+
 func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
 	large, small, err := c.buildAgentModels()
 	if err != nil {