@@ -21,6 +21,7 @@ use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
WeakEntity, Window,
};
+use http_client::StatusCode;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
@@ -51,7 +52,19 @@ use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
const MAX_RETRY_ATTEMPTS: u8 = 3;
-const BASE_RETRY_DELAY_SECS: u64 = 5;
+const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
+
+#[derive(Debug, Clone)]
+enum RetryStrategy {
+ ExponentialBackoff {
+ initial_delay: Duration,
+ max_attempts: u8,
+ },
+ Fixed {
+ delay: Duration,
+ max_attempts: u8,
+ },
+}
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
@@ -1933,18 +1946,6 @@ impl Thread {
project.set_agent_location(None, cx);
});
- fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
- let error_message = error
- .chain()
- .map(|err| err.to_string())
- .collect::<Vec<_>>()
- .join("\n");
- cx.emit(ThreadEvent::ShowError(ThreadError::Message {
- header: "Error interacting with language model".into(),
- message: SharedString::from(error_message.clone()),
- }));
- }
-
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if let Some(error) =
@@ -1956,9 +1957,10 @@ impl Thread {
} else if let Some(completion_error) =
error.downcast_ref::<LanguageModelCompletionError>()
{
- use LanguageModelCompletionError::*;
match &completion_error {
- PromptTooLarge { tokens, .. } => {
+ LanguageModelCompletionError::PromptTooLarge {
+ tokens, ..
+ } => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread
@@ -1979,63 +1981,22 @@ impl Thread {
});
cx.notify();
}
- RateLimitExceeded {
- retry_after: Some(retry_after),
- ..
- }
- | ServerOverloaded {
- retry_after: Some(retry_after),
- ..
- } => {
- thread.handle_rate_limit_error(
- &completion_error,
- *retry_after,
- model.clone(),
- intent,
- window,
- cx,
- );
- retry_scheduled = true;
- }
- RateLimitExceeded { .. } | ServerOverloaded { .. } => {
- retry_scheduled = thread.handle_retryable_error(
- &completion_error,
- model.clone(),
- intent,
- window,
- cx,
- );
- if !retry_scheduled {
- emit_generic_error(error, cx);
- }
- }
- ApiInternalServerError { .. }
- | ApiReadResponseError { .. }
- | HttpSend { .. } => {
- retry_scheduled = thread.handle_retryable_error(
- &completion_error,
- model.clone(),
- intent,
- window,
- cx,
- );
- if !retry_scheduled {
- emit_generic_error(error, cx);
+ _ => {
+ if let Some(retry_strategy) =
+ Thread::get_retry_strategy(completion_error)
+ {
+ retry_scheduled = thread
+ .handle_retryable_error_with_delay(
+ &completion_error,
+ Some(retry_strategy),
+ model.clone(),
+ intent,
+ window,
+ cx,
+ );
}
}
- NoApiKey { .. }
- | HttpResponseError { .. }
- | BadRequestFormat { .. }
- | AuthenticationError { .. }
- | PermissionError { .. }
- | ApiEndpointNotFound { .. }
- | SerializeRequest { .. }
- | BuildRequestBody { .. }
- | DeserializeResponse { .. }
- | Other { .. } => emit_generic_error(error, cx),
}
- } else {
- emit_generic_error(error, cx);
}
if !retry_scheduled {
@@ -2162,73 +2123,86 @@ impl Thread {
});
}
- fn handle_rate_limit_error(
- &mut self,
- error: &LanguageModelCompletionError,
- retry_after: Duration,
- model: Arc<dyn LanguageModel>,
- intent: CompletionIntent,
- window: Option<AnyWindowHandle>,
- cx: &mut Context<Self>,
- ) {
- // For rate limit errors, we only retry once with the specified duration
- let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
- log::warn!(
- "Retrying completion request in {} seconds: {error:?}",
- retry_after.as_secs(),
- );
-
- // Add a UI-only message instead of a regular message
- let id = self.next_message_id.post_inc();
- self.messages.push(Message {
- id,
- role: Role::System,
- segments: vec![MessageSegment::Text(retry_message)],
- loaded_context: LoadedContext::default(),
- creases: Vec::new(),
- is_hidden: false,
- ui_only: true,
- });
- cx.emit(ThreadEvent::MessageAdded(id));
- // Schedule the retry
- let thread_handle = cx.entity().downgrade();
-
- cx.spawn(async move |_thread, cx| {
- cx.background_executor().timer(retry_after).await;
+ fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
+ use LanguageModelCompletionError::*;
- thread_handle
- .update(cx, |thread, cx| {
- // Retry the completion
- thread.send_to_model(model, intent, window, cx);
+ // General strategy here:
+ // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
+ // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), try multiple times with exponential backoff.
+ // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), just retry once.
+ match error {
+ HttpResponseError {
+ status_code: StatusCode::TOO_MANY_REQUESTS,
+ ..
+ } => Some(RetryStrategy::ExponentialBackoff {
+ initial_delay: BASE_RETRY_DELAY,
+ max_attempts: MAX_RETRY_ATTEMPTS,
+ }),
+ ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
+ Some(RetryStrategy::Fixed {
+ delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+ max_attempts: MAX_RETRY_ATTEMPTS,
})
- .log_err();
- })
- .detach();
- }
-
- fn handle_retryable_error(
- &mut self,
- error: &LanguageModelCompletionError,
- model: Arc<dyn LanguageModel>,
- intent: CompletionIntent,
- window: Option<AnyWindowHandle>,
- cx: &mut Context<Self>,
- ) -> bool {
- self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
+ }
+ ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 1,
+ }),
+ ApiReadResponseError { .. }
+ | HttpSend { .. }
+ | DeserializeResponse { .. }
+ | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 1,
+ }),
+ // Retrying these errors definitely shouldn't help.
+ HttpResponseError {
+ status_code:
+ StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
+ ..
+ }
+ | SerializeRequest { .. }
+ | BuildRequestBody { .. }
+ | PromptTooLarge { .. }
+ | AuthenticationError { .. }
+ | PermissionError { .. }
+ | ApiEndpointNotFound { .. }
+ | NoApiKey { .. } => None,
+ // Retry all other 4xx and 5xx errors once.
+ HttpResponseError { status_code, .. }
+ if status_code.is_client_error() || status_code.is_server_error() =>
+ {
+ Some(RetryStrategy::Fixed {
+ delay: BASE_RETRY_DELAY,
+ max_attempts: 1,
+ })
+ }
+ // Conservatively assume that any other errors are non-retryable
+ HttpResponseError { .. } | Other(..) => None,
+ }
}
fn handle_retryable_error_with_delay(
&mut self,
error: &LanguageModelCompletionError,
- custom_delay: Option<Duration>,
+ strategy: Option<RetryStrategy>,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
+ let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
+ return false;
+ };
+
+ let max_attempts = match &strategy {
+ RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+ RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+ };
+
let retry_state = self.retry_state.get_or_insert(RetryState {
attempt: 0,
- max_attempts: MAX_RETRY_ATTEMPTS,
+ max_attempts,
intent,
});
@@ -2238,20 +2212,24 @@ impl Thread {
let intent = retry_state.intent;
if attempt <= max_attempts {
- // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
- let delay = if let Some(custom_delay) = custom_delay {
- custom_delay
- } else {
- let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
- Duration::from_secs(delay_secs)
+ let delay = match &strategy {
+ RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+ let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+ Duration::from_secs(delay_secs)
+ }
+ RetryStrategy::Fixed { delay, .. } => *delay,
};
// Add a transient message to inform the user
let delay_secs = delay.as_secs();
- let retry_message = format!(
- "{error}. Retrying (attempt {attempt} of {max_attempts}) \
- in {delay_secs} seconds..."
- );
+ let retry_message = if max_attempts == 1 {
+ format!("{error}. Retrying in {delay_secs} seconds...")
+ } else {
+ format!(
+ "{error}. Retrying (attempt {attempt} of {max_attempts}) \
+ in {delay_secs} seconds..."
+ )
+ };
log::warn!(
"Retrying completion request (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds: {error:?}",
@@ -2290,19 +2268,9 @@ impl Thread {
// Max retries exceeded
self.retry_state = None;
- let notification_text = if max_attempts == 1 {
- "Failed after retrying.".into()
- } else {
- format!("Failed after retrying {} times.", max_attempts).into()
- };
-
// Stop generating since we're giving up on retrying.
self.pending_completions.clear();
- cx.emit(ThreadEvent::RetriesFailed {
- message: notification_text,
- });
-
false
}
}
@@ -3258,9 +3226,6 @@ pub enum ThreadEvent {
CancelEditing,
CompletionCanceled,
ProfileChanged,
- RetriesFailed {
- message: SharedString,
- },
}
impl EventEmitter<ThreadEvent> for Thread {}
@@ -4192,7 +4157,7 @@ fn main() {{
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
assert_eq!(
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
- "Should have default max attempts"
+ "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
);
});
@@ -4265,7 +4230,7 @@ fn main() {{
let retry_state = thread.retry_state.as_ref().unwrap();
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
assert_eq!(
- retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
+ retry_state.max_attempts, 1,
"Should have correct max attempts"
);
});
@@ -4281,8 +4246,8 @@ fn main() {{
if let MessageSegment::Text(text) = seg {
text.contains("internal")
&& text.contains("Fake")
- && text
- .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
+ && text.contains("Retrying in")
+ && !text.contains("attempt")
} else {
false
}
@@ -4320,8 +4285,8 @@ fn main() {{
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
- // Create model that returns overloaded error
- let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
+ // Create model that returns internal server error
+ let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
// Insert a user message
thread.update(cx, |thread, cx| {
@@ -4371,50 +4336,17 @@ fn main() {{
assert!(thread.retry_state.is_some(), "Should have retry state");
let retry_state = thread.retry_state.as_ref().unwrap();
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
- });
-
- // Advance clock for first retry
- cx.executor()
- .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
- cx.run_until_parked();
-
- // Should have scheduled second retry - count retry messages
- let retry_count = thread.update(cx, |thread, _| {
- thread
- .messages
- .iter()
- .filter(|m| {
- m.ui_only
- && m.segments.iter().any(|s| {
- if let MessageSegment::Text(text) = s {
- text.contains("Retrying") && text.contains("seconds")
- } else {
- false
- }
- })
- })
- .count()
- });
- assert_eq!(retry_count, 2, "Should have scheduled second retry");
-
- // Check retry state updated
- thread.read_with(cx, |thread, _| {
- assert!(thread.retry_state.is_some(), "Should have retry state");
- let retry_state = thread.retry_state.as_ref().unwrap();
- assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
assert_eq!(
- retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
- "Should have correct max attempts"
+ retry_state.max_attempts, 1,
+ "Internal server errors should only retry once"
);
});
- // Advance clock for second retry (exponential backoff)
- cx.executor()
- .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
+ // Advance clock for first retry
+ cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
- // Should have scheduled third retry
- // Count all retry messages now
+ // Should have scheduled second retry - count retry messages
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
@@ -4432,56 +4364,24 @@ fn main() {{
.count()
});
assert_eq!(
- retry_count, MAX_RETRY_ATTEMPTS as usize,
- "Should have scheduled third retry"
+ retry_count, 1,
+ "Should have only one retry for internal server errors"
);
- // Check retry state updated
+ // For internal server errors, we only retry once and then give up
+ // Check that retry_state is cleared after the single retry
thread.read_with(cx, |thread, _| {
- assert!(thread.retry_state.is_some(), "Should have retry state");
- let retry_state = thread.retry_state.as_ref().unwrap();
- assert_eq!(
- retry_state.attempt, MAX_RETRY_ATTEMPTS,
- "Should be at max retry attempt"
- );
- assert_eq!(
- retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
- "Should have correct max attempts"
+ assert!(
+ thread.retry_state.is_none(),
+ "Retry state should be cleared after single retry"
);
});
- // Advance clock for third retry (exponential backoff)
- cx.executor()
- .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
- cx.run_until_parked();
-
- // No more retries should be scheduled after clock was advanced.
- let retry_count = thread.update(cx, |thread, _| {
- thread
- .messages
- .iter()
- .filter(|m| {
- m.ui_only
- && m.segments.iter().any(|s| {
- if let MessageSegment::Text(text) = s {
- text.contains("Retrying") && text.contains("seconds")
- } else {
- false
- }
- })
- })
- .count()
- });
- assert_eq!(
- retry_count, MAX_RETRY_ATTEMPTS as usize,
- "Should not exceed max retries"
- );
-
- // Final completion count should be initial + max retries
+ // Verify total attempts (1 initial + 1 retry)
assert_eq!(
*completion_count.lock(),
- (MAX_RETRY_ATTEMPTS + 1) as usize,
- "Should have made initial + max retry attempts"
+ 2,
+ "Should have attempted once plus 1 retry"
);
}
@@ -4501,13 +4401,13 @@ fn main() {{
});
// Track events
- let retries_failed = Arc::new(Mutex::new(false));
- let retries_failed_clone = retries_failed.clone();
+ let stopped_with_error = Arc::new(Mutex::new(false));
+ let stopped_with_error_clone = stopped_with_error.clone();
let _subscription = thread.update(cx, |_, cx| {
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
- if let ThreadEvent::RetriesFailed { .. } = event {
- *retries_failed_clone.lock() = true;
+ if let ThreadEvent::Stopped(Err(_)) = event {
+ *stopped_with_error_clone.lock() = true;
}
})
});
@@ -4519,23 +4419,11 @@ fn main() {{
cx.run_until_parked();
// Advance through all retries
- for i in 0..MAX_RETRY_ATTEMPTS {
- let delay = if i == 0 {
- BASE_RETRY_DELAY_SECS
- } else {
- BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
- };
- cx.executor().advance_clock(Duration::from_secs(delay));
+ for _ in 0..MAX_RETRY_ATTEMPTS {
+ cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
}
- // After the 3rd retry is scheduled, we need to wait for it to execute and fail
- // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
- let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
- cx.executor()
- .advance_clock(Duration::from_secs(final_delay));
- cx.run_until_parked();
-
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
@@ -4553,14 +4441,14 @@ fn main() {{
.count()
});
- // After max retries, should emit RetriesFailed event
+ // After max retries, should emit Stopped(Err(...)) event
assert_eq!(
retry_count, MAX_RETRY_ATTEMPTS as usize,
- "Should have attempted max retries"
+ "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
);
assert!(
- *retries_failed.lock(),
- "Should emit RetriesFailed event after max retries exceeded"
+ *stopped_with_error.lock(),
+ "Should emit Stopped(Err(...)) event after max retries exceeded"
);
// Retry state should be cleared
@@ -4578,7 +4466,7 @@ fn main() {{
.count();
assert_eq!(
retry_messages, MAX_RETRY_ATTEMPTS as usize,
- "Should have one retry message per attempt"
+ "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
);
});
}
@@ -4716,8 +4604,7 @@ fn main() {{
});
// Wait for retry
- cx.executor()
- .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
+ cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// Stream some successful content
@@ -4879,8 +4766,7 @@ fn main() {{
});
// Wait for retry delay
- cx.executor()
- .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
+ cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// The retry should now use our FailOnceModel which should succeed
@@ -5039,9 +4925,15 @@ fn main() {{
thread.read_with(cx, |thread, _| {
assert!(
- thread.retry_state.is_none(),
- "Rate limit errors should not set retry_state"
+ thread.retry_state.is_some(),
+ "Rate limit errors should set retry_state"
);
+ if let Some(retry_state) = &thread.retry_state {
+ assert_eq!(
+ retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
+ "Rate limit errors should use MAX_RETRY_ATTEMPTS"
+ );
+ }
});
// Verify we have one retry message
@@ -5074,18 +4966,15 @@ fn main() {{
.find(|msg| msg.role == Role::System && msg.ui_only)
.expect("Should have a retry message");
- // Check that the message doesn't contain attempt count
+ // Check that the message contains attempt count since we use retry_state
if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
assert!(
- !text.contains("attempt"),
- "Rate limit retry message should not contain attempt count"
+ text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
+ "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
);
assert!(
- text.contains(&format!(
- "Retrying in {} seconds",
- TEST_RATE_LIMIT_RETRY_SECS
- )),
- "Rate limit retry message should contain retry delay"
+ text.contains("Retrying"),
+ "Rate limit retry message should contain retry text"
);
}
});