assistant: Add support for displaying billing-related errors (#19082) (#19097)

Marshall Bowers , Antonio , and Richard created

Cherry-picking this change to Preview.

This PR adds support to the assistant for display billing-related
errors.

Pulling this out of #19081 to make it easier to cherry-pick.

Release Notes:

- N/A

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>

Change summary

Cargo.lock                                  |   1 
crates/assistant/src/assistant_panel.rs     | 214 +++++++++++++++++-----
crates/assistant/src/context.rs             |  48 +++-
crates/language_model/Cargo.toml            |   1 
crates/language_model/src/provider/cloud.rs | 138 ++++++++++++-
crates/language_model/src/registry.rs       |   3 
crates/proto/proto/zed.proto                |   5 
crates/proto/src/proto.rs                   |   1 
crates/rpc/src/llm.rs                       |   2 
9 files changed, 330 insertions(+), 83 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6302,6 +6302,7 @@ dependencies = [
  "strum 0.25.0",
  "text",
  "theme",
+ "thiserror",
  "tiktoken-rs",
  "ui",
  "unindent",

crates/assistant/src/assistant_panel.rs 🔗

@@ -1496,6 +1496,13 @@ struct WorkflowAssist {
 
 type MessageHeader = MessageMetadata;
 
+#[derive(Clone)]
+enum AssistError {
+    PaymentRequired,
+    MaxMonthlySpendReached,
+    Message(SharedString),
+}
+
 pub struct ContextEditor {
     context: Model<Context>,
     fs: Arc<dyn Fs>,
@@ -1514,7 +1521,7 @@ pub struct ContextEditor {
     workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
     active_workflow_step: Option<ActiveWorkflowStep>,
     assistant_panel: WeakView<AssistantPanel>,
-    error_message: Option<SharedString>,
+    last_error: Option<AssistError>,
     show_accept_terms: bool,
     pub(crate) slash_menu_handle:
         PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
@@ -1585,7 +1592,7 @@ impl ContextEditor {
             workflow_steps: HashMap::default(),
             active_workflow_step: None,
             assistant_panel,
-            error_message: None,
+            last_error: None,
             show_accept_terms: false,
             slash_menu_handle: Default::default(),
             dragged_file_worktrees: Vec::new(),
@@ -1629,7 +1636,7 @@ impl ContextEditor {
         }
 
         if !self.apply_active_workflow_step(cx) {
-            self.error_message = None;
+            self.last_error = None;
             self.send_to_model(cx);
             cx.notify();
         }
@@ -1779,7 +1786,7 @@ impl ContextEditor {
     }
 
     fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
-        self.error_message = None;
+        self.last_error = None;
 
         if self
             .context
@@ -2284,7 +2291,13 @@ impl ContextEditor {
             }
             ContextEvent::Operation(_) => {}
             ContextEvent::ShowAssistError(error_message) => {
-                self.error_message = Some(error_message.clone());
+                self.last_error = Some(AssistError::Message(error_message.clone()));
+            }
+            ContextEvent::ShowPaymentRequiredError => {
+                self.last_error = Some(AssistError::PaymentRequired);
+            }
+            ContextEvent::ShowMaxMonthlySpendReachedError => {
+                self.last_error = Some(AssistError::MaxMonthlySpendReached);
             }
         }
     }
@@ -4298,6 +4311,154 @@ impl ContextEditor {
                 focus_handle.dispatch_action(&Assist, cx);
             })
     }
+
+    fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
+        let last_error = self.last_error.as_ref()?;
+
+        Some(
+            div()
+                .absolute()
+                .right_3()
+                .bottom_12()
+                .max_w_96()
+                .py_2()
+                .px_3()
+                .elevation_2(cx)
+                .occlude()
+                .child(match last_error {
+                    AssistError::PaymentRequired => self.render_payment_required_error(cx),
+                    AssistError::MaxMonthlySpendReached => {
+                        self.render_max_monthly_spend_reached_error(cx)
+                    }
+                    AssistError::Message(error_message) => {
+                        self.render_assist_error(error_message, cx)
+                    }
+                })
+                .into_any(),
+        )
+    }
+
+    fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+        const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
+        const SUBSCRIBE_URL: &str = "https://zed.dev/ai/subscribe";
+
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_24()
+                    .overflow_y_scroll()
+                    .child(Label::new(ERROR_MESSAGE)),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.open_url(SUBSCRIBE_URL);
+                            cx.notify();
+                        },
+                    )))
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
+
+    fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+        const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
+        const ACCOUNT_URL: &str = "https://zed.dev/account";
+
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_24()
+                    .overflow_y_scroll()
+                    .child(Label::new(ERROR_MESSAGE)),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(
+                        Button::new("subscribe", "Update Monthly Spend Limit").on_click(
+                            cx.listener(|this, _, cx| {
+                                this.last_error = None;
+                                cx.open_url(ACCOUNT_URL);
+                                cx.notify();
+                            }),
+                        ),
+                    )
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
+
+    fn render_assist_error(
+        &self,
+        error_message: &SharedString,
+        cx: &mut ViewContext<Self>,
+    ) -> AnyElement {
+        v_flex()
+            .gap_0p5()
+            .child(
+                h_flex()
+                    .gap_1p5()
+                    .items_center()
+                    .child(Icon::new(IconName::XCircle).color(Color::Error))
+                    .child(
+                        Label::new("Error interacting with language model")
+                            .weight(FontWeight::MEDIUM),
+                    ),
+            )
+            .child(
+                div()
+                    .id("error-message")
+                    .max_h_24()
+                    .overflow_y_scroll()
+                    .child(Label::new(error_message.clone())),
+            )
+            .child(
+                h_flex()
+                    .justify_end()
+                    .mt_1()
+                    .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+                        |this, _, cx| {
+                            this.last_error = None;
+                            cx.notify();
+                        },
+                    ))),
+            )
+            .into_any()
+    }
 }
 
 /// Returns the contents of the *outermost* fenced code block that contains the given offset.
@@ -4434,48 +4595,7 @@ impl Render for ContextEditor {
                         .child(element),
                 )
             })
-            .when_some(self.error_message.clone(), |this, error_message| {
-                this.child(
-                    div()
-                        .absolute()
-                        .right_3()
-                        .bottom_12()
-                        .max_w_96()
-                        .py_2()
-                        .px_3()
-                        .elevation_2(cx)
-                        .occlude()
-                        .child(
-                            v_flex()
-                                .gap_0p5()
-                                .child(
-                                    h_flex()
-                                        .gap_1p5()
-                                        .items_center()
-                                        .child(Icon::new(IconName::XCircle).color(Color::Error))
-                                        .child(
-                                            Label::new("Error interacting with language model")
-                                                .weight(FontWeight::MEDIUM),
-                                        ),
-                                )
-                                .child(
-                                    div()
-                                        .id("error-message")
-                                        .max_h_24()
-                                        .overflow_y_scroll()
-                                        .child(Label::new(error_message)),
-                                )
-                                .child(h_flex().justify_end().mt_1().child(
-                                    Button::new("dismiss", "Dismiss").on_click(cx.listener(
-                                        |this, _, cx| {
-                                            this.error_message = None;
-                                            cx.notify();
-                                        },
-                                    )),
-                                )),
-                        ),
-                )
-            })
+            .children(self.render_last_error(cx))
             .child(
                 h_flex().w_full().relative().child(
                     h_flex()

crates/assistant/src/context.rs 🔗

@@ -26,6 +26,7 @@ use gpui::{
 
 use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
 use language_model::{
+    provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
     LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
     LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
     LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
@@ -294,6 +295,8 @@ impl ContextOperation {
 #[derive(Debug, Clone)]
 pub enum ContextEvent {
     ShowAssistError(SharedString),
+    ShowPaymentRequiredError,
+    ShowMaxMonthlySpendReachedError,
     MessagesEdited,
     SummaryChanged,
     StreamedCompletion,
@@ -2112,25 +2115,36 @@ impl Context {
                 let result = stream_completion.await;
 
                 this.update(&mut cx, |this, cx| {
-                    let error_message = result
-                        .as_ref()
-                        .err()
-                        .map(|error| error.to_string().trim().to_string());
-
-                    if let Some(error_message) = error_message.as_ref() {
-                        cx.emit(ContextEvent::ShowAssistError(SharedString::from(
-                            error_message.clone(),
-                        )));
-                    }
-
-                    this.update_metadata(assistant_message_id, cx, |metadata| {
-                        if let Some(error_message) = error_message.as_ref() {
-                            metadata.status =
-                                MessageStatus::Error(SharedString::from(error_message.clone()));
+                    let error_message = if let Some(error) = result.as_ref().err() {
+                        if error.is::<PaymentRequiredError>() {
+                            cx.emit(ContextEvent::ShowPaymentRequiredError);
+                            this.update_metadata(assistant_message_id, cx, |metadata| {
+                                metadata.status = MessageStatus::Canceled;
+                            });
+                            Some(error.to_string())
+                        } else if error.is::<MaxMonthlySpendReachedError>() {
+                            cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
+                            this.update_metadata(assistant_message_id, cx, |metadata| {
+                                metadata.status = MessageStatus::Canceled;
+                            });
+                            Some(error.to_string())
                         } else {
-                            metadata.status = MessageStatus::Done;
+                            let error_message = error.to_string().trim().to_string();
+                            cx.emit(ContextEvent::ShowAssistError(SharedString::from(
+                                error_message.clone(),
+                            )));
+                            this.update_metadata(assistant_message_id, cx, |metadata| {
+                                metadata.status =
+                                    MessageStatus::Error(SharedString::from(error_message.clone()));
+                            });
+                            Some(error_message)
                         }
-                    });
+                    } else {
+                        this.update_metadata(assistant_message_id, cx, |metadata| {
+                            metadata.status = MessageStatus::Done;
+                        });
+                        None
+                    };
 
                     if let Some(telemetry) = this.telemetry.as_ref() {
                         let language_name = this

crates/language_model/Cargo.toml 🔗

@@ -47,6 +47,7 @@ settings.workspace = true
 smol.workspace = true
 strum.workspace = true
 theme.workspace = true
+thiserror.workspace = true
 tiktoken-rs.workspace = true
 ui.workspace = true
 util.workspace = true

crates/language_model/src/provider/cloud.rs 🔗

@@ -7,7 +7,10 @@ use crate::{
 };
 use anthropic::AnthropicError;
 use anyhow::{anyhow, Result};
-use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use client::{
+    Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
+    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
+};
 use collections::BTreeMap;
 use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 use futures::{
@@ -15,10 +18,11 @@ use futures::{
     TryStreamExt as _,
 };
 use gpui::{
-    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
-    Subscription, Task,
+    AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
+    ModelContext, ReadGlobal, Subscription, Task,
 };
-use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
+use proto::TypedEnvelope;
 use schemars::JsonSchema;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::value::RawValue;
@@ -27,12 +31,14 @@ use smol::{
     io::{AsyncReadExt, BufReader},
     lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 };
+use std::fmt;
 use std::time::Duration;
 use std::{
     future,
     sync::{Arc, LazyLock},
 };
 use strum::IntoEnumIterator;
+use thiserror::Error;
 use ui::{prelude::*, TintColor};
 
 use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
@@ -90,22 +96,93 @@ pub struct AvailableModel {
     pub default_temperature: Option<f32>,
 }
 
+struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct RefreshLlmTokenEvent;
+
+pub struct RefreshLlmTokenListener {
+    _llm_token_subscription: client::Subscription,
+}
+
+impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+    pub fn register(client: Arc<Client>, cx: &mut AppContext) {
+        let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
+        cx.set_global(GlobalRefreshLlmTokenListener(listener));
+    }
+
+    pub fn global(cx: &AppContext) -> Model<Self> {
+        GlobalRefreshLlmTokenListener::global(cx).0.clone()
+    }
+
+    fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
+        Self {
+            _llm_token_subscription: client
+                .add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
+        }
+    }
+
+    async fn handle_refresh_llm_token(
+        this: Model<Self>,
+        _: TypedEnvelope<proto::RefreshLlmToken>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
+    }
+}
+
 pub struct CloudLanguageModelProvider {
     client: Arc<Client>,
-    llm_api_token: LlmApiToken,
     state: gpui::Model<State>,
     _maintain_client_status: Task<()>,
 }
 
 pub struct State {
     client: Arc<Client>,
+    llm_api_token: LlmApiToken,
     user_store: Model<UserStore>,
     status: client::Status,
     accept_terms: Option<Task<Result<()>>>,
-    _subscription: Subscription,
+    _settings_subscription: Subscription,
+    _llm_token_subscription: Subscription,
 }
 
 impl State {
+    fn new(
+        client: Arc<Client>,
+        user_store: Model<UserStore>,
+        status: client::Status,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
+        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+
+        Self {
+            client: client.clone(),
+            llm_api_token: LlmApiToken::default(),
+            user_store,
+            status,
+            accept_terms: None,
+            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+                cx.notify();
+            }),
+            _llm_token_subscription: cx.subscribe(
+                &refresh_llm_token_listener,
+                |this, _listener, _event, cx| {
+                    let client = this.client.clone();
+                    let llm_api_token = this.llm_api_token.clone();
+                    cx.spawn(|_this, _cx| async move {
+                        llm_api_token.refresh(&client).await?;
+                        anyhow::Ok(())
+                    })
+                    .detach_and_log_err(cx);
+                },
+            ),
+        }
+    }
+
     fn is_signed_out(&self) -> bool {
         self.status.is_signed_out()
     }
@@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
         let mut status_rx = client.status();
         let status = *status_rx.borrow();
 
-        let state = cx.new_model(|cx| State {
-            client: client.clone(),
-            user_store,
-            status,
-            accept_terms: None,
-            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
-                cx.notify();
-            }),
-        });
+        let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 
         let state_ref = state.downgrade();
         let maintain_client_status = cx.spawn(|mut cx| async move {
@@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
 
         Self {
             client,
-            state,
-            llm_api_token: LlmApiToken::default(),
+            state: state.clone(),
             _maintain_client_status: maintain_client_status,
         }
     }
@@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
             models.insert(model.id().to_string(), model.clone());
         }
 
+        let llm_api_token = self.state.read(cx).llm_api_token.clone();
         models
             .into_values()
             .map(|model| {
                 Arc::new(CloudLanguageModel {
                     id: LanguageModelId::from(model.id().to_string()),
                     model,
-                    llm_api_token: self.llm_api_token.clone(),
+                    llm_api_token: llm_api_token.clone(),
                     client: self.client.clone(),
                     request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
@@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
 #[derive(Clone, Default)]
 struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
+#[derive(Error, Debug)]
+pub struct PaymentRequiredError;
+
+impl fmt::Display for PaymentRequiredError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "Payment required to use this language model. Please upgrade your account."
+        )
+    }
+}
+
+#[derive(Error, Debug)]
+pub struct MaxMonthlySpendReachedError;
+
+impl fmt::Display for MaxMonthlySpendReachedError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "Maximum spending limit reached for this month. For more usage, increase your spending limit."
+        )
+    }
+}
+
 impl CloudLanguageModel {
     async fn perform_llm_completion(
         client: Arc<Client>,
@@ -411,6 +504,15 @@ impl CloudLanguageModel {
             {
                 did_retry = true;
                 token = llm_api_token.refresh(&client).await?;
+            } else if response.status() == StatusCode::FORBIDDEN
+                && response
+                    .headers()
+                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
+                    .is_some()
+            {
+                break Err(anyhow!(MaxMonthlySpendReachedError))?;
+            } else if response.status() == StatusCode::PAYMENT_REQUIRED {
+                break Err(anyhow!(PaymentRequiredError))?;
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;

crates/language_model/src/registry.rs 🔗

@@ -1,3 +1,4 @@
+use crate::provider::cloud::RefreshLlmTokenListener;
 use crate::{
     provider::{
         anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
@@ -30,6 +31,8 @@ fn register_language_model_providers(
 ) {
     use feature_flags::FeatureFlagAppExt;
 
+    RefreshLlmTokenListener::register(client.clone(), cx);
+
     registry.register_provider(
         AnthropicLanguageModelProvider::new(client.http_client(), cx),
         cx,

crates/proto/proto/zed.proto 🔗

@@ -271,6 +271,7 @@ message Envelope {
 
         GetLlmToken get_llm_token = 235;
         GetLlmTokenResponse get_llm_token_response = 236;
+        RefreshLlmToken refresh_llm_token = 259; // current max
 
         LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
         LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
@@ -284,7 +285,7 @@ message Envelope {
         CheckFileExists check_file_exists = 255;
         CheckFileExistsResponse check_file_exists_response = 256;
 
-        ShutdownRemoteServer shutdown_remote_server = 257; // current max
+        ShutdownRemoteServer shutdown_remote_server = 257;
     }
 
     reserved 87 to 88;
@@ -2482,6 +2483,8 @@ message GetLlmTokenResponse {
     string token = 1;
 }
 
+message RefreshLlmToken {}
+
 // Remote FS
 
 message AddWorktree {

crates/proto/src/proto.rs 🔗

@@ -253,6 +253,7 @@ messages!(
     (ProjectEntryResponse, Foreground),
     (CountLanguageModelTokens, Background),
     (CountLanguageModelTokensResponse, Background),
+    (RefreshLlmToken, Background),
     (RefreshInlayHints, Foreground),
     (RejoinChannelBuffers, Foreground),
     (RejoinChannelBuffersResponse, Foreground),

crates/rpc/src/llm.rs 🔗

@@ -3,6 +3,8 @@ use strum::{Display, EnumIter, EnumString};
 
 pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 
+pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached";
+
 #[derive(
     Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
 )]