agent2: Port Zed AI features (#36172)

Bennet Bo Fenner and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/acp_thread/src/acp_thread.rs            | 173 ++++++--
crates/acp_thread/src/connection.rs            |  26 +
crates/agent2/src/agent.rs                     | 282 ++++++++------
crates/agent2/src/tests/mod.rs                 | 202 ++++++++++
crates/agent2/src/tests/test_tools.rs          |   2 
crates/agent2/src/thread.rs                    | 186 ++++++---
crates/agent2/src/tools/edit_file_tool.rs      |   2 
crates/agent_servers/src/acp/v0.rs             |   6 
crates/agent_servers/src/acp/v1.rs             |   6 
crates/agent_servers/src/claude.rs             |   5 
crates/agent_ui/src/acp/thread_view.rs         | 376 +++++++++++++++++--
crates/agent_ui/src/agent_ui.rs                |   1 
crates/agent_ui/src/burn_mode_tooltip.rs       |  61 ---
crates/agent_ui/src/message_editor.rs          |   4 
crates/agent_ui/src/text_thread_editor.rs      |   2 
crates/agent_ui/src/ui/burn_mode_tooltip.rs    |   6 
crates/language_model/src/model/cloud_model.rs |  12 
17 files changed, 994 insertions(+), 358 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -33,13 +33,23 @@ pub struct UserMessage {
     pub id: Option<UserMessageId>,
     pub content: ContentBlock,
     pub chunks: Vec<acp::ContentBlock>,
-    pub checkpoint: Option<GitStoreCheckpoint>,
+    pub checkpoint: Option<Checkpoint>,
+}
+
+#[derive(Debug)]
+pub struct Checkpoint {
+    git_checkpoint: GitStoreCheckpoint,
+    pub show: bool,
 }
 
 impl UserMessage {
     fn to_markdown(&self, cx: &App) -> String {
         let mut markdown = String::new();
-        if let Some(_) = self.checkpoint {
+        if self
+            .checkpoint
+            .as_ref()
+            .map_or(false, |checkpoint| checkpoint.show)
+        {
             writeln!(markdown, "## User (checkpoint)").unwrap();
         } else {
             writeln!(markdown, "## User").unwrap();
@@ -1145,9 +1155,12 @@ impl AcpThread {
             self.project.read(cx).languages().clone(),
             cx,
         );
+        let request = acp::PromptRequest {
+            prompt: message.clone(),
+            session_id: self.session_id.clone(),
+        };
         let git_store = self.project.read(cx).git_store().clone();
 
-        let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
         let message_id = if self
             .connection
             .session_editor(&self.session_id, cx)
@@ -1161,68 +1174,63 @@ impl AcpThread {
             AgentThreadEntry::UserMessage(UserMessage {
                 id: message_id.clone(),
                 content: block,
-                chunks: message.clone(),
+                chunks: message,
                 checkpoint: None,
             }),
             cx,
         );
+
+        self.run_turn(cx, async move |this, cx| {
+            let old_checkpoint = git_store
+                .update(cx, |git, cx| git.checkpoint(cx))?
+                .await
+                .context("failed to get old checkpoint")
+                .log_err();
+            this.update(cx, |this, cx| {
+                if let Some((_ix, message)) = this.last_user_message() {
+                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
+                        git_checkpoint,
+                        show: false,
+                    });
+                }
+                this.connection.prompt(message_id, request, cx)
+            })?
+            .await
+        })
+    }
+
+    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
+        self.run_turn(cx, async move |this, cx| {
+            this.update(cx, |this, cx| {
+                this.connection
+                    .resume(&this.session_id, cx)
+                    .map(|resume| resume.run(cx))
+            })?
+            .context("resuming a session is not supported")?
+            .await
+        })
+    }
+
+    fn run_turn(
+        &mut self,
+        cx: &mut Context<Self>,
+        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
+    ) -> BoxFuture<'static, Result<()>> {
         self.clear_completed_plan_entries(cx);
 
-        let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
         let (tx, rx) = oneshot::channel();
         let cancel_task = self.cancel(cx);
-        let request = acp::PromptRequest {
-            prompt: message,
-            session_id: self.session_id.clone(),
-        };
-
-        self.send_task = Some(cx.spawn({
-            let message_id = message_id.clone();
-            async move |this, cx| {
-                cancel_task.await;
 
-                old_checkpoint_tx.send(old_checkpoint.await).ok();
-                if let Ok(result) = this.update(cx, |this, cx| {
-                    this.connection.prompt(message_id, request, cx)
-                }) {
-                    tx.send(result.await).log_err();
-                }
-            }
+        self.send_task = Some(cx.spawn(async move |this, cx| {
+            cancel_task.await;
+            tx.send(f(this, cx).await).ok();
         }));
 
         cx.spawn(async move |this, cx| {
-            let old_checkpoint = old_checkpoint_rx
-                .await
-                .map_err(|_| anyhow!("send canceled"))
-                .flatten()
-                .context("failed to get old checkpoint")
-                .log_err();
-
             let response = rx.await;
 
-            if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
-                let new_checkpoint = git_store
-                    .update(cx, |git, cx| git.checkpoint(cx))?
-                    .await
-                    .context("failed to get new checkpoint")
-                    .log_err();
-                if let Some(new_checkpoint) = new_checkpoint {
-                    let equal = git_store
-                        .update(cx, |git, cx| {
-                            git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
-                        })?
-                        .await
-                        .unwrap_or(true);
-                    if !equal {
-                        this.update(cx, |this, cx| {
-                            if let Some((ix, message)) = this.user_message_mut(&message_id) {
-                                message.checkpoint = Some(old_checkpoint);
-                                cx.emit(AcpThreadEvent::EntryUpdated(ix));
-                            }
-                        })?;
-                    }
-                }
-            }
+            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
+                .await?;
 
             this.update(cx, |this, cx| {
                 match response {
@@ -1294,7 +1302,10 @@ impl AcpThread {
             return Task::ready(Err(anyhow!("message not found")));
         };
 
-        let checkpoint = message.checkpoint.clone();
+        let checkpoint = message
+            .checkpoint
+            .as_ref()
+            .map(|c| c.git_checkpoint.clone());
 
         let git_store = self.project.read(cx).git_store().clone();
         cx.spawn(async move |this, cx| {
@@ -1316,6 +1327,59 @@ impl AcpThread {
         })
     }
 
+    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let git_store = self.project.read(cx).git_store().clone();
+
+        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
+            if let Some(checkpoint) = message.checkpoint.as_ref() {
+                checkpoint.git_checkpoint.clone()
+            } else {
+                return Task::ready(Ok(()));
+            }
+        } else {
+            return Task::ready(Ok(()));
+        };
+
+        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
+        cx.spawn(async move |this, cx| {
+            let new_checkpoint = new_checkpoint
+                .await
+                .context("failed to get new checkpoint")
+                .log_err();
+            if let Some(new_checkpoint) = new_checkpoint {
+                let equal = git_store
+                    .update(cx, |git, cx| {
+                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
+                    })?
+                    .await
+                    .unwrap_or(true);
+                this.update(cx, |this, cx| {
+                    let (ix, message) = this.last_user_message().context("no user message")?;
+                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
+                    checkpoint.show = !equal;
+                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
+                    anyhow::Ok(())
+                })??;
+            }
+
+            Ok(())
+        })
+    }
+
+    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
+        self.entries
+            .iter_mut()
+            .enumerate()
+            .rev()
+            .find_map(|(ix, entry)| {
+                if let AgentThreadEntry::UserMessage(message) = entry {
+                    Some((ix, message))
+                } else {
+                    None
+                }
+            })
+    }
+
     fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
         self.entries.iter().find_map(|entry| {
             if let AgentThreadEntry::UserMessage(message) = entry {
@@ -1552,6 +1616,7 @@ mod tests {
     use settings::SettingsStore;
     use smol::stream::StreamExt as _;
     use std::{
+        any::Any,
         cell::RefCell,
         path::Path,
         rc::Rc,
@@ -2284,6 +2349,10 @@ mod tests {
                 _session_id: session_id.clone(),
             }))
         }
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
     }
 
     struct FakeAgentSessionEditor {

crates/acp_thread/src/connection.rs 🔗

@@ -4,7 +4,7 @@ use anyhow::Result;
 use collections::IndexMap;
 use gpui::{Entity, SharedString, Task};
 use project::Project;
-use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
+use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 use ui::{App, IconName};
 use uuid::Uuid;
 
@@ -36,6 +36,14 @@ pub trait AgentConnection {
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>>;
 
+    fn resume(
+        &self,
+        _session_id: &acp::SessionId,
+        _cx: &mut App,
+    ) -> Option<Rc<dyn AgentSessionResume>> {
+        None
+    }
+
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 
     fn session_editor(
@@ -53,12 +61,24 @@ pub trait AgentConnection {
     fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
         None
     }
+
+    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
+}
+
+impl dyn AgentConnection {
+    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
+        self.into_any().downcast().ok()
+    }
 }
 
 pub trait AgentSessionEditor {
     fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 }
 
+pub trait AgentSessionResume {
+    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
+}
+
 #[derive(Debug)]
 pub struct AuthRequired;
 
@@ -299,6 +319,10 @@ mod test_support {
         ) -> Option<Rc<dyn AgentSessionEditor>> {
             Some(Rc::new(StubAgentSessionEditor))
         }
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
     }
 
     struct StubAgentSessionEditor;

crates/agent2/src/agent.rs 🔗

@@ -1,9 +1,8 @@
-use crate::{AgentResponseEvent, Thread, templates::Templates};
 use crate::{
-    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
-    EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
-    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
-    WebSearchTool,
+    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
+    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
+    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
+    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
 };
 use acp_thread::AgentModelSelector;
 use agent_client_protocol as acp;
@@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
 use anyhow::{Context as _, Result, anyhow};
 use collections::{HashSet, IndexMap};
 use fs::Fs;
+use futures::channel::mpsc;
 use futures::{StreamExt, future};
 use gpui::{
     App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
@@ -21,6 +21,7 @@ use prompt_store::{
     ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 };
 use settings::update_settings_file;
+use std::any::Any;
 use std::cell::RefCell;
 use std::collections::HashMap;
 use std::path::Path;
@@ -426,9 +427,9 @@ impl NativeAgent {
         self.models.refresh_list(cx);
         for session in self.sessions.values_mut() {
             session.thread.update(cx, |thread, _| {
-                let model_id = LanguageModels::model_id(&thread.selected_model);
+                let model_id = LanguageModels::model_id(&thread.model());
                 if let Some(model) = self.models.model_from_id(&model_id) {
-                    thread.selected_model = model.clone();
+                    thread.set_model(model.clone());
                 }
             });
         }
@@ -439,6 +440,124 @@ impl NativeAgent {
 #[derive(Clone)]
 pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 
+impl NativeAgentConnection {
+    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
+        self.0
+            .read(cx)
+            .sessions
+            .get(session_id)
+            .map(|session| session.thread.clone())
+    }
+
+    fn run_turn(
+        &self,
+        session_id: acp::SessionId,
+        cx: &mut App,
+        f: impl 'static
+        + FnOnce(
+            Entity<Thread>,
+            &mut App,
+        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+    ) -> Task<Result<acp::PromptResponse>> {
+        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
+            agent
+                .sessions
+                .get_mut(&session_id)
+                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
+        }) else {
+            return Task::ready(Err(anyhow!("Session not found")));
+        };
+        log::debug!("Found session for: {}", session_id);
+
+        let mut response_stream = match f(thread, cx) {
+            Ok(stream) => stream,
+            Err(err) => return Task::ready(Err(err)),
+        };
+        cx.spawn(async move |cx| {
+            // Handle response stream and forward to session.acp_thread
+            while let Some(result) = response_stream.next().await {
+                match result {
+                    Ok(event) => {
+                        log::trace!("Received completion event: {:?}", event);
+
+                        match event {
+                            AgentResponseEvent::Text(text) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.push_assistant_content_block(
+                                        acp::ContentBlock::Text(acp::TextContent {
+                                            text,
+                                            annotations: None,
+                                        }),
+                                        false,
+                                        cx,
+                                    )
+                                })?;
+                            }
+                            AgentResponseEvent::Thinking(text) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.push_assistant_content_block(
+                                        acp::ContentBlock::Text(acp::TextContent {
+                                            text,
+                                            annotations: None,
+                                        }),
+                                        true,
+                                        cx,
+                                    )
+                                })?;
+                            }
+                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+                                tool_call,
+                                options,
+                                response,
+                            }) => {
+                                let recv = acp_thread.update(cx, |thread, cx| {
+                                    thread.request_tool_call_authorization(tool_call, options, cx)
+                                })?;
+                                cx.background_spawn(async move {
+                                    if let Some(option) = recv
+                                        .await
+                                        .context("authorization sender was dropped")
+                                        .log_err()
+                                    {
+                                        response
+                                            .send(option)
+                                            .map(|_| anyhow!("authorization receiver was dropped"))
+                                            .log_err();
+                                    }
+                                })
+                                .detach();
+                            }
+                            AgentResponseEvent::ToolCall(tool_call) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.upsert_tool_call(tool_call, cx)
+                                })?;
+                            }
+                            AgentResponseEvent::ToolCallUpdate(update) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.update_tool_call(update, cx)
+                                })??;
+                            }
+                            AgentResponseEvent::Stop(stop_reason) => {
+                                log::debug!("Assistant message complete: {:?}", stop_reason);
+                                return Ok(acp::PromptResponse { stop_reason });
+                            }
+                        }
+                    }
+                    Err(e) => {
+                        log::error!("Error in model response stream: {:?}", e);
+                        return Err(e);
+                    }
+                }
+            }
+
+            log::info!("Response stream completed");
+            anyhow::Ok(acp::PromptResponse {
+                stop_reason: acp::StopReason::EndTurn,
+            })
+        })
+    }
+}
+
 impl AgentModelSelector for NativeAgentConnection {
     fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
         log::debug!("NativeAgentConnection::list_models called");
@@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
         };
 
         thread.update(cx, |thread, _cx| {
-            thread.selected_model = model.clone();
+            thread.set_model(model.clone());
         });
 
         update_settings_file::<AgentSettings>(
@@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
         else {
             return Task::ready(Err(anyhow!("Session not found")));
         };
-        let model = thread.read(cx).selected_model.clone();
+        let model = thread.read(cx).model().clone();
         let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
         else {
             return Task::ready(Err(anyhow!("Provider not found")));
@@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     ) -> Task<Result<acp::PromptResponse>> {
         let id = id.expect("UserMessageId is required");
         let session_id = params.session_id.clone();
-        let agent = self.0.clone();
         log::info!("Received prompt request for session: {}", session_id);
         log::debug!("Prompt blocks count: {}", params.prompt.len());
 
-        cx.spawn(async move |cx| {
-            // Get session
-            let (thread, acp_thread) = agent
-                .update(cx, |agent, _| {
-                    agent
-                        .sessions
-                        .get_mut(&session_id)
-                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
-                })?
-                .ok_or_else(|| {
-                    log::error!("Session not found: {}", session_id);
-                    anyhow::anyhow!("Session not found")
-                })?;
-            log::debug!("Found session for: {}", session_id);
-
+        self.run_turn(session_id, cx, |thread, cx| {
             let content: Vec<UserMessageContent> = params
                 .prompt
                 .into_iter()
@@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             log::debug!("Message id: {:?}", id);
             log::debug!("Message content: {:?}", content);
 
-            // Get model using the ModelSelector capability (always available for agent2)
-            // Get the selected model from the thread directly
-            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
-
-            // Send to thread
-            log::info!("Sending message to thread with model: {:?}", model.name());
-            let mut response_stream =
-                thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
-
-            // Handle response stream and forward to session.acp_thread
-            while let Some(result) = response_stream.next().await {
-                match result {
-                    Ok(event) => {
-                        log::trace!("Received completion event: {:?}", event);
-
-                        match event {
-                            AgentResponseEvent::Text(text) => {
-                                acp_thread.update(cx, |thread, cx| {
-                                    thread.push_assistant_content_block(
-                                        acp::ContentBlock::Text(acp::TextContent {
-                                            text,
-                                            annotations: None,
-                                        }),
-                                        false,
-                                        cx,
-                                    )
-                                })?;
-                            }
-                            AgentResponseEvent::Thinking(text) => {
-                                acp_thread.update(cx, |thread, cx| {
-                                    thread.push_assistant_content_block(
-                                        acp::ContentBlock::Text(acp::TextContent {
-                                            text,
-                                            annotations: None,
-                                        }),
-                                        true,
-                                        cx,
-                                    )
-                                })?;
-                            }
-                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
-                                tool_call,
-                                options,
-                                response,
-                            }) => {
-                                let recv = acp_thread.update(cx, |thread, cx| {
-                                    thread.request_tool_call_authorization(tool_call, options, cx)
-                                })?;
-                                cx.background_spawn(async move {
-                                    if let Some(option) = recv
-                                        .await
-                                        .context("authorization sender was dropped")
-                                        .log_err()
-                                    {
-                                        response
-                                            .send(option)
-                                            .map(|_| anyhow!("authorization receiver was dropped"))
-                                            .log_err();
-                                    }
-                                })
-                                .detach();
-                            }
-                            AgentResponseEvent::ToolCall(tool_call) => {
-                                acp_thread.update(cx, |thread, cx| {
-                                    thread.upsert_tool_call(tool_call, cx)
-                                })?;
-                            }
-                            AgentResponseEvent::ToolCallUpdate(update) => {
-                                acp_thread.update(cx, |thread, cx| {
-                                    thread.update_tool_call(update, cx)
-                                })??;
-                            }
-                            AgentResponseEvent::Stop(stop_reason) => {
-                                log::debug!("Assistant message complete: {:?}", stop_reason);
-                                return Ok(acp::PromptResponse { stop_reason });
-                            }
-                        }
-                    }
-                    Err(e) => {
-                        log::error!("Error in model response stream: {:?}", e);
-                        // TODO: Consider sending an error message to the UI
-                        break;
-                    }
-                }
-            }
-
-            log::info!("Response stream completed");
-            anyhow::Ok(acp::PromptResponse {
-                stop_reason: acp::StopReason::EndTurn,
-            })
+            Ok(thread.update(cx, |thread, cx| {
+                log::info!(
+                    "Sending message to thread with model: {:?}",
+                    thread.model().name()
+                );
+                thread.send(id, content, cx)
+            }))
         })
     }
 
+    fn resume(
+        &self,
+        session_id: &acp::SessionId,
+        _cx: &mut App,
+    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
+        Some(Rc::new(NativeAgentSessionResume {
+            connection: self.clone(),
+            session_id: session_id.clone(),
+        }) as _)
+    }
+
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
         log::info!("Cancelling on session: {}", session_id);
         self.0.update(cx, |agent, cx| {
@@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
         })
     }
+
+    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+        self
+    }
 }
 
 struct NativeAgentSessionEditor(Entity<Thread>);
@@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
     }
 }
 
+struct NativeAgentSessionResume {
+    connection: NativeAgentConnection,
+    session_id: acp::SessionId,
+}
+
+impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
+    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
+        self.connection
+            .run_turn(self.session_id.clone(), cx, |thread, cx| {
+                thread.update(cx, |thread, cx| thread.resume(cx))
+            })
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -957,7 +1007,7 @@ mod tests {
         agent.read_with(cx, |agent, _| {
             let session = agent.sessions.get(&session_id).unwrap();
             session.thread.read_with(cx, |thread, _| {
-                assert_eq!(thread.selected_model.id().0, "fake");
+                assert_eq!(thread.model().id().0, "fake");
             });
         });
 

crates/agent2/src/tests/mod.rs 🔗

@@ -12,9 +12,9 @@ use gpui::{
 };
 use indoc::indoc;
 use language_model::{
-    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
-    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
-    fake_provider::FakeLanguageModel,
+    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
+    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
+    Role, StopReason, fake_provider::FakeLanguageModel,
 };
 use project::Project;
 use prompt_store::ProjectContext;
@@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
     assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 }
 
+#[gpui::test]
+async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let events = thread.update(cx, |thread, cx| {
+        thread.add_tool(EchoTool);
+        thread.send(UserMessageId::new(), ["abc"], cx)
+    });
+    cx.run_until_parked();
+    let tool_use = LanguageModelToolUse {
+        id: "tool_id_1".into(),
+        name: EchoTool.name().into(),
+        raw_input: "{}".into(),
+        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
+        is_input_complete: true,
+    };
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+    fake_model.end_last_completion_stream();
+
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    let tool_result = LanguageModelToolResult {
+        tool_use_id: "tool_id_1".into(),
+        tool_name: EchoTool.name().into(),
+        is_error: false,
+        content: "def".into(),
+        output: Some("def".into()),
+    };
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["abc".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![MessageContent::ToolUse(tool_use.clone())],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::ToolResult(tool_result.clone())],
+                cache: false
+            },
+        ]
+    );
+
+    // Simulate reaching tool use limit.
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
+        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
+    ));
+    fake_model.end_last_completion_stream();
+    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
+    assert!(
+        last_event
+            .unwrap_err()
+            .is::<language_model::ToolUseLimitReachedError>()
+    );
+
+    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["abc".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![MessageContent::ToolUse(tool_use)],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::ToolResult(tool_result)],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Continue where you left off".into()],
+                cache: false
+            }
+        ]
+    );
+
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
+    fake_model.end_last_completion_stream();
+    events.collect::<Vec<_>>().await;
+    thread.read_with(cx, |thread, _cx| {
+        assert_eq!(
+            thread.last_message().unwrap().to_markdown(),
+            indoc! {"
+                ## Assistant
+
+                Done
+            "}
+        )
+    });
+
+    // Ensure we error if calling resume when tool use limit was *not* reached.
+    let error = thread
+        .update(cx, |thread, cx| thread.resume(cx))
+        .unwrap_err();
+    assert_eq!(
+        error.to_string(),
+        "can only resume after tool use limit is reached"
+    )
+}
+
+#[gpui::test]
+async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let events = thread.update(cx, |thread, cx| {
+        thread.add_tool(EchoTool);
+        thread.send(UserMessageId::new(), ["abc"], cx)
+    });
+    cx.run_until_parked();
+
+    let tool_use = LanguageModelToolUse {
+        id: "tool_id_1".into(),
+        name: EchoTool.name().into(),
+        raw_input: "{}".into(),
+        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
+        is_input_complete: true,
+    };
+    let tool_result = LanguageModelToolResult {
+        tool_use_id: "tool_id_1".into(),
+        tool_name: EchoTool.name().into(),
+        is_error: false,
+        content: "def".into(),
+        output: Some("def".into()),
+    };
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
+        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
+    ));
+    fake_model.end_last_completion_stream();
+    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
+    assert!(
+        last_event
+            .unwrap_err()
+            .is::<language_model::ToolUseLimitReachedError>()
+    );
+
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), vec!["ghi"], cx)
+    });
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["abc".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![MessageContent::ToolUse(tool_use)],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::ToolResult(tool_result)],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["ghi".into()],
+                cache: false
+            }
+        ]
+    );
+}
+
 async fn expect_tool_call(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 ) -> acp::ToolCall {
     let event = events
         .next()
@@ -411,7 +597,7 @@ async fn expect_tool_call(
 }
 
 async fn expect_tool_call_update_fields(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 ) -> acp::ToolCallUpdate {
     let event = events
         .next()
@@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields(
 }
 
 async fn next_tool_call_authorization(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 ) -> ToolCallAuthorization {
     loop {
         let event = events
@@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
 }
 
 /// Filters out the stop events for asserting against in tests
-fn stop_events(
-    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-) -> Vec<acp::StopReason> {
+fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
     result_events
         .into_iter()
         .filter_map(|event| match event.unwrap() {

crates/agent2/src/tests/test_tools.rs 🔗

@@ -7,7 +7,7 @@ use std::future;
 #[derive(JsonSchema, Serialize, Deserialize)]
 pub struct EchoToolInput {
     /// The text to echo.
-    text: String,
+    pub text: String,
 }
 
 pub struct EchoTool;

crates/agent2/src/thread.rs 🔗

@@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
 use acp_thread::{MentionUri, UserMessageId};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
-use cloud_llm_client::{CompletionIntent, CompletionMode};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
 use collections::IndexMap;
 use fs::Fs;
 use futures::{
@@ -14,10 +14,10 @@ use futures::{
 };
 use gpui::{App, Context, Entity, SharedString, Task};
 use language_model::{
-    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
-    LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
-    LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
+    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
+    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
+    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
 };
 use project::Project;
 use prompt_store::ProjectContext;
@@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
 pub enum Message {
     User(UserMessage),
     Agent(AgentMessage),
+    Resume,
 }
 
 impl Message {
@@ -47,6 +48,7 @@ impl Message {
         match self {
             Message::User(message) => message.to_markdown(),
             Message::Agent(message) => message.to_markdown(),
+            Message::Resume => "[resumed after tool use limit was reached]".into(),
         }
     }
 }
@@ -320,7 +322,11 @@ impl AgentMessage {
     }
 
     pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
-        let mut content = Vec::with_capacity(self.content.len());
+        let mut assistant_message = LanguageModelRequestMessage {
+            role: Role::Assistant,
+            content: Vec::with_capacity(self.content.len()),
+            cache: false,
+        };
         for chunk in &self.content {
             let chunk = match chunk {
                 AgentMessageContent::Text(text) => {
@@ -342,29 +348,30 @@ impl AgentMessage {
                     language_model::MessageContent::Image(value.clone())
                 }
             };
-            content.push(chunk);
+            assistant_message.content.push(chunk);
         }
 
-        let mut messages = vec![LanguageModelRequestMessage {
-            role: Role::Assistant,
-            content,
+        let mut user_message = LanguageModelRequestMessage {
+            role: Role::User,
+            content: Vec::new(),
             cache: false,
-        }];
+        };
 
-        if !self.tool_results.is_empty() {
-            let mut tool_results = Vec::with_capacity(self.tool_results.len());
-            for tool_result in self.tool_results.values() {
-                tool_results.push(language_model::MessageContent::ToolResult(
+        for tool_result in self.tool_results.values() {
+            user_message
+                .content
+                .push(language_model::MessageContent::ToolResult(
                     tool_result.clone(),
                 ));
-            }
-            messages.push(LanguageModelRequestMessage {
-                role: Role::User,
-                content: tool_results,
-                cache: false,
-            });
         }
 
+        let mut messages = Vec::new();
+        if !assistant_message.content.is_empty() {
+            messages.push(assistant_message);
+        }
+        if !user_message.content.is_empty() {
+            messages.push(user_message);
+        }
         messages
     }
 }
@@ -413,11 +420,12 @@ pub struct Thread {
     running_turn: Option<Task<()>>,
     pending_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+    tool_use_limit_reached: bool,
     context_server_registry: Entity<ContextServerRegistry>,
     profile_id: AgentProfileId,
     project_context: Rc<RefCell<ProjectContext>>,
     templates: Arc<Templates>,
-    pub selected_model: Arc<dyn LanguageModel>,
+    model: Arc<dyn LanguageModel>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
 }
@@ -429,7 +437,7 @@ impl Thread {
         context_server_registry: Entity<ContextServerRegistry>,
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
-        default_model: Arc<dyn LanguageModel>,
+        model: Arc<dyn LanguageModel>,
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -439,11 +447,12 @@ impl Thread {
             running_turn: None,
             pending_message: None,
             tools: BTreeMap::default(),
+            tool_use_limit_reached: false,
             context_server_registry,
             profile_id,
             project_context,
             templates,
-            selected_model: default_model,
+            model,
             project,
             action_log,
         }
@@ -457,7 +466,19 @@ impl Thread {
         &self.action_log
     }
 
-    pub fn set_mode(&mut self, mode: CompletionMode) {
+    pub fn model(&self) -> &Arc<dyn LanguageModel> {
+        &self.model
+    }
+
+    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+        self.model = model;
+    }
+
+    pub fn completion_mode(&self) -> CompletionMode {
+        self.completion_mode
+    }
+
+    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
         self.completion_mode = mode;
     }
 
@@ -499,36 +520,59 @@ impl Thread {
         Ok(())
     }
 
+    pub fn resume(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
+        anyhow::ensure!(
+            self.tool_use_limit_reached,
+            "can only resume after tool use limit is reached"
+        );
+
+        self.messages.push(Message::Resume);
+        cx.notify();
+
+        log::info!("Total messages in thread: {}", self.messages.len());
+        Ok(self.run_turn(cx))
+    }
+
     /// Sending a message results in the model streaming a response, which could include tool calls.
     /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
     /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
     pub fn send<T>(
         &mut self,
-        message_id: UserMessageId,
+        id: UserMessageId,
         content: impl IntoIterator<Item = T>,
         cx: &mut Context<Self>,
-    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
+    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
     where
         T: Into<UserMessageContent>,
     {
-        let model = self.selected_model.clone();
+        log::info!("Thread::send called with model: {:?}", self.model.name());
+
         let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
-        log::info!("Thread::send called with model: {:?}", model.name());
         log::debug!("Thread::send content: {:?}", content);
 
+        self.messages
+            .push(Message::User(UserMessage { id, content }));
         cx.notify();
-        let (events_tx, events_rx) =
-            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
-        let event_stream = AgentResponseEventStream(events_tx);
 
-        self.messages.push(Message::User(UserMessage {
-            id: message_id.clone(),
-            content,
-        }));
         log::info!("Total messages in thread: {}", self.messages.len());
+        self.run_turn(cx)
+    }
+
+    fn run_turn(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+        let model = self.model.clone();
+        let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
+        let event_stream = AgentResponseEventStream(events_tx);
+        let message_ix = self.messages.len().saturating_sub(1);
+        self.tool_use_limit_reached = false;
         self.running_turn = Some(cx.spawn(async move |this, cx| {
             log::info!("Starting agent turn execution");
-            let turn_result = async {
+            let turn_result: Result<()> = async {
                 let mut completion_intent = CompletionIntent::UserPrompt;
                 loop {
                     log::debug!(
@@ -543,13 +587,22 @@ impl Thread {
                     let mut events = model.stream_completion(request, cx).await?;
                     log::debug!("Stream completion started successfully");
 
+                    let mut tool_use_limit_reached = false;
                     let mut tool_uses = FuturesUnordered::new();
                     while let Some(event) = events.next().await {
                         match event? {
+                            LanguageModelCompletionEvent::StatusUpdate(
+                                CompletionRequestStatus::ToolUseLimitReached,
+                            ) => {
+                                tool_use_limit_reached = true;
+                            }
                             LanguageModelCompletionEvent::Stop(reason) => {
                                 event_stream.send_stop(reason);
                                 if reason == StopReason::Refusal {
-                                    this.update(cx, |this, _cx| this.truncate(message_id))??;
+                                    this.update(cx, |this, _cx| {
+                                        this.flush_pending_message();
+                                        this.messages.truncate(message_ix);
+                                    })?;
                                     return Ok(());
                                 }
                             }
@@ -567,12 +620,7 @@ impl Thread {
                         }
                     }
 
-                    if tool_uses.is_empty() {
-                        log::info!("No tool uses found, completing turn");
-                        return Ok(());
-                    }
-                    log::info!("Found {} tool uses to execute", tool_uses.len());
-
+                    let used_tools = tool_uses.is_empty();
                     while let Some(tool_result) = tool_uses.next().await {
                         log::info!("Tool finished {:?}", tool_result);
 
@@ -596,8 +644,17 @@ impl Thread {
                         .ok();
                     }
 
-                    this.update(cx, |this, _| this.flush_pending_message())?;
-                    completion_intent = CompletionIntent::ToolResults;
+                    if tool_use_limit_reached {
+                        log::info!("Tool use limit reached, completing turn");
+                        this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
+                        return Err(language_model::ToolUseLimitReachedError.into());
+                    } else if used_tools {
+                        log::info!("No tool uses found, completing turn");
+                        return Ok(());
+                    } else {
+                        this.update(cx, |this, _| this.flush_pending_message())?;
+                        completion_intent = CompletionIntent::ToolResults;
+                    }
                 }
             }
             .await;
@@ -678,10 +735,10 @@ impl Thread {
     fn handle_text_event(
         &mut self,
         new_text: String,
-        events_stream: &AgentResponseEventStream,
+        event_stream: &AgentResponseEventStream,
         cx: &mut Context<Self>,
     ) {
-        events_stream.send_text(&new_text);
+        event_stream.send_text(&new_text);
 
         let last_message = self.pending_message();
         if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
@@ -798,8 +855,9 @@ impl Thread {
             status: Some(acp::ToolCallStatus::InProgress),
             ..Default::default()
         });
-        let supports_images = self.selected_model.supports_images();
+        let supports_images = self.model.supports_images();
         let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
+        log::info!("Running tool {}", tool_use.name);
         Some(cx.foreground_executor().spawn(async move {
             let tool_result = tool_result.await.and_then(|output| {
                 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
@@ -902,7 +960,7 @@ impl Thread {
                         name: tool_name,
                         description: tool.description().to_string(),
                         input_schema: tool
-                            .input_schema(self.selected_model.tool_input_format())
+                            .input_schema(self.model.tool_input_format())
                             .log_err()?,
                     })
                 })
@@ -917,7 +975,7 @@ impl Thread {
             thread_id: None,
             prompt_id: None,
             intent: Some(completion_intent),
-            mode: Some(self.completion_mode),
+            mode: Some(self.completion_mode.into()),
             messages,
             tools,
             tool_choice: None,
@@ -935,7 +993,7 @@ impl Thread {
             .profiles
             .get(&self.profile_id)
             .context("profile not found")?;
-        let provider_id = self.selected_model.provider_id();
+        let provider_id = self.model.provider_id();
 
         Ok(self
             .tools
@@ -971,6 +1029,11 @@ impl Thread {
             match message {
                 Message::User(message) => messages.push(message.to_request()),
                 Message::Agent(message) => messages.extend(message.to_request()),
+                Message::Resume => messages.push(LanguageModelRequestMessage {
+                    role: Role::User,
+                    content: vec!["Continue where you left off".into()],
+                    cache: false,
+                }),
             }
         }
 
@@ -1123,9 +1186,7 @@ where
 }
 
 #[derive(Clone)]
-struct AgentResponseEventStream(
-    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-);
+struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
 
 impl AgentResponseEventStream {
     fn send_text(&self, text: &str) {
@@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
         }
     }
 
-    fn send_error(&self, error: LanguageModelCompletionError) {
-        self.0.unbounded_send(Err(error)).ok();
+    fn send_error(&self, error: impl Into<anyhow::Error>) {
+        self.0.unbounded_send(Err(error.into())).ok();
     }
 }
 
@@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
 impl ToolCallEventStream {
     #[cfg(test)]
     pub fn test() -> (Self, ToolCallEventStreamReceiver) {
-        let (events_tx, events_rx) =
-            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
+        let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
 
         let stream = ToolCallEventStream::new(
             &LanguageModelToolUse {
@@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
 }
 
 #[cfg(test)]
-pub struct ToolCallEventStreamReceiver(
-    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-);
+pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
 
 #[cfg(test)]
 impl ToolCallEventStreamReceiver {
@@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
 
 #[cfg(test)]
 impl std::ops::Deref for ToolCallEventStreamReceiver {
-    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
+    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
 
     fn deref(&self) -> &Self::Target {
         &self.0

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
             thread.build_completion_request(CompletionIntent::ToolResults, cx)
         });
         let thread = self.thread.read(cx);
-        let model = thread.selected_model.clone();
+        let model = thread.model().clone();
         let action_log = thread.action_log().clone();
 
         let authorize = self.authorize(&input, &event_stream, cx);

crates/agent_servers/src/acp/v0.rs 🔗

@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
 use futures::channel::oneshot;
 use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 use project::Project;
-use std::{cell::RefCell, path::Path, rc::Rc};
+use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
 use ui::App;
 use util::ResultExt as _;
 
@@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection {
             })
             .detach_and_log_err(cx)
     }
+
+    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+        self
+    }
 }

crates/agent_servers/src/acp/v1.rs 🔗

@@ -3,9 +3,9 @@ use anyhow::anyhow;
 use collections::HashMap;
 use futures::channel::oneshot;
 use project::Project;
-use std::cell::RefCell;
 use std::path::Path;
 use std::rc::Rc;
+use std::{any::Any, cell::RefCell};
 
 use anyhow::{Context as _, Result};
 use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
@@ -191,6 +191,10 @@ impl AgentConnection for AcpConnection {
             .spawn(async move { conn.cancel(params).await })
             .detach();
     }
+
+    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+        self
+    }
 }
 
 struct ClientDelegate {

crates/agent_servers/src/claude.rs 🔗

@@ -6,6 +6,7 @@ use context_server::listener::McpServerTool;
 use project::Project;
 use settings::SettingsStore;
 use smol::process::Child;
+use std::any::Any;
 use std::cell::RefCell;
 use std::fmt::Display;
 use std::path::Path;
@@ -289,6 +290,10 @@ impl AgentConnection for ClaudeAgentConnection {
             })
             .log_err();
     }
+
+    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+        self
+    }
 }
 
 #[derive(Clone, Copy)]

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -7,20 +7,21 @@ use action_log::ActionLog;
 use agent::{TextThreadStore, ThreadStore};
 use agent_client_protocol::{self as acp};
 use agent_servers::AgentServer;
-use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
+use agent_settings::{AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
 use anyhow::bail;
 use audio::{Audio, Sound};
 use buffer_diff::BufferDiff;
+use client::zed_urls;
 use collections::{HashMap, HashSet};
 use editor::scroll::Autoscroll;
 use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
 use file_icons::FileIcons;
 use gpui::{
-    Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
-    FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
-    SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
-    Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
-    linear_gradient, list, percentage, point, prelude::*, pulsating_between,
+    Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
+    Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
+    PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
+    TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
+    linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
 };
 use language::Buffer;
 use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
@@ -32,8 +33,8 @@ use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
 use text::Anchor;
 use theme::ThemeSettings;
 use ui::{
-    Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState,
-    Tooltip, prelude::*,
+    Callout, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle,
+    Scrollbar, ScrollbarState, Tooltip, prelude::*,
 };
 use util::{ResultExt, size::format_file_size, time::duration_alt_display};
 use workspace::{CollaboratorId, Workspace};
@@ -44,16 +45,39 @@ use super::entry_view_state::EntryViewState;
 use crate::acp::AcpModelSelectorPopover;
 use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 use crate::agent_diff::AgentDiff;
-use crate::ui::{AgentNotification, AgentNotificationEvent};
+use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip};
 use crate::{
-    AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
+    AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow,
+    KeepAll, OpenAgentDiff, RejectAll, ToggleBurnMode,
 };
 
 const RESPONSE_PADDING_X: Pixels = px(19.);
-
 pub const MIN_EDITOR_LINES: usize = 4;
 pub const MAX_EDITOR_LINES: usize = 8;
 
+enum ThreadError {
+    PaymentRequired,
+    ModelRequestLimitReached(cloud_llm_client::Plan),
+    ToolUseLimitReached,
+    Other(SharedString),
+}
+
+impl ThreadError {
+    fn from_err(error: anyhow::Error) -> Self {
+        if error.is::<language_model::PaymentRequiredError>() {
+            Self::PaymentRequired
+        } else if error.is::<language_model::ToolUseLimitReachedError>() {
+            Self::ToolUseLimitReached
+        } else if let Some(error) =
+            error.downcast_ref::<language_model::ModelRequestLimitReachedError>()
+        {
+            Self::ModelRequestLimitReached(error.plan)
+        } else {
+            Self::Other(error.to_string().into())
+        }
+    }
+}
+
 pub struct AcpThreadView {
     agent: Rc<dyn AgentServer>,
     workspace: WeakEntity<Workspace>,
@@ -66,7 +90,7 @@ pub struct AcpThreadView {
     model_selector: Option<Entity<AcpModelSelectorPopover>>,
     notifications: Vec<WindowHandle<AgentNotification>>,
     notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
-    last_error: Option<Entity<Markdown>>,
+    thread_error: Option<ThreadError>,
     list_state: ListState,
     scrollbar_state: ScrollbarState,
     auth_task: Option<Task<()>>,
@@ -151,7 +175,7 @@ impl AcpThreadView {
             entry_view_state: EntryViewState::default(),
             list_state: list_state.clone(),
             scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
-            last_error: None,
+            thread_error: None,
             auth_task: None,
             expanded_tool_calls: HashSet::default(),
             expanded_thinking_blocks: HashSet::default(),
@@ -316,7 +340,7 @@ impl AcpThreadView {
     }
 
     pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
-        self.last_error.take();
+        self.thread_error.take();
 
         if let Some(thread) = self.thread() {
             self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
@@ -371,6 +395,25 @@ impl AcpThreadView {
         }
     }
 
+    fn resume_chat(&mut self, cx: &mut Context<Self>) {
+        self.thread_error.take();
+        let Some(thread) = self.thread() else {
+            return;
+        };
+
+        let task = thread.update(cx, |thread, cx| thread.resume(cx));
+        cx.spawn(async move |this, cx| {
+            let result = task.await;
+
+            this.update(cx, |this, cx| {
+                if let Err(err) = result {
+                    this.handle_thread_error(err, cx);
+                }
+            })
+        })
+        .detach();
+    }
+
     fn send(&mut self, window: &mut Window, cx: &mut Context<Self>) {
         let contents = self
             .message_editor
@@ -384,7 +427,7 @@ impl AcpThreadView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        self.last_error.take();
+        self.thread_error.take();
         self.editing_message.take();
 
         let Some(thread) = self.thread().cloned() else {
@@ -409,11 +452,9 @@ impl AcpThreadView {
         });
 
         cx.spawn(async move |this, cx| {
-            if let Err(e) = task.await {
+            if let Err(err) = task.await {
                 this.update(cx, |this, cx| {
-                    this.last_error =
-                        Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx)));
-                    cx.notify()
+                    this.handle_thread_error(err, cx);
                 })
                 .ok();
             }
@@ -476,6 +517,16 @@ impl AcpThreadView {
         })
     }
 
+    fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context<Self>) {
+        self.thread_error = Some(ThreadError::from_err(error));
+        cx.notify();
+    }
+
+    fn clear_thread_error(&mut self, cx: &mut Context<Self>) {
+        self.thread_error = None;
+        cx.notify();
+    }
+
     fn handle_thread_event(
         &mut self,
         thread: &Entity<AcpThread>,
@@ -551,7 +602,7 @@ impl AcpThreadView {
             return;
         };
 
-        self.last_error.take();
+        self.thread_error.take();
         let authenticate = connection.authenticate(method, cx);
         self.auth_task = Some(cx.spawn_in(window, {
             let project = self.project.clone();
@@ -561,9 +612,7 @@ impl AcpThreadView {
 
                 this.update_in(cx, |this, window, cx| {
                     if let Err(err) = result {
-                        this.last_error = Some(cx.new(|cx| {
-                            Markdown::new(format!("Error: {err}").into(), None, None, cx)
-                        }))
+                        this.handle_thread_error(err, cx);
                     } else {
                         this.thread_state = Self::initial_state(
                             agent,
@@ -620,9 +669,7 @@ impl AcpThreadView {
                 .py_4()
                 .px_2()
                 .children(message.id.clone().and_then(|message_id| {
-                    message.checkpoint.as_ref()?;
-
-                    Some(
+                    message.checkpoint.as_ref()?.show.then(|| {
                         Button::new("restore-checkpoint", "Restore Checkpoint")
                             .icon(IconName::Undo)
                             .icon_size(IconSize::XSmall)
@@ -630,8 +677,8 @@ impl AcpThreadView {
                             .label_size(LabelSize::XSmall)
                             .on_click(cx.listener(move |this, _, _window, cx| {
                                 this.rewind(&message_id, cx);
-                            })),
-                    )
+                            }))
+                    })
                 }))
                 .child(
                     v_flex()
@@ -2322,7 +2369,12 @@ impl AcpThreadView {
                 h_flex()
                     .flex_none()
                     .justify_between()
-                    .child(self.render_follow_toggle(cx))
+                    .child(
+                        h_flex()
+                            .gap_1()
+                            .child(self.render_follow_toggle(cx))
+                            .children(self.render_burn_mode_toggle(cx)),
+                    )
                     .child(
                         h_flex()
                             .gap_1()
@@ -2333,6 +2385,68 @@ impl AcpThreadView {
             .into_any()
     }
 
+    fn as_native_connection(&self, cx: &App) -> Option<Rc<agent2::NativeAgentConnection>> {
+        let acp_thread = self.thread()?.read(cx);
+        acp_thread.connection().clone().downcast()
+    }
+
+    fn as_native_thread(&self, cx: &App) -> Option<Entity<agent2::Thread>> {
+        let acp_thread = self.thread()?.read(cx);
+        self.as_native_connection(cx)?
+            .thread(acp_thread.session_id(), cx)
+    }
+
+    fn toggle_burn_mode(
+        &mut self,
+        _: &ToggleBurnMode,
+        _window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(thread) = self.as_native_thread(cx) else {
+            return;
+        };
+
+        thread.update(cx, |thread, _cx| {
+            let current_mode = thread.completion_mode();
+            thread.set_completion_mode(match current_mode {
+                CompletionMode::Burn => CompletionMode::Normal,
+                CompletionMode::Normal => CompletionMode::Burn,
+            });
+        });
+    }
+
+    fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+        let thread = self.as_native_thread(cx)?.read(cx);
+
+        if !thread.model().supports_burn_mode() {
+            return None;
+        }
+
+        let active_completion_mode = thread.completion_mode();
+        let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
+        let icon = if burn_mode_enabled {
+            IconName::ZedBurnModeOn
+        } else {
+            IconName::ZedBurnMode
+        };
+
+        Some(
+            IconButton::new("burn-mode", icon)
+                .icon_size(IconSize::Small)
+                .icon_color(Color::Muted)
+                .toggle_state(burn_mode_enabled)
+                .selected_icon_color(Color::Error)
+                .on_click(cx.listener(|this, _event, window, cx| {
+                    this.toggle_burn_mode(&ToggleBurnMode, window, cx);
+                }))
+                .tooltip(move |_window, cx| {
+                    cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
+                        .into()
+                })
+                .into_any_element(),
+        )
+    }
+
     fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
         let Some(thread) = self.thread() else {
             return;
@@ -3002,6 +3116,187 @@ impl AcpThreadView {
     }
 }
 
+impl AcpThreadView {
+    fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
+        let content = match self.thread_error.as_ref()? {
+            ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
+            ThreadError::PaymentRequired => self.render_payment_required_error(cx),
+            ThreadError::ModelRequestLimitReached(plan) => {
+                self.render_model_request_limit_reached_error(*plan, cx)
+            }
+            ThreadError::ToolUseLimitReached => {
+                self.render_tool_use_limit_reached_error(window, cx)?
+            }
+        };
+
+        Some(
+            div()
+                .border_t_1()
+                .border_color(cx.theme().colors().border)
+                .child(content),
+        )
+    }
+
+    fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout {
+        let icon = Icon::new(IconName::XCircle)
+            .size(IconSize::Small)
+            .color(Color::Error);
+
+        Callout::new()
+            .icon(icon)
+            .title("Error")
+            .description(error.clone())
+            .secondary_action(self.create_copy_button(error.to_string()))
+            .primary_action(self.dismiss_error_button(cx))
+            .bg_color(self.error_callout_bg(cx))
+    }
+
+    fn render_payment_required_error(&self, cx: &mut Context<Self>) -> Callout {
+        const ERROR_MESSAGE: &str =
+            "You reached your free usage limit. Upgrade to Zed Pro for more prompts.";
+
+        let icon = Icon::new(IconName::XCircle)
+            .size(IconSize::Small)
+            .color(Color::Error);
+
+        Callout::new()
+            .icon(icon)
+            .title("Free Usage Exceeded")
+            .description(ERROR_MESSAGE)
+            .tertiary_action(self.upgrade_button(cx))
+            .secondary_action(self.create_copy_button(ERROR_MESSAGE))
+            .primary_action(self.dismiss_error_button(cx))
+            .bg_color(self.error_callout_bg(cx))
+    }
+
+    fn render_model_request_limit_reached_error(
+        &self,
+        plan: cloud_llm_client::Plan,
+        cx: &mut Context<Self>,
+    ) -> Callout {
+        let error_message = match plan {
+            cloud_llm_client::Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
+            cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => {
+                "Upgrade to Zed Pro for more prompts."
+            }
+        };
+
+        let icon = Icon::new(IconName::XCircle)
+            .size(IconSize::Small)
+            .color(Color::Error);
+
+        Callout::new()
+            .icon(icon)
+            .title("Model Prompt Limit Reached")
+            .description(error_message)
+            .tertiary_action(self.upgrade_button(cx))
+            .secondary_action(self.create_copy_button(error_message))
+            .primary_action(self.dismiss_error_button(cx))
+            .bg_color(self.error_callout_bg(cx))
+    }
+
+    fn render_tool_use_limit_reached_error(
+        &self,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Option<Callout> {
+        let thread = self.as_native_thread(cx)?;
+        let supports_burn_mode = thread.read(cx).model().supports_burn_mode();
+
+        let focus_handle = self.focus_handle(cx);
+
+        let icon = Icon::new(IconName::Info)
+            .size(IconSize::Small)
+            .color(Color::Info);
+
+        Some(
+            Callout::new()
+                .icon(icon)
+                .title("Consecutive tool use limit reached.")
+                .when(supports_burn_mode, |this| {
+                    this.secondary_action(
+                        Button::new("continue-burn-mode", "Continue with Burn Mode")
+                            .style(ButtonStyle::Filled)
+                            .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+                            .layer(ElevationIndex::ModalSurface)
+                            .label_size(LabelSize::Small)
+                            .key_binding(
+                                KeyBinding::for_action_in(
+                                    &ContinueWithBurnMode,
+                                    &focus_handle,
+                                    window,
+                                    cx,
+                                )
+                                .map(|kb| kb.size(rems_from_px(10.))),
+                            )
+                            .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
+                            .on_click({
+                                cx.listener(move |this, _, _window, cx| {
+                                    thread.update(cx, |thread, _cx| {
+                                        thread.set_completion_mode(CompletionMode::Burn);
+                                    });
+                                    this.resume_chat(cx);
+                                })
+                            }),
+                    )
+                })
+                .primary_action(
+                    Button::new("continue-conversation", "Continue")
+                        .layer(ElevationIndex::ModalSurface)
+                        .label_size(LabelSize::Small)
+                        .key_binding(
+                            KeyBinding::for_action_in(&ContinueThread, &focus_handle, window, cx)
+                                .map(|kb| kb.size(rems_from_px(10.))),
+                        )
+                        .on_click(cx.listener(|this, _, _window, cx| {
+                            this.resume_chat(cx);
+                        })),
+                ),
+        )
+    }
+
+    fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
+        let message = message.into();
+
+        IconButton::new("copy", IconName::Copy)
+            .icon_size(IconSize::Small)
+            .icon_color(Color::Muted)
+            .tooltip(Tooltip::text("Copy Error Message"))
+            .on_click(move |_, _, cx| {
+                cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
+            })
+    }
+
+    fn dismiss_error_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
+        IconButton::new("dismiss", IconName::Close)
+            .icon_size(IconSize::Small)
+            .icon_color(Color::Muted)
+            .tooltip(Tooltip::text("Dismiss Error"))
+            .on_click(cx.listener({
+                move |this, _, _, cx| {
+                    this.clear_thread_error(cx);
+                    cx.notify();
+                }
+            }))
+    }
+
+    fn upgrade_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
+        Button::new("upgrade", "Upgrade")
+            .label_size(LabelSize::Small)
+            .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+            .on_click(cx.listener({
+                move |this, _, _, cx| {
+                    this.clear_thread_error(cx);
+                    cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx));
+                }
+            }))
+    }
+
+    fn error_callout_bg(&self, cx: &Context<Self>) -> Hsla {
+        cx.theme().status().error.opacity(0.08)
+    }
+}
+
 impl Focusable for AcpThreadView {
     fn focus_handle(&self, cx: &App) -> FocusHandle {
         self.message_editor.focus_handle(cx)
@@ -3016,6 +3311,7 @@ impl Render for AcpThreadView {
             .size_full()
             .key_context("AcpThread")
             .on_action(cx.listener(Self::open_agent_diff))
+            .on_action(cx.listener(Self::toggle_burn_mode))
             .bg(cx.theme().colors().panel_background)
             .child(match &self.thread_state {
                 ThreadState::Unauthenticated { connection } => v_flex()
@@ -3100,19 +3396,7 @@ impl Render for AcpThreadView {
                 }
                 _ => this,
             })
-            .when_some(self.last_error.clone(), |el, error| {
-                el.child(
-                    div()
-                        .p_2()
-                        .text_xs()
-                        .border_t_1()
-                        .border_color(cx.theme().colors().border)
-                        .bg(cx.theme().status().error_background)
-                        .child(
-                            self.render_markdown(error, default_markdown_style(false, window, cx)),
-                        ),
-                )
-            })
+            .children(self.render_thread_error(window, cx))
             .child(self.render_message_editor(window, cx))
     }
 }
@@ -3299,8 +3583,6 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
 
 #[cfg(test)]
 pub(crate) mod tests {
-    use std::path::Path;
-
     use acp_thread::StubAgentConnection;
     use agent::{TextThreadStore, ThreadStore};
     use agent_client_protocol::SessionId;
@@ -3310,6 +3592,8 @@ pub(crate) mod tests {
     use project::Project;
     use serde_json::json;
     use settings::SettingsStore;
+    use std::any::Any;
+    use std::path::Path;
 
     use super::*;
 
@@ -3547,6 +3831,10 @@ pub(crate) mod tests {
         fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
             unimplemented!()
         }
+
+        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+            self
+        }
     }
 
     pub(crate) fn init_test(cx: &mut TestAppContext) {

crates/agent_ui/src/agent_ui.rs 🔗

@@ -5,7 +5,6 @@ mod agent_diff;
 mod agent_model_selector;
 mod agent_panel;
 mod buffer_codegen;
-mod burn_mode_tooltip;
 mod context_picker;
 mod context_server_configuration;
 mod context_strip;

crates/agent_ui/src/burn_mode_tooltip.rs 🔗

@@ -1,61 +0,0 @@
-use gpui::{Context, FontWeight, IntoElement, Render, Window};
-use ui::{prelude::*, tooltip_container};
-
-pub struct BurnModeTooltip {
-    selected: bool,
-}
-
-impl BurnModeTooltip {
-    pub fn new() -> Self {
-        Self { selected: false }
-    }
-
-    pub fn selected(mut self, selected: bool) -> Self {
-        self.selected = selected;
-        self
-    }
-}
-
-impl Render for BurnModeTooltip {
-    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
-        let (icon, color) = if self.selected {
-            (IconName::ZedBurnModeOn, Color::Error)
-        } else {
-            (IconName::ZedBurnMode, Color::Default)
-        };
-
-        let turned_on = h_flex()
-            .h_4()
-            .px_1()
-            .border_1()
-            .border_color(cx.theme().colors().border)
-            .bg(cx.theme().colors().text_accent.opacity(0.1))
-            .rounded_sm()
-            .child(
-                Label::new("ON")
-                    .size(LabelSize::XSmall)
-                    .weight(FontWeight::SEMIBOLD)
-                    .color(Color::Accent),
-            );
-
-        let title = h_flex()
-            .gap_1p5()
-            .child(Icon::new(icon).size(IconSize::Small).color(color))
-            .child(Label::new("Burn Mode"))
-            .when(self.selected, |title| title.child(turned_on));
-
-        tooltip_container(window, cx, |this, _, _| {
-            this
-                .child(title)
-                .child(
-                    div()
-                        .max_w_64()
-                        .child(
-                            Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
-                                .size(LabelSize::Small)
-                                .color(Color::Muted)
-                        )
-                )
-        })
-    }
-}

crates/agent_ui/src/message_editor.rs 🔗

@@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread;
 use crate::agent_model_selector::AgentModelSelector;
 use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
 use crate::ui::{
-    MaxModeTooltip,
+    BurnModeTooltip,
     preview::{AgentPreview, UsageCallout},
 };
 use agent::history_store::HistoryStore;
@@ -605,7 +605,7 @@ impl MessageEditor {
                     this.toggle_burn_mode(&ToggleBurnMode, window, cx);
                 }))
                 .tooltip(move |_window, cx| {
-                    cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
+                    cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
                         .into()
                 })
                 .into_any_element(),

crates/agent_ui/src/text_thread_editor.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
-    burn_mode_tooltip::BurnModeTooltip,
     language_model_selector::{LanguageModelSelector, language_model_selector},
+    ui::BurnModeTooltip,
 };
 use agent_settings::{AgentSettings, CompletionMode};
 use anyhow::Result;

crates/agent_ui/src/ui/burn_mode_tooltip.rs 🔗

@@ -2,11 +2,11 @@ use crate::ToggleBurnMode;
 use gpui::{Context, FontWeight, IntoElement, Render, Window};
 use ui::{KeyBinding, prelude::*, tooltip_container};
 
-pub struct MaxModeTooltip {
+pub struct BurnModeTooltip {
     selected: bool,
 }
 
-impl MaxModeTooltip {
+impl BurnModeTooltip {
     pub fn new() -> Self {
         Self { selected: false }
     }
@@ -17,7 +17,7 @@ impl MaxModeTooltip {
     }
 }
 
-impl Render for MaxModeTooltip {
+impl Render for BurnModeTooltip {
     fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
         let (icon, color) = if self.selected {
             (IconName::ZedBurnModeOn, Color::Error)

crates/language_model/src/model/cloud_model.rs 🔗

@@ -42,6 +42,18 @@ impl fmt::Display for ModelRequestLimitReachedError {
     }
 }
 
+#[derive(Error, Debug)]
+pub struct ToolUseLimitReachedError;
+
+impl fmt::Display for ToolUseLimitReachedError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
+        )
+    }
+}
+
 #[derive(Clone, Default)]
 pub struct LlmApiToken(Arc<RwLock<Option<String>>>);