cleanup retry

Kujtim Hoxha created

Change summary

internal/llm/provider/anthropic.go | 117 ++++++++++++-------------------
internal/llm/tools/fetch.go        |   6 -
2 files changed, 45 insertions(+), 78 deletions(-)

Detailed changes

internal/llm/provider/anthropic.go 🔗

@@ -135,40 +135,9 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
 		attempts := 0
 
 		for {
-			// If this isn't the first attempt, we're retrying
-			if attempts > 0 {
-				if attempts > maxRetries {
-					eventChan <- ProviderEvent{
-						Type:  EventError,
-						Error: errors.New("maximum retry attempts reached for rate limit (429)"),
-					}
-					return
-				}
-
-				// Inform user we're retrying with attempt number
-				eventChan <- ProviderEvent{
-					Type: EventWarning,
-					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
-				}
-
-				// Calculate backoff with exponential backoff and jitter
-				backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
-				jitterMs := int(float64(backoffMs) * 0.2)
-				totalBackoffMs := backoffMs + jitterMs
-
-				// Sleep with backoff, respecting context cancellation
-				select {
-				case <-ctx.Done():
-					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
-					return
-				case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
-					// Continue with retry
-				}
-			}
 
 			attempts++
 
-			// Create new streaming request
 			stream := a.client.Messages.NewStreaming(
 				ctx,
 				anthropic.MessageNewParams{
@@ -189,11 +158,8 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
 				},
 			)
 
-			// Process stream events
 			accumulatedMessage := anthropic.Message{}
-			streamSuccess := false
 
-			// Process the stream until completion or error
 			for stream.Next() {
 				event := stream.Current()
 				err := accumulatedMessage.Accumulate(event)
@@ -223,7 +189,6 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
 					eventChan <- ProviderEvent{Type: EventContentStop}
 
 				case anthropic.MessageStopEvent:
-					streamSuccess = true
 					content := ""
 					for _, block := range accumulatedMessage.Content {
 						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
@@ -246,51 +211,59 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa
 				}
 			}
 
-			// If the stream completed successfully, we're done
-			if streamSuccess {
+			err := stream.Err()
+			if err == nil {
 				return
 			}
 
-			// Check for stream errors
-			err := stream.Err()
-			if err != nil {
-				var apierr *anthropic.Error
-				if errors.As(err, &apierr) {
-					if apierr.StatusCode == 429 || apierr.StatusCode == 529 {
-						// Check for Retry-After header
-						if retryAfterValues := apierr.Response.Header.Values("Retry-After"); len(retryAfterValues) > 0 {
-							// Parse the retry after value (seconds)
-							var retryAfterSec int
-							if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
-								retryMs := retryAfterSec * 1000
-
-								// Inform user of retry with specific wait time
-								eventChan <- ProviderEvent{
-									Type: EventWarning,
-									Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
-								}
-
-								// Sleep respecting context cancellation
-								select {
-								case <-ctx.Done():
-									eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
-									return
-								case <-time.After(time.Duration(retryMs) * time.Millisecond):
-									// Continue with retry after specified delay
-									continue
-								}
-							}
-						}
+			var apierr *anthropic.Error
+			if !errors.As(err, &apierr) {
+				eventChan <- ProviderEvent{Type: EventError, Error: err}
+				return
+			}
 
-						// Fall back to exponential backoff if Retry-After parsing failed
-						continue
+			if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
+				eventChan <- ProviderEvent{Type: EventError, Error: err}
+				return
+			}
+
+			if attempts > maxRetries {
+				eventChan <- ProviderEvent{
+					Type:  EventError,
+					Error: errors.New("maximum retry attempts reached for rate limit (429)"),
+				}
+				return
+			}
+
+			retryMs := 0
+			retryAfterValues := apierr.Response.Header.Values("Retry-After")
+			if len(retryAfterValues) > 0 {
+				var retryAfterSec int
+				if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
+					retryMs = retryAfterSec * 1000
+					eventChan <- ProviderEvent{
+						Type: EventWarning,
+						Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
 					}
 				}
+			} else {
+				eventChan <- ProviderEvent{
+					Type: EventWarning,
+					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
+				}
 
-				// For non-rate limit errors, report and exit
-				eventChan <- ProviderEvent{Type: EventError, Error: err}
+				backoffMs := 2000 * (1 << (attempts - 1))
+				jitterMs := int(float64(backoffMs) * 0.2)
+				retryMs = backoffMs + jitterMs
+			}
+			select {
+			case <-ctx.Done():
+				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
 				return
+			case <-time.After(time.Duration(retryMs) * time.Millisecond):
+				continue
 			}
+
 		}
 	}()
 
@@ -388,7 +361,6 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 			}
 
-			// Skip empty assistant messages completely
 			if len(blocks) > 0 {
 				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 			}
@@ -404,4 +376,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
 
 	return anthropicMessages
 }
-

internal/llm/tools/fetch.go 🔗

@@ -121,11 +121,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 			ToolName:    FetchToolName,
 			Action:      "fetch",
 			Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
-			Params: FetchPermissionsParams{
-				URL:     params.URL,
-				Format:  params.Format,
-				Timeout: params.Timeout,
-			},
+			Params:      FetchPermissionsParams(params),
 		},
 	)