Detailed changes
@@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write};
use std::ops::Range;
use std::process::ExitStatus;
use std::rc::Rc;
+use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App;
use util::ResultExt;
@@ -658,6 +659,15 @@ impl PlanEntry {
}
}
+#[derive(Debug, Clone)]
+pub struct RetryStatus {
+ pub last_error: SharedString,
+ pub attempt: usize,
+ pub max_attempts: usize,
+ pub started_at: Instant,
+ pub duration: Duration,
+}
+
pub struct AcpThread {
title: SharedString,
entries: Vec<AgentThreadEntry>,
@@ -676,6 +686,7 @@ pub enum AcpThreadEvent {
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
+ Retry(RetryStatus),
Stopped,
Error,
ServerExited(ExitStatus),
@@ -916,6 +927,10 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
+ pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
+ cx.emit(AcpThreadEvent::Retry(status));
+ }
+
pub fn update_tool_call(
&mut self,
update: impl Into<ToolCallUpdate>,
@@ -546,6 +546,11 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
+ AgentResponseEvent::Retry(status) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.update_retry_status(status, cx)
+ })?;
+ }
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
@@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
-use futures::channel::mpsc::UnboundedReceiver;
+use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
};
use indoc::indoc;
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
- LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
- Role, StopReason, fake_provider::FakeLanguageModel,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+ LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
+ LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
+ fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::Project;
@@ -24,7 +25,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use settings::SettingsStore;
-use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
@@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+
+ let mut retry_events = Vec::new();
+ while let Some(Ok(event)) = events.next().await {
+ match event {
+ AgentResponseEvent::Retry(retry_status) => {
+ retry_events.push(retry_status);
+ }
+ AgentResponseEvent::Stop(..) => break,
+ _ => {}
+ }
+ }
+
+ assert_eq!(retry_events.len(), 0);
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello!
+
+ ## Assistant
+
+ Hey!
+ "}
+ )
+ });
+}
+
+#[gpui::test]
+async fn test_send_retry_on_error(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
+ provider: LanguageModelProviderName::new("Anthropic"),
+ retry_after: Some(Duration::from_secs(3)),
+ });
+ fake_model.end_last_completion_stream();
+
+ cx.executor().advance_clock(Duration::from_secs(3));
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+
+ let mut retry_events = Vec::new();
+ while let Some(Ok(event)) = events.next().await {
+ match event {
+ AgentResponseEvent::Retry(retry_status) => {
+ retry_events.push(retry_status);
+ }
+ AgentResponseEvent::Stop(..) => break,
+ _ => {}
+ }
+ }
+
+ assert_eq!(retry_events.len(), 1);
+ assert!(matches!(
+ retry_events[0],
+ acp_thread::RetryStatus { attempt: 1, .. }
+ ));
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello!
+
+ ## Assistant
+
+ Hey!
+ "}
+ )
+ });
+}
+
+#[gpui::test]
+async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.send(UserMessageId::new(), ["Hello!"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
+ fake_model.send_last_completion_stream_error(
+ LanguageModelCompletionError::ServerOverloaded {
+ provider: LanguageModelProviderName::new("Anthropic"),
+ retry_after: Some(Duration::from_secs(3)),
+ },
+ );
+ fake_model.end_last_completion_stream();
+ cx.executor().advance_clock(Duration::from_secs(3));
+ cx.run_until_parked();
+ }
+
+ let mut errors = Vec::new();
+ let mut retry_events = Vec::new();
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(AgentResponseEvent::Retry(retry_status)) => {
+ retry_events.push(retry_status);
+ }
+ Ok(AgentResponseEvent::Stop(..)) => break,
+ Err(error) => errors.push(error),
+ _ => {}
+ }
+ }
+
+ assert_eq!(
+ retry_events.len(),
+ crate::thread::MAX_RETRY_ATTEMPTS as usize
+ );
+ for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
+ assert_eq!(retry_events[i].attempt, i + 1);
+ }
+ assert_eq!(errors.len(), 1);
+ let error = errors[0]
+ .downcast_ref::<LanguageModelCompletionError>()
+ .unwrap();
+ assert!(matches!(
+ error,
+ LanguageModelCompletionError::ServerOverloaded { .. }
+ ));
+}
+
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events
@@ -12,12 +12,12 @@ use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
-use gpui::{App, Context, Entity, SharedString, Task};
+use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
- LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
- LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
+ LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+ LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
-use std::{collections::BTreeMap, path::Path, sync::Arc};
+use std::{
+ collections::BTreeMap,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
@@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
}
}
+pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
+pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
+
+#[derive(Debug, Clone)]
+enum RetryStrategy {
+ ExponentialBackoff {
+ initial_delay: Duration,
+ max_attempts: u8,
+ },
+ Fixed {
+ delay: Duration,
+ max_attempts: u8,
+ },
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
User(UserMessage),
@@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
+ Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
}
@@ -662,41 +683,18 @@ impl Thread {
})??;
log::info!("Calling model.stream_completion");
- let mut events = model.stream_completion(request, cx).await?;
- log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
- let mut tool_uses = FuturesUnordered::new();
- while let Some(event) = events.next().await {
- match event? {
- LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::ToolUseLimitReached,
- ) => {
- tool_use_limit_reached = true;
- }
- LanguageModelCompletionEvent::Stop(reason) => {
- event_stream.send_stop(reason);
- if reason == StopReason::Refusal {
- this.update(cx, |this, _cx| {
- this.flush_pending_message();
- this.messages.truncate(message_ix);
- })?;
- return Ok(());
- }
- }
- event => {
- log::trace!("Received completion event: {:?}", event);
- this.update(cx, |this, cx| {
- tool_uses.extend(this.handle_streamed_completion_event(
- event,
- &event_stream,
- cx,
- ));
- })
- .ok();
- }
- }
- }
+ let mut tool_uses = Self::stream_completion_with_retries(
+ this.clone(),
+ model.clone(),
+ request,
+ message_ix,
+ &event_stream,
+ &mut tool_use_limit_reached,
+ cx,
+ )
+ .await?;
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
@@ -754,10 +752,105 @@ impl Thread {
Ok(events_rx)
}
+ async fn stream_completion_with_retries(
+ this: WeakEntity<Self>,
+ model: Arc<dyn LanguageModel>,
+ request: LanguageModelRequest,
+ message_ix: usize,
+ event_stream: &AgentResponseEventStream,
+ tool_use_limit_reached: &mut bool,
+ cx: &mut AsyncApp,
+ ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
+ log::debug!("Stream completion started successfully");
+
+ let mut attempt = None;
+ 'retry: loop {
+ let mut events = model.stream_completion(request.clone(), cx).await?;
+ let mut tool_uses = FuturesUnordered::new();
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::ToolUseLimitReached,
+ )) => {
+ *tool_use_limit_reached = true;
+ }
+ Ok(LanguageModelCompletionEvent::Stop(reason)) => {
+ event_stream.send_stop(reason);
+ if reason == StopReason::Refusal {
+ this.update(cx, |this, _cx| {
+ this.flush_pending_message();
+ this.messages.truncate(message_ix);
+ })?;
+ return Ok(tool_uses);
+ }
+ }
+ Ok(event) => {
+ log::trace!("Received completion event: {:?}", event);
+ this.update(cx, |this, cx| {
+ tool_uses.extend(this.handle_streamed_completion_event(
+ event,
+ event_stream,
+ cx,
+ ));
+ })
+ .ok();
+ }
+ Err(error) => {
+ let completion_mode =
+ this.read_with(cx, |thread, _cx| thread.completion_mode())?;
+ if completion_mode == CompletionMode::Normal {
+ return Err(error.into());
+ }
+
+ let Some(strategy) = Self::retry_strategy_for(&error) else {
+ return Err(error.into());
+ };
+
+ let max_attempts = match &strategy {
+ RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+ RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+ };
+
+ let attempt = attempt.get_or_insert(0u8);
+
+ *attempt += 1;
+
+ let attempt = *attempt;
+ if attempt > max_attempts {
+ return Err(error.into());
+ }
+
+ let delay = match &strategy {
+ RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+ let delay_secs =
+ initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+ Duration::from_secs(delay_secs)
+ }
+ RetryStrategy::Fixed { delay, .. } => *delay,
+ };
+ log::debug!("Retry attempt {attempt} with delay {delay:?}");
+
+ event_stream.send_retry(acp_thread::RetryStatus {
+ last_error: error.to_string().into(),
+ attempt: attempt as usize,
+ max_attempts: max_attempts as usize,
+ started_at: Instant::now(),
+ duration: delay,
+ });
+
+ cx.background_executor().timer(delay).await;
+ continue 'retry;
+ }
+ }
+ }
+ return Ok(tool_uses);
+ }
+ }
+
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
- project: &self.project_context.read(cx),
+ project: self.project_context.read(cx),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
@@ -1158,6 +1251,113 @@ impl Thread {
fn advance_prompt_id(&mut self) {
self.prompt_id = PromptId::new();
}
+
+ fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
+ use LanguageModelCompletionError::*;
+ use http_client::StatusCode;
+
+ // General strategy here:
+ // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
+ // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
+ // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
+ match error {
+ HttpResponseError {
+ status_code: StatusCode::TOO_MANY_REQUESTS,
+ ..
+ } => Some(RetryStrategy::ExponentialBackoff {
+ initial_delay: BASE_RETRY_DELAY,
+ max_attempts: MAX_RETRY_ATTEMPTS,
+ }),
+ ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
+ Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ max_attempts: MAX_RETRY_ATTEMPTS,
+ })
+ }
+ UpstreamProviderError {
+ status,
+ retry_after,
+ ..
+ } => match *status {
+ StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
+ Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ max_attempts: MAX_RETRY_ATTEMPTS,
+ })
+ }
+ StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ // Internal Server Error could be anything, retry up to 3 times.
+ max_attempts: 3,
+ }),
+ status => {
+ // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
+ // but we frequently get them in practice. See https://http.dev/529
+ if status.as_u16() == 529 {
+ Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ max_attempts: MAX_RETRY_ATTEMPTS,
+ })
+ } else {
+ Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ max_attempts: 2,
+ })
+ }
+ }
+ },
+ ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 3,
+ }),
+ ApiReadResponseError { .. }
+ | HttpSend { .. }
+ | DeserializeResponse { .. }
+ | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 3,
+ }),
+ // Retrying these errors definitely shouldn't help.
+ HttpResponseError {
+ status_code:
+ StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
+ ..
+ }
+ | AuthenticationError { .. }
+ | PermissionError { .. }
+ | NoApiKey { .. }
+ | ApiEndpointNotFound { .. }
+ | PromptTooLarge { .. } => None,
+ // These errors might be transient, so retry them
+ SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 1,
+ }),
+ // Retry all other 4xx and 5xx errors once.
+ HttpResponseError { status_code, .. }
+ if status_code.is_client_error() || status_code.is_server_error() =>
+ {
+ Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 3,
+ })
+ }
+ Other(err)
+ if err.is::<language_model::PaymentRequiredError>()
+ || err.is::<language_model::ModelRequestLimitReachedError>() =>
+ {
+ // Retrying won't help for Payment Required or Model Request Limit errors (where
+ // the user must upgrade to usage-based billing to get more requests, or else wait
+ // for a significant amount of time for the request limit to reset).
+ None
+ }
+ // Conservatively assume that any other errors are non-retryable
+ HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 2,
+ }),
+ }
+ }
}
struct RunningTurn {
@@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
.ok();
}
+ fn send_retry(&self, status: acp_thread::RetryStatus) {
+ self.0
+ .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
+ .ok();
+ }
+
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
@@ -1,7 +1,7 @@
use acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
- AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
- UserMessageId,
+ AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
+ ToolCallStatus, UserMessageId,
};
use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog;
@@ -35,6 +35,7 @@ use prompt_store::PromptId;
use rope::Point;
use settings::{Settings as _, SettingsStore};
use std::sync::Arc;
+use std::time::Instant;
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
use text::Anchor;
use theme::ThemeSettings;
@@ -115,6 +116,7 @@ pub struct AcpThreadView {
profile_selector: Option<Entity<ProfileSelector>>,
notifications: Vec<WindowHandle<AgentNotification>>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
+ thread_retry_status: Option<RetryStatus>,
thread_error: Option<ThreadError>,
list_state: ListState,
scrollbar_state: ScrollbarState,
@@ -209,6 +211,7 @@ impl AcpThreadView {
notification_subscriptions: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
+ thread_retry_status: None,
thread_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
@@ -445,6 +448,7 @@ impl AcpThreadView {
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
self.thread_error.take();
+ self.thread_retry_status.take();
if let Some(thread) = self.thread() {
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
@@ -775,7 +779,11 @@ impl AcpThreadView {
AcpThreadEvent::ToolAuthorizationRequired => {
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
}
+ AcpThreadEvent::Retry(retry) => {
+ self.thread_retry_status = Some(retry.clone());
+ }
AcpThreadEvent::Stopped => {
+ self.thread_retry_status.take();
let used_tools = thread.read(cx).used_tools_since_last_user_message();
self.notify_with_sound(
if used_tools {
@@ -789,6 +797,7 @@ impl AcpThreadView {
);
}
AcpThreadEvent::Error => {
+ self.thread_retry_status.take();
self.notify_with_sound(
"Agent stopped due to an error",
IconName::Warning,
@@ -797,6 +806,7 @@ impl AcpThreadView {
);
}
AcpThreadEvent::ServerExited(status) => {
+ self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
}
}
@@ -3413,7 +3423,51 @@ impl AcpThreadView {
})
}
- fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
+ fn render_thread_retry_status_callout(
+ &self,
+ _window: &mut Window,
+ _cx: &mut Context<Self>,
+ ) -> Option<Callout> {
+ let state = self.thread_retry_status.as_ref()?;
+
+ let next_attempt_in = state
+ .duration
+ .saturating_sub(Instant::now().saturating_duration_since(state.started_at));
+ if next_attempt_in.is_zero() {
+ return None;
+ }
+
+ let next_attempt_in_secs = next_attempt_in.as_secs() + 1;
+
+ let retry_message = if state.max_attempts == 1 {
+ if next_attempt_in_secs == 1 {
+ "Retrying. Next attempt in 1 second.".to_string()
+ } else {
+ format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
+ }
+ } else {
+ if next_attempt_in_secs == 1 {
+ format!(
+ "Retrying. Next attempt in 1 second (Attempt {} of {}).",
+ state.attempt, state.max_attempts,
+ )
+ } else {
+ format!(
+ "Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
+ state.attempt, state.max_attempts,
+ )
+ }
+ };
+
+ Some(
+ Callout::new()
+ .severity(Severity::Warning)
+ .title(state.last_error.clone())
+ .description(retry_message),
+ )
+ }
+
+ fn render_thread_error(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
let content = match self.thread_error.as_ref()? {
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
@@ -3678,6 +3732,7 @@ impl Render for AcpThreadView {
}
_ => this,
})
+ .children(self.render_thread_retry_status_callout(window, cx))
.children(self.render_thread_error(window, cx))
.child(self.render_message_editor(window, cx))
}
@@ -1523,6 +1523,7 @@ impl AgentDiff {
AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
+ | AcpThreadEvent::Retry(_)
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}
}
@@ -4,10 +4,11 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice,
};
-use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
+use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
+use smol::stream::StreamExt;
use std::sync::Arc;
#[derive(Clone)]
@@ -100,7 +101,9 @@ pub struct FakeLanguageModel {
current_completion_txs: Mutex<
Vec<(
LanguageModelRequest,
- mpsc::UnboundedSender<LanguageModelCompletionEvent>,
+ mpsc::UnboundedSender<
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
)>,
>,
}
@@ -150,7 +153,21 @@ impl FakeLanguageModel {
.find(|(req, _)| req == request)
.map(|(_, tx)| tx)
.unwrap();
- tx.unbounded_send(event.into()).unwrap();
+ tx.unbounded_send(Ok(event.into())).unwrap();
+ }
+
+ pub fn send_completion_stream_error(
+ &self,
+ request: &LanguageModelRequest,
+ error: impl Into<LanguageModelCompletionError>,
+ ) {
+ let current_completion_txs = self.current_completion_txs.lock();
+ let tx = current_completion_txs
+ .iter()
+ .find(|(req, _)| req == request)
+ .map(|(_, tx)| tx)
+ .unwrap();
+ tx.unbounded_send(Err(error.into())).unwrap();
}
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
@@ -170,6 +187,13 @@ impl FakeLanguageModel {
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
}
+ pub fn send_last_completion_stream_error(
+ &self,
+ error: impl Into<LanguageModelCompletionError>,
+ ) {
+ self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
+ }
+
pub fn end_last_completion_stream(&self) {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
@@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel {
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
- async move { Ok(rx.map(Ok).boxed()) }.boxed()
+ async move { Ok(rx.boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {