From eb3550a2bb26f9e4199a6339a3144dcd898dacbe Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 11:56:46 +0200 Subject: [PATCH] chore: change how max tokens works --- internal/config/config.go | 9 ++++++--- internal/llm/agent/agent.go | 6 ------ internal/llm/provider/anthropic.go | 12 +++++++++++- internal/llm/provider/gemini.go | 26 +++++++++++++++++++++----- internal/llm/provider/openai.go | 9 +++++++-- internal/llm/provider/provider.go | 7 ------- 6 files changed, 45 insertions(+), 24 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index f6814db44cdadefd0e88e57e2bcd2521bf8a3c28..32ca8729295bb3994af27ec4359a1b4960527671 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,9 +157,12 @@ type Options struct { } type PreferredModel struct { - ModelID string `json:"model_id"` - Provider provider.InferenceProvider `json:"provider"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` + ModelID string `json:"model_id"` + Provider provider.InferenceProvider `json:"provider"` + // Overrides the default reasoning effort for this model + ReasoningEffort string `json:"reasoning_effort,omitempty"` + // Overrides the default max tokens for this model + MaxTokens int64 `json:"max_tokens,omitempty"` } type PreferredModels struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 8312b0f8965a5d02f7ce049abff50953cc56e422..5f3d41c2eee4cc41df159066379346bad4a97dfc 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -147,7 +147,6 @@ func NewAgent( opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), } agentProvider, err := provider.NewProviderV2(providerCfg, opts...) if err != nil { @@ -184,7 +183,6 @@ func NewAgent( titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(40), } titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) if err != nil { @@ -193,7 +191,6 @@ func NewAgent( summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - provider.WithMaxTokens(smallModel.DefaultMaxTokens), } summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) if err != nil { @@ -832,7 +829,6 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), - provider.WithMaxTokens(model.DefaultMaxTokens), } newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) @@ -877,7 +873,6 @@ func (a *agent) UpdateModel() error { titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(40), } newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) if err != nil { @@ -888,7 +883,6 @@ func (a *agent) UpdateModel() error { summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), - provider.WithMaxTokens(smallModel.DefaultMaxTokens), } newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) if err != nil { diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 626882f283c030454477b27b152bd6a717d08476..d8d4ec002ed35ec06d6932643e070241fed0e227 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -164,9 +164,19 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to // } // } + cfg := config.Get() + modelConfig := cfg.Models.Large + if a.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), - MaxTokens: a.providerOptions.maxTokens, + MaxTokens: maxTokens, Temperature: temperature, Messages: messages, Tools: tools, diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 3531c5cc89cced262d5c22f2598216d9cfe4883e..56263a4389e28289db5adf9392f307d11908e1cc 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -155,17 +155,26 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { // Convert messages geminiMessages := g.convertMessages(messages) - + model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } + modelConfig := cfg.Models.Large + if g.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(g.providerOptions.maxTokens), + MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, @@ -173,7 +182,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if len(tools) > 0 { config.Tools = g.convertTools(tools) } - model := g.providerOptions.model(g.providerOptions.modelType) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 @@ -245,16 +253,25 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t // Convert messages geminiMessages := g.convertMessages(messages) + model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() if cfg.Options.Debug { jsonData, _ := json.Marshal(geminiMessages) logging.Debug("Prepared messages", "messages", string(jsonData)) } + modelConfig := cfg.Models.Large + if g.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ - MaxOutputTokens: int32(g.providerOptions.maxTokens), + MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, }, @@ -262,7 +279,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if len(tools) > 0 { config.Tools = g.convertTools(tools) } - model := g.providerOptions.model(g.providerOptions.modelType) chat, _ := g.client.Chats.Create(ctx, model.ID, config, history) attempts := 0 diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index f6aaacb0ce09ae665fab4bdbb14b28e13e7684c7..71040b485b426b5d80e078aea7c06c710f93e4e2 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -160,8 +160,13 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar Messages: messages, Tools: tools, } + + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } if model.CanReason { - params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) + params.MaxCompletionTokens = openai.Int(maxTokens) switch reasoningEffort { case "low": params.ReasoningEffort = shared.ReasoningEffortLow @@ -173,7 +178,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar params.ReasoningEffort = shared.ReasoningEffortMedium } } else { - params.MaxTokens = openai.Int(o.providerOptions.maxTokens) + params.MaxTokens = openai.Int(maxTokens) } return params diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 9723dc9fe55af414ed415653e3e9e31031395a02..6da5188b2bf6b17e1f91c9ff5e7eb2bd20931392 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -64,7 +64,6 @@ type providerClientOptions struct { modelType config.ModelType model func(config.ModelType) config.Model disableCache bool - maxTokens int64 systemMessage string extraHeaders map[string]string extraParams map[string]string @@ -121,12 +120,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption { } } -func WithMaxTokens(maxTokens int64) ProviderClientOption { - return func(options *providerClientOptions) { - options.maxTokens = maxTokens - } -} - func WithSystemMessage(systemMessage string) ProviderClientOption { return func(options *providerClientOptions) { options.systemMessage = systemMessage