@@ -1,19 +1,19 @@
use crate::{
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
- EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
+ EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
prediction::EditPredictionResult, zeta::compute_edits,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
use futures::AsyncReadExt as _;
use gpui::{
- App, AppContext as _, Entity, Global, SharedString, Task,
- http_client::{self, AsyncBody, HttpClient, Method},
+ App, AppContext as _, Context, Entity, Global, SharedString, Task,
+ http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
};
use language::{ToOffset, ToPoint as _};
use language_model::{ApiKeyState, EnvVar, env_var};
use release_channel::AppVersion;
-use serde::Serialize;
+use serde::{Deserialize, Serialize};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
@@ -21,17 +21,27 @@ const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"
pub struct Mercury {
pub api_token: Entity<ApiKeyState>,
+ payment_required_error: bool,
}
impl Mercury {
pub fn new(cx: &mut App) -> Self {
Mercury {
api_token: mercury_api_token(cx),
+ payment_required_error: false,
}
}
+ pub fn has_payment_required_error(&self) -> bool {
+ self.payment_required_error
+ }
+
+ pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
+ self.payment_required_error = payment_required_error;
+ }
+
pub(crate) fn request_prediction(
- &self,
+ &mut self,
EditPredictionModelInput {
buffer,
snapshot,
@@ -41,7 +51,7 @@ impl Mercury {
debug_tx,
..
}: EditPredictionModelInput,
- cx: &mut App,
+ cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
@@ -163,6 +173,12 @@ impl Mercury {
let response_received_at = Instant::now();
if !response.status().is_success() {
+ if response.status() == StatusCode::PAYMENT_REQUIRED {
+ anyhow::bail!(MercuryPaymentRequiredError(
+ mercury_payment_required_message(&body),
+ ));
+ }
+
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
@@ -209,9 +225,22 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
- cx.spawn(async move |cx| {
- let (id, edits, old_snapshot, response_received_at, inputs) =
- result.await.context("Mercury edit prediction failed")?;
+ cx.spawn(async move |ep_store, cx| {
+ let result = result.await.context("Mercury edit prediction failed");
+
+ let has_payment_required_error = result
+ .as_ref()
+ .err()
+ .is_some_and(is_mercury_payment_required_error);
+
+ ep_store.update(cx, |store, cx| {
+ store
+ .mercury
+ .set_payment_required_error(has_payment_required_error);
+ cx.notify();
+ })?;
+
+ let (id, edits, old_snapshot, response_received_at, inputs) = result?;
anyhow::Ok(Some(
EditPredictionResult::new(
EditPredictionId(id.into()),
@@ -315,6 +344,33 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(
pub const MERCURY_CREDENTIALS_URL: SharedString =
SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
+
+#[derive(Debug, thiserror::Error)]
+#[error("{0}")]
+struct MercuryPaymentRequiredError(SharedString);
+
+#[derive(Deserialize)]
+struct MercuryErrorResponse {
+ error: MercuryErrorMessage,
+}
+
+#[derive(Deserialize)]
+struct MercuryErrorMessage {
+ message: String,
+}
+
+fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
+ error
+ .downcast_ref::<MercuryPaymentRequiredError>()
+ .is_some()
+}
+
+fn mercury_payment_required_message(body: &[u8]) -> SharedString {
+ serde_json::from_slice::<MercuryErrorResponse>(body)
+ .map(|response| response.error.message.into())
+ .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
+}
+
pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
struct GlobalMercuryApiKey(Entity<ApiKeyState>);
@@ -359,10 +359,16 @@ impl Render for EditPredictionButton {
}
EditPredictionProvider::Mercury => {
ep_icon = if enabled { icons.base } else { icons.disabled };
+ let mercury_has_error =
+ edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx));
tooltip_meta = if missing_token {
"Missing API key for Mercury"
+ } else if mercury_has_error {
+ "Mercury free tier limit reached"
} else {
"Powered by Mercury"
};
@@ -414,7 +420,12 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
- let indicator_color = if missing_token {
+ let mercury_has_error = matches!(provider, EditPredictionProvider::Mercury)
+ && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
+
+ let indicator_color = if missing_token || mercury_has_error {
Some(Color::Error)
} else if enabled && (!show_editor_predictions || over_limit) {
Some(if over_limit {
@@ -1096,96 +1107,116 @@ impl EditPredictionButton {
},
)
.separator();
- } else if let Some(usage) = self
- .edit_prediction_provider
- .as_ref()
- .and_then(|provider| provider.usage(cx))
- {
- menu = menu.header("Usage");
- menu = menu
- .custom_entry(
- move |_window, cx| {
- let used_percentage = match usage.limit {
- UsageLimit::Limited(limit) => {
- Some((usage.amount as f32 / limit as f32) * 100.)
- }
- UsageLimit::Unlimited => None,
- };
+ } else {
+ let mercury_payment_required = matches!(provider, EditPredictionProvider::Mercury)
+ && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
+
+ if mercury_payment_required {
+ menu = menu
+ .header("Mercury")
+ .item(ContextMenuEntry::new("Free tier limit reached").disabled(true))
+ .item(
+ ContextMenuEntry::new(
+ "Upgrade to a paid plan to continue using the service",
+ )
+ .disabled(true),
+ )
+ .separator();
+ }
+
+ if let Some(usage) = self
+ .edit_prediction_provider
+ .as_ref()
+ .and_then(|provider| provider.usage(cx))
+ {
+ menu = menu.header("Usage");
+ menu = menu
+ .custom_entry(
+ move |_window, cx| {
+ let used_percentage = match usage.limit {
+ UsageLimit::Limited(limit) => {
+ Some((usage.amount as f32 / limit as f32) * 100.)
+ }
+ UsageLimit::Unlimited => None,
+ };
- h_flex()
- .flex_1()
- .gap_1p5()
- .children(
- used_percentage.map(|percent| {
+ h_flex()
+ .flex_1()
+ .gap_1p5()
+ .children(used_percentage.map(|percent| {
ProgressBar::new("usage", percent, 100., cx)
- }),
- )
- .child(
- Label::new(match usage.limit {
- UsageLimit::Limited(limit) => {
- format!("{} / {limit}", usage.amount)
- }
- UsageLimit::Unlimited => format!("{} / ∞", usage.amount),
- })
+ }))
+ .child(
+ Label::new(match usage.limit {
+ UsageLimit::Limited(limit) => {
+ format!("{} / {limit}", usage.amount)
+ }
+ UsageLimit::Unlimited => {
+ format!("{} / ∞", usage.amount)
+ }
+ })
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ },
+ move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
+ )
+ .when(usage.over_limit(), |menu| -> ContextMenu {
+ menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+ telemetry::event!(
+ "Edit Prediction Menu Action",
+ action = "upsell_clicked",
+ reason = "usage_limit",
+ );
+ cx.open_url(&zed_urls::account_url(cx))
+ })
+ })
+ .separator();
+ } else if self.user_store.read(cx).account_too_young() {
+ menu = menu
+ .custom_entry(
+ |_window, _cx| {
+ Label::new("Your GitHub account is less than 30 days old.")
.size(LabelSize::Small)
- .color(Color::Muted),
- )
- .into_any_element()
- },
- move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
- )
- .when(usage.over_limit(), |menu| -> ContextMenu {
- menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+ .color(Color::Warning)
+ .into_any_element()
+ },
+ |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
+ )
+ .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
telemetry::event!(
"Edit Prediction Menu Action",
action = "upsell_clicked",
- reason = "usage_limit",
+ reason = "account_age",
);
cx.open_url(&zed_urls::account_url(cx))
})
- })
- .separator();
- } else if self.user_store.read(cx).account_too_young() {
- menu = menu
- .custom_entry(
- |_window, _cx| {
- Label::new("Your GitHub account is less than 30 days old.")
- .size(LabelSize::Small)
- .color(Color::Warning)
- .into_any_element()
- },
- |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
- )
- .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
- telemetry::event!(
- "Edit Prediction Menu Action",
- action = "upsell_clicked",
- reason = "account_age",
- );
- cx.open_url(&zed_urls::account_url(cx))
- })
- .separator();
- } else if self.user_store.read(cx).has_overdue_invoices() {
- menu = menu
- .custom_entry(
- |_window, _cx| {
- Label::new("You have an outstanding invoice")
- .size(LabelSize::Small)
- .color(Color::Warning)
- .into_any_element()
- },
- |_window, cx| {
- cx.open_url(&zed_urls::account_url(cx))
- },
- )
- .entry(
- "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
- None,
- |_window, cx| {
- cx.open_url(&zed_urls::account_url(cx))
- },
- )
- .separator();
+ .separator();
+ } else if self.user_store.read(cx).has_overdue_invoices() {
+ menu = menu
+ .custom_entry(
+ |_window, _cx| {
+ Label::new("You have an outstanding invoice")
+ .size(LabelSize::Small)
+ .color(Color::Warning)
+ .into_any_element()
+ },
+ |_window, cx| {
+ cx.open_url(&zed_urls::account_url(cx))
+ },
+ )
+ .entry(
+ "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
+ None,
+ |_window, cx| {
+ cx.open_url(&zed_urls::account_url(cx))
+ },
+ )
+ .separator();
+ }
}
if !needs_sign_in {