chore: change how max tokens works

Kujtim Hoxha created

Change summary

internal/config/config.go          |  9 ++++++---
internal/llm/agent/agent.go        |  6 ------
internal/llm/provider/anthropic.go | 12 +++++++++++-
internal/llm/provider/gemini.go    | 26 +++++++++++++++++++++-----
internal/llm/provider/openai.go    |  9 +++++++--
internal/llm/provider/provider.go  |  7 -------
6 files changed, 45 insertions(+), 24 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -157,9 +157,12 @@ type Options struct {
 }
 
 type PreferredModel struct {
-	ModelID         string                     `json:"model_id"`
-	Provider        provider.InferenceProvider `json:"provider"`
-	ReasoningEffort string                     `json:"reasoning_effort,omitempty"`
+	ModelID  string                     `json:"model_id"`
+	Provider provider.InferenceProvider `json:"provider"`
+	// Overrides the default reasoning effort for this model
+	ReasoningEffort string `json:"reasoning_effort,omitempty"`
+	// Overrides the default max tokens for this model
+	MaxTokens int64 `json:"max_tokens,omitempty"`
 }
 
 type PreferredModels struct {

internal/llm/agent/agent.go 🔗

@@ -147,7 +147,6 @@ func NewAgent(
 	opts := []provider.ProviderClientOption{
 		provider.WithModel(agentCfg.Model),
 		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
-		provider.WithMaxTokens(model.DefaultMaxTokens),
 	}
 	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
 	if err != nil {
@@ -184,7 +183,6 @@ func NewAgent(
 	titleOpts := []provider.ProviderClientOption{
 		provider.WithModel(config.SmallModel),
 		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
-		provider.WithMaxTokens(40),
 	}
 	titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
 	if err != nil {
@@ -193,7 +191,6 @@ func NewAgent(
 	summarizeOpts := []provider.ProviderClientOption{
 		provider.WithModel(config.SmallModel),
 		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
-		provider.WithMaxTokens(smallModel.DefaultMaxTokens),
 	}
 	summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
 	if err != nil {
@@ -832,7 +829,6 @@ func (a *agent) UpdateModel() error {
 		opts := []provider.ProviderClientOption{
 			provider.WithModel(a.agentCfg.Model),
 			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
-			provider.WithMaxTokens(model.DefaultMaxTokens),
 		}
 
 		newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
@@ -877,7 +873,6 @@ func (a *agent) UpdateModel() error {
 		titleOpts := []provider.ProviderClientOption{
 			provider.WithModel(config.SmallModel),
 			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
-			provider.WithMaxTokens(40),
 		}
 		newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
 		if err != nil {
@@ -888,7 +883,6 @@ func (a *agent) UpdateModel() error {
 		summarizeOpts := []provider.ProviderClientOption{
 			provider.WithModel(config.SmallModel),
 			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
-			provider.WithMaxTokens(smallModel.DefaultMaxTokens),
 		}
 		newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
 		if err != nil {

internal/llm/provider/anthropic.go 🔗

@@ -164,9 +164,19 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 	// 	}
 	// }
 
+	cfg := config.Get()
+	modelConfig := cfg.Models.Large
+	if a.providerOptions.modelType == config.SmallModel {
+		modelConfig = cfg.Models.Small
+	}
+	maxTokens := model.DefaultMaxTokens
+	if modelConfig.MaxTokens > 0 {
+		maxTokens = modelConfig.MaxTokens
+	}
+
 	return anthropic.MessageNewParams{
 		Model:       anthropic.Model(model.ID),
-		MaxTokens:   a.providerOptions.maxTokens,
+		MaxTokens:   maxTokens,
 		Temperature: temperature,
 		Messages:    messages,
 		Tools:       tools,

internal/llm/provider/gemini.go 🔗

@@ -155,17 +155,26 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
 func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 	// Convert messages
 	geminiMessages := g.convertMessages(messages)
-
+	model := g.providerOptions.model(g.providerOptions.modelType)
 	cfg := config.Get()
 	if cfg.Options.Debug {
 		jsonData, _ := json.Marshal(geminiMessages)
 		logging.Debug("Prepared messages", "messages", string(jsonData))
 	}
 
+	modelConfig := cfg.Models.Large
+	if g.providerOptions.modelType == config.SmallModel {
+		modelConfig = cfg.Models.Small
+	}
+
+	maxTokens := model.DefaultMaxTokens
+	if modelConfig.MaxTokens > 0 {
+		maxTokens = modelConfig.MaxTokens
+	}
 	history := geminiMessages[:len(geminiMessages)-1] // All but last message
 	lastMsg := geminiMessages[len(geminiMessages)-1]
 	config := &genai.GenerateContentConfig{
-		MaxOutputTokens: int32(g.providerOptions.maxTokens),
+		MaxOutputTokens: int32(maxTokens),
 		SystemInstruction: &genai.Content{
 			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
 		},
@@ -173,7 +182,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
 	if len(tools) > 0 {
 		config.Tools = g.convertTools(tools)
 	}
-	model := g.providerOptions.model(g.providerOptions.modelType)
 	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
 
 	attempts := 0
@@ -245,16 +253,25 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 	// Convert messages
 	geminiMessages := g.convertMessages(messages)
 
+	model := g.providerOptions.model(g.providerOptions.modelType)
 	cfg := config.Get()
 	if cfg.Options.Debug {
 		jsonData, _ := json.Marshal(geminiMessages)
 		logging.Debug("Prepared messages", "messages", string(jsonData))
 	}
 
+	modelConfig := cfg.Models.Large
+	if g.providerOptions.modelType == config.SmallModel {
+		modelConfig = cfg.Models.Small
+	}
+	maxTokens := model.DefaultMaxTokens
+	if modelConfig.MaxTokens > 0 {
+		maxTokens = modelConfig.MaxTokens
+	}
 	history := geminiMessages[:len(geminiMessages)-1] // All but last message
 	lastMsg := geminiMessages[len(geminiMessages)-1]
 	config := &genai.GenerateContentConfig{
-		MaxOutputTokens: int32(g.providerOptions.maxTokens),
+		MaxOutputTokens: int32(maxTokens),
 		SystemInstruction: &genai.Content{
 			Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
 		},
@@ -262,7 +279,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 	if len(tools) > 0 {
 		config.Tools = g.convertTools(tools)
 	}
-	model := g.providerOptions.model(g.providerOptions.modelType)
 	chat, _ := g.client.Chats.Create(ctx, model.ID, config, history)
 
 	attempts := 0

internal/llm/provider/openai.go 🔗

@@ -160,8 +160,13 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
 		Messages: messages,
 		Tools:    tools,
 	}
+
+	maxTokens := model.DefaultMaxTokens
+	if modelConfig.MaxTokens > 0 {
+		maxTokens = modelConfig.MaxTokens
+	}
 	if model.CanReason {
-		params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
+		params.MaxCompletionTokens = openai.Int(maxTokens)
 		switch reasoningEffort {
 		case "low":
 			params.ReasoningEffort = shared.ReasoningEffortLow
@@ -173,7 +178,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
 			params.ReasoningEffort = shared.ReasoningEffortMedium
 		}
 	} else {
-		params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
+		params.MaxTokens = openai.Int(maxTokens)
 	}
 
 	return params

internal/llm/provider/provider.go 🔗

@@ -64,7 +64,6 @@ type providerClientOptions struct {
 	modelType     config.ModelType
 	model         func(config.ModelType) config.Model
 	disableCache  bool
-	maxTokens     int64
 	systemMessage string
 	extraHeaders  map[string]string
 	extraParams   map[string]string
@@ -121,12 +120,6 @@ func WithDisableCache(disableCache bool) ProviderClientOption {
 	}
 }
 
-func WithMaxTokens(maxTokens int64) ProviderClientOption {
-	return func(options *providerClientOptions) {
-		options.maxTokens = maxTokens
-	}
-}
-
 func WithSystemMessage(systemMessage string) ProviderClientOption {
 	return func(options *providerClientOptions) {
 		options.systemMessage = systemMessage