Cargo.lock 🔗
@@ -6302,6 +6302,7 @@ dependencies = [
"strum 0.25.0",
"text",
"theme",
+ "thiserror",
"tiktoken-rs",
"ui",
"unindent",
Marshall Bowers , Antonio , and Richard created
Cherry-picking this change to Preview.
This PR adds support to the assistant for display billing-related
errors.
Pulling this out of #19081 to make it easier to cherry-pick.
Release Notes:
- N/A
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Cargo.lock | 1
crates/assistant/src/assistant_panel.rs | 214 +++++++++++++++++-----
crates/assistant/src/context.rs | 48 +++-
crates/language_model/Cargo.toml | 1
crates/language_model/src/provider/cloud.rs | 138 ++++++++++++-
crates/language_model/src/registry.rs | 3
crates/proto/proto/zed.proto | 5
crates/proto/src/proto.rs | 1
crates/rpc/src/llm.rs | 2
9 files changed, 330 insertions(+), 83 deletions(-)
@@ -6302,6 +6302,7 @@ dependencies = [
"strum 0.25.0",
"text",
"theme",
+ "thiserror",
"tiktoken-rs",
"ui",
"unindent",
@@ -1496,6 +1496,13 @@ struct WorkflowAssist {
type MessageHeader = MessageMetadata;
+#[derive(Clone)]
+enum AssistError {
+ PaymentRequired,
+ MaxMonthlySpendReached,
+ Message(SharedString),
+}
+
pub struct ContextEditor {
context: Model<Context>,
fs: Arc<dyn Fs>,
@@ -1514,7 +1521,7 @@ pub struct ContextEditor {
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
active_workflow_step: Option<ActiveWorkflowStep>,
assistant_panel: WeakView<AssistantPanel>,
- error_message: Option<SharedString>,
+ last_error: Option<AssistError>,
show_accept_terms: bool,
pub(crate) slash_menu_handle:
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
@@ -1585,7 +1592,7 @@ impl ContextEditor {
workflow_steps: HashMap::default(),
active_workflow_step: None,
assistant_panel,
- error_message: None,
+ last_error: None,
show_accept_terms: false,
slash_menu_handle: Default::default(),
dragged_file_worktrees: Vec::new(),
@@ -1629,7 +1636,7 @@ impl ContextEditor {
}
if !self.apply_active_workflow_step(cx) {
- self.error_message = None;
+ self.last_error = None;
self.send_to_model(cx);
cx.notify();
}
@@ -1779,7 +1786,7 @@ impl ContextEditor {
}
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
- self.error_message = None;
+ self.last_error = None;
if self
.context
@@ -2284,7 +2291,13 @@ impl ContextEditor {
}
ContextEvent::Operation(_) => {}
ContextEvent::ShowAssistError(error_message) => {
- self.error_message = Some(error_message.clone());
+ self.last_error = Some(AssistError::Message(error_message.clone()));
+ }
+ ContextEvent::ShowPaymentRequiredError => {
+ self.last_error = Some(AssistError::PaymentRequired);
+ }
+ ContextEvent::ShowMaxMonthlySpendReachedError => {
+ self.last_error = Some(AssistError::MaxMonthlySpendReached);
}
}
}
@@ -4298,6 +4311,154 @@ impl ContextEditor {
focus_handle.dispatch_action(&Assist, cx);
})
}
+
+ fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
+ let last_error = self.last_error.as_ref()?;
+
+ Some(
+ div()
+ .absolute()
+ .right_3()
+ .bottom_12()
+ .max_w_96()
+ .py_2()
+ .px_3()
+ .elevation_2(cx)
+ .occlude()
+ .child(match last_error {
+ AssistError::PaymentRequired => self.render_payment_required_error(cx),
+ AssistError::MaxMonthlySpendReached => {
+ self.render_max_monthly_spend_reached_error(cx)
+ }
+ AssistError::Message(error_message) => {
+ self.render_assist_error(error_message, cx)
+ }
+ })
+ .into_any(),
+ )
+ }
+
+ fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+ const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
+ const SUBSCRIBE_URL: &str = "https://zed.dev/ai/subscribe";
+
+ v_flex()
+ .gap_0p5()
+ .child(
+ h_flex()
+ .gap_1p5()
+ .items_center()
+ .child(Icon::new(IconName::XCircle).color(Color::Error))
+ .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
+ )
+ .child(
+ div()
+ .id("error-message")
+ .max_h_24()
+ .overflow_y_scroll()
+ .child(Label::new(ERROR_MESSAGE)),
+ )
+ .child(
+ h_flex()
+ .justify_end()
+ .mt_1()
+ .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
+ |this, _, cx| {
+ this.last_error = None;
+ cx.open_url(SUBSCRIBE_URL);
+ cx.notify();
+ },
+ )))
+ .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+ |this, _, cx| {
+ this.last_error = None;
+ cx.notify();
+ },
+ ))),
+ )
+ .into_any()
+ }
+
+ fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
+ const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
+ const ACCOUNT_URL: &str = "https://zed.dev/account";
+
+ v_flex()
+ .gap_0p5()
+ .child(
+ h_flex()
+ .gap_1p5()
+ .items_center()
+ .child(Icon::new(IconName::XCircle).color(Color::Error))
+ .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
+ )
+ .child(
+ div()
+ .id("error-message")
+ .max_h_24()
+ .overflow_y_scroll()
+ .child(Label::new(ERROR_MESSAGE)),
+ )
+ .child(
+ h_flex()
+ .justify_end()
+ .mt_1()
+ .child(
+ Button::new("subscribe", "Update Monthly Spend Limit").on_click(
+ cx.listener(|this, _, cx| {
+ this.last_error = None;
+ cx.open_url(ACCOUNT_URL);
+ cx.notify();
+ }),
+ ),
+ )
+ .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+ |this, _, cx| {
+ this.last_error = None;
+ cx.notify();
+ },
+ ))),
+ )
+ .into_any()
+ }
+
+ fn render_assist_error(
+ &self,
+ error_message: &SharedString,
+ cx: &mut ViewContext<Self>,
+ ) -> AnyElement {
+ v_flex()
+ .gap_0p5()
+ .child(
+ h_flex()
+ .gap_1p5()
+ .items_center()
+ .child(Icon::new(IconName::XCircle).color(Color::Error))
+ .child(
+ Label::new("Error interacting with language model")
+ .weight(FontWeight::MEDIUM),
+ ),
+ )
+ .child(
+ div()
+ .id("error-message")
+ .max_h_24()
+ .overflow_y_scroll()
+ .child(Label::new(error_message.clone())),
+ )
+ .child(
+ h_flex()
+ .justify_end()
+ .mt_1()
+ .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
+ |this, _, cx| {
+ this.last_error = None;
+ cx.notify();
+ },
+ ))),
+ )
+ .into_any()
+ }
}
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
@@ -4434,48 +4595,7 @@ impl Render for ContextEditor {
.child(element),
)
})
- .when_some(self.error_message.clone(), |this, error_message| {
- this.child(
- div()
- .absolute()
- .right_3()
- .bottom_12()
- .max_w_96()
- .py_2()
- .px_3()
- .elevation_2(cx)
- .occlude()
- .child(
- v_flex()
- .gap_0p5()
- .child(
- h_flex()
- .gap_1p5()
- .items_center()
- .child(Icon::new(IconName::XCircle).color(Color::Error))
- .child(
- Label::new("Error interacting with language model")
- .weight(FontWeight::MEDIUM),
- ),
- )
- .child(
- div()
- .id("error-message")
- .max_h_24()
- .overflow_y_scroll()
- .child(Label::new(error_message)),
- )
- .child(h_flex().justify_end().mt_1().child(
- Button::new("dismiss", "Dismiss").on_click(cx.listener(
- |this, _, cx| {
- this.error_message = None;
- cx.notify();
- },
- )),
- )),
- ),
- )
- })
+ .children(self.render_last_error(cx))
.child(
h_flex().w_full().relative().child(
h_flex()
@@ -26,6 +26,7 @@ use gpui::{
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
+ provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
@@ -294,6 +295,8 @@ impl ContextOperation {
#[derive(Debug, Clone)]
pub enum ContextEvent {
ShowAssistError(SharedString),
+ ShowPaymentRequiredError,
+ ShowMaxMonthlySpendReachedError,
MessagesEdited,
SummaryChanged,
StreamedCompletion,
@@ -2112,25 +2115,36 @@ impl Context {
let result = stream_completion.await;
this.update(&mut cx, |this, cx| {
- let error_message = result
- .as_ref()
- .err()
- .map(|error| error.to_string().trim().to_string());
-
- if let Some(error_message) = error_message.as_ref() {
- cx.emit(ContextEvent::ShowAssistError(SharedString::from(
- error_message.clone(),
- )));
- }
-
- this.update_metadata(assistant_message_id, cx, |metadata| {
- if let Some(error_message) = error_message.as_ref() {
- metadata.status =
- MessageStatus::Error(SharedString::from(error_message.clone()));
+ let error_message = if let Some(error) = result.as_ref().err() {
+ if error.is::<PaymentRequiredError>() {
+ cx.emit(ContextEvent::ShowPaymentRequiredError);
+ this.update_metadata(assistant_message_id, cx, |metadata| {
+ metadata.status = MessageStatus::Canceled;
+ });
+ Some(error.to_string())
+ } else if error.is::<MaxMonthlySpendReachedError>() {
+ cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
+ this.update_metadata(assistant_message_id, cx, |metadata| {
+ metadata.status = MessageStatus::Canceled;
+ });
+ Some(error.to_string())
} else {
- metadata.status = MessageStatus::Done;
+ let error_message = error.to_string().trim().to_string();
+ cx.emit(ContextEvent::ShowAssistError(SharedString::from(
+ error_message.clone(),
+ )));
+ this.update_metadata(assistant_message_id, cx, |metadata| {
+ metadata.status =
+ MessageStatus::Error(SharedString::from(error_message.clone()));
+ });
+ Some(error_message)
}
- });
+ } else {
+ this.update_metadata(assistant_message_id, cx, |metadata| {
+ metadata.status = MessageStatus::Done;
+ });
+ None
+ };
if let Some(telemetry) = this.telemetry.as_ref() {
let language_name = this
@@ -47,6 +47,7 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
+thiserror.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true
@@ -7,7 +7,10 @@ use crate::{
};
use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
-use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use client::{
+ Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
+ MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
+};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use futures::{
@@ -15,10 +18,11 @@ use futures::{
TryStreamExt as _,
};
use gpui::{
- AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
- Subscription, Task,
+ AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
+ ModelContext, ReadGlobal, Subscription, Task,
};
-use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
+use proto::TypedEnvelope;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
@@ -27,12 +31,14 @@ use smol::{
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
+use std::fmt;
use std::time::Duration;
use std::{
future,
sync::{Arc, LazyLock},
};
use strum::IntoEnumIterator;
+use thiserror::Error;
use ui::{prelude::*, TintColor};
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
@@ -90,22 +96,93 @@ pub struct AvailableModel {
pub default_temperature: Option<f32>,
}
+struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct RefreshLlmTokenEvent;
+
+pub struct RefreshLlmTokenListener {
+ _llm_token_subscription: client::Subscription,
+}
+
+impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+ pub fn register(client: Arc<Client>, cx: &mut AppContext) {
+ let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
+ cx.set_global(GlobalRefreshLlmTokenListener(listener));
+ }
+
+ pub fn global(cx: &AppContext) -> Model<Self> {
+ GlobalRefreshLlmTokenListener::global(cx).0.clone()
+ }
+
+ fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
+ Self {
+ _llm_token_subscription: client
+ .add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
+ }
+ }
+
+ async fn handle_refresh_llm_token(
+ this: Model<Self>,
+ _: TypedEnvelope<proto::RefreshLlmToken>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
+ }
+}
+
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
- llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
pub struct State {
client: Arc<Client>,
+ llm_api_token: LlmApiToken,
user_store: Model<UserStore>,
status: client::Status,
accept_terms: Option<Task<Result<()>>>,
- _subscription: Subscription,
+ _settings_subscription: Subscription,
+ _llm_token_subscription: Subscription,
}
impl State {
+ fn new(
+ client: Arc<Client>,
+ user_store: Model<UserStore>,
+ status: client::Status,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+
+ Self {
+ client: client.clone(),
+ llm_api_token: LlmApiToken::default(),
+ user_store,
+ status,
+ accept_terms: None,
+ _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ cx.notify();
+ }),
+ _llm_token_subscription: cx.subscribe(
+ &refresh_llm_token_listener,
+ |this, _listener, _event, cx| {
+ let client = this.client.clone();
+ let llm_api_token = this.llm_api_token.clone();
+ cx.spawn(|_this, _cx| async move {
+ llm_api_token.refresh(&client).await?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ },
+ ),
+ }
+ }
+
fn is_signed_out(&self) -> bool {
self.status.is_signed_out()
}
@@ -144,15 +221,7 @@ impl CloudLanguageModelProvider {
let mut status_rx = client.status();
let status = *status_rx.borrow();
- let state = cx.new_model(|cx| State {
- client: client.clone(),
- user_store,
- status,
- accept_terms: None,
- _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
- cx.notify();
- }),
- });
+ let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(|mut cx| async move {
@@ -172,8 +241,7 @@ impl CloudLanguageModelProvider {
Self {
client,
- state,
- llm_api_token: LlmApiToken::default(),
+ state: state.clone(),
_maintain_client_status: maintain_client_status,
}
}
@@ -272,13 +340,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
models.insert(model.id().to_string(), model.clone());
}
+ let llm_api_token = self.state.read(cx).llm_api_token.clone();
models
.into_values()
.map(|model| {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
- llm_api_token: self.llm_api_token.clone(),
+ llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@@ -377,6 +446,30 @@ pub struct CloudLanguageModel {
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
+#[derive(Error, Debug)]
+pub struct PaymentRequiredError;
+
+impl fmt::Display for PaymentRequiredError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Payment required to use this language model. Please upgrade your account."
+ )
+ }
+}
+
+#[derive(Error, Debug)]
+pub struct MaxMonthlySpendReachedError;
+
+impl fmt::Display for MaxMonthlySpendReachedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Maximum spending limit reached for this month. For more usage, increase your spending limit."
+ )
+ }
+}
+
impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
@@ -411,6 +504,15 @@ impl CloudLanguageModel {
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
+ } else if response.status() == StatusCode::FORBIDDEN
+ && response
+ .headers()
+ .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
+ .is_some()
+ {
+ break Err(anyhow!(MaxMonthlySpendReachedError))?;
+ } else if response.status() == StatusCode::PAYMENT_REQUIRED {
+ break Err(anyhow!(PaymentRequiredError))?;
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1,3 +1,4 @@
+use crate::provider::cloud::RefreshLlmTokenListener;
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
@@ -30,6 +31,8 @@ fn register_language_model_providers(
) {
use feature_flags::FeatureFlagAppExt;
+ RefreshLlmTokenListener::register(client.clone(), cx);
+
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,
@@ -271,6 +271,7 @@ message Envelope {
GetLlmToken get_llm_token = 235;
GetLlmTokenResponse get_llm_token_response = 236;
+ RefreshLlmToken refresh_llm_token = 259; // current max
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
@@ -284,7 +285,7 @@ message Envelope {
CheckFileExists check_file_exists = 255;
CheckFileExistsResponse check_file_exists_response = 256;
- ShutdownRemoteServer shutdown_remote_server = 257; // current max
+ ShutdownRemoteServer shutdown_remote_server = 257;
}
reserved 87 to 88;
@@ -2482,6 +2483,8 @@ message GetLlmTokenResponse {
string token = 1;
}
+message RefreshLlmToken {}
+
// Remote FS
message AddWorktree {
@@ -253,6 +253,7 @@ messages!(
(ProjectEntryResponse, Foreground),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
+ (RefreshLlmToken, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@@ -3,6 +3,8 @@ use strum::{Display, EnumIter, EnumString};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
+pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached";
+
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]