agent: Show a notice when reaching consecutive tool use limits (#29833)

Marshall Bowers created

This PR adds a notice when reaching consecutive tool use limits when
using normal mode.

Here's an example with the limit artificially lowered to 2 consecutive
tool uses:


https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4

Release Notes:

- agent: Added a notice when reaching consecutive tool use limits when
using a model in normal mode.

Change summary

Cargo.lock                                   |  4 
Cargo.toml                                   |  2 
crates/agent/src/assistant_panel.rs          | 36 +++++++++
crates/agent/src/thread.rs                   | 29 ++++++-
crates/language_model/src/language_model.rs  |  5 
crates/language_models/src/provider/cloud.rs | 83 ++++++++++++++++++---
6 files changed, 134 insertions(+), 25 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -18826,9 +18826,9 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.7.1"
+version = "0.7.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
+checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7"
 dependencies = [
  "anyhow",
  "serde",

Cargo.toml 🔗

@@ -611,7 +611,7 @@ wasmtime-wasi = "29"
 which = "6.0.0"
 wit-component = "0.221"
 workspace-hack = "0.1.0"
-zed_llm_client = "0.7.1"
+zed_llm_client = "0.7.2"
 zstd = "0.11"
 
 [workspace.dependencies.async-stripe]

crates/agent/src/assistant_panel.rs 🔗

@@ -1957,6 +1957,41 @@ impl AssistantPanel {
         Some(UsageBanner::new(plan, usage).into_any_element())
     }
 
+    fn render_tool_use_limit_reached(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+        let tool_use_limit_reached = self
+            .thread
+            .read(cx)
+            .thread()
+            .read(cx)
+            .tool_use_limit_reached();
+        if !tool_use_limit_reached {
+            return None;
+        }
+
+        let model = self
+            .thread
+            .read(cx)
+            .thread()
+            .read(cx)
+            .configured_model()?
+            .model;
+
+        let max_mode_upsell = if model.supports_max_mode() {
+            " Enable max mode for unlimited tool use."
+        } else {
+            ""
+        };
+
+        Some(
+            Banner::new()
+                .severity(ui::Severity::Info)
+                .children(h_flex().child(Label::new(format!(
+                    "Consecutive tool use limit reached.{max_mode_upsell}"
+                ))))
+                .into_any_element(),
+        )
+    }
+
     fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
         let last_error = self.thread.read(cx).last_error()?;
 
@@ -2238,6 +2273,7 @@ impl Render for AssistantPanel {
             .map(|parent| match &self.active_view {
                 ActiveView::Thread { .. } => parent
                     .child(self.render_active_thread_or_empty_state(window, cx))
+                    .children(self.render_tool_use_limit_reached(cx))
                     .children(self.render_usage_banner(cx))
                     .child(h_flex().child(self.message_editor.clone()))
                     .children(self.render_last_error(cx)),

crates/agent/src/thread.rs 🔗

@@ -355,6 +355,7 @@ pub struct Thread {
     request_token_usage: Vec<TokenUsage>,
     cumulative_token_usage: TokenUsage,
     exceeded_window_error: Option<ExceededWindowError>,
+    tool_use_limit_reached: bool,
     feedback: Option<ThreadFeedback>,
     message_feedback: HashMap<MessageId, ThreadFeedback>,
     last_auto_capture_at: Option<Instant>,
@@ -417,6 +418,7 @@ impl Thread {
             request_token_usage: Vec::new(),
             cumulative_token_usage: TokenUsage::default(),
             exceeded_window_error: None,
+            tool_use_limit_reached: false,
             feedback: None,
             message_feedback: HashMap::default(),
             last_auto_capture_at: None,
@@ -524,6 +526,7 @@ impl Thread {
             request_token_usage: serialized.request_token_usage,
             cumulative_token_usage: serialized.cumulative_token_usage,
             exceeded_window_error: None,
+            tool_use_limit_reached: false,
             feedback: None,
             message_feedback: HashMap::default(),
             last_auto_capture_at: None,
@@ -814,6 +817,10 @@ impl Thread {
             .unwrap_or(false)
     }
 
+    pub fn tool_use_limit_reached(&self) -> bool {
+        self.tool_use_limit_reached
+    }
+
     /// Returns whether all of the tool uses have finished running.
     pub fn all_tools_finished(&self) -> bool {
         // If the only pending tool uses left are the ones with errors, then
@@ -1331,6 +1338,8 @@ impl Thread {
         window: Option<AnyWindowHandle>,
         cx: &mut Context<Self>,
     ) {
+        self.tool_use_limit_reached = false;
+
         let pending_completion_id = post_inc(&mut self.completion_count);
         let mut request_callback_parameters = if self.request_callback.is_some() {
             Some((request.clone(), Vec::new()))
@@ -1506,17 +1515,27 @@ impl Thread {
                                     });
                                 }
                             }
-                            LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
+                            LanguageModelCompletionEvent::QueueUpdate(status) => {
                                 if let Some(completion) = thread
                                     .pending_completions
                                     .iter_mut()
                                     .find(|completion| completion.id == pending_completion_id)
                                 {
-                                    completion.queue_state = match queue_event {
-                                        language_model::QueueState::Queued { position } => {
-                                            QueueState::Queued { position }
+                                    let queue_state = match status {
+                                        language_model::CompletionRequestStatus::Queued {
+                                            position,
+                                        } => Some(QueueState::Queued { position }),
+                                        language_model::CompletionRequestStatus::Started => {
+                                            Some(QueueState::Started)
                                         }
-                                        language_model::QueueState::Started => QueueState::Started,
+                                        language_model::CompletionRequestStatus::ToolUseLimitReached => {
+                                            thread.tool_use_limit_reached = true;
+                                            None
+                                        }
+                                    };
+
+                                    if let Some(queue_state) = queue_state {
+                                        completion.queue_state = queue_state;
                                     }
                                 }
                             }

crates/language_model/src/language_model.rs 🔗

@@ -66,15 +66,16 @@ pub struct LanguageModelCacheConfiguration {
 
 #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 #[serde(tag = "status", rename_all = "snake_case")]
-pub enum QueueState {
+pub enum CompletionRequestStatus {
     Queued { position: usize },
     Started,
+    ToolUseLimitReached,
 }
 
 /// A completion event from a language model.
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub enum LanguageModelCompletionEvent {
-    QueueUpdate(QueueState),
+    QueueUpdate(CompletionRequestStatus),
     Stop(StopReason),
     Text(String),
     Thinking {

crates/language_models/src/provider/cloud.rs 🔗

@@ -9,11 +9,12 @@ use futures::{
 use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
 use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 use language_model::{
-    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
-    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
-    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
-    ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
+    AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
+    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
+    LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
+    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
+    ZED_CLOUD_PROVIDER_ID,
 };
 use language_model::{
     LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -38,6 +39,7 @@ use zed_llm_client::{
     CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
     EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
     MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+    TOOL_USE_LIMIT_REACHED_HEADER_NAME,
 };
 
 use crate::AllLanguageModelSettings;
@@ -511,6 +513,13 @@ pub struct CloudLanguageModel {
     request_limiter: RateLimiter,
 }
 
+struct PerformLlmCompletionResponse {
+    response: Response<AsyncBody>,
+    usage: Option<RequestUsage>,
+    tool_use_limit_reached: bool,
+    includes_queue_events: bool,
+}
+
 impl CloudLanguageModel {
     const MAX_RETRIES: usize = 3;
 
@@ -518,7 +527,7 @@ impl CloudLanguageModel {
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
         body: CompletionBody,
-    ) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
+    ) -> Result<PerformLlmCompletionResponse> {
         let http_client = &client.http_client();
 
         let mut token = llm_api_token.acquire(&client).await?;
@@ -545,9 +554,18 @@ impl CloudLanguageModel {
                     .headers()
                     .get("x-zed-server-supports-queueing")
                     .is_some();
+                let tool_use_limit_reached = response
+                    .headers()
+                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
+                    .is_some();
                 let usage = RequestUsage::from_headers(response.headers()).ok();
 
-                return Ok((response, usage, includes_queue_events));
+                return Ok(PerformLlmCompletionResponse {
+                    response,
+                    usage,
+                    includes_queue_events,
+                    tool_use_limit_reached,
+                });
             } else if response
                 .headers()
                 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream_with_usage(async move {
-                    let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+                    let PerformLlmCompletionResponse {
+                        response,
+                        usage,
+                        includes_queue_events,
+                        tool_use_limit_reached,
+                    } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
                         CompletionBody {
@@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel {
                     let mut mapper = AnthropicEventMapper::new();
                     Ok((
                         map_cloud_completion_events(
-                            Box::pin(response_lines(response, includes_queue_events)),
+                            Box::pin(
+                                response_lines(response, includes_queue_events)
+                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                            ),
                             move |event| mapper.map_event(event),
                         ),
                         usage,
@@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel {
                 let request = into_open_ai(request, model, model.max_output_tokens());
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream_with_usage(async move {
-                    let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+                    let PerformLlmCompletionResponse {
+                        response,
+                        usage,
+                        includes_queue_events,
+                        tool_use_limit_reached,
+                    } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
                         CompletionBody {
@@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel {
                     let mut mapper = OpenAiEventMapper::new();
                     Ok((
                         map_cloud_completion_events(
-                            Box::pin(response_lines(response, includes_queue_events)),
+                            Box::pin(
+                                response_lines(response, includes_queue_events)
+                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                            ),
                             move |event| mapper.map_event(event),
                         ),
                         usage,
@@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel {
                 let request = into_google(request, model.id().into());
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream_with_usage(async move {
-                    let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+                    let PerformLlmCompletionResponse {
+                        response,
+                        usage,
+                        includes_queue_events,
+                        tool_use_limit_reached,
+                    } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
                         CompletionBody {
@@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
+
                     let mut mapper = GoogleEventMapper::new();
                     Ok((
                         map_cloud_completion_events(
-                            Box::pin(response_lines(response, includes_queue_events)),
+                            Box::pin(
+                                response_lines(response, includes_queue_events)
+                                    .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                            ),
                             move |event| mapper.map_event(event),
                         ),
                         usage,
@@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
 #[derive(Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum CloudCompletionEvent<T> {
-    Queue(QueueState),
+    System(CompletionRequestStatus),
     Event(T),
 }
 
@@ -925,7 +968,7 @@ where
                 Err(error) => {
                     vec![Err(LanguageModelCompletionError::Other(error))]
                 }
-                Ok(CloudCompletionEvent::Queue(event)) => {
+                Ok(CloudCompletionEvent::System(event)) => {
                     vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
                 }
                 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
@@ -934,6 +977,16 @@ where
         .boxed()
 }
 
+fn tool_use_limit_reached_event<T>(
+    tool_use_limit_reached: bool,
+) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+    futures::stream::iter(tool_use_limit_reached.then(|| {
+        Ok(CloudCompletionEvent::System(
+            CompletionRequestStatus::ToolUseLimitReached,
+        ))
+    }))
+}
+
 fn response_lines<T: DeserializeOwned>(
     response: Response<AsyncBody>,
     includes_queue_events: bool,