Meter edit predictions by acceptance in free plan (#30984)

Max Brunsfeld , Mikayla Maki , Antonio Scandurra , Ben Brandt , and Marshall Bowers created

TODO:

- [x] Release  a new version of `zed_llm_client`

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

Cargo.lock              |   4 
Cargo.toml              |   2 
crates/zeta/src/zeta.rs | 117 +++++++++++++++++++++++++++++++++++++++++-
3 files changed, 115 insertions(+), 8 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -19847,9 +19847,9 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.8.1"
+version = "0.8.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16d993fc42f9ec43ab76fa46c6eb579a66e116bb08cd2bc9a67f3afcaa05d39d"
+checksum = "9be71e2f9b271e1eb8eb3e0d986075e770d1a0a299fb036abc3f1fc13a2fa7eb"
 dependencies = [
  "anyhow",
  "serde",

Cargo.toml 🔗

@@ -615,7 +615,7 @@ wasmtime-wasi = "29"
 which = "6.0.0"
 wit-component = "0.221"
 workspace-hack = "0.1.0"
-zed_llm_client = "0.8.1"
+zed_llm_client = "0.8.2"
 zstd = "0.11"
 
 [workspace.dependencies.async-stripe]

crates/zeta/src/zeta.rs 🔗

@@ -14,7 +14,7 @@ use license_detection::LICENSE_FILES_TO_CHECK;
 pub use license_detection::is_license_eligible_for_data_collection;
 pub use rate_completion_modal::*;
 
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
 use arrayvec::ArrayVec;
 use client::{Client, UserStore};
 use collections::{HashMap, HashSet, VecDeque};
@@ -23,7 +23,7 @@ use gpui::{
     App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
     Subscription, Task, WeakEntity, actions,
 };
-use http_client::{HttpClient, Method};
+use http_client::{AsyncBody, HttpClient, Method, Request, Response};
 use input_excerpt::excerpt_for_cursor_position;
 use language::{
     Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
@@ -54,8 +54,8 @@ use workspace::Workspace;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId};
 use worktree::Worktree;
 use zed_llm_client::{
-    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
-    PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+    PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
 };
 
 const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
@@ -823,6 +823,74 @@ and then another
         }
     }
 
+    fn accept_edit_prediction(
+        &mut self,
+        request_id: InlineCompletionId,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let client = self.client.clone();
+        let llm_token = self.llm_token.clone();
+        let app_version = AppVersion::global(cx);
+        cx.spawn(async move |this, cx| {
+            let http_client = client.http_client();
+            let mut response = llm_token_retry(&llm_token, &client, |token| {
+                let request_builder = http_client::Request::builder().method(Method::POST);
+                let request_builder =
+                    if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
+                        request_builder.uri(accept_prediction_url)
+                    } else {
+                        request_builder.uri(
+                            http_client
+                                .build_zed_llm_url("/predict_edits/accept", &[])?
+                                .as_ref(),
+                        )
+                    };
+                Ok(request_builder
+                    .header("Content-Type", "application/json")
+                    .header("Authorization", format!("Bearer {}", token))
+                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+                    .body(
+                        serde_json::to_string(&AcceptEditPredictionBody {
+                            request_id: request_id.0,
+                        })?
+                        .into(),
+                    )?)
+            })
+            .await?;
+
+            if let Some(minimum_required_version) = response
+                .headers()
+                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+            {
+                if app_version < minimum_required_version {
+                    return Err(anyhow!(ZedUpdateRequiredError {
+                        minimum_version: minimum_required_version
+                    }));
+                }
+            }
+
+            if response.status().is_success() {
+                if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
+                    this.update(cx, |this, cx| {
+                        this.last_usage = Some(usage);
+                        cx.notify();
+                    })?;
+                }
+
+                Ok(())
+            } else {
+                let mut body = String::new();
+                response.body_mut().read_to_string(&mut body).await?;
+                Err(anyhow!(
+                    "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
+                    response.status(),
+                    body
+                ))
+            }
+        })
+    }
+
     fn process_completion_response(
         prediction_response: PredictEditsResponse,
         buffer: Entity<Buffer>,
@@ -1381,6 +1449,34 @@ impl ProviderDataCollection {
     }
 }
 
+async fn llm_token_retry(
+    llm_token: &LlmApiToken,
+    client: &Arc<Client>,
+    build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
+) -> Result<Response<AsyncBody>> {
+    let mut did_retry = false;
+    let http_client = client.http_client();
+    let mut token = llm_token.acquire(client).await?;
+    loop {
+        let request = build_request(token.clone())?;
+        let response = http_client.send(request).await?;
+
+        if !did_retry
+            && !response.status().is_success()
+            && response
+                .headers()
+                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                .is_some()
+        {
+            did_retry = true;
+            token = llm_token.refresh(client).await?;
+            continue;
+        }
+
+        return Ok(response);
+    }
+}
+
 pub struct ZetaInlineCompletionProvider {
     zeta: Entity<Zeta>,
     pending_completions: ArrayVec<PendingCompletion, 2>,
@@ -1597,7 +1693,18 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
         // Right now we don't support cycling.
     }
 
-    fn accept(&mut self, _cx: &mut Context<Self>) {
+    fn accept(&mut self, cx: &mut Context<Self>) {
+        let completion_id = self
+            .current_completion
+            .as_ref()
+            .map(|completion| completion.completion.id);
+        if let Some(completion_id) = completion_id {
+            self.zeta
+                .update(cx, |zeta, cx| {
+                    zeta.accept_edit_prediction(completion_id, cx)
+                })
+                .detach();
+        }
         self.pending_completions.clear();
     }