diff --git a/loop/loop.go b/loop/loop.go index 140b484b55b8b1f790b096c87445ba143b142f2d..acbc517440d509b168aa6df2a8565333403256d4 100644 --- a/loop/loop.go +++ b/loop/loop.go @@ -3,7 +3,9 @@ package loop import ( "context" "fmt" + "io" "log/slog" + "strings" "sync" "time" @@ -242,7 +244,24 @@ func (l *Loop) processLLMRequest(ctx context.Context) error { llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() - resp, err := llmService.Do(llmCtx, req) + // Retry LLM requests that fail with retryable errors (EOF, connection reset) + const maxRetries = 2 + var resp *llm.Response + var err error + for attempt := 1; attempt <= maxRetries; attempt++ { + resp, err = llmService.Do(llmCtx, req) + if err == nil { + break + } + if !isRetryableError(err) || attempt == maxRetries { + break + } + l.logger.Warn("LLM request failed with retryable error, retrying", + "error", err, + "attempt", attempt, + "max_retries", maxRetries) + time.Sleep(time.Second * time.Duration(attempt)) // Simple backoff + } if err != nil { // Record the error as a message so it can be displayed in the UI // EndOfTurn must be true so the agent working state is properly updated @@ -647,3 +666,31 @@ func (l *Loop) insertMissingToolResults(req *llm.Request) { } } } + +// isRetryableError checks if an error is transient and should be retried. +// This includes EOF errors (connection closed unexpectedly) and similar network issues. +func isRetryableError(err error) bool { + if err == nil { + return false + } + // Check for io.EOF and io.ErrUnexpectedEOF + if err == io.EOF || err == io.ErrUnexpectedEOF { + return true + } + // Check error message for common retryable patterns + errStr := err.Error() + retryablePatterns := []string{ + "EOF", + "connection reset", + "connection refused", + "no such host", + "network is unreachable", + "i/o timeout", + } + for _, pattern := range retryablePatterns { + if strings.Contains(errStr, pattern) { + return true + } + } + return false +} diff --git a/loop/loop_test.go b/loop/loop_test.go index 692783bea4722f3c9e48ae5736c2db6312386b05..86f7d9c856076cc94db6f6db943e99cd8916639d 100644 --- a/loop/loop_test.go +++ b/loop/loop_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -1572,6 +1573,164 @@ func (e *errorLLMService) MaxImageDimension() int { return 2000 } +// retryableLLMService fails with a retryable error a specified number of times, then succeeds +type retryableLLMService struct { + failuresRemaining int + callCount int + mu sync.Mutex +} + +func (r *retryableLLMService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) { + r.mu.Lock() + r.callCount++ + if r.failuresRemaining > 0 { + r.failuresRemaining-- + r.mu.Unlock() + return nil, fmt.Errorf("connection error: EOF") + } + r.mu.Unlock() + return &llm.Response{ + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Success after retry"}, + }, + StopReason: llm.StopReasonEndTurn, + }, nil +} + +func (r *retryableLLMService) TokenContextWindow() int { + return 200000 +} + +func (r *retryableLLMService) MaxImageDimension() int { + return 2000 +} + +func (r *retryableLLMService) getCallCount() int { + r.mu.Lock() + defer r.mu.Unlock() + return r.callCount +} + +func TestLLMRequestRetryOnEOF(t *testing.T) { + // Test that LLM requests are retried on EOF errors + retryService := &retryableLLMService{failuresRemaining: 1} + + var recordedMessages []llm.Message + recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + recordedMessages = append(recordedMessages, message) + return nil + } + + loop := NewLoop(Config{ + LLM: retryService, + History: []llm.Message{}, + Tools: []*llm.Tool{}, + RecordMessage: recordFunc, + }) + + // Queue a user message + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}}, + } + loop.QueueUserMessage(userMessage) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := loop.ProcessOneTurn(ctx) + if err != nil { + t.Fatalf("expected no error after retry, got: %v", err) + } + + // Should have been called twice (1 failure + 1 success) + if retryService.getCallCount() != 2 { + t.Errorf("expected 2 LLM calls (retry), got %d", retryService.getCallCount()) + } + + // Check that success message was recorded + if len(recordedMessages) != 1 { + t.Fatalf("expected 1 recorded message (success), got %d", len(recordedMessages)) + } + + if !strings.Contains(recordedMessages[0].Content[0].Text, "Success after retry") { + t.Errorf("expected success message, got: %s", recordedMessages[0].Content[0].Text) + } +} + +func TestLLMRequestRetryExhausted(t *testing.T) { + // Test that after max retries, error is returned + retryService := &retryableLLMService{failuresRemaining: 10} // More than maxRetries + + var recordedMessages []llm.Message + recordFunc := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + recordedMessages = append(recordedMessages, message) + return nil + } + + loop := NewLoop(Config{ + LLM: retryService, + History: []llm.Message{}, + Tools: []*llm.Tool{}, + RecordMessage: recordFunc, + }) + + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test message"}}, + } + loop.QueueUserMessage(userMessage) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := loop.ProcessOneTurn(ctx) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + + // Should have been called maxRetries times (2) + if retryService.getCallCount() != 2 { + t.Errorf("expected 2 LLM calls (maxRetries), got %d", retryService.getCallCount()) + } + + // Check error message was recorded + if len(recordedMessages) != 1 { + t.Fatalf("expected 1 recorded message (error), got %d", len(recordedMessages)) + } + + if !strings.Contains(recordedMessages[0].Content[0].Text, "LLM request failed") { + t.Errorf("expected error message, got: %s", recordedMessages[0].Content[0].Text) + } +} + +func TestIsRetryableError(t *testing.T) { + tests := []struct { + name string + err error + retryable bool + }{ + {"nil error", nil, false}, + {"io.EOF", io.EOF, true}, + {"io.ErrUnexpectedEOF", io.ErrUnexpectedEOF, true}, + {"EOF error string", fmt.Errorf("EOF"), true}, + {"wrapped EOF", fmt.Errorf("connection error: EOF"), true}, + {"connection reset", fmt.Errorf("connection reset by peer"), true}, + {"connection refused", fmt.Errorf("connection refused"), true}, + {"timeout", fmt.Errorf("i/o timeout"), true}, + {"api error", fmt.Errorf("rate limit exceeded"), false}, + {"generic error", fmt.Errorf("something went wrong"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isRetryableError(tt.err); got != tt.retryable { + t.Errorf("isRetryableError(%v) = %v, want %v", tt.err, got, tt.retryable) + } + }) + } +} + func TestCheckGitStateChange(t *testing.T) { // Create a test repo tmpDir := t.TempDir()