diff --git a/agent.go b/agent.go index 07b73b6c72a1a21e40bcb4da4b5cbf3521600c47..54a83a56394f1e92cd2803dcf1cad87567829aaa 100644 --- a/agent.go +++ b/agent.go @@ -174,12 +174,12 @@ type AgentStreamCall struct { RepairToolCall RepairToolCallFunction // Agent-level callbacks - OnAgentStart func() // Called when agent starts - OnAgentFinish func(result *AgentResult) // Called when agent finishes - OnStepStart func(stepNumber int) // Called when a step starts - OnStepFinish func(stepResult StepResult) // Called when a step finishes - OnFinish func(result *AgentResult) // Called when entire agent completes - OnError func(error) // Called when an error occurs + OnAgentStart func() // Called when agent starts + OnAgentFinish func(result *AgentResult) error // Called when agent finishes + OnStepStart func(stepNumber int) error // Called when a step starts + OnStepFinish func(stepResult StepResult) error // Called when a step finishes + OnFinish func(result *AgentResult) // Called when entire agent completes + OnError func(error) // Called when an error occurs // Stream part callbacks - called for each corresponding stream part type OnChunk func(StreamPart) error // Called for each stream part (catch-all) diff --git a/agent_stream_test.go b/agent_stream_test.go index d3c0846dcd2a8d18d948f1cf65770f66941da9d6..c32de4557fdb8c496f74957ad74ea5274bcdf6c3 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -109,14 +109,17 @@ func TestStreamingAgentCallbacks(t *testing.T) { OnAgentStart: func() { callbacks["OnAgentStart"] = true }, - OnAgentFinish: func(result *AgentResult) { + OnAgentFinish: func(result *AgentResult) error { callbacks["OnAgentFinish"] = true + return nil }, - OnStepStart: func(stepNumber int) { + OnStepStart: func(stepNumber int) error { callbacks["OnStepStart"] = true + return nil }, - OnStepFinish: func(stepResult StepResult) { + OnStepFinish: func(stepResult StepResult) error { callbacks["OnStepFinish"] = true + return nil }, OnFinish: func(result *AgentResult) { callbacks["OnFinish"] = true diff --git a/examples/streaming-agent-simple/main.go b/examples/streaming-agent-simple/main.go index bd0e1bc971868baa69e17535bea7f5007ccd5bf9..1333c76de3862e4c58b149d0b9227fc877b9a754 100644 --- a/examples/streaming-agent-simple/main.go +++ b/examples/streaming-agent-simple/main.go @@ -76,8 +76,9 @@ func main() { }, // Show when each step completes - OnStepFinish: func(step ai.StepResult) { + OnStepFinish: func(step ai.StepResult) error { fmt.Printf("\n[Step completed: %s]\n", step.FinishReason) + return nil }, } diff --git a/examples/streaming-agent/main.go b/examples/streaming-agent/main.go index 0aa2846c78985749122e2b9aa4167808d8792220..ce5445c5cca14cfdad50b578c99deeb13e693697 100644 --- a/examples/streaming-agent/main.go +++ b/examples/streaming-agent/main.go @@ -126,15 +126,18 @@ func main() { OnAgentStart: func() { fmt.Println("🎬 Agent started") }, - OnAgentFinish: func(result *ai.AgentResult) { + OnAgentFinish: func(result *ai.AgentResult) error { fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens) + return nil }, - OnStepStart: func(stepNumber int) { + OnStepStart: func(stepNumber int) error { stepCount++ fmt.Printf("📝 Step %d started\n", stepNumber+1) + return nil }, - OnStepFinish: func(stepResult ai.StepResult) { + OnStepFinish: func(stepResult ai.StepResult) error { fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens) + return nil }, OnFinish: func(result *ai.AgentResult) { fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))