diff --git a/errors.go b/errors.go index 0f9efaf986f038b982b129a6ccc7f6bcba12d3c1..f9228de08e182020dcadf04e9d498655917ece1f 100644 --- a/errors.go +++ b/errors.go @@ -38,6 +38,10 @@ type ProviderError struct { RequestBody []byte ResponseHeaders map[string]string ResponseBody []byte + + ContextUsedTokens int + ContextMaxTokens int + ContextTooLargeErr bool } func (m *ProviderError) Error() string { @@ -52,6 +56,11 @@ func (m *ProviderError) IsRetryable() bool { return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests } +// IsContextTooLarge checks if the error is due to the context exceeding the model's limit. +func (m *ProviderError) IsContextTooLarge() bool { + return m.ContextTooLargeErr || m.ContextMaxTokens > 0 || m.ContextUsedTokens > 0 +} + // RetryError represents an error that occurred during retry operations. type RetryError struct { Errors []error diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index b826882c59b5ce1a2002dce6e0aca8982498f62b..4e808d40587652616646ccbc612db7713dba2400 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -341,3 +341,63 @@ func TestToPrompt_DropsEmptyMessages(t *testing.T) { require.Empty(t, warnings) }) } + +func TestParseContextTooLargeError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + message string + wantErr bool + wantUsed int + wantMax int + }{ + { + name: "matches anthropic format", + message: "prompt is too long: 202630 tokens > 200000 maximum", + wantErr: true, + wantUsed: 202630, + wantMax: 200000, + }, + { + name: "matches with different numbers", + message: "prompt is too long: 150000 tokens > 128000 maximum", + wantErr: true, + wantUsed: 150000, + wantMax: 128000, + }, + { + name: "matches with extra whitespace", + message: "prompt is too long: 202630 tokens > 200000 maximum", + wantErr: true, + wantUsed: 202630, + wantMax: 200000, + }, + { + name: "does not match unrelated error", + message: "invalid api key", + wantErr: false, + }, + { + name: "does not match rate limit error", + message: "rate limit exceeded", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + providerErr := &fantasy.ProviderError{Message: tt.message} + parseContextTooLargeError(tt.message, providerErr) + + if tt.wantErr { + require.True(t, providerErr.IsContextTooLarge()) + require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens) + require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens) + } else { + require.False(t, providerErr.IsContextTooLarge()) + } + }) + } +} diff --git a/providers/anthropic/error.go b/providers/anthropic/error.go index 38393b550a76562b9268d25bda3499b38c6ba236..c4022d4641f04570a34a2f93751d4abe155823bc 100644 --- a/providers/anthropic/error.go +++ b/providers/anthropic/error.go @@ -4,16 +4,20 @@ import ( "cmp" "errors" "net/http" + "regexp" + "strconv" "strings" "charm.land/fantasy" "github.com/charmbracelet/anthropic-sdk-go" ) +var anthropicContextPattern = regexp.MustCompile(`prompt is too long:\s*(\d+)\s*tokens?\s*>\s*(\d+)\s*maximum`) + func toProviderErr(err error) error { var apiErr *anthropic.Error if errors.As(err, &apiErr) { - return &fantasy.ProviderError{ + providerErr := &fantasy.ProviderError{ Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"), Message: apiErr.Error(), Cause: apiErr, @@ -23,10 +27,25 @@ func toProviderErr(err error) error { ResponseHeaders: toHeaderMap(apiErr.Response.Header), ResponseBody: apiErr.DumpResponse(true), } + + parseContextTooLargeError(apiErr.Error(), providerErr) + + return providerErr } return err } +func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) { + matches := anthropicContextPattern.FindStringSubmatch(message) + if matches == nil { + return + } + + providerErr.ContextTooLargeErr = true + providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[1]) + providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[2]) +} + func toHeaderMap(in http.Header) (out map[string]string) { out = make(map[string]string, len(in)) for k, v := range in { diff --git a/providers/google/error.go b/providers/google/error.go index 710cffff61217f0c69bf525f5efb7083b26c3e3d..25e915c2c154cac37a730a7beac37c8a47c1554f 100644 --- a/providers/google/error.go +++ b/providers/google/error.go @@ -3,21 +3,40 @@ package google import ( "cmp" "errors" + "regexp" + "strconv" "charm.land/fantasy" "google.golang.org/genai" ) +var googleContextPattern = regexp.MustCompile(`input token count.*?(\d+).*?exceeds.*?maximum.*?(\d+)`) + func toProviderErr(err error) error { var apiErr genai.APIError if !errors.As(err, &apiErr) { return err } - return &fantasy.ProviderError{ + + providerErr := &fantasy.ProviderError{ Message: apiErr.Message, Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.Code), "provider request failed"), Cause: err, StatusCode: apiErr.Code, ResponseBody: []byte(apiErr.Message), } + + parseContextTooLargeError(apiErr.Message, providerErr) + + return providerErr +} + +func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) { + matches := googleContextPattern.FindStringSubmatch(message) + if matches == nil { + return + } + providerErr.ContextTooLargeErr = true + providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[1]) + providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[2]) } diff --git a/providers/openai/error.go b/providers/openai/error.go index fed072088b117f4c0e8cd28c38c3ee8c3bd73584..63e5c4dc561114a49d4f7e20f2926b99ecfae3e6 100644 --- a/providers/openai/error.go +++ b/providers/openai/error.go @@ -5,18 +5,23 @@ import ( "errors" "io" "net/http" + "regexp" + "strconv" "strings" "charm.land/fantasy" "github.com/openai/openai-go/v2" ) +var openaiContextPattern = regexp.MustCompile(`maximum context length is (\d+) tokens.*?(?:resulted in|requested) (\d+) tokens`) + func toProviderErr(err error) error { var apiErr *openai.Error if errors.As(err, &apiErr) { - return &fantasy.ProviderError{ + message := toProviderErrMessage(apiErr) + providerErr := &fantasy.ProviderError{ Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"), - Message: toProviderErrMessage(apiErr), + Message: message, Cause: apiErr, URL: apiErr.Request.URL.String(), StatusCode: apiErr.StatusCode, @@ -24,10 +29,24 @@ func toProviderErr(err error) error { ResponseHeaders: toHeaderMap(apiErr.Response.Header), ResponseBody: apiErr.DumpResponse(true), } + + parseContextTooLargeError(message, providerErr) + + return providerErr } return err } +func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) { + matches := openaiContextPattern.FindStringSubmatch(message) + if matches == nil { + return + } + providerErr.ContextTooLargeErr = true + providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[1]) + providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[2]) +} + func toProviderErrMessage(apiErr *openai.Error) string { if apiErr.Message != "" { return apiErr.Message diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index adc1cf2e8df311722b7aab878a32e47109226f7f..7592b18b2c8c361f7377ba277bf214621c77fbff 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -3247,3 +3247,58 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { require.Empty(t, warnings) }) } + +func TestParseContextTooLargeError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + message string + wantErr bool + wantUsed int + wantMax int + }{ + { + name: "matches openai format with resulted in", + message: "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens.", + wantErr: true, + wantUsed: 150000, + wantMax: 128000, + }, + { + name: "matches openai format with requested", + message: "maximum context length is 8192 tokens, however you requested 10000 tokens", + wantErr: true, + wantUsed: 10000, + wantMax: 8192, + }, + { + name: "does not match unrelated error", + message: "invalid api key", + wantErr: false, + }, + { + name: "does not match rate limit error", + message: "rate limit exceeded", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + providerErr := &fantasy.ProviderError{Message: tt.message} + parseContextTooLargeError(tt.message, providerErr) + + if tt.wantErr { + require.True(t, providerErr.IsContextTooLarge()) + if tt.wantUsed > 0 { + require.Equal(t, tt.wantUsed, providerErr.ContextUsedTokens) + require.Equal(t, tt.wantMax, providerErr.ContextMaxTokens) + } + } else { + require.False(t, providerErr.IsContextTooLarge()) + } + }) + } +}