edit prediction: Request trigger (#43588)

Agus Zubiaga created

Adds a `trigger` field to the zeta1/zeta2 prediction requests so that we
can distinguish between editor, diagnostic, and zeta-cli requests.

Release Notes:

- N/A

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs | 11 +++++
crates/cloud_llm_client/src/predict_edits_v3.rs |  4 +
crates/zeta/src/zeta.rs                         | 35 ++++++++++++++++--
crates/zeta/src/zeta1.rs                        |  7 +++
crates/zeta/src/zeta_tests.rs                   |  2 
crates/zeta_cli/src/main.rs                     |  1 
crates/zeta_cli/src/predict.rs                  |  8 +++
7 files changed, 59 insertions(+), 9 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -169,6 +169,17 @@ pub struct PredictEditsBody {
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
+    /// The trigger for this request.
+    #[serde(default)]
+    pub trigger: PredictEditsRequestTrigger,
+}
+
+#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
+pub enum PredictEditsRequestTrigger {
+    Diagnostics,
+    Cli,
+    #[default]
+    Other,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -9,7 +9,7 @@ use std::{
 use strum::EnumIter;
 use uuid::Uuid;
 
-use crate::PredictEditsGitInfo;
+use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger};
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PlanContextRetrievalRequest {
@@ -53,6 +53,8 @@ pub struct PredictEditsRequest {
     pub prompt_max_bytes: Option<usize>,
     #[serde(default)]
     pub prompt_format: PromptFormat,
+    #[serde(default)]
+    pub trigger: PredictEditsRequestTrigger,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/zeta/src/zeta.rs 🔗

@@ -5,7 +5,8 @@ use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
 use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
     EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
-    MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
+    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
+    ZED_VERSION_HEADER_NAME,
 };
 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
 use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
@@ -1016,7 +1017,13 @@ impl Zeta {
         self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
             let Some(request_task) = this
                 .update(cx, |this, cx| {
-                    this.request_prediction(&project, &buffer, position, cx)
+                    this.request_prediction(
+                        &project,
+                        &buffer,
+                        position,
+                        PredictEditsRequestTrigger::Other,
+                        cx,
+                    )
                 })
                 .log_err()
             else {
@@ -1083,7 +1090,13 @@ impl Zeta {
 
                 let Some(prediction_result) = this
                     .update(cx, |this, cx| {
-                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
+                        this.request_prediction(
+                            &project,
+                            &jump_buffer,
+                            jump_position,
+                            PredictEditsRequestTrigger::Diagnostics,
+                            cx,
+                        )
                     })?
                     .await?
                 else {
@@ -1264,12 +1277,14 @@ impl Zeta {
         project: &Entity<Project>,
         active_buffer: &Entity<Buffer>,
         position: language::Anchor,
+        trigger: PredictEditsRequestTrigger,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         self.request_prediction_internal(
             project.clone(),
             active_buffer.clone(),
             position,
+            trigger,
             cx.has_flag::<Zeta2FeatureFlag>(),
             cx,
         )
@@ -1280,6 +1295,7 @@ impl Zeta {
         project: Entity<Project>,
         active_buffer: Entity<Buffer>,
         position: language::Anchor,
+        trigger: PredictEditsRequestTrigger,
         allow_jump: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
@@ -1305,6 +1321,7 @@ impl Zeta {
                 snapshot.clone(),
                 position,
                 events,
+                trigger,
                 cx,
             ),
             ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
@@ -1313,6 +1330,7 @@ impl Zeta {
                 snapshot.clone(),
                 position,
                 events,
+                trigger,
                 cx,
             ),
             ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
@@ -1349,6 +1367,7 @@ impl Zeta {
                                 project,
                                 jump_buffer,
                                 jump_position,
+                                trigger,
                                 false,
                                 cx,
                             )
@@ -1449,6 +1468,7 @@ impl Zeta {
         active_snapshot: BufferSnapshot,
         position: language::Anchor,
         events: Vec<Arc<Event>>,
+        trigger: PredictEditsRequestTrigger,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         let project_state = self.projects.get(&project.entity_id());
@@ -1621,6 +1641,7 @@ impl Zeta {
                             signatures: vec![],
                             excerpt_parent: None,
                             git_info: None,
+                            trigger,
                         }
                     }
                     ContextMode::Syntax(context_options) => {
@@ -1647,6 +1668,7 @@ impl Zeta {
                             index_state.as_deref(),
                             Some(options.max_prompt_bytes),
                             options.prompt_format,
+                            trigger,
                         )
                     }
                 };
@@ -2416,6 +2438,7 @@ impl Zeta {
                     index_state.as_deref(),
                     Some(options.max_prompt_bytes),
                     options.prompt_format,
+                    PredictEditsRequestTrigger::Other,
                 )
             })
         })
@@ -2574,6 +2597,7 @@ fn make_syntax_context_cloud_request(
     index_state: Option<&SyntaxIndexState>,
     prompt_max_bytes: Option<usize>,
     prompt_format: PromptFormat,
+    trigger: PredictEditsRequestTrigger,
 ) -> predict_edits_v3::PredictEditsRequest {
     let mut signatures = Vec::new();
     let mut declaration_to_signature_index = HashMap::default();
@@ -2653,6 +2677,7 @@ fn make_syntax_context_cloud_request(
         debug_info,
         prompt_max_bytes,
         prompt_format,
+        trigger,
     }
 }
 
@@ -3072,7 +3097,7 @@ mod tests {
         let position = snapshot.anchor_before(language::Point::new(1, 3));
 
         let prediction_task = zeta.update(cx, |zeta, cx| {
-            zeta.request_prediction(&project, &buffer, position, cx)
+            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
         });
 
         let (_, respond_tx) = requests.predict.next().await.unwrap();
@@ -3145,7 +3170,7 @@ mod tests {
         let position = snapshot.anchor_before(language::Point::new(1, 3));
 
         let prediction_task = zeta.update(cx, |zeta, cx| {
-            zeta.request_prediction(&project, &buffer, position, cx)
+            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
         });
 
         let (request, respond_tx) = requests.predict.next().await.unwrap();

crates/zeta/src/zeta1.rs 🔗

@@ -8,7 +8,8 @@ use crate::{
 };
 use anyhow::{Context as _, Result};
 use cloud_llm_client::{
-    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
+    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
+    predict_edits_v3::Event,
 };
 use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 use input_excerpt::excerpt_for_cursor_position;
@@ -35,6 +36,7 @@ pub(crate) fn request_prediction_with_zeta1(
     snapshot: BufferSnapshot,
     position: language::Anchor,
     events: Vec<Arc<Event>>,
+    trigger: PredictEditsRequestTrigger,
     cx: &mut Context<Zeta>,
 ) -> Task<Result<Option<EditPredictionResult>>> {
     let buffer = buffer.clone();
@@ -70,6 +72,7 @@ pub(crate) fn request_prediction_with_zeta1(
         &snapshot,
         cursor_point,
         prompt_for_events,
+        trigger,
         cx,
     );
 
@@ -402,6 +405,7 @@ pub fn gather_context(
     snapshot: &BufferSnapshot,
     cursor_point: language::Point,
     prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
+    trigger: PredictEditsRequestTrigger,
     cx: &App,
 ) -> Task<Result<GatherContextOutput>> {
     cx.background_spawn({
@@ -425,6 +429,7 @@ pub fn gather_context(
                 git_info: None,
                 outline: None,
                 speculated_output: None,
+                trigger,
             };
 
             Ok(GatherContextOutput {

crates/zeta/src/zeta_tests.rs 🔗

@@ -536,7 +536,7 @@ async fn run_edit_prediction(
     zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
     cx.background_executor.run_until_parked();
     let prediction_task = zeta.update(cx, |zeta, cx| {
-        zeta.request_prediction(&project, buffer, cursor, cx)
+        zeta.request_prediction(&project, buffer, cursor, Default::default(), cx)
     });
     prediction_task.await.unwrap().unwrap().prediction.unwrap()
 }

crates/zeta_cli/src/main.rs 🔗

@@ -454,6 +454,7 @@ async fn zeta1_context(
             &snapshot,
             clipped_cursor,
             prompt_for_events,
+            cloud_llm_client::PredictEditsRequestTrigger::Cli,
             cx,
         )
     })?

crates/zeta_cli/src/predict.rs 🔗

@@ -226,7 +226,13 @@ pub async fn perform_predict(
 
     let prediction = zeta
         .update(cx, |zeta, cx| {
-            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+            zeta.request_prediction(
+                &project,
+                &cursor_buffer,
+                cursor_anchor,
+                cloud_llm_client::PredictEditsRequestTrigger::Cli,
+                cx,
+            )
         })?
         .await?;