diff --git a/providers/anthropic/error.go b/providers/anthropic/error.go index c4022d4641f04570a34a2f93751d4abe155823bc..a2a351d50ff29fb66956b0e0d1623aba7a9c01d9 100644 --- a/providers/anthropic/error.go +++ b/providers/anthropic/error.go @@ -3,6 +3,7 @@ package anthropic import ( "cmp" "errors" + "io" "net/http" "regexp" "strconv" @@ -32,6 +33,14 @@ func toProviderErr(err error) error { return providerErr } + // Wrap in a `ProviderError` so `.IsRetriable()` works. + if errors.Is(err, io.ErrUnexpectedEOF) { + return &fantasy.ProviderError{ + Title: "stream transport error", + Message: err.Error(), + Cause: err, + } + } return err } diff --git a/providers/anthropic/error_test.go b/providers/anthropic/error_test.go new file mode 100644 index 0000000000000000000000000000000000000000..26273a3528b1afa0adbbf92a6d1ed76c756d3562 --- /dev/null +++ b/providers/anthropic/error_test.go @@ -0,0 +1,67 @@ +package anthropic + +import ( + "errors" + "fmt" + "io" + "testing" + + "charm.land/fantasy" +) + +func TestToProviderErr_WrapsUnexpectedEOF(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + }{ + {"direct", io.ErrUnexpectedEOF}, + {"wrapped", fmt.Errorf("read stream: %w", io.ErrUnexpectedEOF)}, + {"double_wrapped", fmt.Errorf("anthropic: %w", fmt.Errorf("sse: %w", io.ErrUnexpectedEOF))}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := toProviderErr(tc.err) + + var providerErr *fantasy.ProviderError + if !errors.As(got, &providerErr) { + t.Fatalf("toProviderErr did not wrap %v as *fantasy.ProviderError (got %T)", tc.err, got) + } + if !errors.Is(providerErr.Cause, io.ErrUnexpectedEOF) { + t.Errorf("ProviderError.Cause = %v, want chain containing io.ErrUnexpectedEOF", providerErr.Cause) + } + if !providerErr.IsRetryable() { + t.Error("wrapped io.ErrUnexpectedEOF must be retryable so retry.go engages") + } + }) + } +} + +func TestToProviderErr_PassesThroughUnrelatedErrors(t *testing.T) { + t.Parallel() + + err := errors.New("something unrelated") + got := toProviderErr(err) + if got != err { + t.Errorf("toProviderErr mutated unrelated error: got %v, want %v", got, err) + } +} + +func TestToProviderErr_PassesThroughPlainEOF(t *testing.T) { + t.Parallel() + + // A clean io.EOF at the end of a stream is not a failure — the streaming + // handler in anthropic.go treats it as a normal terminator and never + // calls toProviderErr with io.EOF. But if it ever did, we should not + // wrap it: io.EOF is not "retryable" in the ProviderError sense. + got := toProviderErr(io.EOF) + var providerErr *fantasy.ProviderError + if errors.As(got, &providerErr) { + t.Errorf("toProviderErr wrapped io.EOF as ProviderError; should pass through") + } +} diff --git a/providers/google/error.go b/providers/google/error.go index 25e915c2c154cac37a730a7beac37c8a47c1554f..cd706384e61c4016694b110760a6892f42857802 100644 --- a/providers/google/error.go +++ b/providers/google/error.go @@ -3,6 +3,7 @@ package google import ( "cmp" "errors" + "io" "regexp" "strconv" @@ -15,6 +16,14 @@ var googleContextPattern = regexp.MustCompile(`input token count.*?(\d+).*?excee func toProviderErr(err error) error { var apiErr genai.APIError if !errors.As(err, &apiErr) { + // Wrap in a `ProviderError` so `.IsRetriable()` works. + if errors.Is(err, io.ErrUnexpectedEOF) { + return &fantasy.ProviderError{ + Title: "stream transport error", + Message: err.Error(), + Cause: err, + } + } return err } diff --git a/providers/google/error_test.go b/providers/google/error_test.go new file mode 100644 index 0000000000000000000000000000000000000000..693f35ee2fa6487ea8ec4954d31421b2e7d9a3c7 --- /dev/null +++ b/providers/google/error_test.go @@ -0,0 +1,63 @@ +package google + +import ( + "errors" + "fmt" + "io" + "testing" + + "charm.land/fantasy" +) + +func TestToProviderErr_WrapsUnexpectedEOF(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + }{ + {"direct", io.ErrUnexpectedEOF}, + {"wrapped", fmt.Errorf("read stream: %w", io.ErrUnexpectedEOF)}, + {"double_wrapped", fmt.Errorf("google: %w", fmt.Errorf("sse: %w", io.ErrUnexpectedEOF))}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := toProviderErr(tc.err) + + var providerErr *fantasy.ProviderError + if !errors.As(got, &providerErr) { + t.Fatalf("toProviderErr did not wrap %v as *fantasy.ProviderError (got %T)", tc.err, got) + } + if !errors.Is(providerErr.Cause, io.ErrUnexpectedEOF) { + t.Errorf("ProviderError.Cause = %v, want chain containing io.ErrUnexpectedEOF", providerErr.Cause) + } + if !providerErr.IsRetryable() { + t.Error("wrapped io.ErrUnexpectedEOF must be retryable so retry.go engages") + } + }) + } +} + +func TestToProviderErr_PassesThroughUnrelatedErrors(t *testing.T) { + t.Parallel() + + err := errors.New("something unrelated") + got := toProviderErr(err) + if got != err { + t.Errorf("toProviderErr mutated unrelated error: got %v, want %v", got, err) + } +} + +func TestToProviderErr_PassesThroughPlainEOF(t *testing.T) { + t.Parallel() + + got := toProviderErr(io.EOF) + var providerErr *fantasy.ProviderError + if errors.As(got, &providerErr) { + t.Errorf("toProviderErr wrapped io.EOF as ProviderError; should pass through") + } +} diff --git a/providers/openai/error.go b/providers/openai/error.go index 861b79eea14480ac9f9d94342bcf14797c4f2c80..5cbc76be04bfe23e9f79e03906b7a81df7cfdb76 100644 --- a/providers/openai/error.go +++ b/providers/openai/error.go @@ -34,6 +34,14 @@ func toProviderErr(err error) error { return providerErr } + // Wrap in a `ProviderError` so `.IsRetriable()` works. + if errors.Is(err, io.ErrUnexpectedEOF) { + return &fantasy.ProviderError{ + Title: "stream transport error", + Message: err.Error(), + Cause: err, + } + } return err } diff --git a/providers/openai/error_test.go b/providers/openai/error_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2914f85012f19350fd6f38894a9ec59cf16b6ce5 --- /dev/null +++ b/providers/openai/error_test.go @@ -0,0 +1,63 @@ +package openai + +import ( + "errors" + "fmt" + "io" + "testing" + + "charm.land/fantasy" +) + +func TestToProviderErr_WrapsUnexpectedEOF(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + }{ + {"direct", io.ErrUnexpectedEOF}, + {"wrapped", fmt.Errorf("read stream: %w", io.ErrUnexpectedEOF)}, + {"double_wrapped", fmt.Errorf("openai: %w", fmt.Errorf("sse: %w", io.ErrUnexpectedEOF))}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := toProviderErr(tc.err) + + var providerErr *fantasy.ProviderError + if !errors.As(got, &providerErr) { + t.Fatalf("toProviderErr did not wrap %v as *fantasy.ProviderError (got %T)", tc.err, got) + } + if !errors.Is(providerErr.Cause, io.ErrUnexpectedEOF) { + t.Errorf("ProviderError.Cause = %v, want chain containing io.ErrUnexpectedEOF", providerErr.Cause) + } + if !providerErr.IsRetryable() { + t.Error("wrapped io.ErrUnexpectedEOF must be retryable so retry.go engages") + } + }) + } +} + +func TestToProviderErr_PassesThroughUnrelatedErrors(t *testing.T) { + t.Parallel() + + err := errors.New("something unrelated") + got := toProviderErr(err) + if got != err { + t.Errorf("toProviderErr mutated unrelated error: got %v, want %v", got, err) + } +} + +func TestToProviderErr_PassesThroughPlainEOF(t *testing.T) { + t.Parallel() + + got := toProviderErr(io.EOF) + var providerErr *fantasy.ProviderError + if errors.As(got, &providerErr) { + t.Errorf("toProviderErr wrapped io.EOF as ProviderError; should pass through") + } +}