Detailed changes
@@ -1373,6 +1373,10 @@ impl AcpThread {
})
}
+ pub fn can_resume(&self, cx: &App) -> bool {
+ self.connection.resume(&self.session_id, cx).is_some()
+ }
+
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| {
@@ -2659,7 +2663,7 @@ mod tests {
fn truncate(
&self,
session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> {
Some(Rc::new(FakeAgentSessionEditor {
_session_id: session_id.clone(),
@@ -43,7 +43,7 @@ pub trait AgentConnection {
fn resume(
&self,
_session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn AgentSessionResume>> {
None
}
@@ -53,7 +53,7 @@ pub trait AgentConnection {
fn truncate(
&self,
_session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> {
None
}
@@ -61,7 +61,7 @@ pub trait AgentConnection {
fn set_title(
&self,
_session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn AgentSessionSetTitle>> {
None
}
@@ -439,7 +439,7 @@ mod test_support {
fn truncate(
&self,
_session_id: &agent_client_protocol::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> {
Some(Rc::new(StubAgentSessionEditor))
}
@@ -936,7 +936,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn resume(
&self,
session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
Some(Rc::new(NativeAgentSessionResume {
connection: self.clone(),
@@ -956,9 +956,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn truncate(
&self,
session_id: &agent_client_protocol::SessionId,
- cx: &mut App,
+ cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
- self.0.update(cx, |agent, _cx| {
+ self.0.read_with(cx, |agent, _cx| {
agent.sessions.get(session_id).map(|session| {
Rc::new(NativeAgentSessionEditor {
thread: session.thread.clone(),
@@ -971,7 +971,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn set_title(
&self,
session_id: &acp::SessionId,
- _cx: &mut App,
+ _cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
Some(Rc::new(NativeAgentSessionSetTitle {
connection: self.clone(),
@@ -5,6 +5,7 @@ use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use cloud_llm_client::CompletionIntent;
+use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use fs::{FakeFs, Fs};
use futures::{
@@ -673,15 +674,6 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
"}
)
});
-
- // Ensure we error if calling resume when tool use limit was *not* reached.
- let error = thread
- .update(cx, |thread, cx| thread.resume(cx))
- .unwrap_err();
- assert_eq!(
- error.to_string(),
- "can only resume after tool use limit is reached"
- )
}
#[gpui::test]
@@ -2105,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
.unwrap();
cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey,");
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
@@ -2114,8 +2107,9 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
- fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.send_last_completion_stream_text_chunk("there!");
fake_model.end_last_completion_stream();
+ cx.run_until_parked();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
@@ -2143,12 +2137,94 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
## Assistant
- Hey!
+ Hey,
+
+ [resume]
+
+ ## Assistant
+
+ there!
"}
)
});
}
+#[gpui::test]
+async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let tool_use_1 = LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: EchoTool::name().into(),
+ raw_input: json!({"text": "test"}).to_string(),
+ input: json!({"text": "test"}),
+ is_input_complete: true,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ tool_use_1.clone(),
+ ));
+ fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
+ provider: LanguageModelProviderName::new("Anthropic"),
+ retry_after: Some(Duration::from_secs(3)),
+ });
+ fake_model.end_last_completion_stream();
+
+ cx.executor().advance_clock(Duration::from_secs(3));
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Call the echo tool!".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![language_model::MessageContent::ToolResult(
+ LanguageModelToolResult {
+ tool_use_id: tool_use_1.id.clone(),
+ tool_name: tool_use_1.name.clone(),
+ is_error: false,
+ content: "test".into(),
+ output: Some("test".into())
+ }
+ )],
+ cache: true
+ },
+ ]
+ );
+
+ fake_model.send_last_completion_stream_text_chunk("Done");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ events.collect::<Vec<_>>().await;
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.last_message(),
+ Some(Message::Agent(AgentMessage {
+ content: vec![AgentMessageContent::Text("Done".into())],
+ tool_results: IndexMap::default()
+ }))
+ );
+ })
+}
+
#[gpui::test]
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
@@ -123,7 +123,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
- Message::Resume => "[resumed after tool use limit was reached]".into(),
+ Message::Resume => "[resume]\n".into(),
}
}
@@ -1085,11 +1085,6 @@ impl Thread {
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
- anyhow::ensure!(
- self.tool_use_limit_reached,
- "can only resume after tool use limit is reached"
- );
-
self.messages.push(Message::Resume);
cx.notify();
@@ -1216,12 +1211,13 @@ impl Thread {
cx: &mut AsyncApp,
) -> Result<()> {
log::debug!("Stream completion started successfully");
- let request = this.update(cx, |this, cx| {
- this.build_completion_request(completion_intent, cx)
- })??;
let mut attempt = None;
- 'retry: loop {
+ loop {
+ let request = this.update(cx, |this, cx| {
+ this.build_completion_request(completion_intent, cx)
+ })??;
+
telemetry::event!(
"Agent Thread Completion",
thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
@@ -1236,10 +1232,11 @@ impl Thread {
attempt.unwrap_or(0)
);
let mut events = model
- .stream_completion(request.clone(), cx)
+ .stream_completion(request, cx)
.await
.map_err(|error| anyhow!(error))?;
let mut tool_results = FuturesUnordered::new();
+ let mut error = None;
while let Some(event) = events.next().await {
match event {
@@ -1249,51 +1246,9 @@ impl Thread {
this.handle_streamed_completion_event(event, event_stream, cx)
})??);
}
- Err(error) => {
- let completion_mode =
- this.read_with(cx, |thread, _cx| thread.completion_mode())?;
- if completion_mode == CompletionMode::Normal {
- return Err(anyhow!(error))?;
- }
-
- let Some(strategy) = Self::retry_strategy_for(&error) else {
- return Err(anyhow!(error))?;
- };
-
- let max_attempts = match &strategy {
- RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
- RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
- };
-
- let attempt = attempt.get_or_insert(0u8);
-
- *attempt += 1;
-
- let attempt = *attempt;
- if attempt > max_attempts {
- return Err(anyhow!(error))?;
- }
-
- let delay = match &strategy {
- RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
- let delay_secs =
- initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
- Duration::from_secs(delay_secs)
- }
- RetryStrategy::Fixed { delay, .. } => *delay,
- };
- log::debug!("Retry attempt {attempt} with delay {delay:?}");
-
- event_stream.send_retry(acp_thread::RetryStatus {
- last_error: error.to_string().into(),
- attempt: attempt as usize,
- max_attempts: max_attempts as usize,
- started_at: Instant::now(),
- duration: delay,
- });
-
- cx.background_executor().timer(delay).await;
- continue 'retry;
+ Err(err) => {
+ error = Some(err);
+ break;
}
}
}
@@ -1320,7 +1275,58 @@ impl Thread {
})?;
}
- return Ok(());
+ if let Some(error) = error {
+ let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
+ if completion_mode == CompletionMode::Normal {
+ return Err(anyhow!(error))?;
+ }
+
+ let Some(strategy) = Self::retry_strategy_for(&error) else {
+ return Err(anyhow!(error))?;
+ };
+
+ let max_attempts = match &strategy {
+ RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+ RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+ };
+
+ let attempt = attempt.get_or_insert(0u8);
+
+ *attempt += 1;
+
+ let attempt = *attempt;
+ if attempt > max_attempts {
+ return Err(anyhow!(error))?;
+ }
+
+ let delay = match &strategy {
+ RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+ let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+ Duration::from_secs(delay_secs)
+ }
+ RetryStrategy::Fixed { delay, .. } => *delay,
+ };
+ log::debug!("Retry attempt {attempt} with delay {delay:?}");
+
+ event_stream.send_retry(acp_thread::RetryStatus {
+ last_error: error.to_string().into(),
+ attempt: attempt as usize,
+ max_attempts: max_attempts as usize,
+ started_at: Instant::now(),
+ duration: delay,
+ });
+ cx.background_executor().timer(delay).await;
+ this.update(cx, |this, cx| {
+ this.flush_pending_message(cx);
+ if let Some(Message::Agent(message)) = this.messages.last() {
+ if message.tool_results.is_empty() {
+ this.messages.push(Message::Resume);
+ }
+ }
+ })?;
+ } else {
+ return Ok(());
+ }
}
}
@@ -1737,6 +1743,10 @@ impl Thread {
return;
};
+ if message.content.is_empty() {
+ return;
+ }
+
for content in &message.content {
let AgentMessageContent::ToolUse(tool_use) = content else {
continue;
@@ -820,6 +820,9 @@ impl AcpThreadView {
let Some(thread) = self.thread() else {
return;
};
+ if !thread.read(cx).can_resume(cx) {
+ return;
+ }
let task = thread.update(cx, |thread, cx| thread.resume(cx));
cx.spawn(async move |this, cx| {
@@ -4459,12 +4462,53 @@ impl AcpThreadView {
}
fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout {
+ let can_resume = self
+ .thread()
+ .map_or(false, |thread| thread.read(cx).can_resume(cx));
+
+ let can_enable_burn_mode = self.as_native_thread(cx).map_or(false, |thread| {
+ let thread = thread.read(cx);
+ let supports_burn_mode = thread
+ .model()
+ .map_or(false, |model| model.supports_burn_mode());
+ supports_burn_mode && thread.completion_mode() == CompletionMode::Normal
+ });
+
Callout::new()
.severity(Severity::Error)
.title("Error")
.icon(IconName::XCircle)
.description(error.clone())
- .actions_slot(self.create_copy_button(error.to_string()))
+ .actions_slot(
+ h_flex()
+ .gap_0p5()
+ .when(can_resume && can_enable_burn_mode, |this| {
+ this.child(
+ Button::new("enable-burn-mode-and-retry", "Enable Burn Mode and Retry")
+ .icon(IconName::ZedBurnMode)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.toggle_burn_mode(&ToggleBurnMode, window, cx);
+ this.resume_chat(cx);
+ })),
+ )
+ })
+ .when(can_resume, |this| {
+ this.child(
+ Button::new("retry", "Retry")
+ .icon(IconName::RotateCw)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, _window, cx| {
+ this.resume_chat(cx);
+ })),
+ )
+ })
+ .child(self.create_copy_button(error.to_string())),
+ )
.dismiss_action(self.dismiss_error_button(cx))
}