Detailed changes
@@ -971,7 +971,7 @@ func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []Agen
func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
if prompt == "" {
- return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
+ return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"}
}
var preparedPrompt Prompt
@@ -528,7 +528,7 @@ func TestAgent_Generate_EmptyPrompt(t *testing.T) {
require.Error(t, err)
require.Nil(t, result)
- require.Contains(t, err.Error(), "Prompt can't be empty")
+ require.Contains(t, err.Error(), "invalid argument: prompt can't be empty")
}
// Test with system prompt
@@ -1,128 +1,70 @@
package fantasy
import (
- "encoding/json"
- "errors"
"fmt"
+ "net/http"
+
+ "github.com/charmbracelet/x/exp/slice"
)
-// AIError is a custom error type for AI SDK related errors.
-type AIError struct {
+// Error is a custom error type for the fantasy package.
+type Error struct {
Message string
+ Title string
Cause error
}
-// Error implements the error interface.
-func (e *AIError) Error() string {
- return e.Message
-}
-
-// Unwrap returns the underlying cause of the error.
-func (e *AIError) Unwrap() error {
- return e.Cause
-}
-
-// NewAIError creates a new AI SDK Error.
-func NewAIError(message string, cause error) *AIError {
- return &AIError{
- Message: message,
- Cause: cause,
+func (err *Error) Error() string {
+ if err.Title == "" {
+ return err.Message
}
+ return fmt.Sprintf("%s: %s", err.Title, err.Message)
}
-// IsAIError checks if the given error is an AI SDK Error.
-func IsAIError(err error) bool {
- var sdkErr *AIError
- return errors.As(err, &sdkErr)
+func (err Error) Unwrap() error {
+ return err.Cause
}
-// APICallError represents an error from an API call.
-type APICallError struct {
- *AIError
+// ProviderError represents an error returned by an external provider.
+type ProviderError struct {
+ Message string
+ Title string
+ Cause error
+
URL string
- RequestDump string
StatusCode int
+ RequestBody []byte
ResponseHeaders map[string]string
- ResponseDump string
- IsRetryable bool
-}
-
-// NewAPICallError creates a new API call error.
-func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError {
- if !isRetryable && statusCode != 0 {
- isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500
- }
-
- return &APICallError{
- AIError: NewAIError(message, cause),
- URL: url,
- RequestDump: requestDump,
- StatusCode: statusCode,
- ResponseHeaders: responseHeaders,
- ResponseDump: responseDump,
- IsRetryable: isRetryable,
- }
-}
-
-// InvalidArgumentError represents an invalid function argument error.
-type InvalidArgumentError struct {
- *AIError
- Argument string
+ ResponseBody []byte
}
-// NewInvalidArgumentError creates a new invalid argument error.
-func NewInvalidArgumentError(argument, message string, cause error) *InvalidArgumentError {
- return &InvalidArgumentError{
- AIError: NewAIError(message, cause),
- Argument: argument,
+func (m *ProviderError) Error() string {
+ if m.Title == "" {
+ return m.Message
}
+ return fmt.Sprintf("%s: %s", m.Title, m.Message)
}
-// InvalidPromptError represents an invalid prompt error.
-type InvalidPromptError struct {
- *AIError
- Prompt any
+// IsRetryable checks if the error is retryable based on the status code.
+func (m *ProviderError) IsRetryable() bool {
+ return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests
}
-// NewInvalidPromptError creates a new invalid prompt error.
-func NewInvalidPromptError(prompt any, message string, cause error) *InvalidPromptError {
- return &InvalidPromptError{
- AIError: NewAIError(fmt.Sprintf("Invalid prompt: %s", message), cause),
- Prompt: prompt,
- }
+// RetryError represents an error that occurred during retry operations.
+type RetryError struct {
+ Errors []error
}
-// InvalidResponseDataError represents invalid response data from the server.
-type InvalidResponseDataError struct {
- *AIError
- Data any
-}
-
-// NewInvalidResponseDataError creates a new invalid response data error.
-func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError {
- if message == "" {
- dataJSON, _ := json.Marshal(data)
- message = fmt.Sprintf("Invalid response data: %s.", string(dataJSON))
- }
- return &InvalidResponseDataError{
- AIError: NewAIError(message, nil),
- Data: data,
+func (e *RetryError) Error() string {
+ if err, ok := slice.Last(e.Errors); ok {
+ return fmt.Sprintf("retry error: %v", err)
}
+ return "retry error: no underlying errors"
}
-// UnsupportedFunctionalityError represents an unsupported functionality error.
-type UnsupportedFunctionalityError struct {
- *AIError
- Functionality string
-}
-
-// NewUnsupportedFunctionalityError creates a new unsupported functionality error.
-func NewUnsupportedFunctionalityError(functionality, message string) *UnsupportedFunctionalityError {
- if message == "" {
- message = fmt.Sprintf("'%s' functionality not supported.", functionality)
- }
- return &UnsupportedFunctionalityError{
- AIError: NewAIError(message, nil),
- Functionality: functionality,
+func (e RetryError) Unwrap() error {
+ if err, ok := slice.Last(e.Errors); ok {
+ return err
}
+ return nil
}
@@ -202,7 +202,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
+ return nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"}
}
}
sendReasoning := true
@@ -251,7 +251,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa
}
if isThinking {
if thinkingBudget == 0 {
- return nil, nil, fantasy.NewUnsupportedFunctionalityError("thinking requires budget", "")
+ return nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"}
}
params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
if call.Temperature != nil {
@@ -700,16 +700,16 @@ func (a languageModel) handleError(err error) error {
v := h[len(h)-1]
headers[strings.ToLower(k)] = v
}
- return fantasy.NewAPICallError(
- apiErr.Error(),
- apiErr.Request.URL.String(),
- string(requestDump),
- apiErr.StatusCode,
- headers,
- string(responseDump),
- apiErr,
- false,
- )
+ return &fantasy.ProviderError{
+ Title: "provider request failed",
+ Message: apiErr.Error(),
+ Cause: apiErr,
+ URL: apiErr.Request.URL.String(),
+ StatusCode: apiErr.StatusCode,
+ RequestBody: requestDump,
+ ResponseHeaders: headers,
+ ResponseBody: responseDump,
+ }
}
return err
}
@@ -197,7 +197,7 @@ func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentC
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil)
+ return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"}
}
}
@@ -233,16 +233,16 @@ func (o languageModel) handleError(err error) error {
v := h[len(h)-1]
headers[strings.ToLower(k)] = v
}
- return fantasy.NewAPICallError(
- apiErr.Message,
- apiErr.Request.URL.String(),
- string(requestDump),
- apiErr.StatusCode,
- headers,
- string(responseDump),
- apiErr,
- false,
- )
+ return &fantasy.ProviderError{
+ Title: "provider request failed",
+ Message: apiErr.Message,
+ Cause: apiErr,
+ URL: apiErr.Request.URL.String(),
+ StatusCode: apiErr.StatusCode,
+ RequestBody: requestDump,
+ ResponseHeaders: headers,
+ ResponseBody: responseDump,
+ }
}
return err
}
@@ -422,13 +422,13 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
// Does not exist
var err error
if toolCallDelta.Type != "function" {
- err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
+ err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."}
}
if toolCallDelta.ID == "" {
- err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
+ err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."}
}
if toolCallDelta.Function.Name == "" {
- err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
+ err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."}
}
if err != nil {
yield(fantasy.StreamPart{
@@ -45,7 +45,7 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
+ return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"}
}
}
@@ -659,16 +659,16 @@ func (o responsesLanguageModel) handleError(err error) error {
v := h[len(h)-1]
headers[strings.ToLower(k)] = v
}
- return fantasy.NewAPICallError(
- apiErr.Message,
- apiErr.Request.URL.String(),
- string(requestDump),
- apiErr.StatusCode,
- headers,
- string(responseDump),
- apiErr,
- false,
- )
+ return &fantasy.ProviderError{
+ Title: "provider request failed",
+ Message: apiErr.Message,
+ Cause: apiErr,
+ URL: apiErr.Request.URL.String(),
+ StatusCode: apiErr.StatusCode,
+ RequestBody: requestDump,
+ ResponseHeaders: headers,
+ ResponseBody: responseDump,
+ }
}
return err
}
@@ -19,7 +19,7 @@ func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNe
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+ return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"}
}
}
@@ -86,7 +86,7 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.Str
if err != nil {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
- Error: fantasy.NewAIError("error unmarshalling delta", err),
+ Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
})
return ctx, false
}
@@ -21,7 +21,7 @@ func languagePrepareModelCall(_ fantasy.LanguageModel, params *openaisdk.ChatCom
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
- return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
+ return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"}
}
}
@@ -180,7 +180,7 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fa
if err != nil {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
- Error: fantasy.NewAIError("error unmarshalling delta", err),
+ Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err},
})
return ctx, false
}
@@ -3,7 +3,6 @@ package fantasy
import (
"context"
"errors"
- "fmt"
"strconv"
"time"
)
@@ -14,40 +13,14 @@ type RetryFn[T any] func() (T, error)
// RetryFunction is a function that retries another function.
type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
-// RetryReason represents the reason why a retry operation failed.
-type RetryReason string
-
-const (
- // RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded.
- RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
- // RetryReasonErrorNotRetryable indicates the error is not retryable.
- RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
-)
-
-// RetryError represents an error that occurred during retry operations.
-type RetryError struct {
- *AIError
- Reason RetryReason
- Errors []error
-}
-
-// NewRetryError creates a new retry error.
-func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
- return &RetryError{
- AIError: NewAIError(message, nil),
- Reason: reason,
- Errors: errors,
- }
-}
-
// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
- var apiErr *APICallError
- if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
+ var providerErr *ProviderError
+ if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil {
return exponentialBackoffDelay
}
- headers := apiErr.ResponseHeaders
+ headers := providerErr.ResponseHeaders
var ms time.Duration
// retry-ms is more precise than retry-after and used by e.g. OpenAI
@@ -101,7 +74,7 @@ type RetryOptions struct {
}
// OnRetryCallback defines a function that is called when a retry occurs.
-type OnRetryCallback = func(err *APICallError, delay time.Duration)
+type OnRetryCallback = func(err *ProviderError, delay time.Duration)
// DefaultRetryOptions returns the default retry options.
// DefaultRetryOptions returns the default retry options.
@@ -129,23 +102,18 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti
return zero, err // don't wrap the error when retries are disabled
}
- errorMessage := err.Error()
newErrors := append(allErrors, err)
tryNumber := len(newErrors)
if tryNumber > options.MaxRetries {
- return zero, NewRetryError(
- fmt.Sprintf("Failed after %d attempts. Last error: %v", tryNumber, errorMessage),
- RetryReasonMaxRetriesExceeded,
- newErrors,
- )
+ return zero, &RetryError{newErrors}
}
- var apiErr *APICallError
- if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
+ var providerErr *ProviderError
+ if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries {
delay := getRetryDelayInMs(err, options.InitialDelayIn)
if options.OnRetry != nil {
- options.OnRetry(apiErr, delay)
+ options.OnRetry(providerErr, delay)
}
select {
@@ -165,9 +133,5 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti
return zero, err // don't wrap the error when a non-retryable error occurs on the first try
}
- return zero, NewRetryError(
- fmt.Sprintf("Failed after %d attempts with non-retryable error: %v", tryNumber, errorMessage),
- RetryReasonErrorNotRetryable,
- newErrors,
- )
+ return zero, &RetryError{newErrors}
}