@@ -2,7 +2,9 @@ package ai
import (
"context"
+ "encoding/json"
"errors"
+ "fmt"
"maps"
"slices"
"sync"
@@ -12,12 +14,73 @@ import (
type StepResult struct {
Response
- // Messages generated during this step
Messages []Message
}
type StopCondition = func(steps []StepResult) bool
+// StepCountIs returns a stop condition that stops after the specified number of steps.
+func StepCountIs(stepCount int) StopCondition {
+ return func(steps []StepResult) bool {
+ return len(steps) >= stepCount
+ }
+}
+
+// HasToolCall returns a stop condition that stops when the specified tool is called in the last step.
+func HasToolCall(toolName string) StopCondition {
+ return func(steps []StepResult) bool {
+ if len(steps) == 0 {
+ return false
+ }
+ lastStep := steps[len(steps)-1]
+ toolCalls := lastStep.Content.ToolCalls()
+ for _, toolCall := range toolCalls {
+ if toolCall.ToolName == toolName {
+ return true
+ }
+ }
+ return false
+ }
+}
+
+// HasContent returns a stop condition that stops when the specified content type appears in the last step.
+func HasContent(contentType ContentType) StopCondition {
+ return func(steps []StepResult) bool {
+ if len(steps) == 0 {
+ return false
+ }
+ lastStep := steps[len(steps)-1]
+ for _, content := range lastStep.Content {
+ if content.GetType() == contentType {
+ return true
+ }
+ }
+ return false
+ }
+}
+
+// FinishReasonIs returns a stop condition that stops when the specified finish reason occurs.
+func FinishReasonIs(reason FinishReason) StopCondition {
+ return func(steps []StepResult) bool {
+ if len(steps) == 0 {
+ return false
+ }
+ lastStep := steps[len(steps)-1]
+ return lastStep.FinishReason == reason
+ }
+}
+
+// MaxTokensUsed returns a stop condition that stops when total token usage exceeds the specified limit.
+func MaxTokensUsed(maxTokens int64) StopCondition {
+ return func(steps []StepResult) bool {
+ var totalTokens int64
+ for _, step := range steps {
+ totalTokens += step.Usage.TotalTokens
+ }
+ return totalTokens >= maxTokens
+ }
+}
+
type PrepareStepFunctionOptions struct {
Steps []StepResult
StepNumber int
@@ -26,14 +89,26 @@ type PrepareStepFunctionOptions struct {
}
type PrepareStepResult struct {
- Model LanguageModel
- Messages []Message
+ Model LanguageModel
+ Messages []Message
+ System *string
+ ToolChoice *ToolChoice
+ ActiveTools []string
+ DisableAllTools bool
+}
+
+type ToolCallRepairOptions struct {
+ OriginalToolCall ToolCallContent
+ ValidationError error
+ AvailableTools []tools.BaseTool
+ SystemPrompt string
+ Messages []Message
}
type (
PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult
OnStepFinishedFunction = func(step StepResult)
- RepairToolCall = func(ToolCallContent) ToolCallContent
+ RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
)
type AgentSettings struct {
@@ -55,7 +130,7 @@ type AgentSettings struct {
stopWhen []StopCondition
prepareStep PrepareStepFunction
- repairToolCall RepairToolCall
+ repairToolCall RepairToolCallFunction
onStepFinished OnStepFinishedFunction
onRetry OnRetryCallback
}
@@ -78,7 +153,7 @@ type AgentCall struct {
StopWhen []StopCondition
PrepareStep PrepareStepFunction
- RepairToolCall RepairToolCall
+ RepairToolCall RepairToolCallFunction
OnStepFinished OnStepFinishedFunction
}
@@ -185,6 +260,11 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
for {
stepInputMessages := append(initialPrompt, responseMessages...)
stepModel := a.settings.model
+ stepSystemPrompt := a.settings.systemPrompt
+ stepActiveTools := opts.ActiveTools
+ stepToolChoice := ToolChoiceAuto
+ disableAllTools := false
+
if opts.PrepareStep != nil {
prepared := opts.PrepareStep(PrepareStepFunctionOptions{
Model: stepModel,
@@ -192,15 +272,40 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
StepNumber: len(steps),
Messages: stepInputMessages,
})
- stepInputMessages = prepared.Messages
+
+ // Apply prepared step modifications
+ if prepared.Messages != nil {
+ stepInputMessages = prepared.Messages
+ }
if prepared.Model != nil {
stepModel = prepared.Model
}
+ if prepared.System != nil {
+ stepSystemPrompt = *prepared.System
+ }
+ if prepared.ToolChoice != nil {
+ stepToolChoice = *prepared.ToolChoice
+ }
+ if len(prepared.ActiveTools) > 0 {
+ stepActiveTools = prepared.ActiveTools
+ }
+ disableAllTools = prepared.DisableAllTools
+ }
+
+ // Recreate prompt with potentially modified system prompt
+ if stepSystemPrompt != a.settings.systemPrompt {
+ stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
+ if err != nil {
+ return nil, err
+ }
+ // Replace system message part, keep the rest
+ if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
+ stepInputMessages[0] = stepPrompt[0] // Replace system message
+ }
}
- preparedTools := a.prepareTools(a.settings.tools, opts.ActiveTools)
+ preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
- toolChoice := ToolChoiceAuto
retryOptions := DefaultRetryOptions()
retryOptions.OnRetry = opts.OnRetry
retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
@@ -215,7 +320,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
Tools: preparedTools,
- ToolChoice: &toolChoice,
+ ToolChoice: &stepToolChoice,
Headers: opts.Headers,
ProviderOptions: opts.ProviderOptions,
})
@@ -231,13 +336,31 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
if !ok {
continue
}
- stepToolCalls = append(stepToolCalls, toolCall)
+
+ // Validate and potentially repair the tool call
+ validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
+ stepToolCalls = append(stepToolCalls, validatedToolCall)
}
}
toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
- stepContent := result.Content
+ // Build step content with validated tool calls and tool results
+ stepContent := []Content{}
+ toolCallIndex := 0
+ for _, content := range result.Content {
+ if content.GetType() == ContentTypeToolCall {
+ // Replace with validated tool call
+ if toolCallIndex < len(stepToolCalls) {
+ stepContent = append(stepContent, stepToolCalls[toolCallIndex])
+ toolCallIndex++
+ }
+ } else {
+ // Keep other content as-is
+ stepContent = append(stepContent, content)
+ }
+ }
+ // Add tool results
for _, result := range toolResults {
stepContent = append(stepContent, result)
}
@@ -345,6 +468,10 @@ func toResponseMessages(content []Content) []Message {
MediaType: file.MediaType,
ProviderOptions: ProviderOptions(file.ProviderMetadata),
})
+ case ContentTypeSource:
+ // Sources are metadata about references used to generate the response.
+ // They don't need to be included in the conversation messages.
+ continue
case ContentTypeToolResult:
result, ok := AsContentType[ToolResultContent](c)
if !ok {
@@ -395,6 +522,19 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
go func(index int, call ToolCallContent) {
defer wg.Done()
+ // Skip invalid tool calls - create error result
+ if call.Invalid {
+ results[index] = ToolResultContent{
+ ToolCallID: call.ToolCallID,
+ ToolName: call.ToolName,
+ Result: ToolResultOutputContentError{
+ Error: call.ValidationError,
+ },
+ ProviderExecuted: false,
+ }
+ return
+ }
+
tool, exists := toolMap[call.ToolName]
if !exists {
results[index] = ToolResultContent{
@@ -461,9 +601,17 @@ func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, err
panic("not implemented")
}
-func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Tool {
+func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
var preparedTools []Tool
+
+ // If explicitly disabling all tools, return no tools
+ if disableAllTools {
+ return preparedTools
+ }
+
for _, tool := range tools {
+ // If activeTools has items, only include tools in the list
+ // If activeTools is empty, include all tools
if len(activeTools) > 0 && !slices.Contains(activeTools, tool.Info().Name) {
continue
}
@@ -481,6 +629,65 @@ func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string) []Too
return preparedTools
}
+// validateAndRepairToolCall validates a tool call and attempts repair if validation fails
+func (a *agent) validateAndRepairToolCall(ctx context.Context, toolCall ToolCallContent, availableTools []tools.BaseTool, systemPrompt string, messages []Message, repairFunc RepairToolCallFunction) ToolCallContent {
+ if err := a.validateToolCall(toolCall, availableTools); err == nil {
+ return toolCall
+ } else {
+ if repairFunc != nil {
+ repairOptions := ToolCallRepairOptions{
+ OriginalToolCall: toolCall,
+ ValidationError: err,
+ AvailableTools: availableTools,
+ SystemPrompt: systemPrompt,
+ Messages: messages,
+ }
+
+ if repairedToolCall, repairErr := repairFunc(ctx, repairOptions); repairErr == nil && repairedToolCall != nil {
+ if validateErr := a.validateToolCall(*repairedToolCall, availableTools); validateErr == nil {
+ return *repairedToolCall
+ }
+ }
+ }
+
+ invalidToolCall := toolCall
+ invalidToolCall.Invalid = true
+ invalidToolCall.ValidationError = err
+ return invalidToolCall
+ }
+}
+
+// validateToolCall validates a tool call against available tools and their schemas
+func (a *agent) validateToolCall(toolCall ToolCallContent, availableTools []tools.BaseTool) error {
+ var tool tools.BaseTool
+ for _, t := range availableTools {
+ if t.Info().Name == toolCall.ToolName {
+ tool = t
+ break
+ }
+ }
+
+ if tool == nil {
+ return fmt.Errorf("tool not found: %s", toolCall.ToolName)
+ }
+
+ // Validate JSON parsing
+ var input map[string]any
+ if err := json.Unmarshal([]byte(toolCall.Input), &input); err != nil {
+ return fmt.Errorf("invalid JSON input: %w", err)
+ }
+
+ // Basic schema validation (check required fields)
+ // TODO: more robust schema validation using JSON Schema or similar
+ toolInfo := tool.Info()
+ for _, required := range toolInfo.Required {
+ if _, exists := input[required]; !exists {
+ return fmt.Errorf("missing required parameter: %s", required)
+ }
+ }
+ return nil
+}
+
func (a *agent) createPrompt(system, prompt string, messages []Message, files ...FilePart) (Prompt, error) {
if prompt == "" {
return nil, NewInvalidPromptError(prompt, "Prompt can't be empty", nil)
@@ -557,7 +764,7 @@ func WithPrepareStep(fn PrepareStepFunction) agentOption {
}
}
-func WithRepairToolCall(fn RepairToolCall) agentOption {
+func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
return func(s *AgentSettings) {
s.repairToolCall = fn
}
@@ -3,6 +3,8 @@ package ai
import (
"context"
"encoding/json"
+ "errors"
+ "fmt"
"testing"
"github.com/charmbracelet/crush/internal/llm/tools"
@@ -645,4 +647,890 @@ func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, result)
-}
+}
+
+func TestResponseContent_Getters(t *testing.T) {
+ t.Parallel()
+
+ // Create test content with all types
+ content := ResponseContent{
+ TextContent{Text: "Hello world"},
+ ReasoningContent{Text: "Let me think..."},
+ FileContent{Data: []byte("file data"), MediaType: "text/plain"},
+ SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
+ ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
+ ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
+ }
+
+ // Test Text()
+ require.Equal(t, "Hello world", content.Text())
+
+ // Test Reasoning()
+ reasoning := content.Reasoning()
+ require.Len(t, reasoning, 1)
+ require.Equal(t, "Let me think...", reasoning[0].Text)
+
+ // Test ReasoningText()
+ require.Equal(t, "Let me think...", content.ReasoningText())
+
+ // Test Files()
+ files := content.Files()
+ require.Len(t, files, 1)
+ require.Equal(t, "text/plain", files[0].MediaType)
+ require.Equal(t, []byte("file data"), files[0].Data)
+
+ // Test Sources()
+ sources := content.Sources()
+ require.Len(t, sources, 1)
+ require.Equal(t, SourceTypeURL, sources[0].SourceType)
+ require.Equal(t, "https://example.com", sources[0].URL)
+ require.Equal(t, "Example", sources[0].Title)
+
+ // Test ToolCalls()
+ toolCalls := content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.Equal(t, "call1", toolCalls[0].ToolCallID)
+ require.Equal(t, "test_tool", toolCalls[0].ToolName)
+ require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
+
+ // Test ToolResults()
+ toolResults := content.ToolResults()
+ require.Len(t, toolResults, 1)
+ require.Equal(t, "call1", toolResults[0].ToolCallID)
+ require.Equal(t, "test_tool", toolResults[0].ToolName)
+ result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
+ require.True(t, ok)
+ require.Equal(t, "result", result.Text)
+}
+
+func TestResponseContent_Getters_Empty(t *testing.T) {
+ t.Parallel()
+
+ // Test with empty content
+ content := ResponseContent{}
+
+ require.Equal(t, "", content.Text())
+ require.Equal(t, "", content.ReasoningText())
+ require.Empty(t, content.Reasoning())
+ require.Empty(t, content.Files())
+ require.Empty(t, content.Sources())
+ require.Empty(t, content.ToolCalls())
+ require.Empty(t, content.ToolResults())
+}
+
+func TestResponseContent_Getters_MultipleItems(t *testing.T) {
+ t.Parallel()
+
+ // Test with multiple items of same type
+ content := ResponseContent{
+ ReasoningContent{Text: "First thought"},
+ ReasoningContent{Text: "Second thought"},
+ FileContent{Data: []byte("file1"), MediaType: "text/plain"},
+ FileContent{Data: []byte("file2"), MediaType: "image/png"},
+ }
+
+ // Test multiple reasoning
+ reasoning := content.Reasoning()
+ require.Len(t, reasoning, 2)
+ require.Equal(t, "First thought", reasoning[0].Text)
+ require.Equal(t, "Second thought", reasoning[1].Text)
+
+ // Test concatenated reasoning text
+ require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
+
+ // Test multiple files
+ files := content.Files()
+ require.Len(t, files, 2)
+ require.Equal(t, "text/plain", files[0].MediaType)
+ require.Equal(t, "image/png", files[1].MediaType)
+}
+
+func TestStopConditions(t *testing.T) {
+ t.Parallel()
+
+ // Create test steps
+ step1 := StepResult{
+ Response: Response{
+ Content: ResponseContent{
+ TextContent{Text: "Hello"},
+ },
+ FinishReason: FinishReasonToolCalls,
+ Usage: Usage{TotalTokens: 10},
+ },
+ }
+
+ step2 := StepResult{
+ Response: Response{
+ Content: ResponseContent{
+ TextContent{Text: "World"},
+ ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
+ },
+ FinishReason: FinishReasonStop,
+ Usage: Usage{TotalTokens: 15},
+ },
+ }
+
+ step3 := StepResult{
+ Response: Response{
+ Content: ResponseContent{
+ ReasoningContent{Text: "Let me think..."},
+ FileContent{Data: []byte("data"), MediaType: "text/plain"},
+ },
+ FinishReason: FinishReasonLength,
+ Usage: Usage{TotalTokens: 20},
+ },
+ }
+
+ t.Run("StepCountIs", func(t *testing.T) {
+ condition := StepCountIs(2)
+
+ // Should not stop with 1 step
+ require.False(t, condition([]StepResult{step1}))
+
+ // Should stop with 2 steps
+ require.True(t, condition([]StepResult{step1, step2}))
+
+ // Should stop with more than 2 steps
+ require.True(t, condition([]StepResult{step1, step2, step3}))
+
+ // Should not stop with empty steps
+ require.False(t, condition([]StepResult{}))
+ })
+
+ t.Run("HasToolCall", func(t *testing.T) {
+ condition := HasToolCall("search")
+
+ // Should not stop when tool not called
+ require.False(t, condition([]StepResult{step1}))
+
+ // Should stop when tool is called in last step
+ require.True(t, condition([]StepResult{step1, step2}))
+
+ // Should not stop when tool called in earlier step but not last
+ require.False(t, condition([]StepResult{step1, step2, step3}))
+
+ // Should not stop with empty steps
+ require.False(t, condition([]StepResult{}))
+
+ // Should not stop when different tool is called
+ differentToolCondition := HasToolCall("different_tool")
+ require.False(t, differentToolCondition([]StepResult{step1, step2}))
+ })
+
+ t.Run("HasContent", func(t *testing.T) {
+ reasoningCondition := HasContent(ContentTypeReasoning)
+ fileCondition := HasContent(ContentTypeFile)
+
+ // Should not stop when content type not present
+ require.False(t, reasoningCondition([]StepResult{step1, step2}))
+
+ // Should stop when content type is present in last step
+ require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
+ require.True(t, fileCondition([]StepResult{step1, step2, step3}))
+
+ // Should not stop with empty steps
+ require.False(t, reasoningCondition([]StepResult{}))
+ })
+
+ t.Run("FinishReasonIs", func(t *testing.T) {
+ stopCondition := FinishReasonIs(FinishReasonStop)
+ lengthCondition := FinishReasonIs(FinishReasonLength)
+
+ // Should not stop when finish reason doesn't match
+ require.False(t, stopCondition([]StepResult{step1}))
+
+ // Should stop when finish reason matches in last step
+ require.True(t, stopCondition([]StepResult{step1, step2}))
+ require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
+
+ // Should not stop with empty steps
+ require.False(t, stopCondition([]StepResult{}))
+ })
+
+ t.Run("MaxTokensUsed", func(t *testing.T) {
+ condition := MaxTokensUsed(30)
+
+ // Should not stop when under limit
+ require.False(t, condition([]StepResult{step1})) // 10 tokens
+ require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
+
+ // Should stop when at or over limit
+ require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
+
+ // Should not stop with empty steps
+ require.False(t, condition([]StepResult{}))
+ })
+}
+
+func TestStopConditions_Integration(t *testing.T) {
+ t.Parallel()
+
+ t.Run("StepCountIs integration", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Mock response"},
+ },
+ Usage: Usage{
+ InputTokens: 3,
+ OutputTokens: 10,
+ TotalTokens: 13,
+ },
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Should stop after 1 step
+ })
+
+ t.Run("Multiple stop conditions", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Mock response"},
+ },
+ Usage: Usage{
+ InputTokens: 3,
+ OutputTokens: 10,
+ TotalTokens: 13,
+ },
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithStopConditions(
+ StepCountIs(5), // Stop after 5 steps
+ FinishReasonIs(FinishReasonStop), // Or stop on finish reason
+ ))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ // Should stop on first condition met (finish reason stop)
+ require.Equal(t, FinishReasonStop, result.Response.FinishReason)
+ })
+}
+
+func TestPrepareStep(t *testing.T) {
+ t.Parallel()
+
+ t.Run("System prompt modification", func(t *testing.T) {
+ var capturedSystemPrompt string
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ // Capture the system message to verify it was modified
+ if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+ if len(call.Prompt[0].Content) > 0 {
+ if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+ capturedSystemPrompt = textPart.Text
+ }
+ }
+ }
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ System: &newSystem,
+ }
+ }
+
+ agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
+ })
+
+ t.Run("Tool choice modification", func(t *testing.T) {
+ var capturedToolChoice *ToolChoice
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ capturedToolChoice = call.ToolChoice
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ toolChoice := ToolChoiceNone
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ ToolChoice: &toolChoice,
+ }
+ }
+
+ agent := NewAgent(model)
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, capturedToolChoice)
+ require.Equal(t, ToolChoiceNone, *capturedToolChoice)
+ })
+
+ t.Run("Active tools modification", func(t *testing.T) {
+ var capturedToolNames []string
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ // Capture tool names to verify active tools were modified
+ for _, tool := range call.Tools {
+ capturedToolNames = append(capturedToolNames, tool.GetName())
+ }
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+ tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+ tool3 := &mockTool{name: "tool3", description: "Tool 3"}
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ activeTools := []string{"tool2"} // Only tool2 should be active
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ ActiveTools: activeTools,
+ }
+ }
+
+ agent := NewAgent(model, WithTools(tool1, tool2, tool3))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, capturedToolNames, 1)
+ require.Equal(t, "tool2", capturedToolNames[0])
+ })
+
+ t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
+ var capturedToolCount int
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ capturedToolCount = len(call.Tools)
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ DisableAllTools: true, // Disable all tools for this step
+ }
+ }
+
+ agent := NewAgent(model, WithTools(tool1))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 0, capturedToolCount) // No tools should be passed
+ })
+
+ t.Run("All fields modified together", func(t *testing.T) {
+ var capturedSystemPrompt string
+ var capturedToolChoice *ToolChoice
+ var capturedToolNames []string
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ // Capture system prompt
+ if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+ if len(call.Prompt[0].Content) > 0 {
+ if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+ capturedSystemPrompt = textPart.Text
+ }
+ }
+ }
+ // Capture tool choice
+ capturedToolChoice = call.ToolChoice
+ // Capture tool names
+ for _, tool := range call.Tools {
+ capturedToolNames = append(capturedToolNames, tool.GetName())
+ }
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+ tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ newSystem := "Step-specific system"
+ toolChoice := SpecificToolChoice("tool1")
+ activeTools := []string{"tool1"}
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ System: &newSystem,
+ ToolChoice: &toolChoice,
+ ActiveTools: activeTools,
+ }
+ }
+
+ agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "Step-specific system", capturedSystemPrompt)
+ require.NotNil(t, capturedToolChoice)
+ require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
+ require.Len(t, capturedToolNames, 1)
+ require.Equal(t, "tool1", capturedToolNames[0])
+ })
+
+ t.Run("Nil fields use parent values", func(t *testing.T) {
+ var capturedSystemPrompt string
+ var capturedToolChoice *ToolChoice
+ var capturedToolNames []string
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ // Capture system prompt
+ if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
+ if len(call.Prompt[0].Content) > 0 {
+ if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
+ capturedSystemPrompt = textPart.Text
+ }
+ }
+ }
+ // Capture tool choice
+ capturedToolChoice = call.ToolChoice
+ // Capture tool names
+ for _, tool := range call.Tools {
+ capturedToolNames = append(capturedToolNames, tool.GetName())
+ }
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ // All optional fields are nil, should use parent values
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ System: nil, // Use parent
+ ToolChoice: nil, // Use parent (auto)
+ ActiveTools: nil, // Use parent (all tools)
+ }
+ }
+
+ agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "Parent system", capturedSystemPrompt)
+ require.NotNil(t, capturedToolChoice)
+ require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
+ require.Len(t, capturedToolNames, 1)
+ require.Equal(t, "tool1", capturedToolNames[0])
+ })
+
+ t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
+ var capturedToolNames []string
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ // Capture tool names to verify all tools are included
+ for _, tool := range call.Tools {
+ capturedToolNames = append(capturedToolNames, tool.GetName())
+ }
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ tool1 := &mockTool{name: "tool1", description: "Tool 1"}
+ tool2 := &mockTool{name: "tool2", description: "Tool 2"}
+
+ prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+ return PrepareStepResult{
+ Model: options.Model,
+ Messages: options.Messages,
+ ActiveTools: []string{}, // Empty slice means all tools
+ }
+ }
+
+ agent := NewAgent(model, WithTools(tool1, tool2))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ PrepareStep: prepareStepFunc,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, capturedToolNames, 2) // All tools should be included
+ require.Contains(t, capturedToolNames, "tool1")
+ require.Contains(t, capturedToolNames, "tool2")
+ })
+}
+
+func TestToolCallRepair(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Valid tool call passes validation", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "test_tool",
+ Input: `{"value": "test"}`, // Valid JSON with required field
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{
+ name: "test_tool",
+ description: "Test tool",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+ return tools.ToolResponse{Content: "success", IsError: false}, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
+
+ // Check that tool call was executed successfully
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.False(t, toolCalls[0].Invalid) // Should be valid
+ })
+
+ t.Run("Invalid tool call without repair function", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "test_tool",
+ Input: `{"wrong_field": "test"}`, // Missing required field
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{
+ name: "test_tool",
+ description: "Test tool",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ }
+
+ agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step
+
+ // Check that tool call was marked as invalid
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.True(t, toolCalls[0].Invalid) // Should be invalid
+ require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
+ })
+
+ t.Run("Invalid tool call with successful repair", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "test_tool",
+ Input: `{"wrong_field": "test"}`, // Missing required field
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{
+ name: "test_tool",
+ description: "Test tool",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
+ return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
+ },
+ }
+
+ repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
+ // Simple repair: add the missing required field
+ repairedToolCall := options.OriginalToolCall
+ repairedToolCall.Input = `{"value": "repaired"}`
+ return &repairedToolCall, nil
+ }
+
+ agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step
+
+ // Check that tool call was repaired and is now valid
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.False(t, toolCalls[0].Invalid) // Should be valid after repair
+ require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
+ })
+
+ t.Run("Invalid tool call with failed repair", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "test_tool",
+ Input: `{"wrong_field": "test"}`, // Missing required field
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{
+ name: "test_tool",
+ description: "Test tool",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ }
+
+ repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
+ // Repair function fails
+ return nil, errors.New("repair failed")
+ }
+
+ agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step
+
+ // Check that tool call was marked as invalid since repair failed
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.True(t, toolCalls[0].Invalid) // Should be invalid
+ require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
+ })
+
+ t.Run("Nonexistent tool call", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "nonexistent_tool",
+ Input: `{"value": "test"}`,
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{name: "test_tool", description: "Test tool"}
+
+ agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step
+
+ // Check that tool call was marked as invalid due to nonexistent tool
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.True(t, toolCalls[0].Invalid) // Should be invalid
+ require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
+ })
+
+ t.Run("Invalid JSON in tool call", func(t *testing.T) {
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ return &Response{
+ Content: ResponseContent{
+ TextContent{Text: "Response"},
+ ToolCallContent{
+ ToolCallID: "call1",
+ ToolName: "test_tool",
+ Input: `{invalid json}`, // Invalid JSON
+ },
+ },
+ Usage: Usage{TotalTokens: 10},
+ FinishReason: FinishReasonStop, // Changed to stop
+ }, nil
+ },
+ }
+
+ tool := &mockTool{
+ name: "test_tool",
+ description: "Test tool",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ }
+
+ agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
+
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test prompt",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Len(t, result.Steps, 1) // Only one step
+
+ // Check that tool call was marked as invalid due to invalid JSON
+ toolCalls := result.Steps[0].Response.Content.ToolCalls()
+ require.Len(t, toolCalls, 1)
+ require.True(t, toolCalls[0].Invalid) // Should be invalid
+ require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
+ })
+}