From fac77962e90277d561ba1f58207756e1ad144805 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 21 Aug 2025 21:53:48 +0200 Subject: [PATCH] feat: initial agent implementation still need to setup streaming --- agent.go | 552 ++++++++++++++++++++++++++++--- content.go | 26 +- errors.go | 30 +- provider.go | 20 -- providers/examples/agent/main.go | 72 ++++ providers/openai.go | 30 +- providers/openai_test.go | 3 +- retry.go | 170 ++++++++++ tool.go | 234 +++++++++++++ util.go | 21 ++ 10 files changed, 1067 insertions(+), 91 deletions(-) create mode 100644 providers/examples/agent/main.go create mode 100644 retry.go create mode 100644 tool.go create mode 100644 util.go diff --git a/agent.go b/agent.go index bcc63e8285c1dead6c31536f6f8a10d1e6649415..4d43855a7985b420168c2f97af90babf8b378615 100644 --- a/agent.go +++ b/agent.go @@ -2,45 +2,96 @@ package ai import ( "context" + "errors" + "maps" + "slices" + "sync" + + "github.com/charmbracelet/crush/internal/llm/tools" ) -type StepResponse struct { +type StepResult struct { Response // Messages generated during this step Messages []Message } -type StepCondition = func(steps []StepResponse) bool +type StopCondition = func(steps []StepResult) bool type PrepareStepFunctionOptions struct { - Steps []StepResponse + Steps []StepResult StepNumber int Model LanguageModel Messages []Message } type PrepareStepResult struct { - SystemPrompt string - Model LanguageModel - Messages []Message + Model LanguageModel + Messages []Message } -type PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult - -type OnStepFinishedFunction = func(step StepResponse) +type ( + PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult + OnStepFinishedFunction = func(step StepResult) + RepairToolCall = func(ToolCallContent) ToolCallContent +) type AgentSettings struct { - Call - Model LanguageModel + systemPrompt string + maxOutputTokens *int64 + temperature *float64 + topP *float64 + topK *int64 + presencePenalty *float64 + frequencyPenalty *float64 + headers map[string]string + providerOptions ProviderOptions + + // TODO: add support for provider tools + tools []tools.BaseTool + maxRetries *int + + model LanguageModel + + stopWhen []StopCondition + prepareStep PrepareStepFunction + repairToolCall RepairToolCall + onStepFinished OnStepFinishedFunction + onRetry OnRetryCallback +} - StopWhen []StepCondition +type AgentCall struct { + Prompt string `json:"prompt"` + Files []FilePart `json:"files"` + Messages []Message `json:"messages"` + MaxOutputTokens *int64 + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int64 `json:"top_k"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + ActiveTools []string `json:"active_tools"` + Headers map[string]string + ProviderOptions ProviderOptions + OnRetry OnRetryCallback + MaxRetries *int + + StopWhen []StopCondition PrepareStep PrepareStepFunction + RepairToolCall RepairToolCall OnStepFinished OnStepFinishedFunction } +type AgentResult struct { + Steps []StepResult + // Final response + Response Response + TotalUsage Usage +} + type Agent interface { - Generate(context.Context, Call) (*Response, error) - Stream(context.Context, Call) (StreamResponse, error) + Generate(context.Context, AgentCall) (*AgentResult, error) + Stream(context.Context, AgentCall) (StreamResponse, error) } type agentOption = func(*AgentSettings) @@ -51,7 +102,7 @@ type agent struct { func NewAgent(model LanguageModel, opts ...agentOption) Agent { settings := AgentSettings{ - Model: model, + model: model, } for _, o := range opts { o(&settings) @@ -61,48 +112,465 @@ func NewAgent(model LanguageModel, opts ...agentOption) Agent { } } -func mergeCall(agentOpts, opts Call) Call { - if len(opts.Prompt) > 0 { - agentOpts.Prompt = opts.Prompt +func (a *agent) prepareCall(call AgentCall) AgentCall { + if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil { + call.MaxOutputTokens = a.settings.maxOutputTokens + } + if call.Temperature == nil && a.settings.temperature != nil { + call.Temperature = a.settings.temperature } - if opts.MaxOutputTokens != nil { - agentOpts.MaxOutputTokens = opts.MaxOutputTokens + if call.TopP == nil && a.settings.topP != nil { + call.TopP = a.settings.topP } - if opts.Temperature != nil { - agentOpts.Temperature = opts.Temperature + if call.TopK == nil && a.settings.topK != nil { + call.TopK = a.settings.topK } - if opts.TopP != nil { - agentOpts.TopP = opts.TopP + if call.PresencePenalty == nil && a.settings.presencePenalty != nil { + call.PresencePenalty = a.settings.presencePenalty } - if opts.TopK != nil { - agentOpts.TopK = opts.TopK + if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil { + call.FrequencyPenalty = a.settings.frequencyPenalty } - if opts.PresencePenalty != nil { - agentOpts.PresencePenalty = opts.PresencePenalty + if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 { + call.StopWhen = a.settings.stopWhen } - if opts.FrequencyPenalty != nil { - agentOpts.FrequencyPenalty = opts.FrequencyPenalty + if call.PrepareStep == nil && a.settings.prepareStep != nil { + call.PrepareStep = a.settings.prepareStep } - if opts.Tools != nil { - agentOpts.Tools = opts.Tools + if call.RepairToolCall == nil && a.settings.repairToolCall != nil { + call.RepairToolCall = a.settings.repairToolCall } - if opts.Headers != nil { - agentOpts.Headers = opts.Headers + if call.OnStepFinished == nil && a.settings.onStepFinished != nil { + call.OnStepFinished = a.settings.onStepFinished } - if opts.ProviderOptions != nil { - agentOpts.ProviderOptions = opts.ProviderOptions + if call.OnRetry == nil && a.settings.onRetry != nil { + call.OnRetry = a.settings.onRetry + } + if call.MaxRetries == nil && a.settings.maxRetries != nil { + call.MaxRetries = a.settings.maxRetries + } + + providerOptions := ProviderOptions{} + if a.settings.providerOptions != nil { + maps.Copy(providerOptions, a.settings.providerOptions) } - return agentOpts + if call.ProviderOptions != nil { + maps.Copy(providerOptions, call.ProviderOptions) + } + call.ProviderOptions = providerOptions + + headers := map[string]string{} + + if a.settings.headers != nil { + maps.Copy(headers, a.settings.headers) + } + + if call.Headers != nil { + maps.Copy(headers, call.Headers) + } + call.Headers = headers + return call } // Generate implements Agent. -func (a *agent) Generate(ctx context.Context, opts Call) (*Response, error) { - // TODO: implement the agentic stuff - return a.settings.Model.Generate(ctx, mergeCall(a.settings.Call, opts)) +func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { + opts = a.prepareCall(opts) + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + if err != nil { + return nil, err + } + var responseMessages []Message + var steps []StepResult + + for { + stepInputMessages := append(initialPrompt, responseMessages...) + stepModel := a.settings.model + if opts.PrepareStep != nil { + prepared := opts.PrepareStep(PrepareStepFunctionOptions{ + Model: stepModel, + Steps: steps, + StepNumber: len(steps), + Messages: stepInputMessages, + }) + stepInputMessages = prepared.Messages + if prepared.Model != nil { + stepModel = prepared.Model + } + } + + preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools) + + toolChoice := ToolChoiceAuto + retryOptions := DefaultRetryOptions() + retryOptions.OnRetry = opts.OnRetry + retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions) + + result, err := retry(ctx, func() (*Response, error) { + return stepModel.Generate(ctx, Call{ + Prompt: stepInputMessages, + MaxOutputTokens: opts.MaxOutputTokens, + Temperature: opts.Temperature, + TopP: opts.TopP, + TopK: opts.TopK, + PresencePenalty: opts.PresencePenalty, + FrequencyPenalty: opts.FrequencyPenalty, + Tools: preparedTools, + ToolChoice: &toolChoice, + Headers: opts.Headers, + ProviderOptions: opts.ProviderOptions, + }) + }) + if err != nil { + return nil, err + } + + var stepToolCalls []ToolCallContent + for _, content := range result.Content { + if content.GetType() == ContentTypeToolCall { + toolCall, ok := AsContentType[ToolCallContent](content) + if !ok { + continue + } + stepToolCalls = append(stepToolCalls, toolCall) + } + } + + toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls) + + stepContent := result.Content + for _, result := range toolResults { + stepContent = append(stepContent, result) + } + currentStepMessages := toResponseMessages(stepContent) + responseMessages = append(responseMessages, currentStepMessages...) + + stepResult := StepResult{ + Response: *result, + Messages: currentStepMessages, + } + steps = append(steps, stepResult) + if opts.OnStepFinished != nil { + opts.OnStepFinished(stepResult) + } + + shouldStop := isStopConditionMet(opts.StopWhen, steps) + + if shouldStop || err != nil || len(stepToolCalls) == 0 { + break + } + } + + totalUsage := Usage{} + + for _, step := range steps { + usage := step.Usage + totalUsage.InputTokens += usage.InputTokens + totalUsage.OutputTokens += usage.OutputTokens + totalUsage.ReasoningTokens += usage.ReasoningTokens + totalUsage.CacheCreationTokens += usage.CacheCreationTokens + totalUsage.CacheReadTokens += usage.CacheReadTokens + totalUsage.TotalTokens += totalUsage.TotalTokens + } + + agentResult := &AgentResult{ + Steps: steps, + Response: steps[len(steps)-1].Response, + TotalUsage: totalUsage, + } + return agentResult, nil +} + +func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool { + if len(conditions) == 0 { + return false + } + + for _, condition := range conditions { + if condition(steps) { + return true + } + } + return false +} + +func toResponseMessages(content []Content) []Message { + var assistantParts []MessagePart + var toolParts []MessagePart + + for _, c := range content { + switch c.GetType() { + case ContentTypeText: + text, ok := AsContentType[TextContent](c) + if !ok { + continue + } + assistantParts = append(assistantParts, TextPart{ + Text: text.Text, + ProviderOptions: ProviderOptions(text.ProviderMetadata), + }) + case ContentTypeReasoning: + reasoning, ok := AsContentType[ReasoningContent](c) + if !ok { + continue + } + assistantParts = append(assistantParts, ReasoningPart{ + Text: reasoning.Text, + ProviderOptions: ProviderOptions(reasoning.ProviderMetadata), + }) + case ContentTypeToolCall: + toolCall, ok := AsContentType[ToolCallContent](c) + if !ok { + continue + } + assistantParts = append(assistantParts, ToolCallPart{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Input: toolCall.Input, + ProviderExecuted: toolCall.ProviderExecuted, + ProviderOptions: ProviderOptions(toolCall.ProviderMetadata), + }) + case ContentTypeFile: + file, ok := AsContentType[FileContent](c) + if !ok { + continue + } + assistantParts = append(assistantParts, FilePart{ + Data: file.Data, + MediaType: file.MediaType, + ProviderOptions: ProviderOptions(file.ProviderMetadata), + }) + case ContentTypeToolResult: + result, ok := AsContentType[ToolResultContent](c) + if !ok { + continue + } + toolParts = append(toolParts, ToolResultPart{ + ToolCallID: result.ToolCallID, + Output: result.Result, + ProviderOptions: ProviderOptions(result.ProviderMetadata), + }) + } + } + + var messages []Message + if len(assistantParts) > 0 { + messages = append(messages, Message{ + Role: MessageRoleAssistant, + Content: assistantParts, + }) + } + if len(toolParts) > 0 { + messages = append(messages, Message{ + Role: MessageRoleTool, + Content: toolParts, + }) + } + return messages +} + +func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) { + if len(toolCalls) == 0 { + return nil, nil + } + + // Create a map for quick tool lookup + toolMap := make(map[string]tools.BaseTool) + for _, tool := range allTools { + toolMap[tool.Info().Name] = tool + } + + // Execute all tool calls in parallel + results := make([]ToolResultContent, len(toolCalls)) + var toolExecutionError error + var wg sync.WaitGroup + + for i, toolCall := range toolCalls { + wg.Add(1) + go func(index int, call ToolCallContent) { + defer wg.Done() + + tool, exists := toolMap[call.ToolName] + if !exists { + results[index] = ToolResultContent{ + ToolCallID: call.ToolCallID, + ToolName: call.ToolName, + Result: ToolResultOutputContentError{ + Error: errors.New("Error: Tool not found: " + call.ToolName), + }, + ProviderExecuted: false, + } + return + } + + // Execute the tool + result, err := tool.Run(ctx, tools.ToolCall{ + ID: call.ToolCallID, + Name: call.ToolName, + Input: call.Input, + }) + if err != nil { + results[index] = ToolResultContent{ + ToolCallID: call.ToolCallID, + ToolName: call.ToolName, + Result: ToolResultOutputContentError{ + Error: err, + }, + ProviderExecuted: false, + } + toolExecutionError = err + return + } + + if result.IsError { + results[index] = ToolResultContent{ + ToolCallID: call.ToolCallID, + ToolName: call.ToolName, + Result: ToolResultOutputContentError{ + Error: errors.New(result.Content), + }, + ProviderExecuted: false, + } + } else { + results[index] = ToolResultContent{ + ToolCallID: call.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentText{ + Text: result.Content, + }, + ProviderExecuted: false, + } + } + }(i, toolCall) + } + + // Wait for all tool executions to complete + wg.Wait() + + return results, toolExecutionError } // Stream implements Agent. -func (a *agent) Stream(ctx context.Context, opts Call) (StreamResponse, error) { +func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) { // TODO: implement the agentic stuff - return a.settings.Model.Stream(ctx, mergeCall(a.settings.Call, opts)) + panic("not implemented") +} + +func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool { + var preparedTools []Tool + for _, tool := range tools { + if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) { + continue + } + info := tool.Info() + preparedTools = append(preparedTools, FunctionTool{ + Name: info.Name, + Description: info.Description, + InputSchema: map[string]any{ + "type": "object", + "properties": info.Parameters, + "required": info.Required, + }, + }) + } + return preparedTools +} + +func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) { + if prompt == "" { + return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil) + } + + var preparedPrompt Prompt + + if system != "" { + preparedPrompt = append(preparedPrompt, NewSystemMessage(system)) + } + + preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...)) + preparedPrompt = append(preparedPrompt, messages...) + return preparedPrompt, nil +} + +func WithSystemPrompt(prompt string) agentOption { + return func(s *AgentSettings) { + s.systemPrompt = prompt + } +} + +func WithMaxOutputTokens(tokens int64) agentOption { + return func(s *AgentSettings) { + s.maxOutputTokens = &tokens + } +} + +func WithTemperature(temp float64) agentOption { + return func(s *AgentSettings) { + s.temperature = &temp + } +} + +func WithTopP(topP float64) agentOption { + return func(s *AgentSettings) { + s.topP = &topP + } +} + +func WithTopK(topK int64) agentOption { + return func(s *AgentSettings) { + s.topK = &topK + } +} + +func WithPresencePenalty(penalty float64) agentOption { + return func(s *AgentSettings) { + s.presencePenalty = &penalty + } +} + +func WithFrequencyPenalty(penalty float64) agentOption { + return func(s *AgentSettings) { + s.frequencyPenalty = &penalty + } +} + +func WithTools(tools ...tools.BaseTool) agentOption { + return func(s *AgentSettings) { + s.tools = append(s.tools, tools...) + } +} + +func WithStopConditions(conditions ...StopCondition) agentOption { + return func(s *AgentSettings) { + s.stopWhen = append(s.stopWhen, conditions...) + } +} + +func WithPrepareStep(fn PrepareStepFunction) agentOption { + return func(s *AgentSettings) { + s.prepareStep = fn + } +} + +func WithRepairToolCall(fn RepairToolCall) agentOption { + return func(s *AgentSettings) { + s.repairToolCall = fn + } +} + +func WithOnStepFinished(fn OnStepFinishedFunction) agentOption { + return func(s *AgentSettings) { + s.onStepFinished = fn + } +} + +func WithHeaders(headers map[string]string) agentOption { + return func(s *AgentSettings) { + s.headers = headers + } +} + +func WithProviderOptions(providerOptions ProviderOptions) agentOption { + return func(s *AgentSettings) { + s.providerOptions = providerOptions + } } diff --git a/content.go b/content.go index 094a0e6e592fb4ce4aea69cb72995c1b3c0e0f7a..0359df767c80e3d20c5cc3a1d5884154a8a6b85c 100644 --- a/content.go +++ b/content.go @@ -184,7 +184,7 @@ func (t ToolResultOutputContentText) GetType() ToolResultContentType { } type ToolResultOutputContentError struct { - Error string `json:"error"` + Error error `json:"error"` } func (t ToolResultOutputContentError) GetType() ToolResultContentType { @@ -268,11 +268,9 @@ type FileContent struct { // The IANA media type of the file, e.g. `image/png` or `audio/mp3`. // @see https://www.iana.org/assignments/media-types/media-types.xhtml MediaType string `json:"media_type"` - // Generated file data as base64 encoded strings or binary data. - // If the API returns base64 encoded strings, the file data should be returned - // as base64 encoded strings. If the API returns binary data, the file data should - // be returned as binary data. - Data any `json:"data"` // string (base64) or []byte + // Generated file data as binary data. + Data []byte `json:"data"` + ProviderMetadata ProviderMetadata `json:"provider_metadata"` } // GetType returns the type of the file content. @@ -332,9 +330,7 @@ type ToolResultContent struct { // Name of the tool that generated this result. ToolName string `json:"tool_name"` // Result of the tool call. This is a JSON-serializable object. - Result any `json:"result"` - // Optional flag if the result is an error or an error message. - IsError bool `json:"is_error"` + Result ToolResultOutputContent `json:"result"` // Whether the tool result was generated by the provider. // If this flag is set to true, the tool result was generated by the provider. // If this flag is not set or is false, the tool result was generated by the client. @@ -430,3 +426,15 @@ func NewUserMessage(prompt string, files ...FilePart) Message { Content: content, } } + +func NewSystemMessage(prompt ...string) Message { + var content []MessagePart + for _, p := range prompt { + content = append(content, TextPart{Text: p}) + } + + return Message{ + Role: MessageRoleSystem, + Content: content, + } +} diff --git a/errors.go b/errors.go index 8cbd9748be0656d3d4d1a57553f6b4517f4a1610..293f4e1f6abfd3c5b710a53cc6089094bb05f324 100644 --- a/errors.go +++ b/errors.go @@ -46,30 +46,28 @@ func IsAIError(err error) bool { // APICallError represents an error from an API call. type APICallError struct { *AIError - URL string - RequestBodyValues any - StatusCode int - ResponseHeaders map[string]string - ResponseBody string - IsRetryable bool - Data any + URL string + RequestDump string + StatusCode int + ResponseHeaders map[string]string + ResponseDump string + IsRetryable bool } // NewAPICallError creates a new API call error. -func NewAPICallError(message, url string, requestBodyValues any, statusCode int, responseHeaders map[string]string, responseBody string, cause error, isRetryable bool, data any) *APICallError { +func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError { if !isRetryable && statusCode != 0 { isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500 } return &APICallError{ - AIError: NewAIError("AI_APICallError", message, cause), - URL: url, - RequestBodyValues: requestBodyValues, - StatusCode: statusCode, - ResponseHeaders: responseHeaders, - ResponseBody: responseBody, - IsRetryable: isRetryable, - Data: data, + AIError: NewAIError("AI_APICallError", message, cause), + URL: url, + RequestDump: requestDump, + StatusCode: statusCode, + ResponseHeaders: responseHeaders, + ResponseDump: responseDump, + IsRetryable: isRetryable, } } diff --git a/provider.go b/provider.go index 877326150d1efe3c5515c79f162d0cbfc871f4f7..b6d0e16a09c7c78646c4273cd9690f23f4199ab8 100644 --- a/provider.go +++ b/provider.go @@ -1,26 +1,6 @@ package ai -import ( - "encoding/json" - - "github.com/go-viper/mapstructure/v2" -) - type Provider interface { LanguageModel(modelID string) LanguageModel // TODO: add other model types when needed } - -func ParseOptions[T any](options map[string]any, m *T) error { - return mapstructure.Decode(options, m) -} - -func FloatOption(f float64) *float64 { - return &f -} - -func IsParsableJSON(data string) bool { - var m map[string]any - err := json.Unmarshal([]byte(data), &m) - return err == nil -} diff --git a/providers/examples/agent/main.go b/providers/examples/agent/main.go new file mode 100644 index 0000000000000000000000000000000000000000..7076a6fd3b458c892a6dec5e40531b67a0866f22 --- /dev/null +++ b/providers/examples/agent/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/charmbracelet/crush/internal/ai" + "github.com/charmbracelet/crush/internal/ai/providers" + "github.com/charmbracelet/crush/internal/llm/tools" +) + +type weatherTool struct{} + +// Info implements tools.BaseTool. +func (w *weatherTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: "weather", + Parameters: map[string]any{ + "location": map[string]string{ + "type": "string", + "description": "the city", + }, + }, + Required: []string{"location"}, + } +} + +// Name implements tools.BaseTool. +func (w *weatherTool) Name() string { + return "weather" +} + +// Run implements tools.BaseTool. +func (w *weatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + return tools.NewTextResponse("40 C"), nil +} + +func newWeatherTool() tools.BaseTool { + return &weatherTool{} +} + +func main() { + provider := providers.NewOpenAIProvider( + providers.WithOpenAIApiKey(os.Getenv("OPENAI_API_KEY")), + ) + model := provider.LanguageModel("gpt-4o") + + agent := ai.NewAgent( + model, + ai.WithSystemPrompt("You are a helpful assistant"), + ai.WithTools(newWeatherTool()), + ) + + result, _ := agent.Generate(context.Background(), ai.AgentCall{ + Prompt: "What's the weather in pristina", + }) + + fmt.Println("Steps: ", len(result.Steps)) + for _, s := range result.Steps { + for _, c := range s.Content { + if c.GetType() == ai.ContentTypeToolCall { + tc, _ := ai.AsContentType[ai.ToolCallContent](c) + fmt.Println("ToolCall: ", tc.ToolName) + + } + } + } + + fmt.Println("Final Response: ", result.Response.Content.Text()) + fmt.Println("Total Usage: ", result.TotalUsage) +} diff --git a/providers/openai.go b/providers/openai.go index dcfcc9d64872ad87dba1591d881dae26c4ad79ff..d6ff2ac9fa78d6c21244f151df199c352cc6e18c 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -394,6 +394,30 @@ func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletion return params, warnings, nil } +func (o openAILanguageModel) handleError(err error) error { + var apiErr *openai.Error + if errors.As(err, &apiErr) { + requestDump := apiErr.DumpRequest(true) + responseDump := apiErr.DumpResponse(true) + headers := map[string]string{} + for k, h := range apiErr.Response.Header { + v := h[len(h)-1] + headers[strings.ToLower(k)] = v + } + return ai.NewAPICallError( + apiErr.Message, + apiErr.Request.URL.String(), + string(requestDump), + apiErr.StatusCode, + headers, + string(responseDump), + apiErr, + false, + ) + } + return err +} + // Generate implements ai.LanguageModel. func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { params, warnings, err := o.prepareParams(call) @@ -402,7 +426,7 @@ func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Re } response, err := o.client.Chat.Completions.New(ctx, *params) if err != nil { - return nil, err + return nil, o.handleError(err) } if len(response.Choices) == 0 { @@ -626,7 +650,7 @@ func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea if err != nil { yield(ai.StreamPart{ Type: ai.StreamPartTypeError, - Error: stream.Err(), + Error: o.handleError(stream.Err()), }) return } @@ -1097,7 +1121,7 @@ func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, }) continue } - messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID)) + messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID)) } } } diff --git a/providers/openai_test.go b/providers/openai_test.go index 1ad9479239654a9ae9a57d44152e63ba33c0e6ad..1b52399f40884bd40a45e71a14a3276cc68eb497 100644 --- a/providers/openai_test.go +++ b/providers/openai_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -496,7 +497,7 @@ func TestToOpenAIPrompt_ToolCalls(t *testing.T) { ai.ToolResultPart{ ToolCallID: "error-tool", Output: ai.ToolResultOutputContentError{ - Error: "Something went wrong", + Error: errors.New("Something went wrong"), }, }, }, diff --git a/retry.go b/retry.go new file mode 100644 index 0000000000000000000000000000000000000000..71d209c231579144acfa9645a76998d61132afbe --- /dev/null +++ b/retry.go @@ -0,0 +1,170 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" +) + +// RetryFn is a function that returns a value and an error. +type RetryFn[T any] func() (T, error) + +// RetryFunction is a function that retries another function. +type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error) + +// RetryReason represents the reason why a retry operation failed. +type RetryReason string + +const ( + RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded" + RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable" +) + +// RetryError represents an error that occurred during retry operations. +type RetryError struct { + *AIError + Reason RetryReason + Errors []error +} + +// NewRetryError creates a new retry error. +func NewRetryError(message string, reason RetryReason, errors []error) *RetryError { + return &RetryError{ + AIError: NewAIError("AI_RetryError", message, nil), + Reason: reason, + Errors: errors, + } +} + +// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff. +func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration { + var apiErr *APICallError + if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil { + return exponentialBackoffDelay + } + + headers := apiErr.ResponseHeaders + var ms time.Duration + + // retry-ms is more precise than retry-after and used by e.g. OpenAI + if retryAfterMs, exists := headers["retry-after-ms"]; exists { + if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil { + ms = time.Duration(timeoutMs) * time.Millisecond + } + } + + // About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + if retryAfter, exists := headers["retry-after"]; exists && ms == 0 { + if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil { + ms = time.Duration(timeoutSeconds) * time.Second + } else { + // Try parsing as HTTP date + if t, err := time.Parse(time.RFC1123, retryAfter); err == nil { + ms = time.Until(t) + } + } + } + + // Check that the delay is reasonable: + // 0 <= ms < 60 seconds or ms < exponentialBackoffDelay + if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) { + return ms + } + + return exponentialBackoffDelay +} + +// isAbortError checks if the error is a context cancellation error. +func isAbortError(err error) bool { + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +} + +// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries +// a failed operation with exponential backoff, while respecting rate limit headers +// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds). +func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] { + return func(ctx context.Context, fn RetryFn[T]) (T, error) { + return retryWithExponentialBackoff(ctx, fn, options, nil) + } +} + +// RetryOptions configures the retry behavior. +type RetryOptions struct { + MaxRetries int + InitialDelayIn time.Duration + BackoffFactor float64 + OnRetry OnRetryCallback +} + +type OnRetryCallback = func(err *APICallError, delay time.Duration) + +// DefaultRetryOptions returns the default retry options. +func DefaultRetryOptions() RetryOptions { + return RetryOptions{ + MaxRetries: 2, + InitialDelayIn: 2000 * time.Millisecond, + BackoffFactor: 2.0, + } +} + +// retryWithExponentialBackoff implements the retry logic with exponential backoff. +func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) { + var zero T + result, err := fn() + if err == nil { + return result, nil + } + + if isAbortError(err) { + return zero, err // don't retry when the request was aborted + } + + if options.MaxRetries == 0 { + return zero, err // don't wrap the error when retries are disabled + } + + errorMessage := GetErrorMessage(err) + newErrors := append(allErrors, err) + tryNumber := len(newErrors) + + if tryNumber > options.MaxRetries { + return zero, NewRetryError( + fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage), + RetryReasonMaxRetriesExceeded, + newErrors, + ) + } + + var apiErr *APICallError + if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries { + delay := getRetryDelayInMs(err, options.InitialDelayIn) + if options.OnRetry != nil { + options.OnRetry(apiErr, delay) + } + + select { + case <-time.After(delay): + // Continue with retry + case <-ctx.Done(): + return zero, ctx.Err() + } + + newOptions := options + newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor) + + return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors) + } + + if tryNumber == 1 { + return zero, err // don't wrap the error when a non-retryable error occurs on the first try + } + + return zero, NewRetryError( + fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage), + RetryReasonErrorNotRetryable, + newErrors, + ) +} + diff --git a/tool.go b/tool.go new file mode 100644 index 0000000000000000000000000000000000000000..b0c7c518c0852cf78f8d3687a284c49a552ff220 --- /dev/null +++ b/tool.go @@ -0,0 +1,234 @@ +// WIP NEED TO REVISIT +package ai + +import ( + "context" + "encoding/json" + "fmt" +) + +// AgentTool represents a function that can be called by a language model. +type AgentTool interface { + Name() string + Description() string + InputSchema() Schema + Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) +} + +// Schema represents a JSON schema for tool input validation. +type Schema struct { + Type string `json:"type"` + Properties map[string]*Schema `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + Items *Schema `json:"items,omitempty"` + Description string `json:"description,omitempty"` + Enum []any `json:"enum,omitempty"` + Format string `json:"format,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` +} + +// BasicTool provides a basic implementation of the Tool interface +// +// Example usage: +// +// calculator := &tools.BasicTool{ +// ToolName: "calculate", +// ToolDescription: "Evaluates mathematical expressions", +// ToolInputSchema: tools.Schema{ +// Type: "object", +// Properties: map[string]*tools.Schema{ +// "expression": { +// Type: "string", +// Description: "Mathematical expression to evaluate", +// }, +// }, +// Required: []string{"expression"}, +// }, +// ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { +// var req struct { +// Expression string `json:"expression"` +// } +// if err := json.Unmarshal(input, &req); err != nil { +// return nil, err +// } +// result := evaluateExpression(req.Expression) +// return json.Marshal(map[string]any{"result": result}) +// }, +// } +type BasicTool struct { + ToolName string + ToolDescription string + ToolInputSchema Schema + ExecuteFunc func(context.Context, json.RawMessage) (json.RawMessage, error) +} + +// Name returns the tool name. +func (t *BasicTool) Name() string { + return t.ToolName +} + +// Description returns the tool description. +func (t *BasicTool) Description() string { + return t.ToolDescription +} + +// InputSchema returns the tool input schema. +func (t *BasicTool) InputSchema() Schema { + return t.ToolInputSchema +} + +// Execute executes the tool with the given input. +func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { + if t.ExecuteFunc == nil { + return nil, fmt.Errorf("tool %s has no execute function", t.ToolName) + } + return t.ExecuteFunc(ctx, input) +} + +// ToolBuilder provides a fluent interface for building tools. +type ToolBuilder struct { + tool *BasicTool +} + +// NewTool creates a new tool builder. +func NewTool(name string) *ToolBuilder { + return &ToolBuilder{ + tool: &BasicTool{ + ToolName: name, + }, + } +} + +// Description sets the tool description. +func (b *ToolBuilder) Description(desc string) *ToolBuilder { + b.tool.ToolDescription = desc + return b +} + +// InputSchema sets the tool input schema. +func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder { + b.tool.ToolInputSchema = schema + return b +} + +// Execute sets the tool execution function. +func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder { + b.tool.ExecuteFunc = fn + return b +} + +// Build creates the final tool. +func (b *ToolBuilder) Build() AgentTool { + return b.tool +} + +// SchemaBuilder provides a fluent interface for building JSON schemas. +type SchemaBuilder struct { + schema Schema +} + +// NewSchema creates a new schema builder. +func NewSchema(schemaType string) *SchemaBuilder { + return &SchemaBuilder{ + schema: Schema{ + Type: schemaType, + }, + } +} + +// Object creates a schema builder for an object type. +func Object() *SchemaBuilder { + return NewSchema("object") +} + +// String creates a schema builder for a string type. +func String() *SchemaBuilder { + return NewSchema("string") +} + +// Number creates a schema builder for a number type. +func Number() *SchemaBuilder { + return NewSchema("number") +} + +// Array creates a schema builder for an array type. +func Array() *SchemaBuilder { + return NewSchema("array") +} + +// Description sets the schema description. +func (b *SchemaBuilder) Description(desc string) *SchemaBuilder { + b.schema.Description = desc + return b +} + +// Properties sets the schema properties. +func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder { + b.schema.Properties = props + return b +} + +// Property adds a property to the schema. +func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder { + if b.schema.Properties == nil { + b.schema.Properties = make(map[string]*Schema) + } + b.schema.Properties[name] = schema + return b +} + +// Required marks fields as required. +func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder { + b.schema.Required = append(b.schema.Required, fields...) + return b +} + +// Items sets the schema for array items. +func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder { + b.schema.Items = schema + return b +} + +// Enum sets allowed values for the schema. +func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder { + b.schema.Enum = values + return b +} + +// Format sets the string format. +func (b *SchemaBuilder) Format(format string) *SchemaBuilder { + b.schema.Format = format + return b +} + +// Min sets the minimum value. +func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder { + b.schema.Minimum = &minimum + return b +} + +// Max sets the maximum value. +func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder { + b.schema.Maximum = &maximum + return b +} + +// MinLength sets the minimum string length. +func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder { + b.schema.MinLength = &minimum + return b +} + +// MaxLength sets the maximum string length. +func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder { + b.schema.MaxLength = &maximum + return b +} + +// Build creates the final schema. +func (b *SchemaBuilder) Build() *Schema { + return &b.schema +} diff --git a/util.go b/util.go new file mode 100644 index 0000000000000000000000000000000000000000..6f0012d66d132a5810a58c7f8f8bede59cb41956 --- /dev/null +++ b/util.go @@ -0,0 +1,21 @@ +package ai + +import ( + "encoding/json" + + "github.com/go-viper/mapstructure/v2" +) + +func ParseOptions[T any](options map[string]any, m *T) error { + return mapstructure.Decode(options, m) +} + +func FloatOption(f float64) *float64 { + return &f +} + +func IsParsableJSON(data string) bool { + var m map[string]any + err := json.Unmarshal([]byte(data), &m) + return err == nil +}