feat: tool repair and more stop conditions

Kujtim Hoxha created

Change summary

internal/ai/agent.go      | 235 ++++++++++
internal/ai/agent_test.go | 890 ++++++++++++++++++++++++++++++++++++++++
internal/ai/content.go    |   4 
internal/ai/model.go      |  74 +++
internal/ai/retry.go      |   1 
5 files changed, 1,188 insertions(+), 16 deletions(-)

Detailed changes

internal/ai/agent.go 🔗

@@ -2,7 +2,9 @@ package ai
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
+	"fmt"
 	"maps"
 	"slices"
 	"sync"
@@ -12,12 +14,73 @@ import (
 
 type StepResult struct {
 	Response
-	// Messages generated during this step
 	Messages []Message
 }
 
 type StopCondition = func(steps []StepResult) bool
 
+// StepCountIs returns a stop condition that stops after the specified number of steps.
+func StepCountIs(stepCount int) StopCondition {
+	return func(steps []StepResult) bool {
+		return len(steps) >= stepCount
+	}
+}
+
+// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
+func HasToolCall(toolName string) StopCondition {
+	return func(steps []StepResult) bool {
+		if len(steps) == 0 {
+			return false
+		}
+		lastStep := steps[len(steps)-1]
+		toolCalls := lastStep.Content.ToolCalls()
+		for _, toolCall := range toolCalls {
+			if toolCall.ToolName == toolName {
+				return true
+			}
+		}
+		return false
+	}
+}
+
+// HasContent returns a stop condition that stops when the specified content type appears in the last step.
+func HasContent(contentType ContentType) StopCondition {
+	return func(steps []StepResult) bool {
+		if len(steps) == 0 {
+			return false
+		}
+		lastStep := steps[len(steps)-1]
+		for _, content := range lastStep.Content {
+			if content.GetType() == contentType {
+				return true
+			}
+		}
+		return false
+	}
+}
+
+// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
+func FinishReasonIs(reason FinishReason) StopCondition {
+	return func(steps []StepResult) bool {
+		if len(steps) == 0 {
+			return false
+		}
+		lastStep := steps[len(steps)-1]
+		return lastStep.FinishReason == reason
+	}
+}
+
+// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
+func MaxTokensUsed(maxTokens int64) StopCondition {
+	return func(steps []StepResult) bool {
+		var totalTokens int64
+		for _, step := range steps {
+			totalTokens += step.Usage.TotalTokens
+		}
+		return totalTokens >= maxTokens
+	}
+}
+
 type PrepareStepFunctionOptions struct {
 	Steps      []StepResult
 	StepNumber int
@@ -26,14 +89,26 @@ type PrepareStepFunctionOptions struct {
 }
 
 type PrepareStepResult struct {
-	Model    LanguageModel
-	Messages []Message
+	Model           LanguageModel
+	Messages        []Message
+	System          *string
+	ToolChoice      *ToolChoice
+	ActiveTools     []string
+	DisableAllTools bool
+}
+
+type ToolCallRepairOptions struct {
+	OriginalToolCall ToolCallContent
+	ValidationError  error
+	AvailableTools   []tools.BaseTool
+	SystemPrompt     string
+	Messages         []Message
 }
 
 type (
 	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
 	OnStepFinishedFunction = func(step StepResult)
-	RepairToolCall         = func(ToolCallContent) ToolCallContent
+	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 )
 
 type AgentSettings struct {
@@ -55,7 +130,7 @@ type AgentSettings struct {
 
 	stopWhen       []StopCondition
 	prepareStep    PrepareStepFunction
-	repairToolCall RepairToolCall
+	repairToolCall RepairToolCallFunction
 	onStepFinished OnStepFinishedFunction
 	onRetry        OnRetryCallback
 }
@@ -78,7 +153,7 @@ type AgentCall struct {
 
 	StopWhen       []StopCondition
 	PrepareStep    PrepareStepFunction
-	RepairToolCall RepairToolCall
+	RepairToolCall RepairToolCallFunction
 	OnStepFinished OnStepFinishedFunction
 }
 
@@ -185,6 +260,11 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 	for {
 		stepInputMessages := append(initialPrompt, responseMessages...)
 		stepModel := a.settings.model
+		stepSystemPrompt := a.settings.systemPrompt
+		stepActiveTools := opts.ActiveTools
+		stepToolChoice := ToolChoiceAuto
+		disableAllTools := false
+
 		if opts.PrepareStep != nil {
 			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
 				Model:      stepModel,
@@ -192,15 +272,40 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				StepNumber: len(steps),
 				Messages:   stepInputMessages,
 			})
-			stepInputMessages = prepared.Messages
+
+			// Apply prepared step modifications
+			if prepared.Messages != nil {
+				stepInputMessages = prepared.Messages
+			}
 			if prepared.Model != nil {
 				stepModel = prepared.Model
 			}
+			if prepared.System != nil {
+				stepSystemPrompt = *prepared.System
+			}
+			if prepared.ToolChoice != nil {
+				stepToolChoice = *prepared.ToolChoice
+			}
+			if len(prepared.ActiveTools) > 0 {
+				stepActiveTools = prepared.ActiveTools
+			}
+			disableAllTools = prepared.DisableAllTools
+		}
+
+		// Recreate prompt with potentially modified system prompt
+		if stepSystemPrompt != a.settings.systemPrompt {
+			stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
+			if err != nil {
+				return nil, err
+			}
+			// Replace system message part, keep the rest
+			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
+				stepInputMessages[0] = stepPrompt[0] // Replace system message
+			}
 		}
 
-		preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
+		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 
-		toolChoice := ToolChoiceAuto
 		retryOptions := DefaultRetryOptions()
 		retryOptions.OnRetry = opts.OnRetry
 		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
@@ -215,7 +320,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				PresencePenalty:  opts.PresencePenalty,
 				FrequencyPenalty: opts.FrequencyPenalty,
 				Tools:            preparedTools,
-				ToolChoice:       &toolChoice,
+				ToolChoice:       &stepToolChoice,
 				Headers:          opts.Headers,
 				ProviderOptions:  opts.ProviderOptions,
 			})
@@ -231,13 +336,31 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 				if !ok {
 					continue
 				}
-				stepToolCalls = append(stepToolCalls, toolCall)
+
+				// Validate and potentially repair the tool call
+				validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
+				stepToolCalls = append(stepToolCalls, validatedToolCall)
 			}
 		}
 
 		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
 
-		stepContent := result.Content
+		// Build step content with validated tool calls and tool results
+		stepContent := []Content{}
+		toolCallIndex := 0
+		for _, content := range result.Content {
+			if content.GetType() == ContentTypeToolCall {
+				// Replace with validated tool call
+				if toolCallIndex < len(stepToolCalls) {
+					stepContent = append(stepContent, stepToolCalls[toolCallIndex])
+					toolCallIndex++
+				}
+			} else {
+				// Keep other content as-is
+				stepContent = append(stepContent, content)
+			}
+		}
+		// Add tool results
 		for _, result := range toolResults {
 			stepContent = append(stepContent, result)
 		}
@@ -345,6 +468,10 @@ func toResponseMessages(content []Content) []Message {
 				MediaType:       file.MediaType,
 				ProviderOptions: ProviderOptions(file.ProviderMetadata),
 			})
+		case ContentTypeSource:
+			// Sources are metadata about references used to generate the response.
+			// They don't need to be included in the conversation messages.
+			continue
 		case ContentTypeToolResult:
 			result, ok := AsContentType[ToolResultContent](c)
 			if !ok {
@@ -395,6 +522,19 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 		go func(index int, call ToolCallContent) {
 			defer wg.Done()
 
+			// Skip invalid tool calls - create error result
+			if call.Invalid {
+				results[index] = ToolResultContent{
+					ToolCallID: call.ToolCallID,
+					ToolName:   call.ToolName,
+					Result: ToolResultOutputContentError{
+						Error: call.ValidationError,
+					},
+					ProviderExecuted: false,
+				}
+				return
+			}
+
 			tool, exists := toolMap[call.ToolName]
 			if !exists {
 				results[index] = ToolResultContent{
@@ -461,9 +601,17 @@ func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, err
 	panic("not implemented")
 }
 
-func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
+func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
 	var preparedTools []Tool
+
+	// If explicitly disabling all tools, return no tools
+	if disableAllTools {
+		return preparedTools
+	}
+
 	for _, tool := range tools {
+		// If activeTools has items, only include tools in the list
+		// If activeTools is empty, include all tools
 		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
 			continue
 		}
@@ -481,6 +629,65 @@ func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Too
 	return preparedTools
 }
 
+// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
+func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
+	if err := a.validateToolCall(toolCall, availableTools); err == nil {
+		return toolCall
+	} else {
+		if repairFunc != nil {
+			repairOptions := ToolCallRepairOptions{
+				OriginalToolCall: toolCall,
+				ValidationError:  err,
+				AvailableTools:   availableTools,
+				SystemPrompt:     systemPrompt,
+				Messages:         messages,
+			}
+
+			if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
+				if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
+					return *repairedToolCall
+				}
+			}
+		}
+
+		invalidToolCall := toolCall
+		invalidToolCall.Invalid = true
+		invalidToolCall.ValidationError = err
+		return invalidToolCall
+	}
+}
+
+// validateToolCall validates a tool call against available tools and their schemas
+func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
+	var tool tools.BaseTool
+	for _, t := range availableTools {
+		if t.Info().Name == toolCall.ToolName {
+			tool = t
+			break
+		}
+	}
+
+	if tool == nil {
+		return fmt.Errorf("tool not found: %s", toolCall.ToolName)
+	}
+
+	// Validate JSON parsing
+	var input map[string]any
+	if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
+		return fmt.Errorf("invalid JSON input: %w", err)
+	}
+
+	// Basic schema validation (check required fields)
+	// TODO: more robust schema validation using JSON Schema or similar
+	toolInfo := tool.Info()
+	for _, required := range toolInfo.Required {
+		if _, exists := input[required]; !exists {
+			return fmt.Errorf("missing required parameter: %s", required)
+		}
+	}
+	return nil
+}
+
 func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
 	if prompt == "" {
 		return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
@@ -557,7 +764,7 @@ func WithPrepareStep(fn PrepareStepFunction) agentOption {
 	}
 }
 
-func WithRepairToolCall(fn RepairToolCall) agentOption {
+func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
 	return func(s *AgentSettings) {
 		s.repairToolCall = fn
 	}

internal/ai/agent_test.go 🔗

@@ -3,6 +3,8 @@ package ai
 import (
 	"context"
 	"encoding/json"
+	"errors"
+	"fmt"
 	"testing"
 
 	"github.com/charmbracelet/crush/internal/llm/tools"
@@ -645,4 +647,890 @@ func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 
 	require.NoError(t, err)
 	require.NotNil(t, result)
-}
+}
+
+func TestResponseContent_Getters(t *testing.T) {
+	t.Parallel()
+
+	// Create test content with all types
+	content := ResponseContent{
+		TextContent{Text: "Hello world"},
+		ReasoningContent{Text: "Let me think..."},
+		FileContent{Data: []byte("file data"), MediaType: "text/plain"},
+		SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
+		ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
+		ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
+	}
+
+	// Test Text()
+	require.Equal(t, "Hello world", content.Text())
+
+	// Test Reasoning()
+	reasoning := content.Reasoning()
+	require.Len(t, reasoning, 1)
+	require.Equal(t, "Let me think...", reasoning[0].Text)
+
+	// Test ReasoningText()
+	require.Equal(t, "Let me think...", content.ReasoningText())
+
+	// Test Files()
+	files := content.Files()
+	require.Len(t, files, 1)
+	require.Equal(t, "text/plain", files[0].MediaType)
+	require.Equal(t, []byte("file data"), files[0].Data)
+
+	// Test Sources()
+	sources := content.Sources()
+	require.Len(t, sources, 1)
+	require.Equal(t, SourceTypeURL, sources[0].SourceType)
+	require.Equal(t, "https://example.com", sources[0].URL)
+	require.Equal(t, "Example", sources[0].Title)
+
+	// Test ToolCalls()
+	toolCalls := content.ToolCalls()
+	require.Len(t, toolCalls, 1)
+	require.Equal(t, "call1", toolCalls[0].ToolCallID)
+	require.Equal(t, "test_tool", toolCalls[0].ToolName)
+	require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
+
+	// Test ToolResults()
+	toolResults := content.ToolResults()
+	require.Len(t, toolResults, 1)
+	require.Equal(t, "call1", toolResults[0].ToolCallID)
+	require.Equal(t, "test_tool", toolResults[0].ToolName)
+	result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
+	require.True(t, ok)
+	require.Equal(t, "result", result.Text)
+}
+
+func TestResponseContent_Getters_Empty(t *testing.T) {
+	t.Parallel()
+
+	// Test with empty content
+	content := ResponseContent{}
+
+	require.Equal(t, "", content.Text())
+	require.Equal(t, "", content.ReasoningText())
+	require.Empty(t, content.Reasoning())
+	require.Empty(t, content.Files())
+	require.Empty(t, content.Sources())
+	require.Empty(t, content.ToolCalls())
+	require.Empty(t, content.ToolResults())
+}
+
+func TestResponseContent_Getters_MultipleItems(t *testing.T) {
+	t.Parallel()
+
+	// Test with multiple items of same type
+	content := ResponseContent{
+		ReasoningContent{Text: "First thought"},
+		ReasoningContent{Text: "Second thought"},
+		FileContent{Data: []byte("file1"), MediaType: "text/plain"},
+		FileContent{Data: []byte("file2"), MediaType: "image/png"},
+	}
+
+	// Test multiple reasoning
+	reasoning := content.Reasoning()
+	require.Len(t, reasoning, 2)
+	require.Equal(t, "First thought", reasoning[0].Text)
+	require.Equal(t, "Second thought", reasoning[1].Text)
+
+	// Test concatenated reasoning text
+	require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
+
+	// Test multiple files
+	files := content.Files()
+	require.Len(t, files, 2)
+	require.Equal(t, "text/plain", files[0].MediaType)
+	require.Equal(t, "image/png", files[1].MediaType)
+}
+
+func TestStopConditions(t *testing.T) {
+	t.Parallel()
+
+	// Create test steps
+	step1 := StepResult{
+		Response: Response{
+			Content: ResponseContent{
+				TextContent{Text: "Hello"},
+			},
+			FinishReason: FinishReasonToolCalls,
+			Usage:        Usage{TotalTokens: 10},
+		},
+	}
+
+	step2 := StepResult{
+		Response: Response{
+			Content: ResponseContent{
+				TextContent{Text: "World"},
+				ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
+			},
+			FinishReason: FinishReasonStop,
+			Usage:        Usage{TotalTokens: 15},
+		},
+	}
+
+	step3 := StepResult{
+		Response: Response{
+			Content: ResponseContent{
+				ReasoningContent{Text: "Let me think..."},
+				FileContent{Data: []byte("data"), MediaType: "text/plain"},
+			},
+			FinishReason: FinishReasonLength,
+			Usage:        Usage{TotalTokens: 20},
+		},
+	}
+
+	t.Run("StepCountIs", func(t *testing.T) {
+		condition := StepCountIs(2)
+
+		// Should not stop with 1 step
+		require.False(t, condition([]StepResult{step1}))
+
+		// Should stop with 2 steps
+		require.True(t, condition([]StepResult{step1, step2}))
+
+		// Should stop with more than 2 steps
+		require.True(t, condition([]StepResult{step1, step2, step3}))
+
+		// Should not stop with empty steps
+		require.False(t, condition([]StepResult{}))
+	})
+
+	t.Run("HasToolCall", func(t *testing.T) {
+		condition := HasToolCall("search")
+
+		// Should not stop when tool not called
+		require.False(t, condition([]StepResult{step1}))
+
+		// Should stop when tool is called in last step
+		require.True(t, condition([]StepResult{step1, step2}))
+
+		// Should not stop when tool called in earlier step but not last
+		require.False(t, condition([]StepResult{step1, step2, step3}))
+
+		// Should not stop with empty steps
+		require.False(t, condition([]StepResult{}))
+
+		// Should not stop when different tool is called
+		differentToolCondition := HasToolCall("different_tool")
+		require.False(t, differentToolCondition([]StepResult{step1, step2}))
+	})
+
+	t.Run("HasContent", func(t *testing.T) {
+		reasoningCondition := HasContent(ContentTypeReasoning)
+		fileCondition := HasContent(ContentTypeFile)
+
+		// Should not stop when content type not present
+		require.False(t, reasoningCondition([]StepResult{step1, step2}))
+
+		// Should stop when content type is present in last step
+		require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
+		require.True(t, fileCondition([]StepResult{step1, step2, step3}))
+
+		// Should not stop with empty steps
+		require.False(t, reasoningCondition([]StepResult{}))
+	})
+
+	t.Run("FinishReasonIs", func(t *testing.T) {
+		stopCondition := FinishReasonIs(FinishReasonStop)
+		lengthCondition := FinishReasonIs(FinishReasonLength)
+
+		// Should not stop when finish reason doesn't match
+		require.False(t, stopCondition([]StepResult{step1}))
+
+		// Should stop when finish reason matches in last step
+		require.True(t, stopCondition([]StepResult{step1, step2}))
+		require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
+
+		// Should not stop with empty steps
+		require.False(t, stopCondition([]StepResult{}))
+	})
+
+	t.Run("MaxTokensUsed", func(t *testing.T) {
+		condition := MaxTokensUsed(30)
+
+		// Should not stop when under limit
+		require.False(t, condition([]StepResult{step1}))        // 10 tokens
+		require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
+
+		// Should stop when at or over limit
+		require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
+
+		// Should not stop with empty steps
+		require.False(t, condition([]StepResult{}))
+	})
+}
+
+func TestStopConditions_Integration(t *testing.T) {
+	t.Parallel()
+
+	t.Run("StepCountIs integration", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Mock response"},
+					},
+					Usage: Usage{
+						InputTokens:  3,
+						OutputTokens: 10,
+						TotalTokens:  13,
+					},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Should stop after 1 step
+	})
+
+	t.Run("Multiple stop conditions", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Mock response"},
+					},
+					Usage: Usage{
+						InputTokens:  3,
+						OutputTokens: 10,
+						TotalTokens:  13,
+					},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithStopConditions(
+			StepCountIs(5),                   // Stop after 5 steps
+			FinishReasonIs(FinishReasonStop), // Or stop on finish reason
+		))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		// Should stop on first condition met (finish reason stop)
+		require.Equal(t, FinishReasonStop, result.Response.FinishReason)
+	})
+}
+
+func TestPrepareStep(t *testing.T) {
+	t.Parallel()
+
+	t.Run("System prompt modification", func(t *testing.T) {
+		var capturedSystemPrompt string
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				// Capture the system message to verify it was modified
+				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+					if len(call.Prompt[0].Content) > 0 {
+						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+							capturedSystemPrompt = textPart.Text
+						}
+					}
+				}
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
+			return PrepareStepResult{
+				Model:    options.Model,
+				Messages: options.Messages,
+				System:   &newSystem,
+			}
+		}
+
+		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
+	})
+
+	t.Run("Tool choice modification", func(t *testing.T) {
+		var capturedToolChoice *ToolChoice
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				capturedToolChoice = call.ToolChoice
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			toolChoice := ToolChoiceNone
+			return PrepareStepResult{
+				Model:      options.Model,
+				Messages:   options.Messages,
+				ToolChoice: &toolChoice,
+			}
+		}
+
+		agent := NewAgent(model)
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.NotNil(t, capturedToolChoice)
+		require.Equal(t, ToolChoiceNone, *capturedToolChoice)
+	})
+
+	t.Run("Active tools modification", func(t *testing.T) {
+		var capturedToolNames []string
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				// Capture tool names to verify active tools were modified
+				for _, tool := range call.Tools {
+					capturedToolNames = append(capturedToolNames, tool.GetName())
+				}
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			activeTools := []string{"tool2"} // Only tool2 should be active
+			return PrepareStepResult{
+				Model:       options.Model,
+				Messages:    options.Messages,
+				ActiveTools: activeTools,
+			}
+		}
+
+		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, capturedToolNames, 1)
+		require.Equal(t, "tool2", capturedToolNames[0])
+	})
+
+	t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
+		var capturedToolCount int
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				capturedToolCount = len(call.Tools)
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			return PrepareStepResult{
+				Model:           options.Model,
+				Messages:        options.Messages,
+				DisableAllTools: true, // Disable all tools for this step
+			}
+		}
+
+		agent := NewAgent(model, WithTools(tool1))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Equal(t, 0, capturedToolCount) // No tools should be passed
+	})
+
+	t.Run("All fields modified together", func(t *testing.T) {
+		var capturedSystemPrompt string
+		var capturedToolChoice *ToolChoice
+		var capturedToolNames []string
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				// Capture system prompt
+				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+					if len(call.Prompt[0].Content) > 0 {
+						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+							capturedSystemPrompt = textPart.Text
+						}
+					}
+				}
+				// Capture tool choice
+				capturedToolChoice = call.ToolChoice
+				// Capture tool names
+				for _, tool := range call.Tools {
+					capturedToolNames = append(capturedToolNames, tool.GetName())
+				}
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			newSystem := "Step-specific system"
+			toolChoice := SpecificToolChoice("tool1")
+			activeTools := []string{"tool1"}
+			return PrepareStepResult{
+				Model:       options.Model,
+				Messages:    options.Messages,
+				System:      &newSystem,
+				ToolChoice:  &toolChoice,
+				ActiveTools: activeTools,
+			}
+		}
+
+		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Equal(t, "Step-specific system", capturedSystemPrompt)
+		require.NotNil(t, capturedToolChoice)
+		require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
+		require.Len(t, capturedToolNames, 1)
+		require.Equal(t, "tool1", capturedToolNames[0])
+	})
+
+	t.Run("Nil fields use parent values", func(t *testing.T) {
+		var capturedSystemPrompt string
+		var capturedToolChoice *ToolChoice
+		var capturedToolNames []string
+
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				// Capture system prompt
+				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+					if len(call.Prompt[0].Content) > 0 {
+						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+							capturedSystemPrompt = textPart.Text
+						}
+					}
+				}
+				// Capture tool choice
+				capturedToolChoice = call.ToolChoice
+				// Capture tool names
+				for _, tool := range call.Tools {
+					capturedToolNames = append(capturedToolNames, tool.GetName())
+				}
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			// All optional fields are nil, should use parent values
+			return PrepareStepResult{
+				Model:       options.Model,
+				Messages:    options.Messages,
+				System:      nil, // Use parent
+				ToolChoice:  nil, // Use parent (auto)
+				ActiveTools: nil, // Use parent (all tools)
+			}
+		}
+
+		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Equal(t, "Parent system", capturedSystemPrompt)
+		require.NotNil(t, capturedToolChoice)
+		require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
+		require.Len(t, capturedToolNames, 1)
+		require.Equal(t, "tool1", capturedToolNames[0])
+	})
+
+	t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
+		var capturedToolNames []string
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				// Capture tool names to verify all tools are included
+				for _, tool := range call.Tools {
+					capturedToolNames = append(capturedToolNames, tool.GetName())
+				}
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop,
+				}, nil
+			},
+		}
+
+		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+
+		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+			return PrepareStepResult{
+				Model:       options.Model,
+				Messages:    options.Messages,
+				ActiveTools: []string{}, // Empty slice means all tools
+			}
+		}
+
+		agent := NewAgent(model, WithTools(tool1, tool2))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt:      "test prompt",
+			PrepareStep: prepareStepFunc,
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, capturedToolNames, 2) // All tools should be included
+		require.Contains(t, capturedToolNames, "tool1")
+		require.Contains(t, capturedToolNames, "tool2")
+	})
+}
+
+func TestToolCallRepair(t *testing.T) {
+	t.Parallel()
+
+	t.Run("Valid tool call passes validation", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "test_tool",
+							Input:      `{"value": "test"}`, // Valid JSON with required field
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{
+			name:        "test_tool",
+			description: "Test tool",
+			parameters: map[string]any{
+				"value": map[string]any{"type": "string"},
+			},
+			required: []string{"value"},
+			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+				return tools.ToolResponse{Content: "success", IsError: false}, nil
+			},
+		}
+
+		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
+
+		// Check that tool call was executed successfully
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.False(t, toolCalls[0].Invalid) // Should be valid
+	})
+
+	t.Run("Invalid tool call without repair function", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "test_tool",
+							Input:      `{"wrong_field": "test"}`, // Missing required field
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{
+			name:        "test_tool",
+			description: "Test tool",
+			parameters: map[string]any{
+				"value": map[string]any{"type": "string"},
+			},
+			required: []string{"value"},
+		}
+
+		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step
+
+		// Check that tool call was marked as invalid
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.True(t, toolCalls[0].Invalid) // Should be invalid
+		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
+	})
+
+	t.Run("Invalid tool call with successful repair", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "test_tool",
+							Input:      `{"wrong_field": "test"}`, // Missing required field
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{
+			name:        "test_tool",
+			description: "Test tool",
+			parameters: map[string]any{
+				"value": map[string]any{"type": "string"},
+			},
+			required: []string{"value"},
+			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+				return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
+			},
+		}
+
+		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
+			// Simple repair: add the missing required field
+			repairedToolCall := options.OriginalToolCall
+			repairedToolCall.Input = `{"value": "repaired"}`
+			return &repairedToolCall, nil
+		}
+
+		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step
+
+		// Check that tool call was repaired and is now valid
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.False(t, toolCalls[0].Invalid)                        // Should be valid after repair
+		require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
+	})
+
+	t.Run("Invalid tool call with failed repair", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "test_tool",
+							Input:      `{"wrong_field": "test"}`, // Missing required field
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{
+			name:        "test_tool",
+			description: "Test tool",
+			parameters: map[string]any{
+				"value": map[string]any{"type": "string"},
+			},
+			required: []string{"value"},
+		}
+
+		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
+			// Repair function fails
+			return nil, errors.New("repair failed")
+		}
+
+		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step
+
+		// Check that tool call was marked as invalid since repair failed
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.True(t, toolCalls[0].Invalid) // Should be invalid
+		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
+	})
+
+	t.Run("Nonexistent tool call", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "nonexistent_tool",
+							Input:      `{"value": "test"}`,
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{name: "test_tool", description: "Test tool"}
+
+		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step
+
+		// Check that tool call was marked as invalid due to nonexistent tool
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.True(t, toolCalls[0].Invalid) // Should be invalid
+		require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
+	})
+
+	t.Run("Invalid JSON in tool call", func(t *testing.T) {
+		model := &mockLanguageModel{
+			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+				return &Response{
+					Content: ResponseContent{
+						TextContent{Text: "Response"},
+						ToolCallContent{
+							ToolCallID: "call1",
+							ToolName:   "test_tool",
+							Input:      `{invalid json}`, // Invalid JSON
+						},
+					},
+					Usage:        Usage{TotalTokens: 10},
+					FinishReason: FinishReasonStop, // Changed to stop
+				}, nil
+			},
+		}
+
+		tool := &mockTool{
+			name:        "test_tool",
+			description: "Test tool",
+			parameters: map[string]any{
+				"value": map[string]any{"type": "string"},
+			},
+			required: []string{"value"},
+		}
+
+		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
+
+		result, err := agent.Generate(context.Background(), AgentCall{
+			Prompt: "test prompt",
+		})
+
+		require.NoError(t, err)
+		require.NotNil(t, result)
+		require.Len(t, result.Steps, 1) // Only one step
+
+		// Check that tool call was marked as invalid due to invalid JSON
+		toolCalls := result.Steps[0].Response.Content.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.True(t, toolCalls[0].Invalid) // Should be invalid
+		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
+	})
+}

internal/ai/content.go 🔗

@@ -316,6 +316,10 @@ type ToolCallContent struct {
 	ProviderExecuted bool `json:"provider_executed"`
 	// Additional provider-specific metadata for the tool call.
 	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
+	// Whether this tool call is invalid (failed validation/parsing)
+	Invalid bool `json:"invalid,omitempty"`
+	// Error that occurred during validation/parsing (only set if Invalid is true)
+	ValidationError error `json:"validation_error,omitempty"`
 }
 
 // GetType returns the type of the tool call content.

internal/ai/model.go 🔗

@@ -37,6 +37,80 @@ func (r ResponseContent) Text() string {
 	return ""
 }
 
+// Reasoning returns all reasoning content parts.
+func (r ResponseContent) Reasoning() []ReasoningContent {
+	var reasoning []ReasoningContent
+	for _, c := range r {
+		if c.GetType() == ContentTypeReasoning {
+			if reasoningContent, ok := AsContentType[ReasoningContent](c); ok {
+				reasoning = append(reasoning, reasoningContent)
+			}
+		}
+	}
+	return reasoning
+}
+
+// ReasoningText returns all reasoning content as a concatenated string.
+func (r ResponseContent) ReasoningText() string {
+	var text string
+	for _, reasoning := range r.Reasoning() {
+		text += reasoning.Text
+	}
+	return text
+}
+
+// Files returns all file content parts.
+func (r ResponseContent) Files() []FileContent {
+	var files []FileContent
+	for _, c := range r {
+		if c.GetType() == ContentTypeFile {
+			if fileContent, ok := AsContentType[FileContent](c); ok {
+				files = append(files, fileContent)
+			}
+		}
+	}
+	return files
+}
+
+// Sources returns all source content parts.
+func (r ResponseContent) Sources() []SourceContent {
+	var sources []SourceContent
+	for _, c := range r {
+		if c.GetType() == ContentTypeSource {
+			if sourceContent, ok := AsContentType[SourceContent](c); ok {
+				sources = append(sources, sourceContent)
+			}
+		}
+	}
+	return sources
+}
+
+// ToolCalls returns all tool call content parts.
+func (r ResponseContent) ToolCalls() []ToolCallContent {
+	var toolCalls []ToolCallContent
+	for _, c := range r {
+		if c.GetType() == ContentTypeToolCall {
+			if toolCallContent, ok := AsContentType[ToolCallContent](c); ok {
+				toolCalls = append(toolCalls, toolCallContent)
+			}
+		}
+	}
+	return toolCalls
+}
+
+// ToolResults returns all tool result content parts.
+func (r ResponseContent) ToolResults() []ToolResultContent {
+	var toolResults []ToolResultContent
+	for _, c := range r {
+		if c.GetType() == ContentTypeToolResult {
+			if toolResultContent, ok := AsContentType[ToolResultContent](c); ok {
+				toolResults = append(toolResults, toolResultContent)
+			}
+		}
+	}
+	return toolResults
+}
+
 type Response struct {
 	Content      ResponseContent `json:"content"`
 	FinishReason FinishReason    `json:"finish_reason"`

internal/ai/retry.go 🔗

@@ -167,4 +167,3 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti
 		newErrors,
 	)
 }
-