fix(google): return `*fantasy.ProviderError` when we should

Andrey Nering created

Change summary

errors.go                                    | 17 +++++++
providers/anthropic/anthropic.go             | 28 -----------
providers/anthropic/error.go                 | 39 ++++++++++++++++
providers/google/error.go                    | 23 +++++++++
providers/google/google.go                   |  4 
providers/openai/error.go                    | 52 ++++++++++++++++++++++
providers/openai/language_model.go           | 32 +-----------
providers/openai/responses_language_model.go | 34 ++-----------
8 files changed, 145 insertions(+), 84 deletions(-)

Detailed changes

errors.go 🔗

@@ -68,3 +68,20 @@ func (e RetryError) Unwrap() error {
 	}
 	return nil
 }
+
+var statusCodeToTitle = map[int]string{
+	http.StatusBadRequest:          "bad request",
+	http.StatusUnauthorized:        "authentication failed",
+	http.StatusForbidden:           "permission denied",
+	http.StatusNotFound:            "resource not found",
+	http.StatusTooManyRequests:     "rate limit exceeded",
+	http.StatusInternalServerError: "internal server error",
+	http.StatusBadGateway:          "bad gateway",
+	http.StatusServiceUnavailable:  "service unavailable",
+	http.StatusGatewayTimeout:      "gateway timeout",
+}
+
+// ErrorTitleForStatusCode returns a human-readable title for a given HTTP status code.
+func ErrorTitleForStatusCode(statusCode int) string {
+	return statusCodeToTitle[statusCode]
+}

providers/anthropic/anthropic.go 🔗

@@ -690,30 +690,6 @@ func toPrompt(prompt fantasy.Prompt, sendReasoningData bool) ([]anthropic.TextBl
 	return systemBlocks, messages, warnings
 }
 
-func (a languageModel) handleError(err error) error {
-	var apiErr *anthropic.Error
-	if errors.As(err, &apiErr) {
-		requestDump := apiErr.DumpRequest(true)
-		responseDump := apiErr.DumpResponse(true)
-		headers := map[string]string{}
-		for k, h := range apiErr.Response.Header {
-			v := h[len(h)-1]
-			headers[strings.ToLower(k)] = v
-		}
-		return &fantasy.ProviderError{
-			Title:           "provider request failed",
-			Message:         apiErr.Error(),
-			Cause:           apiErr,
-			URL:             apiErr.Request.URL.String(),
-			StatusCode:      apiErr.StatusCode,
-			RequestBody:     requestDump,
-			ResponseHeaders: headers,
-			ResponseBody:    responseDump,
-		}
-	}
-	return err
-}
-
 func mapFinishReason(finishReason string) fantasy.FinishReason {
 	switch finishReason {
 	case "end_turn", "pause_turn", "stop_sequence":
@@ -735,7 +711,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
 	}
 	response, err := a.client.Messages.New(ctx, *params)
 	if err != nil {
-		return nil, a.handleError(err)
+		return nil, toProviderErr(err)
 	}
 
 	var content []fantasy.Content
@@ -968,7 +944,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 		} else { //nolint: revive
 			yield(fantasy.StreamPart{
 				Type:  fantasy.StreamPartTypeError,
-				Error: a.handleError(err),
+				Error: toProviderErr(err),
 			})
 			return
 		}

providers/anthropic/error.go 🔗

@@ -0,0 +1,39 @@
+package anthropic
+
+import (
+	"cmp"
+	"errors"
+	"net/http"
+	"strings"
+
+	"charm.land/fantasy"
+	"github.com/charmbracelet/anthropic-sdk-go"
+)
+
+func toProviderErr(err error) error {
+	var apiErr *anthropic.Error
+	if errors.As(err, &apiErr) {
+		return &fantasy.ProviderError{
+			Title:           cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"),
+			Message:         apiErr.Error(),
+			Cause:           apiErr,
+			URL:             apiErr.Request.URL.String(),
+			StatusCode:      apiErr.StatusCode,
+			RequestBody:     apiErr.DumpRequest(true),
+			ResponseHeaders: toHeaderMap(apiErr.Response.Header),
+			ResponseBody:    apiErr.DumpResponse(true),
+		}
+	}
+	return err
+}
+
+func toHeaderMap(in http.Header) (out map[string]string) {
+	out = make(map[string]string, len(in))
+	for k, v := range in {
+		if l := len(v); l > 0 {
+			out[k] = v[l-1]
+			in[strings.ToLower(k)] = v
+		}
+	}
+	return out
+}

providers/google/error.go 🔗

@@ -0,0 +1,23 @@
+package google
+
+import (
+	"cmp"
+	"errors"
+
+	"charm.land/fantasy"
+	"google.golang.org/genai"
+)
+
+func toProviderErr(err error) error {
+	var apiErr genai.APIError
+	if !errors.As(err, &apiErr) {
+		return err
+	}
+	return &fantasy.ProviderError{
+		Message:      apiErr.Message,
+		Title:        cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.Code), "provider request failed"),
+		Cause:        err,
+		StatusCode:   apiErr.Code,
+		ResponseBody: []byte(apiErr.Message),
+	}
+}

providers/google/google.go 🔗

@@ -514,7 +514,7 @@ func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fanta
 
 	response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...)
 	if err != nil {
-		return nil, err
+		return nil, toProviderErr(err)
 	}
 
 	return g.mapResponse(response, warnings)
@@ -571,7 +571,7 @@ func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.
 			if err != nil {
 				yield(fantasy.StreamPart{
 					Type:  fantasy.StreamPartTypeError,
-					Error: err,
+					Error: toProviderErr(err),
 				})
 				return
 			}

providers/openai/error.go 🔗

@@ -0,0 +1,52 @@
+package openai
+
+import (
+	"cmp"
+	"errors"
+	"io"
+	"net/http"
+	"strings"
+
+	"charm.land/fantasy"
+	"github.com/openai/openai-go/v2"
+)
+
+func toProviderErr(err error) error {
+	var apiErr *openai.Error
+	if errors.As(err, &apiErr) {
+		return &fantasy.ProviderError{
+			Title:           cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"),
+			Message:         toProviderErrMessage(apiErr),
+			Cause:           apiErr,
+			URL:             apiErr.Request.URL.String(),
+			StatusCode:      apiErr.StatusCode,
+			RequestBody:     apiErr.DumpRequest(true),
+			ResponseHeaders: toHeaderMap(apiErr.Response.Header),
+			ResponseBody:    apiErr.DumpResponse(true),
+		}
+	}
+	return err
+}
+
+func toProviderErrMessage(apiErr *openai.Error) string {
+	if apiErr.Message != "" {
+		return apiErr.Message
+	}
+
+	// For some OpenAI-compatible providers, the SDK is not always able to parse
+	// the error message correctly.
+	// Fallback to returning the raw response body in such cases.
+	data, _ := io.ReadAll(apiErr.Response.Body)
+	return string(data)
+}
+
+func toHeaderMap(in http.Header) (out map[string]string) {
+	out = make(map[string]string, len(in))
+	for k, v := range in {
+		if l := len(v); l > 0 {
+			out[k] = v[l-1]
+			in[strings.ToLower(k)] = v
+		}
+	}
+	return out
+}

providers/openai/language_model.go 🔗

@@ -223,30 +223,6 @@ func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionN
 	return params, warnings, nil
 }
 
-func (o languageModel) handleError(err error) error {
-	var apiErr *openai.Error
-	if errors.As(err, &apiErr) {
-		requestDump := apiErr.DumpRequest(true)
-		responseDump := apiErr.DumpResponse(true)
-		headers := map[string]string{}
-		for k, h := range apiErr.Response.Header {
-			v := h[len(h)-1]
-			headers[strings.ToLower(k)] = v
-		}
-		return &fantasy.ProviderError{
-			Title:           "provider request failed",
-			Message:         apiErr.Message,
-			Cause:           apiErr,
-			URL:             apiErr.Request.URL.String(),
-			StatusCode:      apiErr.StatusCode,
-			RequestBody:     requestDump,
-			ResponseHeaders: headers,
-			ResponseBody:    responseDump,
-		}
-	}
-	return err
-}
-
 // Generate implements fantasy.LanguageModel.
 func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
 	params, warnings, err := o.prepareParams(call)
@@ -255,11 +231,11 @@ func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
 	}
 	response, err := o.client.Chat.Completions.New(ctx, *params)
 	if err != nil {
-		return nil, o.handleError(err)
+		return nil, toProviderErr(err)
 	}
 
 	if len(response.Choices) == 0 {
-		return nil, errors.New("no response generated")
+		return nil, &fantasy.Error{Title: "no response", Message: "no response generated"}
 	}
 	choice := response.Choices[0]
 	content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
@@ -433,7 +409,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 							if err != nil {
 								yield(fantasy.StreamPart{
 									Type:  fantasy.StreamPartTypeError,
-									Error: o.handleError(stream.Err()),
+									Error: toProviderErr(stream.Err()),
 								})
 								return
 							}
@@ -563,7 +539,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 		} else { //nolint: revive
 			yield(fantasy.StreamPart{
 				Type:  fantasy.StreamPartTypeError,
-				Error: o.handleError(err),
+				Error: toProviderErr(err),
 			})
 			return
 		}

providers/openai/responses_language_model.go 🔗

@@ -4,7 +4,6 @@ import (
 	"context"
 	"encoding/base64"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"strings"
 
@@ -649,39 +648,18 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti
 	return openaiTools, openaiToolChoice, warnings
 }
 
-func (o responsesLanguageModel) handleError(err error) error {
-	var apiErr *openai.Error
-	if errors.As(err, &apiErr) {
-		requestDump := apiErr.DumpRequest(true)
-		responseDump := apiErr.DumpResponse(true)
-		headers := map[string]string{}
-		for k, h := range apiErr.Response.Header {
-			v := h[len(h)-1]
-			headers[strings.ToLower(k)] = v
-		}
-		return &fantasy.ProviderError{
-			Title:           "provider request failed",
-			Message:         apiErr.Message,
-			Cause:           apiErr,
-			URL:             apiErr.Request.URL.String(),
-			StatusCode:      apiErr.StatusCode,
-			RequestBody:     requestDump,
-			ResponseHeaders: headers,
-			ResponseBody:    responseDump,
-		}
-	}
-	return err
-}
-
 func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
 	params, warnings := o.prepareParams(call)
 	response, err := o.client.Responses.New(ctx, *params)
 	if err != nil {
-		return nil, o.handleError(err)
+		return nil, toProviderErr(err)
 	}
 
 	if response.Error.Message != "" {
-		return nil, o.handleError(fmt.Errorf("response error: %s (code: %s)", response.Error.Message, response.Error.Code))
+		return nil, &fantasy.Error{
+			Title:   "provider error",
+			Message: fmt.Sprintf("%s (code: %s)", response.Error.Message, response.Error.Code),
+		}
 	}
 
 	var content []fantasy.Content
@@ -1023,7 +1001,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (
 		if err != nil {
 			yield(fantasy.StreamPart{
 				Type:  fantasy.StreamPartTypeError,
-				Error: o.handleError(err),
+				Error: toProviderErr(err),
 			})
 			return
 		}