@@ -3,7 +3,9 @@ package loop
import (
"context"
"fmt"
+ "io"
"log/slog"
+ "strings"
"sync"
"time"
@@ -242,7 +244,24 @@ func (l *Loop) processLLMRequest(ctx context.Context) error {
llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
- resp, err := llmService.Do(llmCtx, req)
+ // Retry LLM requests that fail with retryable errors (EOF, connection reset)
+ const maxRetries = 2
+ var resp *llm.Response
+ var err error
+ for attempt := 1; attempt <= maxRetries; attempt++ {
+ resp, err = llmService.Do(llmCtx, req)
+ if err == nil {
+ break
+ }
+ if !isRetryableError(err) || attempt == maxRetries {
+ break
+ }
+ l.logger.Warn("LLM request failed with retryable error, retrying",
+ "error", err,
+ "attempt", attempt,
+ "max_retries", maxRetries)
+ time.Sleep(time.Second * time.Duration(attempt)) // Simple backoff
+ }
if err != nil {
// Record the error as a message so it can be displayed in the UI
// EndOfTurn must be true so the agent working state is properly updated
@@ -647,3 +666,31 @@ func (l *Loop) insertMissingToolResults(req *llm.Request) {
}
}
}
+
+// isRetryableError checks if an error is transient and should be retried.
+// This includes EOF errors (connection closed unexpectedly) and similar network issues.
+func isRetryableError(err error) bool {
+ if err == nil {
+ return false
+ }
+ // Check for io.EOF and io.ErrUnexpectedEOF
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ return true
+ }
+ // Check error message for common retryable patterns
+ errStr := err.Error()
+ retryablePatterns := []string{
+ "EOF",
+ "connection reset",
+ "connection refused",
+ "no such host",
+ "network is unreachable",
+ "i/o timeout",
+ }
+ for _, pattern := range retryablePatterns {
+ if strings.Contains(errStr, pattern) {
+ return true
+ }
+ }
+ return false
+}
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"os"
"os/exec"
"path/filepath"
@@ -1572,6 +1573,164 @@ func (e *errorLLMService) MaxImageDimension() int {
return 2000
}
+// retryableLLMService fails with a retryable error a specified number of times, then succeeds
+type retryableLLMService struct {
+ failuresRemaining int
+ callCount int
+ mu sync.Mutex
+}
+
+func (r *retryableLLMService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) {
+ r.mu.Lock()
+ r.callCount++
+ if r.failuresRemaining > 0 {
+ r.failuresRemaining--
+ r.mu.Unlock()
+ return nil, fmt.Errorf("connection error: EOF")
+ }
+ r.mu.Unlock()
+ return &llm.Response{
+ Content: []llm.Content{
+ {Type: llm.ContentTypeText, Text: "Success after retry"},
+ },
+ StopReason: llm.StopReasonEndTurn,
+ }, nil
+}
+
+func (r *retryableLLMService) TokenContextWindow() int {
+ return 200000
+}
+
+func (r *retryableLLMService) MaxImageDimension() int {
+ return 2000
+}
+
+func (r *retryableLLMService) getCallCount() int {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.callCount
+}
+
+func TestLLMRequestRetryOnEOF(t *testing.T) {
+ // Test that LLM requests are retried on EOF errors
+ retryService := &retryableLLMService{failuresRemaining: 1}
+
+ var recordedMessages []llm.Message
+ recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+ recordedMessages = append(recordedMessages, message)
+ return nil
+ }
+
+ loop := NewLoop(Config{
+ LLM: retryService,
+ History: []llm.Message{},
+ Tools: []*llm.Tool{},
+ RecordMessage: recordFunc,
+ })
+
+ // Queue a user message
+ userMessage := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}},
+ }
+ loop.QueueUserMessage(userMessage)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ err := loop.ProcessOneTurn(ctx)
+ if err != nil {
+ t.Fatalf("expected no error after retry, got: %v", err)
+ }
+
+ // Should have been called twice (1 failure + 1 success)
+ if retryService.getCallCount() != 2 {
+ t.Errorf("expected 2 LLM calls (retry), got %d", retryService.getCallCount())
+ }
+
+ // Check that success message was recorded
+ if len(recordedMessages) != 1 {
+ t.Fatalf("expected 1 recorded message (success), got %d", len(recordedMessages))
+ }
+
+ if !strings.Contains(recordedMessages[0].Content[0].Text, "Success after retry") {
+ t.Errorf("expected success message, got: %s", recordedMessages[0].Content[0].Text)
+ }
+}
+
+func TestLLMRequestRetryExhausted(t *testing.T) {
+ // Test that after max retries, error is returned
+ retryService := &retryableLLMService{failuresRemaining: 10} // More than maxRetries
+
+ var recordedMessages []llm.Message
+ recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+ recordedMessages = append(recordedMessages, message)
+ return nil
+ }
+
+ loop := NewLoop(Config{
+ LLM: retryService,
+ History: []llm.Message{},
+ Tools: []*llm.Tool{},
+ RecordMessage: recordFunc,
+ })
+
+ userMessage := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}},
+ }
+ loop.QueueUserMessage(userMessage)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ err := loop.ProcessOneTurn(ctx)
+ if err == nil {
+ t.Fatal("expected error after exhausting retries")
+ }
+
+ // Should have been called maxRetries times (2)
+ if retryService.getCallCount() != 2 {
+ t.Errorf("expected 2 LLM calls (maxRetries), got %d", retryService.getCallCount())
+ }
+
+ // Check error message was recorded
+ if len(recordedMessages) != 1 {
+ t.Fatalf("expected 1 recorded message (error), got %d", len(recordedMessages))
+ }
+
+ if !strings.Contains(recordedMessages[0].Content[0].Text, "LLM request failed") {
+ t.Errorf("expected error message, got: %s", recordedMessages[0].Content[0].Text)
+ }
+}
+
+func TestIsRetryableError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ retryable bool
+ }{
+ {"nil error", nil, false},
+ {"io.EOF", io.EOF, true},
+ {"io.ErrUnexpectedEOF", io.ErrUnexpectedEOF, true},
+ {"EOF error string", fmt.Errorf("EOF"), true},
+ {"wrapped EOF", fmt.Errorf("connection error: EOF"), true},
+ {"connection reset", fmt.Errorf("connection reset by peer"), true},
+ {"connection refused", fmt.Errorf("connection refused"), true},
+ {"timeout", fmt.Errorf("i/o timeout"), true},
+ {"api error", fmt.Errorf("rate limit exceeded"), false},
+ {"generic error", fmt.Errorf("something went wrong"), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := isRetryableError(tt.err); got != tt.retryable {
+ t.Errorf("isRetryableError(%v) = %v, want %v", tt.err, got, tt.retryable)
+ }
+ })
+ }
+}
+
func TestCheckGitStateChange(t *testing.T) {
// Create a test repo
tmpDir := t.TempDir()