ep: Send edit prediction mode in prediction request (#53812)

Ben Kunkle created

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs | 11 +++++++++
crates/edit_prediction/src/edit_prediction.rs   | 21 +++++++++++++-----
crates/edit_prediction/src/zeta.rs              |  2 +
3 files changed, 28 insertions(+), 6 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -2,6 +2,17 @@ use crate::PredictEditsRequestTrigger;
 use serde::{Deserialize, Serialize};
 use std::borrow::Cow;
 use std::ops::Range;
+use strum::{AsRefStr, EnumString};
+
+pub const PREDICT_EDITS_MODE_HEADER_NAME: &str = "X-Zed-Predict-Edits-Mode";
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, AsRefStr, EnumString)]
+#[serde(rename_all = "snake_case")]
+#[strum(serialize_all = "snake_case")]
+pub enum PredictEditsMode {
+    Eager,
+    Subtle,
+}
 
 #[derive(Debug, Deserialize, Serialize)]
 pub struct RawCompletionRequest {

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -3,7 +3,8 @@ use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, globa
 use cloud_api_client::LlmApiToken;
 use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
 use cloud_llm_client::predict_edits_v3::{
-    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
+    PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
+    PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
 };
 use cloud_llm_client::{
     EditPredictionRejectReason, EditPredictionRejection,
@@ -29,9 +30,10 @@ use gpui::{
     prelude::*,
 };
 use heapless::Vec as ArrayVec;
-use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
-use language::{BufferSnapshot, OffsetRangeExt};
+use language::{
+    Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
+    TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
+};
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
@@ -176,6 +178,7 @@ pub struct EditPredictionModelInput {
     position: Anchor,
     events: Vec<Arc<zeta_prompt::Event>>,
     related_files: Vec<RelatedFile>,
+    mode: PredictEditsMode,
     trigger: PredictEditsRequestTrigger,
     diagnostic_search_range: Range<Point>,
     debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
@@ -2366,6 +2369,10 @@ impl EditPredictionStore {
             Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
 
         let related_files = self.context_for_project(&project, cx);
+        let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
+            EditPredictionsMode::Eager => PredictEditsMode::Eager,
+            EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
+        };
 
         let is_open_source = snapshot
             .file()
@@ -2377,7 +2384,6 @@ impl EditPredictionStore {
             && is_open_source
             && self.is_data_collection_enabled(cx)
             && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
-
         let inputs = EditPredictionModelInput {
             project: project.clone(),
             buffer: active_buffer,
@@ -2385,8 +2391,9 @@ impl EditPredictionStore {
             position,
             events,
             related_files,
+            mode,
             trigger,
-            diagnostic_search_range: diagnostic_search_range,
+            diagnostic_search_range,
             debug_tx,
             can_collect_data,
             is_open_source,
@@ -2584,6 +2591,7 @@ impl EditPredictionStore {
         organization_id: Option<OrganizationId>,
         app_version: Version,
         trigger: PredictEditsRequestTrigger,
+        mode: PredictEditsMode,
     ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
         let url = client
             .http_client()
@@ -2599,6 +2607,7 @@ impl EditPredictionStore {
                 let req = builder
                     .uri(url.as_ref())
                     .header("Content-Encoding", "zstd")
+                    .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref())
                     .body(compressed.clone().into());
                 Ok(req?)
             },

crates/edit_prediction/src/zeta.rs 🔗

@@ -43,6 +43,7 @@ pub fn request_prediction_with_zeta(
         related_files,
         events,
         debug_tx,
+        mode,
         trigger,
         project,
         diagnostic_search_range,
@@ -278,6 +279,7 @@ pub fn request_prediction_with_zeta(
                         organization_id,
                         app_version,
                         trigger,
+                        mode,
                     )
                     .await?;