Detailed changes
@@ -20145,9 +20145,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
+checksum = "6607f74dee2a18a9ce0f091844944a0e59881359ab62e0768fb0618f55d4c1dc"
dependencies = [
"anyhow",
"serde",
@@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
-zed_llm_client = "= 0.8.5"
+zed_llm_client = "= 0.8.6"
zstd = "0.11"
[workspace.dependencies.async-stripe]
@@ -23,10 +23,11 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
- LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
- PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
+ LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
+ ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
+ TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@@ -1582,6 +1583,7 @@ impl Thread {
tool_name,
tool_output,
self.configured_model.as_ref(),
+ self.completion_mode,
);
pending_tool_use
@@ -1610,6 +1612,10 @@ impl Thread {
prompt_id: prompt_id.clone(),
};
+ let completion_mode = request
+ .mode
+ .unwrap_or(zed_llm_client::CompletionMode::Normal);
+
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
@@ -1959,7 +1965,11 @@ impl Thread {
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
- .max(model.max_token_count().saturating_add(1))
+ .max(
+ model
+ .max_token_count_for_mode(completion_mode)
+ .saturating_add(1),
+ )
});
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
@@ -2507,6 +2517,7 @@ impl Thread {
hallucinated_tool_name,
Err(anyhow!("Missing tool call: {error_message}")),
self.configured_model.as_ref(),
+ self.completion_mode,
);
cx.emit(ThreadEvent::MissingToolUse {
@@ -2533,6 +2544,7 @@ impl Thread {
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
self.configured_model.as_ref(),
+ self.completion_mode,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
@@ -2608,6 +2620,7 @@ impl Thread {
tool_name,
output,
thread.configured_model.as_ref(),
+ thread.completion_mode,
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
@@ -3084,7 +3097,9 @@ impl Thread {
return TotalTokenUsage::default();
};
- let max = model.model.max_token_count();
+ let max = model
+ .model
+ .max_token_count_for_mode(self.completion_mode().into());
let index = self
.messages
@@ -3111,7 +3126,9 @@ impl Thread {
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
let model = self.configured_model.as_ref()?;
- let max = model.model.max_token_count();
+ let max = model
+ .model
+ .max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error {
if model.model.id() == exceeded_error.model_id {
@@ -3177,6 +3194,7 @@ impl Thread {
tool_name,
err,
self.configured_model.as_ref(),
+ self.completion_mode,
);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}
@@ -2,6 +2,7 @@ use crate::{
thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage,
};
+use agent_settings::CompletionMode;
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
@@ -11,8 +12,9 @@ use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName;
use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
+ ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest,
+ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
+ LanguageModelToolUseId, Role,
};
use project::Project;
use std::sync::Arc;
@@ -400,6 +402,7 @@ impl ToolUseState {
tool_name: Arc<str>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
+ completion_mode: CompletionMode,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@@ -426,7 +429,10 @@ impl ToolUseState {
// Protect from overly large output
let tool_output_limit = configured_model
- .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
+ .map(|model| {
+ model.model.max_token_count_for_mode(completion_mode.into()) as usize
+ * BYTES_PER_TOKEN_ESTIMATE
+ })
.unwrap_or(usize::MAX);
let content = match tool_result {
@@ -38,8 +38,8 @@ use language::{
language_settings::{SoftWrap, all_language_settings},
};
use language_model::{
- ConfigurationError, LanguageModelImage, LanguageModelProviderTosView, LanguageModelRegistry,
- Role,
+ ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView,
+ LanguageModelRegistry, Role,
};
use multi_buffer::MultiBufferRow;
use picker::{Picker, popover_menu::PickerPopoverMenu};
@@ -3063,7 +3063,7 @@ fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenStat
.default_model()?
.model;
let token_count = context.read(cx).token_count()?;
- let max_token_count = model.max_token_count();
+ let max_token_count = model.max_token_count_for_mode(context.read(cx).completion_mode().into());
let token_state = if max_token_count.saturating_sub(token_count) == 0 {
TokenState::NoTokensLeft {
max_token_count,
@@ -26,7 +26,7 @@ use std::time::Duration;
use std::{fmt, io};
use thiserror::Error;
use util::serde::is_default;
-use zed_llm_client::CompletionRequestStatus;
+use zed_llm_client::{CompletionMode, CompletionRequestStatus};
pub use crate::model::*;
pub use crate::rate_limiter::*;
@@ -462,6 +462,10 @@ pub trait LanguageModel: Send + Sync {
}
fn max_token_count(&self) -> u64;
+ /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
+ fn max_token_count_in_burn_mode(&self) -> Option<u64> {
+ None
+ }
fn max_output_tokens(&self) -> Option<u64> {
None
}
@@ -557,6 +561,18 @@ pub trait LanguageModel: Send + Sync {
}
}
+pub trait LanguageModelExt: LanguageModel {
+ fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
+ match mode {
+ CompletionMode::Normal => self.max_token_count(),
+ CompletionMode::Max => self
+ .max_token_count_in_burn_mode()
+ .unwrap_or_else(|| self.max_token_count()),
+ }
+ }
+}
+impl LanguageModelExt for dyn LanguageModel {}
+
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@@ -730,6 +730,13 @@ impl LanguageModel for CloudLanguageModel {
self.model.max_token_count as u64
}
+ fn max_token_count_in_burn_mode(&self) -> Option<u64> {
+ self.model
+ .max_token_count_in_max_mode
+ .filter(|_| self.model.supports_max_mode)
+ .map(|max_token_count| max_token_count as u64)
+ }
+
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
match &self.model.provider {
zed_llm_client::LanguageModelProvider::Anthropic => {