agent: Fix max token count mismatch when not using burn mode (#34025)

Bennet Bo Fenner created

Closes #31854

Release Notes:

- agent: Fixed an issue where the maximum token count would be displayed
incorrectly when burn mode was not being used.

Change summary

Cargo.lock                                   |  4 +-
Cargo.toml                                   |  2 
crates/agent/src/thread.rs                   | 32 +++++++++++++++++----
crates/agent/src/tool_use.rs                 | 12 ++++++--
crates/agent_ui/src/text_thread_editor.rs    |  6 ++--
crates/language_model/src/language_model.rs  | 18 +++++++++++
crates/language_models/src/provider/cloud.rs |  7 ++++
7 files changed, 64 insertions(+), 17 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -20145,9 +20145,9 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.8.5"
+version = "0.8.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
+checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc"
 dependencies = [
  "anyhow",
  "serde",

Cargo.toml 🔗

@@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
 wasmtime-wasi = "29"
 which = "6.0.0"
 workspace-hack = "0.1.0"
-zed_llm_client = "= 0.8.5"
+zed_llm_client = "= 0.8.6"
 zstd = "0.11"
 
 [workspace.dependencies.async-stripe]

crates/agent/src/thread.rs 🔗

@@ -23,10 +23,11 @@ use gpui::{
 };
 use language_model::{
     ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
-    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
-    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
-    PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
+    LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
+    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
+    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
+    TokenUsage,
 };
 use postage::stream::Stream as _;
 use project::{
@@ -1582,6 +1583,7 @@ impl Thread {
             tool_name,
             tool_output,
             self.configured_model.as_ref(),
+            self.completion_mode,
         );
 
         pending_tool_use
@@ -1610,6 +1612,10 @@ impl Thread {
             prompt_id: prompt_id.clone(),
         };
 
+        let completion_mode = request
+            .mode
+            .unwrap_or(zed_llm_client::CompletionMode::Normal);
+
         self.last_received_chunk_at = Some(Instant::now());
 
         let task = cx.spawn(async move |thread, cx| {
@@ -1959,7 +1965,11 @@ impl Thread {
                                                 .unwrap_or(0)
                                                 // We know the context window was exceeded in practice, so if our estimate was
                                                 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
-                                                .max(model.max_token_count().saturating_add(1))
+                                                .max(
+                                                    model
+                                                        .max_token_count_for_mode(completion_mode)
+                                                        .saturating_add(1),
+                                                )
                                         });
                                         thread.exceeded_window_error = Some(ExceededWindowError {
                                             model_id: model.id(),
@@ -2507,6 +2517,7 @@ impl Thread {
             hallucinated_tool_name,
             Err(anyhow!("Missing tool call: {error_message}")),
             self.configured_model.as_ref(),
+            self.completion_mode,
         );
 
         cx.emit(ThreadEvent::MissingToolUse {
@@ -2533,6 +2544,7 @@ impl Thread {
             tool_name,
             Err(anyhow!("Error parsing input JSON: {error}")),
             self.configured_model.as_ref(),
+            self.completion_mode,
         );
         let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
             pending_tool_use.ui_text.clone()
@@ -2608,6 +2620,7 @@ impl Thread {
                             tool_name,
                             output,
                             thread.configured_model.as_ref(),
+                            thread.completion_mode,
                         );
                         thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
                     })
@@ -3084,7 +3097,9 @@ impl Thread {
             return TotalTokenUsage::default();
         };
 
-        let max = model.model.max_token_count();
+        let max = model
+            .model
+            .max_token_count_for_mode(self.completion_mode().into());
 
         let index = self
             .messages
@@ -3111,7 +3126,9 @@ impl Thread {
     pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
         let model = self.configured_model.as_ref()?;
 
-        let max = model.model.max_token_count();
+        let max = model
+            .model
+            .max_token_count_for_mode(self.completion_mode().into());
 
         if let Some(exceeded_error) = &self.exceeded_window_error {
             if model.model.id() == exceeded_error.model_id {
@@ -3177,6 +3194,7 @@ impl Thread {
             tool_name,
             err,
             self.configured_model.as_ref(),
+            self.completion_mode,
         );
         self.tool_finished(tool_use_id.clone(), None, true, window, cx);
     }

crates/agent/src/tool_use.rs 🔗

@@ -2,6 +2,7 @@ use crate::{
     thread::{MessageId, PromptId, ThreadId},
     thread_store::SerializedMessage,
 };
+use agent_settings::CompletionMode;
 use anyhow::Result;
 use assistant_tool::{
     AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
@@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared};
 use gpui::{App, Entity, SharedString, Task, Window};
 use icons::IconName;
 use language_model::{
-    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
-    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
+    ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
+    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
+    LanguageModelToolUseId, Role,
 };
 use project::Project;
 use std::sync::Arc;
@@ -400,6 +402,7 @@ impl ToolUseState {
         tool_name: Arc<str>,
         output: Result<ToolResultOutput>,
         configured_model: Option<&ConfiguredModel>,
+        completion_mode: CompletionMode,
     ) -> Option<PendingToolUse> {
         let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
 
@@ -426,7 +429,10 @@ impl ToolUseState {
 
                 // Protect from overly large output
                 let tool_output_limit = configured_model
-                    .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
+                    .map(|model| {
+                        model.model.max_token_count_for_mode(completion_mode.into()) as usize
+                            * BYTES_PER_TOKEN_ESTIMATE
+                    })
                     .unwrap_or(usize::MAX);
 
                 let content = match tool_result {

crates/agent_ui/src/text_thread_editor.rs 🔗

@@ -38,8 +38,8 @@ use language::{
     language_settings::{SoftWrap, all_language_settings},
 };
 use language_model::{
-    ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry,
-    Role,
+    ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView,
+    LanguageModelRegistry, Role,
 };
 use multi_buffer::MultiBufferRow;
 use picker::{Picker, popover_menu::PickerPopoverMenu};
@@ -3063,7 +3063,7 @@ fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenStat
         .default_model()?
         .model;
     let token_count = context.read(cx).token_count()?;
-    let max_token_count = model.max_token_count();
+    let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into());
     let token_state = if max_token_count.saturating_sub(token_count) == 0 {
         TokenState::NoTokensLeft {
             max_token_count,

crates/language_model/src/language_model.rs 🔗

@@ -26,7 +26,7 @@ use std::time::Duration;
 use std::{fmt, io};
 use thiserror::Error;
 use util::serde::is_default;
-use zed_llm_client::CompletionRequestStatus;
+use zed_llm_client::{CompletionMode, CompletionRequestStatus};
 
 pub use crate::model::*;
 pub use crate::rate_limiter::*;
@@ -462,6 +462,10 @@ pub trait LanguageModel: Send + Sync {
     }
 
     fn max_token_count(&self) -> u64;
+    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
+    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
+        None
+    }
     fn max_output_tokens(&self) -> Option<u64> {
         None
     }
@@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync {
     }
 }
 
+pub trait LanguageModelExt: LanguageModel {
+    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
+        match mode {
+            CompletionMode::Normal => self.max_token_count(),
+            CompletionMode::Max => self
+                .max_token_count_in_burn_mode()
+                .unwrap_or_else(|| self.max_token_count()),
+        }
+    }
+}
+impl LanguageModelExt for dyn LanguageModel {}
+
 pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
     fn name() -> String;
     fn description() -> String;

crates/language_models/src/provider/cloud.rs 🔗

@@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel {
         self.model.max_token_count as u64
     }
 
+    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
+        self.model
+            .max_token_count_in_max_mode
+            .filter(|_| self.model.supports_max_mode)
+            .map(|max_token_count| max_token_count as u64)
+    }
+
     fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
         match &self.model.provider {
             zed_llm_client::LanguageModelProvider::Anthropic => {