chore: agent improvements

Kujtim Hoxha created

Change summary

internal/ai/agent.go                         | 130 ++++++++++-----------
internal/ai/agent_stream_test.go             |   4 
internal/ai/agent_test.go                    |  28 ++--
internal/ai/content.go                       |  38 ++++++
internal/ai/examples/streaming-agent/main.go |   2 
internal/ai/model.go                         |   7 
internal/ai/tool.go                          |   3 
7 files changed, 123 insertions(+), 89 deletions(-)

Detailed changes

internal/ai/agent.go 🔗

@@ -104,12 +104,12 @@ type ToolCallRepairOptions struct {
 }
 
 type (
-	PrepareStepFunction    = func(options PrepareStepFunctionOptions) PrepareStepResult
+	PrepareStepFunction    = func(options PrepareStepFunctionOptions) (PrepareStepResult, error)
 	OnStepFinishedFunction = func(step StepResult)
 	RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)
 )
 
-type AgentSettings struct {
+type agentSettings struct {
 	systemPrompt     string
 	maxOutputTokens  *int64
 	temperature      *float64
@@ -129,7 +129,6 @@ type AgentSettings struct {
 	stopWhen       []StopCondition
 	prepareStep    PrepareStepFunction
 	repairToolCall RepairToolCallFunction
-	onStepFinished OnStepFinishedFunction
 	onRetry        OnRetryCallback
 }
 
@@ -152,7 +151,6 @@ type AgentCall struct {
 	StopWhen       []StopCondition
 	PrepareStep    PrepareStepFunction
 	RepairToolCall RepairToolCallFunction
-	OnStepFinished OnStepFinishedFunction
 }
 
 type AgentStreamCall struct {
@@ -184,22 +182,22 @@ type AgentStreamCall struct {
 	OnError       func(error)                 // Called when an error occurs
 
 	// Stream part callbacks - called for each corresponding stream part type
-	OnChunk          func(StreamPart)                                                               // Called for each stream part (catch-all)
-	OnWarnings       func(warnings []CallWarning)                                                   // Called for warnings
-	OnTextStart      func(id string)                                                                // Called when text starts
-	OnTextDelta      func(id, text string)                                                          // Called for text deltas
-	OnTextEnd        func(id string)                                                                // Called when text ends
-	OnReasoningStart func(id string)                                                                // Called when reasoning starts
-	OnReasoningDelta func(id, text string)                                                          // Called for reasoning deltas
-	OnReasoningEnd   func(id string)                                                                // Called when reasoning ends
-	OnToolInputStart func(id, toolName string)                                                      // Called when tool input starts
-	OnToolInputDelta func(id, delta string)                                                         // Called for tool input deltas
-	OnToolInputEnd   func(id string)                                                                // Called when tool input ends
-	OnToolCall       func(toolCall ToolCallContent)                                                 // Called when tool call is complete
-	OnToolResult     func(result ToolResultContent)                                                 // Called when tool execution completes
-	OnSource         func(source SourceContent)                                                     // Called for source references
-	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
-	OnStreamError    func(error)                                                                    // Called when stream error occurs
+	OnChunk          func(StreamPart)                                                                // Called for each stream part (catch-all)
+	OnWarnings       func(warnings []CallWarning)                                                    // Called for warnings
+	OnTextStart      func(id string)                                                                 // Called when text starts
+	OnTextDelta      func(id, text string)                                                           // Called for text deltas
+	OnTextEnd        func(id string)                                                                 // Called when text ends
+	OnReasoningStart func(id string)                                                                 // Called when reasoning starts
+	OnReasoningDelta func(id, text string)                                                           // Called for reasoning deltas
+	OnReasoningEnd   func(id string, reasoning ReasoningContent)                                     // Called when reasoning ends
+	OnToolInputStart func(id, toolName string)                                                       // Called when tool input starts
+	OnToolInputDelta func(id, delta string)                                                          // Called for tool input deltas
+	OnToolInputEnd   func(id string)                                                                 // Called when tool input ends
+	OnToolCall       func(toolCall ToolCallContent)                                                  // Called when tool call is complete
+	OnToolResult     func(result ToolResultContent)                                                  // Called when tool execution completes
+	OnSource         func(source SourceContent)                                                      // Called for source references
+	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
+	OnStreamError    func(error)                                                                     // Called when stream error occurs
 }
 
 type AgentResult struct {
@@ -214,14 +212,14 @@ type Agent interface {
 	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 }
 
-type agentOption = func(*AgentSettings)
+type AgentOption = func(*agentSettings)
 
 type agent struct {
-	settings AgentSettings
+	settings agentSettings
 }
 
-func NewAgent(model LanguageModel, opts ...agentOption) Agent {
-	settings := AgentSettings{
+func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
+	settings := agentSettings{
 		model: model,
 	}
 	for _, o := range opts {
@@ -260,9 +258,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
 	if call.RepairToolCall == nil && a.settings.repairToolCall != nil {
 		call.RepairToolCall = a.settings.repairToolCall
 	}
-	if call.OnStepFinished == nil && a.settings.onStepFinished != nil {
-		call.OnStepFinished = a.settings.onStepFinished
-	}
 	if call.OnRetry == nil && a.settings.onRetry != nil {
 		call.OnRetry = a.settings.onRetry
 	}
@@ -311,12 +306,15 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 		disableAllTools := false
 
 		if opts.PrepareStep != nil {
-			prepared := opts.PrepareStep(PrepareStepFunctionOptions{
+			prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{
 				Model:      stepModel,
 				Steps:      steps,
 				StepNumber: len(steps),
 				Messages:   stepInputMessages,
 			})
+			if err != nil {
+				return nil, err
+			}
 
 			// Apply prepared step modifications
 			if prepared.Messages != nil {
@@ -423,10 +421,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 			Messages: currentStepMessages,
 		}
 		steps = append(steps, stepResult)
-		if opts.OnStepFinished != nil {
-			opts.OnStepFinished(stepResult)
-		}
-
 		shouldStop := isStopConditionMet(opts.StopWhen, steps)
 
 		if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls {
@@ -709,12 +703,15 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 
 		// Apply step preparation if provided
 		if call.PrepareStep != nil {
-			prepared := call.PrepareStep(PrepareStepFunctionOptions{
+			prepared, err := call.PrepareStep(PrepareStepFunctionOptions{
 				Model:      stepModel,
 				Steps:      steps,
 				StepNumber: stepNumber,
 				Messages:   stepInputMessages,
 			})
+			if err != nil {
+				return nil, err
+			}
 
 			if prepared.Messages != nil {
 				stepInputMessages = prepared.Messages
@@ -925,78 +922,72 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
 	return preparedPrompt, nil
 }
 
-func WithSystemPrompt(prompt string) agentOption {
-	return func(s *AgentSettings) {
+func WithSystemPrompt(prompt string) AgentOption {
+	return func(s *agentSettings) {
 		s.systemPrompt = prompt
 	}
 }
 
-func WithMaxOutputTokens(tokens int64) agentOption {
-	return func(s *AgentSettings) {
+func WithMaxOutputTokens(tokens int64) AgentOption {
+	return func(s *agentSettings) {
 		s.maxOutputTokens = &tokens
 	}
 }
 
-func WithTemperature(temp float64) agentOption {
-	return func(s *AgentSettings) {
+func WithTemperature(temp float64) AgentOption {
+	return func(s *agentSettings) {
 		s.temperature = &temp
 	}
 }
 
-func WithTopP(topP float64) agentOption {
-	return func(s *AgentSettings) {
+func WithTopP(topP float64) AgentOption {
+	return func(s *agentSettings) {
 		s.topP = &topP
 	}
 }
 
-func WithTopK(topK int64) agentOption {
-	return func(s *AgentSettings) {
+func WithTopK(topK int64) AgentOption {
+	return func(s *agentSettings) {
 		s.topK = &topK
 	}
 }
 
-func WithPresencePenalty(penalty float64) agentOption {
-	return func(s *AgentSettings) {
+func WithPresencePenalty(penalty float64) AgentOption {
+	return func(s *agentSettings) {
 		s.presencePenalty = &penalty
 	}
 }
 
-func WithFrequencyPenalty(penalty float64) agentOption {
-	return func(s *AgentSettings) {
+func WithFrequencyPenalty(penalty float64) AgentOption {
+	return func(s *agentSettings) {
 		s.frequencyPenalty = &penalty
 	}
 }
 
-func WithTools(tools ...AgentTool) agentOption {
-	return func(s *AgentSettings) {
+func WithTools(tools ...AgentTool) AgentOption {
+	return func(s *agentSettings) {
 		s.tools = append(s.tools, tools...)
 	}
 }
 
-func WithStopConditions(conditions ...StopCondition) agentOption {
-	return func(s *AgentSettings) {
+func WithStopConditions(conditions ...StopCondition) AgentOption {
+	return func(s *agentSettings) {
 		s.stopWhen = append(s.stopWhen, conditions...)
 	}
 }
 
-func WithPrepareStep(fn PrepareStepFunction) agentOption {
-	return func(s *AgentSettings) {
+func WithPrepareStep(fn PrepareStepFunction) AgentOption {
+	return func(s *agentSettings) {
 		s.prepareStep = fn
 	}
 }
 
-func WithRepairToolCall(fn RepairToolCallFunction) agentOption {
-	return func(s *AgentSettings) {
+func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
+	return func(s *agentSettings) {
 		s.repairToolCall = fn
 	}
 }
 
-func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
-	return func(s *AgentSettings) {
-		s.onStepFinished = fn
-	}
-}
-
 // processStepStream processes a single step's stream and returns the step result
 func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
 	var stepContent []Content
@@ -1069,11 +1060,14 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 					Text:             text,
 					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
 				})
+				if opts.OnReasoningEnd != nil {
+					opts.OnReasoningEnd(part.ID, ReasoningContent{
+						Text:             text,
+						ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+					})
+				}
 				delete(activeTextContent, part.ID)
 			}
-			if opts.OnReasoningEnd != nil {
-				opts.OnReasoningEnd(part.ID)
-			}
 
 		case StreamPartTypeToolInputStart:
 			activeToolCalls[part.ID] = &ToolCallContent{
@@ -1194,14 +1188,14 @@ func addUsage(a, b Usage) Usage {
 	}
 }
 
-func WithHeaders(headers map[string]string) agentOption {
-	return func(s *AgentSettings) {
+func WithHeaders(headers map[string]string) AgentOption {
+	return func(s *agentSettings) {
 		s.headers = headers
 	}
 }
 
-func WithProviderOptions(providerOptions ProviderOptions) agentOption {
-	return func(s *AgentSettings) {
+func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
+	return func(s *agentSettings) {
 		s.providerOptions = providerOptions
 	}
 }

internal/ai/agent_stream_test.go 🔗

@@ -145,7 +145,7 @@ func TestStreamingAgentCallbacks(t *testing.T) {
 		OnReasoningDelta: func(id, text string) {
 			callbacks["OnReasoningDelta"] = true
 		},
-		OnReasoningEnd: func(id string) {
+		OnReasoningEnd: func(id string, content ReasoningContent) {
 			callbacks["OnReasoningEnd"] = true
 		},
 		OnToolInputStart: func(id, toolName string) {
@@ -166,7 +166,7 @@ func TestStreamingAgentCallbacks(t *testing.T) {
 		OnSource: func(source SourceContent) {
 			callbacks["OnSource"] = true
 		},
-		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
+		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) {
 			callbacks["OnStreamFinish"] = true
 		},
 		OnStreamError: func(err error) {

internal/ai/agent_test.go 🔗

@@ -962,13 +962,13 @@ func TestPrepareStep(t *testing.T) {
 			},
 		}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
 			return PrepareStepResult{
 				Model:    options.Model,
 				Messages: options.Messages,
 				System:   &newSystem,
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
@@ -999,13 +999,13 @@ func TestPrepareStep(t *testing.T) {
 			},
 		}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			toolChoice := ToolChoiceNone
 			return PrepareStepResult{
 				Model:      options.Model,
 				Messages:   options.Messages,
 				ToolChoice: &toolChoice,
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model)
@@ -1044,13 +1044,13 @@ func TestPrepareStep(t *testing.T) {
 		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
 		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			activeTools := []string{"tool2"} // Only tool2 should be active
 			return PrepareStepResult{
 				Model:       options.Model,
 				Messages:    options.Messages,
 				ActiveTools: activeTools,
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
@@ -1084,12 +1084,12 @@ func TestPrepareStep(t *testing.T) {
 
 		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			return PrepareStepResult{
 				Model:           options.Model,
 				Messages:        options.Messages,
 				DisableAllTools: true, // Disable all tools for this step
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithTools(tool1))
@@ -1139,7 +1139,7 @@ func TestPrepareStep(t *testing.T) {
 		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
 		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			newSystem := "Step-specific system"
 			toolChoice := SpecificToolChoice("tool1")
 			activeTools := []string{"tool1"}
@@ -1149,7 +1149,7 @@ func TestPrepareStep(t *testing.T) {
 				System:      &newSystem,
 				ToolChoice:  &toolChoice,
 				ActiveTools: activeTools,
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
@@ -1202,7 +1202,7 @@ func TestPrepareStep(t *testing.T) {
 
 		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			// All optional fields are nil, should use parent values
 			return PrepareStepResult{
 				Model:       options.Model,
@@ -1210,7 +1210,7 @@ func TestPrepareStep(t *testing.T) {
 				System:      nil, // Use parent
 				ToolChoice:  nil, // Use parent (auto)
 				ActiveTools: nil, // Use parent (all tools)
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
@@ -1251,12 +1251,12 @@ func TestPrepareStep(t *testing.T) {
 		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
 		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
 
-		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
+		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 			return PrepareStepResult{
 				Model:       options.Model,
 				Messages:    options.Messages,
 				ActiveTools: []string{}, // Empty slice means all tools
-			}
+			}, nil
 		}
 
 		agent := NewAgent(model, WithTools(tool1, tool2))

internal/ai/content.go 🔗

@@ -78,7 +78,22 @@ type Message struct {
 	ProviderOptions ProviderOptions `json:"provider_options"`
 }
 
-func AsContentType[T MessagePart](content MessagePart) (T, bool) {
+func AsContentType[T Content](content Content) (T, bool) {
+	var zero T
+	if content == nil {
+		return zero, false
+	}
+	switch v := any(content).(type) {
+	case T:
+		return v, true
+	case *T:
+		return *v, true
+	default:
+		return zero, false
+	}
+}
+
+func AsMessagePart[T MessagePart](content MessagePart) (T, bool) {
 	var zero T
 	if content == nil {
 		return zero, false
@@ -96,6 +111,7 @@ func AsContentType[T MessagePart](content MessagePart) (T, bool) {
 // MessagePart represents a part of a message content.
 type MessagePart interface {
 	GetType() ContentType
+	Options() ProviderOptions
 }
 
 // TextPart represents text content in a message.
@@ -109,6 +125,10 @@ func (t TextPart) GetType() ContentType {
 	return ContentTypeText
 }
 
+func (t TextPart) Options() ProviderOptions {
+	return t.ProviderOptions
+}
+
 // ReasoningPart represents reasoning content in a message.
 type ReasoningPart struct {
 	Text            string          `json:"text"`
@@ -120,6 +140,10 @@ func (r ReasoningPart) GetType() ContentType {
 	return ContentTypeReasoning
 }
 
+func (r ReasoningPart) Options() ProviderOptions {
+	return r.ProviderOptions
+}
+
 // FilePart represents file content in a message.
 type FilePart struct {
 	Filename        string          `json:"filename"`
@@ -133,6 +157,10 @@ func (f FilePart) GetType() ContentType {
 	return ContentTypeFile
 }
 
+func (f FilePart) Options() ProviderOptions {
+	return f.ProviderOptions
+}
+
 // ToolCallPart represents a tool call in a message.
 type ToolCallPart struct {
 	ToolCallID       string          `json:"tool_call_id"`
@@ -147,6 +175,10 @@ func (t ToolCallPart) GetType() ContentType {
 	return ContentTypeToolCall
 }
 
+func (t ToolCallPart) Options() ProviderOptions {
+	return t.ProviderOptions
+}
+
 // ToolResultPart represents a tool result in a message.
 type ToolResultPart struct {
 	ToolCallID      string                  `json:"tool_call_id"`
@@ -159,6 +191,10 @@ func (t ToolResultPart) GetType() ContentType {
 	return ContentTypeToolResult
 }
 
+func (t ToolResultPart) Options() ProviderOptions {
+	return t.ProviderOptions
+}
+
 // ToolResultContentType represents the type of tool result output.
 type ToolResultContentType string
 

internal/ai/examples/streaming-agent/main.go 🔗

@@ -163,7 +163,7 @@ func main() {
 		OnReasoningDelta: func(id, text string) {
 			reasoningBuffer.WriteString(text)
 		},
-		OnReasoningEnd: func(id string) {
+		OnReasoningEnd: func(id string, content ai.ReasoningContent) {
 			if reasoningBuffer.Len() > 0 {
 				fmt.Printf("%s\n", reasoningBuffer.String())
 				reasoningBuffer.Reset()

internal/ai/model.go 🔗

@@ -159,15 +159,16 @@ type StreamPart struct {
 	URL        string     `json:"url"`
 	Title      string     `json:"title"`
 
-	ProviderMetadata ProviderOptions `json:"provider_metadata"`
+	ProviderMetadata ProviderMetadata `json:"provider_metadata"`
 }
 type StreamResponse = iter.Seq[StreamPart]
 
 type ToolChoice string
 
 const (
-	ToolChoiceNone ToolChoice = "none"
-	ToolChoiceAuto ToolChoice = "auto"
+	ToolChoiceNone     ToolChoice = "none"
+	ToolChoiceAuto     ToolChoice = "auto"
+	ToolChoiceRequired ToolChoice = "required"
 )
 
 func SpecificToolChoice(name string) ToolChoice {

internal/ai/tool.go 🔗

@@ -110,6 +110,9 @@ type funcToolWrapper[TInput any] struct {
 }
 
 func (w *funcToolWrapper[TInput]) Info() ToolInfo {
+	if w.schema.Required == nil {
+		w.schema.Required = []string{}
+	}
 	return ToolInfo{
 		Name:        w.name,
 		Description: w.description,