ep: Error indication when Mercury free tier limit reached (#51447)

Ben Kunkle created

Release Notes:

- Added an error indicator in the edit prediction menu with an error
message when the free tier limit is exceeded

Change summary

crates/edit_prediction/src/edit_prediction.rs           |   4 
crates/edit_prediction/src/mercury.rs                   |  74 +++
crates/edit_prediction_ui/src/edit_prediction_button.rs | 195 ++++++----
3 files changed, 182 insertions(+), 91 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -967,6 +967,10 @@ impl EditPredictionStore {
         self.mercury.api_token.read(cx).has_key()
     }
 
+    pub fn mercury_has_payment_required_error(&self) -> bool {
+        self.mercury.has_payment_required_error()
+    }
+
     pub fn clear_history(&mut self) {
         for project_state in self.projects.values_mut() {
             project_state.events.clear();

crates/edit_prediction/src/mercury.rs 🔗

@@ -1,19 +1,19 @@
 use crate::{
     DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
-    EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
+    EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
     prediction::EditPredictionResult, zeta::compute_edits,
 };
 use anyhow::{Context as _, Result};
 use cloud_llm_client::EditPredictionRejectReason;
 use futures::AsyncReadExt as _;
 use gpui::{
-    App, AppContext as _, Entity, Global, SharedString, Task,
-    http_client::{self, AsyncBody, HttpClient, Method},
+    App, AppContext as _, Context, Entity, Global, SharedString, Task,
+    http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
 };
 use language::{ToOffset, ToPoint as _};
 use language_model::{ApiKeyState, EnvVar, env_var};
 use release_channel::AppVersion;
-use serde::Serialize;
+use serde::{Deserialize, Serialize};
 use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
 use zeta_prompt::ZetaPromptInput;
 
@@ -21,17 +21,27 @@ const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"
 
 pub struct Mercury {
     pub api_token: Entity<ApiKeyState>,
+    payment_required_error: bool,
 }
 
 impl Mercury {
     pub fn new(cx: &mut App) -> Self {
         Mercury {
             api_token: mercury_api_token(cx),
+            payment_required_error: false,
         }
     }
 
+    pub fn has_payment_required_error(&self) -> bool {
+        self.payment_required_error
+    }
+
+    pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
+        self.payment_required_error = payment_required_error;
+    }
+
     pub(crate) fn request_prediction(
-        &self,
+        &mut self,
         EditPredictionModelInput {
             buffer,
             snapshot,
@@ -41,7 +51,7 @@ impl Mercury {
             debug_tx,
             ..
         }: EditPredictionModelInput,
-        cx: &mut App,
+        cx: &mut Context<EditPredictionStore>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         self.api_token.update(cx, |key_state, cx| {
             _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
@@ -163,6 +173,12 @@ impl Mercury {
 
             let response_received_at = Instant::now();
             if !response.status().is_success() {
+                if response.status() == StatusCode::PAYMENT_REQUIRED {
+                    anyhow::bail!(MercuryPaymentRequiredError(
+                        mercury_payment_required_message(&body),
+                    ));
+                }
+
                 anyhow::bail!(
                     "Request failed with status: {:?}\nBody: {}",
                     response.status(),
@@ -209,9 +225,22 @@ impl Mercury {
             anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
         });
 
-        cx.spawn(async move |cx| {
-            let (id, edits, old_snapshot, response_received_at, inputs) =
-                result.await.context("Mercury edit prediction failed")?;
+        cx.spawn(async move |ep_store, cx| {
+            let result = result.await.context("Mercury edit prediction failed");
+
+            let has_payment_required_error = result
+                .as_ref()
+                .err()
+                .is_some_and(is_mercury_payment_required_error);
+
+            ep_store.update(cx, |store, cx| {
+                store
+                    .mercury
+                    .set_payment_required_error(has_payment_required_error);
+                cx.notify();
+            })?;
+
+            let (id, edits, old_snapshot, response_received_at, inputs) = result?;
             anyhow::Ok(Some(
                 EditPredictionResult::new(
                     EditPredictionId(id.into()),
@@ -315,6 +344,33 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(
 pub const MERCURY_CREDENTIALS_URL: SharedString =
     SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
 pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
+
+#[derive(Debug, thiserror::Error)]
+#[error("{0}")]
+struct MercuryPaymentRequiredError(SharedString);
+
+#[derive(Deserialize)]
+struct MercuryErrorResponse {
+    error: MercuryErrorMessage,
+}
+
+#[derive(Deserialize)]
+struct MercuryErrorMessage {
+    message: String,
+}
+
+fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
+    error
+        .downcast_ref::<MercuryPaymentRequiredError>()
+        .is_some()
+}
+
+fn mercury_payment_required_message(body: &[u8]) -> SharedString {
+    serde_json::from_slice::<MercuryErrorResponse>(body)
+        .map(|response| response.error.message.into())
+        .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
+}
+
 pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
 
 struct GlobalMercuryApiKey(Entity<ApiKeyState>);

crates/edit_prediction_ui/src/edit_prediction_button.rs 🔗

@@ -359,10 +359,16 @@ impl Render for EditPredictionButton {
                     }
                     EditPredictionProvider::Mercury => {
                         ep_icon = if enabled { icons.base } else { icons.disabled };
+                        let mercury_has_error =
+                            edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+                                |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+                            );
                         missing_token = edit_prediction::EditPredictionStore::try_global(cx)
                             .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx));
                         tooltip_meta = if missing_token {
                             "Missing API key for Mercury"
+                        } else if mercury_has_error {
+                            "Mercury free tier limit reached"
                         } else {
                             "Powered by Mercury"
                         };
@@ -414,7 +420,12 @@ impl Render for EditPredictionButton {
                 let show_editor_predictions = self.editor_show_predictions;
                 let user = self.user_store.read(cx).current_user();
 
-                let indicator_color = if missing_token {
+                let mercury_has_error = matches!(provider, EditPredictionProvider::Mercury)
+                    && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+                        |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+                    );
+
+                let indicator_color = if missing_token || mercury_has_error {
                     Some(Color::Error)
                 } else if enabled && (!show_editor_predictions || over_limit) {
                     Some(if over_limit {
@@ -1096,96 +1107,116 @@ impl EditPredictionButton {
                         },
                     )
                     .separator();
-            } else if let Some(usage) = self
-                .edit_prediction_provider
-                .as_ref()
-                .and_then(|provider| provider.usage(cx))
-            {
-                menu = menu.header("Usage");
-                menu = menu
-                    .custom_entry(
-                        move |_window, cx| {
-                            let used_percentage = match usage.limit {
-                                UsageLimit::Limited(limit) => {
-                                    Some((usage.amount as f32 / limit as f32) * 100.)
-                                }
-                                UsageLimit::Unlimited => None,
-                            };
+            } else {
+                let mercury_payment_required = matches!(provider, EditPredictionProvider::Mercury)
+                    && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+                        |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+                    );
+
+                if mercury_payment_required {
+                    menu = menu
+                        .header("Mercury")
+                        .item(ContextMenuEntry::new("Free tier limit reached").disabled(true))
+                        .item(
+                            ContextMenuEntry::new(
+                                "Upgrade to a paid plan to continue using the service",
+                            )
+                            .disabled(true),
+                        )
+                        .separator();
+                }
+
+                if let Some(usage) = self
+                    .edit_prediction_provider
+                    .as_ref()
+                    .and_then(|provider| provider.usage(cx))
+                {
+                    menu = menu.header("Usage");
+                    menu = menu
+                        .custom_entry(
+                            move |_window, cx| {
+                                let used_percentage = match usage.limit {
+                                    UsageLimit::Limited(limit) => {
+                                        Some((usage.amount as f32 / limit as f32) * 100.)
+                                    }
+                                    UsageLimit::Unlimited => None,
+                                };
 
-                            h_flex()
-                                .flex_1()
-                                .gap_1p5()
-                                .children(
-                                    used_percentage.map(|percent| {
+                                h_flex()
+                                    .flex_1()
+                                    .gap_1p5()
+                                    .children(used_percentage.map(|percent| {
                                         ProgressBar::new("usage", percent, 100., cx)
-                                    }),
-                                )
-                                .child(
-                                    Label::new(match usage.limit {
-                                        UsageLimit::Limited(limit) => {
-                                            format!("{} / {limit}", usage.amount)
-                                        }
-                                        UsageLimit::Unlimited => format!("{} / ∞", usage.amount),
-                                    })
+                                    }))
+                                    .child(
+                                        Label::new(match usage.limit {
+                                            UsageLimit::Limited(limit) => {
+                                                format!("{} / {limit}", usage.amount)
+                                            }
+                                            UsageLimit::Unlimited => {
+                                                format!("{} / ∞", usage.amount)
+                                            }
+                                        })
+                                        .size(LabelSize::Small)
+                                        .color(Color::Muted),
+                                    )
+                                    .into_any_element()
+                            },
+                            move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
+                        )
+                        .when(usage.over_limit(), |menu| -> ContextMenu {
+                            menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+                                telemetry::event!(
+                                    "Edit Prediction Menu Action",
+                                    action = "upsell_clicked",
+                                    reason = "usage_limit",
+                                );
+                                cx.open_url(&zed_urls::account_url(cx))
+                            })
+                        })
+                        .separator();
+                } else if self.user_store.read(cx).account_too_young() {
+                    menu = menu
+                        .custom_entry(
+                            |_window, _cx| {
+                                Label::new("Your GitHub account is less than 30 days old.")
                                     .size(LabelSize::Small)
-                                    .color(Color::Muted),
-                                )
-                                .into_any_element()
-                        },
-                        move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
-                    )
-                    .when(usage.over_limit(), |menu| -> ContextMenu {
-                        menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+                                    .color(Color::Warning)
+                                    .into_any_element()
+                            },
+                            |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
+                        )
+                        .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
                             telemetry::event!(
                                 "Edit Prediction Menu Action",
                                 action = "upsell_clicked",
-                                reason = "usage_limit",
+                                reason = "account_age",
                             );
                             cx.open_url(&zed_urls::account_url(cx))
                         })
-                    })
-                    .separator();
-            } else if self.user_store.read(cx).account_too_young() {
-                menu = menu
-                    .custom_entry(
-                        |_window, _cx| {
-                            Label::new("Your GitHub account is less than 30 days old.")
-                                .size(LabelSize::Small)
-                                .color(Color::Warning)
-                                .into_any_element()
-                        },
-                        |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
-                    )
-                    .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
-                        telemetry::event!(
-                            "Edit Prediction Menu Action",
-                            action = "upsell_clicked",
-                            reason = "account_age",
-                        );
-                        cx.open_url(&zed_urls::account_url(cx))
-                    })
-                    .separator();
-            } else if self.user_store.read(cx).has_overdue_invoices() {
-                menu = menu
-                    .custom_entry(
-                        |_window, _cx| {
-                            Label::new("You have an outstanding invoice")
-                                .size(LabelSize::Small)
-                                .color(Color::Warning)
-                                .into_any_element()
-                        },
-                        |_window, cx| {
-                            cx.open_url(&zed_urls::account_url(cx))
-                        },
-                    )
-                    .entry(
-                        "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
-                        None,
-                        |_window, cx| {
-                            cx.open_url(&zed_urls::account_url(cx))
-                        },
-                    )
-                    .separator();
+                        .separator();
+                } else if self.user_store.read(cx).has_overdue_invoices() {
+                    menu = menu
+                        .custom_entry(
+                            |_window, _cx| {
+                                Label::new("You have an outstanding invoice")
+                                    .size(LabelSize::Small)
+                                    .color(Color::Warning)
+                                    .into_any_element()
+                            },
+                            |_window, cx| {
+                                cx.open_url(&zed_urls::account_url(cx))
+                            },
+                        )
+                        .entry(
+                            "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
+                            None,
+                            |_window, cx| {
+                                cx.open_url(&zed_urls::account_url(cx))
+                            },
+                        )
+                        .separator();
+                }
             }
 
             if !needs_sign_in {