Simplify error management in stream_completion (#43035)

Mikayla Maki and Michael Benfield created

This PR simplifies error and event handling by removing the
`Ok(LanguageModelCompletionEvent::Status(CompletionRequestStatus::Failed)))`
state from the stream returned by `LanguageModel::stream_completion()`,
by changing it into an `Err(LanguageModelCompletionError)`. This was
done by collapsing the valid `CompletionRequestStatus` values into
`LanguageModelCompletionEvent`.

Release Notes:

- N/A

---------

Co-authored-by: Michael Benfield <mbenfield@zed.dev>

Change summary

crates/agent/src/tests/mod.rs                   |  8 --
crates/agent/src/thread.rs                      | 20 ++-----
crates/assistant_text_thread/src/text_thread.rs | 20 ++++---
crates/eval/src/instance.rs                     | 12 +++-
crates/language_model/src/language_model.rs     | 48 +++++++++++++++++-
crates/language_models/src/provider/cloud.rs    | 27 +++++++++
6 files changed, 98 insertions(+), 37 deletions(-)

Detailed changes

crates/agent/src/tests/mod.rs 🔗

@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
     );
 
     // Simulate reaching tool use limit.
-    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
-        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
-    ));
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
     fake_model.end_last_completion_stream();
     let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
     assert!(
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
     };
     fake_model
         .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
-    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
-        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
-    ));
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
     fake_model.end_last_completion_stream();
     let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
     assert!(

crates/agent/src/thread.rs 🔗

@@ -15,7 +15,7 @@ use agent_settings::{
 use anyhow::{Context as _, Result, anyhow};
 use chrono::{DateTime, Utc};
 use client::{ModelRequestUsage, RequestUsage, UserStore};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
+use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
 use collections::{HashMap, HashSet, IndexMap};
 use fs::Fs;
 use futures::stream;
@@ -1430,20 +1430,16 @@ impl Thread {
                 );
                 self.update_token_usage(usage, cx);
             }
-            StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
+            UsageUpdated { amount, limit } => {
                 self.update_model_request_usage(amount, limit, cx);
             }
-            StatusUpdate(
-                CompletionRequestStatus::Started
-                | CompletionRequestStatus::Queued { .. }
-                | CompletionRequestStatus::Failed { .. },
-            ) => {}
-            StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
+            ToolUseLimitReached => {
                 self.tool_use_limit_reached = true;
             }
             Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
             Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
             Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
+            Started | Queued { .. } => {}
         }
 
         Ok(None)
@@ -1687,9 +1683,7 @@ impl Thread {
                     let event = event.log_err()?;
                     let text = match event {
                         LanguageModelCompletionEvent::Text(text) => text,
-                        LanguageModelCompletionEvent::StatusUpdate(
-                            CompletionRequestStatus::UsageUpdated { amount, limit },
-                        ) => {
+                        LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
                             this.update(cx, |thread, cx| {
                                 thread.update_model_request_usage(amount, limit, cx);
                             })
@@ -1753,9 +1747,7 @@ impl Thread {
                     let event = event?;
                     let text = match event {
                         LanguageModelCompletionEvent::Text(text) => text,
-                        LanguageModelCompletionEvent::StatusUpdate(
-                            CompletionRequestStatus::UsageUpdated { amount, limit },
-                        ) => {
+                        LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
                             this.update(cx, |thread, cx| {
                                 thread.update_model_request_usage(amount, limit, cx);
                             })?;

crates/assistant_text_thread/src/text_thread.rs 🔗

@@ -7,9 +7,10 @@ use assistant_slash_command::{
 use assistant_slash_commands::FileCommandMetadata;
 use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
 use clock::ReplicaId;
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
+use cloud_llm_client::{CompletionIntent, UsageLimit};
 use collections::{HashMap, HashSet};
 use fs::{Fs, RenameOptions};
+
 use futures::{FutureExt, StreamExt, future::Shared};
 use gpui::{
     App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
@@ -2073,14 +2074,15 @@ impl TextThread {
                                     });
 
                                 match event {
-                                    LanguageModelCompletionEvent::StatusUpdate(status_update) => {
-                                        if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
-                                            this.update_model_request_usage(
-                                                amount as u32,
-                                                limit,
-                                                cx,
-                                            );
-                                        }
+                                    LanguageModelCompletionEvent::Started |
+                                    LanguageModelCompletionEvent::Queued {..} |
+                                    LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
+                                    LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
+                                        this.update_model_request_usage(
+                                            amount as u32,
+                                            limit,
+                                            cx,
+                                        );
                                     }
                                     LanguageModelCompletionEvent::StartMessage { .. } => {}
                                     LanguageModelCompletionEvent::Stop(reason) => {

crates/eval/src/instance.rs 🔗

@@ -1251,8 +1251,11 @@ pub fn response_events_to_markdown(
             }
             Ok(
                 LanguageModelCompletionEvent::UsageUpdate(_)
+                | LanguageModelCompletionEvent::ToolUseLimitReached
                 | LanguageModelCompletionEvent::StartMessage { .. }
-                | LanguageModelCompletionEvent::StatusUpdate { .. },
+                | LanguageModelCompletionEvent::UsageUpdated { .. }
+                | LanguageModelCompletionEvent::Queued { .. }
+                | LanguageModelCompletionEvent::Started,
             ) => {}
             Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
                 json_parse_error, ..
@@ -1337,9 +1340,12 @@ impl ThreadDialog {
                 // Skip these
                 Ok(LanguageModelCompletionEvent::UsageUpdate(_))
                 | Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
-                | Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
                 | Ok(LanguageModelCompletionEvent::StartMessage { .. })
-                | Ok(LanguageModelCompletionEvent::Stop(_)) => {}
+                | Ok(LanguageModelCompletionEvent::Stop(_))
+                | Ok(LanguageModelCompletionEvent::Queued { .. })
+                | Ok(LanguageModelCompletionEvent::Started)
+                | Ok(LanguageModelCompletionEvent::UsageUpdated { .. })
+                | Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
 
                 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
                     json_parse_error,

crates/language_model/src/language_model.rs 🔗

@@ -12,7 +12,7 @@ pub mod fake_provider;
 use anthropic::{AnthropicError, parse_prompt_too_long};
 use anyhow::{Result, anyhow};
 use client::Client;
-use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
+use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
 use futures::FutureExt;
 use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
 /// A completion event from a language model.
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub enum LanguageModelCompletionEvent {
-    StatusUpdate(CompletionRequestStatus),
+    Queued {
+        position: usize,
+    },
+    Started,
+    UsageUpdated {
+        amount: usize,
+        limit: UsageLimit,
+    },
+    ToolUseLimitReached,
     Stop(StopReason),
     Text(String),
     Thinking {
@@ -93,6 +101,37 @@ pub enum LanguageModelCompletionEvent {
     UsageUpdate(TokenUsage),
 }
 
+impl LanguageModelCompletionEvent {
+    pub fn from_completion_request_status(
+        status: CompletionRequestStatus,
+        upstream_provider: LanguageModelProviderName,
+    ) -> Result<Self, LanguageModelCompletionError> {
+        match status {
+            CompletionRequestStatus::Queued { position } => {
+                Ok(LanguageModelCompletionEvent::Queued { position })
+            }
+            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
+            CompletionRequestStatus::UsageUpdated { amount, limit } => {
+                Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
+            }
+            CompletionRequestStatus::ToolUseLimitReached => {
+                Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
+            }
+            CompletionRequestStatus::Failed {
+                code,
+                message,
+                request_id: _,
+                retry_after,
+            } => Err(LanguageModelCompletionError::from_cloud_failure(
+                upstream_provider,
+                code,
+                message,
+                retry_after.map(Duration::from_secs_f64),
+            )),
+        }
+    }
+}
+
 #[derive(Error, Debug)]
 pub enum LanguageModelCompletionError {
     #[error("prompt too large for context window")]
@@ -633,7 +672,10 @@ pub trait LanguageModel: Send + Sync {
                         let last_token_usage = last_token_usage.clone();
                         async move {
                             match result {
-                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
+                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
+                                Ok(LanguageModelCompletionEvent::Started) => None,
+                                Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
+                                Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
                                 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
                                 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
                                 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,

crates/language_models/src/provider/cloud.rs 🔗

@@ -752,6 +752,7 @@ impl LanguageModel for CloudLanguageModel {
         let mode = request.mode;
         let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
         let thinking_allowed = request.thinking_allowed;
+        let provider_name = provider_name(&self.model.provider);
         match self.model.provider {
             cloud_llm_client::LanguageModelProvider::Anthropic => {
                 let request = into_anthropic(
@@ -801,8 +802,9 @@ impl LanguageModel for CloudLanguageModel {
                         Box::pin(
                             response_lines(response, includes_status_messages)
                                 .chain(usage_updated_event(usage))
-                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
                         ),
+                        &provider_name,
                         move |event| mapper.map_event(event),
                     ))
                 });
@@ -849,6 +851,7 @@ impl LanguageModel for CloudLanguageModel {
                                 .chain(usage_updated_event(usage))
                                 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
                         ),
+                        &provider_name,
                         move |event| mapper.map_event(event),
                     ))
                 });
@@ -895,6 +898,7 @@ impl LanguageModel for CloudLanguageModel {
                                 .chain(usage_updated_event(usage))
                                 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
                         ),
+                        &provider_name,
                         move |event| mapper.map_event(event),
                     ))
                 });
@@ -935,6 +939,7 @@ impl LanguageModel for CloudLanguageModel {
                                 .chain(usage_updated_event(usage))
                                 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
                         ),
+                        &provider_name,
                         move |event| mapper.map_event(event),
                     ))
                 });
@@ -946,6 +951,7 @@ impl LanguageModel for CloudLanguageModel {
 
 fn map_cloud_completion_events<T, F>(
     stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
+    provider: &LanguageModelProviderName,
     mut map_callback: F,
 ) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 where
@@ -954,6 +960,7 @@ where
         + Send
         + 'static,
 {
+    let provider = provider.clone();
     stream
         .flat_map(move |event| {
             futures::stream::iter(match event {
@@ -961,7 +968,12 @@ where
                     vec![Err(LanguageModelCompletionError::from(error))]
                 }
                 Ok(CompletionEvent::Status(event)) => {
-                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
+                    vec![
+                        LanguageModelCompletionEvent::from_completion_request_status(
+                            event,
+                            provider.clone(),
+                        ),
+                    ]
                 }
                 Ok(CompletionEvent::Event(event)) => map_callback(event),
             })
@@ -969,6 +981,17 @@ where
         .boxed()
 }
 
+fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
+    match provider {
+        cloud_llm_client::LanguageModelProvider::Anthropic => {
+            language_model::ANTHROPIC_PROVIDER_NAME
+        }
+        cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
+        cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
+        cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
+    }
+}
+
 fn usage_updated_event<T>(
     usage: Option<ModelRequestUsage>,
 ) -> impl Stream<Item = Result<CompletionEvent<T>>> {