From 5f8efc937006dfce95edab0c57573093a308c4d5 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 17 Apr 2025 15:07:40 -0400 Subject: [PATCH] zeta: Extract usage information from response headers (#28999) This PR updates the Zeta provider to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll need to figure out where this needs to get threaded through to in order to display it in the UI. Release Notes: - N/A --- Cargo.lock | 5 +++-- Cargo.toml | 2 +- crates/zeta/src/zeta.rs | 50 ++++++++++++++++++++++++++++++++++++----- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index be58cbea8acd5e0ada4fddba002b5347334a26de..5b837fb446d446440806626f0821cb735de95123 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18388,10 +18388,11 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ee4d410dbc030c3e6e3af78fc76296f6bebe20dcb6d7d3fa24bca306fc8c1ce" +checksum = "b91b8b05f1028157205026e525869eb860fa89bec87ea60b445efc91d05df31f" dependencies = [ + "anyhow", "serde", "serde_json", "strum 0.27.1", diff --git a/Cargo.toml b/Cargo.toml index 045a627024d0f54b34b723b497858bdf0856d9a7..edfe44483ff5a556fbf920c1030cff8e88df6cab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -605,7 +605,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.5.1" +zed_llm_client = "0.6.0" zstd = "0.11" metal = "0.29" diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 14aa2820cb3f3c6d7f43d3dda3765aa08dda5b8a..b5367816fd9d433824bf6dbd4356be9b08645b69 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use http_client::http::{HeaderMap, HeaderValue}; pub use init::*; use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; @@ -54,8 +55,9 @@ use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use worktree::Worktree; use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, - PredictEditsResponse, + PredictEditsResponse, UsageLimit, }; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; @@ -74,6 +76,32 @@ const MAX_EVENT_COUNT: usize = 16; actions!(edit_prediction, [ClearHistory]); +#[derive(Debug, Clone, Copy)] +pub struct Usage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl Usage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") + })?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| { + anyhow!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") + })?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct InlineCompletionId(Uuid); @@ -359,7 +387,7 @@ impl Zeta { ) -> Task>> where F: FnOnce(PerformPredictEditsParams) -> R + 'static, - R: Future> + Send + 'static, + R: Future)>> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); let diagnostic_groups = snapshot.diagnostic_groups(None); @@ -467,7 +495,7 @@ impl Zeta { body, }) .await; - let response = match response { + let (response, usage) = match response { Ok(response) => response, Err(err) => { if err.is::() { @@ -503,6 +531,14 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("edit prediction usage: {} / {}", usage.amount, limit); + } + Self::process_completion_response( response, buffer, @@ -685,7 +721,7 @@ and then another use std::future::ready; self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { - ready(Ok(response)) + ready(Ok((response, None))) }) } @@ -714,7 +750,7 @@ and then another fn perform_predict_edits( params: PerformPredictEditsParams, - ) -> impl Future> { + ) -> impl Future)>> { async move { let PerformPredictEditsParams { client, @@ -760,9 +796,11 @@ and then another } if response.status().is_success() { + let usage = Usage::from_headers(response.headers()).ok(); + let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Ok(serde_json::from_str(&body)?); + return Ok((serde_json::from_str(&body)?, usage)); } else if !did_retry && response .headers()