From 62545b0481f950518b446d92c330ce1a17a4bbe1 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 13 Oct 2025 13:17:07 +0200 Subject: [PATCH] chore: merge catwalk and user options --- internal/agent/agent_tool.go | 2 +- internal/agent/coordinator.go | 71 +++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 8601170f53ea623e4ce82118c9558aff4e5bf53c..5e29b790f5e6e06566c1e947aebe419199d3580a 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -74,7 +74,7 @@ func (c *coordinator) agentTool() (ai.AgentTool, error) { SessionID: session.ID, Prompt: params.Prompt, MaxOutputTokens: maxTokens, - ProviderOptions: c.getProviderOptions(model), + ProviderOptions: getProviderOptions(model), Temperature: model.ModelCfg.Temperature, TopP: model.ModelCfg.TopP, TopK: model.ModelCfg.TopK, diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 59e2059aa3a6a27c9345faa5054eb2c6c331a4c8..a353ab1680484a6fd37216c4f9b47fef427e1737 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -1,8 +1,11 @@ package agent import ( + "cmp" "context" + "encoding/json" "errors" + "log/slog" "slices" "strings" @@ -24,6 +27,7 @@ import ( "github.com/charmbracelet/fantasy/openai" "github.com/charmbracelet/fantasy/openaicompat" "github.com/charmbracelet/fantasy/openrouter" + "github.com/qjebbs/go-jsons" ) type Coordinator interface { @@ -103,46 +107,79 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments = nil } + mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model) + return c.currentAgent.Run(ctx, SessionAgentCall{ SessionID: sessionID, Prompt: prompt, Attachments: attachments, MaxOutputTokens: maxTokens, - ProviderOptions: c.getProviderOptions(model), - Temperature: model.ModelCfg.Temperature, - TopP: model.ModelCfg.TopP, - TopK: model.ModelCfg.TopK, - FrequencyPenalty: model.ModelCfg.FrequencyPenalty, - PresencePenalty: model.ModelCfg.PresencePenalty, + ProviderOptions: mergedOptions, + Temperature: temp, + TopP: topP, + TopK: topK, + FrequencyPenalty: freqPenalty, + PresencePenalty: presPenalty, }) } -func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions { +func getProviderOptions(model Model) ai.ProviderOptions { options := ai.ProviderOptions{} + cfgOpts := "{}" + catwalkOpts := "{}" + + if model.ModelCfg.ProviderOptions != nil { + data, err := json.Marshal(model.ModelCfg.ProviderOptions) + if err == nil { + cfgOpts = string(data) + } + } + + if model.CatwalkCfg.Options.ProviderOptions != nil { + data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions) + if err == nil { + catwalkOpts = string(data) + } + } + + got, err := jsons.Merge([]string{catwalkOpts, cfgOpts}) + if err != nil { + slog.Error("Could not merge call config", "err", err) + return options + } + + mergedOptions := make(map[string]any) + + err = json.Unmarshal([]byte(got), &mergedOptions) + if err != nil { + slog.Error("Could not create config for call", "err", err) + return options + } + switch model.Model.Provider() { case openai.Name: - parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions) + parsed, err := openai.ParseOptions(mergedOptions) if err == nil { options[openai.Name] = parsed } case anthropic.Name: - parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions) + parsed, err := anthropic.ParseOptions(mergedOptions) if err == nil { options[anthropic.Name] = parsed } case openrouter.Name: - parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions) + parsed, err := openrouter.ParseOptions(mergedOptions) if err == nil { options[openrouter.Name] = parsed } case google.Name: - parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions) + parsed, err := google.ParseOptions(mergedOptions) if err == nil { options[google.Name] = parsed } case openaicompat.Name: - parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions) + parsed, err := openaicompat.ParseOptions(mergedOptions) if err == nil { options[openaicompat.Name] = parsed } @@ -151,6 +188,16 @@ func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions { return options } +func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) { + modelOptions := getProviderOptions(model) + temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature) + topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP) + topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK) + freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty) + presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty) + return modelOptions, temp, topP, topK, freqPenalty, presPenalty +} + func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { large, small, err := c.buildAgentModels() if err != nil {