@@ -1283,6 +1283,7 @@ struct CurrentEditPrediction {
buffer_id: EntityId,
completion: EditPrediction,
was_shown: bool,
+ was_accepted: bool,
}
impl CurrentEditPrediction {
@@ -1310,7 +1311,7 @@ impl CurrentEditPrediction {
struct PendingCompletion {
id: usize,
- _task: Task<()>,
+ task: Task<()>,
}
#[derive(Debug, Clone, Copy)]
@@ -1386,6 +1387,7 @@ pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
singleton_buffer: Option<Entity<Buffer>>,
pending_completions: ArrayVec<PendingCompletion, 2>,
+ canceled_completions: HashMap<usize, Task<()>>,
next_pending_completion_id: usize,
current_completion: Option<CurrentEditPrediction>,
last_request_timestamp: Instant,
@@ -1399,17 +1401,34 @@ impl ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
project: Entity<Project>,
singleton_buffer: Option<Entity<Buffer>>,
+ cx: &mut Context<Self>,
) -> Self {
+ cx.on_release(|this, cx| {
+ this.take_current_edit_prediction(cx);
+ })
+ .detach();
+
Self {
zeta,
singleton_buffer,
pending_completions: ArrayVec::new(),
+ canceled_completions: HashMap::default(),
next_pending_completion_id: 0,
current_completion: None,
last_request_timestamp: Instant::now(),
project,
}
}
+
+ fn take_current_edit_prediction(&mut self, cx: &mut App) {
+ if let Some(completion) = self.current_completion.take() {
+ if !completion.was_accepted {
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
+ });
+ }
+ }
+ }
}
impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
@@ -1531,42 +1550,65 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
buffer_id: buffer.entity_id(),
completion,
was_shown: false,
+ was_accepted: false,
})
})
}
Err(error) => Err(error),
};
- let Some(new_completion) = completion
- .context("edit prediction failed")
- .log_err()
- .flatten()
- else {
- this.update(cx, |this, cx| {
- if this.pending_completions[0].id == pending_completion_id {
+
+ let discarded = this
+ .update(cx, |this, cx| {
+ if this
+ .pending_completions
+ .first()
+ .is_some_and(|completion| completion.id == pending_completion_id)
+ {
this.pending_completions.remove(0);
} else {
- this.pending_completions.clear();
+ if let Some(discarded) = this.pending_completions.drain(..).next() {
+ this.canceled_completions
+ .insert(discarded.id, discarded.task);
+ }
+ }
+
+ let canceled = this.canceled_completions.remove(&pending_completion_id);
+
+ if canceled.is_some()
+ && let Ok(Some(new_completion)) = &completion
+ {
+ this.zeta.update(cx, |zeta, cx| {
+ zeta.discard_completion(new_completion.completion.id, false, cx);
+ });
+ return true;
}
cx.notify();
+ false
})
- .ok();
+ .ok()
+ .unwrap_or(true);
+
+ if discarded {
+ return;
+ }
+
+ let Some(new_completion) = completion
+ .context("edit prediction failed")
+ .log_err()
+ .flatten()
+ else {
return;
};
this.update(cx, |this, cx| {
- if this.pending_completions[0].id == pending_completion_id {
- this.pending_completions.remove(0);
- } else {
- this.pending_completions.clear();
- }
-
if let Some(old_completion) = this.current_completion.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if new_completion.should_replace_completion(old_completion, &snapshot) {
this.zeta.update(cx, |zeta, cx| {
zeta.completion_shown(&new_completion.completion, cx);
});
+ this.take_current_edit_prediction(cx);
this.current_completion = Some(new_completion);
}
} else {
@@ -1586,13 +1628,16 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
if self.pending_completions.len() <= 1 {
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
- _task: task,
+ task,
});
} else if self.pending_completions.len() == 2 {
- self.pending_completions.pop();
+ if let Some(discarded) = self.pending_completions.pop() {
+ self.canceled_completions
+ .insert(discarded.id, discarded.task);
+ }
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
- _task: task,
+ task,
});
}
}
@@ -1608,14 +1653,12 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
}
fn accept(&mut self, cx: &mut Context<Self>) {
- let completion_id = self
- .current_completion
- .as_ref()
- .map(|completion| completion.completion.id);
- if let Some(completion_id) = completion_id {
+ let completion = self.current_completion.as_mut();
+ if let Some(completion) = completion {
+ completion.was_accepted = true;
self.zeta
.update(cx, |zeta, cx| {
- zeta.accept_edit_prediction(completion_id, cx)
+ zeta.accept_edit_prediction(completion.completion.id, cx)
})
.detach();
}
@@ -1624,11 +1667,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
fn discard(&mut self, cx: &mut Context<Self>) {
self.pending_completions.clear();
- if let Some(completion) = self.current_completion.take() {
- self.zeta.update(cx, |zeta, cx| {
- zeta.discard_completion(completion.completion.id, completion.was_shown, cx);
- });
- }
+ self.take_current_edit_prediction(cx);
}
fn did_show(&mut self, _cx: &mut Context<Self>) {
@@ -1651,13 +1690,13 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
// Invalidate previous completion if it was generated for a different buffer.
if *buffer_id != buffer.entity_id() {
- self.current_completion.take();
+ self.take_current_edit_prediction(cx);
return None;
}
let buffer = buffer.read(cx);
let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
- self.current_completion.take();
+ self.take_current_edit_prediction(cx);
return None;
};