assistant2: Wire up error messages (#21426)

Marshall Bowers created

This PR wires up the error messages for Assistant 2 so that they are
shown to the user:

<img width="1138" alt="Screenshot 2024-12-02 at 4 28 02 PM"
src="https://github.com/user-attachments/assets/d8a5b9bd-0cef-4304-b561-b2edadbc70ef">
<img width="1138" alt="Screenshot 2024-12-02 at 4 29 09 PM"
src="https://github.com/user-attachments/assets/0dd70841-0d5a-4de6-bebe-82c563246b65">
<img width="1138" alt="Screenshot 2024-12-02 at 4 32 49 PM"
src="https://github.com/user-attachments/assets/a8838866-fad1-43a9-8935-490dc1936016">

@danilo-leal I kept the existing UX from Assistant 1, as I didn't see
any errors in the design prototype, but we can revisit if another
approach would work better.

Release Notes:

- N/A

Change summary

Cargo.lock                                  |   2 
crates/assistant2/Cargo.toml                |   2 
crates/assistant2/src/assistant_panel.rs    | 160 ++++++++++++++++++++++
crates/assistant2/src/thread.rs             |  56 ++++---
crates/language_model/src/language_model.rs |   2 
5 files changed, 194 insertions(+), 28 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -456,6 +456,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "assistant_tool",
+ "client",
  "collections",
  "command_palette_hooks",
  "context_server",
@@ -465,6 +466,7 @@ dependencies = [
  "gpui",
  "language_model",
  "language_model_selector",
+ "language_models",
  "log",
  "project",
  "proto",

crates/assistant2/Cargo.toml 🔗

@@ -15,6 +15,7 @@ doctest = false
 [dependencies]
 anyhow.workspace = true
 assistant_tool.workspace = true
+client.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
 context_server.workspace = true
@@ -24,6 +25,7 @@ futures.workspace = true
 gpui.workspace = true
 language_model.workspace = true
 language_model_selector.workspace = true
+language_models.workspace = true
 log.workspace = true
 project.workspace = true
 proto.workspace = true

crates/assistant2/src/assistant_panel.rs 🔗

@@ -2,9 +2,11 @@ use std::sync::Arc;
 
 use anyhow::Result;
 use assistant_tool::ToolWorkingSet;
+use client::zed_urls;
 use gpui::{
-    prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
-    FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
+    prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
+    FocusableView, FontWeight, Model, Pixels, Subscription, Task, View, ViewContext, WeakView,
+    WindowContext,
 };
 use language_model::{LanguageModelRegistry, Role};
 use language_model_selector::LanguageModelSelector;
@@ -13,7 +15,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
 use workspace::Workspace;
 
 use crate::message_editor::MessageEditor;
-use crate::thread::{Message, Thread, ThreadEvent};
+use crate::thread::{Message, Thread, ThreadError, ThreadEvent};
 use crate::thread_store::ThreadStore;
 use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 
@@ -35,6 +37,7 @@ pub struct AssistantPanel {
     thread: Model<Thread>,
     message_editor: View<MessageEditor>,
     tools: Arc<ToolWorkingSet>,
+    last_error: Option<ThreadError>,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -76,6 +79,7 @@ impl AssistantPanel {
             thread: thread.clone(),
             message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
             tools,
+            last_error: None,
             _subscriptions: subscriptions,
         }
     }
@@ -102,6 +106,9 @@ impl AssistantPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
+            ThreadEvent::ShowError(error) => {
+                self.last_error = Some(error.clone());
+            }
             ThreadEvent::StreamedCompletion => {}
             ThreadEvent::UsePendingTools => {
                 let pending_tool_uses = self
@@ -320,6 +327,152 @@ impl AssistantPanel {
             )
             .child(v_flex().p_1p5().child(Label::new(message.text.clone())))
     }
+
+    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
+        let last_error = self.last_error.as_ref()?;
+
+        Some(
+            div()
+                .absolute()
+                .right_3()
+                .bottom_12()
+                .max_w_96()
+                .py_2()
+                .px_3()
+                .elevation_2(cx)
+                .occlude()
+                .child(match last_error {
+                    ThreadError::PaymentRequired => self.render_payment_required_error(cx),
+                    ThreadError::MaxMonthlySpendReached => {
+                        self.render_max_monthly_spend_reached_error(cx)
+                    }
+                    ThreadError::Message(error_message) => {
+                        self.render_error_message(error_message, cx)
+                    }
+                })
+                .into_any(),
+        )
+    }
+
+    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
+
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_24()
+                    .overflow_y_scroll()
+                    .child(Label::new(ERROR_MESSAGE)),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.open_url(&zed_urls::account_url(cx));
+                            cx.notify();
+                        },
+                    )))
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
+
+    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
+
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_24()
+                    .overflow_y_scroll()
+                    .child(Label::new(ERROR_MESSAGE)),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(
+                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
+                            cx.listener(|this, _, cx| {
+                                this.last_error = None;
+                                cx.open_url(&zed_urls::account_url(cx));
+                                cx.notify();
+                            }),
+                        ),
+                    )
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
+
+    fn render_error_message(
+        &self,
+        error_message: &SharedString,
+        cx: &mut ViewContext<Self>,
+    ) -> AnyElement {
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(
+                        Label::new("Error interacting with language model")
+                            .weight(FontWeight::MEDIUM),
+                    ),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_32()
+                    .overflow_y_scroll()
+                    .child(Label::new(error_message.clone())),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
 }
 
 impl Render for AssistantPanel {
@@ -354,5 +507,6 @@ impl Render for AssistantPanel {
                     .border_color(cx.theme().colors().border_variant)
                     .child(self.message_editor.clone()),
             )
+            .children(self.render_last_error(cx))
     }
 }

crates/assistant2/src/thread.rs 🔗

@@ -5,12 +5,13 @@ use assistant_tool::ToolWorkingSet;
 use collections::HashMap;
 use futures::future::Shared;
 use futures::{FutureExt as _, StreamExt as _};
-use gpui::{AppContext, EventEmitter, ModelContext, Task};
+use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
     LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
     StopReason,
 };
+use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 use serde::{Deserialize, Serialize};
 use util::post_inc;
 
@@ -210,29 +211,28 @@ impl Thread {
             let result = stream_completion.await;
 
             thread
-                .update(&mut cx, |_thread, cx| {
-                    let error_message = if let Some(error) = result.as_ref().err() {
-                        let error_message = error
-                            .chain()
-                            .map(|err| err.to_string())
-                            .collect::<Vec<_>>()
-                            .join("\n");
-                        Some(error_message)
-                    } else {
-                        None
-                    };
-
-                    if let Some(error_message) = error_message {
-                        eprintln!("Completion failed: {error_message:?}");
-                    }
-
-                    if let Ok(stop_reason) = result {
-                        match stop_reason {
-                            StopReason::ToolUse => {
-                                cx.emit(ThreadEvent::UsePendingTools);
-                            }
-                            StopReason::EndTurn => {}
-                            StopReason::MaxTokens => {}
+                .update(&mut cx, |_thread, cx| match result.as_ref() {
+                    Ok(stop_reason) => match stop_reason {
+                        StopReason::ToolUse => {
+                            cx.emit(ThreadEvent::UsePendingTools);
+                        }
+                        StopReason::EndTurn => {}
+                        StopReason::MaxTokens => {}
+                    },
+                    Err(error) => {
+                        if error.is::<PaymentRequiredError>() {
+                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
+                        } else if error.is::<MaxMonthlySpendReachedError>() {
+                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
+                        } else {
+                            let error_message = error
+                                .chain()
+                                .map(|err| err.to_string())
+                                .collect::<Vec<_>>()
+                                .join("\n");
+                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
+                                SharedString::from(error_message.clone()),
+                            )));
                         }
                     }
                 })
@@ -305,8 +305,16 @@ impl Thread {
     }
 }
 
+#[derive(Debug, Clone)]
+pub enum ThreadError {
+    PaymentRequired,
+    MaxMonthlySpendReached,
+    Message(SharedString),
+}
+
 #[derive(Debug, Clone)]
 pub enum ThreadEvent {
+    ShowError(ThreadError),
     StreamedCompletion,
     UsePendingTools,
     ToolFinished {

crates/language_model/src/language_model.rs 🔗

@@ -55,7 +55,7 @@ pub enum LanguageModelCompletionEvent {
     StartMessage { message_id: String },
 }
 
-#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 #[serde(rename_all = "snake_case")]
 pub enum StopReason {
     EndTurn,