From e36cc2fa2c79dc425101b5fd3a03e7b389b3ff8b Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 12:06:40 +0200 Subject: [PATCH] chore: make it possible to override maxTokens we still need to handle the case where the max tokens is more than 5% of the total context window, this can cause the endpoint to error. --- internal/llm/agent/agent.go | 18 +++++++++++------- internal/llm/provider/anthropic.go | 5 +++++ internal/llm/provider/gemini.go | 5 +++++ internal/llm/provider/openai.go | 5 +++++ internal/llm/provider/provider.go | 9 ++++++++- 5 files changed, 34 insertions(+), 8 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5f3d41c2eee4cc41df159066379346bad4a97dfc..57771a7dc98efd2fa897d655aa04b7fef628dab5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -148,7 +148,7 @@ func NewAgent( provider.WithModel(agentCfg.Model), provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)), } - agentProvider, err := provider.NewProviderV2(providerCfg, opts...) + agentProvider, err := provider.NewProvider(providerCfg, opts...) if err != nil { return nil, err } @@ -184,7 +184,7 @@ func NewAgent( provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), } - titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) if err != nil { return nil, err } @@ -192,7 +192,7 @@ func NewAgent( provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } - summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) if err != nil { return nil, err } @@ -277,7 +277,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str if err != nil { return err } - parts := []message.ContentPart{message.TextContent{Text: content}} + parts := []message.ContentPart{message.TextContent{ + Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content), + }} // Use streaming approach like summarization response := a.titleProvider.StreamResponse( @@ -831,7 +833,7 @@ func (a *agent) UpdateModel() error { provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), } - newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) + newProvider, err := provider.NewProvider(currentProviderCfg, opts...) if err != nil { return fmt.Errorf("failed to create new provider: %w", err) } @@ -873,8 +875,10 @@ func (a *agent) UpdateModel() error { titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + // We want the title to be short, so we limit the max tokens + provider.WithMaxTokens(40), } - newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) if err != nil { return fmt.Errorf("failed to create new title provider: %w", err) } @@ -884,7 +888,7 @@ func (a *agent) UpdateModel() error { provider.WithModel(config.SmallModel), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), } - newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...) if err != nil { return fmt.Errorf("failed to create new summarize provider: %w", err) } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index d8d4ec002ed35ec06d6932643e070241fed0e227..df6b8490ebc48abc7c01a2a938c6f7d395526654 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -174,6 +174,11 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to maxTokens = modelConfig.MaxTokens } + // Override max tokens if set in provider options + if a.providerOptions.maxTokens > 0 { + maxTokens = a.providerOptions.maxTokens + } + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), MaxTokens: maxTokens, diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 56263a4389e28289db5adf9392f307d11908e1cc..f644d118b4ef642c5f9e835ecfaa450d9f835f4d 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -268,6 +268,11 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + + // Override max tokens if set in provider options + if g.providerOptions.maxTokens > 0 { + maxTokens = g.providerOptions.maxTokens + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 71040b485b426b5d80e078aea7c06c710f93e4e2..1ae8847db441181a1a65bcacc8b4bd039b45a0fc 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -165,6 +165,11 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + + // Override max tokens if set in provider options + if o.providerOptions.maxTokens > 0 { + maxTokens = o.providerOptions.maxTokens + } if model.CanReason { params.MaxCompletionTokens = openai.Int(maxTokens) switch reasoningEffort { diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 6da5188b2bf6b17e1f91c9ff5e7eb2bd20931392..2133e23309b4d92d8c8b2efbf1bb386a2e7753cd 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -65,6 +65,7 @@ type providerClientOptions struct { model func(config.ModelType) config.Model disableCache bool systemMessage string + maxTokens int64 extraHeaders map[string]string extraParams map[string]string } @@ -126,7 +127,13 @@ func WithSystemMessage(systemMessage string) ProviderClientOption { } } -func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, apiKey: cfg.APIKey,