From 20e275bab2f87828cfc589e758606d994a699590 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 25 Aug 2025 16:31:56 +0200 Subject: [PATCH] chore: agent improvements --- internal/ai/agent.go | 130 +++++++++---------- internal/ai/agent_stream_test.go | 4 +- internal/ai/agent_test.go | 28 ++-- internal/ai/content.go | 38 +++++- internal/ai/examples/streaming-agent/main.go | 2 +- internal/ai/model.go | 7 +- internal/ai/tool.go | 3 + 7 files changed, 123 insertions(+), 89 deletions(-) diff --git a/internal/ai/agent.go b/internal/ai/agent.go index 5ac8b568d5dd8ee53533dd1305d74bfe767af168..d5a8b5a6444e50f5c56da8ed80c3977f849d7aae 100644 --- a/internal/ai/agent.go +++ b/internal/ai/agent.go @@ -104,12 +104,12 @@ type ToolCallRepairOptions struct { } type ( - PrepareStepFunction = func(options PrepareStepFunctionOptions) PrepareStepResult + PrepareStepFunction = func(options PrepareStepFunctionOptions) (PrepareStepResult, error) OnStepFinishedFunction = func(step StepResult) RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) ) -type AgentSettings struct { +type agentSettings struct { systemPrompt string maxOutputTokens *int64 temperature *float64 @@ -129,7 +129,6 @@ type AgentSettings struct { stopWhen []StopCondition prepareStep PrepareStepFunction repairToolCall RepairToolCallFunction - onStepFinished OnStepFinishedFunction onRetry OnRetryCallback } @@ -152,7 +151,6 @@ type AgentCall struct { StopWhen []StopCondition PrepareStep PrepareStepFunction RepairToolCall RepairToolCallFunction - OnStepFinished OnStepFinishedFunction } type AgentStreamCall struct { @@ -184,22 +182,22 @@ type AgentStreamCall struct { OnError func(error) // Called when an error occurs // Stream part callbacks - called for each corresponding stream part type - OnChunk func(StreamPart) // Called for each stream part (catch-all) - OnWarnings func(warnings []CallWarning) // Called for warnings - OnTextStart func(id string) // Called when text starts - OnTextDelta func(id, text string) // Called for text deltas - OnTextEnd func(id string) // Called when text ends - OnReasoningStart func(id string) // Called when reasoning starts - OnReasoningDelta func(id, text string) // Called for reasoning deltas - OnReasoningEnd func(id string) // Called when reasoning ends - OnToolInputStart func(id, toolName string) // Called when tool input starts - OnToolInputDelta func(id, delta string) // Called for tool input deltas - OnToolInputEnd func(id string) // Called when tool input ends - OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete - OnToolResult func(result ToolResultContent) // Called when tool execution completes - OnSource func(source SourceContent) // Called for source references - OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes - OnStreamError func(error) // Called when stream error occurs + OnChunk func(StreamPart) // Called for each stream part (catch-all) + OnWarnings func(warnings []CallWarning) // Called for warnings + OnTextStart func(id string) // Called when text starts + OnTextDelta func(id, text string) // Called for text deltas + OnTextEnd func(id string) // Called when text ends + OnReasoningStart func(id string) // Called when reasoning starts + OnReasoningDelta func(id, text string) // Called for reasoning deltas + OnReasoningEnd func(id string, reasoning ReasoningContent) // Called when reasoning ends + OnToolInputStart func(id, toolName string) // Called when tool input starts + OnToolInputDelta func(id, delta string) // Called for tool input deltas + OnToolInputEnd func(id string) // Called when tool input ends + OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete + OnToolResult func(result ToolResultContent) // Called when tool execution completes + OnSource func(source SourceContent) // Called for source references + OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes + OnStreamError func(error) // Called when stream error occurs } type AgentResult struct { @@ -214,14 +212,14 @@ type Agent interface { Stream(context.Context, AgentStreamCall) (*AgentResult, error) } -type agentOption = func(*AgentSettings) +type AgentOption = func(*agentSettings) type agent struct { - settings AgentSettings + settings agentSettings } -func NewAgent(model LanguageModel, opts ...agentOption) Agent { - settings := AgentSettings{ +func NewAgent(model LanguageModel, opts ...AgentOption) Agent { + settings := agentSettings{ model: model, } for _, o := range opts { @@ -260,9 +258,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { if call.RepairToolCall == nil && a.settings.repairToolCall != nil { call.RepairToolCall = a.settings.repairToolCall } - if call.OnStepFinished == nil && a.settings.onStepFinished != nil { - call.OnStepFinished = a.settings.onStepFinished - } if call.OnRetry == nil && a.settings.onRetry != nil { call.OnRetry = a.settings.onRetry } @@ -311,12 +306,15 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err disableAllTools := false if opts.PrepareStep != nil { - prepared := opts.PrepareStep(PrepareStepFunctionOptions{ + prepared, err := opts.PrepareStep(PrepareStepFunctionOptions{ Model: stepModel, Steps: steps, StepNumber: len(steps), Messages: stepInputMessages, }) + if err != nil { + return nil, err + } // Apply prepared step modifications if prepared.Messages != nil { @@ -423,10 +421,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err Messages: currentStepMessages, } steps = append(steps, stepResult) - if opts.OnStepFinished != nil { - opts.OnStepFinished(stepResult) - } - shouldStop := isStopConditionMet(opts.StopWhen, steps) if shouldStop || err != nil || len(stepToolCalls) == 0 || result.FinishReason != FinishReasonToolCalls { @@ -709,12 +703,15 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, // Apply step preparation if provided if call.PrepareStep != nil { - prepared := call.PrepareStep(PrepareStepFunctionOptions{ + prepared, err := call.PrepareStep(PrepareStepFunctionOptions{ Model: stepModel, Steps: steps, StepNumber: stepNumber, Messages: stepInputMessages, }) + if err != nil { + return nil, err + } if prepared.Messages != nil { stepInputMessages = prepared.Messages @@ -925,78 +922,72 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files .. return preparedPrompt, nil } -func WithSystemPrompt(prompt string) agentOption { - return func(s *AgentSettings) { +func WithSystemPrompt(prompt string) AgentOption { + return func(s *agentSettings) { s.systemPrompt = prompt } } -func WithMaxOutputTokens(tokens int64) agentOption { - return func(s *AgentSettings) { +func WithMaxOutputTokens(tokens int64) AgentOption { + return func(s *agentSettings) { s.maxOutputTokens = &tokens } } -func WithTemperature(temp float64) agentOption { - return func(s *AgentSettings) { +func WithTemperature(temp float64) AgentOption { + return func(s *agentSettings) { s.temperature = &temp } } -func WithTopP(topP float64) agentOption { - return func(s *AgentSettings) { +func WithTopP(topP float64) AgentOption { + return func(s *agentSettings) { s.topP = &topP } } -func WithTopK(topK int64) agentOption { - return func(s *AgentSettings) { +func WithTopK(topK int64) AgentOption { + return func(s *agentSettings) { s.topK = &topK } } -func WithPresencePenalty(penalty float64) agentOption { - return func(s *AgentSettings) { +func WithPresencePenalty(penalty float64) AgentOption { + return func(s *agentSettings) { s.presencePenalty = &penalty } } -func WithFrequencyPenalty(penalty float64) agentOption { - return func(s *AgentSettings) { +func WithFrequencyPenalty(penalty float64) AgentOption { + return func(s *agentSettings) { s.frequencyPenalty = &penalty } } -func WithTools(tools ...AgentTool) agentOption { - return func(s *AgentSettings) { +func WithTools(tools ...AgentTool) AgentOption { + return func(s *agentSettings) { s.tools = append(s.tools, tools...) } } -func WithStopConditions(conditions ...StopCondition) agentOption { - return func(s *AgentSettings) { +func WithStopConditions(conditions ...StopCondition) AgentOption { + return func(s *agentSettings) { s.stopWhen = append(s.stopWhen, conditions...) } } -func WithPrepareStep(fn PrepareStepFunction) agentOption { - return func(s *AgentSettings) { +func WithPrepareStep(fn PrepareStepFunction) AgentOption { + return func(s *agentSettings) { s.prepareStep = fn } } -func WithRepairToolCall(fn RepairToolCallFunction) agentOption { - return func(s *AgentSettings) { +func WithRepairToolCall(fn RepairToolCallFunction) AgentOption { + return func(s *agentSettings) { s.repairToolCall = fn } } -func WithOnStepFinished(fn OnStepFinishedFunction) agentOption { - return func(s *AgentSettings) { - s.onStepFinished = fn - } -} - // processStepStream processes a single step's stream and returns the step result func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) { var stepContent []Content @@ -1069,11 +1060,14 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op Text: text, ProviderMetadata: ProviderMetadata(part.ProviderMetadata), }) + if opts.OnReasoningEnd != nil { + opts.OnReasoningEnd(part.ID, ReasoningContent{ + Text: text, + ProviderMetadata: ProviderMetadata(part.ProviderMetadata), + }) + } delete(activeTextContent, part.ID) } - if opts.OnReasoningEnd != nil { - opts.OnReasoningEnd(part.ID) - } case StreamPartTypeToolInputStart: activeToolCalls[part.ID] = &ToolCallContent{ @@ -1194,14 +1188,14 @@ func addUsage(a, b Usage) Usage { } } -func WithHeaders(headers map[string]string) agentOption { - return func(s *AgentSettings) { +func WithHeaders(headers map[string]string) AgentOption { + return func(s *agentSettings) { s.headers = headers } } -func WithProviderOptions(providerOptions ProviderOptions) agentOption { - return func(s *AgentSettings) { +func WithProviderOptions(providerOptions ProviderOptions) AgentOption { + return func(s *agentSettings) { s.providerOptions = providerOptions } } diff --git a/internal/ai/agent_stream_test.go b/internal/ai/agent_stream_test.go index e11011e8b308f95059be9c7d0c076c2ac2ff34e1..9bd1477a04d97e0fb5d49c6c2deebe1f2952a969 100644 --- a/internal/ai/agent_stream_test.go +++ b/internal/ai/agent_stream_test.go @@ -145,7 +145,7 @@ func TestStreamingAgentCallbacks(t *testing.T) { OnReasoningDelta: func(id, text string) { callbacks["OnReasoningDelta"] = true }, - OnReasoningEnd: func(id string) { + OnReasoningEnd: func(id string, content ReasoningContent) { callbacks["OnReasoningEnd"] = true }, OnToolInputStart: func(id, toolName string) { @@ -166,7 +166,7 @@ func TestStreamingAgentCallbacks(t *testing.T) { OnSource: func(source SourceContent) { callbacks["OnSource"] = true }, - OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) { + OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) { callbacks["OnStreamFinish"] = true }, OnStreamError: func(err error) { diff --git a/internal/ai/agent_test.go b/internal/ai/agent_test.go index f21781c8639697a4bd5aac4c33809885c00242a4..73301a0892892f45d67a9f1ab9e0c865a34a9858 100644 --- a/internal/ai/agent_test.go +++ b/internal/ai/agent_test.go @@ -962,13 +962,13 @@ func TestPrepareStep(t *testing.T) { }, } - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber) return PrepareStepResult{ Model: options.Model, Messages: options.Messages, System: &newSystem, - } + }, nil } agent := NewAgent(model, WithSystemPrompt("Original system prompt")) @@ -999,13 +999,13 @@ func TestPrepareStep(t *testing.T) { }, } - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { toolChoice := ToolChoiceNone return PrepareStepResult{ Model: options.Model, Messages: options.Messages, ToolChoice: &toolChoice, - } + }, nil } agent := NewAgent(model) @@ -1044,13 +1044,13 @@ func TestPrepareStep(t *testing.T) { tool2 := &mockTool{name: "tool2", description: "Tool 2"} tool3 := &mockTool{name: "tool3", description: "Tool 3"} - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { activeTools := []string{"tool2"} // Only tool2 should be active return PrepareStepResult{ Model: options.Model, Messages: options.Messages, ActiveTools: activeTools, - } + }, nil } agent := NewAgent(model, WithTools(tool1, tool2, tool3)) @@ -1084,12 +1084,12 @@ func TestPrepareStep(t *testing.T) { tool1 := &mockTool{name: "tool1", description: "Tool 1"} - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { return PrepareStepResult{ Model: options.Model, Messages: options.Messages, DisableAllTools: true, // Disable all tools for this step - } + }, nil } agent := NewAgent(model, WithTools(tool1)) @@ -1139,7 +1139,7 @@ func TestPrepareStep(t *testing.T) { tool1 := &mockTool{name: "tool1", description: "Tool 1"} tool2 := &mockTool{name: "tool2", description: "Tool 2"} - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { newSystem := "Step-specific system" toolChoice := SpecificToolChoice("tool1") activeTools := []string{"tool1"} @@ -1149,7 +1149,7 @@ func TestPrepareStep(t *testing.T) { System: &newSystem, ToolChoice: &toolChoice, ActiveTools: activeTools, - } + }, nil } agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2)) @@ -1202,7 +1202,7 @@ func TestPrepareStep(t *testing.T) { tool1 := &mockTool{name: "tool1", description: "Tool 1"} - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { // All optional fields are nil, should use parent values return PrepareStepResult{ Model: options.Model, @@ -1210,7 +1210,7 @@ func TestPrepareStep(t *testing.T) { System: nil, // Use parent ToolChoice: nil, // Use parent (auto) ActiveTools: nil, // Use parent (all tools) - } + }, nil } agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1)) @@ -1251,12 +1251,12 @@ func TestPrepareStep(t *testing.T) { tool1 := &mockTool{name: "tool1", description: "Tool 1"} tool2 := &mockTool{name: "tool2", description: "Tool 2"} - prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult { + prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) { return PrepareStepResult{ Model: options.Model, Messages: options.Messages, ActiveTools: []string{}, // Empty slice means all tools - } + }, nil } agent := NewAgent(model, WithTools(tool1, tool2)) diff --git a/internal/ai/content.go b/internal/ai/content.go index eec2ec539dedc5c55e762ceff8e11721320a0749..f2f9aa93dbb6c583efd2b113232966c0035c992b 100644 --- a/internal/ai/content.go +++ b/internal/ai/content.go @@ -78,7 +78,22 @@ type Message struct { ProviderOptions ProviderOptions `json:"provider_options"` } -func AsContentType[T MessagePart](content MessagePart) (T, bool) { +func AsContentType[T Content](content Content) (T, bool) { + var zero T + if content == nil { + return zero, false + } + switch v := any(content).(type) { + case T: + return v, true + case *T: + return *v, true + default: + return zero, false + } +} + +func AsMessagePart[T MessagePart](content MessagePart) (T, bool) { var zero T if content == nil { return zero, false @@ -96,6 +111,7 @@ func AsContentType[T MessagePart](content MessagePart) (T, bool) { // MessagePart represents a part of a message content. type MessagePart interface { GetType() ContentType + Options() ProviderOptions } // TextPart represents text content in a message. @@ -109,6 +125,10 @@ func (t TextPart) GetType() ContentType { return ContentTypeText } +func (t TextPart) Options() ProviderOptions { + return t.ProviderOptions +} + // ReasoningPart represents reasoning content in a message. type ReasoningPart struct { Text string `json:"text"` @@ -120,6 +140,10 @@ func (r ReasoningPart) GetType() ContentType { return ContentTypeReasoning } +func (r ReasoningPart) Options() ProviderOptions { + return r.ProviderOptions +} + // FilePart represents file content in a message. type FilePart struct { Filename string `json:"filename"` @@ -133,6 +157,10 @@ func (f FilePart) GetType() ContentType { return ContentTypeFile } +func (f FilePart) Options() ProviderOptions { + return f.ProviderOptions +} + // ToolCallPart represents a tool call in a message. type ToolCallPart struct { ToolCallID string `json:"tool_call_id"` @@ -147,6 +175,10 @@ func (t ToolCallPart) GetType() ContentType { return ContentTypeToolCall } +func (t ToolCallPart) Options() ProviderOptions { + return t.ProviderOptions +} + // ToolResultPart represents a tool result in a message. type ToolResultPart struct { ToolCallID string `json:"tool_call_id"` @@ -159,6 +191,10 @@ func (t ToolResultPart) GetType() ContentType { return ContentTypeToolResult } +func (t ToolResultPart) Options() ProviderOptions { + return t.ProviderOptions +} + // ToolResultContentType represents the type of tool result output. type ToolResultContentType string diff --git a/internal/ai/examples/streaming-agent/main.go b/internal/ai/examples/streaming-agent/main.go index 1c6295da514fc672a88f9aa27daadc5deb2b8b80..2d087cf2f6ac849f3863486bd3153bb6dfc7158e 100644 --- a/internal/ai/examples/streaming-agent/main.go +++ b/internal/ai/examples/streaming-agent/main.go @@ -163,7 +163,7 @@ func main() { OnReasoningDelta: func(id, text string) { reasoningBuffer.WriteString(text) }, - OnReasoningEnd: func(id string) { + OnReasoningEnd: func(id string, content ai.ReasoningContent) { if reasoningBuffer.Len() > 0 { fmt.Printf("%s\n", reasoningBuffer.String()) reasoningBuffer.Reset() diff --git a/internal/ai/model.go b/internal/ai/model.go index 8e5b5d6126c25cec28daa0a0c8c8a41704c98f18..6e2f2415de5221e0a150a51c3f6e49d3d2e5dfa8 100644 --- a/internal/ai/model.go +++ b/internal/ai/model.go @@ -159,15 +159,16 @@ type StreamPart struct { URL string `json:"url"` Title string `json:"title"` - ProviderMetadata ProviderOptions `json:"provider_metadata"` + ProviderMetadata ProviderMetadata `json:"provider_metadata"` } type StreamResponse = iter.Seq[StreamPart] type ToolChoice string const ( - ToolChoiceNone ToolChoice = "none" - ToolChoiceAuto ToolChoice = "auto" + ToolChoiceNone ToolChoice = "none" + ToolChoiceAuto ToolChoice = "auto" + ToolChoiceRequired ToolChoice = "required" ) func SpecificToolChoice(name string) ToolChoice { diff --git a/internal/ai/tool.go b/internal/ai/tool.go index c12b2775be655eeb93576ef68f509c2d6ce6d761..b6ad7ec1ca8c5ef7e8de69470773783671eebf25 100644 --- a/internal/ai/tool.go +++ b/internal/ai/tool.go @@ -110,6 +110,9 @@ type funcToolWrapper[TInput any] struct { } func (w *funcToolWrapper[TInput]) Info() ToolInfo { + if w.schema.Required == nil { + w.schema.Required = []string{} + } return ToolInfo{ Name: w.name, Description: w.description,