chore: change how we handle max tokens for anthropic

Kujtim Hoxha created

Change summary

internal/config/config.go          | 38 ----------------
internal/llm/agent/agent.go        |  5 --
internal/llm/provider/anthropic.go | 74 ++++++++++++++++++++++++++-----
internal/tui/tui.go                | 20 -------
4 files changed, 63 insertions(+), 74 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -890,44 +890,6 @@ func GetAgentModel(agentID AgentID) Model {
 	return Model{}
 }
 
-// GetAgentEffectiveMaxTokens returns the effective max tokens for an agent,
-// considering any overrides from the preferred model configuration
-func GetAgentEffectiveMaxTokens(agentID AgentID) int64 {
-	cfg := Get()
-	agent, ok := cfg.Agents[agentID]
-	if !ok {
-		logging.Error("Agent not found", "agent_id", agentID)
-		return 0
-	}
-
-	var preferredModel PreferredModel
-	switch agent.Model {
-	case LargeModel:
-		preferredModel = cfg.Models.Large
-	case SmallModel:
-		preferredModel = cfg.Models.Small
-	default:
-		logging.Warn("Unknown model type for agent", "agent_id", agentID, "model_type", agent.Model)
-		preferredModel = cfg.Models.Large // Fallback to large model
-	}
-
-	// Get the base model configuration
-	baseModel := GetAgentModel(agentID)
-	if baseModel.ID == "" {
-		return 0
-	}
-
-	// Start with the default max tokens from the base model
-	maxTokens := baseModel.DefaultMaxTokens
-
-	// Override with preferred model max tokens if set
-	if preferredModel.MaxTokens > 0 {
-		maxTokens = preferredModel.MaxTokens
-	}
-
-	return maxTokens
-}
-
 func GetAgentProvider(agentID AgentID) ProviderConfig {
 	cfg := Get()
 	agent, ok := cfg.Agents[agentID]

internal/llm/agent/agent.go 🔗

@@ -50,7 +50,6 @@ type AgentEvent struct {
 type Service interface {
 	pubsub.Suscriber[AgentEvent]
 	Model() config.Model
-	EffectiveMaxTokens() int64
 	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 	Cancel(sessionID string)
 	CancelAll()
@@ -230,10 +229,6 @@ func (a *agent) Model() config.Model {
 	return config.GetAgentModel(a.agentCfg.ID)
 }
 
-func (a *agent) EffectiveMaxTokens() int64 {
-	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
-}
-
 func (a *agent) Cancel(sessionID string) {
 	// Cancel regular requests
 	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {

internal/llm/provider/anthropic.go 🔗

@@ -6,6 +6,8 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"regexp"
+	"strconv"
 	"time"
 
 	"github.com/anthropics/anthropic-sdk-go"
@@ -19,9 +21,10 @@ import (
 )
 
 type anthropicClient struct {
-	providerOptions providerClientOptions
-	useBedrock      bool
-	client          anthropic.Client
+	providerOptions   providerClientOptions
+	useBedrock        bool
+	client            anthropic.Client
+	adjustedMaxTokens int // Used when context limit is hit
 }
 
 type AnthropicClient ProviderClient
@@ -171,6 +174,11 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 		maxTokens = a.providerOptions.maxTokens
 	}
 
+	// Use adjusted max tokens if context limit was hit
+	if a.adjustedMaxTokens > 0 {
+		maxTokens = int64(a.adjustedMaxTokens)
+	}
+
 	return anthropic.MessageNewParams{
 		Model:       anthropic.Model(model.ID),
 		MaxTokens:   maxTokens,
@@ -190,16 +198,18 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
 }
 
 func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
-	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
 	cfg := config.Get()
-	if cfg.Options.Debug {
-		jsonData, _ := json.Marshal(preparedMessages)
-		logging.Debug("Prepared messages", "messages", string(jsonData))
-	}
 
 	attempts := 0
 	for {
 		attempts++
+		// Prepare messages on each attempt in case max_tokens was adjusted
+		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+		if cfg.Options.Debug {
+			jsonData, _ := json.Marshal(preparedMessages)
+			logging.Debug("Prepared messages", "messages", string(jsonData))
+		}
+
 		anthropicResponse, err := a.client.Messages.New(
 			ctx,
 			preparedMessages,
@@ -239,17 +249,19 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
 }
 
 func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
-	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
 	cfg := config.Get()
-	if cfg.Options.Debug {
-		// jsonData, _ := json.Marshal(preparedMessages)
-		// logging.Debug("Prepared messages", "messages", string(jsonData))
-	}
 	attempts := 0
 	eventChan := make(chan ProviderEvent)
 	go func() {
 		for {
 			attempts++
+			// Prepare messages on each attempt in case max_tokens was adjusted
+			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
+			if cfg.Options.Debug {
+				jsonData, _ := json.Marshal(preparedMessages)
+				logging.Debug("Prepared messages", "messages", string(jsonData))
+			}
+
 			anthropicStream := a.client.Messages.NewStreaming(
 				ctx,
 				preparedMessages,
@@ -395,6 +407,15 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
 		return true, 0, nil
 	}
 
+	// Handle context limit exceeded error (400 Bad Request)
+	if apiErr.StatusCode == 400 {
+		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
+			a.adjustedMaxTokens = adjusted
+			logging.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
+			return true, 0, nil
+		}
+	}
+
 	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
 		return false, 0, err
 	}
@@ -413,6 +434,33 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
 	return true, int64(retryMs), nil
 }
 
+// handleContextLimitError parses context limit error and returns adjusted max_tokens
+func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
+	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
+	errorMsg := apiErr.Error()
+	re := regexp.MustCompile(`input length and max_tokens exceed context limit: (\d+) \+ (\d+) > (\d+)`)
+	matches := re.FindStringSubmatch(errorMsg)
+
+	if len(matches) != 4 {
+		return 0, false
+	}
+
+	inputTokens, err1 := strconv.Atoi(matches[1])
+	contextLimit, err2 := strconv.Atoi(matches[3])
+
+	if err1 != nil || err2 != nil {
+		return 0, false
+	}
+
+	// Calculate safe max_tokens with a buffer of 1000 tokens
+	safeMaxTokens := contextLimit - inputTokens - 1000
+
+	// Ensure we don't go below a minimum threshold
+	safeMaxTokens = max(safeMaxTokens, 1000)
+
+	return safeMaxTokens, true
+}
+
 func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
 	var toolCalls []message.ToolCall
 

internal/tui/tui.go 🔗

@@ -228,24 +228,8 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			if err == nil {
 				model := a.app.CoderAgent.Model()
 				contextWindow := model.ContextWindow
-				usedTokens := session.CompletionTokens + session.PromptTokens
-				remainingTokens := contextWindow - usedTokens
-
-				// Get effective max tokens for this agent (considering overrides)
-				maxTokens := a.app.CoderAgent.EffectiveMaxTokens()
-
-				// Apply 10% margin to max tokens
-				maxTokensWithMargin := int64(float64(maxTokens) * 1.1)
-
-				// Trigger auto-summarize if remaining tokens < max tokens + 10% margin
-				// Also ensure we have a reasonable minimum threshold to avoid too-frequent summaries
-				minThreshold := int64(1000) // Minimum 1000 tokens remaining before triggering
-				if maxTokensWithMargin < minThreshold {
-					maxTokensWithMargin = minThreshold
-				}
-
-				if remainingTokens < maxTokensWithMargin && !config.Get().Options.DisableAutoSummarize {
-					// Show compact confirmation dialog
+				tokens := session.CompletionTokens + session.PromptTokens
+				if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog
 					cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
 						Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),
 					}))