fix: handle anthropic 429s

adamdottv created

Change summary

internal/llm/provider/anthropic.go | 186 +++++++++++++++++++++----------
1 file changed, 124 insertions(+), 62 deletions(-)

Detailed changes

internal/llm/provider/anthropic.go 🔗

@@ -4,7 +4,10 @@ import (
 	"context"
 	"encoding/json"
 	"errors"
+	"fmt"
+	"log"
 	"strings"
+	"time"
 
 	"github.com/anthropics/anthropic-sdk-go"
 	"github.com/anthropics/anthropic-sdk-go/option"
@@ -125,87 +128,145 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
 		temperature = anthropic.Float(1)
 	}
 
-	stream := a.client.Messages.NewStreaming(
-		ctx,
-		anthropic.MessageNewParams{
-			Model:       anthropic.Model(a.model.APIModel),
-			MaxTokens:   a.maxTokens,
-			Temperature: temperature,
-			Messages:    anthropicMessages,
-			Tools:       anthropicTools,
-			Thinking:    thinkingParam,
-			System: []anthropic.TextBlockParam{
-				{
-					Text: a.systemMessage,
-					CacheControl: anthropic.CacheControlEphemeralParam{
-						Type: "ephemeral",
-					},
-				},
-			},
-		},
-		option.WithMaxRetries(8),
-	)
-
 	eventChan := make(chan ProviderEvent)
 
 	go func() {
 		defer close(eventChan)
 
-		accumulatedMessage := anthropic.Message{}
+		const maxRetries = 8
+		attempts := 0
 
-		for stream.Next() {
-			event := stream.Current()
-			err := accumulatedMessage.Accumulate(event)
-			if err != nil {
-				eventChan <- ProviderEvent{Type: EventError, Error: err}
-				return
+		for {
+			// If this isn't the first attempt, we're retrying
+			if attempts > 0 {
+				if attempts > maxRetries {
+					eventChan <- ProviderEvent{
+						Type:  EventError,
+						Error: errors.New("maximum retry attempts reached for rate limit (429)"),
+					}
+					return
+				}
+
+				// Inform user we're retrying with attempt number
+				eventChan <- ProviderEvent{
+					Type:    EventContentDelta,
+					Content: fmt.Sprintf("\n\n[Retrying due to rate limit... attempt %d of %d]\n\n", attempts, maxRetries),
+				}
+
+				// Calculate backoff with exponential backoff and jitter
+				backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
+				jitterMs := int(float64(backoffMs) * 0.2)
+				totalBackoffMs := backoffMs + jitterMs
+
+				// Sleep with backoff, respecting context cancellation
+				select {
+				case <-ctx.Done():
+					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+					return
+				case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
+					// Continue with retry
+				}
 			}
 
-			switch event := event.AsAny().(type) {
-			case anthropic.ContentBlockStartEvent:
-				eventChan <- ProviderEvent{Type: EventContentStart}
+			attempts++
+
+			// Create new streaming request
+			stream := a.client.Messages.NewStreaming(
+				ctx,
+				anthropic.MessageNewParams{
+					Model:       anthropic.Model(a.model.APIModel),
+					MaxTokens:   a.maxTokens,
+					Temperature: temperature,
+					Messages:    anthropicMessages,
+					Tools:       anthropicTools,
+					Thinking:    thinkingParam,
+					System: []anthropic.TextBlockParam{
+						{
+							Text: a.systemMessage,
+							CacheControl: anthropic.CacheControlEphemeralParam{
+								Type: "ephemeral",
+							},
+						},
+					},
+				},
+			)
 
-			case anthropic.ContentBlockDeltaEvent:
-				if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
-					eventChan <- ProviderEvent{
-						Type:     EventThinkingDelta,
-						Thinking: event.Delta.Thinking,
+			// Process stream events
+			accumulatedMessage := anthropic.Message{}
+			streamSuccess := false
+
+			// Process the stream until completion or error
+			for stream.Next() {
+				event := stream.Current()
+				err := accumulatedMessage.Accumulate(event)
+				if err != nil {
+					eventChan <- ProviderEvent{Type: EventError, Error: err}
+					return // Don't retry on accumulation errors
+				}
+
+				switch event := event.AsAny().(type) {
+				case anthropic.ContentBlockStartEvent:
+					eventChan <- ProviderEvent{Type: EventContentStart}
+
+				case anthropic.ContentBlockDeltaEvent:
+					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
+						eventChan <- ProviderEvent{
+							Type:     EventThinkingDelta,
+							Thinking: event.Delta.Thinking,
+						}
+					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
+						eventChan <- ProviderEvent{
+							Type:    EventContentDelta,
+							Content: event.Delta.Text,
+						}
 					}
-				} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
-					eventChan <- ProviderEvent{
-						Type:    EventContentDelta,
-						Content: event.Delta.Text,
+
+				case anthropic.ContentBlockStopEvent:
+					eventChan <- ProviderEvent{Type: EventContentStop}
+
+				case anthropic.MessageStopEvent:
+					streamSuccess = true
+					content := ""
+					for _, block := range accumulatedMessage.Content {
+						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
+							content += text.Text
+						}
 					}
-				}
 
-			case anthropic.ContentBlockStopEvent:
-				eventChan <- ProviderEvent{Type: EventContentStop}
+					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
+					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
 
-			case anthropic.MessageStopEvent:
-				content := ""
-				for _, block := range accumulatedMessage.Content {
-					if text, ok := block.AsAny().(anthropic.TextBlock); ok {
-						content += text.Text
+					eventChan <- ProviderEvent{
+						Type: EventComplete,
+						Response: &ProviderResponse{
+							Content:      content,
+							ToolCalls:    toolCalls,
+							Usage:        tokenUsage,
+							FinishReason: string(accumulatedMessage.StopReason),
+						},
 					}
 				}
+			}
 
-				toolCalls := a.extractToolCalls(accumulatedMessage.Content)
-				tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
+			// If the stream completed successfully, we're done
+			if streamSuccess {
+				return
+			}
 
-				eventChan <- ProviderEvent{
-					Type: EventComplete,
-					Response: &ProviderResponse{
-						Content:      content,
-						ToolCalls:    toolCalls,
-						Usage:        tokenUsage,
-						FinishReason: string(accumulatedMessage.StopReason),
-					},
+			// Check for stream errors
+			err := stream.Err()
+			if err != nil {
+				log.Println("error", err)
+
+				var apierr *anthropic.Error
+				if errors.As(err, &apierr) && apierr.StatusCode == 429 {
+					continue
 				}
-			}
-		}
 
-		if stream.Err() != nil {
-			eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
+				// For non-rate limit errors, report and exit
+				eventChan <- ProviderEvent{Type: EventError, Error: err}
+				return
+			}
 		}
 	}()
 
@@ -319,3 +380,4 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 
 	return anthropicMessages
 }
+