From d3d8f1500d43685b02379f7dc132a4ba21550107 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Wed, 15 Apr 2026 01:22:14 -0500 Subject: [PATCH] ep: Send edit prediction mode in prediction request (#53812) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A or Added/Fixed/Improved ... --- .../cloud_llm_client/src/predict_edits_v3.rs | 11 ++++++++++ crates/edit_prediction/src/edit_prediction.rs | 21 +++++++++++++------ crates/edit_prediction/src/zeta.rs | 2 ++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 5002c1a770ec1955d2a96c97098867f20f9bd05d..36c091a3100844872a8ef29bf6ddd37222374d99 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -2,6 +2,17 @@ use crate::PredictEditsRequestTrigger; use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::ops::Range; +use strum::{AsRefStr, EnumString}; + +pub const PREDICT_EDITS_MODE_HEADER_NAME: &str = "X-Zed-Predict-Edits-Mode"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, AsRefStr, EnumString)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum PredictEditsMode { + Eager, + Subtle, +} #[derive(Debug, Deserialize, Serialize)] pub struct RawCompletionRequest { diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 6bca0a1639d47d09a94b650bc59ad790dbdcbf46..07ec5366db8c2d5f84c53f8ccfe44f84e393df6c 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -3,7 +3,8 @@ use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, globa use cloud_api_client::LlmApiToken; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ - PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, + PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request, + PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, }; use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, @@ -29,9 +30,10 @@ use gpui::{ prelude::*, }; use heapless::Vec as ArrayVec; -use language::language_settings::all_language_settings; -use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; -use language::{BufferSnapshot, OffsetRangeExt}; +use language::{ + Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point, + TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings, +}; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; @@ -176,6 +178,7 @@ pub struct EditPredictionModelInput { position: Anchor, events: Vec>, related_files: Vec, + mode: PredictEditsMode, trigger: PredictEditsRequestTrigger, diagnostic_search_range: Range, debug_tx: Option>, @@ -2366,6 +2369,10 @@ impl EditPredictionStore { Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); let related_files = self.context_for_project(&project, cx); + let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() { + EditPredictionsMode::Eager => PredictEditsMode::Eager, + EditPredictionsMode::Subtle => PredictEditsMode::Subtle, + }; let is_open_source = snapshot .file() @@ -2377,7 +2384,6 @@ impl EditPredictionStore { && is_open_source && self.is_data_collection_enabled(cx) && matches!(self.edit_prediction_model, EditPredictionModel::Zeta); - let inputs = EditPredictionModelInput { project: project.clone(), buffer: active_buffer, @@ -2385,8 +2391,9 @@ impl EditPredictionStore { position, events, related_files, + mode, trigger, - diagnostic_search_range: diagnostic_search_range, + diagnostic_search_range, debug_tx, can_collect_data, is_open_source, @@ -2584,6 +2591,7 @@ impl EditPredictionStore { organization_id: Option, app_version: Version, trigger: PredictEditsRequestTrigger, + mode: PredictEditsMode, ) -> Result<(PredictEditsV3Response, Option)> { let url = client .http_client() @@ -2599,6 +2607,7 @@ impl EditPredictionStore { let req = builder .uri(url.as_ref()) .header("Content-Encoding", "zstd") + .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref()) .body(compressed.clone().into()); Ok(req?) }, diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 1173cd047a93253add13da946f02cbccb8da55f9..1674de5c0a71cf9a63d2e1fc55a58645b9a9314a 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -43,6 +43,7 @@ pub fn request_prediction_with_zeta( related_files, events, debug_tx, + mode, trigger, project, diagnostic_search_range, @@ -278,6 +279,7 @@ pub fn request_prediction_with_zeta( organization_id, app_version, trigger, + mode, ) .await?;