shelley: retry LLM requests on EOF and transient network errors

Philip Zeyliger and Shelley created

Prompt: "LLM request failed: EOF" If the LLM fails with EOF, Shelley should retry at least once

When an LLM request fails with EOF or other transient network errors
(connection reset, connection refused, timeout, etc.), retry up to 2
times with exponential backoff before recording and returning the error.

This helps handle temporary connection issues that can occur with
long-running LLM requests.

Co-authored-by: Shelley <shelley@exe.dev>

Change summary

loop/loop.go      |  49 ++++++++++++++
loop/loop_test.go | 159 +++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 207 insertions(+), 1 deletion(-)

Detailed changes

loop/loop.go 🔗

@@ -3,7 +3,9 @@ package loop
 import (
 	"context"
 	"fmt"
+	"io"
 	"log/slog"
+	"strings"
 	"sync"
 	"time"
 
@@ -242,7 +244,24 @@ func (l *Loop) processLLMRequest(ctx context.Context) error {
 	llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
 	defer cancel()
 
-	resp, err := llmService.Do(llmCtx, req)
+	// Retry LLM requests that fail with retryable errors (EOF, connection reset)
+	const maxRetries = 2
+	var resp *llm.Response
+	var err error
+	for attempt := 1; attempt <= maxRetries; attempt++ {
+		resp, err = llmService.Do(llmCtx, req)
+		if err == nil {
+			break
+		}
+		if !isRetryableError(err) || attempt == maxRetries {
+			break
+		}
+		l.logger.Warn("LLM request failed with retryable error, retrying",
+			"error", err,
+			"attempt", attempt,
+			"max_retries", maxRetries)
+		time.Sleep(time.Second * time.Duration(attempt)) // Simple backoff
+	}
 	if err != nil {
 		// Record the error as a message so it can be displayed in the UI
 		// EndOfTurn must be true so the agent working state is properly updated
@@ -647,3 +666,31 @@ func (l *Loop) insertMissingToolResults(req *llm.Request) {
 		}
 	}
 }
+
+// isRetryableError checks if an error is transient and should be retried.
+// This includes EOF errors (connection closed unexpectedly) and similar network issues.
+func isRetryableError(err error) bool {
+	if err == nil {
+		return false
+	}
+	// Check for io.EOF and io.ErrUnexpectedEOF
+	if err == io.EOF || err == io.ErrUnexpectedEOF {
+		return true
+	}
+	// Check error message for common retryable patterns
+	errStr := err.Error()
+	retryablePatterns := []string{
+		"EOF",
+		"connection reset",
+		"connection refused",
+		"no such host",
+		"network is unreachable",
+		"i/o timeout",
+	}
+	for _, pattern := range retryablePatterns {
+		if strings.Contains(errStr, pattern) {
+			return true
+		}
+	}
+	return false
+}

loop/loop_test.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"io"
 	"os"
 	"os/exec"
 	"path/filepath"
@@ -1572,6 +1573,164 @@ func (e *errorLLMService) MaxImageDimension() int {
 	return 2000
 }
 
+// retryableLLMService fails with a retryable error a specified number of times, then succeeds
+type retryableLLMService struct {
+	failuresRemaining int
+	callCount         int
+	mu                sync.Mutex
+}
+
+func (r *retryableLLMService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) {
+	r.mu.Lock()
+	r.callCount++
+	if r.failuresRemaining > 0 {
+		r.failuresRemaining--
+		r.mu.Unlock()
+		return nil, fmt.Errorf("connection error: EOF")
+	}
+	r.mu.Unlock()
+	return &llm.Response{
+		Content: []llm.Content{
+			{Type: llm.ContentTypeText, Text: "Success after retry"},
+		},
+		StopReason: llm.StopReasonEndTurn,
+	}, nil
+}
+
+func (r *retryableLLMService) TokenContextWindow() int {
+	return 200000
+}
+
+func (r *retryableLLMService) MaxImageDimension() int {
+	return 2000
+}
+
+func (r *retryableLLMService) getCallCount() int {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	return r.callCount
+}
+
+func TestLLMRequestRetryOnEOF(t *testing.T) {
+	// Test that LLM requests are retried on EOF errors
+	retryService := &retryableLLMService{failuresRemaining: 1}
+
+	var recordedMessages []llm.Message
+	recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+		recordedMessages = append(recordedMessages, message)
+		return nil
+	}
+
+	loop := NewLoop(Config{
+		LLM:           retryService,
+		History:       []llm.Message{},
+		Tools:         []*llm.Tool{},
+		RecordMessage: recordFunc,
+	})
+
+	// Queue a user message
+	userMessage := llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}},
+	}
+	loop.QueueUserMessage(userMessage)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+
+	err := loop.ProcessOneTurn(ctx)
+	if err != nil {
+		t.Fatalf("expected no error after retry, got: %v", err)
+	}
+
+	// Should have been called twice (1 failure + 1 success)
+	if retryService.getCallCount() != 2 {
+		t.Errorf("expected 2 LLM calls (retry), got %d", retryService.getCallCount())
+	}
+
+	// Check that success message was recorded
+	if len(recordedMessages) != 1 {
+		t.Fatalf("expected 1 recorded message (success), got %d", len(recordedMessages))
+	}
+
+	if !strings.Contains(recordedMessages[0].Content[0].Text, "Success after retry") {
+		t.Errorf("expected success message, got: %s", recordedMessages[0].Content[0].Text)
+	}
+}
+
+func TestLLMRequestRetryExhausted(t *testing.T) {
+	// Test that after max retries, error is returned
+	retryService := &retryableLLMService{failuresRemaining: 10} // More than maxRetries
+
+	var recordedMessages []llm.Message
+	recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+		recordedMessages = append(recordedMessages, message)
+		return nil
+	}
+
+	loop := NewLoop(Config{
+		LLM:           retryService,
+		History:       []llm.Message{},
+		Tools:         []*llm.Tool{},
+		RecordMessage: recordFunc,
+	})
+
+	userMessage := llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}},
+	}
+	loop.QueueUserMessage(userMessage)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+
+	err := loop.ProcessOneTurn(ctx)
+	if err == nil {
+		t.Fatal("expected error after exhausting retries")
+	}
+
+	// Should have been called maxRetries times (2)
+	if retryService.getCallCount() != 2 {
+		t.Errorf("expected 2 LLM calls (maxRetries), got %d", retryService.getCallCount())
+	}
+
+	// Check error message was recorded
+	if len(recordedMessages) != 1 {
+		t.Fatalf("expected 1 recorded message (error), got %d", len(recordedMessages))
+	}
+
+	if !strings.Contains(recordedMessages[0].Content[0].Text, "LLM request failed") {
+		t.Errorf("expected error message, got: %s", recordedMessages[0].Content[0].Text)
+	}
+}
+
+func TestIsRetryableError(t *testing.T) {
+	tests := []struct {
+		name      string
+		err       error
+		retryable bool
+	}{
+		{"nil error", nil, false},
+		{"io.EOF", io.EOF, true},
+		{"io.ErrUnexpectedEOF", io.ErrUnexpectedEOF, true},
+		{"EOF error string", fmt.Errorf("EOF"), true},
+		{"wrapped EOF", fmt.Errorf("connection error: EOF"), true},
+		{"connection reset", fmt.Errorf("connection reset by peer"), true},
+		{"connection refused", fmt.Errorf("connection refused"), true},
+		{"timeout", fmt.Errorf("i/o timeout"), true},
+		{"api error", fmt.Errorf("rate limit exceeded"), false},
+		{"generic error", fmt.Errorf("something went wrong"), false},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := isRetryableError(tt.err); got != tt.retryable {
+				t.Errorf("isRetryableError(%v) = %v, want %v", tt.err, got, tt.retryable)
+			}
+		})
+	}
+}
+
 func TestCheckGitStateChange(t *testing.T) {
 	// Create a test repo
 	tmpDir := t.TempDir()