@@ -2,6 +2,17 @@ use crate::PredictEditsRequestTrigger;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::ops::Range;
+use strum::{AsRefStr, EnumString};
+
+pub const PREDICT_EDITS_MODE_HEADER_NAME: &str = "X-Zed-Predict-Edits-Mode";
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, AsRefStr, EnumString)]
+#[serde(rename_all = "snake_case")]
+#[strum(serialize_all = "snake_case")]
+pub enum PredictEditsMode {
+ Eager,
+ Subtle,
+}
#[derive(Debug, Deserialize, Serialize)]
pub struct RawCompletionRequest {
@@ -3,7 +3,8 @@ use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, globa
use cloud_api_client::LlmApiToken;
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
- PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
+ PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
+ PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
};
use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection,
@@ -29,9 +30,10 @@ use gpui::{
prelude::*,
};
use heapless::Vec as ArrayVec;
-use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
-use language::{BufferSnapshot, OffsetRangeExt};
+use language::{
+ Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
+ TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
+};
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -176,6 +178,7 @@ pub struct EditPredictionModelInput {
position: Anchor,
events: Vec<Arc<zeta_prompt::Event>>,
related_files: Vec<RelatedFile>,
+ mode: PredictEditsMode,
trigger: PredictEditsRequestTrigger,
diagnostic_search_range: Range<Point>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
@@ -2366,6 +2369,10 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = self.context_for_project(&project, cx);
+ let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
+ EditPredictionsMode::Eager => PredictEditsMode::Eager,
+ EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
+ };
let is_open_source = snapshot
.file()
@@ -2377,7 +2384,6 @@ impl EditPredictionStore {
&& is_open_source
&& self.is_data_collection_enabled(cx)
&& matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
-
let inputs = EditPredictionModelInput {
project: project.clone(),
buffer: active_buffer,
@@ -2385,8 +2391,9 @@ impl EditPredictionStore {
position,
events,
related_files,
+ mode,
trigger,
- diagnostic_search_range: diagnostic_search_range,
+ diagnostic_search_range,
debug_tx,
can_collect_data,
is_open_source,
@@ -2584,6 +2591,7 @@ impl EditPredictionStore {
organization_id: Option<OrganizationId>,
app_version: Version,
trigger: PredictEditsRequestTrigger,
+ mode: PredictEditsMode,
) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
let url = client
.http_client()
@@ -2599,6 +2607,7 @@ impl EditPredictionStore {
let req = builder
.uri(url.as_ref())
.header("Content-Encoding", "zstd")
+ .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref())
.body(compressed.clone().into());
Ok(req?)
},
@@ -43,6 +43,7 @@ pub fn request_prediction_with_zeta(
related_files,
events,
debug_tx,
+ mode,
trigger,
project,
diagnostic_search_range,
@@ -278,6 +279,7 @@ pub fn request_prediction_with_zeta(
organization_id,
app_version,
trigger,
+ mode,
)
.await?;