acp: Refactor agent2 `send` to have a clearer control flow (#36689)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

Cargo.lock                  |   1 
crates/agent2/Cargo.toml    |   1 
crates/agent2/src/thread.rs | 295 +++++++++++++++++---------------------
3 files changed, 134 insertions(+), 163 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -244,6 +244,7 @@ dependencies = [
  "terminal",
  "text",
  "theme",
+ "thiserror 2.0.12",
  "tree-sitter-rust",
  "ui",
  "unindent",

crates/agent2/Cargo.toml 🔗

@@ -61,6 +61,7 @@ sqlez.workspace = true
 task.workspace = true
 telemetry.workspace = true
 terminal.workspace = true
+thiserror.workspace = true
 text.workspace = true
 ui.workspace = true
 util.workspace = true

crates/agent2/src/thread.rs 🔗

@@ -499,6 +499,16 @@ pub struct ToolCallAuthorization {
     pub response: oneshot::Sender<acp::PermissionOptionId>,
 }
 
+#[derive(Debug, thiserror::Error)]
+enum CompletionError {
+    #[error("max tokens")]
+    MaxTokens,
+    #[error("refusal")]
+    Refusal,
+    #[error(transparent)]
+    Other(#[from] anyhow::Error),
+}
+
 pub struct Thread {
     id: acp::SessionId,
     prompt_id: PromptId,
@@ -1077,101 +1087,62 @@ impl Thread {
             _task: cx.spawn(async move |this, cx| {
                 log::info!("Starting agent turn execution");
                 let mut update_title = None;
-                let turn_result: Result<StopReason> = async {
-                    let mut completion_intent = CompletionIntent::UserPrompt;
+                let turn_result: Result<()> = async {
+                    let mut intent = CompletionIntent::UserPrompt;
                     loop {
-                        log::debug!(
-                            "Building completion request with intent: {:?}",
-                            completion_intent
-                        );
-                        let request = this.update(cx, |this, cx| {
-                            this.build_completion_request(completion_intent, cx)
-                        })??;
-
-                        log::info!("Calling model.stream_completion");
-
-                        let mut tool_use_limit_reached = false;
-                        let mut refused = false;
-                        let mut reached_max_tokens = false;
-                        let mut tool_uses = Self::stream_completion_with_retries(
-                            this.clone(),
-                            model.clone(),
-                            request,
-                            &event_stream,
-                            &mut tool_use_limit_reached,
-                            &mut refused,
-                            &mut reached_max_tokens,
-                            cx,
-                        )
-                        .await?;
-
-                        if refused {
-                            return Ok(StopReason::Refusal);
-                        } else if reached_max_tokens {
-                            return Ok(StopReason::MaxTokens);
-                        }
-
-                        let end_turn = tool_uses.is_empty();
-                        while let Some(tool_result) = tool_uses.next().await {
-                            log::info!("Tool finished {:?}", tool_result);
-
-                            event_stream.update_tool_call_fields(
-                                &tool_result.tool_use_id,
-                                acp::ToolCallUpdateFields {
-                                    status: Some(if tool_result.is_error {
-                                        acp::ToolCallStatus::Failed
-                                    } else {
-                                        acp::ToolCallStatus::Completed
-                                    }),
-                                    raw_output: tool_result.output.clone(),
-                                    ..Default::default()
-                                },
-                            );
-                            this.update(cx, |this, _cx| {
-                                this.pending_message()
-                                    .tool_results
-                                    .insert(tool_result.tool_use_id.clone(), tool_result);
-                            })?;
-                        }
+                        Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
 
+                        let mut end_turn = true;
                         this.update(cx, |this, cx| {
+                            // Generate title if needed.
                             if this.title.is_none() && update_title.is_none() {
                                 update_title = Some(this.update_title(&event_stream, cx));
                             }
+
+                            // End the turn if the model didn't use tools.
+                            let message = this.pending_message.as_ref();
+                            end_turn =
+                                message.map_or(true, |message| message.tool_results.is_empty());
+                            this.flush_pending_message(cx);
                         })?;
 
-                        if tool_use_limit_reached {
+                        if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
                             log::info!("Tool use limit reached, completing turn");
-                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
                             return Err(language_model::ToolUseLimitReachedError.into());
                         } else if end_turn {
                             log::info!("No tool uses found, completing turn");
-                            return Ok(StopReason::EndTurn);
+                            return Ok(());
                         } else {
-                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
-                            completion_intent = CompletionIntent::ToolResults;
+                            intent = CompletionIntent::ToolResults;
                         }
                     }
                 }
                 .await;
                 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
 
-                match turn_result {
-                    Ok(reason) => {
-                        log::info!("Turn execution completed: {:?}", reason);
-
-                        if let Some(update_title) = update_title {
-                            update_title.await.context("update title failed").log_err();
-                        }
+                if let Some(update_title) = update_title {
+                    update_title.await.context("update title failed").log_err();
+                }
 
-                        event_stream.send_stop(reason);
-                        if reason == StopReason::Refusal {
-                            _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
-                        }
+                match turn_result {
+                    Ok(()) => {
+                        log::info!("Turn execution completed");
+                        event_stream.send_stop(acp::StopReason::EndTurn);
                     }
                     Err(error) => {
                         log::error!("Turn execution failed: {:?}", error);
-                        event_stream.send_error(error);
+                        match error.downcast::<CompletionError>() {
+                            Ok(CompletionError::Refusal) => {
+                                event_stream.send_stop(acp::StopReason::Refusal);
+                                _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
+                            }
+                            Ok(CompletionError::MaxTokens) => {
+                                event_stream.send_stop(acp::StopReason::MaxTokens);
+                            }
+                            Ok(CompletionError::Other(error)) | Err(error) => {
+                                event_stream.send_error(error);
+                            }
+                        }
                     }
                 }
 
@@ -1181,17 +1152,17 @@ impl Thread {
         Ok(events_rx)
     }
 
-    async fn stream_completion_with_retries(
-        this: WeakEntity<Self>,
-        model: Arc<dyn LanguageModel>,
-        request: LanguageModelRequest,
+    async fn stream_completion(
+        this: &WeakEntity<Self>,
+        model: &Arc<dyn LanguageModel>,
+        completion_intent: CompletionIntent,
         event_stream: &ThreadEventStream,
-        tool_use_limit_reached: &mut bool,
-        refusal: &mut bool,
-        max_tokens_reached: &mut bool,
         cx: &mut AsyncApp,
-    ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
+    ) -> Result<()> {
         log::debug!("Stream completion started successfully");
+        let request = this.update(cx, |this, cx| {
+            this.build_completion_request(completion_intent, cx)
+        })??;
 
         let mut attempt = None;
         'retry: loop {
@@ -1204,68 +1175,33 @@ impl Thread {
                 attempt
             );
 
-            let mut events = model.stream_completion(request.clone(), cx).await?;
-            let mut tool_uses = FuturesUnordered::new();
+            log::info!(
+                "Calling model.stream_completion, attempt {}",
+                attempt.unwrap_or(0)
+            );
+            let mut events = model
+                .stream_completion(request.clone(), cx)
+                .await
+                .map_err(|error| anyhow!(error))?;
+            let mut tool_results = FuturesUnordered::new();
+
             while let Some(event) = events.next().await {
                 match event {
-                    Ok(LanguageModelCompletionEvent::StatusUpdate(
-                        CompletionRequestStatus::ToolUseLimitReached,
-                    )) => {
-                        *tool_use_limit_reached = true;
-                    }
-                    Ok(LanguageModelCompletionEvent::StatusUpdate(
-                        CompletionRequestStatus::UsageUpdated { amount, limit },
-                    )) => {
-                        this.update(cx, |this, cx| {
-                            this.update_model_request_usage(amount, limit, cx)
-                        })?;
-                    }
-                    Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
-                        telemetry::event!(
-                            "Agent Thread Completion Usage Updated",
-                            thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
-                            prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
-                            model = model.telemetry_id(),
-                            model_provider = model.provider_id().to_string(),
-                            attempt,
-                            input_tokens = usage.input_tokens,
-                            output_tokens = usage.output_tokens,
-                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
-                            cache_read_input_tokens = usage.cache_read_input_tokens,
-                        );
-
-                        this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
-                    }
-                    Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
-                        *refusal = true;
-                        return Ok(FuturesUnordered::default());
-                    }
-                    Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
-                        *max_tokens_reached = true;
-                        return Ok(FuturesUnordered::default());
-                    }
-                    Ok(LanguageModelCompletionEvent::Stop(
-                        StopReason::ToolUse | StopReason::EndTurn,
-                    )) => break,
                     Ok(event) => {
                         log::trace!("Received completion event: {:?}", event);
-                        this.update(cx, |this, cx| {
-                            tool_uses.extend(this.handle_streamed_completion_event(
-                                event,
-                                event_stream,
-                                cx,
-                            ));
-                        })?;
+                        tool_results.extend(this.update(cx, |this, cx| {
+                            this.handle_streamed_completion_event(event, event_stream, cx)
+                        })??);
                     }
                     Err(error) => {
                         let completion_mode =
                             this.read_with(cx, |thread, _cx| thread.completion_mode())?;
                         if completion_mode == CompletionMode::Normal {
-                            return Err(error.into());
+                            return Err(anyhow!(error))?;
                         }
 
                         let Some(strategy) = Self::retry_strategy_for(&error) else {
-                            return Err(error.into());
+                            return Err(anyhow!(error))?;
                         };
 
                         let max_attempts = match &strategy {
@@ -1279,7 +1215,7 @@ impl Thread {
 
                         let attempt = *attempt;
                         if attempt > max_attempts {
-                            return Err(error.into());
+                            return Err(anyhow!(error))?;
                         }
 
                         let delay = match &strategy {
@@ -1306,7 +1242,29 @@ impl Thread {
                 }
             }
 
-            return Ok(tool_uses);
+            while let Some(tool_result) = tool_results.next().await {
+                log::info!("Tool finished {:?}", tool_result);
+
+                event_stream.update_tool_call_fields(
+                    &tool_result.tool_use_id,
+                    acp::ToolCallUpdateFields {
+                        status: Some(if tool_result.is_error {
+                            acp::ToolCallStatus::Failed
+                        } else {
+                            acp::ToolCallStatus::Completed
+                        }),
+                        raw_output: tool_result.output.clone(),
+                        ..Default::default()
+                    },
+                );
+                this.update(cx, |this, _cx| {
+                    this.pending_message()
+                        .tool_results
+                        .insert(tool_result.tool_use_id.clone(), tool_result);
+                })?;
+            }
+
+            return Ok(());
         }
     }
 
@@ -1328,14 +1286,14 @@ impl Thread {
     }
 
     /// A helper method that's called on every streamed completion event.
-    /// Returns an optional tool result task, which the main agentic loop in
-    /// send will send back to the model when it resolves.
+    /// Returns an optional tool result task, which the main agentic loop will
+    /// send back to the model when it resolves.
     fn handle_streamed_completion_event(
         &mut self,
         event: LanguageModelCompletionEvent,
         event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
-    ) -> Option<Task<LanguageModelToolResult>> {
+    ) -> Result<Option<Task<LanguageModelToolResult>>> {
         log::trace!("Handling streamed completion event: {:?}", event);
         use LanguageModelCompletionEvent::*;
 
@@ -1350,7 +1308,7 @@ impl Thread {
             }
             RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
             ToolUse(tool_use) => {
-                return self.handle_tool_use_event(tool_use, event_stream, cx);
+                return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
             }
             ToolUseJsonParseError {
                 id,
@@ -1358,18 +1316,46 @@ impl Thread {
                 raw_input,
                 json_parse_error,
             } => {
-                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
-                    id,
-                    tool_name,
-                    raw_input,
-                    json_parse_error,
+                return Ok(Some(Task::ready(
+                    self.handle_tool_use_json_parse_error_event(
+                        id,
+                        tool_name,
+                        raw_input,
+                        json_parse_error,
+                    ),
                 )));
             }
-            StatusUpdate(_) => {}
-            UsageUpdate(_) | Stop(_) => unreachable!(),
+            UsageUpdate(usage) => {
+                telemetry::event!(
+                    "Agent Thread Completion Usage Updated",
+                    thread_id = self.id.to_string(),
+                    prompt_id = self.prompt_id.to_string(),
+                    model = self.model.as_ref().map(|m| m.telemetry_id()),
+                    model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
+                    input_tokens = usage.input_tokens,
+                    output_tokens = usage.output_tokens,
+                    cache_creation_input_tokens = usage.cache_creation_input_tokens,
+                    cache_read_input_tokens = usage.cache_read_input_tokens,
+                );
+                self.update_token_usage(usage, cx);
+            }
+            StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
+                self.update_model_request_usage(amount, limit, cx);
+            }
+            StatusUpdate(
+                CompletionRequestStatus::Started
+                | CompletionRequestStatus::Queued { .. }
+                | CompletionRequestStatus::Failed { .. },
+            ) => {}
+            StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
+                self.tool_use_limit_reached = true;
+            }
+            Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
+            Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
+            Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
         }
 
-        None
+        Ok(None)
     }
 
     fn handle_text_event(
@@ -2225,25 +2211,8 @@ impl ThreadEventStream {
         self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
     }
 
-    fn send_stop(&self, reason: StopReason) {
-        match reason {
-            StopReason::EndTurn => {
-                self.0
-                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
-                    .ok();
-            }
-            StopReason::MaxTokens => {
-                self.0
-                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
-                    .ok();
-            }
-            StopReason::Refusal => {
-                self.0
-                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
-                    .ok();
-            }
-            StopReason::ToolUse => {}
-        }
+    fn send_stop(&self, reason: acp::StopReason) {
+        self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
     }
 
     fn send_canceled(&self) {