diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index ff8275fe40eae6945691a7b8d315414617be0235..241e760887cdf0c4455f6769c79a813de0626028 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -58,6 +58,9 @@ pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = /// The name of the header used by the client to indicate that it supports receiving xAI models. pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai"; +/// The maximum number of edit predictions that can be rejected per request. +pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100; + #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum UsageLimit { @@ -192,6 +195,17 @@ pub struct AcceptEditPredictionBody { pub request_id: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RejectEditPredictionsBody { + pub rejections: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EditPredictionRejection { + pub request_id: String, + pub was_shown: bool, +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CompletionMode { diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index c9bb0672a0c9cb7c56c3c703b0e10594d56cc0c1..aebfa5e5229ef1fec50f2d9cf74e354878ddc1c5 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -104,6 +104,7 @@ pub trait EditPredictionProvider: 'static + Sized { ); fn accept(&mut self, cx: &mut Context); fn discard(&mut self, cx: &mut Context); + fn did_show(&mut self, _cx: &mut Context) {} fn suggest( &mut self, buffer: &Entity, @@ -142,6 +143,7 @@ pub trait EditPredictionProviderHandle { direction: Direction, cx: &mut App, ); + fn did_show(&self, cx: &mut App); fn accept(&self, cx: &mut App); fn discard(&self, cx: &mut App); fn suggest( @@ -233,6 +235,10 @@ where self.update(cx, |this, cx| this.discard(cx)) } + fn did_show(&self, cx: &mut App) { + self.update(cx, |this, cx| this.did_show(cx)) + } + fn suggest( &self, buffer: &Entity, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 2da80a405a0db357712039f06d10c9e6b33e05c8..057d0b223bb43b41c863316010648005b3675119 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -7865,6 +7865,10 @@ impl Editor { self.edit_prediction_preview, EditPredictionPreview::Inactive { .. } ) { + if let Some(provider) = self.edit_prediction_provider.as_ref() { + provider.provider.did_show(cx) + } + self.edit_prediction_preview = EditPredictionPreview::Active { previous_scroll_position: None, since: Instant::now(), @@ -8044,6 +8048,9 @@ impl Editor { && !self.edit_predictions_hidden_for_vim_mode; if show_completions_in_buffer { + if let Some(provider) = &self.edit_prediction_provider { + provider.provider.did_show(cx); + } if edits .iter() .all(|(range, _)| range.to_offset(&multibuffer).is_empty()) diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 708a53ff47bd2c60e6b9620e8bed30b16419ba14..577ca77c13c0b9f8e0eff578c20d0a933c858bce 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,7 +8,9 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use db::smol::stream::StreamExt as _; use edit_prediction::DataCollectionState; +use futures::channel::mpsc; pub use init::*; use license_detection::LicenseDetectionWatcher; pub use rate_completion_modal::*; @@ -17,8 +19,10 @@ use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection, + MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody, + ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; @@ -171,12 +175,15 @@ pub struct Zeta { shown_completions: VecDeque, rated_completions: HashSet, data_collection_choice: DataCollectionChoice, + discarded_completions: Vec, llm_token: LlmApiToken, _llm_token_subscription: Subscription, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity, license_detection_watchers: HashMap>, + discard_completions_debounce_task: Option>, + discard_completions_tx: mpsc::UnboundedSender<()>, } struct ZetaProject { @@ -226,11 +233,25 @@ impl Zeta { fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choice(); + let (reject_tx, mut reject_rx) = mpsc::unbounded(); + cx.spawn(async move |this, cx| { + while let Some(()) = reject_rx.next().await { + this.update(cx, |this, cx| this.reject_edit_predictions(cx))? + .await + .log_err(); + } + anyhow::Ok(()) + }) + .detach(); + Self { projects: HashMap::default(), client, shown_completions: VecDeque::new(), rated_completions: HashSet::default(), + discarded_completions: Vec::new(), + discard_completions_debounce_task: None, + discard_completions_tx: reject_tx, data_collection_choice, llm_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( @@ -692,6 +713,75 @@ impl Zeta { }) } + fn reject_edit_predictions(&mut self, cx: &mut Context) -> Task> { + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + let last_rejection = self.discarded_completions.last().cloned(); + let body = serde_json::to_string(&RejectEditPredictionsBody { + rejections: self.discarded_completions.clone(), + }) + .ok(); + + let Some(last_rejection) = last_rejection else { + return Task::ready(anyhow::Ok(())); + }; + + cx.spawn(async move |this, cx| { + let http_client = client.http_client(); + let mut response = llm_token_retry(&llm_token, &client, |token| { + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = request_builder.uri( + http_client + .build_zed_llm_url("/predict_edits/reject", &[])? + .as_ref(), + ); + Ok(request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + .body( + body.as_ref() + .context("failed to serialize body")? + .clone() + .into(), + )?) + }) + .await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok()) + && app_version < minimum_required_version + { + return Err(anyhow!(ZedUpdateRequiredError { + minimum_version: minimum_required_version + })); + } + + if response.status().is_success() { + this.update(cx, |this, _| { + if let Some(ix) = this + .discarded_completions + .iter() + .position(|rejection| rejection.request_id == last_rejection.request_id) + { + this.discarded_completions.drain(..ix + 1); + } + }) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Err(anyhow!( + "error rejecting edit predictions.\nStatus: {:?}\nBody: {}", + response.status(), + body + )) + } + }) + } + fn process_completion_response( prediction_response: PredictEditsResponse, buffer: Entity, @@ -995,6 +1085,31 @@ impl Zeta { ) }); } + + fn discard_completion( + &mut self, + completion_id: EditPredictionId, + was_shown: bool, + cx: &mut Context, + ) { + self.discarded_completions.push(EditPredictionRejection { + request_id: completion_id.to_string(), + was_shown, + }); + + let reached_request_limit = + self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST; + let discard_completions_tx = self.discard_completions_tx.clone(); + self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| { + const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15); + if !reached_request_limit { + cx.background_executor() + .timer(DISCARD_COMPLETIONS_DEBOUNCE) + .await; + } + discard_completions_tx.unbounded_send(()).log_err(); + })); + } } pub struct PerformPredictEditsParams { @@ -1167,6 +1282,7 @@ impl Event { struct CurrentEditPrediction { buffer_id: EntityId, completion: EditPrediction, + was_shown: bool, } impl CurrentEditPrediction { @@ -1414,6 +1530,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { c.map(|completion| CurrentEditPrediction { buffer_id: buffer.entity_id(), completion, + was_shown: false, }) }) } @@ -1505,9 +1622,19 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { self.pending_completions.clear(); } - fn discard(&mut self, _cx: &mut Context) { + fn discard(&mut self, cx: &mut Context) { self.pending_completions.clear(); - self.current_completion.take(); + if let Some(completion) = self.current_completion.take() { + self.zeta.update(cx, |zeta, cx| { + zeta.discard_completion(completion.completion.id, completion.was_shown, cx); + }); + } + } + + fn did_show(&mut self, _cx: &mut Context) { + if let Some(current_completion) = self.current_completion.as_mut() { + current_completion.was_shown = true; + } } fn suggest(