acp: Simplify control flow for native agent loop (#36868)

Antonio Scandurra and Bennet Bo Fenner created

Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

crates/agent2/src/thread.rs | 162 +++++++++++++++++---------------------
1 file changed, 72 insertions(+), 90 deletions(-)

Detailed changes

crates/agent2/src/thread.rs 🔗

@@ -1142,37 +1142,7 @@ impl Thread {
             _task: cx.spawn(async move |this, cx| {
                 log::debug!("Starting agent turn execution");
 
-                let turn_result: Result<()> = async {
-                    let mut intent = CompletionIntent::UserPrompt;
-                    loop {
-                        Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
-
-                        let mut end_turn = true;
-                        this.update(cx, |this, cx| {
-                            // Generate title if needed.
-                            if this.title.is_none() && this.pending_title_generation.is_none() {
-                                this.generate_title(cx);
-                            }
-
-                            // End the turn if the model didn't use tools.
-                            let message = this.pending_message.as_ref();
-                            end_turn =
-                                message.map_or(true, |message| message.tool_results.is_empty());
-                            this.flush_pending_message(cx);
-                        })?;
-
-                        if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
-                            log::info!("Tool use limit reached, completing turn");
-                            return Err(language_model::ToolUseLimitReachedError.into());
-                        } else if end_turn {
-                            log::debug!("No tool uses found, completing turn");
-                            return Ok(());
-                        } else {
-                            intent = CompletionIntent::ToolResults;
-                        }
-                    }
-                }
-                .await;
+                let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
                 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
 
                 match turn_result {
@@ -1203,20 +1173,17 @@ impl Thread {
         Ok(events_rx)
     }
 
-    async fn stream_completion(
+    async fn run_turn_internal(
         this: &WeakEntity<Self>,
-        model: &Arc<dyn LanguageModel>,
-        completion_intent: CompletionIntent,
+        model: Arc<dyn LanguageModel>,
         event_stream: &ThreadEventStream,
         cx: &mut AsyncApp,
     ) -> Result<()> {
-        log::debug!("Stream completion started successfully");
-
-        let mut attempt = None;
+        let mut attempt = 0;
+        let mut intent = CompletionIntent::UserPrompt;
         loop {
-            let request = this.update(cx, |this, cx| {
-                this.build_completion_request(completion_intent, cx)
-            })??;
+            let request =
+                this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
 
             telemetry::event!(
                 "Agent Thread Completion",
@@ -1227,23 +1194,19 @@ impl Thread {
                 attempt
             );
 
-            log::debug!(
-                "Calling model.stream_completion, attempt {}",
-                attempt.unwrap_or(0)
-            );
+            log::debug!("Calling model.stream_completion, attempt {}", attempt);
             let mut events = model
                 .stream_completion(request, cx)
                 .await
                 .map_err(|error| anyhow!(error))?;
             let mut tool_results = FuturesUnordered::new();
             let mut error = None;
-
             while let Some(event) = events.next().await {
+                log::trace!("Received completion event: {:?}", event);
                 match event {
                     Ok(event) => {
-                        log::trace!("Received completion event: {:?}", event);
                         tool_results.extend(this.update(cx, |this, cx| {
-                            this.handle_streamed_completion_event(event, event_stream, cx)
+                            this.handle_completion_event(event, event_stream, cx)
                         })??);
                     }
                     Err(err) => {
@@ -1253,6 +1216,7 @@ impl Thread {
                 }
             }
 
+            let end_turn = tool_results.is_empty();
             while let Some(tool_result) = tool_results.next().await {
                 log::debug!("Tool finished {:?}", tool_result);
 
@@ -1275,65 +1239,83 @@ impl Thread {
                 })?;
             }
 
-            if let Some(error) = error {
-                let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
-                if completion_mode == CompletionMode::Normal {
-                    return Err(anyhow!(error))?;
-                }
-
-                let Some(strategy) = Self::retry_strategy_for(&error) else {
-                    return Err(anyhow!(error))?;
-                };
-
-                let max_attempts = match &strategy {
-                    RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
-                    RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
-                };
-
-                let attempt = attempt.get_or_insert(0u8);
-
-                *attempt += 1;
-
-                let attempt = *attempt;
-                if attempt > max_attempts {
-                    return Err(anyhow!(error))?;
+            this.update(cx, |this, cx| {
+                this.flush_pending_message(cx);
+                if this.title.is_none() && this.pending_title_generation.is_none() {
+                    this.generate_title(cx);
                 }
+            })?;
 
-                let delay = match &strategy {
-                    RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
-                        let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
-                        Duration::from_secs(delay_secs)
-                    }
-                    RetryStrategy::Fixed { delay, .. } => *delay,
-                };
-                log::debug!("Retry attempt {attempt} with delay {delay:?}");
-
-                event_stream.send_retry(acp_thread::RetryStatus {
-                    last_error: error.to_string().into(),
-                    attempt: attempt as usize,
-                    max_attempts: max_attempts as usize,
-                    started_at: Instant::now(),
-                    duration: delay,
-                });
-                cx.background_executor().timer(delay).await;
-                this.update(cx, |this, cx| {
-                    this.flush_pending_message(cx);
+            if let Some(error) = error {
+                attempt += 1;
+                let retry =
+                    this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
+                let timer = cx.background_executor().timer(retry.duration);
+                event_stream.send_retry(retry);
+                timer.await;
+                this.update(cx, |this, _cx| {
                     if let Some(Message::Agent(message)) = this.messages.last() {
                         if message.tool_results.is_empty() {
+                            intent = CompletionIntent::UserPrompt;
                             this.messages.push(Message::Resume);
                         }
                     }
                 })?;
-            } else {
+            } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
+                return Err(language_model::ToolUseLimitReachedError.into());
+            } else if end_turn {
                 return Ok(());
+            } else {
+                intent = CompletionIntent::ToolResults;
+                attempt = 0;
             }
         }
     }
 
+    fn handle_completion_error(
+        &mut self,
+        error: LanguageModelCompletionError,
+        attempt: u8,
+    ) -> Result<acp_thread::RetryStatus> {
+        if self.completion_mode == CompletionMode::Normal {
+            return Err(anyhow!(error));
+        }
+
+        let Some(strategy) = Self::retry_strategy_for(&error) else {
+            return Err(anyhow!(error));
+        };
+
+        let max_attempts = match &strategy {
+            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+        };
+
+        if attempt > max_attempts {
+            return Err(anyhow!(error));
+        }
+
+        let delay = match &strategy {
+            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+                let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+                Duration::from_secs(delay_secs)
+            }
+            RetryStrategy::Fixed { delay, .. } => *delay,
+        };
+        log::debug!("Retry attempt {attempt} with delay {delay:?}");
+
+        Ok(acp_thread::RetryStatus {
+            last_error: error.to_string().into(),
+            attempt: attempt as usize,
+            max_attempts: max_attempts as usize,
+            started_at: Instant::now(),
+            duration: delay,
+        })
+    }
+
     /// A helper method that's called on every streamed completion event.
     /// Returns an optional tool result task, which the main agentic loop will
     /// send back to the model when it resolves.
-    fn handle_streamed_completion_event(
+    fn handle_completion_event(
         &mut self,
         event: LanguageModelCompletionEvent,
         event_stream: &ThreadEventStream,