zeta: Extract usage information from response headers (#28999)

Marshall Bowers created

This PR updates the Zeta provider to extract the usage information from
the response headers, if they are present.

For now we just log the information, but we'll need to figure out where
this needs to get threaded through to in order to display it in the UI.

Release Notes:

- N/A

Change summary

Cargo.lock              |  5 ++-
Cargo.toml              |  2 
crates/zeta/src/zeta.rs | 50 +++++++++++++++++++++++++++++++++++++-----
3 files changed, 48 insertions(+), 9 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -18361,10 +18361,11 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.5.1"
+version = "0.6.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9ee4d410dbc030c3e6e3af78fc76296f6bebe20dcb6d7d3fa24bca306fc8c1ce"
+checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f"
 dependencies = [
+ "anyhow",
  "serde",
  "serde_json",
  "strum 0.27.1",

Cargo.toml 🔗

@@ -604,7 +604,7 @@ wasmtime-wasi = "29"
 which = "6.0.0"
 wit-component = "0.221"
 workspace-hack = "0.1.0"
-zed_llm_client = "0.5.1"
+zed_llm_client = "0.6.0"
 zstd = "0.11"
 metal = "0.29"
 

crates/zeta/src/zeta.rs 🔗

@@ -8,6 +8,7 @@ mod rate_completion_modal;
 
 pub(crate) use completion_diff_element::*;
 use db::kvp::KEY_VALUE_STORE;
+use http_client::http::{HeaderMap, HeaderValue};
 pub use init::*;
 use inline_completion::DataCollectionState;
 use license_detection::LICENSE_FILES_TO_CHECK;
@@ -54,8 +55,9 @@ use workspace::Workspace;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId};
 use worktree::Worktree;
 use zed_llm_client::{
+    EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
     EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
-    PredictEditsResponse,
+    PredictEditsResponse, UsageLimit,
 };
 
 const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
@@ -74,6 +76,32 @@ const MAX_EVENT_COUNT: usize = 16;
 
 actions!(edit_prediction, [ClearHistory]);
 
+#[derive(Debug, Clone, Copy)]
+pub struct Usage {
+    pub limit: UsageLimit,
+    pub amount: i32,
+}
+
+impl Usage {
+    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
+        let limit = headers
+            .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME)
+            .ok_or_else(|| {
+                anyhow!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header")
+            })?;
+        let limit = UsageLimit::from_str(limit.to_str()?)?;
+
+        let amount = headers
+            .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME)
+            .ok_or_else(|| {
+                anyhow!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header")
+            })?;
+        let amount = amount.to_str()?.parse::<i32>()?;
+
+        Ok(Self { limit, amount })
+    }
+}
+
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 pub struct InlineCompletionId(Uuid);
 
@@ -359,7 +387,7 @@ impl Zeta {
     ) -> Task<Result<Option<InlineCompletion>>>
     where
         F: FnOnce(PerformPredictEditsParams) -> R + 'static,
-        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
+        R: Future<Output = Result<(PredictEditsResponse, Option<Usage>)>> + Send + 'static,
     {
         let snapshot = self.report_changes_for_buffer(&buffer, cx);
         let diagnostic_groups = snapshot.diagnostic_groups(None);
@@ -467,7 +495,7 @@ impl Zeta {
                 body,
             })
             .await;
-            let response = match response {
+            let (response, usage) = match response {
                 Ok(response) => response,
                 Err(err) => {
                     if err.is::<ZedUpdateRequiredError>() {
@@ -503,6 +531,14 @@ impl Zeta {
 
             log::debug!("completion response: {}", &response.output_excerpt);
 
+            if let Some(usage) = usage {
+                let limit = match usage.limit {
+                    UsageLimit::Limited(limit) => limit.to_string(),
+                    UsageLimit::Unlimited => "unlimited".to_string(),
+                };
+                log::info!("edit prediction usage: {} / {}", usage.amount, limit);
+            }
+
             Self::process_completion_response(
                 response,
                 buffer,
@@ -685,7 +721,7 @@ and then another
         use std::future::ready;
 
         self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
-            ready(Ok(response))
+            ready(Ok((response, None)))
         })
     }
 
@@ -714,7 +750,7 @@ and then another
 
     fn perform_predict_edits(
         params: PerformPredictEditsParams,
-    ) -> impl Future<Output = Result<PredictEditsResponse>> {
+    ) -> impl Future<Output = Result<(PredictEditsResponse, Option<Usage>)>> {
         async move {
             let PerformPredictEditsParams {
                 client,
@@ -760,9 +796,11 @@ and then another
                 }
 
                 if response.status().is_success() {
+                    let usage = Usage::from_headers(response.headers()).ok();
+
                     let mut body = String::new();
                     response.body_mut().read_to_string(&mut body).await?;
-                    return Ok(serde_json::from_str(&body)?);
+                    return Ok((serde_json::from_str(&body)?, usage));
                 } else if !did_retry
                     && response
                         .headers()