feat: initial agent implementation

Kujtim Hoxha created

still need to setup streaming

Change summary

agent.go                         | 552 +++++++++++++++++++++++++++++++--
content.go                       |  26 +
errors.go                        |  30 
provider.go                      |  20 -
providers/examples/agent/main.go |  72 ++++
providers/openai.go              |  30 +
providers/openai_test.go         |   3 
retry.go                         | 170 ++++++++++
tool.go                          | 234 ++++++++++++++
util.go                          |  21 +
10 files changed, 1,067 insertions(+), 91 deletions(-)

Detailed changes

agent.go 🔗

@@ -2,45 +2,96 @@ package ai
 
 import (
 	"context"
+	"errors"
+	"maps"
+	"slices"
+	"sync"
+
+	"github.com/charmbracelet/crush/internal/llm/tools"
 )
 
-type StepResponse struct {
+type StepResult struct {
 	Response
 	// Messages generated during this step
 	Messages []Message
 }
 
-type StepCondition = func(steps []StepResponse) bool
+type StopCondition = func(steps []StepResult) bool
 
 type PrepareStepFunctionOptions struct {
-	Steps      []StepResponse
+	Steps      []StepResult
 	StepNumber int
 	Model      LanguageModel
 	Messages   []Message
 }
 
 type PrepareStepResult struct {
-	SystemPrompt string
-	Model        LanguageModel
-	Messages     []Message
+	Model    LanguageModel
+	Messages []Message
 }
 
-type PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
-
-type OnStepFinishedFunction = func(step StepResponse)
+type (
+	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
+	OnStepFinishedFunction = func(step StepResult)
+	RepairToolCall         = func(ToolCallContent) ToolCallContent
+)
 
 type AgentSettings struct {
-	Call
-	Model LanguageModel
+	systemPrompt     string
+	maxOutputTokens  *int64
+	temperature      *float64
+	topP             *float64
+	topK             *int64
+	presencePenalty  *float64
+	frequencyPenalty *float64
+	headers          map[string]string
+	providerOptions  ProviderOptions
+
+	// TODO: add support for provider tools
+	tools      []tools.BaseTool
+	maxRetries *int
+
+	model LanguageModel
+
+	stopWhen       []StopCondition
+	prepareStep    PrepareStepFunction
+	repairToolCall RepairToolCall
+	onStepFinished OnStepFinishedFunction
+	onRetry        OnRetryCallback
+}
 
-	StopWhen       []StepCondition
+type AgentCall struct {
+	Prompt           string     `json:"prompt"`
+	Files            []FilePart `json:"files"`
+	Messages         []Message  `json:"messages"`
+	MaxOutputTokens  *int64
+	Temperature      *float64 `json:"temperature"`
+	TopP             *float64 `json:"top_p"`
+	TopK             *int64   `json:"top_k"`
+	PresencePenalty  *float64 `json:"presence_penalty"`
+	FrequencyPenalty *float64 `json:"frequency_penalty"`
+	ActiveTools      []string `json:"active_tools"`
+	Headers          map[string]string
+	ProviderOptions  ProviderOptions
+	OnRetry          OnRetryCallback
+	MaxRetries       *int
+
+	StopWhen       []StopCondition
 	PrepareStep    PrepareStepFunction
+	RepairToolCall RepairToolCall
 	OnStepFinished OnStepFinishedFunction
 }
 
+type AgentResult struct {
+	Steps []StepResult
+	// Final response
+	Response   Response
+	TotalUsage Usage
+}
+
 type Agent interface {
-	Generate(context.Context, Call) (*Response, error)
-	Stream(context.Context, Call) (StreamResponse, error)
+	Generate(context.Context, AgentCall) (*AgentResult, error)
+	Stream(context.Context, AgentCall) (StreamResponse, error)
 }
 
 type agentOption = func(*AgentSettings)
@@ -51,7 +102,7 @@ type agent struct {
 
 func NewAgent(model LanguageModel, opts ...agentOption) Agent {
 	settings := AgentSettings{
-		Model: model,
+		model: model,
 	}
 	for _, o := range opts {
 		o(&settings)
@@ -61,48 +112,465 @@ func NewAgent(model LanguageModel, opts ...agentOption) Agent {
 	}
 }
 
-func mergeCall(agentOpts, opts Call) Call {
-	if len(opts.Prompt) > 0 {
-		agentOpts.Prompt = opts.Prompt
+func (a *agent) prepareCall(call AgentCall) AgentCall {
+	if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
+		call.MaxOutputTokens = a.settings.maxOutputTokens
+	}
+	if call.Temperature == nil && a.settings.temperature != nil {
+		call.Temperature = a.settings.temperature
 	}
-	if opts.MaxOutputTokens != nil {
-		agentOpts.MaxOutputTokens = opts.MaxOutputTokens
+	if call.TopP == nil && a.settings.topP != nil {
+		call.TopP = a.settings.topP
 	}
-	if opts.Temperature != nil {
-		agentOpts.Temperature = opts.Temperature
+	if call.TopK == nil && a.settings.topK != nil {
+		call.TopK = a.settings.topK
 	}
-	if opts.TopP != nil {
-		agentOpts.TopP = opts.TopP
+	if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
+		call.PresencePenalty = a.settings.presencePenalty
 	}
-	if opts.TopK != nil {
-		agentOpts.TopK = opts.TopK
+	if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
+		call.FrequencyPenalty = a.settings.frequencyPenalty
 	}
-	if opts.PresencePenalty != nil {
-		agentOpts.PresencePenalty = opts.PresencePenalty
+	if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
+		call.StopWhen = a.settings.stopWhen
 	}
-	if opts.FrequencyPenalty != nil {
-		agentOpts.FrequencyPenalty = opts.FrequencyPenalty
+	if call.PrepareStep == nil && a.settings.prepareStep != nil {
+		call.PrepareStep = a.settings.prepareStep
 	}
-	if opts.Tools != nil {
-		agentOpts.Tools = opts.Tools
+	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
+		call.RepairToolCall = a.settings.repairToolCall
 	}
-	if opts.Headers != nil {
-		agentOpts.Headers = opts.Headers
+	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
+		call.OnStepFinished = a.settings.onStepFinished
 	}
-	if opts.ProviderOptions != nil {
-		agentOpts.ProviderOptions = opts.ProviderOptions
+	if call.OnRetry == nil && a.settings.onRetry != nil {
+		call.OnRetry = a.settings.onRetry
+	}
+	if call.MaxRetries == nil && a.settings.maxRetries != nil {
+		call.MaxRetries = a.settings.maxRetries
+	}
+
+	providerOptions := ProviderOptions{}
+	if a.settings.providerOptions != nil {
+		maps.Copy(providerOptions, a.settings.providerOptions)
 	}
-	return agentOpts
+	if call.ProviderOptions != nil {
+		maps.Copy(providerOptions, call.ProviderOptions)
+	}
+	call.ProviderOptions = providerOptions
+
+	headers := map[string]string{}
+
+	if a.settings.headers != nil {
+		maps.Copy(headers, a.settings.headers)
+	}
+
+	if call.Headers != nil {
+		maps.Copy(headers, call.Headers)
+	}
+	call.Headers = headers
+	return call
 }
 
 // Generate implements Agent.
-func (a *agent) Generate(ctx context.Context, opts Call) (*Response, error) {
-	// TODO: implement the agentic stuff
-	return a.settings.Model.Generate(ctx, mergeCall(a.settings.Call, opts))
+func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
+	opts = a.prepareCall(opts)
+	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
+	if err != nil {
+		return nil, err
+	}
+	var responseMessages []Message
+	var steps []StepResult
+
+	for {
+		stepInputMessages := append(initialPrompt, responseMessages...)
+		stepModel := a.settings.model
+		if opts.PrepareStep != nil {
+			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
+				Model:      stepModel,
+				Steps:      steps,
+				StepNumber: len(steps),
+				Messages:   stepInputMessages,
+			})
+			stepInputMessages = prepared.Messages
+			if prepared.Model != nil {
+				stepModel = prepared.Model
+			}
+		}
+
+		preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
+
+		toolChoice := ToolChoiceAuto
+		retryOptions := DefaultRetryOptions()
+		retryOptions.OnRetry = opts.OnRetry
+		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
+
+		result, err := retry(ctx, func() (*Response, error) {
+			return stepModel.Generate(ctx, Call{
+				Prompt:           stepInputMessages,
+				MaxOutputTokens:  opts.MaxOutputTokens,
+				Temperature:      opts.Temperature,
+				TopP:             opts.TopP,
+				TopK:             opts.TopK,
+				PresencePenalty:  opts.PresencePenalty,
+				FrequencyPenalty: opts.FrequencyPenalty,
+				Tools:            preparedTools,
+				ToolChoice:       &toolChoice,
+				Headers:          opts.Headers,
+				ProviderOptions:  opts.ProviderOptions,
+			})
+		})
+		if err != nil {
+			return nil, err
+		}
+
+		var stepToolCalls []ToolCallContent
+		for _, content := range result.Content {
+			if content.GetType() == ContentTypeToolCall {
+				toolCall, ok := AsContentType[ToolCallContent](content)
+				if !ok {
+					continue
+				}
+				stepToolCalls = append(stepToolCalls, toolCall)
+			}
+		}
+
+		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
+
+		stepContent := result.Content
+		for _, result := range toolResults {
+			stepContent = append(stepContent, result)
+		}
+		currentStepMessages := toResponseMessages(stepContent)
+		responseMessages = append(responseMessages, currentStepMessages...)
+
+		stepResult := StepResult{
+			Response: *result,
+			Messages: currentStepMessages,
+		}
+		steps = append(steps, stepResult)
+		if opts.OnStepFinished != nil {
+			opts.OnStepFinished(stepResult)
+		}
+
+		shouldStop := isStopConditionMet(opts.StopWhen, steps)
+
+		if shouldStop || err != nil || len(stepToolCalls) == 0 {
+			break
+		}
+	}
+
+	totalUsage := Usage{}
+
+	for _, step := range steps {
+		usage := step.Usage
+		totalUsage.InputTokens += usage.InputTokens
+		totalUsage.OutputTokens += usage.OutputTokens
+		totalUsage.ReasoningTokens += usage.ReasoningTokens
+		totalUsage.CacheCreationTokens += usage.CacheCreationTokens
+		totalUsage.CacheReadTokens += usage.CacheReadTokens
+		totalUsage.TotalTokens += totalUsage.TotalTokens
+	}
+
+	agentResult := &AgentResult{
+		Steps:      steps,
+		Response:   steps[len(steps)-1].Response,
+		TotalUsage: totalUsage,
+	}
+	return agentResult, nil
+}
+
+func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {
+	if len(conditions) == 0 {
+		return false
+	}
+
+	for _, condition := range conditions {
+		if condition(steps) {
+			return true
+		}
+	}
+	return false
+}
+
+func toResponseMessages(content []Content) []Message {
+	var assistantParts []MessagePart
+	var toolParts []MessagePart
+
+	for _, c := range content {
+		switch c.GetType() {
+		case ContentTypeText:
+			text, ok := AsContentType[TextContent](c)
+			if !ok {
+				continue
+			}
+			assistantParts = append(assistantParts, TextPart{
+				Text:            text.Text,
+				ProviderOptions: ProviderOptions(text.ProviderMetadata),
+			})
+		case ContentTypeReasoning:
+			reasoning, ok := AsContentType[ReasoningContent](c)
+			if !ok {
+				continue
+			}
+			assistantParts = append(assistantParts, ReasoningPart{
+				Text:            reasoning.Text,
+				ProviderOptions: ProviderOptions(reasoning.ProviderMetadata),
+			})
+		case ContentTypeToolCall:
+			toolCall, ok := AsContentType[ToolCallContent](c)
+			if !ok {
+				continue
+			}
+			assistantParts = append(assistantParts, ToolCallPart{
+				ToolCallID:       toolCall.ToolCallID,
+				ToolName:         toolCall.ToolName,
+				Input:            toolCall.Input,
+				ProviderExecuted: toolCall.ProviderExecuted,
+				ProviderOptions:  ProviderOptions(toolCall.ProviderMetadata),
+			})
+		case ContentTypeFile:
+			file, ok := AsContentType[FileContent](c)
+			if !ok {
+				continue
+			}
+			assistantParts = append(assistantParts, FilePart{
+				Data:            file.Data,
+				MediaType:       file.MediaType,
+				ProviderOptions: ProviderOptions(file.ProviderMetadata),
+			})
+		case ContentTypeToolResult:
+			result, ok := AsContentType[ToolResultContent](c)
+			if !ok {
+				continue
+			}
+			toolParts = append(toolParts, ToolResultPart{
+				ToolCallID:      result.ToolCallID,
+				Output:          result.Result,
+				ProviderOptions: ProviderOptions(result.ProviderMetadata),
+			})
+		}
+	}
+
+	var messages []Message
+	if len(assistantParts) > 0 {
+		messages = append(messages, Message{
+			Role:    MessageRoleAssistant,
+			Content: assistantParts,
+		})
+	}
+	if len(toolParts) > 0 {
+		messages = append(messages, Message{
+			Role:    MessageRoleTool,
+			Content: toolParts,
+		})
+	}
+	return messages
+}
+
+func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
+	if len(toolCalls) == 0 {
+		return nil, nil
+	}
+
+	// Create a map for quick tool lookup
+	toolMap := make(map[string]tools.BaseTool)
+	for _, tool := range allTools {
+		toolMap[tool.Info().Name] = tool
+	}
+
+	// Execute all tool calls in parallel
+	results := make([]ToolResultContent, len(toolCalls))
+	var toolExecutionError error
+	var wg sync.WaitGroup
+
+	for i, toolCall := range toolCalls {
+		wg.Add(1)
+		go func(index int, call ToolCallContent) {
+			defer wg.Done()
+
+			tool, exists := toolMap[call.ToolName]
+			if !exists {
+				results[index] = ToolResultContent{
+					ToolCallID: call.ToolCallID,
+					ToolName:   call.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: errors.New("Error: Tool not found: " + call.ToolName),
+					},
+					ProviderExecuted: false,
+				}
+				return
+			}
+
+			// Execute the tool
+			result, err := tool.Run(ctx, tools.ToolCall{
+				ID:    call.ToolCallID,
+				Name:  call.ToolName,
+				Input: call.Input,
+			})
+			if err != nil {
+				results[index] = ToolResultContent{
+					ToolCallID: call.ToolCallID,
+					ToolName:   call.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: err,
+					},
+					ProviderExecuted: false,
+				}
+				toolExecutionError = err
+				return
+			}
+
+			if result.IsError {
+				results[index] = ToolResultContent{
+					ToolCallID: call.ToolCallID,
+					ToolName:   call.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: errors.New(result.Content),
+					},
+					ProviderExecuted: false,
+				}
+			} else {
+				results[index] = ToolResultContent{
+					ToolCallID: call.ToolCallID,
+					ToolName:   toolCall.ToolName,
+					Result: ToolResultOutputContentText{
+						Text: result.Content,
+					},
+					ProviderExecuted: false,
+				}
+			}
+		}(i, toolCall)
+	}
+
+	// Wait for all tool executions to complete
+	wg.Wait()
+
+	return results, toolExecutionError
 }
 
 // Stream implements Agent.
-func (a *agent) Stream(ctx context.Context, opts Call) (StreamResponse, error) {
+func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
 	// TODO: implement the agentic stuff
-	return a.settings.Model.Stream(ctx, mergeCall(a.settings.Call, opts))
+	panic("not implemented")
+}
+
+func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
+	var preparedTools []Tool
+	for _, tool := range tools {
+		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
+			continue
+		}
+		info := tool.Info()
+		preparedTools = append(preparedTools, FunctionTool{
+			Name:        info.Name,
+			Description: info.Description,
+			InputSchema: map[string]any{
+				"type":       "object",
+				"properties": info.Parameters,
+				"required":   info.Required,
+			},
+		})
+	}
+	return preparedTools
+}
+
+func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
+	if prompt == "" {
+		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
+	}
+
+	var preparedPrompt Prompt
+
+	if system != "" {
+		preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
+	}
+
+	preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
+	preparedPrompt = append(preparedPrompt, messages...)
+	return preparedPrompt, nil
+}
+
+func WithSystemPrompt(prompt string) agentOption {
+	return func(s *AgentSettings) {
+		s.systemPrompt = prompt
+	}
+}
+
+func WithMaxOutputTokens(tokens int64) agentOption {
+	return func(s *AgentSettings) {
+		s.maxOutputTokens = &tokens
+	}
+}
+
+func WithTemperature(temp float64) agentOption {
+	return func(s *AgentSettings) {
+		s.temperature = &temp
+	}
+}
+
+func WithTopP(topP float64) agentOption {
+	return func(s *AgentSettings) {
+		s.topP = &topP
+	}
+}
+
+func WithTopK(topK int64) agentOption {
+	return func(s *AgentSettings) {
+		s.topK = &topK
+	}
+}
+
+func WithPresencePenalty(penalty float64) agentOption {
+	return func(s *AgentSettings) {
+		s.presencePenalty = &penalty
+	}
+}
+
+func WithFrequencyPenalty(penalty float64) agentOption {
+	return func(s *AgentSettings) {
+		s.frequencyPenalty = &penalty
+	}
+}
+
+func WithTools(tools ...tools.BaseTool) agentOption {
+	return func(s *AgentSettings) {
+		s.tools = append(s.tools, tools...)
+	}
+}
+
+func WithStopConditions(conditions ...StopCondition) agentOption {
+	return func(s *AgentSettings) {
+		s.stopWhen = append(s.stopWhen, conditions...)
+	}
+}
+
+func WithPrepareStep(fn PrepareStepFunction) agentOption {
+	return func(s *AgentSettings) {
+		s.prepareStep = fn
+	}
+}
+
+func WithRepairToolCall(fn RepairToolCall) agentOption {
+	return func(s *AgentSettings) {
+		s.repairToolCall = fn
+	}
+}
+
+func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
+	return func(s *AgentSettings) {
+		s.onStepFinished = fn
+	}
+}
+
+func WithHeaders(headers map[string]string) agentOption {
+	return func(s *AgentSettings) {
+		s.headers = headers
+	}
+}
+
+func WithProviderOptions(providerOptions ProviderOptions) agentOption {
+	return func(s *AgentSettings) {
+		s.providerOptions = providerOptions
+	}
 }

content.go 🔗

@@ -184,7 +184,7 @@ func (t ToolResultOutputContentText) GetType() ToolResultContentType {
 }
 
 type ToolResultOutputContentError struct {
-	Error string `json:"error"`
+	Error error `json:"error"`
 }
 
 func (t ToolResultOutputContentError) GetType() ToolResultContentType {
@@ -268,11 +268,9 @@ type FileContent struct {
 	// The IANA media type of the file, e.g. `image/png` or `audio/mp3`.
 	// @see https://www.iana.org/assignments/media-types/media-types.xhtml
 	MediaType string `json:"media_type"`
-	// Generated file data as base64 encoded strings or binary data.
-	// If the API returns base64 encoded strings, the file data should be returned
-	// as base64 encoded strings. If the API returns binary data, the file data should
-	// be returned as binary data.
-	Data any `json:"data"` // string (base64) or []byte
+	// Generated file data as binary data.
+	Data             []byte           `json:"data"`
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
 }
 
 // GetType returns the type of the file content.
@@ -332,9 +330,7 @@ type ToolResultContent struct {
 	// Name of the tool that generated this result.
 	ToolName string `json:"tool_name"`
 	// Result of the tool call. This is a JSON-serializable object.
-	Result any `json:"result"`
-	// Optional flag if the result is an error or an error message.
-	IsError bool `json:"is_error"`
+	Result ToolResultOutputContent `json:"result"`
 	// Whether the tool result was generated by the provider.
 	// If this flag is set to true, the tool result was generated by the provider.
 	// If this flag is not set or is false, the tool result was generated by the client.
@@ -430,3 +426,15 @@ func NewUserMessage(prompt string, files ...FilePart) Message {
 		Content: content,
 	}
 }
+
+func NewSystemMessage(prompt ...string) Message {
+	var content []MessagePart
+	for _, p := range prompt {
+		content = append(content, TextPart{Text: p})
+	}
+
+	return Message{
+		Role:    MessageRoleSystem,
+		Content: content,
+	}
+}

errors.go 🔗

@@ -46,30 +46,28 @@ func IsAIError(err error) bool {
 // APICallError represents an error from an API call.
 type APICallError struct {
 	*AIError
-	URL               string
-	RequestBodyValues any
-	StatusCode        int
-	ResponseHeaders   map[string]string
-	ResponseBody      string
-	IsRetryable       bool
-	Data              any
+	URL             string
+	RequestDump     string
+	StatusCode      int
+	ResponseHeaders map[string]string
+	ResponseDump    string
+	IsRetryable     bool
 }
 
 // NewAPICallError creates a new API call error.
-func NewAPICallError(message, url string, requestBodyValues any, statusCode int, responseHeaders map[string]string, responseBody string, cause error, isRetryable bool, data any) *APICallError {
+func NewAPICallError(message, url string, requestDump string, statusCode int, responseHeaders map[string]string, responseDump string, cause error, isRetryable bool) *APICallError {
 	if !isRetryable && statusCode != 0 {
 		isRetryable = statusCode == 408 || statusCode == 409 || statusCode == 429 || statusCode >= 500
 	}
 
 	return &APICallError{
-		AIError:           NewAIError("AI_APICallError", message, cause),
-		URL:               url,
-		RequestBodyValues: requestBodyValues,
-		StatusCode:        statusCode,
-		ResponseHeaders:   responseHeaders,
-		ResponseBody:      responseBody,
-		IsRetryable:       isRetryable,
-		Data:              data,
+		AIError:         NewAIError("AI_APICallError", message, cause),
+		URL:             url,
+		RequestDump:     requestDump,
+		StatusCode:      statusCode,
+		ResponseHeaders: responseHeaders,
+		ResponseDump:    responseDump,
+		IsRetryable:     isRetryable,
 	}
 }
 

provider.go 🔗

@@ -1,26 +1,6 @@
 package ai
 
-import (
-	"encoding/json"
-
-	"github.com/go-viper/mapstructure/v2"
-)
-
 type Provider interface {
 	LanguageModel(modelID string) LanguageModel
 	// TODO: add other model types when needed
 }
-
-func ParseOptions[T any](options map[string]any, m *T) error {
-	return mapstructure.Decode(options, m)
-}
-
-func FloatOption(f float64) *float64 {
-	return &f
-}
-
-func IsParsableJSON(data string) bool {
-	var m map[string]any
-	err := json.Unmarshal([]byte(data), &m)
-	return err == nil
-}

providers/examples/agent/main.go 🔗

@@ -0,0 +1,72 @@
+package main
+
+import (
+	"context"
+	"fmt"
+	"os"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/ai/providers"
+	"github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+type weatherTool struct{}
+
+// Info implements tools.BaseTool.
+func (w *weatherTool) Info() tools.ToolInfo {
+	return tools.ToolInfo{
+		Name: "weather",
+		Parameters: map[string]any{
+			"location": map[string]string{
+				"type":        "string",
+				"description": "the city",
+			},
+		},
+		Required: []string{"location"},
+	}
+}
+
+// Name implements tools.BaseTool.
+func (w *weatherTool) Name() string {
+	return "weather"
+}
+
+// Run implements tools.BaseTool.
+func (w *weatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+	return tools.NewTextResponse("40 C"), nil
+}
+
+func newWeatherTool() tools.BaseTool {
+	return &weatherTool{}
+}
+
+func main() {
+	provider := providers.NewOpenAIProvider(
+		providers.WithOpenAIApiKey(os.Getenv("OPENAI_API_KEY")),
+	)
+	model := provider.LanguageModel("gpt-4o")
+
+	agent := ai.NewAgent(
+		model,
+		ai.WithSystemPrompt("You are a helpful assistant"),
+		ai.WithTools(newWeatherTool()),
+	)
+
+	result, _ := agent.Generate(context.Background(), ai.AgentCall{
+		Prompt: "What's the weather in pristina",
+	})
+
+	fmt.Println("Steps: ", len(result.Steps))
+	for _, s := range result.Steps {
+		for _, c := range s.Content {
+			if c.GetType() == ai.ContentTypeToolCall {
+				tc, _ := ai.AsContentType[ai.ToolCallContent](c)
+				fmt.Println("ToolCall: ", tc.ToolName)
+
+			}
+		}
+	}
+
+	fmt.Println("Final Response: ", result.Response.Content.Text())
+	fmt.Println("Total Usage: ", result.TotalUsage)
+}

providers/openai.go 🔗

@@ -394,6 +394,30 @@ func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletion
 	return params, warnings, nil
 }
 
+func (o openAILanguageModel) handleError(err error) error {
+	var apiErr *openai.Error
+	if errors.As(err, &apiErr) {
+		requestDump := apiErr.DumpRequest(true)
+		responseDump := apiErr.DumpResponse(true)
+		headers := map[string]string{}
+		for k, h := range apiErr.Response.Header {
+			v := h[len(h)-1]
+			headers[strings.ToLower(k)] = v
+		}
+		return ai.NewAPICallError(
+			apiErr.Message,
+			apiErr.Request.URL.String(),
+			string(requestDump),
+			apiErr.StatusCode,
+			headers,
+			string(responseDump),
+			apiErr,
+			false,
+		)
+	}
+	return err
+}
+
 // Generate implements ai.LanguageModel.
 func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
 	params, warnings, err := o.prepareParams(call)
@@ -402,7 +426,7 @@ func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Re
 	}
 	response, err := o.client.Chat.Completions.New(ctx, *params)
 	if err != nil {
-		return nil, err
+		return nil, o.handleError(err)
 	}
 
 	if len(response.Choices) == 0 {
@@ -626,7 +650,7 @@ func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
 							if err != nil {
 								yield(ai.StreamPart{
 									Type:  ai.StreamPartTypeError,
-									Error: stream.Err(),
+									Error: o.handleError(stream.Err()),
 								})
 								return
 							}
@@ -1097,7 +1121,7 @@ func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion,
 						})
 						continue
 					}
-					messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
+					messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
 				}
 			}
 		}

providers/openai_test.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/base64"
 	"encoding/json"
+	"errors"
 	"net/http"
 	"net/http/httptest"
 	"strings"
@@ -496,7 +497,7 @@ func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
 					ai.ToolResultPart{
 						ToolCallID: "error-tool",
 						Output: ai.ToolResultOutputContentError{
-							Error: "Something went wrong",
+							Error: errors.New("Something went wrong"),
 						},
 					},
 				},

retry.go 🔗

@@ -0,0 +1,170 @@
+package ai
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"strconv"
+	"time"
+)
+
+// RetryFn is a function that returns a value and an error.
+type RetryFn[T any] func() (T, error)
+
+// RetryFunction is a function that retries another function.
+type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
+
+// RetryReason represents the reason why a retry operation failed.
+type RetryReason string
+
+const (
+	RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
+	RetryReasonErrorNotRetryable  RetryReason = "errorNotRetryable"
+)
+
+// RetryError represents an error that occurred during retry operations.
+type RetryError struct {
+	*AIError
+	Reason RetryReason
+	Errors []error
+}
+
+// NewRetryError creates a new retry error.
+func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
+	return &RetryError{
+		AIError: NewAIError("AI_RetryError", message, nil),
+		Reason:  reason,
+		Errors:  errors,
+	}
+}
+
+// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
+func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
+	var apiErr *APICallError
+	if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
+		return exponentialBackoffDelay
+	}
+
+	headers := apiErr.ResponseHeaders
+	var ms time.Duration
+
+	// retry-ms is more precise than retry-after and used by e.g. OpenAI
+	if retryAfterMs, exists := headers["retry-after-ms"]; exists {
+		if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
+			ms = time.Duration(timeoutMs) * time.Millisecond
+		}
+	}
+
+	// About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
+	if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
+		if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
+			ms = time.Duration(timeoutSeconds) * time.Second
+		} else {
+			// Try parsing as HTTP date
+			if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
+				ms = time.Until(t)
+			}
+		}
+	}
+
+	// Check that the delay is reasonable:
+	// 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
+	if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
+		return ms
+	}
+
+	return exponentialBackoffDelay
+}
+
+// isAbortError checks if the error is a context cancellation error.
+func isAbortError(err error) bool {
+	return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
+}
+
+// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
+// a failed operation with exponential backoff, while respecting rate limit headers
+// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
+func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
+	return func(ctx context.Context, fn RetryFn[T]) (T, error) {
+		return retryWithExponentialBackoff(ctx, fn, options, nil)
+	}
+}
+
+// RetryOptions configures the retry behavior.
+type RetryOptions struct {
+	MaxRetries     int
+	InitialDelayIn time.Duration
+	BackoffFactor  float64
+	OnRetry        OnRetryCallback
+}
+
+type OnRetryCallback = func(err *APICallError, delay time.Duration)
+
+// DefaultRetryOptions returns the default retry options.
+func DefaultRetryOptions() RetryOptions {
+	return RetryOptions{
+		MaxRetries:     2,
+		InitialDelayIn: 2000 * time.Millisecond,
+		BackoffFactor:  2.0,
+	}
+}
+
+// retryWithExponentialBackoff implements the retry logic with exponential backoff.
+func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
+	var zero T
+	result, err := fn()
+	if err == nil {
+		return result, nil
+	}
+
+	if isAbortError(err) {
+		return zero, err // don't retry when the request was aborted
+	}
+
+	if options.MaxRetries == 0 {
+		return zero, err // don't wrap the error when retries are disabled
+	}
+
+	errorMessage := GetErrorMessage(err)
+	newErrors := append(allErrors, err)
+	tryNumber := len(newErrors)
+
+	if tryNumber > options.MaxRetries {
+		return zero, NewRetryError(
+			fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
+			RetryReasonMaxRetriesExceeded,
+			newErrors,
+		)
+	}
+
+	var apiErr *APICallError
+	if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
+		delay := getRetryDelayInMs(err, options.InitialDelayIn)
+		if options.OnRetry != nil {
+			options.OnRetry(apiErr, delay)
+		}
+
+		select {
+		case <-time.After(delay):
+			// Continue with retry
+		case <-ctx.Done():
+			return zero, ctx.Err()
+		}
+
+		newOptions := options
+		newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
+
+		return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
+	}
+
+	if tryNumber == 1 {
+		return zero, err // don't wrap the error when a non-retryable error occurs on the first try
+	}
+
+	return zero, NewRetryError(
+		fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
+		RetryReasonErrorNotRetryable,
+		newErrors,
+	)
+}
+

tool.go 🔗

@@ -0,0 +1,234 @@
+// WIP NEED TO REVISIT
+package ai
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+)
+
+// AgentTool represents a function that can be called by a language model.
+type AgentTool interface {
+	Name() string
+	Description() string
+	InputSchema() Schema
+	Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error)
+}
+
+// Schema represents a JSON schema for tool input validation.
+type Schema struct {
+	Type        string             `json:"type"`
+	Properties  map[string]*Schema `json:"properties,omitempty"`
+	Required    []string           `json:"required,omitempty"`
+	Items       *Schema            `json:"items,omitempty"`
+	Description string             `json:"description,omitempty"`
+	Enum        []any              `json:"enum,omitempty"`
+	Format      string             `json:"format,omitempty"`
+	Minimum     *float64           `json:"minimum,omitempty"`
+	Maximum     *float64           `json:"maximum,omitempty"`
+	MinLength   *int               `json:"minLength,omitempty"`
+	MaxLength   *int               `json:"maxLength,omitempty"`
+}
+
+// BasicTool provides a basic implementation of the Tool interface
+//
+// Example usage:
+//
+//	calculator := &tools.BasicTool{
+//	    ToolName:        "calculate",
+//	    ToolDescription: "Evaluates mathematical expressions",
+//	    ToolInputSchema: tools.Schema{
+//	        Type: "object",
+//	        Properties: map[string]*tools.Schema{
+//	            "expression": {
+//	                Type:        "string",
+//	                Description: "Mathematical expression to evaluate",
+//	            },
+//	        },
+//	        Required: []string{"expression"},
+//	    },
+//	    ExecuteFunc: func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
+//	        var req struct {
+//	            Expression string `json:"expression"`
+//	        }
+//	        if err := json.Unmarshal(input, &req); err != nil {
+//	            return nil, err
+//	        }
+//	        result := evaluateExpression(req.Expression)
+//	        return json.Marshal(map[string]any{"result": result})
+//	    },
+//	}
+type BasicTool struct {
+	ToolName        string
+	ToolDescription string
+	ToolInputSchema Schema
+	ExecuteFunc     func(context.Context, json.RawMessage) (json.RawMessage, error)
+}
+
+// Name returns the tool name.
+func (t *BasicTool) Name() string {
+	return t.ToolName
+}
+
+// Description returns the tool description.
+func (t *BasicTool) Description() string {
+	return t.ToolDescription
+}
+
+// InputSchema returns the tool input schema.
+func (t *BasicTool) InputSchema() Schema {
+	return t.ToolInputSchema
+}
+
+// Execute executes the tool with the given input.
+func (t *BasicTool) Execute(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
+	if t.ExecuteFunc == nil {
+		return nil, fmt.Errorf("tool %s has no execute function", t.ToolName)
+	}
+	return t.ExecuteFunc(ctx, input)
+}
+
+// ToolBuilder provides a fluent interface for building tools.
+type ToolBuilder struct {
+	tool *BasicTool
+}
+
+// NewTool creates a new tool builder.
+func NewTool(name string) *ToolBuilder {
+	return &ToolBuilder{
+		tool: &BasicTool{
+			ToolName: name,
+		},
+	}
+}
+
+// Description sets the tool description.
+func (b *ToolBuilder) Description(desc string) *ToolBuilder {
+	b.tool.ToolDescription = desc
+	return b
+}
+
+// InputSchema sets the tool input schema.
+func (b *ToolBuilder) InputSchema(schema Schema) *ToolBuilder {
+	b.tool.ToolInputSchema = schema
+	return b
+}
+
+// Execute sets the tool execution function.
+func (b *ToolBuilder) Execute(fn func(context.Context, json.RawMessage) (json.RawMessage, error)) *ToolBuilder {
+	b.tool.ExecuteFunc = fn
+	return b
+}
+
+// Build creates the final tool.
+func (b *ToolBuilder) Build() AgentTool {
+	return b.tool
+}
+
+// SchemaBuilder provides a fluent interface for building JSON schemas.
+type SchemaBuilder struct {
+	schema Schema
+}
+
+// NewSchema creates a new schema builder.
+func NewSchema(schemaType string) *SchemaBuilder {
+	return &SchemaBuilder{
+		schema: Schema{
+			Type: schemaType,
+		},
+	}
+}
+
+// Object creates a schema builder for an object type.
+func Object() *SchemaBuilder {
+	return NewSchema("object")
+}
+
+// String creates a schema builder for a string type.
+func String() *SchemaBuilder {
+	return NewSchema("string")
+}
+
+// Number creates a schema builder for a number type.
+func Number() *SchemaBuilder {
+	return NewSchema("number")
+}
+
+// Array creates a schema builder for an array type.
+func Array() *SchemaBuilder {
+	return NewSchema("array")
+}
+
+// Description sets the schema description.
+func (b *SchemaBuilder) Description(desc string) *SchemaBuilder {
+	b.schema.Description = desc
+	return b
+}
+
+// Properties sets the schema properties.
+func (b *SchemaBuilder) Properties(props map[string]*Schema) *SchemaBuilder {
+	b.schema.Properties = props
+	return b
+}
+
+// Property adds a property to the schema.
+func (b *SchemaBuilder) Property(name string, schema *Schema) *SchemaBuilder {
+	if b.schema.Properties == nil {
+		b.schema.Properties = make(map[string]*Schema)
+	}
+	b.schema.Properties[name] = schema
+	return b
+}
+
+// Required marks fields as required.
+func (b *SchemaBuilder) Required(fields ...string) *SchemaBuilder {
+	b.schema.Required = append(b.schema.Required, fields...)
+	return b
+}
+
+// Items sets the schema for array items.
+func (b *SchemaBuilder) Items(schema *Schema) *SchemaBuilder {
+	b.schema.Items = schema
+	return b
+}
+
+// Enum sets allowed values for the schema.
+func (b *SchemaBuilder) Enum(values ...any) *SchemaBuilder {
+	b.schema.Enum = values
+	return b
+}
+
+// Format sets the string format.
+func (b *SchemaBuilder) Format(format string) *SchemaBuilder {
+	b.schema.Format = format
+	return b
+}
+
+// Min sets the minimum value.
+func (b *SchemaBuilder) Min(minimum float64) *SchemaBuilder {
+	b.schema.Minimum = &minimum
+	return b
+}
+
+// Max sets the maximum value.
+func (b *SchemaBuilder) Max(maximum float64) *SchemaBuilder {
+	b.schema.Maximum = &maximum
+	return b
+}
+
+// MinLength sets the minimum string length.
+func (b *SchemaBuilder) MinLength(minimum int) *SchemaBuilder {
+	b.schema.MinLength = &minimum
+	return b
+}
+
+// MaxLength sets the maximum string length.
+func (b *SchemaBuilder) MaxLength(maximum int) *SchemaBuilder {
+	b.schema.MaxLength = &maximum
+	return b
+}
+
+// Build creates the final schema.
+func (b *SchemaBuilder) Build() *Schema {
+	return &b.schema
+}

util.go 🔗

@@ -0,0 +1,21 @@
+package ai
+
+import (
+	"encoding/json"
+
+	"github.com/go-viper/mapstructure/v2"
+)
+
+func ParseOptions[T any](options map[string]any, m *T) error {
+	return mapstructure.Decode(options, m)
+}
+
+func FloatOption(f float64) *float64 {
+	return &f
+}
+
+func IsParsableJSON(data string) bool {
+	var m map[string]any
+	err := json.Unmarshal([]byte(data), &m)
+	return err == nil
+}