diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 2347a731cb5b5f3590dafcf0a57dc0bab88c380c..0dd387e627a29fcd48b0523dd72990bbc05a5311 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -967,6 +967,10 @@ impl EditPredictionStore { self.mercury.api_token.read(cx).has_key() } + pub fn mercury_has_payment_required_error(&self) -> bool { + self.mercury.has_payment_required_error() + } + pub fn clear_history(&mut self) { for project_state in self.projects.values_mut() { project_state.events.clear(); diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 0a952f0869b46f626c231e11f8a61370c50490fa..b80498c4ddccfffab02e77ceb20e6e9cf68851f4 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -1,19 +1,19 @@ use crate::{ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, - EditPredictionStartedDebugEvent, open_ai_response::text_from_response, + EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response, prediction::EditPredictionResult, zeta::compute_edits, }; use anyhow::{Context as _, Result}; use cloud_llm_client::EditPredictionRejectReason; use futures::AsyncReadExt as _; use gpui::{ - App, AppContext as _, Entity, Global, SharedString, Task, - http_client::{self, AsyncBody, HttpClient, Method}, + App, AppContext as _, Context, Entity, Global, SharedString, Task, + http_client::{self, AsyncBody, HttpClient, Method, StatusCode}, }; use language::{ToOffset, ToPoint as _}; use language_model::{ApiKeyState, EnvVar, env_var}; use release_channel::AppVersion; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant}; use zeta_prompt::ZetaPromptInput; @@ -21,17 +21,27 @@ const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions" pub struct Mercury { pub api_token: Entity, + payment_required_error: bool, } impl Mercury { pub fn new(cx: &mut App) -> Self { Mercury { api_token: mercury_api_token(cx), + payment_required_error: false, } } + pub fn has_payment_required_error(&self) -> bool { + self.payment_required_error + } + + pub fn set_payment_required_error(&mut self, payment_required_error: bool) { + self.payment_required_error = payment_required_error; + } + pub(crate) fn request_prediction( - &self, + &mut self, EditPredictionModelInput { buffer, snapshot, @@ -41,7 +51,7 @@ impl Mercury { debug_tx, .. }: EditPredictionModelInput, - cx: &mut App, + cx: &mut Context, ) -> Task>> { self.api_token.update(cx, |key_state, cx| { _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx); @@ -163,6 +173,12 @@ impl Mercury { let response_received_at = Instant::now(); if !response.status().is_success() { + if response.status() == StatusCode::PAYMENT_REQUIRED { + anyhow::bail!(MercuryPaymentRequiredError( + mercury_payment_required_message(&body), + )); + } + anyhow::bail!( "Request failed with status: {:?}\nBody: {}", response.status(), @@ -209,9 +225,22 @@ impl Mercury { anyhow::Ok((id, edits, snapshot, response_received_at, inputs)) }); - cx.spawn(async move |cx| { - let (id, edits, old_snapshot, response_received_at, inputs) = - result.await.context("Mercury edit prediction failed")?; + cx.spawn(async move |ep_store, cx| { + let result = result.await.context("Mercury edit prediction failed"); + + let has_payment_required_error = result + .as_ref() + .err() + .is_some_and(is_mercury_payment_required_error); + + ep_store.update(cx, |store, cx| { + store + .mercury + .set_payment_required_error(has_payment_required_error); + cx.notify(); + })?; + + let (id, edits, old_snapshot, response_received_at, inputs) = result?; anyhow::Ok(Some( EditPredictionResult::new( EditPredictionId(id.into()), @@ -315,6 +344,33 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce( pub const MERCURY_CREDENTIALS_URL: SharedString = SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions"); pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token"; + +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +struct MercuryPaymentRequiredError(SharedString); + +#[derive(Deserialize)] +struct MercuryErrorResponse { + error: MercuryErrorMessage, +} + +#[derive(Deserialize)] +struct MercuryErrorMessage { + message: String, +} + +fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool { + error + .downcast_ref::() + .is_some() +} + +fn mercury_payment_required_message(body: &[u8]) -> SharedString { + serde_json::from_slice::(body) + .map(|response| response.error.message.into()) + .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into()) +} + pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock = env_var!("MERCURY_AI_TOKEN"); struct GlobalMercuryApiKey(Entity); diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index dac4c812f8ac1377423f7044c1c250b5a5333f64..1a5e60ca8b27f31d26c6389bbd39a516164f3bf6 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -359,10 +359,16 @@ impl Render for EditPredictionButton { } EditPredictionProvider::Mercury => { ep_icon = if enabled { icons.base } else { icons.disabled }; + let mercury_has_error = + edit_prediction::EditPredictionStore::try_global(cx).is_some_and( + |ep_store| ep_store.read(cx).mercury_has_payment_required_error(), + ); missing_token = edit_prediction::EditPredictionStore::try_global(cx) .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx)); tooltip_meta = if missing_token { "Missing API key for Mercury" + } else if mercury_has_error { + "Mercury free tier limit reached" } else { "Powered by Mercury" }; @@ -414,7 +420,12 @@ impl Render for EditPredictionButton { let show_editor_predictions = self.editor_show_predictions; let user = self.user_store.read(cx).current_user(); - let indicator_color = if missing_token { + let mercury_has_error = matches!(provider, EditPredictionProvider::Mercury) + && edit_prediction::EditPredictionStore::try_global(cx).is_some_and( + |ep_store| ep_store.read(cx).mercury_has_payment_required_error(), + ); + + let indicator_color = if missing_token || mercury_has_error { Some(Color::Error) } else if enabled && (!show_editor_predictions || over_limit) { Some(if over_limit { @@ -1096,96 +1107,116 @@ impl EditPredictionButton { }, ) .separator(); - } else if let Some(usage) = self - .edit_prediction_provider - .as_ref() - .and_then(|provider| provider.usage(cx)) - { - menu = menu.header("Usage"); - menu = menu - .custom_entry( - move |_window, cx| { - let used_percentage = match usage.limit { - UsageLimit::Limited(limit) => { - Some((usage.amount as f32 / limit as f32) * 100.) - } - UsageLimit::Unlimited => None, - }; + } else { + let mercury_payment_required = matches!(provider, EditPredictionProvider::Mercury) + && edit_prediction::EditPredictionStore::try_global(cx).is_some_and( + |ep_store| ep_store.read(cx).mercury_has_payment_required_error(), + ); + + if mercury_payment_required { + menu = menu + .header("Mercury") + .item(ContextMenuEntry::new("Free tier limit reached").disabled(true)) + .item( + ContextMenuEntry::new( + "Upgrade to a paid plan to continue using the service", + ) + .disabled(true), + ) + .separator(); + } + + if let Some(usage) = self + .edit_prediction_provider + .as_ref() + .and_then(|provider| provider.usage(cx)) + { + menu = menu.header("Usage"); + menu = menu + .custom_entry( + move |_window, cx| { + let used_percentage = match usage.limit { + UsageLimit::Limited(limit) => { + Some((usage.amount as f32 / limit as f32) * 100.) + } + UsageLimit::Unlimited => None, + }; - h_flex() - .flex_1() - .gap_1p5() - .children( - used_percentage.map(|percent| { + h_flex() + .flex_1() + .gap_1p5() + .children(used_percentage.map(|percent| { ProgressBar::new("usage", percent, 100., cx) - }), - ) - .child( - Label::new(match usage.limit { - UsageLimit::Limited(limit) => { - format!("{} / {limit}", usage.amount) - } - UsageLimit::Unlimited => format!("{} / ∞", usage.amount), - }) + })) + .child( + Label::new(match usage.limit { + UsageLimit::Limited(limit) => { + format!("{} / {limit}", usage.amount) + } + UsageLimit::Unlimited => { + format!("{} / ∞", usage.amount) + } + }) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + }, + move |_, cx| cx.open_url(&zed_urls::account_url(cx)), + ) + .when(usage.over_limit(), |menu| -> ContextMenu { + menu.entry("Subscribe to increase your limit", None, |_window, cx| { + telemetry::event!( + "Edit Prediction Menu Action", + action = "upsell_clicked", + reason = "usage_limit", + ); + cx.open_url(&zed_urls::account_url(cx)) + }) + }) + .separator(); + } else if self.user_store.read(cx).account_too_young() { + menu = menu + .custom_entry( + |_window, _cx| { + Label::new("Your GitHub account is less than 30 days old.") .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element() - }, - move |_, cx| cx.open_url(&zed_urls::account_url(cx)), - ) - .when(usage.over_limit(), |menu| -> ContextMenu { - menu.entry("Subscribe to increase your limit", None, |_window, cx| { + .color(Color::Warning) + .into_any_element() + }, + |_window, cx| cx.open_url(&zed_urls::account_url(cx)), + ) + .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| { telemetry::event!( "Edit Prediction Menu Action", action = "upsell_clicked", - reason = "usage_limit", + reason = "account_age", ); cx.open_url(&zed_urls::account_url(cx)) }) - }) - .separator(); - } else if self.user_store.read(cx).account_too_young() { - menu = menu - .custom_entry( - |_window, _cx| { - Label::new("Your GitHub account is less than 30 days old.") - .size(LabelSize::Small) - .color(Color::Warning) - .into_any_element() - }, - |_window, cx| cx.open_url(&zed_urls::account_url(cx)), - ) - .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| { - telemetry::event!( - "Edit Prediction Menu Action", - action = "upsell_clicked", - reason = "account_age", - ); - cx.open_url(&zed_urls::account_url(cx)) - }) - .separator(); - } else if self.user_store.read(cx).has_overdue_invoices() { - menu = menu - .custom_entry( - |_window, _cx| { - Label::new("You have an outstanding invoice") - .size(LabelSize::Small) - .color(Color::Warning) - .into_any_element() - }, - |_window, cx| { - cx.open_url(&zed_urls::account_url(cx)) - }, - ) - .entry( - "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.", - None, - |_window, cx| { - cx.open_url(&zed_urls::account_url(cx)) - }, - ) - .separator(); + .separator(); + } else if self.user_store.read(cx).has_overdue_invoices() { + menu = menu + .custom_entry( + |_window, _cx| { + Label::new("You have an outstanding invoice") + .size(LabelSize::Small) + .color(Color::Warning) + .into_any_element() + }, + |_window, cx| { + cx.open_url(&zed_urls::account_url(cx)) + }, + ) + .entry( + "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.", + None, + |_window, cx| { + cx.open_url(&zed_urls::account_url(cx)) + }, + ) + .separator(); + } } if !needs_sign_in {