Cargo.lock 🔗
@@ -21791,6 +21791,7 @@ dependencies = [
  "pretty_assertions",
  "project",
  "release_channel",
+ "serde",
  "serde_json",
  "settings",
  "thiserror 2.0.12",
  Agus Zubiaga created
Release Notes:
- N/A
  
  
  
Cargo.lock                            |   1 
crates/http_client/src/http_client.rs |   3 
crates/zeta2/Cargo.toml               |   1 
crates/zeta2/src/prediction.rs        |   6 
crates/zeta2/src/provider.rs          |   4 
crates/zeta2/src/zeta2.rs             | 211 ++++++++++++++++++----------
6 files changed, 148 insertions(+), 78 deletions(-)
@@ -21791,6 +21791,7 @@ dependencies = [
  "pretty_assertions",
  "project",
  "release_channel",
+ "serde",
  "serde_json",
  "settings",
  "thiserror 2.0.12",
  @@ -6,13 +6,12 @@ pub use anyhow::{Result, anyhow};
 pub use async_body::{AsyncBody, Inner};
 use derive_more::Deref;
 use http::HeaderValue;
-pub use http::{self, Method, Request, Response, StatusCode, Uri};
+pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
 
 use futures::{
     FutureExt as _,
     future::{self, BoxFuture},
 };
-use http::request::Builder;
 use parking_lot::Mutex;
 #[cfg(feature = "test-support")]
 use std::fmt;
  @@ -28,6 +28,7 @@ language_model.workspace = true
 log.workspace = true
 project.workspace = true
 release_channel.workspace = true
+serde.workspace = true
 serde_json.workspace = true
 thiserror.workspace = true
 util.workspace = true
  @@ -13,6 +13,12 @@ use uuid::Uuid;
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 pub struct EditPredictionId(Uuid);
 
+impl Into<Uuid> for EditPredictionId {
+    fn into(self) -> Uuid {
+        self.0
+    }
+}
+
 impl From<EditPredictionId> for gpui::ElementId {
     fn from(value: EditPredictionId) -> Self {
         gpui::ElementId::Uuid(value.0)
  @@ -179,8 +179,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
     }
 
     fn accept(&mut self, cx: &mut Context<Self>) {
-        self.zeta.update(cx, |zeta, _cx| {
-            zeta.accept_current_prediction(&self.project);
+        self.zeta.update(cx, |zeta, cx| {
+            zeta.accept_current_prediction(&self.project, cx);
         });
         self.pending_predictions.clear();
     }
  @@ -3,7 +3,8 @@ use chrono::TimeDelta;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
 use cloud_llm_client::{
-    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
+    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+    ZED_VERSION_HEADER_NAME,
 };
 use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
 use edit_prediction_context::{
@@ -12,7 +13,7 @@ use edit_prediction_context::{
 };
 use futures::AsyncReadExt as _;
 use futures::channel::{mpsc, oneshot};
-use gpui::http_client::Method;
+use gpui::http_client::{AsyncBody, Method};
 use gpui::{
     App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
     http_client, prelude::*,
@@ -22,6 +23,7 @@ use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
+use serde::de::DeserializeOwned;
 use std::collections::{HashMap, VecDeque, hash_map};
 use std::path::Path;
 use std::str::FromStr as _;
@@ -391,11 +393,46 @@ impl Zeta {
         }
     }
 
-    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
-        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
-            project_state.current_prediction.take();
+    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+            return;
+        };
+
+        let Some(prediction) = project_state.current_prediction.take() else {
+            return;
         };
-        // TODO report accepted
+        let request_id = prediction.prediction.id.into();
+
+        let client = self.client.clone();
+        let llm_token = self.llm_token.clone();
+        let app_version = AppVersion::global(cx);
+        cx.spawn(async move |this, cx| {
+            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
+                http_client::Url::parse(&predict_edits_url)?
+            } else {
+                client
+                    .http_client()
+                    .build_zed_llm_url("/predict_edits/accept", &[])?
+            };
+
+            let response = cx
+                .background_spawn(Self::send_api_request::<()>(
+                    move |builder| {
+                        let req = builder.uri(url.as_ref()).body(
+                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
+                        );
+                        Ok(req?)
+                    },
+                    client,
+                    llm_token,
+                    app_version,
+                ))
+                .await;
+
+            Self::handle_api_response(&this, response, cx)?;
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx);
     }
 
     fn discard_current_prediction(&mut self, project: &Entity<Project>) {
@@ -545,7 +582,7 @@ impl Zeta {
                     &options.context,
                     index_state.as_deref(),
                 ) else {
-                    return Ok(None);
+                    return Ok((None, None));
                 };
 
                 let retrieval_time = chrono::Utc::now() - before_retrieval;
@@ -607,7 +644,8 @@ impl Zeta {
                     anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
                 }
 
-                let response = Self::perform_request(client, llm_token, app_version, request).await;
+                let response =
+                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 
                 if let Some(debug_response_tx) = debug_response_tx {
                     debug_response_tx
@@ -620,7 +658,7 @@ impl Zeta {
                         .ok();
                 }
 
-                anyhow::Ok(Some(response?))
+                response.map(|(res, usage)| (Some(res), usage))
             }
         });
 
@@ -629,60 +667,18 @@ impl Zeta {
         cx.spawn({
             let project = project.clone();
             async move |this, cx| {
-                match request_task.await {
-                    Ok(Some((response, usage))) => {
-                        if let Some(usage) = usage {
-                            this.update(cx, |this, cx| {
-                                this.user_store.update(cx, |user_store, cx| {
-                                    user_store.update_edit_prediction_usage(usage, cx);
-                                });
-                            })
-                            .ok();
-                        }
-
-                        let prediction = EditPrediction::from_response(
-                            response, &snapshot, &buffer, &project, cx,
-                        )
-                        .await;
-
-                        // TODO telemetry: duration, etc
-                        Ok(prediction)
-                    }
-                    Ok(None) => Ok(None),
-                    Err(err) => {
-                        if err.is::<ZedUpdateRequiredError>() {
-                            cx.update(|cx| {
-                                this.update(cx, |this, _cx| {
-                                    this.update_required = true;
-                                })
-                                .ok();
-
-                                let error_message: SharedString = err.to_string().into();
-                                show_app_notification(
-                                    NotificationId::unique::<ZedUpdateRequiredError>(),
-                                    cx,
-                                    move |cx| {
-                                        cx.new(|cx| {
-                                            ErrorMessagePrompt::new(error_message.clone(), cx)
-                                                .with_link_button(
-                                                    "Update Zed",
-                                                    "https://zed.dev/releases",
-                                                )
-                                        })
-                                    },
-                                );
-                            })
-                            .ok();
-                        }
+                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
+                else {
+                    return Ok(None);
+                };
 
-                        Err(err)
-                    }
-                }
+                // TODO telemetry: duration, etc
+                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
             }
         })
     }
 
-    async fn perform_request(
+    async fn send_prediction_request(
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: SemanticVersion,
@@ -691,27 +687,94 @@ impl Zeta {
         predict_edits_v3::PredictEditsResponse,
         Option<EditPredictionUsage>,
     )> {
+        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
+            http_client::Url::parse(&predict_edits_url)?
+        } else {
+            client
+                .http_client()
+                .build_zed_llm_url("/predict_edits/v3", &[])?
+        };
+
+        Self::send_api_request(
+            |builder| {
+                let req = builder
+                    .uri(url.as_ref())
+                    .body(serde_json::to_string(&request)?.into());
+                Ok(req?)
+            },
+            client,
+            llm_token,
+            app_version,
+        )
+        .await
+    }
+
+    fn handle_api_response<T>(
+        this: &WeakEntity<Self>,
+        response: Result<(T, Option<EditPredictionUsage>)>,
+        cx: &mut gpui::AsyncApp,
+    ) -> Result<T> {
+        match response {
+            Ok((data, usage)) => {
+                if let Some(usage) = usage {
+                    this.update(cx, |this, cx| {
+                        this.user_store.update(cx, |user_store, cx| {
+                            user_store.update_edit_prediction_usage(usage, cx);
+                        });
+                    })
+                    .ok();
+                }
+                Ok(data)
+            }
+            Err(err) => {
+                if err.is::<ZedUpdateRequiredError>() {
+                    cx.update(|cx| {
+                        this.update(cx, |this, _cx| {
+                            this.update_required = true;
+                        })
+                        .ok();
+
+                        let error_message: SharedString = err.to_string().into();
+                        show_app_notification(
+                            NotificationId::unique::<ZedUpdateRequiredError>(),
+                            cx,
+                            move |cx| {
+                                cx.new(|cx| {
+                                    ErrorMessagePrompt::new(error_message.clone(), cx)
+                                        .with_link_button("Update Zed", "https://zed.dev/releases")
+                                })
+                            },
+                        );
+                    })
+                    .ok();
+                }
+                Err(err)
+            }
+        }
+    }
+
+    async fn send_api_request<Res>(
+        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
+        client: Arc<Client>,
+        llm_token: LlmApiToken,
+        app_version: SemanticVersion,
+    ) -> Result<(Res, Option<EditPredictionUsage>)>
+    where
+        Res: DeserializeOwned,
+    {
         let http_client = client.http_client();
         let mut token = llm_token.acquire(&client).await?;
         let mut did_retry = false;
 
         loop {
             let request_builder = http_client::Request::builder().method(Method::POST);
-            let request_builder =
-                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
-                    request_builder.uri(predict_edits_url)
-                } else {
-                    request_builder.uri(
-                        http_client
-                            .build_zed_llm_url("/predict_edits/v3", &[])?
-                            .as_ref(),
-                    )
-                };
-            let request = request_builder
-                .header("Content-Type", "application/json")
-                .header("Authorization", format!("Bearer {}", token))
-                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
-                .body(serde_json::to_string(&request)?.into())?;
+
+            let request = build(
+                request_builder
+                    .header("Content-Type", "application/json")
+                    .header("Authorization", format!("Bearer {}", token))
+                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
+            )?;
 
             let mut response = http_client.send(request).await?;
 
@@ -746,7 +809,7 @@ impl Zeta {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;
                 anyhow::bail!(
-                    "error predicting edits.\nStatus: {:?}\nBody: {}",
+                    "Request failed with status: {:?}\nBody: {}",
                     response.status(),
                     body
                 );