refactor: rework error types

Andrey Nering created

Change summary

agent.go                                       |   2 
agent_test.go                                  |   2 
errors.go                                      | 134 +++++--------------
providers/anthropic/anthropic.go               |  24 +-
providers/google/google.go                     |   2 
providers/openai/language_model.go             |  26 +-
providers/openai/language_model_hooks.go       |   2 
providers/openai/responses_language_model.go   |  20 +-
providers/openaicompat/language_model_hooks.go |   4 
providers/openrouter/language_model_hooks.go   |   4 
retry.go                                       |  54 +------
11 files changed, 90 insertions(+), 184 deletions(-)

Detailed changes

agent.go 🔗

@@ -971,7 +971,7 @@ func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []Agen
 
 func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 	if prompt == "" {
-		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
+		return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
 	}
 
 	var preparedPrompt Prompt

agent_test.go 🔗

@@ -528,7 +528,7 @@ func TestAgent_Generate_EmptyPrompt(t *testing.T) {
 
 	require.Error(t, err)
 	require.Nil(t, result)
-	require.Contains(t, err.Error(), "Prompt can't be empty")
+	require.Contains(t, err.Error(), "invalid argument: prompt can't be empty")
 }
 
 // Test with system prompt

errors.go 🔗

@@ -1,128 +1,70 @@
 package fantasy
 
 import (
-	"encoding/json"
-	"errors"
 	"fmt"
+	"net/http"
+
+	"github.com/charmbracelet/x/exp/slice"
 )
 
-// AIError is a custom error type for AI SDK related errors.
-type AIError struct {
+// Error is a custom error type for the fantasy package.
+type Error struct {
 	Message string
+	Title   string
 	Cause   error
 }
 
-// Error implements the error interface.
-func (e *AIError) Error() string {
-	return e.Message
-}
-
-// Unwrap returns the underlying cause of the error.
-func (e *AIError) Unwrap() error {
-	return e.Cause
-}
-
-// NewAIError creates a new AI SDK Error.
-func NewAIError(message string, cause error) *AIError {
-	return &AIError{
-		Message: message,
-		Cause:   cause,
+func (err *Error) Error() string {
+	if err.Title == "" {
+		return err.Message
 	}
+	return fmt.Sprintf("%s: %s", err.Title, err.Message)
 }
 
-// IsAIError checks if the given error is an AI SDK Error.
-func IsAIError(err error) bool {
-	var sdkErr *AIError
-	return errors.As(err, &sdkErr)
+func (err Error) Unwrap() error {
+	return err.Cause
 }
 
-// APICallError represents an error from an API call.
-type APICallError struct {
-	*AIError
+// ProviderError represents an error returned by an external provider.
+type ProviderError struct {
+	Message string
+	Title   string
+	Cause   error
+
 	URL             string
-	RequestDump     string
 	StatusCode      int
+	RequestBody     []byte
 	ResponseHeaders map[string]string
-	ResponseDump    string
-	IsRetryable     bool
-}
-
-// NewAPICallError creates a new API call error.
-func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError {
-	if !isRetryable && statusCode != 0 {
-		isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500
-	}
-
-	return &APICallError{
-		AIError:         NewAIError(message, cause),
-		URL:             url,
-		RequestDump:     requestDump,
-		StatusCode:      statusCode,
-		ResponseHeaders: responseHeaders,
-		ResponseDump:    responseDump,
-		IsRetryable:     isRetryable,
-	}
-}
-
-// InvalidArgumentError represents an invalid function argument error.
-type InvalidArgumentError struct {
-	*AIError
-	Argument string
+	ResponseBody    []byte
 }
 
-// NewInvalidArgumentError creates a new invalid argument error.
-func NewInvalidArgumentError(argument, message string, cause error) *InvalidArgumentError {
-	return &InvalidArgumentError{
-		AIError:  NewAIError(message, cause),
-		Argument: argument,
+func (m *ProviderError) Error() string {
+	if m.Title == "" {
+		return m.Message
 	}
+	return fmt.Sprintf("%s: %s", m.Title, m.Message)
 }
 
-// InvalidPromptError represents an invalid prompt error.
-type InvalidPromptError struct {
-	*AIError
-	Prompt any
+// IsRetryable checks if the error is retryable based on the status code.
+func (m *ProviderError) IsRetryable() bool {
+	return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests
 }
 
-// NewInvalidPromptError creates a new invalid prompt error.
-func NewInvalidPromptError(prompt any, message string, cause error) *InvalidPromptError {
-	return &InvalidPromptError{
-		AIError: NewAIError(fmt.Sprintf("Invalid prompt: %s", message), cause),
-		Prompt:  prompt,
-	}
+// RetryError represents an error that occurred during retry operations.
+type RetryError struct {
+	Errors []error
 }
 
-// InvalidResponseDataError represents invalid response data from the server.
-type InvalidResponseDataError struct {
-	*AIError
-	Data any
-}
-
-// NewInvalidResponseDataError creates a new invalid response data error.
-func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError {
-	if message == "" {
-		dataJSON, _ := json.Marshal(data)
-		message = fmt.Sprintf("Invalid response data: %s.", string(dataJSON))
-	}
-	return &InvalidResponseDataError{
-		AIError: NewAIError(message, nil),
-		Data:    data,
+func (e *RetryError) Error() string {
+	if err, ok := slice.Last(e.Errors); ok {
+		return fmt.Sprintf("retry error: %v", err)
 	}
+	return "retry error: no underlying errors"
 }
 
-// UnsupportedFunctionalityError represents an unsupported functionality error.
-type UnsupportedFunctionalityError struct {
-	*AIError
-	Functionality string
-}
-
-// NewUnsupportedFunctionalityError creates a new unsupported functionality error.
-func NewUnsupportedFunctionalityError(functionality, message string) *UnsupportedFunctionalityError {
-	if message == "" {
-		message = fmt.Sprintf("'%s' functionality not supported.", functionality)
-	}
-	return &UnsupportedFunctionalityError{
-		AIError:       NewAIError(message, nil),
-		Functionality: functionality,
+func (e RetryError) Unwrap() error {
+	if err, ok := slice.Last(e.Errors); ok {
+		return err
 	}
+	return nil
 }

providers/anthropic/anthropic.go 🔗

@@ -202,7 +202,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
+			return nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"}
 		}
 	}
 	sendReasoning := true
@@ -251,7 +251,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa
 	}
 	if isThinking {
 		if thinkingBudget == 0 {
-			return nil, nil, fantasy.NewUnsupportedFunctionalityError("thinking requires budget", "")
+			return nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"}
 		}
 		params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
 		if call.Temperature != nil {
@@ -700,16 +700,16 @@ func (a languageModel) handleError(err error) error {
 			v := h[len(h)-1]
 			headers[strings.ToLower(k)] = v
 		}
-		return fantasy.NewAPICallError(
-			apiErr.Error(),
-			apiErr.Request.URL.String(),
-			string(requestDump),
-			apiErr.StatusCode,
-			headers,
-			string(responseDump),
-			apiErr,
-			false,
-		)
+		return &fantasy.ProviderError{
+			Title:           "provider request failed",
+			Message:         apiErr.Error(),
+			Cause:           apiErr,
+			URL:             apiErr.Request.URL.String(),
+			StatusCode:      apiErr.StatusCode,
+			RequestBody:     requestDump,
+			ResponseHeaders: headers,
+			ResponseBody:    responseDump,
+		}
 	}
 	return err
 }

providers/google/google.go 🔗

@@ -197,7 +197,7 @@ func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentC
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
+			return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
 		}
 	}
 

providers/openai/language_model.go 🔗

@@ -233,16 +233,16 @@ func (o languageModel) handleError(err error) error {
 			v := h[len(h)-1]
 			headers[strings.ToLower(k)] = v
 		}
-		return fantasy.NewAPICallError(
-			apiErr.Message,
-			apiErr.Request.URL.String(),
-			string(requestDump),
-			apiErr.StatusCode,
-			headers,
-			string(responseDump),
-			apiErr,
-			false,
-		)
+		return &fantasy.ProviderError{
+			Title:           "provider request failed",
+			Message:         apiErr.Message,
+			Cause:           apiErr,
+			URL:             apiErr.Request.URL.String(),
+			StatusCode:      apiErr.StatusCode,
+			RequestBody:     requestDump,
+			ResponseHeaders: headers,
+			ResponseBody:    responseDump,
+		}
 	}
 	return err
 }
@@ -422,13 +422,13 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
 							// Does not exist
 							var err error
 							if toolCallDelta.Type != "function" {
-								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
+								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
 							}
 							if toolCallDelta.ID == "" {
-								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
+								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
 							}
 							if toolCallDelta.Function.Name == "" {
-								err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
+								err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
 							}
 							if err != nil {
 								yield(fantasy.StreamPart{

providers/openai/language_model_hooks.go 🔗

@@ -45,7 +45,7 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
+			return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"}
 		}
 	}
 

providers/openai/responses_language_model.go 🔗

@@ -659,16 +659,16 @@ func (o responsesLanguageModel) handleError(err error) error {
 			v := h[len(h)-1]
 			headers[strings.ToLower(k)] = v
 		}
-		return fantasy.NewAPICallError(
-			apiErr.Message,
-			apiErr.Request.URL.String(),
-			string(requestDump),
-			apiErr.StatusCode,
-			headers,
-			string(responseDump),
-			apiErr,
-			false,
-		)
+		return &fantasy.ProviderError{
+			Title:           "provider request failed",
+			Message:         apiErr.Message,
+			Cause:           apiErr,
+			URL:             apiErr.Request.URL.String(),
+			StatusCode:      apiErr.StatusCode,
+			RequestBody:     requestDump,
+			ResponseHeaders: headers,
+			ResponseBody:    responseDump,
+		}
 	}
 	return err
 }

providers/openaicompat/language_model_hooks.go 🔗

@@ -19,7 +19,7 @@ func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNe
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+			return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"}
 		}
 	}
 
@@ -86,7 +86,7 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.Str
 		if err != nil {
 			yield(fantasy.StreamPart{
 				Type:  fantasy.StreamPartTypeError,
-				Error: fantasy.NewAIError("error unmarshalling delta", err),
+				Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 			})
 			return ctx, false
 		}

providers/openrouter/language_model_hooks.go 🔗

@@ -21,7 +21,7 @@ func languagePrepareModelCall(_ fantasy.LanguageModel, params *openaisdk.ChatCom
 	if v, ok := call.ProviderOptions[Name]; ok {
 		providerOptions, ok = v.(*ProviderOptions)
 		if !ok {
-			return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+			return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"}
 		}
 	}
 
@@ -180,7 +180,7 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fa
 	if err != nil {
 		yield(fantasy.StreamPart{
 			Type:  fantasy.StreamPartTypeError,
-			Error: fantasy.NewAIError("error unmarshalling delta", err),
+			Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
 		})
 		return ctx, false
 	}

retry.go 🔗

@@ -3,7 +3,6 @@ package fantasy
 import (
 	"context"
 	"errors"
-	"fmt"
 	"strconv"
 	"time"
 )
@@ -14,40 +13,14 @@ type RetryFn[T any] func() (T, error)
 // RetryFunction is a function that retries another function.
 type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
 
-// RetryReason represents the reason why a retry operation failed.
-type RetryReason string
-
-const (
-	// RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded.
-	RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
-	// RetryReasonErrorNotRetryable indicates the error is not retryable.
-	RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
-)
-
-// RetryError represents an error that occurred during retry operations.
-type RetryError struct {
-	*AIError
-	Reason RetryReason
-	Errors []error
-}
-
-// NewRetryError creates a new retry error.
-func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
-	return &RetryError{
-		AIError: NewAIError(message, nil),
-		Reason:  reason,
-		Errors:  errors,
-	}
-}
-
 // getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
 func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
-	var apiErr *APICallError
-	if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
+	var providerErr *ProviderError
+	if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil {
 		return exponentialBackoffDelay
 	}
 
-	headers := apiErr.ResponseHeaders
+	headers := providerErr.ResponseHeaders
 	var ms time.Duration
 
 	// retry-ms is more precise than retry-after and used by e.g. OpenAI
@@ -101,7 +74,7 @@ type RetryOptions struct {
 }
 
 // OnRetryCallback defines a function that is called when a retry occurs.
-type OnRetryCallback = func(err *APICallError, delay time.Duration)
+type OnRetryCallback = func(err *ProviderError, delay time.Duration)
 
 // DefaultRetryOptions returns the default retry options.
 // DefaultRetryOptions returns the default retry options.
@@ -129,23 +102,18 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti
 		return zero, err // don't wrap the error when retries are disabled
 	}
 
-	errorMessage := err.Error()
 	newErrors := append(allErrors, err)
 	tryNumber := len(newErrors)
 
 	if tryNumber > options.MaxRetries {
-		return zero, NewRetryError(
-			fmt.Sprintf("Failed after %d attempts. Last error: %v", tryNumber, errorMessage),
-			RetryReasonMaxRetriesExceeded,
-			newErrors,
-		)
+		return zero, &RetryError{newErrors}
 	}
 
-	var apiErr *APICallError
-	if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
+	var providerErr *ProviderError
+	if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries {
 		delay := getRetryDelayInMs(err, options.InitialDelayIn)
 		if options.OnRetry != nil {
-			options.OnRetry(apiErr, delay)
+			options.OnRetry(providerErr, delay)
 		}
 
 		select {
@@ -165,9 +133,5 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti
 		return zero, err // don't wrap the error when a non-retryable error occurs on the first try
 	}
 
-	return zero, NewRetryError(
-		fmt.Sprintf("Failed after %d attempts with non-retryable error: %v", tryNumber, errorMessage),
-		RetryReasonErrorNotRetryable,
-		newErrors,
-	)
+	return zero, &RetryError{newErrors}
 }