chore: make it possible to override maxTokens

Kujtim Hoxha created

we still need to handle the case where the max tokens is more than 5% of
the total context window, this can cause the endpoint to error.

Change summary

internal/llm/agent/agent.go        | 18 +++++++++++-------
internal/llm/provider/anthropic.go |  5 +++++
internal/llm/provider/gemini.go    |  5 +++++
internal/llm/provider/openai.go    |  5 +++++
internal/llm/provider/provider.go  |  9 ++++++++-
5 files changed, 34 insertions(+), 8 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -148,7 +148,7 @@ func NewAgent(
 		provider.WithModel(agentCfg.Model),
 		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
 	}
-	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
+	agentProvider, err := provider.NewProvider(providerCfg, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -184,7 +184,7 @@ func NewAgent(
 		provider.WithModel(config.SmallModel),
 		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 	}
-	titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
+	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
 	if err != nil {
 		return nil, err
 	}
@@ -192,7 +192,7 @@ func NewAgent(
 		provider.WithModel(config.SmallModel),
 		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
 	}
-	summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
+	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
 	if err != nil {
 		return nil, err
 	}
@@ -277,7 +277,9 @@ func (a *agent) generateTitle(ctx context.Context, sessionID string, content str
 	if err != nil {
 		return err
 	}
-	parts := []message.ContentPart{message.TextContent{Text: content}}
+	parts := []message.ContentPart{message.TextContent{
+		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
+	}}
 
 	// Use streaming approach like summarization
 	response := a.titleProvider.StreamResponse(
@@ -831,7 +833,7 @@ func (a *agent) UpdateModel() error {
 			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
 		}
 
-		newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
+		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
 		if err != nil {
 			return fmt.Errorf("failed to create new provider: %w", err)
 		}
@@ -873,8 +875,10 @@ func (a *agent) UpdateModel() error {
 		titleOpts := []provider.ProviderClientOption{
 			provider.WithModel(config.SmallModel),
 			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+			// We want the title to be short, so we limit the max tokens
+			provider.WithMaxTokens(40),
 		}
-		newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
+		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
 		if err != nil {
 			return fmt.Errorf("failed to create new title provider: %w", err)
 		}
@@ -884,7 +888,7 @@ func (a *agent) UpdateModel() error {
 			provider.WithModel(config.SmallModel),
 			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
 		}
-		newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
+		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
 		if err != nil {
 			return fmt.Errorf("failed to create new summarize provider: %w", err)
 		}

internal/llm/provider/anthropic.go 🔗

@@ -174,6 +174,11 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 		maxTokens = modelConfig.MaxTokens
 	}
 
+	// Override max tokens if set in provider options
+	if a.providerOptions.maxTokens > 0 {
+		maxTokens = a.providerOptions.maxTokens
+	}
+
 	return anthropic.MessageNewParams{
 		Model:       anthropic.Model(model.ID),
 		MaxTokens:   maxTokens,

internal/llm/provider/gemini.go 🔗

@@ -268,6 +268,11 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
 	if modelConfig.MaxTokens > 0 {
 		maxTokens = modelConfig.MaxTokens
 	}
+
+	// Override max tokens if set in provider options
+	if g.providerOptions.maxTokens > 0 {
+		maxTokens = g.providerOptions.maxTokens
+	}
 	history := geminiMessages[:len(geminiMessages)-1] // All but last message
 	lastMsg := geminiMessages[len(geminiMessages)-1]
 	config := &genai.GenerateContentConfig{

internal/llm/provider/openai.go 🔗

@@ -165,6 +165,11 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
 	if modelConfig.MaxTokens > 0 {
 		maxTokens = modelConfig.MaxTokens
 	}
+
+	// Override max tokens if set in provider options
+	if o.providerOptions.maxTokens > 0 {
+		maxTokens = o.providerOptions.maxTokens
+	}
 	if model.CanReason {
 		params.MaxCompletionTokens = openai.Int(maxTokens)
 		switch reasoningEffort {

internal/llm/provider/provider.go 🔗

@@ -65,6 +65,7 @@ type providerClientOptions struct {
 	model         func(config.ModelType) config.Model
 	disableCache  bool
 	systemMessage string
+	maxTokens     int64
 	extraHeaders  map[string]string
 	extraParams   map[string]string
 }
@@ -126,7 +127,13 @@ func WithSystemMessage(systemMessage string) ProviderClientOption {
 	}
 }
 
-func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+func WithMaxTokens(maxTokens int64) ProviderClientOption {
+	return func(options *providerClientOptions) {
+		options.maxTokens = maxTokens
+	}
+}
+
+func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
 	clientOptions := providerClientOptions{
 		baseURL:      cfg.BaseURL,
 		apiKey:       cfg.APIKey,