From 789f18b460252a5c6e6998edd2502be9e5c95718 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 5 Nov 2025 14:40:08 -0300 Subject: [PATCH] refactor: rework error types --- agent.go | 2 +- agent_test.go | 2 +- errors.go | 134 +++++------------- providers/anthropic/anthropic.go | 24 ++-- providers/google/google.go | 2 +- providers/openai/language_model.go | 26 ++-- providers/openai/language_model_hooks.go | 2 +- providers/openai/responses_language_model.go | 20 +-- .../openaicompat/language_model_hooks.go | 4 +- providers/openrouter/language_model_hooks.go | 4 +- retry.go | 54 ++----- 11 files changed, 90 insertions(+), 184 deletions(-) diff --git a/agent.go b/agent.go index 214cb388918fa15115716ccf58d96fffce36cf41..52015b5fd90ca988cb8e7af763e38693bfd4c9e1 100644 --- a/agent.go +++ b/agent.go @@ -971,7 +971,7 @@ func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []Agen func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) { if prompt == "" { - return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil) + return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"} } var preparedPrompt Prompt diff --git a/agent_test.go b/agent_test.go index 2929e3c0c98100231f498a8084dc46247d57cb1f..dce488116c012e2be15034750cf87779a70671f7 100644 --- a/agent_test.go +++ b/agent_test.go @@ -528,7 +528,7 @@ func TestAgent_Generate_EmptyPrompt(t *testing.T) { require.Error(t, err) require.Nil(t, result) - require.Contains(t, err.Error(), "Prompt can't be empty") + require.Contains(t, err.Error(), "invalid argument: prompt can't be empty") } // Test with system prompt diff --git a/errors.go b/errors.go index ffa07aa6ee9cd7adcc5601f5fe016564fbc2ecc7..df7998e3ab717b2f287d2fb53a822d459c1c3dd2 100644 --- a/errors.go +++ b/errors.go @@ -1,128 +1,70 @@ package fantasy import ( - "encoding/json" - "errors" "fmt" + "net/http" + + "github.com/charmbracelet/x/exp/slice" ) -// AIError is a custom error type for AI SDK related errors. -type AIError struct { +// Error is a custom error type for the fantasy package. +type Error struct { Message string + Title string Cause error } -// Error implements the error interface. -func (e *AIError) Error() string { - return e.Message -} - -// Unwrap returns the underlying cause of the error. -func (e *AIError) Unwrap() error { - return e.Cause -} - -// NewAIError creates a new AI SDK Error. -func NewAIError(message string, cause error) *AIError { - return &AIError{ - Message: message, - Cause: cause, +func (err *Error) Error() string { + if err.Title == "" { + return err.Message } + return fmt.Sprintf("%s: %s", err.Title, err.Message) } -// IsAIError checks if the given error is an AI SDK Error. -func IsAIError(err error) bool { - var sdkErr *AIError - return errors.As(err, &sdkErr) +func (err Error) Unwrap() error { + return err.Cause } -// APICallError represents an error from an API call. -type APICallError struct { - *AIError +// ProviderError represents an error returned by an external provider. +type ProviderError struct { + Message string + Title string + Cause error + URL string - RequestDump string StatusCode int + RequestBody []byte ResponseHeaders map[string]string - ResponseDump string - IsRetryable bool -} - -// NewAPICallError creates a new API call error. -func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError { - if !isRetryable && statusCode != 0 { - isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500 - } - - return &APICallError{ - AIError: NewAIError(message, cause), - URL: url, - RequestDump: requestDump, - StatusCode: statusCode, - ResponseHeaders: responseHeaders, - ResponseDump: responseDump, - IsRetryable: isRetryable, - } -} - -// InvalidArgumentError represents an invalid function argument error. -type InvalidArgumentError struct { - *AIError - Argument string + ResponseBody []byte } -// NewInvalidArgumentError creates a new invalid argument error. -func NewInvalidArgumentError(argument, message string, cause error) *InvalidArgumentError { - return &InvalidArgumentError{ - AIError: NewAIError(message, cause), - Argument: argument, +func (m *ProviderError) Error() string { + if m.Title == "" { + return m.Message } + return fmt.Sprintf("%s: %s", m.Title, m.Message) } -// InvalidPromptError represents an invalid prompt error. -type InvalidPromptError struct { - *AIError - Prompt any +// IsRetryable checks if the error is retryable based on the status code. +func (m *ProviderError) IsRetryable() bool { + return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests } -// NewInvalidPromptError creates a new invalid prompt error. -func NewInvalidPromptError(prompt any, message string, cause error) *InvalidPromptError { - return &InvalidPromptError{ - AIError: NewAIError(fmt.Sprintf("Invalid prompt: %s", message), cause), - Prompt: prompt, - } +// RetryError represents an error that occurred during retry operations. +type RetryError struct { + Errors []error } -// InvalidResponseDataError represents invalid response data from the server. -type InvalidResponseDataError struct { - *AIError - Data any -} - -// NewInvalidResponseDataError creates a new invalid response data error. -func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError { - if message == "" { - dataJSON, _ := json.Marshal(data) - message = fmt.Sprintf("Invalid response data: %s.", string(dataJSON)) - } - return &InvalidResponseDataError{ - AIError: NewAIError(message, nil), - Data: data, +func (e *RetryError) Error() string { + if err, ok := slice.Last(e.Errors); ok { + return fmt.Sprintf("retry error: %v", err) } + return "retry error: no underlying errors" } -// UnsupportedFunctionalityError represents an unsupported functionality error. -type UnsupportedFunctionalityError struct { - *AIError - Functionality string -} - -// NewUnsupportedFunctionalityError creates a new unsupported functionality error. -func NewUnsupportedFunctionalityError(functionality, message string) *UnsupportedFunctionalityError { - if message == "" { - message = fmt.Sprintf("'%s' functionality not supported.", functionality) - } - return &UnsupportedFunctionalityError{ - AIError: NewAIError(message, nil), - Functionality: functionality, +func (e RetryError) Unwrap() error { + if err, ok := slice.Last(e.Errors); ok { + return err } + return nil } diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 42f8d72f1039cabf9c85341120655b4a24125fa3..9b0fb97e3d7617685a31445cd9e1eab539348360 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -202,7 +202,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) + return nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"} } } sendReasoning := true @@ -251,7 +251,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa } if isThinking { if thinkingBudget == 0 { - return nil, nil, fantasy.NewUnsupportedFunctionalityError("thinking requires budget", "") + return nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"} } params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget) if call.Temperature != nil { @@ -700,16 +700,16 @@ func (a languageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return fantasy.NewAPICallError( - apiErr.Error(), - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) + return &fantasy.ProviderError{ + Title: "provider request failed", + Message: apiErr.Error(), + Cause: apiErr, + URL: apiErr.Request.URL.String(), + StatusCode: apiErr.StatusCode, + RequestBody: requestDump, + ResponseHeaders: headers, + ResponseBody: responseDump, + } } return err } diff --git a/providers/google/google.go b/providers/google/google.go index c59ffb3c8d2542bc546b98f592572bab5a8063fa..0434c9ea01db63452a6704cad1fe61093f56b4f9 100644 --- a/providers/google/google.go +++ b/providers/google/google.go @@ -197,7 +197,7 @@ func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentC if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil) + return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"} } } diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 72d22808a27fce5b8c1688b94b248eb86bfdd114..d0b8212f6e00b2cb1e94d3dc23d42e00ef2b6ede 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -233,16 +233,16 @@ func (o languageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return fantasy.NewAPICallError( - apiErr.Message, - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) + return &fantasy.ProviderError{ + Title: "provider request failed", + Message: apiErr.Message, + Cause: apiErr, + URL: apiErr.Request.URL.String(), + StatusCode: apiErr.StatusCode, + RequestBody: requestDump, + ResponseHeaders: headers, + ResponseBody: responseDump, + } } return err } @@ -422,13 +422,13 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S // Does not exist var err error if toolCallDelta.Type != "function" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."} } if toolCallDelta.ID == "" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."} } if toolCallDelta.Function.Name == "" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."} } if err != nil { yield(fantasy.StreamPart{ diff --git a/providers/openai/language_model_hooks.go b/providers/openai/language_model_hooks.go index 576dcf46ebc8375ec0a74547ff7b5969c9fd692f..9ca2c64f520a1dbf01f95f46ac097689b2e6d045 100644 --- a/providers/openai/language_model_hooks.go +++ b/providers/openai/language_model_hooks.go @@ -45,7 +45,7 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"} } } diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index faa2f07c2c3b0b0862ad4da2135cdb7c19a126d0..2f69551011d14df85b29a0f89ad7ae06a10eec48 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -659,16 +659,16 @@ func (o responsesLanguageModel) handleError(err error) error { v := h[len(h)-1] headers[strings.ToLower(k)] = v } - return fantasy.NewAPICallError( - apiErr.Message, - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) + return &fantasy.ProviderError{ + Title: "provider request failed", + Message: apiErr.Message, + Cause: apiErr, + URL: apiErr.Request.URL.String(), + StatusCode: apiErr.StatusCode, + RequestBody: requestDump, + ResponseHeaders: headers, + ResponseBody: responseDump, + } } return err } diff --git a/providers/openaicompat/language_model_hooks.go b/providers/openaicompat/language_model_hooks.go index f9f16b33f7cc97d46d059b444b9edf07d5ed28e0..d3c5a8c6b5cdbc67c58825399754325131e323fe 100644 --- a/providers/openaicompat/language_model_hooks.go +++ b/providers/openaicompat/language_model_hooks.go @@ -19,7 +19,7 @@ func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNe if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"} } } @@ -86,7 +86,7 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.Str if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewAIError("error unmarshalling delta", err), + Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err}, }) return ctx, false } diff --git a/providers/openrouter/language_model_hooks.go b/providers/openrouter/language_model_hooks.go index 59a1935be43a96e297c9f590b6f844bba2388632..77cc839414d2aeb2b856a92a5b1b22d5b9946d66 100644 --- a/providers/openrouter/language_model_hooks.go +++ b/providers/openrouter/language_model_hooks.go @@ -21,7 +21,7 @@ func languagePrepareModelCall(_ fantasy.LanguageModel, params *openaisdk.ChatCom if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"} } } @@ -180,7 +180,7 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fa if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewAIError("error unmarshalling delta", err), + Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err}, }) return ctx, false } diff --git a/retry.go b/retry.go index 4ef40d71361953837c718434559dc1a0df94c91e..6b2c2a412f728b589099f0fe8e41f4e1e93f78a8 100644 --- a/retry.go +++ b/retry.go @@ -3,7 +3,6 @@ package fantasy import ( "context" "errors" - "fmt" "strconv" "time" ) @@ -14,40 +13,14 @@ type RetryFn[T any] func() (T, error) // RetryFunction is a function that retries another function. type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error) -// RetryReason represents the reason why a retry operation failed. -type RetryReason string - -const ( - // RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded. - RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded" - // RetryReasonErrorNotRetryable indicates the error is not retryable. - RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable" -) - -// RetryError represents an error that occurred during retry operations. -type RetryError struct { - *AIError - Reason RetryReason - Errors []error -} - -// NewRetryError creates a new retry error. -func NewRetryError(message string, reason RetryReason, errors []error) *RetryError { - return &RetryError{ - AIError: NewAIError(message, nil), - Reason: reason, - Errors: errors, - } -} - // getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff. func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration { - var apiErr *APICallError - if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil { + var providerErr *ProviderError + if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil { return exponentialBackoffDelay } - headers := apiErr.ResponseHeaders + headers := providerErr.ResponseHeaders var ms time.Duration // retry-ms is more precise than retry-after and used by e.g. OpenAI @@ -101,7 +74,7 @@ type RetryOptions struct { } // OnRetryCallback defines a function that is called when a retry occurs. -type OnRetryCallback = func(err *APICallError, delay time.Duration) +type OnRetryCallback = func(err *ProviderError, delay time.Duration) // DefaultRetryOptions returns the default retry options. // DefaultRetryOptions returns the default retry options. @@ -129,23 +102,18 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti return zero, err // don't wrap the error when retries are disabled } - errorMessage := err.Error() newErrors := append(allErrors, err) tryNumber := len(newErrors) if tryNumber > options.MaxRetries { - return zero, NewRetryError( - fmt.Sprintf("Failed after %d attempts. Last error: %v", tryNumber, errorMessage), - RetryReasonMaxRetriesExceeded, - newErrors, - ) + return zero, &RetryError{newErrors} } - var apiErr *APICallError - if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries { + var providerErr *ProviderError + if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries { delay := getRetryDelayInMs(err, options.InitialDelayIn) if options.OnRetry != nil { - options.OnRetry(apiErr, delay) + options.OnRetry(providerErr, delay) } select { @@ -165,9 +133,5 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti return zero, err // don't wrap the error when a non-retryable error occurs on the first try } - return zero, NewRetryError( - fmt.Sprintf("Failed after %d attempts with non-retryable error: %v", tryNumber, errorMessage), - RetryReasonErrorNotRetryable, - newErrors, - ) + return zero, &RetryError{newErrors} }