acp: Never build a request with a tool use without its corresponding result (#36847)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/agent2/src/tests/mod.rs | 76 ++++++++++++++++++++++++++++++++++++
crates/agent2/src/thread.rs    | 74 +++++++++++++++++-----------------
2 files changed, 113 insertions(+), 37 deletions(-)

Detailed changes

crates/agent2/src/tests/mod.rs 🔗

@@ -4,6 +4,7 @@ use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
 use anyhow::Result;
 use client::{Client, UserStore};
+use cloud_llm_client::CompletionIntent;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use fs::{FakeFs, Fs};
 use futures::{
@@ -1737,6 +1738,81 @@ async fn test_title_generation(cx: &mut TestAppContext) {
     thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
 }
 
+#[gpui::test]
+async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.add_tool(ToolRequiringPermission);
+            thread.add_tool(EchoTool);
+            thread.send(UserMessageId::new(), ["Hey!"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let permission_tool_use = LanguageModelToolUse {
+        id: "tool_id_1".into(),
+        name: ToolRequiringPermission::name().into(),
+        raw_input: "{}".into(),
+        input: json!({}),
+        is_input_complete: true,
+    };
+    let echo_tool_use = LanguageModelToolUse {
+        id: "tool_id_2".into(),
+        name: EchoTool::name().into(),
+        raw_input: json!({"text": "test"}).to_string(),
+        input: json!({"text": "test"}),
+        is_input_complete: true,
+    };
+    fake_model.send_last_completion_stream_text_chunk("Hi!");
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        permission_tool_use,
+    ));
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        echo_tool_use.clone(),
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // Ensure pending tools are skipped when building a request.
+    let request = thread
+        .read_with(cx, |thread, cx| {
+            thread.build_completion_request(CompletionIntent::EditFile, cx)
+        })
+        .unwrap();
+    assert_eq!(
+        request.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Hey!".into()],
+                cache: true
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![
+                    MessageContent::Text("Hi!".into()),
+                    MessageContent::ToolUse(echo_tool_use.clone())
+                ],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
+                    tool_use_id: echo_tool_use.id.clone(),
+                    tool_name: echo_tool_use.name,
+                    is_error: false,
+                    content: "test".into(),
+                    output: Some("test".into())
+                })],
+                cache: false
+            },
+        ],
+    );
+}
+
 #[gpui::test]
 async fn test_agent_connection(cx: &mut TestAppContext) {
     cx.update(settings::init);

crates/agent2/src/thread.rs 🔗

@@ -448,24 +448,33 @@ impl AgentMessage {
             cache: false,
         };
         for chunk in &self.content {
-            let chunk = match chunk {
+            match chunk {
                 AgentMessageContent::Text(text) => {
-                    language_model::MessageContent::Text(text.clone())
+                    assistant_message
+                        .content
+                        .push(language_model::MessageContent::Text(text.clone()));
                 }
                 AgentMessageContent::Thinking { text, signature } => {
-                    language_model::MessageContent::Thinking {
-                        text: text.clone(),
-                        signature: signature.clone(),
-                    }
+                    assistant_message
+                        .content
+                        .push(language_model::MessageContent::Thinking {
+                            text: text.clone(),
+                            signature: signature.clone(),
+                        });
                 }
                 AgentMessageContent::RedactedThinking(value) => {
-                    language_model::MessageContent::RedactedThinking(value.clone())
+                    assistant_message.content.push(
+                        language_model::MessageContent::RedactedThinking(value.clone()),
+                    );
                 }
-                AgentMessageContent::ToolUse(value) => {
-                    language_model::MessageContent::ToolUse(value.clone())
+                AgentMessageContent::ToolUse(tool_use) => {
+                    if self.tool_results.contains_key(&tool_use.id) {
+                        assistant_message
+                            .content
+                            .push(language_model::MessageContent::ToolUse(tool_use.clone()));
+                    }
                 }
             };
-            assistant_message.content.push(chunk);
         }
 
         let mut user_message = LanguageModelRequestMessage {
@@ -1315,23 +1324,6 @@ impl Thread {
         }
     }
 
-    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
-        log::debug!("Building system message");
-        let prompt = SystemPromptTemplate {
-            project: self.project_context.read(cx),
-            available_tools: self.tools.keys().cloned().collect(),
-        }
-        .render(&self.templates)
-        .context("failed to build system prompt")
-        .expect("Invalid template");
-        log::debug!("System message built");
-        LanguageModelRequestMessage {
-            role: Role::System,
-            content: vec![prompt.into()],
-            cache: true,
-        }
-    }
-
     /// A helper method that's called on every streamed completion event.
     /// Returns an optional tool result task, which the main agentic loop will
     /// send back to the model when it resolves.
@@ -1773,7 +1765,7 @@ impl Thread {
     pub(crate) fn build_completion_request(
         &self,
         completion_intent: CompletionIntent,
-        cx: &mut App,
+        cx: &App,
     ) -> Result<LanguageModelRequest> {
         let model = self.model().context("No language model configured")?;
         let tools = if let Some(turn) = self.running_turn.as_ref() {
@@ -1894,21 +1886,29 @@ impl Thread {
             "Building request messages from {} thread messages",
             self.messages.len()
         );
-        let mut messages = vec![self.build_system_message(cx)];
+
+        let system_prompt = SystemPromptTemplate {
+            project: self.project_context.read(cx),
+            available_tools: self.tools.keys().cloned().collect(),
+        }
+        .render(&self.templates)
+        .context("failed to build system prompt")
+        .expect("Invalid template");
+        let mut messages = vec![LanguageModelRequestMessage {
+            role: Role::System,
+            content: vec![system_prompt.into()],
+            cache: false,
+        }];
         for message in &self.messages {
             messages.extend(message.to_request());
         }
 
-        if let Some(message) = self.pending_message.as_ref() {
-            messages.extend(message.to_request());
+        if let Some(last_message) = messages.last_mut() {
+            last_message.cache = true;
         }
 
-        if let Some(last_user_message) = messages
-            .iter_mut()
-            .rev()
-            .find(|message| message.role == Role::User)
-        {
-            last_user_message.cache = true;
+        if let Some(message) = self.pending_message.as_ref() {
+            messages.extend(message.to_request());
         }
 
         messages