@@ -15,10 +15,11 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
-use client::{ModelRequestUsage, RequestUsage};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
+use client::{ModelRequestUsage, RequestUsage, UserStore};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
+use futures::stream;
use futures::{
FutureExt,
channel::{mpsc, oneshot},
@@ -34,7 +35,7 @@ use language_model::{
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
- LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
+ LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
};
use project::{
Project,
@@ -585,6 +586,7 @@ pub struct Thread {
pending_title_generation: Option<Task<()>>,
summary: Option<SharedString>,
messages: Vec<Message>,
+ user_store: Entity<UserStore>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
@@ -641,6 +643,7 @@ impl Thread {
pending_title_generation: None,
summary: None,
messages: Vec::new(),
+ user_store: project.read(cx).user_store(),
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
running_turn: None,
pending_message: None,
@@ -820,6 +823,7 @@ impl Thread {
pending_title_generation: None,
summary: db_thread.detailed_summary,
messages: db_thread.messages,
+ user_store: project.read(cx).user_store(),
completion_mode: db_thread.completion_mode.unwrap_or_default(),
running_turn: None,
pending_message: None,
@@ -1249,12 +1253,12 @@ impl Thread {
);
log::debug!("Calling model.stream_completion, attempt {}", attempt);
- let mut events = model
- .stream_completion(request, cx)
- .await
- .map_err(|error| anyhow!(error))?;
+
+ let (mut events, mut error) = match model.stream_completion(request, cx).await {
+ Ok(events) => (events, None),
+ Err(err) => (stream::empty().boxed(), Some(err)),
+ };
let mut tool_results = FuturesUnordered::new();
- let mut error = None;
while let Some(event) = events.next().await {
log::trace!("Received completion event: {:?}", event);
match event {
@@ -1302,8 +1306,10 @@ impl Thread {
if let Some(error) = error {
attempt += 1;
- let retry =
- this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
+ let retry = this.update(cx, |this, cx| {
+ let user_store = this.user_store.read(cx);
+ this.handle_completion_error(error, attempt, user_store.plan())
+ })??;
let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry);
timer.await;
@@ -1330,8 +1336,23 @@ impl Thread {
&mut self,
error: LanguageModelCompletionError,
attempt: u8,
+ plan: Option<Plan>,
) -> Result<acp_thread::RetryStatus> {
- if self.completion_mode == CompletionMode::Normal {
+ let Some(model) = self.model.as_ref() else {
+ return Err(anyhow!(error));
+ };
+
+ let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID {
+ match plan {
+ Some(Plan::V2(_)) => true,
+ Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn,
+ None => false,
+ }
+ } else {
+ true
+ };
+
+ if !auto_retry {
return Err(anyhow!(error));
}