@@ -58,6 +58,9 @@ pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
/// The name of the header used by the client to indicate that it supports receiving xAI models.
pub const CLIENT_SUPPORTS_X_AI_HEADER_NAME: &str = "x-zed-client-supports-x-ai";
+/// The maximum number of edit predictions that can be rejected per request.
+pub const MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST: usize = 100;
+
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UsageLimit {
@@ -192,6 +195,17 @@ pub struct AcceptEditPredictionBody {
pub request_id: String,
}
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct RejectEditPredictionsBody {
+ pub rejections: Vec<EditPredictionRejection>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct EditPredictionRejection {
+ pub request_id: String,
+ pub was_shown: bool,
+}
+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
@@ -104,6 +104,7 @@ pub trait EditPredictionProvider: 'static + Sized {
);
fn accept(&mut self, cx: &mut Context<Self>);
fn discard(&mut self, cx: &mut Context<Self>);
+ fn did_show(&mut self, _cx: &mut Context<Self>) {}
fn suggest(
&mut self,
buffer: &Entity<Buffer>,
@@ -142,6 +143,7 @@ pub trait EditPredictionProviderHandle {
direction: Direction,
cx: &mut App,
);
+ fn did_show(&self, cx: &mut App);
fn accept(&self, cx: &mut App);
fn discard(&self, cx: &mut App);
fn suggest(
@@ -233,6 +235,10 @@ where
self.update(cx, |this, cx| this.discard(cx))
}
+ fn did_show(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.did_show(cx))
+ }
+
fn suggest(
&self,
buffer: &Entity<Buffer>,
@@ -8,7 +8,9 @@ mod rate_completion_modal;
pub(crate) use completion_diff_element::*;
use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use db::smol::stream::StreamExt as _;
use edit_prediction::DataCollectionState;
+use futures::channel::mpsc;
pub use init::*;
use license_detection::LicenseDetectionWatcher;
pub use rate_completion_modal::*;
@@ -17,8 +19,10 @@ use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{
- AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
+ MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+ PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody,
+ ZED_VERSION_HEADER_NAME,
};
use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt;
@@ -171,12 +175,15 @@ pub struct Zeta {
shown_completions: VecDeque<EditPrediction>,
rated_completions: HashSet<EditPredictionId>,
data_collection_choice: DataCollectionChoice,
+ discarded_completions: Vec<EditPredictionRejection>,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
/// Whether an update to a newer version of Zed is required to continue using Zeta.
update_required: bool,
user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ discard_completions_debounce_task: Option<Task<()>>,
+ discard_completions_tx: mpsc::UnboundedSender<()>,
}
struct ZetaProject {
@@ -226,11 +233,25 @@ impl Zeta {
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choice();
+ let (reject_tx, mut reject_rx) = mpsc::unbounded();
+ cx.spawn(async move |this, cx| {
+ while let Some(()) = reject_rx.next().await {
+ this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
+ .await
+ .log_err();
+ }
+ anyhow::Ok(())
+ })
+ .detach();
+
Self {
projects: HashMap::default(),
client,
shown_completions: VecDeque::new(),
rated_completions: HashSet::default(),
+ discarded_completions: Vec::new(),
+ discard_completions_debounce_task: None,
+ discard_completions_tx: reject_tx,
data_collection_choice,
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
@@ -692,6 +713,75 @@ impl Zeta {
})
}
+ fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let last_rejection = self.discarded_completions.last().cloned();
+ let body = serde_json::to_string(&RejectEditPredictionsBody {
+ rejections: self.discarded_completions.clone(),
+ })
+ .ok();
+
+ let Some(last_rejection) = last_rejection else {
+ return Task::ready(anyhow::Ok(()));
+ };
+
+ cx.spawn(async move |this, cx| {
+ let http_client = client.http_client();
+ let mut response = llm_token_retry(&llm_token, &client, |token| {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+ let request_builder = request_builder.uri(
+ http_client
+ .build_zed_llm_url("/predict_edits/reject", &[])?
+ .as_ref(),
+ );
+ Ok(request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+ .body(
+ body.as_ref()
+ .context("failed to serialize body")?
+ .clone()
+ .into(),
+ )?)
+ })
+ .await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+ && app_version < minimum_required_version
+ {
+ return Err(anyhow!(ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }));
+ }
+
+ if response.status().is_success() {
+ this.update(cx, |this, _| {
+ if let Some(ix) = this
+ .discarded_completions
+ .iter()
+ .position(|rejection| rejection.request_id == last_rejection.request_id)
+ {
+ this.discarded_completions.drain(..ix + 1);
+ }
+ })
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ Err(anyhow!(
+ "error rejecting edit predictions.\nStatus: {:?}\nBody: {}",
+ response.status(),
+ body
+ ))
+ }
+ })
+ }
+
fn process_completion_response(
prediction_response: PredictEditsResponse,
buffer: Entity<Buffer>,
@@ -995,6 +1085,31 @@ impl Zeta {
)
});
}
+
+ fn discard_completion(
+ &mut self,
+ completion_id: EditPredictionId,
+ was_shown: bool,
+ cx: &mut Context<Self>,
+ ) {
+ self.discarded_completions.push(EditPredictionRejection {
+ request_id: completion_id.to_string(),
+ was_shown,
+ });
+
+ let reached_request_limit =
+ self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
+ let discard_completions_tx = self.discard_completions_tx.clone();
+ self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| {
+ const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
+ if !reached_request_limit {
+ cx.background_executor()
+ .timer(DISCARD_COMPLETIONS_DEBOUNCE)
+ .await;
+ }
+ discard_completions_tx.unbounded_send(()).log_err();
+ }));
+ }
}
pub struct PerformPredictEditsParams {
@@ -1167,6 +1282,7 @@ impl Event {
struct CurrentEditPrediction {
buffer_id: EntityId,
completion: EditPrediction,
+ was_shown: bool,
}
impl CurrentEditPrediction {
@@ -1414,6 +1530,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
c.map(|completion| CurrentEditPrediction {
buffer_id: buffer.entity_id(),
completion,
+ was_shown: false,
})
})
}
@@ -1505,9 +1622,19 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
self.pending_completions.clear();
}
- fn discard(&mut self, _cx: &mut Context<Self>) {
+ fn discard(&mut self, cx: &mut Context<Self>) {
self.pending_completions.clear();
- self.current_completion.take();
+ if let Some(completion) = self.current_completion.take() {
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
+ });
+ }
+ }
+
+ fn did_show(&mut self, _cx: &mut Context<Self>) {
+ if let Some(current_completion) = self.current_completion.as_mut() {
+ current_completion.was_shown = true;
+ }
}
fn suggest(