diff --git a/agent.go b/agent.go index 19beb0ecb6e89f8f9be3a8eb224c78cb1b52cd3b..df087a8e4fc50df0676652789ceebc82afb165be 100644 --- a/agent.go +++ b/agent.go @@ -1029,7 +1029,7 @@ func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []Agen func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) { if prompt == "" { - return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil) + return nil, &Error{Title: "invalid argument", Message: "prompt can't be empty"} } var preparedPrompt Prompt diff --git a/agent_test.go b/agent_test.go index 2929e3c0c98100231f498a8084dc46247d57cb1f..dce488116c012e2be15034750cf87779a70671f7 100644 --- a/agent_test.go +++ b/agent_test.go @@ -528,7 +528,7 @@ func TestAgent_Generate_EmptyPrompt(t *testing.T) { require.Error(t, err) require.Nil(t, result) - require.Contains(t, err.Error(), "Prompt can't be empty") + require.Contains(t, err.Error(), "invalid argument: prompt can't be empty") } // Test with system prompt diff --git a/errors.go b/errors.go index c096ff2fdae84fdafab4546b26b6ffbec0040337..479b0c1a45eefe84bfcf11064865eb44b80f5b5f 100644 --- a/errors.go +++ b/errors.go @@ -1,298 +1,76 @@ package fantasy import ( - "encoding/json" - "errors" "fmt" -) + "net/http" + "strings" -// markerSymbol is used for identifying AI SDK Error instances. -var markerSymbol = "fantasy.error" + "github.com/charmbracelet/x/exp/slice" +) -// AIError is a custom error type for AI SDK related errors. -type AIError struct { - Name string +// Error is a custom error type for the fantasy package. +type Error struct { Message string + Title string Cause error - marker string -} - -// Error implements the error interface. -func (e *AIError) Error() string { - return e.Message -} - -// Unwrap returns the underlying cause of the error. -func (e *AIError) Unwrap() error { - return e.Cause } -// NewAIError creates a new AI SDK Error. -func NewAIError(name, message string, cause error) *AIError { - return &AIError{ - Name: name, - Message: message, - Cause: cause, - marker: markerSymbol, +func (err *Error) Error() string { + if err.Title == "" { + return err.Message } + return fmt.Sprintf("%s: %s", err.Title, err.Message) } -// IsAIError checks if the given error is an AI SDK Error. -func IsAIError(err error) bool { - var sdkErr *AIError - return errors.As(err, &sdkErr) && sdkErr.marker == markerSymbol +func (err Error) Unwrap() error { + return err.Cause } -// APICallError represents an error from an API call. -type APICallError struct { - *AIError +// ProviderError represents an error returned by an external provider. +type ProviderError struct { + Message string + Title string + Cause error + URL string - RequestDump string StatusCode int + RequestBody []byte ResponseHeaders map[string]string - ResponseDump string - IsRetryable bool + ResponseBody []byte } -// NewAPICallError creates a new API call error. -func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError { - if !isRetryable && statusCode != 0 { - isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500 +func (m *ProviderError) Error() string { + if m.Title == "" { + return m.Message } - - return &APICallError{ - AIError: NewAIError("AI_APICallError", message, cause), - URL: url, - RequestDump: requestDump, - StatusCode: statusCode, - ResponseHeaders: responseHeaders, - ResponseDump: responseDump, - IsRetryable: isRetryable, - } -} - -// EmptyResponseBodyError represents an empty response body error. -type EmptyResponseBodyError struct { - *AIError + return fmt.Sprintf("%s: %s", m.Title, m.Message) } -// NewEmptyResponseBodyError creates a new empty response body error. -func NewEmptyResponseBodyError(message string) *EmptyResponseBodyError { - if message == "" { - message = "Empty response body" - } - return &EmptyResponseBodyError{ - AIError: NewAIError("AI_EmptyResponseBodyError", message, nil), - } +// IsRetryable checks if the error is retryable based on the status code. +func (m *ProviderError) IsRetryable() bool { + return m.StatusCode == http.StatusRequestTimeout || m.StatusCode == http.StatusConflict || m.StatusCode == http.StatusTooManyRequests } -// InvalidArgumentError represents an invalid function argument error. -type InvalidArgumentError struct { - *AIError - Argument string +// RetryError represents an error that occurred during retry operations. +type RetryError struct { + Errors []error } -// NewInvalidArgumentError creates a new invalid argument error. -func NewInvalidArgumentError(argument, message string, cause error) *InvalidArgumentError { - return &InvalidArgumentError{ - AIError: NewAIError("AI_InvalidArgumentError", message, cause), - Argument: argument, +func (e *RetryError) Error() string { + if err, ok := slice.Last(e.Errors); ok { + return fmt.Sprintf("retry error: %v", err) } + return "retry error: no underlying errors" } -// InvalidPromptError represents an invalid prompt error. -type InvalidPromptError struct { - *AIError - Prompt any -} - -// NewInvalidPromptError creates a new invalid prompt error. -func NewInvalidPromptError(prompt any, message string, cause error) *InvalidPromptError { - return &InvalidPromptError{ - AIError: NewAIError("AI_InvalidPromptError", fmt.Sprintf("Invalid prompt: %s", message), cause), - Prompt: prompt, +func (e RetryError) Unwrap() error { + if err, ok := slice.Last(e.Errors); ok { + return err } + return nil } -// InvalidResponseDataError represents invalid response data from the server. -type InvalidResponseDataError struct { - *AIError - Data any -} - -// NewInvalidResponseDataError creates a new invalid response data error. -func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError { - if message == "" { - dataJSON, _ := json.Marshal(data) - message = fmt.Sprintf("Invalid response data: %s.", string(dataJSON)) - } - return &InvalidResponseDataError{ - AIError: NewAIError("AI_InvalidResponseDataError", message, nil), - Data: data, - } -} - -// JSONParseError represents a JSON parsing error. -type JSONParseError struct { - *AIError - Text string -} - -// NewJSONParseError creates a new JSON parse error. -func NewJSONParseError(text string, cause error) *JSONParseError { - message := fmt.Sprintf("JSON parsing failed: Text: %s.\nError message: %s", text, GetErrorMessage(cause)) - return &JSONParseError{ - AIError: NewAIError("AI_JSONParseError", message, cause), - Text: text, - } -} - -// LoadAPIKeyError represents an error loading an API key. -type LoadAPIKeyError struct { - *AIError -} - -// NewLoadAPIKeyError creates a new load API key error. -func NewLoadAPIKeyError(message string) *LoadAPIKeyError { - return &LoadAPIKeyError{ - AIError: NewAIError("AI_LoadAPIKeyError", message, nil), - } -} - -// LoadSettingError represents an error loading a setting. -type LoadSettingError struct { - *AIError -} - -// NewLoadSettingError creates a new load setting error. -func NewLoadSettingError(message string) *LoadSettingError { - return &LoadSettingError{ - AIError: NewAIError("AI_LoadSettingError", message, nil), - } -} - -// NoContentGeneratedError is thrown when the AI provider fails to generate any content. -type NoContentGeneratedError struct { - *AIError -} - -// NewNoContentGeneratedError creates a new no content generated error. -func NewNoContentGeneratedError(message string) *NoContentGeneratedError { - if message == "" { - message = "No content generated." - } - return &NoContentGeneratedError{ - AIError: NewAIError("AI_NoContentGeneratedError", message, nil), - } -} - -// ModelType represents the type of model. -type ModelType string - -const ( - // ModelTypeLanguage represents a language model. - ModelTypeLanguage ModelType = "languageModel" - // ModelTypeTextEmbedding represents a text embedding model. - ModelTypeTextEmbedding ModelType = "textEmbeddingModel" - // ModelTypeImage represents an image model. - ModelTypeImage ModelType = "imageModel" - // ModelTypeTranscription represents a transcription model. - ModelTypeTranscription ModelType = "transcriptionModel" - // ModelTypeSpeech represents a speech model. - ModelTypeSpeech ModelType = "speechModel" -) - -// NoSuchModelError represents an error when a model is not found. -type NoSuchModelError struct { - *AIError - ModelID string - ModelType ModelType -} - -// NewNoSuchModelError creates a new no such model error. -func NewNoSuchModelError(modelID string, modelType ModelType, message string) *NoSuchModelError { - if message == "" { - message = fmt.Sprintf("No such %s: %s", modelType, modelID) - } - return &NoSuchModelError{ - AIError: NewAIError("AI_NoSuchModelError", message, nil), - ModelID: modelID, - ModelType: modelType, - } -} - -// TooManyEmbeddingValuesForCallError represents an error when too many values are provided for embedding. -type TooManyEmbeddingValuesForCallError struct { - *AIError - Provider string - ModelID string - MaxEmbeddingsPerCall int - Values []any -} - -// NewTooManyEmbeddingValuesForCallError creates a new too many embedding values error. -func NewTooManyEmbeddingValuesForCallError(provider, modelID string, maxEmbeddingsPerCall int, values []any) *TooManyEmbeddingValuesForCallError { - message := fmt.Sprintf( - "Too many values for a single embedding call. The %s model \"%s\" can only embed up to %d values per call, but %d values were provided.", - provider, modelID, maxEmbeddingsPerCall, len(values), - ) - return &TooManyEmbeddingValuesForCallError{ - AIError: NewAIError("AI_TooManyEmbeddingValuesForCallError", message, nil), - Provider: provider, - ModelID: modelID, - MaxEmbeddingsPerCall: maxEmbeddingsPerCall, - Values: values, - } -} - -// TypeValidationError represents a type validation error. -type TypeValidationError struct { - *AIError - Value any -} - -// NewTypeValidationError creates a new type validation error. -func NewTypeValidationError(value any, cause error) *TypeValidationError { - valueJSON, _ := json.Marshal(value) - message := fmt.Sprintf( - "Type validation failed: Value: %s.\nError message: %s", - string(valueJSON), GetErrorMessage(cause), - ) - return &TypeValidationError{ - AIError: NewAIError("AI_TypeValidationError", message, cause), - Value: value, - } -} - -// WrapTypeValidationError wraps an error into a TypeValidationError. -func WrapTypeValidationError(value any, cause error) *TypeValidationError { - if tvErr, ok := cause.(*TypeValidationError); ok && tvErr.Value == value { - return tvErr - } - return NewTypeValidationError(value, cause) -} - -// UnsupportedFunctionalityError represents an unsupported functionality error. -type UnsupportedFunctionalityError struct { - *AIError - Functionality string -} - -// NewUnsupportedFunctionalityError creates a new unsupported functionality error. -func NewUnsupportedFunctionalityError(functionality, message string) *UnsupportedFunctionalityError { - if message == "" { - message = fmt.Sprintf("'%s' functionality not supported.", functionality) - } - return &UnsupportedFunctionalityError{ - AIError: NewAIError("AI_UnsupportedFunctionalityError", message, nil), - Functionality: functionality, - } -} - -// GetErrorMessage extracts a message from an error. -func GetErrorMessage(err error) string { - if err == nil { - return "unknown error" - } - return err.Error() +// ErrorTitleForStatusCode returns a human-readable title for a given HTTP status code. +func ErrorTitleForStatusCode(statusCode int) string { + return strings.ToLower(http.StatusText(statusCode)) } diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 42f8d72f1039cabf9c85341120655b4a24125fa3..86c0e03c96e7a10cac7145620c9e6461c6a035f2 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -122,6 +122,8 @@ func WithHTTPClient(client option.HTTPClient) Option { func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) { clientOptions := make([]option.RequestOption, 0, 5+len(a.options.headers)) + clientOptions = append(clientOptions, option.WithMaxRetries(0)) + if a.options.apiKey != "" && !a.options.useBedrock { clientOptions = append(clientOptions, option.WithAPIKey(a.options.apiKey)) } @@ -202,7 +204,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil) + return nil, nil, &fantasy.Error{Title: "invalid argument", Message: "anthropic provider options should be *anthropic.ProviderOptions"} } } sendReasoning := true @@ -251,7 +253,7 @@ func (a languageModel) prepareParams(call fantasy.Call) (*anthropic.MessageNewPa } if isThinking { if thinkingBudget == 0 { - return nil, nil, fantasy.NewUnsupportedFunctionalityError("thinking requires budget", "") + return nil, nil, &fantasy.Error{Title: "no budget", Message: "thinking requires budget"} } params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget) if call.Temperature != nil { @@ -690,30 +692,6 @@ func toPrompt(prompt fantasy.Prompt, sendReasoningData bool) ([]anthropic.TextBl return systemBlocks, messages, warnings } -func (a languageModel) handleError(err error) error { - var apiErr *anthropic.Error - if errors.As(err, &apiErr) { - requestDump := apiErr.DumpRequest(true) - responseDump := apiErr.DumpResponse(true) - headers := map[string]string{} - for k, h := range apiErr.Response.Header { - v := h[len(h)-1] - headers[strings.ToLower(k)] = v - } - return fantasy.NewAPICallError( - apiErr.Error(), - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) - } - return err -} - func mapFinishReason(finishReason string) fantasy.FinishReason { switch finishReason { case "end_turn", "pause_turn", "stop_sequence": @@ -735,7 +713,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas } response, err := a.client.Messages.New(ctx, *params) if err != nil { - return nil, a.handleError(err) + return nil, toProviderErr(err) } var content []fantasy.Content @@ -968,7 +946,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } else { //nolint: revive yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: a.handleError(err), + Error: toProviderErr(err), }) return } diff --git a/providers/anthropic/error.go b/providers/anthropic/error.go new file mode 100644 index 0000000000000000000000000000000000000000..38393b550a76562b9268d25bda3499b38c6ba236 --- /dev/null +++ b/providers/anthropic/error.go @@ -0,0 +1,39 @@ +package anthropic + +import ( + "cmp" + "errors" + "net/http" + "strings" + + "charm.land/fantasy" + "github.com/charmbracelet/anthropic-sdk-go" +) + +func toProviderErr(err error) error { + var apiErr *anthropic.Error + if errors.As(err, &apiErr) { + return &fantasy.ProviderError{ + Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"), + Message: apiErr.Error(), + Cause: apiErr, + URL: apiErr.Request.URL.String(), + StatusCode: apiErr.StatusCode, + RequestBody: apiErr.DumpRequest(true), + ResponseHeaders: toHeaderMap(apiErr.Response.Header), + ResponseBody: apiErr.DumpResponse(true), + } + } + return err +} + +func toHeaderMap(in http.Header) (out map[string]string) { + out = make(map[string]string, len(in)) + for k, v := range in { + if l := len(v); l > 0 { + out[k] = v[l-1] + in[strings.ToLower(k)] = v + } + } + return out +} diff --git a/providers/google/error.go b/providers/google/error.go new file mode 100644 index 0000000000000000000000000000000000000000..710cffff61217f0c69bf525f5efb7083b26c3e3d --- /dev/null +++ b/providers/google/error.go @@ -0,0 +1,23 @@ +package google + +import ( + "cmp" + "errors" + + "charm.land/fantasy" + "google.golang.org/genai" +) + +func toProviderErr(err error) error { + var apiErr genai.APIError + if !errors.As(err, &apiErr) { + return err + } + return &fantasy.ProviderError{ + Message: apiErr.Message, + Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.Code), "provider request failed"), + Cause: err, + StatusCode: apiErr.Code, + ResponseBody: []byte(apiErr.Message), + } +} diff --git a/providers/google/google.go b/providers/google/google.go index c59ffb3c8d2542bc546b98f592572bab5a8063fa..c64cc379cb60467236cdad53923b1f3dfcc0fc0b 100644 --- a/providers/google/google.go +++ b/providers/google/google.go @@ -197,7 +197,7 @@ func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentC if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, nil, nil, fantasy.NewInvalidArgumentError("providerOptions", "google provider options should be *google.ProviderOptions", nil) + return nil, nil, nil, &fantasy.Error{Title: "invalid argument", Message: "google provider options should be *google.ProviderOptions"} } } @@ -514,7 +514,7 @@ func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fanta response, err := chat.SendMessage(ctx, depointerSlice(lastMessage.Parts)...) if err != nil { - return nil, err + return nil, toProviderErr(err) } return g.mapResponse(response, warnings) @@ -571,7 +571,7 @@ func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy. if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: err, + Error: toProviderErr(err), }) return } diff --git a/providers/openai/error.go b/providers/openai/error.go new file mode 100644 index 0000000000000000000000000000000000000000..fed072088b117f4c0e8cd28c38c3ee8c3bd73584 --- /dev/null +++ b/providers/openai/error.go @@ -0,0 +1,52 @@ +package openai + +import ( + "cmp" + "errors" + "io" + "net/http" + "strings" + + "charm.land/fantasy" + "github.com/openai/openai-go/v2" +) + +func toProviderErr(err error) error { + var apiErr *openai.Error + if errors.As(err, &apiErr) { + return &fantasy.ProviderError{ + Title: cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"), + Message: toProviderErrMessage(apiErr), + Cause: apiErr, + URL: apiErr.Request.URL.String(), + StatusCode: apiErr.StatusCode, + RequestBody: apiErr.DumpRequest(true), + ResponseHeaders: toHeaderMap(apiErr.Response.Header), + ResponseBody: apiErr.DumpResponse(true), + } + } + return err +} + +func toProviderErrMessage(apiErr *openai.Error) string { + if apiErr.Message != "" { + return apiErr.Message + } + + // For some OpenAI-compatible providers, the SDK is not always able to parse + // the error message correctly. + // Fallback to returning the raw response body in such cases. + data, _ := io.ReadAll(apiErr.Response.Body) + return string(data) +} + +func toHeaderMap(in http.Header) (out map[string]string) { + out = make(map[string]string, len(in)) + for k, v := range in { + if l := len(v); l > 0 { + out[k] = v[l-1] + in[strings.ToLower(k)] = v + } + } + return out +} diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 72d22808a27fce5b8c1688b94b248eb86bfdd114..ef85d2ccf0c35dd46bf09bb2cb26b768a43b725a 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -223,30 +223,6 @@ func (o languageModel) prepareParams(call fantasy.Call) (*openai.ChatCompletionN return params, warnings, nil } -func (o languageModel) handleError(err error) error { - var apiErr *openai.Error - if errors.As(err, &apiErr) { - requestDump := apiErr.DumpRequest(true) - responseDump := apiErr.DumpResponse(true) - headers := map[string]string{} - for k, h := range apiErr.Response.Header { - v := h[len(h)-1] - headers[strings.ToLower(k)] = v - } - return fantasy.NewAPICallError( - apiErr.Message, - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) - } - return err -} - // Generate implements fantasy.LanguageModel. func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings, err := o.prepareParams(call) @@ -255,11 +231,11 @@ func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas } response, err := o.client.Chat.Completions.New(ctx, *params) if err != nil { - return nil, o.handleError(err) + return nil, toProviderErr(err) } if len(response.Choices) == 0 { - return nil, errors.New("no response generated") + return nil, &fantasy.Error{Title: "no response", Message: "no response generated"} } choice := response.Choices[0] content := make([]fantasy.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations)) @@ -422,18 +398,18 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S // Does not exist var err error if toolCallDelta.Type != "function" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."} } if toolCallDelta.ID == "" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'id' to be a string."} } if toolCallDelta.Function.Name == "" { - err = fantasy.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.") + err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function.name' to be a string."} } if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: o.handleError(stream.Err()), + Error: toProviderErr(stream.Err()), }) return } @@ -563,7 +539,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } else { //nolint: revive yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: o.handleError(err), + Error: toProviderErr(err), }) return } diff --git a/providers/openai/language_model_hooks.go b/providers/openai/language_model_hooks.go index 576dcf46ebc8375ec0a74547ff7b5969c9fd692f..9ca2c64f520a1dbf01f95f46ac097689b2e6d045 100644 --- a/providers/openai/language_model_hooks.go +++ b/providers/openai/language_model_hooks.go @@ -45,7 +45,7 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"} } } diff --git a/providers/openai/openai.go b/providers/openai/openai.go index f3f20e2751ed1629afb41b169e3dc483ccb9ae57..7f0e8f14d3d95b57ea91f080af50d8fa364d8a27 100644 --- a/providers/openai/openai.go +++ b/providers/openai/openai.go @@ -134,6 +134,7 @@ func WithUseResponsesAPI() Option { // LanguageModel implements fantasy.Provider. func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.LanguageModel, error) { openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions)) + openaiClientOptions = append(openaiClientOptions, option.WithMaxRetries(0)) if o.options.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey)) diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index faa2f07c2c3b0b0862ad4da2135cdb7c19a126d0..9d90cf78bc7f4776599cd71b982d489227d59d2c 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "strings" @@ -649,39 +648,18 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti return openaiTools, openaiToolChoice, warnings } -func (o responsesLanguageModel) handleError(err error) error { - var apiErr *openai.Error - if errors.As(err, &apiErr) { - requestDump := apiErr.DumpRequest(true) - responseDump := apiErr.DumpResponse(true) - headers := map[string]string{} - for k, h := range apiErr.Response.Header { - v := h[len(h)-1] - headers[strings.ToLower(k)] = v - } - return fantasy.NewAPICallError( - apiErr.Message, - apiErr.Request.URL.String(), - string(requestDump), - apiErr.StatusCode, - headers, - string(responseDump), - apiErr, - false, - ) - } - return err -} - func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) response, err := o.client.Responses.New(ctx, *params) if err != nil { - return nil, o.handleError(err) + return nil, toProviderErr(err) } if response.Error.Message != "" { - return nil, o.handleError(fmt.Errorf("response error: %s (code: %s)", response.Error.Message, response.Error.Code)) + return nil, &fantasy.Error{ + Title: "provider error", + Message: fmt.Sprintf("%s (code: %s)", response.Error.Message, response.Error.Code), + } } var content []fantasy.Content @@ -1023,7 +1001,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: o.handleError(err), + Error: toProviderErr(err), }) return } diff --git a/providers/openaicompat/language_model_hooks.go b/providers/openaicompat/language_model_hooks.go index 5fa06818531331feb18b90d88f1c1948b1b961e8..d3c5a8c6b5cdbc67c58825399754325131e323fe 100644 --- a/providers/openaicompat/language_model_hooks.go +++ b/providers/openaicompat/language_model_hooks.go @@ -19,7 +19,7 @@ func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNe if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"} } } @@ -86,7 +86,7 @@ func StreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.Str if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err), + Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err}, }) return ctx, false } diff --git a/providers/openrouter/language_model_hooks.go b/providers/openrouter/language_model_hooks.go index 60488360bd2a7d07175c2a2a073df4171ebf850f..77cc839414d2aeb2b856a92a5b1b22d5b9946d66 100644 --- a/providers/openrouter/language_model_hooks.go +++ b/providers/openrouter/language_model_hooks.go @@ -21,7 +21,7 @@ func languagePrepareModelCall(_ fantasy.LanguageModel, params *openaisdk.ChatCom if v, ok := call.ProviderOptions[Name]; ok { providerOptions, ok = v.(*ProviderOptions) if !ok { - return nil, fantasy.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil) + return nil, &fantasy.Error{Title: "invalid argument", Message: "openrouter provider options should be *openrouter.ProviderOptions"} } } @@ -180,7 +180,7 @@ func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(fa if err != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewAIError("Unexpected", "error unmarshalling delta", err), + Error: &fantasy.Error{Title: "stream error", Message: "error unmarshalling delta", Cause: err}, }) return ctx, false } diff --git a/retry.go b/retry.go index 23581fd2286ea7e676b51316e42a7514bfd26789..6b2c2a412f728b589099f0fe8e41f4e1e93f78a8 100644 --- a/retry.go +++ b/retry.go @@ -3,7 +3,6 @@ package fantasy import ( "context" "errors" - "fmt" "strconv" "time" ) @@ -14,40 +13,14 @@ type RetryFn[T any] func() (T, error) // RetryFunction is a function that retries another function. type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error) -// RetryReason represents the reason why a retry operation failed. -type RetryReason string - -const ( - // RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded. - RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded" - // RetryReasonErrorNotRetryable indicates the error is not retryable. - RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable" -) - -// RetryError represents an error that occurred during retry operations. -type RetryError struct { - *AIError - Reason RetryReason - Errors []error -} - -// NewRetryError creates a new retry error. -func NewRetryError(message string, reason RetryReason, errors []error) *RetryError { - return &RetryError{ - AIError: NewAIError("AI_RetryError", message, nil), - Reason: reason, - Errors: errors, - } -} - // getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff. func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration { - var apiErr *APICallError - if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil { + var providerErr *ProviderError + if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil { return exponentialBackoffDelay } - headers := apiErr.ResponseHeaders + headers := providerErr.ResponseHeaders var ms time.Duration // retry-ms is more precise than retry-after and used by e.g. OpenAI @@ -101,7 +74,7 @@ type RetryOptions struct { } // OnRetryCallback defines a function that is called when a retry occurs. -type OnRetryCallback = func(err *APICallError, delay time.Duration) +type OnRetryCallback = func(err *ProviderError, delay time.Duration) // DefaultRetryOptions returns the default retry options. // DefaultRetryOptions returns the default retry options. @@ -129,23 +102,18 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti return zero, err // don't wrap the error when retries are disabled } - errorMessage := GetErrorMessage(err) newErrors := append(allErrors, err) tryNumber := len(newErrors) if tryNumber > options.MaxRetries { - return zero, NewRetryError( - fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage), - RetryReasonMaxRetriesExceeded, - newErrors, - ) + return zero, &RetryError{newErrors} } - var apiErr *APICallError - if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries { + var providerErr *ProviderError + if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries { delay := getRetryDelayInMs(err, options.InitialDelayIn) if options.OnRetry != nil { - options.OnRetry(apiErr, delay) + options.OnRetry(providerErr, delay) } select { @@ -165,9 +133,5 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti return zero, err // don't wrap the error when a non-retryable error occurs on the first try } - return zero, NewRetryError( - fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage), - RetryReasonErrorNotRetryable, - newErrors, - ) + return zero, &RetryError{newErrors} }