feat: detect context-too-large errors (#125)

mhpenta created

Change summary

errors.go                             |  9 ++++
providers/anthropic/anthropic_test.go | 60 ++++++++++++++++++++++++++++
providers/anthropic/error.go          | 21 +++++++++
providers/google/error.go             | 21 +++++++++
providers/openai/error.go             | 23 ++++++++++
providers/openai/openai_test.go       | 55 ++++++++++++++++++++++++++
6 files changed, 185 insertions(+), 4 deletions(-)

Detailed changes

errors.go 🔗

@@ -38,6 +38,10 @@ type ProviderError struct {
 	RequestBody     []byte
 	ResponseHeaders map[string]string
 	ResponseBody    []byte
+
+	ContextUsedTokens  int
+	ContextMaxTokens   int
+	ContextTooLargeErr bool
 }
 
 func (m *ProviderError) Error() string {
@@ -52,6 +56,11 @@ func (m *ProviderError) IsRetryable() bool {
 	return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests
 }
 
+// IsContextTooLarge checks if the error is due to the context exceeding the model's limit.
+func (m *ProviderError) IsContextTooLarge() bool {
+	return m.ContextTooLargeErr || m.ContextMaxTokens > 0 || m.ContextUsedTokens > 0
+}
+
 // RetryError represents an error that occurred during retry operations.
 type RetryError struct {
 	Errors []error

providers/anthropic/anthropic_test.go 🔗

@@ -341,3 +341,63 @@ func TestToPrompt_DropsEmptyMessages(t *testing.T) {
 		require.Empty(t, warnings)
 	})
 }
+
+func TestParseContextTooLargeError(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name     string
+		message  string
+		wantErr  bool
+		wantUsed int
+		wantMax  int
+	}{
+		{
+			name:     "matches anthropic format",
+			message:  "prompt is too long: 202630 tokens > 200000 maximum",
+			wantErr:  true,
+			wantUsed: 202630,
+			wantMax:  200000,
+		},
+		{
+			name:     "matches with different numbers",
+			message:  "prompt is too long: 150000 tokens > 128000 maximum",
+			wantErr:  true,
+			wantUsed: 150000,
+			wantMax:  128000,
+		},
+		{
+			name:     "matches with extra whitespace",
+			message:  "prompt is too long:  202630  tokens  >  200000  maximum",
+			wantErr:  true,
+			wantUsed: 202630,
+			wantMax:  200000,
+		},
+		{
+			name:    "does not match unrelated error",
+			message: "invalid api key",
+			wantErr: false,
+		},
+		{
+			name:    "does not match rate limit error",
+			message: "rate limit exceeded",
+			wantErr: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			providerErr := &fantasy.ProviderError{Message: tt.message}
+			parseContextTooLargeError(tt.message, providerErr)
+
+			if tt.wantErr {
+				require.True(t, providerErr.IsContextTooLarge())
+				require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
+				require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
+			} else {
+				require.False(t, providerErr.IsContextTooLarge())
+			}
+		})
+	}
+}

providers/anthropic/error.go 🔗

@@ -4,16 +4,20 @@ import (
 	"cmp"
 	"errors"
 	"net/http"
+	"regexp"
+	"strconv"
 	"strings"
 
 	"charm.land/fantasy"
 	"github.com/charmbracelet/anthropic-sdk-go"
 )
 
+var anthropicContextPattern = regexp.MustCompile(`prompt is too long:\s*(\d+)\s*tokens?\s*>\s*(\d+)\s*maximum`)
+
 func toProviderErr(err error) error {
 	var apiErr *anthropic.Error
 	if errors.As(err, &apiErr) {
-		return &fantasy.ProviderError{
+		providerErr := &fantasy.ProviderError{
 			Title:           cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"),
 			Message:         apiErr.Error(),
 			Cause:           apiErr,
@@ -23,10 +27,25 @@ func toProviderErr(err error) error {
 			ResponseHeaders: toHeaderMap(apiErr.Response.Header),
 			ResponseBody:    apiErr.DumpResponse(true),
 		}
+
+		parseContextTooLargeError(apiErr.Error(), providerErr)
+
+		return providerErr
 	}
 	return err
 }
 
+func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) {
+	matches := anthropicContextPattern.FindStringSubmatch(message)
+	if matches == nil {
+		return
+	}
+
+	providerErr.ContextTooLargeErr = true
+	providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[1])
+	providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[2])
+}
+
 func toHeaderMap(in http.Header) (out map[string]string) {
 	out = make(map[string]string, len(in))
 	for k, v := range in {

providers/google/error.go 🔗

@@ -3,21 +3,40 @@ package google
 import (
 	"cmp"
 	"errors"
+	"regexp"
+	"strconv"
 
 	"charm.land/fantasy"
 	"google.golang.org/genai"
 )
 
+var googleContextPattern = regexp.MustCompile(`input token count.*?(\d+).*?exceeds.*?maximum.*?(\d+)`)
+
 func toProviderErr(err error) error {
 	var apiErr genai.APIError
 	if !errors.As(err, &apiErr) {
 		return err
 	}
-	return &fantasy.ProviderError{
+
+	providerErr := &fantasy.ProviderError{
 		Message:      apiErr.Message,
 		Title:        cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.Code), "provider request failed"),
 		Cause:        err,
 		StatusCode:   apiErr.Code,
 		ResponseBody: []byte(apiErr.Message),
 	}
+
+	parseContextTooLargeError(apiErr.Message, providerErr)
+
+	return providerErr
+}
+
+func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) {
+	matches := googleContextPattern.FindStringSubmatch(message)
+	if matches == nil {
+		return
+	}
+	providerErr.ContextTooLargeErr = true
+	providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[1])
+	providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[2])
 }

providers/openai/error.go 🔗

@@ -5,18 +5,23 @@ import (
 	"errors"
 	"io"
 	"net/http"
+	"regexp"
+	"strconv"
 	"strings"
 
 	"charm.land/fantasy"
 	"github.com/openai/openai-go/v2"
 )
 
+var openaiContextPattern = regexp.MustCompile(`maximum context length is (\d+) tokens.*?(?:resulted in|requested) (\d+) tokens`)
+
 func toProviderErr(err error) error {
 	var apiErr *openai.Error
 	if errors.As(err, &apiErr) {
-		return &fantasy.ProviderError{
+		message := toProviderErrMessage(apiErr)
+		providerErr := &fantasy.ProviderError{
 			Title:           cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"),
-			Message:         toProviderErrMessage(apiErr),
+			Message:         message,
 			Cause:           apiErr,
 			URL:             apiErr.Request.URL.String(),
 			StatusCode:      apiErr.StatusCode,
@@ -24,10 +29,24 @@ func toProviderErr(err error) error {
 			ResponseHeaders: toHeaderMap(apiErr.Response.Header),
 			ResponseBody:    apiErr.DumpResponse(true),
 		}
+
+		parseContextTooLargeError(message, providerErr)
+
+		return providerErr
 	}
 	return err
 }
 
+func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) {
+	matches := openaiContextPattern.FindStringSubmatch(message)
+	if matches == nil {
+		return
+	}
+	providerErr.ContextTooLargeErr = true
+	providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[1])
+	providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[2])
+}
+
 func toProviderErrMessage(apiErr *openai.Error) string {
 	if apiErr.Message != "" {
 		return apiErr.Message

providers/openai/openai_test.go 🔗

@@ -3247,3 +3247,58 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
 		require.Empty(t, warnings)
 	})
 }
+
+func TestParseContextTooLargeError(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name     string
+		message  string
+		wantErr  bool
+		wantUsed int
+		wantMax  int
+	}{
+		{
+			name:     "matches openai format with resulted in",
+			message:  "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.",
+			wantErr:  true,
+			wantUsed: 150000,
+			wantMax:  128000,
+		},
+		{
+			name:     "matches openai format with requested",
+			message:  "maximum context length is 8192 tokens, however you requested 10000 tokens",
+			wantErr:  true,
+			wantUsed: 10000,
+			wantMax:  8192,
+		},
+		{
+			name:    "does not match unrelated error",
+			message: "invalid api key",
+			wantErr: false,
+		},
+		{
+			name:    "does not match rate limit error",
+			message: "rate limit exceeded",
+			wantErr: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			providerErr := &fantasy.ProviderError{Message: tt.message}
+			parseContextTooLargeError(tt.message, providerErr)
+
+			if tt.wantErr {
+				require.True(t, providerErr.IsContextTooLarge())
+				if tt.wantUsed > 0 {
+					require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens)
+					require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens)
+				}
+			} else {
+				require.False(t, providerErr.IsContextTooLarge())
+			}
+		})
+	}
+}