fix: honor MaxRetries option and add retry logic to Stream method (#68)

Cristian created

Change summary

agent.go | 30 ++++++++++++++++++++++++++++--
1 file changed, 28 insertions(+), 2 deletions(-)

Detailed changes

agent.go 🔗

@@ -418,6 +418,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
 
 		retryOptions := DefaultRetryOptions()
+		if opts.MaxRetries != nil {
+			retryOptions.MaxRetries = *opts.MaxRetries
+		}
 		retryOptions.OnRetry = opts.OnRetry
 		retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
 
@@ -826,8 +829,17 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			ProviderOptions:  call.ProviderOptions,
 		}
 
-		// Get streaming response
-		stream, err := stepModel.Stream(ctx, streamCall)
+		// Get streaming response with retry logic
+		retryOptions := DefaultRetryOptions()
+		if call.MaxRetries != nil {
+			retryOptions.MaxRetries = *call.MaxRetries
+		}
+		retryOptions.OnRetry = call.OnRetry
+		retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
+
+		stream, err := retry(ctx, func() (StreamResponse, error) {
+			return stepModel.Stream(ctx, streamCall)
+		})
 		if err != nil {
 			if opts.OnError != nil {
 				opts.OnError(err)
@@ -1061,6 +1073,20 @@ func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
 	}
 }
 
+// WithMaxRetries sets the maximum number of retries for the agent.
+func WithMaxRetries(maxRetries int) AgentOption {
+	return func(s *agentSettings) {
+		s.maxRetries = &maxRetries
+	}
+}
+
+// WithOnRetry sets the retry callback for the agent.
+func WithOnRetry(callback OnRetryCallback) AgentOption {
+	return func(s *agentSettings) {
+		s.onRetry = callback
+	}
+}
+
 // processStepStream processes a single step's stream and returns the step result.
 func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
 	var stepContent []Content