Detailed changes
@@ -2,45 +2,96 @@ package ai
import (
"context"
+ "errors"
+ "maps"
+ "slices"
+ "sync"
+
+ "github.com/charmbracelet/crush/internal/llm/tools"
)
-type StepResponse struct {
+type StepResult struct {
Response
// Messages generated during this step
Messages []Message
}
-type StepCondition = func(steps []StepResponse) bool
+type StopCondition = func(steps []StepResult) bool
type PrepareStepFunctionOptions struct {
- Steps []StepResponse
+ Steps []StepResult
StepNumber int
Model LanguageModel
Messages []Message
}
type PrepareStepResult struct {
- SystemPrompt string
- Model LanguageModel
- Messages []Message
+ Model LanguageModel
+ Messages []Message
}
-type PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
-
-type OnStepFinishedFunction = func(step StepResponse)
+type (
+ PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
+ OnStepFinishedFunction = func(step StepResult)
+ RepairToolCall = func(ToolCallContent) ToolCallContent
+)
type AgentSettings struct {
- Call
- Model LanguageModel
+ systemPrompt string
+ maxOutputTokens *int64
+ temperature *float64
+ topP *float64
+ topK *int64
+ presencePenalty *float64
+ frequencyPenalty *float64
+ headers map[string]string
+ providerOptions ProviderOptions
+
+ // TODO: add support for provider tools
+ tools []tools.BaseTool
+ maxRetries *int
+
+ model LanguageModel
+
+ stopWhen []StopCondition
+ prepareStep PrepareStepFunction
+ repairToolCall RepairToolCall
+ onStepFinished OnStepFinishedFunction
+ onRetry OnRetryCallback
+}
- StopWhen []StepCondition
+type AgentCall struct {
+ Prompt string `json:"prompt"`
+ Files []FilePart `json:"files"`
+ Messages []Message `json:"messages"`
+ MaxOutputTokens *int64
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ ActiveTools []string `json:"active_tools"`
+ Headers map[string]string
+ ProviderOptions ProviderOptions
+ OnRetry OnRetryCallback
+ MaxRetries *int
+
+ StopWhen []StopCondition
PrepareStep PrepareStepFunction
+ RepairToolCall RepairToolCall
OnStepFinished OnStepFinishedFunction
}
+type AgentResult struct {
+ Steps []StepResult
+ // Final response
+ Response Response
+ TotalUsage Usage
+}
+
type Agent interface {
- Generate(context.Context, Call) (*Response, error)
- Stream(context.Context, Call) (StreamResponse, error)
+ Generate(context.Context, AgentCall) (*AgentResult, error)
+ Stream(context.Context, AgentCall) (StreamResponse, error)
}
type agentOption = func(*AgentSettings)
@@ -51,7 +102,7 @@ type agent struct {
func NewAgent(model LanguageModel, opts ...agentOption) Agent {
settings := AgentSettings{
- Model: model,
+ model: model,
}
for _, o := range opts {
o(&settings)
@@ -61,48 +112,465 @@ func NewAgent(model LanguageModel, opts ...agentOption) Agent {
}
}
-func mergeCall(agentOpts, opts Call) Call {
- if len(opts.Prompt) > 0 {
- agentOpts.Prompt = opts.Prompt
+func (a *agent) prepareCall(call AgentCall) AgentCall {
+ if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
+ call.MaxOutputTokens = a.settings.maxOutputTokens
+ }
+ if call.Temperature == nil && a.settings.temperature != nil {
+ call.Temperature = a.settings.temperature
}
- if opts.MaxOutputTokens != nil {
- agentOpts.MaxOutputTokens = opts.MaxOutputTokens
+ if call.TopP == nil && a.settings.topP != nil {
+ call.TopP = a.settings.topP
}
- if opts.Temperature != nil {
- agentOpts.Temperature = opts.Temperature
+ if call.TopK == nil && a.settings.topK != nil {
+ call.TopK = a.settings.topK
}
- if opts.TopP != nil {
- agentOpts.TopP = opts.TopP
+ if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
+ call.PresencePenalty = a.settings.presencePenalty
}
- if opts.TopK != nil {
- agentOpts.TopK = opts.TopK
+ if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
+ call.FrequencyPenalty = a.settings.frequencyPenalty
}
- if opts.PresencePenalty != nil {
- agentOpts.PresencePenalty = opts.PresencePenalty
+ if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
+ call.StopWhen = a.settings.stopWhen
}
- if opts.FrequencyPenalty != nil {
- agentOpts.FrequencyPenalty = opts.FrequencyPenalty
+ if call.PrepareStep == nil && a.settings.prepareStep != nil {
+ call.PrepareStep = a.settings.prepareStep
}
- if opts.Tools != nil {
- agentOpts.Tools = opts.Tools
+ if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
+ call.RepairToolCall = a.settings.repairToolCall
}
- if opts.Headers != nil {
- agentOpts.Headers = opts.Headers
+ if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
+ call.OnStepFinished = a.settings.onStepFinished
}
- if opts.ProviderOptions != nil {
- agentOpts.ProviderOptions = opts.ProviderOptions
+ if call.OnRetry == nil && a.settings.onRetry != nil {
+ call.OnRetry = a.settings.onRetry
+ }
+ if call.MaxRetries == nil && a.settings.maxRetries != nil {
+ call.MaxRetries = a.settings.maxRetries
+ }
+
+ providerOptions := ProviderOptions{}
+ if a.settings.providerOptions != nil {
+ maps.Copy(providerOptions, a.settings.providerOptions)
}
- return agentOpts
+ if call.ProviderOptions != nil {
+ maps.Copy(providerOptions, call.ProviderOptions)
+ }
+ call.ProviderOptions = providerOptions
+
+ headers := map[string]string{}
+
+ if a.settings.headers != nil {
+ maps.Copy(headers, a.settings.headers)
+ }
+
+ if call.Headers != nil {
+ maps.Copy(headers, call.Headers)
+ }
+ call.Headers = headers
+ return call
}
// Generate implements Agent.
-func (a *agent) Generate(ctx context.Context, opts Call) (*Response, error) {
- // TODO: implement the agentic stuff
- return a.settings.Model.Generate(ctx, mergeCall(a.settings.Call, opts))
+func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
+ opts = a.prepareCall(opts)
+ initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
+ if err != nil {
+ return nil, err
+ }
+ var responseMessages []Message
+ var steps []StepResult
+
+ for {
+ stepInputMessages := append(initialPrompt, responseMessages...)
+ stepModel := a.settings.model
+ if opts.PrepareStep != nil {
+ prepared := opts.PrepareStep(PrepareStepFunctionOptions{
+ Model: stepModel,
+ Steps: steps,
+ StepNumber: len(steps),
+ Messages: stepInputMessages,
+ })
+ stepInputMessages = prepared.Messages
+ if prepared.Model != nil {
+ stepModel = prepared.Model
+ }
+ }
+
+ preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
+
+ toolChoice := ToolChoiceAuto
+ retryOptions := DefaultRetryOptions()
+ retryOptions.OnRetry = opts.OnRetry
+ retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
+
+ result, err := retry(ctx, func() (*Response, error) {
+ return stepModel.Generate(ctx, Call{
+ Prompt: stepInputMessages,
+ MaxOutputTokens: opts.MaxOutputTokens,
+ Temperature: opts.Temperature,
+ TopP: opts.TopP,
+ TopK: opts.TopK,
+ PresencePenalty: opts.PresencePenalty,
+ FrequencyPenalty: opts.FrequencyPenalty,
+ Tools: preparedTools,
+ ToolChoice: &toolChoice,
+ Headers: opts.Headers,
+ ProviderOptions: opts.ProviderOptions,
+ })
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ var stepToolCalls []ToolCallContent
+ for _, content := range result.Content {
+ if content.GetType() == ContentTypeToolCall {
+ toolCall, ok := AsContentType[ToolCallContent](content)
+ if !ok {
+ continue
+ }
+ stepToolCalls = append(stepToolCalls, toolCall)
+ }
+ }
+
+ toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
+
+ stepContent := result.Content
+ for _, result := range toolResults {
+ stepContent = append(stepContent, result)
+ }
+ currentStepMessages := toResponseMessages(stepContent)
+ responseMessages = append(responseMessages, currentStepMessages...)
+
+ stepResult := StepResult{
+ Response: *result,
+ Messages: currentStepMessages,
+ }
+ steps = append(steps, stepResult)
+ if opts.OnStepFinished != nil {
+ opts.OnStepFinished(stepResult)
+ }
+
+ shouldStop := isStopConditionMet(opts.StopWhen, steps)
+
+ if shouldStop || err != nil || len(stepToolCalls) == 0 {
+ break
+ }
+ }
+
+ totalUsage := Usage{}
+
+ for _, step := range steps {
+ usage := step.Usage
+ totalUsage.InputTokens += usage.InputTokens
+ totalUsage.OutputTokens += usage.OutputTokens
+ totalUsage.ReasoningTokens += usage.ReasoningTokens
+ totalUsage.CacheCreationTokens += usage.CacheCreationTokens
+ totalUsage.CacheReadTokens += usage.CacheReadTokens
+ totalUsage.TotalTokens += totalUsage.TotalTokens
+ }
+
+ agentResult := &AgentResult{
+ Steps: steps,
+ Response: steps[len(steps)-1].Response,
+ TotalUsage: totalUsage,
+ }
+ return agentResult, nil
+}
+
+func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
+ if len(conditions) == 0 {
+ return false
+ }
+
+ for _, condition := range conditions {
+ if condition(steps) {
+ return true
+ }
+ }
+ return false
+}
+
+func toResponseMessages(content []Content) []Message {
+ var assistantParts []MessagePart
+ var toolParts []MessagePart
+
+ for _, c := range content {
+ switch c.GetType() {
+ case ContentTypeText:
+ text, ok := AsContentType[TextContent](c)
+ if !ok {
+ continue
+ }
+ assistantParts = append(assistantParts, TextPart{
+ Text: text.Text,
+ ProviderOptions: ProviderOptions(text.ProviderMetadata),
+ })
+ case ContentTypeReasoning:
+ reasoning, ok := AsContentType[ReasoningContent](c)
+ if !ok {
+ continue
+ }
+ assistantParts = append(assistantParts, ReasoningPart{
+ Text: reasoning.Text,
+ ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
+ })
+ case ContentTypeToolCall:
+ toolCall, ok := AsContentType[ToolCallContent](c)
+ if !ok {
+ continue
+ }
+ assistantParts = append(assistantParts, ToolCallPart{
+ ToolCallID: toolCall.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Input: toolCall.Input,
+ ProviderExecuted: toolCall.ProviderExecuted,
+ ProviderOptions: ProviderOptions(toolCall.ProviderMetadata),
+ })
+ case ContentTypeFile:
+ file, ok := AsContentType[FileContent](c)
+ if !ok {
+ continue
+ }
+ assistantParts = append(assistantParts, FilePart{
+ Data: file.Data,
+ MediaType: file.MediaType,
+ ProviderOptions: ProviderOptions(file.ProviderMetadata),
+ })
+ case ContentTypeToolResult:
+ result, ok := AsContentType[ToolResultContent](c)
+ if !ok {
+ continue
+ }
+ toolParts = append(toolParts, ToolResultPart{
+ ToolCallID: result.ToolCallID,
+ Output: result.Result,
+ ProviderOptions: ProviderOptions(result.ProviderMetadata),
+ })
+ }
+ }
+
+ var messages []Message
+ if len(assistantParts) > 0 {
+ messages = append(messages, Message{
+ Role: MessageRoleAssistant,
+ Content: assistantParts,
+ })
+ }
+ if len(toolParts) > 0 {
+ messages = append(messages, Message{
+ Role: MessageRoleTool,
+ Content: toolParts,
+ })
+ }
+ return messages
+}
+
+func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
+ if len(toolCalls) == 0 {
+ return nil, nil
+ }
+
+ // Create a map for quick tool lookup
+ toolMap := make(map[string]tools.BaseTool)
+ for _, tool := range allTools {
+ toolMap[tool.Info().Name] = tool
+ }
+
+ // Execute all tool calls in parallel
+ results := make([]ToolResultContent, len(toolCalls))
+ var toolExecutionError error
+ var wg sync.WaitGroup
+
+ for i, toolCall := range toolCalls {
+ wg.Add(1)
+ go func(index int, call ToolCallContent) {
+ defer wg.Done()
+
+ tool, exists := toolMap[call.ToolName]
+ if !exists {
+ results[index] = ToolResultContent{
+ ToolCallID: call.ToolCallID,
+ ToolName: call.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: errors.New("Error: Tool not found: " + call.ToolName),
+ },
+ ProviderExecuted: false,
+ }
+ return
+ }
+
+ // Execute the tool
+ result, err := tool.Run(ctx, tools.ToolCall{
+ ID: call.ToolCallID,
+ Name: call.ToolName,
+ Input: call.Input,
+ })
+ if err != nil {
+ results[index] = ToolResultContent{
+ ToolCallID: call.ToolCallID,
+ ToolName: call.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: err,
+ },
+ ProviderExecuted: false,
+ }
+ toolExecutionError = err
+ return
+ }
+
+ if result.IsError {
+ results[index] = ToolResultContent{
+ ToolCallID: call.ToolCallID,
+ ToolName: call.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: errors.New(result.Content),
+ },
+ ProviderExecuted: false,
+ }
+ } else {
+ results[index] = ToolResultContent{
+ ToolCallID: call.ToolCallID,
+ ToolName: toolCall.ToolName,
+ Result: ToolResultOutputContentText{
+ Text: result.Content,
+ },
+ ProviderExecuted: false,
+ }
+ }
+ }(i, toolCall)
+ }
+
+ // Wait for all tool executions to complete
+ wg.Wait()
+
+ return results, toolExecutionError
}
// Stream implements Agent.
-func (a *agent) Stream(ctx context.Context, opts Call) (StreamResponse, error) {
+func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
// TODO: implement the agentic stuff
- return a.settings.Model.Stream(ctx, mergeCall(a.settings.Call, opts))
+ panic("not implemented")
+}
+
+func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
+ var preparedTools []Tool
+ for _, tool := range tools {
+ if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
+ continue
+ }
+ info := tool.Info()
+ preparedTools = append(preparedTools, FunctionTool{
+ Name: info.Name,
+ Description: info.Description,
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": info.Parameters,
+ "required": info.Required,
+ },
+ })
+ }
+ return preparedTools
+}
+
+func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
+ if prompt == "" {
+ return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
+ }
+
+ var preparedPrompt Prompt
+
+ if system != "" {
+ preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
+ }
+
+ preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
+ preparedPrompt = append(preparedPrompt, messages...)
+ return preparedPrompt, nil
+}
+
+func WithSystemPrompt(prompt string) agentOption {
+ return func(s *AgentSettings) {
+ s.systemPrompt = prompt
+ }
+}
+
+func WithMaxOutputTokens(tokens int64) agentOption {
+ return func(s *AgentSettings) {
+ s.maxOutputTokens = &tokens
+ }
+}
+
+func WithTemperature(temp float64) agentOption {
+ return func(s *AgentSettings) {
+ s.temperature = &temp
+ }
+}
+
+func WithTopP(topP float64) agentOption {
+ return func(s *AgentSettings) {
+ s.topP = &topP
+ }
+}
+
+func WithTopK(topK int64) agentOption {
+ return func(s *AgentSettings) {
+ s.topK = &topK
+ }
+}
+
+func WithPresencePenalty(penalty float64) agentOption {
+ return func(s *AgentSettings) {
+ s.presencePenalty = &penalty
+ }
+}
+
+func WithFrequencyPenalty(penalty float64) agentOption {
+ return func(s *AgentSettings) {
+ s.frequencyPenalty = &penalty
+ }
+}
+
+func WithTools(tools ...tools.BaseTool) agentOption {
+ return func(s *AgentSettings) {
+ s.tools = append(s.tools, tools...)
+ }
+}
+
+func WithStopConditions(conditions ...StopCondition) agentOption {
+ return func(s *AgentSettings) {
+ s.stopWhen = append(s.stopWhen, conditions...)
+ }
+}
+
+func WithPrepareStep(fn PrepareStepFunction) agentOption {
+ return func(s *AgentSettings) {
+ s.prepareStep = fn
+ }
+}
+
+func WithRepairToolCall(fn RepairToolCall) agentOption {
+ return func(s *AgentSettings) {
+ s.repairToolCall = fn
+ }
+}
+
+func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
+ return func(s *AgentSettings) {
+ s.onStepFinished = fn
+ }
+}
+
+func WithHeaders(headers map[string]string) agentOption {
+ return func(s *AgentSettings) {
+ s.headers = headers
+ }
+}
+
+func WithProviderOptions(providerOptions ProviderOptions) agentOption {
+ return func(s *AgentSettings) {
+ s.providerOptions = providerOptions
+ }
}
@@ -184,7 +184,7 @@ func (t ToolResultOutputContentText) GetType() ToolResultContentType {
}
type ToolResultOutputContentError struct {
- Error string `json:"error"`
+ Error error `json:"error"`
}
func (t ToolResultOutputContentError) GetType() ToolResultContentType {
@@ -268,11 +268,9 @@ type FileContent struct {
// The IANA media type of the file, e.g. `image/png` or `audio/mp3`.
// @see https://www.iana.org/assignments/media-types/media-types.xhtml
MediaType string `json:"media_type"`
- // Generated file data as base64 encoded strings or binary data.
- // If the API returns base64 encoded strings, the file data should be returned
- // as base64 encoded strings. If the API returns binary data, the file data should
- // be returned as binary data.
- Data any `json:"data"` // string (base64) or []byte
+ // Generated file data as binary data.
+ Data []byte `json:"data"`
+ ProviderMetadata ProviderMetadata `json:"provider_metadata"`
}
// GetType returns the type of the file content.
@@ -332,9 +330,7 @@ type ToolResultContent struct {
// Name of the tool that generated this result.
ToolName string `json:"tool_name"`
// Result of the tool call. This is a JSON-serializable object.
- Result any `json:"result"`
- // Optional flag if the result is an error or an error message.
- IsError bool `json:"is_error"`
+ Result ToolResultOutputContent `json:"result"`
// Whether the tool result was generated by the provider.
// If this flag is set to true, the tool result was generated by the provider.
// If this flag is not set or is false, the tool result was generated by the client.
@@ -430,3 +426,15 @@ func NewUserMessage(prompt string, files ...FilePart) Message {
Content: content,
}
}
+
+func NewSystemMessage(prompt ...string) Message {
+ var content []MessagePart
+ for _, p := range prompt {
+ content = append(content, TextPart{Text: p})
+ }
+
+ return Message{
+ Role: MessageRoleSystem,
+ Content: content,
+ }
+}
@@ -46,30 +46,28 @@ func IsAIError(err error) bool {
// APICallError represents an error from an API call.
type APICallError struct {
*AIError
- URL string
- RequestBodyValues any
- StatusCode int
- ResponseHeaders map[string]string
- ResponseBody string
- IsRetryable bool
- Data any
+ URL string
+ RequestDump string
+ StatusCode int
+ ResponseHeaders map[string]string
+ ResponseDump string
+ IsRetryable bool
}
// NewAPICallError creates a new API call error.
-func NewAPICallError(message, url string, requestBodyValues any, statusCode int, responseHeaders map[string]string, responseBody string, cause error, isRetryable bool, data any) *APICallError {
+func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError {
if !isRetryable && statusCode != 0 {
isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500
}
return &APICallError{
- AIError: NewAIError("AI_APICallError", message, cause),
- URL: url,
- RequestBodyValues: requestBodyValues,
- StatusCode: statusCode,
- ResponseHeaders: responseHeaders,
- ResponseBody: responseBody,
- IsRetryable: isRetryable,
- Data: data,
+ AIError: NewAIError("AI_APICallError", message, cause),
+ URL: url,
+ RequestDump: requestDump,
+ StatusCode: statusCode,
+ ResponseHeaders: responseHeaders,
+ ResponseDump: responseDump,
+ IsRetryable: isRetryable,
}
}
@@ -1,26 +1,6 @@
package ai
-import (
- "encoding/json"
-
- "github.com/go-viper/mapstructure/v2"
-)
-
type Provider interface {
LanguageModel(modelID string) LanguageModel
// TODO: add other model types when needed
}
-
-func ParseOptions[T any](options map[string]any, m *T) error {
- return mapstructure.Decode(options, m)
-}
-
-func FloatOption(f float64) *float64 {
- return &f
-}
-
-func IsParsableJSON(data string) bool {
- var m map[string]any
- err := json.Unmarshal([]byte(data), &m)
- return err == nil
-}
@@ -0,0 +1,72 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/charmbracelet/crush/internal/ai/providers"
+ "github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+type weatherTool struct{}
+
+// Info implements tools.BaseTool.
+func (w *weatherTool) Info() tools.ToolInfo {
+ return tools.ToolInfo{
+ Name: "weather",
+ Parameters: map[string]any{
+ "location": map[string]string{
+ "type": "string",
+ "description": "the city",
+ },
+ },
+ Required: []string{"location"},
+ }
+}
+
+// Name implements tools.BaseTool.
+func (w *weatherTool) Name() string {
+ return "weather"
+}
+
+// Run implements tools.BaseTool.
+func (w *weatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+ return tools.NewTextResponse("40 C"), nil
+}
+
+func newWeatherTool() tools.BaseTool {
+ return &weatherTool{}
+}
+
+func main() {
+ provider := providers.NewOpenAIProvider(
+ providers.WithOpenAIApiKey(os.Getenv("OPENAI_API_KEY")),
+ )
+ model := provider.LanguageModel("gpt-4o")
+
+ agent := ai.NewAgent(
+ model,
+ ai.WithSystemPrompt("You are a helpful assistant"),
+ ai.WithTools(newWeatherTool()),
+ )
+
+ result, _ := agent.Generate(context.Background(), ai.AgentCall{
+ Prompt: "What's the weather in pristina",
+ })
+
+ fmt.Println("Steps: ", len(result.Steps))
+ for _, s := range result.Steps {
+ for _, c := range s.Content {
+ if c.GetType() == ai.ContentTypeToolCall {
+ tc, _ := ai.AsContentType[ai.ToolCallContent](c)
+ fmt.Println("ToolCall: ", tc.ToolName)
+
+ }
+ }
+ }
+
+ fmt.Println("Final Response: ", result.Response.Content.Text())
+ fmt.Println("Total Usage: ", result.TotalUsage)
+}
@@ -394,6 +394,30 @@ func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletion
return params, warnings, nil
}
+func (o openAILanguageModel) handleError(err error) error {
+ var apiErr *openai.Error
+ if errors.As(err, &apiErr) {
+ requestDump := apiErr.DumpRequest(true)
+ responseDump := apiErr.DumpResponse(true)
+ headers := map[string]string{}
+ for k, h := range apiErr.Response.Header {
+ v := h[len(h)-1]
+ headers[strings.ToLower(k)] = v
+ }
+ return ai.NewAPICallError(
+ apiErr.Message,
+ apiErr.Request.URL.String(),
+ string(requestDump),
+ apiErr.StatusCode,
+ headers,
+ string(responseDump),
+ apiErr,
+ false,
+ )
+ }
+ return err
+}
+
// Generate implements ai.LanguageModel.
func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
params, warnings, err := o.prepareParams(call)
@@ -402,7 +426,7 @@ func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Re
}
response, err := o.client.Chat.Completions.New(ctx, *params)
if err != nil {
- return nil, err
+ return nil, o.handleError(err)
}
if len(response.Choices) == 0 {
@@ -626,7 +650,7 @@ func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
if err != nil {
yield(ai.StreamPart{
Type: ai.StreamPartTypeError,
- Error: stream.Err(),
+ Error: o.handleError(stream.Err()),
})
return
}
@@ -1097,7 +1121,7 @@ func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion,
})
continue
}
- messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
+ messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
}
}
}
@@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"strings"
@@ -496,7 +497,7 @@ func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
ai.ToolResultPart{
ToolCallID: "error-tool",
Output: ai.ToolResultOutputContentError{
- Error: "Something went wrong",
+ Error: errors.New("Something went wrong"),
},
},
},
@@ -0,0 +1,170 @@
+package ai
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+// RetryFn is a function that returns a value and an error.
+type RetryFn[T any] func() (T, error)
+
+// RetryFunction is a function that retries another function.
+type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
+
+// RetryReason represents the reason why a retry operation failed.
+type RetryReason string
+
+const (
+ RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
+ RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
+)
+
+// RetryError represents an error that occurred during retry operations.
+type RetryError struct {
+ *AIError
+ Reason RetryReason
+ Errors []error
+}
+
+// NewRetryError creates a new retry error.
+func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
+ return &RetryError{
+ AIError: NewAIError("AI_RetryError", message, nil),
+ Reason: reason,
+ Errors: errors,
+ }
+}
+
+// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
+func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
+ var apiErr *APICallError
+ if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
+ return exponentialBackoffDelay
+ }
+
+ headers := apiErr.ResponseHeaders
+ var ms time.Duration
+
+ // retry-ms is more precise than retry-after and used by e.g. OpenAI
+ if retryAfterMs, exists := headers["retry-after-ms"]; exists {
+ if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
+ ms = time.Duration(timeoutMs) * time.Millisecond
+ }
+ }
+
+ // About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
+ if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
+ if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
+ ms = time.Duration(timeoutSeconds) * time.Second
+ } else {
+ // Try parsing as HTTP date
+ if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
+ ms = time.Until(t)
+ }
+ }
+ }
+
+ // Check that the delay is reasonable:
+ // 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
+ if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
+ return ms
+ }
+
+ return exponentialBackoffDelay
+}
+
+// isAbortError checks if the error is a context cancellation error.
+func isAbortError(err error) bool {
+ return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
+}
+
+// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
+// a failed operation with exponential backoff, while respecting rate limit headers
+// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
+func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
+ return func(ctx context.Context, fn RetryFn[T]) (T, error) {
+ return retryWithExponentialBackoff(ctx, fn, options, nil)
+ }
+}
+
+// RetryOptions configures the retry behavior.
+type RetryOptions struct {
+ MaxRetries int
+ InitialDelayIn time.Duration
+ BackoffFactor float64
+ OnRetry OnRetryCallback
+}
+
+type OnRetryCallback = func(err *APICallError, delay time.Duration)
+
+// DefaultRetryOptions returns the default retry options.
+func DefaultRetryOptions() RetryOptions {
+ return RetryOptions{
+ MaxRetries: 2,
+ InitialDelayIn: 2000 * time.Millisecond,
+ BackoffFactor: 2.0,
+ }
+}
+
+// retryWithExponentialBackoff implements the retry logic with exponential backoff.
+func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
+ var zero T
+ result, err := fn()
+ if err == nil {
+ return result, nil
+ }
+
+ if isAbortError(err) {
+ return zero, err // don't retry when the request was aborted
+ }
+
+ if options.MaxRetries == 0 {
+ return zero, err // don't wrap the error when retries are disabled
+ }
+
+ errorMessage := GetErrorMessage(err)
+ newErrors := append(allErrors, err)
+ tryNumber := len(newErrors)
+
+ if tryNumber > options.MaxRetries {
+ return zero, NewRetryError(
+ fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
+ RetryReasonMaxRetriesExceeded,
+ newErrors,
+ )
+ }
+
+ var apiErr *APICallError
+ if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
+ delay := getRetryDelayInMs(err, options.InitialDelayIn)
+ if options.OnRetry != nil {
+ options.OnRetry(apiErr, delay)
+ }
+
+ select {
+ case <-time.After(delay):
+ // Continue with retry
+ case <-ctx.Done():
+ return zero, ctx.Err()
+ }
+
+ newOptions := options
+ newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
+
+ return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
+ }
+
+ if tryNumber == 1 {
+ return zero, err // don't wrap the error when a non-retryable error occurs on the first try
+ }
+
+ return zero, NewRetryError(
+ fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
+ RetryReasonErrorNotRetryable,
+ newErrors,
+ )
+}
+
@@ -0,0 +1,234 @@
+// WIP NEED TO REVISIT
+package ai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+)
+
+// AgentTool represents a function that can be called by a language model.
+type AgentTool interface {
+ Name() string
+ Description() string
+ InputSchema() Schema
+ Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error)
+}
+
+// Schema represents a JSON schema for tool input validation.
+type Schema struct {
+ Type string `json:"type"`
+ Properties map[string]*Schema `json:"properties,omitempty"`
+ Required []string `json:"required,omitempty"`
+ Items *Schema `json:"items,omitempty"`
+ Description string `json:"description,omitempty"`
+ Enum []any `json:"enum,omitempty"`
+ Format string `json:"format,omitempty"`
+ Minimum *float64 `json:"minimum,omitempty"`
+ Maximum *float64 `json:"maximum,omitempty"`
+ MinLength *int `json:"minLength,omitempty"`
+ MaxLength *int `json:"maxLength,omitempty"`
+}
+
+// BasicTool provides a basic implementation of the Tool interface
+//
+// Example usage:
+//
+// calculator := &tools.BasicTool{
+// ToolName: "calculate",
+// ToolDescription: "Evaluates mathematical expressions",
+// ToolInputSchema: tools.Schema{
+// Type: "object",
+// Properties: map[string]*tools.Schema{
+// "expression": {
+// Type: "string",
+// Description: "Mathematical expression to evaluate",
+// },
+// },
+// Required: []string{"expression"},
+// },
+// ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
+// var req struct {
+// Expression string `json:"expression"`
+// }
+// if err := json.Unmarshal(input, &req); err != nil {
+// return nil, err
+// }
+// result := evaluateExpression(req.Expression)
+// return json.Marshal(map[string]any{"result": result})
+// },
+// }
+type BasicTool struct {
+ ToolName string
+ ToolDescription string
+ ToolInputSchema Schema
+ ExecuteFunc func(context.Context, json.RawMessage) (json.RawMessage, error)
+}
+
+// Name returns the tool name.
+func (t *BasicTool) Name() string {
+ return t.ToolName
+}
+
+// Description returns the tool description.
+func (t *BasicTool) Description() string {
+ return t.ToolDescription
+}
+
+// InputSchema returns the tool input schema.
+func (t *BasicTool) InputSchema() Schema {
+ return t.ToolInputSchema
+}
+
+// Execute executes the tool with the given input.
+func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
+ if t.ExecuteFunc == nil {
+ return nil, fmt.Errorf("tool %s has no execute function", t.ToolName)
+ }
+ return t.ExecuteFunc(ctx, input)
+}
+
+// ToolBuilder provides a fluent interface for building tools.
+type ToolBuilder struct {
+ tool *BasicTool
+}
+
+// NewTool creates a new tool builder.
+func NewTool(name string) *ToolBuilder {
+ return &ToolBuilder{
+ tool: &BasicTool{
+ ToolName: name,
+ },
+ }
+}
+
+// Description sets the tool description.
+func (b *ToolBuilder) Description(desc string) *ToolBuilder {
+ b.tool.ToolDescription = desc
+ return b
+}
+
+// InputSchema sets the tool input schema.
+func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder {
+ b.tool.ToolInputSchema = schema
+ return b
+}
+
+// Execute sets the tool execution function.
+func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder {
+ b.tool.ExecuteFunc = fn
+ return b
+}
+
+// Build creates the final tool.
+func (b *ToolBuilder) Build() AgentTool {
+ return b.tool
+}
+
+// SchemaBuilder provides a fluent interface for building JSON schemas.
+type SchemaBuilder struct {
+ schema Schema
+}
+
+// NewSchema creates a new schema builder.
+func NewSchema(schemaType string) *SchemaBuilder {
+ return &SchemaBuilder{
+ schema: Schema{
+ Type: schemaType,
+ },
+ }
+}
+
+// Object creates a schema builder for an object type.
+func Object() *SchemaBuilder {
+ return NewSchema("object")
+}
+
+// String creates a schema builder for a string type.
+func String() *SchemaBuilder {
+ return NewSchema("string")
+}
+
+// Number creates a schema builder for a number type.
+func Number() *SchemaBuilder {
+ return NewSchema("number")
+}
+
+// Array creates a schema builder for an array type.
+func Array() *SchemaBuilder {
+ return NewSchema("array")
+}
+
+// Description sets the schema description.
+func (b *SchemaBuilder) Description(desc string) *SchemaBuilder {
+ b.schema.Description = desc
+ return b
+}
+
+// Properties sets the schema properties.
+func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder {
+ b.schema.Properties = props
+ return b
+}
+
+// Property adds a property to the schema.
+func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder {
+ if b.schema.Properties == nil {
+ b.schema.Properties = make(map[string]*Schema)
+ }
+ b.schema.Properties[name] = schema
+ return b
+}
+
+// Required marks fields as required.
+func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder {
+ b.schema.Required = append(b.schema.Required, fields...)
+ return b
+}
+
+// Items sets the schema for array items.
+func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder {
+ b.schema.Items = schema
+ return b
+}
+
+// Enum sets allowed values for the schema.
+func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder {
+ b.schema.Enum = values
+ return b
+}
+
+// Format sets the string format.
+func (b *SchemaBuilder) Format(format string) *SchemaBuilder {
+ b.schema.Format = format
+ return b
+}
+
+// Min sets the minimum value.
+func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder {
+ b.schema.Minimum = &minimum
+ return b
+}
+
+// Max sets the maximum value.
+func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder {
+ b.schema.Maximum = &maximum
+ return b
+}
+
+// MinLength sets the minimum string length.
+func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder {
+ b.schema.MinLength = &minimum
+ return b
+}
+
+// MaxLength sets the maximum string length.
+func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder {
+ b.schema.MaxLength = &maximum
+ return b
+}
+
+// Build creates the final schema.
+func (b *SchemaBuilder) Build() *Schema {
+ return &b.schema
+}
@@ -0,0 +1,21 @@
+package ai
+
+import (
+ "encoding/json"
+
+ "github.com/go-viper/mapstructure/v2"
+)
+
+func ParseOptions[T any](options map[string]any, m *T) error {
+ return mapstructure.Decode(options, m)
+}
+
+func FloatOption(f float64) *float64 {
+ return &f
+}
+
+func IsParsableJSON(data string) bool {
+ var m map[string]any
+ err := json.Unmarshal([]byte(data), &m)
+ return err == nil
+}